query/query_engine/
default_serializer.rs1use std::sync::Arc;
16
17use common_error::ext::BoxedError;
18use common_function::aggr::{GeoPathAccumulator, HllState, UddSketchState};
19use common_function::function_registry::FUNCTION_REGISTRY;
20use common_function::scalars::udf::create_udf;
21use common_query::error::RegisterUdfSnafu;
22use common_query::logical_plan::SubstraitPlanDecoder;
23use datafusion::catalog::CatalogProviderList;
24use datafusion::common::DataFusionError;
25use datafusion::error::Result;
26use datafusion::execution::context::SessionState;
27use datafusion::execution::registry::SerializerRegistry;
28use datafusion::execution::{FunctionRegistry, SessionStateBuilder};
29use datafusion::logical_expr::LogicalPlan;
30use datafusion_expr::UserDefinedLogicalNode;
31use greptime_proto::substrait_extension::MergeScan as PbMergeScan;
32use promql::functions::{
33 quantile_udaf, AbsentOverTime, AvgOverTime, Changes, CountOverTime, Delta, Deriv, IDelta,
34 Increase, LastOverTime, MaxOverTime, MinOverTime, PresentOverTime, Rate, Resets, Round,
35 StddevOverTime, StdvarOverTime, SumOverTime,
36};
37use prost::Message;
38use session::context::QueryContextRef;
39use snafu::ResultExt;
40use substrait::extension_serializer::ExtensionSerializer;
41use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
42
43use crate::dist_plan::MergeScanLogicalPlan;
44
45#[derive(Debug)]
47pub struct DefaultSerializer;
48
49impl SerializerRegistry for DefaultSerializer {
50 fn serialize_logical_plan(&self, node: &dyn UserDefinedLogicalNode) -> Result<Vec<u8>> {
51 if node.name() == MergeScanLogicalPlan::name() {
52 let merge_scan = node
53 .as_any()
54 .downcast_ref::<MergeScanLogicalPlan>()
55 .expect("Failed to downcast to MergeScanLogicalPlan");
56
57 let input = merge_scan.input();
58 let is_placeholder = merge_scan.is_placeholder();
59 let input = DFLogicalSubstraitConvertor
60 .encode(input, DefaultSerializer)
61 .map_err(|e| DataFusionError::External(Box::new(e)))?
62 .to_vec();
63
64 Ok(PbMergeScan {
65 is_placeholder,
66 input,
67 }
68 .encode_to_vec())
69 } else {
70 ExtensionSerializer.serialize_logical_plan(node)
71 }
72 }
73
74 fn deserialize_logical_plan(
75 &self,
76 name: &str,
77 bytes: &[u8],
78 ) -> Result<Arc<dyn UserDefinedLogicalNode>> {
79 if name == MergeScanLogicalPlan::name() {
80 Err(DataFusionError::Substrait(format!(
83 "Unsupported plan node: {name}"
84 )))
85 } else {
86 ExtensionSerializer.deserialize_logical_plan(name, bytes)
87 }
88 }
89}
90
91pub struct DefaultPlanDecoder {
93 session_state: SessionState,
94 query_ctx: QueryContextRef,
95}
96
97impl DefaultPlanDecoder {
98 pub fn new(
99 session_state: SessionState,
100 query_ctx: &QueryContextRef,
101 ) -> crate::error::Result<Self> {
102 Ok(Self {
103 session_state,
104 query_ctx: query_ctx.clone(),
105 })
106 }
107}
108
109#[async_trait::async_trait]
110impl SubstraitPlanDecoder for DefaultPlanDecoder {
111 async fn decode(
112 &self,
113 message: bytes::Bytes,
114 catalog_list: Arc<dyn CatalogProviderList>,
115 optimize: bool,
116 ) -> common_query::error::Result<LogicalPlan> {
117 let mut session_state = SessionStateBuilder::new_from_existing(self.session_state.clone())
119 .with_catalog_list(catalog_list)
120 .build();
121 for func in FUNCTION_REGISTRY.functions() {
128 let udf = Arc::new(create_udf(
129 func.clone(),
130 self.query_ctx.clone(),
131 Default::default(),
132 ));
133 session_state
134 .register_udf(udf)
135 .context(RegisterUdfSnafu { name: func.name() })?;
136 let _ = session_state.register_udaf(Arc::new(UddSketchState::state_udf_impl()));
137 let _ = session_state.register_udaf(Arc::new(UddSketchState::merge_udf_impl()));
138 let _ = session_state.register_udaf(Arc::new(HllState::state_udf_impl()));
139 let _ = session_state.register_udaf(Arc::new(HllState::merge_udf_impl()));
140 let _ = session_state.register_udaf(Arc::new(GeoPathAccumulator::udf_impl()));
141 let _ = session_state.register_udaf(quantile_udaf());
142
143 let _ = session_state.register_udf(Arc::new(IDelta::<false>::scalar_udf()));
144 let _ = session_state.register_udf(Arc::new(IDelta::<true>::scalar_udf()));
145 let _ = session_state.register_udf(Arc::new(Rate::scalar_udf()));
146 let _ = session_state.register_udf(Arc::new(Increase::scalar_udf()));
147 let _ = session_state.register_udf(Arc::new(Delta::scalar_udf()));
148 let _ = session_state.register_udf(Arc::new(Resets::scalar_udf()));
149 let _ = session_state.register_udf(Arc::new(Changes::scalar_udf()));
150 let _ = session_state.register_udf(Arc::new(Deriv::scalar_udf()));
151 let _ = session_state.register_udf(Arc::new(Round::scalar_udf()));
152 let _ = session_state.register_udf(Arc::new(AvgOverTime::scalar_udf()));
153 let _ = session_state.register_udf(Arc::new(MinOverTime::scalar_udf()));
154 let _ = session_state.register_udf(Arc::new(MaxOverTime::scalar_udf()));
155 let _ = session_state.register_udf(Arc::new(SumOverTime::scalar_udf()));
156 let _ = session_state.register_udf(Arc::new(CountOverTime::scalar_udf()));
157 let _ = session_state.register_udf(Arc::new(LastOverTime::scalar_udf()));
158 let _ = session_state.register_udf(Arc::new(AbsentOverTime::scalar_udf()));
159 let _ = session_state.register_udf(Arc::new(PresentOverTime::scalar_udf()));
160 let _ = session_state.register_udf(Arc::new(StddevOverTime::scalar_udf()));
161 let _ = session_state.register_udf(Arc::new(StdvarOverTime::scalar_udf()));
162 }
164 let logical_plan = DFLogicalSubstraitConvertor
165 .decode(message, session_state)
166 .await
167 .map_err(BoxedError::new)
168 .context(common_query::error::DecodePlanSnafu)?;
169
170 if optimize {
171 self.session_state
172 .optimize(&logical_plan)
173 .context(common_query::error::GeneralDataFusionSnafu)
174 } else {
175 Ok(logical_plan)
176 }
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use datafusion::catalog::TableProvider;
183 use datafusion_expr::{col, lit, LogicalPlanBuilder, LogicalTableSource};
184 use datatypes::arrow::datatypes::SchemaRef;
185 use session::context::QueryContext;
186
187 use super::*;
188 use crate::dummy_catalog::DummyCatalogList;
189 use crate::optimizer::test_util::mock_table_provider;
190 use crate::options::QueryOptions;
191 use crate::QueryEngineFactory;
192
193 fn mock_plan(schema: SchemaRef) -> LogicalPlan {
194 let table_source = LogicalTableSource::new(schema);
195 let projection = None;
196 let builder =
197 LogicalPlanBuilder::scan("devices", Arc::new(table_source), projection).unwrap();
198
199 builder
200 .filter(col("k0").eq(lit("hello")))
201 .unwrap()
202 .build()
203 .unwrap()
204 }
205
206 #[tokio::test]
207 async fn test_serializer_decode_plan() {
208 let catalog_list = catalog::memory::new_memory_catalog_manager().unwrap();
209 let factory = QueryEngineFactory::new(
210 catalog_list,
211 None,
212 None,
213 None,
214 None,
215 false,
216 QueryOptions::default(),
217 );
218
219 let engine = factory.query_engine();
220
221 let table_provider = Arc::new(mock_table_provider(1.into()));
222 let plan = mock_plan(table_provider.schema().clone());
223
224 let bytes = DFLogicalSubstraitConvertor
225 .encode(&plan, DefaultSerializer)
226 .unwrap();
227
228 let plan_decoder = engine
229 .engine_context(QueryContext::arc())
230 .new_plan_decoder()
231 .unwrap();
232 let catalog_list = Arc::new(DummyCatalogList::with_table_provider(table_provider));
233
234 let decode_plan = plan_decoder
235 .decode(bytes, catalog_list, false)
236 .await
237 .unwrap();
238
239 assert_eq!(
240 "Filter: devices.k0 = Utf8(\"hello\")
241 TableScan: devices",
242 decode_plan.to_string(),
243 );
244 }
245}