query/query_engine/
default_serializer.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use common_error::ext::BoxedError;
18use common_function::function::FunctionContext;
19use common_function::function_registry::FUNCTION_REGISTRY;
20use common_query::error::RegisterUdfSnafu;
21use common_query::logical_plan::SubstraitPlanDecoder;
22use datafusion::catalog::CatalogProviderList;
23use datafusion::common::DataFusionError;
24use datafusion::error::Result;
25use datafusion::execution::context::SessionState;
26use datafusion::execution::registry::SerializerRegistry;
27use datafusion::execution::{FunctionRegistry, SessionStateBuilder};
28use datafusion::logical_expr::LogicalPlan;
29use datafusion_expr::UserDefinedLogicalNode;
30use greptime_proto::substrait_extension::MergeScan as PbMergeScan;
31use promql::functions::{
32    AbsentOverTime, AvgOverTime, Changes, CountOverTime, Delta, Deriv, HoltWinters, IDelta,
33    Increase, LastOverTime, MaxOverTime, MinOverTime, PredictLinear, PresentOverTime,
34    QuantileOverTime, Rate, Resets, Round, StddevOverTime, StdvarOverTime, SumOverTime,
35    quantile_udaf,
36};
37use prost::Message;
38use session::context::QueryContextRef;
39use snafu::ResultExt;
40use substrait::extension_serializer::ExtensionSerializer;
41use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
42
43use crate::dist_plan::MergeScanLogicalPlan;
44
45/// Extended [`substrait::extension_serializer::ExtensionSerializer`] but supports [`MergeScanLogicalPlan`] serialization.
46#[derive(Debug)]
47pub struct DefaultSerializer;
48
49impl SerializerRegistry for DefaultSerializer {
50    fn serialize_logical_plan(&self, node: &dyn UserDefinedLogicalNode) -> Result<Vec<u8>> {
51        if node.name() == MergeScanLogicalPlan::name() {
52            let merge_scan = node
53                .as_any()
54                .downcast_ref::<MergeScanLogicalPlan>()
55                .expect("Failed to downcast to MergeScanLogicalPlan");
56
57            let input = merge_scan.input();
58            let is_placeholder = merge_scan.is_placeholder();
59            let input = DFLogicalSubstraitConvertor
60                .encode(input, DefaultSerializer)
61                .map_err(|e| DataFusionError::External(Box::new(e)))?
62                .to_vec();
63
64            Ok(PbMergeScan {
65                is_placeholder,
66                input,
67            }
68            .encode_to_vec())
69        } else {
70            ExtensionSerializer.serialize_logical_plan(node)
71        }
72    }
73
74    fn deserialize_logical_plan(
75        &self,
76        name: &str,
77        bytes: &[u8],
78    ) -> Result<Arc<dyn UserDefinedLogicalNode>> {
79        if name == MergeScanLogicalPlan::name() {
80            // TODO(dennis): missing `session_state` to decode the logical plan in `MergeScanLogicalPlan`,
81            // so we only save the unoptimized logical plan for view currently.
82            Err(DataFusionError::Substrait(format!(
83                "Unsupported plan node: {name}"
84            )))
85        } else {
86            ExtensionSerializer.deserialize_logical_plan(name, bytes)
87        }
88    }
89}
90
91/// The datafusion `[LogicalPlan]` decoder.
92pub struct DefaultPlanDecoder {
93    session_state: SessionState,
94    query_ctx: QueryContextRef,
95}
96
97impl DefaultPlanDecoder {
98    pub fn new(
99        session_state: SessionState,
100        query_ctx: &QueryContextRef,
101    ) -> crate::error::Result<Self> {
102        Ok(Self {
103            session_state,
104            query_ctx: query_ctx.clone(),
105        })
106    }
107}
108
109#[async_trait::async_trait]
110impl SubstraitPlanDecoder for DefaultPlanDecoder {
111    async fn decode(
112        &self,
113        message: bytes::Bytes,
114        catalog_list: Arc<dyn CatalogProviderList>,
115        optimize: bool,
116    ) -> common_query::error::Result<LogicalPlan> {
117        // The session_state already has the `DefaultSerialzier` as `SerializerRegistry`.
118        let mut session_state = SessionStateBuilder::new_from_existing(self.session_state.clone())
119            .with_catalog_list(catalog_list)
120            .build();
121        // Substrait decoder will look up the UDFs in SessionState, so we need to register them
122        // Note: the query context must be passed to set the timezone
123        // We MUST register the UDFs after we build the session state, otherwise the UDFs will be lost
124        // if they have the same name as the default UDFs or their alias.
125        // e.g. The default UDF `to_char()` has an alias `date_format()`, if we register a UDF with the name `date_format()`
126        // before we build the session state, the UDF will be lost.
127        for func in FUNCTION_REGISTRY.scalar_functions() {
128            let udf = func.provide(FunctionContext {
129                query_ctx: self.query_ctx.clone(),
130                state: Default::default(),
131            });
132            session_state
133                .register_udf(Arc::new(udf))
134                .context(RegisterUdfSnafu { name: func.name() })?;
135        }
136
137        for func in FUNCTION_REGISTRY.aggregate_functions() {
138            let name = func.name().to_string();
139            session_state
140                .register_udaf(Arc::new(func))
141                .context(RegisterUdfSnafu { name })?;
142        }
143
144        let _ = session_state.register_udaf(quantile_udaf());
145
146        let _ = session_state.register_udf(Arc::new(IDelta::<false>::scalar_udf()));
147        let _ = session_state.register_udf(Arc::new(IDelta::<true>::scalar_udf()));
148        let _ = session_state.register_udf(Arc::new(Rate::scalar_udf()));
149        let _ = session_state.register_udf(Arc::new(Increase::scalar_udf()));
150        let _ = session_state.register_udf(Arc::new(Delta::scalar_udf()));
151        let _ = session_state.register_udf(Arc::new(Resets::scalar_udf()));
152        let _ = session_state.register_udf(Arc::new(Changes::scalar_udf()));
153        let _ = session_state.register_udf(Arc::new(Deriv::scalar_udf()));
154        let _ = session_state.register_udf(Arc::new(Round::scalar_udf()));
155        let _ = session_state.register_udf(Arc::new(AvgOverTime::scalar_udf()));
156        let _ = session_state.register_udf(Arc::new(MinOverTime::scalar_udf()));
157        let _ = session_state.register_udf(Arc::new(MaxOverTime::scalar_udf()));
158        let _ = session_state.register_udf(Arc::new(SumOverTime::scalar_udf()));
159        let _ = session_state.register_udf(Arc::new(CountOverTime::scalar_udf()));
160        let _ = session_state.register_udf(Arc::new(LastOverTime::scalar_udf()));
161        let _ = session_state.register_udf(Arc::new(AbsentOverTime::scalar_udf()));
162        let _ = session_state.register_udf(Arc::new(PresentOverTime::scalar_udf()));
163        let _ = session_state.register_udf(Arc::new(StddevOverTime::scalar_udf()));
164        let _ = session_state.register_udf(Arc::new(StdvarOverTime::scalar_udf()));
165        let _ = session_state.register_udf(Arc::new(QuantileOverTime::scalar_udf()));
166        let _ = session_state.register_udf(Arc::new(PredictLinear::scalar_udf()));
167        let _ = session_state.register_udf(Arc::new(HoltWinters::scalar_udf()));
168
169        let logical_plan = DFLogicalSubstraitConvertor
170            .decode(message, session_state)
171            .await
172            .map_err(BoxedError::new)
173            .context(common_query::error::DecodePlanSnafu)?;
174
175        if optimize {
176            self.session_state
177                .optimize(&logical_plan)
178                .map_err(Into::into)
179        } else {
180            Ok(logical_plan)
181        }
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use datafusion::catalog::TableProvider;
188    use datafusion_expr::{LogicalPlanBuilder, LogicalTableSource, col, lit};
189    use datatypes::arrow::datatypes::SchemaRef;
190    use session::context::QueryContext;
191
192    use super::*;
193    use crate::QueryEngineFactory;
194    use crate::dummy_catalog::DummyCatalogList;
195    use crate::optimizer::test_util::mock_table_provider;
196    use crate::options::QueryOptions;
197
198    fn mock_plan(schema: SchemaRef) -> LogicalPlan {
199        let table_source = LogicalTableSource::new(schema);
200        let projection = None;
201        let builder =
202            LogicalPlanBuilder::scan("devices", Arc::new(table_source), projection).unwrap();
203
204        builder
205            .filter(col("k0").eq(lit("hello")))
206            .unwrap()
207            .build()
208            .unwrap()
209    }
210
211    #[tokio::test]
212    async fn test_serializer_decode_plan() {
213        let catalog_list = catalog::memory::new_memory_catalog_manager().unwrap();
214        let factory = QueryEngineFactory::new(
215            catalog_list,
216            None,
217            None,
218            None,
219            None,
220            false,
221            QueryOptions::default(),
222        );
223
224        let engine = factory.query_engine();
225
226        let table_provider = Arc::new(mock_table_provider(1.into()));
227        let plan = mock_plan(table_provider.schema().clone());
228
229        let bytes = DFLogicalSubstraitConvertor
230            .encode(&plan, DefaultSerializer)
231            .unwrap();
232
233        let plan_decoder = engine
234            .engine_context(QueryContext::arc())
235            .new_plan_decoder()
236            .unwrap();
237        let catalog_list = Arc::new(DummyCatalogList::with_table_provider(table_provider));
238
239        let decode_plan = plan_decoder
240            .decode(bytes, catalog_list, false)
241            .await
242            .unwrap();
243
244        assert_eq!(
245            "Filter: devices.k0 = Utf8(\"hello\")
246  TableScan: devices",
247            decode_plan.to_string(),
248        );
249    }
250}