query/query_engine/
default_serializer.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use common_error::ext::BoxedError;
18use common_function::aggr::{GeoPathAccumulator, HllState, UddSketchState};
19use common_function::function_registry::FUNCTION_REGISTRY;
20use common_function::scalars::udf::create_udf;
21use common_query::error::RegisterUdfSnafu;
22use common_query::logical_plan::SubstraitPlanDecoder;
23use datafusion::catalog::CatalogProviderList;
24use datafusion::common::DataFusionError;
25use datafusion::error::Result;
26use datafusion::execution::context::SessionState;
27use datafusion::execution::registry::SerializerRegistry;
28use datafusion::execution::{FunctionRegistry, SessionStateBuilder};
29use datafusion::logical_expr::LogicalPlan;
30use datafusion_expr::UserDefinedLogicalNode;
31use greptime_proto::substrait_extension::MergeScan as PbMergeScan;
32use promql::functions::{
33    quantile_udaf, AbsentOverTime, AvgOverTime, Changes, CountOverTime, Delta, Deriv, IDelta,
34    Increase, LastOverTime, MaxOverTime, MinOverTime, PresentOverTime, Rate, Resets, Round,
35    StddevOverTime, StdvarOverTime, SumOverTime,
36};
37use prost::Message;
38use session::context::QueryContextRef;
39use snafu::ResultExt;
40use substrait::extension_serializer::ExtensionSerializer;
41use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
42
43use crate::dist_plan::MergeScanLogicalPlan;
44
45/// Extended [`substrait::extension_serializer::ExtensionSerializer`] but supports [`MergeScanLogicalPlan`] serialization.
46#[derive(Debug)]
47pub struct DefaultSerializer;
48
49impl SerializerRegistry for DefaultSerializer {
50    fn serialize_logical_plan(&self, node: &dyn UserDefinedLogicalNode) -> Result<Vec<u8>> {
51        if node.name() == MergeScanLogicalPlan::name() {
52            let merge_scan = node
53                .as_any()
54                .downcast_ref::<MergeScanLogicalPlan>()
55                .expect("Failed to downcast to MergeScanLogicalPlan");
56
57            let input = merge_scan.input();
58            let is_placeholder = merge_scan.is_placeholder();
59            let input = DFLogicalSubstraitConvertor
60                .encode(input, DefaultSerializer)
61                .map_err(|e| DataFusionError::External(Box::new(e)))?
62                .to_vec();
63
64            Ok(PbMergeScan {
65                is_placeholder,
66                input,
67            }
68            .encode_to_vec())
69        } else {
70            ExtensionSerializer.serialize_logical_plan(node)
71        }
72    }
73
74    fn deserialize_logical_plan(
75        &self,
76        name: &str,
77        bytes: &[u8],
78    ) -> Result<Arc<dyn UserDefinedLogicalNode>> {
79        if name == MergeScanLogicalPlan::name() {
80            // TODO(dennis): missing `session_state` to decode the logical plan in `MergeScanLogicalPlan`,
81            // so we only save the unoptimized logical plan for view currently.
82            Err(DataFusionError::Substrait(format!(
83                "Unsupported plan node: {name}"
84            )))
85        } else {
86            ExtensionSerializer.deserialize_logical_plan(name, bytes)
87        }
88    }
89}
90
91/// The datafusion `[LogicalPlan]` decoder.
92pub struct DefaultPlanDecoder {
93    session_state: SessionState,
94    query_ctx: QueryContextRef,
95}
96
97impl DefaultPlanDecoder {
98    pub fn new(
99        session_state: SessionState,
100        query_ctx: &QueryContextRef,
101    ) -> crate::error::Result<Self> {
102        Ok(Self {
103            session_state,
104            query_ctx: query_ctx.clone(),
105        })
106    }
107}
108
109#[async_trait::async_trait]
110impl SubstraitPlanDecoder for DefaultPlanDecoder {
111    async fn decode(
112        &self,
113        message: bytes::Bytes,
114        catalog_list: Arc<dyn CatalogProviderList>,
115        optimize: bool,
116    ) -> common_query::error::Result<LogicalPlan> {
117        // The session_state already has the `DefaultSerialzier` as `SerializerRegistry`.
118        let mut session_state = SessionStateBuilder::new_from_existing(self.session_state.clone())
119            .with_catalog_list(catalog_list)
120            .build();
121        // Substrait decoder will look up the UDFs in SessionState, so we need to register them
122        // Note: the query context must be passed to set the timezone
123        // We MUST register the UDFs after we build the session state, otherwise the UDFs will be lost
124        // if they have the same name as the default UDFs or their alias.
125        // e.g. The default UDF `to_char()` has an alias `date_format()`, if we register a UDF with the name `date_format()`
126        // before we build the session state, the UDF will be lost.
127        for func in FUNCTION_REGISTRY.functions() {
128            let udf = Arc::new(create_udf(
129                func.clone(),
130                self.query_ctx.clone(),
131                Default::default(),
132            ));
133            session_state
134                .register_udf(udf)
135                .context(RegisterUdfSnafu { name: func.name() })?;
136            let _ = session_state.register_udaf(Arc::new(UddSketchState::state_udf_impl()));
137            let _ = session_state.register_udaf(Arc::new(UddSketchState::merge_udf_impl()));
138            let _ = session_state.register_udaf(Arc::new(HllState::state_udf_impl()));
139            let _ = session_state.register_udaf(Arc::new(HllState::merge_udf_impl()));
140            let _ = session_state.register_udaf(Arc::new(GeoPathAccumulator::udf_impl()));
141            let _ = session_state.register_udaf(quantile_udaf());
142
143            let _ = session_state.register_udf(Arc::new(IDelta::<false>::scalar_udf()));
144            let _ = session_state.register_udf(Arc::new(IDelta::<true>::scalar_udf()));
145            let _ = session_state.register_udf(Arc::new(Rate::scalar_udf()));
146            let _ = session_state.register_udf(Arc::new(Increase::scalar_udf()));
147            let _ = session_state.register_udf(Arc::new(Delta::scalar_udf()));
148            let _ = session_state.register_udf(Arc::new(Resets::scalar_udf()));
149            let _ = session_state.register_udf(Arc::new(Changes::scalar_udf()));
150            let _ = session_state.register_udf(Arc::new(Deriv::scalar_udf()));
151            let _ = session_state.register_udf(Arc::new(Round::scalar_udf()));
152            let _ = session_state.register_udf(Arc::new(AvgOverTime::scalar_udf()));
153            let _ = session_state.register_udf(Arc::new(MinOverTime::scalar_udf()));
154            let _ = session_state.register_udf(Arc::new(MaxOverTime::scalar_udf()));
155            let _ = session_state.register_udf(Arc::new(SumOverTime::scalar_udf()));
156            let _ = session_state.register_udf(Arc::new(CountOverTime::scalar_udf()));
157            let _ = session_state.register_udf(Arc::new(LastOverTime::scalar_udf()));
158            let _ = session_state.register_udf(Arc::new(AbsentOverTime::scalar_udf()));
159            let _ = session_state.register_udf(Arc::new(PresentOverTime::scalar_udf()));
160            let _ = session_state.register_udf(Arc::new(StddevOverTime::scalar_udf()));
161            let _ = session_state.register_udf(Arc::new(StdvarOverTime::scalar_udf()));
162            // TODO(ruihang): add quantile_over_time, predict_linear, holt_winters, round
163        }
164        let logical_plan = DFLogicalSubstraitConvertor
165            .decode(message, session_state)
166            .await
167            .map_err(BoxedError::new)
168            .context(common_query::error::DecodePlanSnafu)?;
169
170        if optimize {
171            self.session_state
172                .optimize(&logical_plan)
173                .context(common_query::error::GeneralDataFusionSnafu)
174        } else {
175            Ok(logical_plan)
176        }
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use datafusion::catalog::TableProvider;
183    use datafusion_expr::{col, lit, LogicalPlanBuilder, LogicalTableSource};
184    use datatypes::arrow::datatypes::SchemaRef;
185    use session::context::QueryContext;
186
187    use super::*;
188    use crate::dummy_catalog::DummyCatalogList;
189    use crate::optimizer::test_util::mock_table_provider;
190    use crate::options::QueryOptions;
191    use crate::QueryEngineFactory;
192
193    fn mock_plan(schema: SchemaRef) -> LogicalPlan {
194        let table_source = LogicalTableSource::new(schema);
195        let projection = None;
196        let builder =
197            LogicalPlanBuilder::scan("devices", Arc::new(table_source), projection).unwrap();
198
199        builder
200            .filter(col("k0").eq(lit("hello")))
201            .unwrap()
202            .build()
203            .unwrap()
204    }
205
206    #[tokio::test]
207    async fn test_serializer_decode_plan() {
208        let catalog_list = catalog::memory::new_memory_catalog_manager().unwrap();
209        let factory = QueryEngineFactory::new(
210            catalog_list,
211            None,
212            None,
213            None,
214            None,
215            false,
216            QueryOptions::default(),
217        );
218
219        let engine = factory.query_engine();
220
221        let table_provider = Arc::new(mock_table_provider(1.into()));
222        let plan = mock_plan(table_provider.schema().clone());
223
224        let bytes = DFLogicalSubstraitConvertor
225            .encode(&plan, DefaultSerializer)
226            .unwrap();
227
228        let plan_decoder = engine
229            .engine_context(QueryContext::arc())
230            .new_plan_decoder()
231            .unwrap();
232        let catalog_list = Arc::new(DummyCatalogList::with_table_provider(table_provider));
233
234        let decode_plan = plan_decoder
235            .decode(bytes, catalog_list, false)
236            .await
237            .unwrap();
238
239        assert_eq!(
240            "Filter: devices.k0 = Utf8(\"hello\")
241  TableScan: devices",
242            decode_plan.to_string(),
243        );
244    }
245}