1#![warn(unused)]
18
19use std::collections::{HashMap, HashSet};
20use std::sync::Arc;
21
22use common_error::ext::BoxedError;
23use common_telemetry::debug;
24use datafusion::config::ConfigOptions;
25use datafusion::error::DataFusionError;
26use datafusion::functions_aggregate::count::count_udaf;
27use datafusion::functions_aggregate::sum::sum_udaf;
28use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
29use datafusion::optimizer::common_subexpr_eliminate::CommonSubexprEliminate;
30use datafusion::optimizer::optimize_projections::OptimizeProjections;
31use datafusion::optimizer::simplify_expressions::SimplifyExpressions;
32use datafusion::optimizer::utils::NamePreserver;
33use datafusion::optimizer::{Analyzer, AnalyzerRule, Optimizer, OptimizerContext};
34use datafusion_common::tree_node::{
35 Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
36};
37use datafusion_common::{Column, DFSchema, ScalarValue};
38use datafusion_expr::utils::merge_schema;
39use datafusion_expr::{
40 BinaryExpr, ColumnarValue, Expr, Literal, Operator, Projection, ScalarFunctionArgs,
41 ScalarUDFImpl, Signature, TypeSignature, Volatility,
42};
43use query::QueryEngine;
44use query::optimizer::count_wildcard::CountWildcardToTimeIndexRule;
45use query::parser::QueryLanguageParser;
46use query::query_engine::DefaultSerializer;
47use session::context::QueryContextRef;
48use snafu::ResultExt;
49use substrait::DFLogicalSubstraitConvertor;
52
53use crate::adapter::FlownodeContext;
54use crate::error::{DatafusionSnafu, Error, ExternalSnafu, UnexpectedSnafu};
55use crate::expr::{TUMBLE_END, TUMBLE_START};
56use crate::plan::TypedPlan;
57
58pub async fn apply_df_optimizer(
60 plan: datafusion_expr::LogicalPlan,
61 query_ctx: &QueryContextRef,
62) -> Result<datafusion_expr::LogicalPlan, Error> {
63 let cfg = query_ctx.create_config_options();
64 let analyzer = Analyzer::with_rules(vec![
65 Arc::new(CountWildcardToTimeIndexRule),
66 Arc::new(AvgExpandRule),
67 Arc::new(TumbleExpandRule),
68 Arc::new(CheckGroupByRule::new()),
69 Arc::new(TypeCoercion::new()),
70 ]);
71 let plan = analyzer
72 .execute_and_check(plan, &cfg, |p, r| {
73 debug!("After apply rule {}, get plan: \n{:?}", r.name(), p);
74 })
75 .context(DatafusionSnafu {
76 context: "Fail to apply analyzer",
77 })?;
78
79 let ctx = OptimizerContext::new();
80 let optimizer = Optimizer::with_rules(vec![
81 Arc::new(OptimizeProjections::new()),
82 Arc::new(CommonSubexprEliminate::new()),
83 Arc::new(SimplifyExpressions::new()),
84 ]);
85 let plan = optimizer
86 .optimize(plan, &ctx, |_, _| {})
87 .context(DatafusionSnafu {
88 context: "Fail to apply optimizer",
89 })?;
90
91 Ok(plan)
92}
93
94pub async fn sql_to_flow_plan(
97 ctx: &mut FlownodeContext,
98 engine: &Arc<dyn QueryEngine>,
99 sql: &str,
100) -> Result<TypedPlan, Error> {
101 let query_ctx = ctx.query_context.clone().ok_or_else(|| {
102 UnexpectedSnafu {
103 reason: "Query context is missing",
104 }
105 .build()
106 })?;
107 let stmt = QueryLanguageParser::parse_sql(sql, &query_ctx)
108 .map_err(BoxedError::new)
109 .context(ExternalSnafu)?;
110 let plan = engine
111 .planner()
112 .plan(&stmt, query_ctx.clone())
113 .await
114 .map_err(BoxedError::new)
115 .context(ExternalSnafu)?;
116
117 let opted_plan = apply_df_optimizer(plan, &query_ctx).await?;
118
119 let sub_plan = DFLogicalSubstraitConvertor {}
121 .to_sub_plan(&opted_plan, DefaultSerializer)
122 .map_err(BoxedError::new)
123 .context(ExternalSnafu)?;
124
125 let flow_plan = TypedPlan::from_substrait_plan(ctx, &sub_plan).await?;
126
127 Ok(flow_plan)
128}
129
130#[derive(Debug)]
131struct AvgExpandRule;
132
133impl AnalyzerRule for AvgExpandRule {
134 fn analyze(
135 &self,
136 plan: datafusion_expr::LogicalPlan,
137 _config: &ConfigOptions,
138 ) -> datafusion_common::Result<datafusion_expr::LogicalPlan> {
139 let transformed = plan
140 .transform_up_with_subqueries(expand_avg_analyzer)?
141 .data
142 .transform_down_with_subqueries(put_aggr_to_proj_analyzer)?
143 .data;
144 Ok(transformed)
145 }
146
147 fn name(&self) -> &str {
148 "avg_expand"
149 }
150}
151
152fn put_aggr_to_proj_analyzer(
164 plan: datafusion_expr::LogicalPlan,
165) -> Result<Transformed<datafusion_expr::LogicalPlan>, DataFusionError> {
166 if let datafusion_expr::LogicalPlan::Projection(proj) = &plan
167 && let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref()
168 {
169 let mut replace_old_proj_exprs = HashMap::new();
170 let mut expanded_aggr_exprs = vec![];
171 for aggr_expr in &aggr.aggr_expr {
172 let mut is_composite = false;
173 if let Expr::AggregateFunction(_) = &aggr_expr {
174 expanded_aggr_exprs.push(aggr_expr.clone());
175 } else {
176 let old_name = aggr_expr.name_for_alias()?;
177 let new_proj_expr = aggr_expr
178 .clone()
179 .transform(|ch| {
180 if let Expr::AggregateFunction(_) = &ch {
181 is_composite = true;
182 expanded_aggr_exprs.push(ch.clone());
183 Ok(Transformed::yes(Expr::Column(Column::from_qualified_name(
184 ch.name_for_alias()?,
185 ))))
186 } else {
187 Ok(Transformed::no(ch))
188 }
189 })?
190 .data;
191 replace_old_proj_exprs.insert(old_name, new_proj_expr);
192 }
193 }
194
195 if expanded_aggr_exprs.len() > aggr.aggr_expr.len() {
196 let mut aggr = aggr.clone();
197 aggr.aggr_expr = expanded_aggr_exprs;
198 let mut aggr_plan = datafusion_expr::LogicalPlan::Aggregate(aggr);
199 aggr_plan = aggr_plan.recompute_schema()?;
201
202 let mut new_proj_exprs = proj.expr.clone();
204 for proj_expr in new_proj_exprs.iter_mut() {
205 if let Some(new_proj_expr) =
206 replace_old_proj_exprs.get(&proj_expr.name_for_alias()?)
207 {
208 *proj_expr = new_proj_expr.clone();
209 }
210 *proj_expr = proj_expr
211 .clone()
212 .transform(|expr| {
213 if let Some(new_expr) = replace_old_proj_exprs.get(&expr.name_for_alias()?)
214 {
215 Ok(Transformed::yes(new_expr.clone()))
216 } else {
217 Ok(Transformed::no(expr))
218 }
219 })?
220 .data;
221 }
222 let proj = datafusion_expr::LogicalPlan::Projection(Projection::try_new(
223 new_proj_exprs,
224 Arc::new(aggr_plan),
225 )?);
226 return Ok(Transformed::yes(proj));
227 }
228 }
229 Ok(Transformed::no(plan))
230}
231
232fn expand_avg_analyzer(
234 plan: datafusion_expr::LogicalPlan,
235) -> Result<Transformed<datafusion_expr::LogicalPlan>, DataFusionError> {
236 let mut schema = merge_schema(&plan.inputs());
237
238 if let datafusion_expr::LogicalPlan::TableScan(ts) = &plan {
239 let source_schema =
240 DFSchema::try_from_qualified_schema(ts.table_name.clone(), &ts.source.schema())?;
241 schema.merge(&source_schema);
242 }
243
244 let mut expr_rewrite = ExpandAvgRewriter::new(&schema);
245
246 let name_preserver = NamePreserver::new(&plan);
247 plan.map_expressions(|expr| {
249 let original_name = name_preserver.save(&expr);
250 Ok(expr
251 .rewrite(&mut expr_rewrite)?
252 .update_data(|expr| original_name.restore(expr)))
253 })?
254 .map_data(|plan| plan.recompute_schema())
255}
256
257pub(crate) struct ExpandAvgRewriter<'a> {
263 #[allow(unused)]
265 pub(crate) schema: &'a DFSchema,
266}
267
268impl<'a> ExpandAvgRewriter<'a> {
269 fn new(schema: &'a DFSchema) -> Self {
270 Self { schema }
271 }
272}
273
274impl TreeNodeRewriter for ExpandAvgRewriter<'_> {
275 type Node = Expr;
276
277 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>, DataFusionError> {
278 if let Expr::AggregateFunction(aggr_func) = &expr
279 && aggr_func.func.name() == "avg"
280 {
281 let sum_expr = {
282 let mut tmp = aggr_func.clone();
283 tmp.func = sum_udaf();
284 Expr::AggregateFunction(tmp)
285 };
286 let sum_cast = {
287 let mut tmp = sum_expr.clone();
288 tmp = Expr::Cast(datafusion_expr::Cast {
289 expr: Box::new(tmp),
290 data_type: arrow_schema::DataType::Float64,
291 });
292 tmp
293 };
294
295 let count_expr = {
296 let mut tmp = aggr_func.clone();
297 tmp.func = count_udaf();
298
299 Expr::AggregateFunction(tmp)
300 };
301 let count_expr_ref =
302 Expr::Column(Column::from_qualified_name(count_expr.name_for_alias()?));
303
304 let div = BinaryExpr::new(Box::new(sum_cast), Operator::Divide, Box::new(count_expr));
305 let div_expr = Box::new(Expr::BinaryExpr(div));
306
307 let zero = Box::new(0.lit());
308 let not_zero = BinaryExpr::new(Box::new(count_expr_ref), Operator::NotEq, zero.clone());
309 let not_zero = Box::new(Expr::BinaryExpr(not_zero));
310 let null = Box::new(Expr::Literal(ScalarValue::Null, None));
311
312 let case_when =
313 datafusion_expr::Case::new(None, vec![(not_zero, div_expr)], Some(null));
314 let case_when_expr = Expr::Case(case_when);
315
316 return Ok(Transformed::yes(case_when_expr));
317 }
318
319 Ok(Transformed::no(expr))
320 }
321}
322
323#[derive(Debug)]
325struct TumbleExpandRule;
326
327impl AnalyzerRule for TumbleExpandRule {
328 fn analyze(
329 &self,
330 plan: datafusion_expr::LogicalPlan,
331 _config: &ConfigOptions,
332 ) -> datafusion_common::Result<datafusion_expr::LogicalPlan> {
333 let transformed = plan
334 .transform_up_with_subqueries(expand_tumble_analyzer)?
335 .data;
336 Ok(transformed)
337 }
338
339 fn name(&self) -> &str {
340 "tumble_expand"
341 }
342}
343
344fn expand_tumble_analyzer(
348 plan: datafusion_expr::LogicalPlan,
349) -> Result<Transformed<datafusion_expr::LogicalPlan>, DataFusionError> {
350 if let datafusion_expr::LogicalPlan::Projection(proj) = &plan
351 && let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref()
352 {
353 let mut new_group_expr = vec![];
354 let mut alias_to_expand = HashMap::new();
355 let mut encountered_tumble = false;
356 for expr in aggr.group_expr.iter() {
357 match expr {
358 datafusion_expr::Expr::ScalarFunction(func) if func.name() == "tumble" => {
359 encountered_tumble = true;
360
361 let tumble_start = TumbleExpand::new(TUMBLE_START);
362 let tumble_start = datafusion_expr::expr::ScalarFunction::new_udf(
363 Arc::new(tumble_start.into()),
364 func.args.clone(),
365 );
366 let tumble_start = datafusion_expr::Expr::ScalarFunction(tumble_start);
367 let start_col_name = tumble_start.name_for_alias()?;
368 new_group_expr.push(tumble_start);
369
370 let tumble_end = TumbleExpand::new(TUMBLE_END);
371 let tumble_end = datafusion_expr::expr::ScalarFunction::new_udf(
372 Arc::new(tumble_end.into()),
373 func.args.clone(),
374 );
375 let tumble_end = datafusion_expr::Expr::ScalarFunction(tumble_end);
376 let end_col_name = tumble_end.name_for_alias()?;
377 new_group_expr.push(tumble_end);
378
379 alias_to_expand.insert(expr.name_for_alias()?, (start_col_name, end_col_name));
380 }
381 _ => new_group_expr.push(expr.clone()),
382 }
383 }
384 if !encountered_tumble {
385 return Ok(Transformed::no(plan));
386 }
387 let mut new_aggr = aggr.clone();
388 new_aggr.group_expr = new_group_expr;
389 let new_aggr = datafusion_expr::LogicalPlan::Aggregate(new_aggr).recompute_schema()?;
390 let mut new_proj_expr = vec![];
392 let mut have_expanded = false;
393
394 for proj_expr in proj.expr.iter() {
395 if let Some((start_col_name, end_col_name)) =
396 alias_to_expand.get(&proj_expr.name_for_alias()?)
397 {
398 let start_col = Column::from_qualified_name(start_col_name);
399 let end_col = Column::from_qualified_name(end_col_name);
400 new_proj_expr.push(datafusion_expr::Expr::Column(start_col));
401 new_proj_expr.push(datafusion_expr::Expr::Column(end_col));
402 have_expanded = true;
403 } else {
404 new_proj_expr.push(proj_expr.clone());
405 }
406 }
407
408 if !have_expanded {
410 for (start_col_name, end_col_name) in alias_to_expand.values() {
411 let start_col = Column::from_qualified_name(start_col_name);
412 let end_col = Column::from_qualified_name(end_col_name);
413 new_proj_expr.push(datafusion_expr::Expr::Column(start_col).alias("window_start"));
414 new_proj_expr.push(datafusion_expr::Expr::Column(end_col).alias("window_end"));
415 }
416 }
417
418 let new_proj = datafusion_expr::LogicalPlan::Projection(Projection::try_new(
419 new_proj_expr,
420 Arc::new(new_aggr),
421 )?);
422 return Ok(Transformed::yes(new_proj));
423 }
424
425 Ok(Transformed::no(plan))
426}
427
428#[derive(Debug)]
431pub struct TumbleExpand {
432 signature: Signature,
433 name: String,
434}
435
436impl TumbleExpand {
437 pub fn new(name: &str) -> Self {
438 Self {
439 signature: Signature::new(TypeSignature::UserDefined, Volatility::Immutable),
440 name: name.to_string(),
441 }
442 }
443}
444
445impl ScalarUDFImpl for TumbleExpand {
446 fn as_any(&self) -> &dyn std::any::Any {
447 self
448 }
449
450 fn name(&self) -> &str {
451 &self.name
452 }
453
454 fn signature(&self) -> &Signature {
456 &self.signature
457 }
458
459 fn coerce_types(
460 &self,
461 arg_types: &[arrow_schema::DataType],
462 ) -> datafusion_common::Result<Vec<arrow_schema::DataType>> {
463 match (arg_types.first(), arg_types.get(1), arg_types.get(2)) {
464 (Some(ts), Some(window), opt) => {
465 use arrow_schema::DataType::*;
466 if !matches!(ts, Date32 | Timestamp(_, _)) {
467 return Err(DataFusionError::Plan(
468 format!("Expect timestamp column as first arg for tumble_start, found {:?}", ts)
469 ));
470 }
471 if !matches!(window, Utf8 | Interval(_)) {
472 return Err(DataFusionError::Plan(
473 format!("Expect second arg for window size's type being interval for tumble_start, found {:?}", window),
474 ));
475 }
476
477 if let Some(start_time) = opt
478 && !matches!(start_time, Utf8 | Date32 | Timestamp(_, _)){
479 return Err(DataFusionError::Plan(
480 format!("Expect start_time to either be date, timestamp or string, found {:?}", start_time)
481 ));
482 }
483
484 Ok(arg_types.to_vec())
485 }
486 _ => Err(DataFusionError::Plan(
487 "Expect tumble function have at least two arg(timestamp column and window size) and a third optional arg for starting time".to_string(),
488 )),
489 }
490 }
491
492 fn return_type(
493 &self,
494 arg_types: &[arrow_schema::DataType],
495 ) -> Result<arrow_schema::DataType, DataFusionError> {
496 arg_types.first().cloned().ok_or_else(|| {
497 DataFusionError::Plan(
498 "Expect tumble function have at least two arg(timestamp column and window size)"
499 .to_string(),
500 )
501 })
502 }
503
504 fn invoke_with_args(
505 &self,
506 _args: ScalarFunctionArgs,
507 ) -> datafusion_common::Result<ColumnarValue> {
508 Err(DataFusionError::Plan(
509 "This function should not be executed by datafusion".to_string(),
510 ))
511 }
512}
513
514#[derive(Debug)]
516struct CheckGroupByRule {}
517
518impl CheckGroupByRule {
519 pub fn new() -> Self {
520 Self {}
521 }
522}
523
524impl AnalyzerRule for CheckGroupByRule {
525 fn analyze(
526 &self,
527 plan: datafusion_expr::LogicalPlan,
528 _config: &ConfigOptions,
529 ) -> datafusion_common::Result<datafusion_expr::LogicalPlan> {
530 let transformed = plan
531 .transform_up_with_subqueries(check_group_by_analyzer)?
532 .data;
533 Ok(transformed)
534 }
535
536 fn name(&self) -> &str {
537 "check_groupby"
538 }
539}
540
541fn check_group_by_analyzer(
543 plan: datafusion_expr::LogicalPlan,
544) -> Result<Transformed<datafusion_expr::LogicalPlan>, DataFusionError> {
545 if let datafusion_expr::LogicalPlan::Projection(proj) = &plan
546 && let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref()
547 {
548 let mut found_column_used = FindColumn::new();
549 proj.expr
550 .iter()
551 .map(|i| i.visit(&mut found_column_used))
552 .count();
553 for expr in aggr.group_expr.iter() {
554 if !found_column_used
555 .names_for_alias
556 .contains(&expr.name_for_alias()?)
557 {
558 return Err(DataFusionError::Plan(format!(
559 "Expect {} expr in group by also exist in select list, but select list only contain {:?}",
560 expr.name_for_alias()?,
561 found_column_used.names_for_alias
562 )));
563 }
564 }
565 }
566
567 Ok(Transformed::no(plan))
568}
569
570#[derive(Debug, Default)]
572struct FindColumn {
573 names_for_alias: HashSet<String>,
574}
575
576impl FindColumn {
577 fn new() -> Self {
578 Default::default()
579 }
580}
581
582impl TreeNodeVisitor<'_> for FindColumn {
583 type Node = datafusion_expr::Expr;
584 fn f_down(
585 &mut self,
586 node: &datafusion_expr::Expr,
587 ) -> Result<TreeNodeRecursion, DataFusionError> {
588 if let datafusion_expr::Expr::Column(_) = node {
589 self.names_for_alias.insert(node.name_for_alias()?);
590 }
591 Ok(TreeNodeRecursion::Continue)
592 }
593}