query/query_engine/
default_serializer.rs1use std::sync::Arc;
16
17use common_error::ext::BoxedError;
18use common_function::function::FunctionContext;
19use common_function::function_registry::FUNCTION_REGISTRY;
20use common_query::error::RegisterUdfSnafu;
21use common_query::logical_plan::SubstraitPlanDecoder;
22use datafusion::catalog::CatalogProviderList;
23use datafusion::common::DataFusionError;
24use datafusion::error::Result;
25use datafusion::execution::context::SessionState;
26use datafusion::execution::registry::SerializerRegistry;
27use datafusion::execution::{FunctionRegistry, SessionStateBuilder};
28use datafusion::logical_expr::LogicalPlan;
29use datafusion_expr::UserDefinedLogicalNode;
30use greptime_proto::substrait_extension::MergeScan as PbMergeScan;
31use promql::functions::{
32 quantile_udaf, AbsentOverTime, AvgOverTime, Changes, CountOverTime, Delta, Deriv, HoltWinters,
33 IDelta, Increase, LastOverTime, MaxOverTime, MinOverTime, PredictLinear, PresentOverTime,
34 QuantileOverTime, Rate, Resets, Round, StddevOverTime, StdvarOverTime, SumOverTime,
35};
36use prost::Message;
37use session::context::QueryContextRef;
38use snafu::ResultExt;
39use substrait::extension_serializer::ExtensionSerializer;
40use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
41
42use crate::dist_plan::MergeScanLogicalPlan;
43
44#[derive(Debug)]
46pub struct DefaultSerializer;
47
48impl SerializerRegistry for DefaultSerializer {
49 fn serialize_logical_plan(&self, node: &dyn UserDefinedLogicalNode) -> Result<Vec<u8>> {
50 if node.name() == MergeScanLogicalPlan::name() {
51 let merge_scan = node
52 .as_any()
53 .downcast_ref::<MergeScanLogicalPlan>()
54 .expect("Failed to downcast to MergeScanLogicalPlan");
55
56 let input = merge_scan.input();
57 let is_placeholder = merge_scan.is_placeholder();
58 let input = DFLogicalSubstraitConvertor
59 .encode(input, DefaultSerializer)
60 .map_err(|e| DataFusionError::External(Box::new(e)))?
61 .to_vec();
62
63 Ok(PbMergeScan {
64 is_placeholder,
65 input,
66 }
67 .encode_to_vec())
68 } else {
69 ExtensionSerializer.serialize_logical_plan(node)
70 }
71 }
72
73 fn deserialize_logical_plan(
74 &self,
75 name: &str,
76 bytes: &[u8],
77 ) -> Result<Arc<dyn UserDefinedLogicalNode>> {
78 if name == MergeScanLogicalPlan::name() {
79 Err(DataFusionError::Substrait(format!(
82 "Unsupported plan node: {name}"
83 )))
84 } else {
85 ExtensionSerializer.deserialize_logical_plan(name, bytes)
86 }
87 }
88}
89
90pub struct DefaultPlanDecoder {
92 session_state: SessionState,
93 query_ctx: QueryContextRef,
94}
95
96impl DefaultPlanDecoder {
97 pub fn new(
98 session_state: SessionState,
99 query_ctx: &QueryContextRef,
100 ) -> crate::error::Result<Self> {
101 Ok(Self {
102 session_state,
103 query_ctx: query_ctx.clone(),
104 })
105 }
106}
107
108#[async_trait::async_trait]
109impl SubstraitPlanDecoder for DefaultPlanDecoder {
110 async fn decode(
111 &self,
112 message: bytes::Bytes,
113 catalog_list: Arc<dyn CatalogProviderList>,
114 optimize: bool,
115 ) -> common_query::error::Result<LogicalPlan> {
116 let mut session_state = SessionStateBuilder::new_from_existing(self.session_state.clone())
118 .with_catalog_list(catalog_list)
119 .build();
120 for func in FUNCTION_REGISTRY.scalar_functions() {
127 let udf = func.provide(FunctionContext {
128 query_ctx: self.query_ctx.clone(),
129 state: Default::default(),
130 });
131 session_state
132 .register_udf(Arc::new(udf))
133 .context(RegisterUdfSnafu { name: func.name() })?;
134 }
135
136 for func in FUNCTION_REGISTRY.aggregate_functions() {
137 let name = func.name().to_string();
138 session_state
139 .register_udaf(Arc::new(func))
140 .context(RegisterUdfSnafu { name })?;
141 }
142
143 let _ = session_state.register_udaf(quantile_udaf());
144
145 let _ = session_state.register_udf(Arc::new(IDelta::<false>::scalar_udf()));
146 let _ = session_state.register_udf(Arc::new(IDelta::<true>::scalar_udf()));
147 let _ = session_state.register_udf(Arc::new(Rate::scalar_udf()));
148 let _ = session_state.register_udf(Arc::new(Increase::scalar_udf()));
149 let _ = session_state.register_udf(Arc::new(Delta::scalar_udf()));
150 let _ = session_state.register_udf(Arc::new(Resets::scalar_udf()));
151 let _ = session_state.register_udf(Arc::new(Changes::scalar_udf()));
152 let _ = session_state.register_udf(Arc::new(Deriv::scalar_udf()));
153 let _ = session_state.register_udf(Arc::new(Round::scalar_udf()));
154 let _ = session_state.register_udf(Arc::new(AvgOverTime::scalar_udf()));
155 let _ = session_state.register_udf(Arc::new(MinOverTime::scalar_udf()));
156 let _ = session_state.register_udf(Arc::new(MaxOverTime::scalar_udf()));
157 let _ = session_state.register_udf(Arc::new(SumOverTime::scalar_udf()));
158 let _ = session_state.register_udf(Arc::new(CountOverTime::scalar_udf()));
159 let _ = session_state.register_udf(Arc::new(LastOverTime::scalar_udf()));
160 let _ = session_state.register_udf(Arc::new(AbsentOverTime::scalar_udf()));
161 let _ = session_state.register_udf(Arc::new(PresentOverTime::scalar_udf()));
162 let _ = session_state.register_udf(Arc::new(StddevOverTime::scalar_udf()));
163 let _ = session_state.register_udf(Arc::new(StdvarOverTime::scalar_udf()));
164 let _ = session_state.register_udf(Arc::new(QuantileOverTime::scalar_udf()));
165 let _ = session_state.register_udf(Arc::new(PredictLinear::scalar_udf()));
166 let _ = session_state.register_udf(Arc::new(HoltWinters::scalar_udf()));
167
168 let logical_plan = DFLogicalSubstraitConvertor
169 .decode(message, session_state)
170 .await
171 .map_err(BoxedError::new)
172 .context(common_query::error::DecodePlanSnafu)?;
173
174 if optimize {
175 self.session_state
176 .optimize(&logical_plan)
177 .context(common_query::error::GeneralDataFusionSnafu)
178 } else {
179 Ok(logical_plan)
180 }
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use datafusion::catalog::TableProvider;
187 use datafusion_expr::{col, lit, LogicalPlanBuilder, LogicalTableSource};
188 use datatypes::arrow::datatypes::SchemaRef;
189 use session::context::QueryContext;
190
191 use super::*;
192 use crate::dummy_catalog::DummyCatalogList;
193 use crate::optimizer::test_util::mock_table_provider;
194 use crate::options::QueryOptions;
195 use crate::QueryEngineFactory;
196
197 fn mock_plan(schema: SchemaRef) -> LogicalPlan {
198 let table_source = LogicalTableSource::new(schema);
199 let projection = None;
200 let builder =
201 LogicalPlanBuilder::scan("devices", Arc::new(table_source), projection).unwrap();
202
203 builder
204 .filter(col("k0").eq(lit("hello")))
205 .unwrap()
206 .build()
207 .unwrap()
208 }
209
210 #[tokio::test]
211 async fn test_serializer_decode_plan() {
212 let catalog_list = catalog::memory::new_memory_catalog_manager().unwrap();
213 let factory = QueryEngineFactory::new(
214 catalog_list,
215 None,
216 None,
217 None,
218 None,
219 false,
220 QueryOptions::default(),
221 );
222
223 let engine = factory.query_engine();
224
225 let table_provider = Arc::new(mock_table_provider(1.into()));
226 let plan = mock_plan(table_provider.schema().clone());
227
228 let bytes = DFLogicalSubstraitConvertor
229 .encode(&plan, DefaultSerializer)
230 .unwrap();
231
232 let plan_decoder = engine
233 .engine_context(QueryContext::arc())
234 .new_plan_decoder()
235 .unwrap();
236 let catalog_list = Arc::new(DummyCatalogList::with_table_provider(table_provider));
237
238 let decode_plan = plan_decoder
239 .decode(bytes, catalog_list, false)
240 .await
241 .unwrap();
242
243 assert_eq!(
244 "Filter: devices.k0 = Utf8(\"hello\")
245 TableScan: devices",
246 decode_plan.to_string(),
247 );
248 }
249}