query/query_engine/
default_serializer.rs1use std::sync::Arc;
16
17use common_error::ext::BoxedError;
18use common_function::function::FunctionContext;
19use common_function::function_registry::FUNCTION_REGISTRY;
20use common_query::error::RegisterUdfSnafu;
21use common_query::logical_plan::SubstraitPlanDecoder;
22use datafusion::catalog::CatalogProviderList;
23use datafusion::common::DataFusionError;
24use datafusion::error::Result;
25use datafusion::execution::context::SessionState;
26use datafusion::execution::registry::SerializerRegistry;
27use datafusion::execution::{FunctionRegistry, SessionStateBuilder};
28use datafusion::logical_expr::LogicalPlan;
29use datafusion_expr::UserDefinedLogicalNode;
30use greptime_proto::substrait_extension::MergeScan as PbMergeScan;
31use promql::functions::{
32 AbsentOverTime, AvgOverTime, Changes, CountOverTime, Delta, Deriv, HoltWinters, IDelta,
33 Increase, LastOverTime, MaxOverTime, MinOverTime, PredictLinear, PresentOverTime,
34 QuantileOverTime, Rate, Resets, Round, StddevOverTime, StdvarOverTime, SumOverTime,
35 quantile_udaf,
36};
37use prost::Message;
38use session::context::QueryContextRef;
39use snafu::ResultExt;
40use substrait::extension_serializer::ExtensionSerializer;
41use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
42
43use crate::dist_plan::MergeScanLogicalPlan;
44
45#[derive(Debug)]
47pub struct DefaultSerializer;
48
49impl SerializerRegistry for DefaultSerializer {
50 fn serialize_logical_plan(&self, node: &dyn UserDefinedLogicalNode) -> Result<Vec<u8>> {
51 if node.name() == MergeScanLogicalPlan::name() {
52 let merge_scan = node
53 .as_any()
54 .downcast_ref::<MergeScanLogicalPlan>()
55 .expect("Failed to downcast to MergeScanLogicalPlan");
56
57 let input = merge_scan.input();
58 let is_placeholder = merge_scan.is_placeholder();
59 let input = DFLogicalSubstraitConvertor
60 .encode(input, DefaultSerializer)
61 .map_err(|e| DataFusionError::External(Box::new(e)))?
62 .to_vec();
63
64 Ok(PbMergeScan {
65 is_placeholder,
66 input,
67 }
68 .encode_to_vec())
69 } else {
70 ExtensionSerializer.serialize_logical_plan(node)
71 }
72 }
73
74 fn deserialize_logical_plan(
75 &self,
76 name: &str,
77 bytes: &[u8],
78 ) -> Result<Arc<dyn UserDefinedLogicalNode>> {
79 if name == MergeScanLogicalPlan::name() {
80 Err(DataFusionError::Substrait(format!(
83 "Unsupported plan node: {name}"
84 )))
85 } else {
86 ExtensionSerializer.deserialize_logical_plan(name, bytes)
87 }
88 }
89}
90
91pub struct DefaultPlanDecoder {
93 session_state: SessionState,
94 query_ctx: QueryContextRef,
95}
96
97impl DefaultPlanDecoder {
98 pub fn new(
99 session_state: SessionState,
100 query_ctx: &QueryContextRef,
101 ) -> crate::error::Result<Self> {
102 Ok(Self {
103 session_state,
104 query_ctx: query_ctx.clone(),
105 })
106 }
107}
108
109#[async_trait::async_trait]
110impl SubstraitPlanDecoder for DefaultPlanDecoder {
111 async fn decode(
112 &self,
113 message: bytes::Bytes,
114 catalog_list: Arc<dyn CatalogProviderList>,
115 optimize: bool,
116 ) -> common_query::error::Result<LogicalPlan> {
117 let mut session_state = SessionStateBuilder::new_from_existing(self.session_state.clone())
119 .with_catalog_list(catalog_list)
120 .build();
121 for func in FUNCTION_REGISTRY.scalar_functions() {
128 let udf = func.provide(FunctionContext {
129 query_ctx: self.query_ctx.clone(),
130 state: Default::default(),
131 });
132 session_state
133 .register_udf(Arc::new(udf))
134 .context(RegisterUdfSnafu { name: func.name() })?;
135 }
136
137 for func in FUNCTION_REGISTRY.aggregate_functions() {
138 let name = func.name().to_string();
139 session_state
140 .register_udaf(Arc::new(func))
141 .context(RegisterUdfSnafu { name })?;
142 }
143
144 let _ = session_state.register_udaf(quantile_udaf());
145
146 let _ = session_state.register_udf(Arc::new(IDelta::<false>::scalar_udf()));
147 let _ = session_state.register_udf(Arc::new(IDelta::<true>::scalar_udf()));
148 let _ = session_state.register_udf(Arc::new(Rate::scalar_udf()));
149 let _ = session_state.register_udf(Arc::new(Increase::scalar_udf()));
150 let _ = session_state.register_udf(Arc::new(Delta::scalar_udf()));
151 let _ = session_state.register_udf(Arc::new(Resets::scalar_udf()));
152 let _ = session_state.register_udf(Arc::new(Changes::scalar_udf()));
153 let _ = session_state.register_udf(Arc::new(Deriv::scalar_udf()));
154 let _ = session_state.register_udf(Arc::new(Round::scalar_udf()));
155 let _ = session_state.register_udf(Arc::new(AvgOverTime::scalar_udf()));
156 let _ = session_state.register_udf(Arc::new(MinOverTime::scalar_udf()));
157 let _ = session_state.register_udf(Arc::new(MaxOverTime::scalar_udf()));
158 let _ = session_state.register_udf(Arc::new(SumOverTime::scalar_udf()));
159 let _ = session_state.register_udf(Arc::new(CountOverTime::scalar_udf()));
160 let _ = session_state.register_udf(Arc::new(LastOverTime::scalar_udf()));
161 let _ = session_state.register_udf(Arc::new(AbsentOverTime::scalar_udf()));
162 let _ = session_state.register_udf(Arc::new(PresentOverTime::scalar_udf()));
163 let _ = session_state.register_udf(Arc::new(StddevOverTime::scalar_udf()));
164 let _ = session_state.register_udf(Arc::new(StdvarOverTime::scalar_udf()));
165 let _ = session_state.register_udf(Arc::new(QuantileOverTime::scalar_udf()));
166 let _ = session_state.register_udf(Arc::new(PredictLinear::scalar_udf()));
167 let _ = session_state.register_udf(Arc::new(HoltWinters::scalar_udf()));
168
169 let logical_plan = DFLogicalSubstraitConvertor
170 .decode(message, session_state)
171 .await
172 .map_err(BoxedError::new)
173 .context(common_query::error::DecodePlanSnafu)?;
174
175 if optimize {
176 self.session_state
177 .optimize(&logical_plan)
178 .map_err(Into::into)
179 } else {
180 Ok(logical_plan)
181 }
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use datafusion::catalog::TableProvider;
188 use datafusion_expr::{LogicalPlanBuilder, LogicalTableSource, col, lit};
189 use datatypes::arrow::datatypes::SchemaRef;
190 use session::context::QueryContext;
191
192 use super::*;
193 use crate::QueryEngineFactory;
194 use crate::dummy_catalog::DummyCatalogList;
195 use crate::optimizer::test_util::mock_table_provider;
196 use crate::options::QueryOptions;
197
198 fn mock_plan(schema: SchemaRef) -> LogicalPlan {
199 let table_source = LogicalTableSource::new(schema);
200 let projection = None;
201 let builder =
202 LogicalPlanBuilder::scan("devices", Arc::new(table_source), projection).unwrap();
203
204 builder
205 .filter(col("k0").eq(lit("hello")))
206 .unwrap()
207 .build()
208 .unwrap()
209 }
210
211 #[tokio::test]
212 async fn test_serializer_decode_plan() {
213 let catalog_list = catalog::memory::new_memory_catalog_manager().unwrap();
214 let factory = QueryEngineFactory::new(
215 catalog_list,
216 None,
217 None,
218 None,
219 None,
220 false,
221 QueryOptions::default(),
222 );
223
224 let engine = factory.query_engine();
225
226 let table_provider = Arc::new(mock_table_provider(1.into()));
227 let plan = mock_plan(table_provider.schema().clone());
228
229 let bytes = DFLogicalSubstraitConvertor
230 .encode(&plan, DefaultSerializer)
231 .unwrap();
232
233 let plan_decoder = engine
234 .engine_context(QueryContext::arc())
235 .new_plan_decoder()
236 .unwrap();
237 let catalog_list = Arc::new(DummyCatalogList::with_table_provider(table_provider));
238
239 let decode_plan = plan_decoder
240 .decode(bytes, catalog_list, false)
241 .await
242 .unwrap();
243
244 assert_eq!(
245 "Filter: devices.k0 = Utf8(\"hello\")
246 TableScan: devices",
247 decode_plan.to_string(),
248 );
249 }
250}