1#![warn(unused)]
18
19use std::collections::{HashMap, HashSet};
20use std::sync::Arc;
21
22use common_error::ext::BoxedError;
23use common_telemetry::debug;
24use datafusion::config::ConfigOptions;
25use datafusion::error::DataFusionError;
26use datafusion::functions_aggregate::count::count_udaf;
27use datafusion::functions_aggregate::sum::sum_udaf;
28use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
29use datafusion::optimizer::common_subexpr_eliminate::CommonSubexprEliminate;
30use datafusion::optimizer::optimize_projections::OptimizeProjections;
31use datafusion::optimizer::simplify_expressions::SimplifyExpressions;
32use datafusion::optimizer::unwrap_cast_in_comparison::UnwrapCastInComparison;
33use datafusion::optimizer::utils::NamePreserver;
34use datafusion::optimizer::{Analyzer, AnalyzerRule, Optimizer, OptimizerContext};
35use datafusion_common::tree_node::{
36 Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
37};
38use datafusion_common::{Column, DFSchema, ScalarValue};
39use datafusion_expr::utils::merge_schema;
40use datafusion_expr::{
41 BinaryExpr, ColumnarValue, Expr, Operator, Projection, ScalarFunctionArgs, ScalarUDFImpl,
42 Signature, TypeSignature, Volatility,
43};
44use query::optimizer::count_wildcard::CountWildcardToTimeIndexRule;
45use query::parser::QueryLanguageParser;
46use query::query_engine::DefaultSerializer;
47use query::QueryEngine;
48use snafu::ResultExt;
49use substrait::DFLogicalSubstraitConvertor;
52
53use crate::adapter::FlownodeContext;
54use crate::error::{DatafusionSnafu, Error, ExternalSnafu, UnexpectedSnafu};
55use crate::expr::{TUMBLE_END, TUMBLE_START};
56use crate::plan::TypedPlan;
57
58pub async fn apply_df_optimizer(
60 plan: datafusion_expr::LogicalPlan,
61) -> Result<datafusion_expr::LogicalPlan, Error> {
62 let cfg = ConfigOptions::new();
63 let analyzer = Analyzer::with_rules(vec![
64 Arc::new(CountWildcardToTimeIndexRule),
65 Arc::new(AvgExpandRule),
66 Arc::new(TumbleExpandRule),
67 Arc::new(CheckGroupByRule::new()),
68 Arc::new(TypeCoercion::new()),
69 ]);
70 let plan = analyzer
71 .execute_and_check(plan, &cfg, |p, r| {
72 debug!("After apply rule {}, get plan: \n{:?}", r.name(), p);
73 })
74 .context(DatafusionSnafu {
75 context: "Fail to apply analyzer",
76 })?;
77
78 let ctx = OptimizerContext::new();
79 let optimizer = Optimizer::with_rules(vec![
80 Arc::new(OptimizeProjections::new()),
81 Arc::new(CommonSubexprEliminate::new()),
82 Arc::new(SimplifyExpressions::new()),
83 Arc::new(UnwrapCastInComparison::new()),
84 ]);
85 let plan = optimizer
86 .optimize(plan, &ctx, |_, _| {})
87 .context(DatafusionSnafu {
88 context: "Fail to apply optimizer",
89 })?;
90
91 Ok(plan)
92}
93
94pub async fn sql_to_flow_plan(
97 ctx: &mut FlownodeContext,
98 engine: &Arc<dyn QueryEngine>,
99 sql: &str,
100) -> Result<TypedPlan, Error> {
101 let query_ctx = ctx.query_context.clone().ok_or_else(|| {
102 UnexpectedSnafu {
103 reason: "Query context is missing",
104 }
105 .build()
106 })?;
107 let stmt = QueryLanguageParser::parse_sql(sql, &query_ctx)
108 .map_err(BoxedError::new)
109 .context(ExternalSnafu)?;
110 let plan = engine
111 .planner()
112 .plan(&stmt, query_ctx)
113 .await
114 .map_err(BoxedError::new)
115 .context(ExternalSnafu)?;
116
117 let opted_plan = apply_df_optimizer(plan).await?;
118
119 let sub_plan = DFLogicalSubstraitConvertor {}
121 .to_sub_plan(&opted_plan, DefaultSerializer)
122 .map_err(BoxedError::new)
123 .context(ExternalSnafu)?;
124
125 let flow_plan = TypedPlan::from_substrait_plan(ctx, &sub_plan).await?;
126
127 Ok(flow_plan)
128}
129
130#[derive(Debug)]
131struct AvgExpandRule;
132
133impl AnalyzerRule for AvgExpandRule {
134 fn analyze(
135 &self,
136 plan: datafusion_expr::LogicalPlan,
137 _config: &ConfigOptions,
138 ) -> datafusion_common::Result<datafusion_expr::LogicalPlan> {
139 let transformed = plan
140 .transform_up_with_subqueries(expand_avg_analyzer)?
141 .data
142 .transform_down_with_subqueries(put_aggr_to_proj_analyzer)?
143 .data;
144 Ok(transformed)
145 }
146
147 fn name(&self) -> &str {
148 "avg_expand"
149 }
150}
151
152fn put_aggr_to_proj_analyzer(
164 plan: datafusion_expr::LogicalPlan,
165) -> Result<Transformed<datafusion_expr::LogicalPlan>, DataFusionError> {
166 if let datafusion_expr::LogicalPlan::Projection(proj) = &plan {
167 if let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref() {
168 let mut replace_old_proj_exprs = HashMap::new();
169 let mut expanded_aggr_exprs = vec![];
170 for aggr_expr in &aggr.aggr_expr {
171 let mut is_composite = false;
172 if let Expr::AggregateFunction(_) = &aggr_expr {
173 expanded_aggr_exprs.push(aggr_expr.clone());
174 } else {
175 let old_name = aggr_expr.name_for_alias()?;
176 let new_proj_expr = aggr_expr
177 .clone()
178 .transform(|ch| {
179 if let Expr::AggregateFunction(_) = &ch {
180 is_composite = true;
181 expanded_aggr_exprs.push(ch.clone());
182 Ok(Transformed::yes(Expr::Column(Column::from_qualified_name(
183 ch.name_for_alias()?,
184 ))))
185 } else {
186 Ok(Transformed::no(ch))
187 }
188 })?
189 .data;
190 replace_old_proj_exprs.insert(old_name, new_proj_expr);
191 }
192 }
193
194 if expanded_aggr_exprs.len() > aggr.aggr_expr.len() {
195 let mut aggr = aggr.clone();
196 aggr.aggr_expr = expanded_aggr_exprs;
197 let mut aggr_plan = datafusion_expr::LogicalPlan::Aggregate(aggr);
198 aggr_plan = aggr_plan.recompute_schema()?;
200
201 let mut new_proj_exprs = proj.expr.clone();
203 for proj_expr in new_proj_exprs.iter_mut() {
204 if let Some(new_proj_expr) =
205 replace_old_proj_exprs.get(&proj_expr.name_for_alias()?)
206 {
207 *proj_expr = new_proj_expr.clone();
208 }
209 *proj_expr = proj_expr
210 .clone()
211 .transform(|expr| {
212 if let Some(new_expr) =
213 replace_old_proj_exprs.get(&expr.name_for_alias()?)
214 {
215 Ok(Transformed::yes(new_expr.clone()))
216 } else {
217 Ok(Transformed::no(expr))
218 }
219 })?
220 .data;
221 }
222 let proj = datafusion_expr::LogicalPlan::Projection(Projection::try_new(
223 new_proj_exprs,
224 Arc::new(aggr_plan),
225 )?);
226 return Ok(Transformed::yes(proj));
227 }
228 }
229 }
230 Ok(Transformed::no(plan))
231}
232
233fn expand_avg_analyzer(
235 plan: datafusion_expr::LogicalPlan,
236) -> Result<Transformed<datafusion_expr::LogicalPlan>, DataFusionError> {
237 let mut schema = merge_schema(&plan.inputs());
238
239 if let datafusion_expr::LogicalPlan::TableScan(ts) = &plan {
240 let source_schema =
241 DFSchema::try_from_qualified_schema(ts.table_name.clone(), &ts.source.schema())?;
242 schema.merge(&source_schema);
243 }
244
245 let mut expr_rewrite = ExpandAvgRewriter::new(&schema);
246
247 let name_preserver = NamePreserver::new(&plan);
248 plan.map_expressions(|expr| {
250 let original_name = name_preserver.save(&expr);
251 Ok(expr
252 .rewrite(&mut expr_rewrite)?
253 .update_data(|expr| original_name.restore(expr)))
254 })?
255 .map_data(|plan| plan.recompute_schema())
256}
257
258pub(crate) struct ExpandAvgRewriter<'a> {
264 #[allow(unused)]
266 pub(crate) schema: &'a DFSchema,
267}
268
269impl<'a> ExpandAvgRewriter<'a> {
270 fn new(schema: &'a DFSchema) -> Self {
271 Self { schema }
272 }
273}
274
275impl TreeNodeRewriter for ExpandAvgRewriter<'_> {
276 type Node = Expr;
277
278 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>, DataFusionError> {
279 if let Expr::AggregateFunction(aggr_func) = &expr {
280 if aggr_func.func.name() == "avg" {
281 let sum_expr = {
282 let mut tmp = aggr_func.clone();
283 tmp.func = sum_udaf();
284 Expr::AggregateFunction(tmp)
285 };
286 let sum_cast = {
287 let mut tmp = sum_expr.clone();
288 tmp = Expr::Cast(datafusion_expr::Cast {
289 expr: Box::new(tmp),
290 data_type: arrow_schema::DataType::Float64,
291 });
292 tmp
293 };
294
295 let count_expr = {
296 let mut tmp = aggr_func.clone();
297 tmp.func = count_udaf();
298
299 Expr::AggregateFunction(tmp)
300 };
301 let count_expr_ref =
302 Expr::Column(Column::from_qualified_name(count_expr.name_for_alias()?));
303
304 let div =
305 BinaryExpr::new(Box::new(sum_cast), Operator::Divide, Box::new(count_expr));
306 let div_expr = Box::new(Expr::BinaryExpr(div));
307
308 let zero = Box::new(Expr::Literal(ScalarValue::Int64(Some(0))));
309 let not_zero =
310 BinaryExpr::new(Box::new(count_expr_ref), Operator::NotEq, zero.clone());
311 let not_zero = Box::new(Expr::BinaryExpr(not_zero));
312 let null = Box::new(Expr::Literal(ScalarValue::Null));
313
314 let case_when =
315 datafusion_expr::Case::new(None, vec![(not_zero, div_expr)], Some(null));
316 let case_when_expr = Expr::Case(case_when);
317
318 return Ok(Transformed::yes(case_when_expr));
319 }
320 }
321
322 Ok(Transformed::no(expr))
323 }
324}
325
326#[derive(Debug)]
328struct TumbleExpandRule;
329
330impl AnalyzerRule for TumbleExpandRule {
331 fn analyze(
332 &self,
333 plan: datafusion_expr::LogicalPlan,
334 _config: &ConfigOptions,
335 ) -> datafusion_common::Result<datafusion_expr::LogicalPlan> {
336 let transformed = plan
337 .transform_up_with_subqueries(expand_tumble_analyzer)?
338 .data;
339 Ok(transformed)
340 }
341
342 fn name(&self) -> &str {
343 "tumble_expand"
344 }
345}
346
347fn expand_tumble_analyzer(
351 plan: datafusion_expr::LogicalPlan,
352) -> Result<Transformed<datafusion_expr::LogicalPlan>, DataFusionError> {
353 if let datafusion_expr::LogicalPlan::Projection(proj) = &plan {
354 if let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref() {
355 let mut new_group_expr = vec![];
356 let mut alias_to_expand = HashMap::new();
357 let mut encountered_tumble = false;
358 for expr in aggr.group_expr.iter() {
359 match expr {
360 datafusion_expr::Expr::ScalarFunction(func) if func.name() == "tumble" => {
361 encountered_tumble = true;
362
363 let tumble_start = TumbleExpand::new(TUMBLE_START);
364 let tumble_start = datafusion_expr::expr::ScalarFunction::new_udf(
365 Arc::new(tumble_start.into()),
366 func.args.clone(),
367 );
368 let tumble_start = datafusion_expr::Expr::ScalarFunction(tumble_start);
369 let start_col_name = tumble_start.name_for_alias()?;
370 new_group_expr.push(tumble_start);
371
372 let tumble_end = TumbleExpand::new(TUMBLE_END);
373 let tumble_end = datafusion_expr::expr::ScalarFunction::new_udf(
374 Arc::new(tumble_end.into()),
375 func.args.clone(),
376 );
377 let tumble_end = datafusion_expr::Expr::ScalarFunction(tumble_end);
378 let end_col_name = tumble_end.name_for_alias()?;
379 new_group_expr.push(tumble_end);
380
381 alias_to_expand
382 .insert(expr.name_for_alias()?, (start_col_name, end_col_name));
383 }
384 _ => new_group_expr.push(expr.clone()),
385 }
386 }
387 if !encountered_tumble {
388 return Ok(Transformed::no(plan));
389 }
390 let mut new_aggr = aggr.clone();
391 new_aggr.group_expr = new_group_expr;
392 let new_aggr = datafusion_expr::LogicalPlan::Aggregate(new_aggr).recompute_schema()?;
393 let mut new_proj_expr = vec![];
395 let mut have_expanded = false;
396
397 for proj_expr in proj.expr.iter() {
398 if let Some((start_col_name, end_col_name)) =
399 alias_to_expand.get(&proj_expr.name_for_alias()?)
400 {
401 let start_col = Column::from_qualified_name(start_col_name);
402 let end_col = Column::from_qualified_name(end_col_name);
403 new_proj_expr.push(datafusion_expr::Expr::Column(start_col));
404 new_proj_expr.push(datafusion_expr::Expr::Column(end_col));
405 have_expanded = true;
406 } else {
407 new_proj_expr.push(proj_expr.clone());
408 }
409 }
410
411 if !have_expanded {
413 for (start_col_name, end_col_name) in alias_to_expand.values() {
414 let start_col = Column::from_qualified_name(start_col_name);
415 let end_col = Column::from_qualified_name(end_col_name);
416 new_proj_expr
417 .push(datafusion_expr::Expr::Column(start_col).alias("window_start"));
418 new_proj_expr.push(datafusion_expr::Expr::Column(end_col).alias("window_end"));
419 }
420 }
421
422 let new_proj = datafusion_expr::LogicalPlan::Projection(Projection::try_new(
423 new_proj_expr,
424 Arc::new(new_aggr),
425 )?);
426 return Ok(Transformed::yes(new_proj));
427 }
428 }
429
430 Ok(Transformed::no(plan))
431}
432
433#[derive(Debug)]
436pub struct TumbleExpand {
437 signature: Signature,
438 name: String,
439}
440
441impl TumbleExpand {
442 pub fn new(name: &str) -> Self {
443 Self {
444 signature: Signature::new(TypeSignature::UserDefined, Volatility::Immutable),
445 name: name.to_string(),
446 }
447 }
448}
449
450impl ScalarUDFImpl for TumbleExpand {
451 fn as_any(&self) -> &dyn std::any::Any {
452 self
453 }
454
455 fn name(&self) -> &str {
456 &self.name
457 }
458
459 fn signature(&self) -> &Signature {
461 &self.signature
462 }
463
464 fn coerce_types(
465 &self,
466 arg_types: &[arrow_schema::DataType],
467 ) -> datafusion_common::Result<Vec<arrow_schema::DataType>> {
468 match (arg_types.first(), arg_types.get(1), arg_types.get(2)) {
469 (Some(ts), Some(window), opt) => {
470 use arrow_schema::DataType::*;
471 if !matches!(ts, Date32 | Timestamp(_, _)) {
472 return Err(DataFusionError::Plan(
473 format!("Expect timestamp column as first arg for tumble_start, found {:?}", ts)
474 ));
475 }
476 if !matches!(window, Utf8 | Interval(_)) {
477 return Err(DataFusionError::Plan(
478 format!("Expect second arg for window size's type being interval for tumble_start, found {:?}", window),
479 ));
480 }
481
482 if let Some(start_time) = opt{
483 if !matches!(start_time, Utf8 | Date32 | Timestamp(_, _)){
484 return Err(DataFusionError::Plan(
485 format!("Expect start_time to either be date, timestamp or string, found {:?}", start_time)
486 ));
487 }
488 }
489
490 Ok(arg_types.to_vec())
491 }
492 _ => Err(DataFusionError::Plan(
493 "Expect tumble function have at least two arg(timestamp column and window size) and a third optional arg for starting time".to_string(),
494 )),
495 }
496 }
497
498 fn return_type(
499 &self,
500 arg_types: &[arrow_schema::DataType],
501 ) -> Result<arrow_schema::DataType, DataFusionError> {
502 arg_types.first().cloned().ok_or_else(|| {
503 DataFusionError::Plan(
504 "Expect tumble function have at least two arg(timestamp column and window size)"
505 .to_string(),
506 )
507 })
508 }
509
510 fn invoke_with_args(
511 &self,
512 _args: ScalarFunctionArgs,
513 ) -> datafusion_common::Result<ColumnarValue> {
514 Err(DataFusionError::Plan(
515 "This function should not be executed by datafusion".to_string(),
516 ))
517 }
518}
519
520#[derive(Debug)]
522struct CheckGroupByRule {}
523
524impl CheckGroupByRule {
525 pub fn new() -> Self {
526 Self {}
527 }
528}
529
530impl AnalyzerRule for CheckGroupByRule {
531 fn analyze(
532 &self,
533 plan: datafusion_expr::LogicalPlan,
534 _config: &ConfigOptions,
535 ) -> datafusion_common::Result<datafusion_expr::LogicalPlan> {
536 let transformed = plan
537 .transform_up_with_subqueries(check_group_by_analyzer)?
538 .data;
539 Ok(transformed)
540 }
541
542 fn name(&self) -> &str {
543 "check_groupby"
544 }
545}
546
547fn check_group_by_analyzer(
549 plan: datafusion_expr::LogicalPlan,
550) -> Result<Transformed<datafusion_expr::LogicalPlan>, DataFusionError> {
551 if let datafusion_expr::LogicalPlan::Projection(proj) = &plan {
552 if let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref() {
553 let mut found_column_used = FindColumn::new();
554 proj.expr
555 .iter()
556 .map(|i| i.visit(&mut found_column_used))
557 .count();
558 for expr in aggr.group_expr.iter() {
559 if !found_column_used
560 .names_for_alias
561 .contains(&expr.name_for_alias()?)
562 {
563 return Err(DataFusionError::Plan(format!("Expect {} expr in group by also exist in select list, but select list only contain {:?}",expr.name_for_alias()?, found_column_used.names_for_alias)));
564 }
565 }
566 }
567 }
568
569 Ok(Transformed::no(plan))
570}
571
572#[derive(Debug, Default)]
574struct FindColumn {
575 names_for_alias: HashSet<String>,
576}
577
578impl FindColumn {
579 fn new() -> Self {
580 Default::default()
581 }
582}
583
584impl TreeNodeVisitor<'_> for FindColumn {
585 type Node = datafusion_expr::Expr;
586 fn f_down(
587 &mut self,
588 node: &datafusion_expr::Expr,
589 ) -> Result<TreeNodeRecursion, DataFusionError> {
590 if let datafusion_expr::Expr::Column(_) = node {
591 self.names_for_alias.insert(node.name_for_alias()?);
592 }
593 Ok(TreeNodeRecursion::Continue)
594 }
595}