query/query_engine/
default_serializer.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use common_error::ext::BoxedError;
18use common_function::function::FunctionContext;
19use common_function::function_registry::FUNCTION_REGISTRY;
20use common_query::error::RegisterUdfSnafu;
21use common_query::logical_plan::SubstraitPlanDecoder;
22use datafusion::catalog::CatalogProviderList;
23use datafusion::common::DataFusionError;
24use datafusion::error::Result;
25use datafusion::execution::context::SessionState;
26use datafusion::execution::registry::SerializerRegistry;
27use datafusion::execution::{FunctionRegistry, SessionStateBuilder};
28use datafusion::logical_expr::LogicalPlan;
29use datafusion_expr::UserDefinedLogicalNode;
30use greptime_proto::substrait_extension::MergeScan as PbMergeScan;
31use promql::functions::{
32    quantile_udaf, AbsentOverTime, AvgOverTime, Changes, CountOverTime, Delta, Deriv, HoltWinters,
33    IDelta, Increase, LastOverTime, MaxOverTime, MinOverTime, PredictLinear, PresentOverTime,
34    QuantileOverTime, Rate, Resets, Round, StddevOverTime, StdvarOverTime, SumOverTime,
35};
36use prost::Message;
37use session::context::QueryContextRef;
38use snafu::ResultExt;
39use substrait::extension_serializer::ExtensionSerializer;
40use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
41
42use crate::dist_plan::MergeScanLogicalPlan;
43
44/// Extended [`substrait::extension_serializer::ExtensionSerializer`] but supports [`MergeScanLogicalPlan`] serialization.
45#[derive(Debug)]
46pub struct DefaultSerializer;
47
48impl SerializerRegistry for DefaultSerializer {
49    fn serialize_logical_plan(&self, node: &dyn UserDefinedLogicalNode) -> Result<Vec<u8>> {
50        if node.name() == MergeScanLogicalPlan::name() {
51            let merge_scan = node
52                .as_any()
53                .downcast_ref::<MergeScanLogicalPlan>()
54                .expect("Failed to downcast to MergeScanLogicalPlan");
55
56            let input = merge_scan.input();
57            let is_placeholder = merge_scan.is_placeholder();
58            let input = DFLogicalSubstraitConvertor
59                .encode(input, DefaultSerializer)
60                .map_err(|e| DataFusionError::External(Box::new(e)))?
61                .to_vec();
62
63            Ok(PbMergeScan {
64                is_placeholder,
65                input,
66            }
67            .encode_to_vec())
68        } else {
69            ExtensionSerializer.serialize_logical_plan(node)
70        }
71    }
72
73    fn deserialize_logical_plan(
74        &self,
75        name: &str,
76        bytes: &[u8],
77    ) -> Result<Arc<dyn UserDefinedLogicalNode>> {
78        if name == MergeScanLogicalPlan::name() {
79            // TODO(dennis): missing `session_state` to decode the logical plan in `MergeScanLogicalPlan`,
80            // so we only save the unoptimized logical plan for view currently.
81            Err(DataFusionError::Substrait(format!(
82                "Unsupported plan node: {name}"
83            )))
84        } else {
85            ExtensionSerializer.deserialize_logical_plan(name, bytes)
86        }
87    }
88}
89
90/// The datafusion `[LogicalPlan]` decoder.
91pub struct DefaultPlanDecoder {
92    session_state: SessionState,
93    query_ctx: QueryContextRef,
94}
95
96impl DefaultPlanDecoder {
97    pub fn new(
98        session_state: SessionState,
99        query_ctx: &QueryContextRef,
100    ) -> crate::error::Result<Self> {
101        Ok(Self {
102            session_state,
103            query_ctx: query_ctx.clone(),
104        })
105    }
106}
107
108#[async_trait::async_trait]
109impl SubstraitPlanDecoder for DefaultPlanDecoder {
110    async fn decode(
111        &self,
112        message: bytes::Bytes,
113        catalog_list: Arc<dyn CatalogProviderList>,
114        optimize: bool,
115    ) -> common_query::error::Result<LogicalPlan> {
116        // The session_state already has the `DefaultSerialzier` as `SerializerRegistry`.
117        let mut session_state = SessionStateBuilder::new_from_existing(self.session_state.clone())
118            .with_catalog_list(catalog_list)
119            .build();
120        // Substrait decoder will look up the UDFs in SessionState, so we need to register them
121        // Note: the query context must be passed to set the timezone
122        // We MUST register the UDFs after we build the session state, otherwise the UDFs will be lost
123        // if they have the same name as the default UDFs or their alias.
124        // e.g. The default UDF `to_char()` has an alias `date_format()`, if we register a UDF with the name `date_format()`
125        // before we build the session state, the UDF will be lost.
126        for func in FUNCTION_REGISTRY.scalar_functions() {
127            let udf = func.provide(FunctionContext {
128                query_ctx: self.query_ctx.clone(),
129                state: Default::default(),
130            });
131            session_state
132                .register_udf(Arc::new(udf))
133                .context(RegisterUdfSnafu { name: func.name() })?;
134        }
135
136        for func in FUNCTION_REGISTRY.aggregate_functions() {
137            let name = func.name().to_string();
138            session_state
139                .register_udaf(Arc::new(func))
140                .context(RegisterUdfSnafu { name })?;
141        }
142
143        let _ = session_state.register_udaf(quantile_udaf());
144
145        let _ = session_state.register_udf(Arc::new(IDelta::<false>::scalar_udf()));
146        let _ = session_state.register_udf(Arc::new(IDelta::<true>::scalar_udf()));
147        let _ = session_state.register_udf(Arc::new(Rate::scalar_udf()));
148        let _ = session_state.register_udf(Arc::new(Increase::scalar_udf()));
149        let _ = session_state.register_udf(Arc::new(Delta::scalar_udf()));
150        let _ = session_state.register_udf(Arc::new(Resets::scalar_udf()));
151        let _ = session_state.register_udf(Arc::new(Changes::scalar_udf()));
152        let _ = session_state.register_udf(Arc::new(Deriv::scalar_udf()));
153        let _ = session_state.register_udf(Arc::new(Round::scalar_udf()));
154        let _ = session_state.register_udf(Arc::new(AvgOverTime::scalar_udf()));
155        let _ = session_state.register_udf(Arc::new(MinOverTime::scalar_udf()));
156        let _ = session_state.register_udf(Arc::new(MaxOverTime::scalar_udf()));
157        let _ = session_state.register_udf(Arc::new(SumOverTime::scalar_udf()));
158        let _ = session_state.register_udf(Arc::new(CountOverTime::scalar_udf()));
159        let _ = session_state.register_udf(Arc::new(LastOverTime::scalar_udf()));
160        let _ = session_state.register_udf(Arc::new(AbsentOverTime::scalar_udf()));
161        let _ = session_state.register_udf(Arc::new(PresentOverTime::scalar_udf()));
162        let _ = session_state.register_udf(Arc::new(StddevOverTime::scalar_udf()));
163        let _ = session_state.register_udf(Arc::new(StdvarOverTime::scalar_udf()));
164        let _ = session_state.register_udf(Arc::new(QuantileOverTime::scalar_udf()));
165        let _ = session_state.register_udf(Arc::new(PredictLinear::scalar_udf()));
166        let _ = session_state.register_udf(Arc::new(HoltWinters::scalar_udf()));
167
168        let logical_plan = DFLogicalSubstraitConvertor
169            .decode(message, session_state)
170            .await
171            .map_err(BoxedError::new)
172            .context(common_query::error::DecodePlanSnafu)?;
173
174        if optimize {
175            self.session_state
176                .optimize(&logical_plan)
177                .context(common_query::error::GeneralDataFusionSnafu)
178        } else {
179            Ok(logical_plan)
180        }
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use datafusion::catalog::TableProvider;
187    use datafusion_expr::{col, lit, LogicalPlanBuilder, LogicalTableSource};
188    use datatypes::arrow::datatypes::SchemaRef;
189    use session::context::QueryContext;
190
191    use super::*;
192    use crate::dummy_catalog::DummyCatalogList;
193    use crate::optimizer::test_util::mock_table_provider;
194    use crate::options::QueryOptions;
195    use crate::QueryEngineFactory;
196
197    fn mock_plan(schema: SchemaRef) -> LogicalPlan {
198        let table_source = LogicalTableSource::new(schema);
199        let projection = None;
200        let builder =
201            LogicalPlanBuilder::scan("devices", Arc::new(table_source), projection).unwrap();
202
203        builder
204            .filter(col("k0").eq(lit("hello")))
205            .unwrap()
206            .build()
207            .unwrap()
208    }
209
210    #[tokio::test]
211    async fn test_serializer_decode_plan() {
212        let catalog_list = catalog::memory::new_memory_catalog_manager().unwrap();
213        let factory = QueryEngineFactory::new(
214            catalog_list,
215            None,
216            None,
217            None,
218            None,
219            false,
220            QueryOptions::default(),
221        );
222
223        let engine = factory.query_engine();
224
225        let table_provider = Arc::new(mock_table_provider(1.into()));
226        let plan = mock_plan(table_provider.schema().clone());
227
228        let bytes = DFLogicalSubstraitConvertor
229            .encode(&plan, DefaultSerializer)
230            .unwrap();
231
232        let plan_decoder = engine
233            .engine_context(QueryContext::arc())
234            .new_plan_decoder()
235            .unwrap();
236        let catalog_list = Arc::new(DummyCatalogList::with_table_provider(table_provider));
237
238        let decode_plan = plan_decoder
239            .decode(bytes, catalog_list, false)
240            .await
241            .unwrap();
242
243        assert_eq!(
244            "Filter: devices.k0 = Utf8(\"hello\")
245  TableScan: devices",
246            decode_plan.to_string(),
247        );
248    }
249}