1#![warn(unused)]
18
19use std::collections::{HashMap, HashSet};
20use std::sync::Arc;
21
22use common_error::ext::BoxedError;
23use common_telemetry::debug;
24use datafusion::config::ConfigOptions;
25use datafusion::error::DataFusionError;
26use datafusion::functions_aggregate::count::count_udaf;
27use datafusion::functions_aggregate::sum::sum_udaf;
28use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
29use datafusion::optimizer::common_subexpr_eliminate::CommonSubexprEliminate;
30use datafusion::optimizer::optimize_projections::OptimizeProjections;
31use datafusion::optimizer::simplify_expressions::SimplifyExpressions;
32use datafusion::optimizer::utils::NamePreserver;
33use datafusion::optimizer::{Analyzer, AnalyzerRule, Optimizer, OptimizerContext};
34use datafusion_common::tree_node::{
35 Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
36};
37use datafusion_common::{Column, DFSchema, ScalarValue};
38use datafusion_expr::utils::merge_schema;
39use datafusion_expr::{
40 BinaryExpr, ColumnarValue, Expr, Literal, Operator, Projection, ScalarFunctionArgs,
41 ScalarUDFImpl, Signature, TypeSignature, Volatility,
42};
43use query::optimizer::count_wildcard::CountWildcardToTimeIndexRule;
44use query::parser::QueryLanguageParser;
45use query::query_engine::DefaultSerializer;
46use query::QueryEngine;
47use snafu::ResultExt;
48use substrait::DFLogicalSubstraitConvertor;
51
52use crate::adapter::FlownodeContext;
53use crate::error::{DatafusionSnafu, Error, ExternalSnafu, UnexpectedSnafu};
54use crate::expr::{TUMBLE_END, TUMBLE_START};
55use crate::plan::TypedPlan;
56
57pub async fn apply_df_optimizer(
59 plan: datafusion_expr::LogicalPlan,
60) -> Result<datafusion_expr::LogicalPlan, Error> {
61 let cfg = ConfigOptions::new();
62 let analyzer = Analyzer::with_rules(vec![
63 Arc::new(CountWildcardToTimeIndexRule),
64 Arc::new(AvgExpandRule),
65 Arc::new(TumbleExpandRule),
66 Arc::new(CheckGroupByRule::new()),
67 Arc::new(TypeCoercion::new()),
68 ]);
69 let plan = analyzer
70 .execute_and_check(plan, &cfg, |p, r| {
71 debug!("After apply rule {}, get plan: \n{:?}", r.name(), p);
72 })
73 .context(DatafusionSnafu {
74 context: "Fail to apply analyzer",
75 })?;
76
77 let ctx = OptimizerContext::new();
78 let optimizer = Optimizer::with_rules(vec![
79 Arc::new(OptimizeProjections::new()),
80 Arc::new(CommonSubexprEliminate::new()),
81 Arc::new(SimplifyExpressions::new()),
82 ]);
83 let plan = optimizer
84 .optimize(plan, &ctx, |_, _| {})
85 .context(DatafusionSnafu {
86 context: "Fail to apply optimizer",
87 })?;
88
89 Ok(plan)
90}
91
92pub async fn sql_to_flow_plan(
95 ctx: &mut FlownodeContext,
96 engine: &Arc<dyn QueryEngine>,
97 sql: &str,
98) -> Result<TypedPlan, Error> {
99 let query_ctx = ctx.query_context.clone().ok_or_else(|| {
100 UnexpectedSnafu {
101 reason: "Query context is missing",
102 }
103 .build()
104 })?;
105 let stmt = QueryLanguageParser::parse_sql(sql, &query_ctx)
106 .map_err(BoxedError::new)
107 .context(ExternalSnafu)?;
108 let plan = engine
109 .planner()
110 .plan(&stmt, query_ctx)
111 .await
112 .map_err(BoxedError::new)
113 .context(ExternalSnafu)?;
114
115 let opted_plan = apply_df_optimizer(plan).await?;
116
117 let sub_plan = DFLogicalSubstraitConvertor {}
119 .to_sub_plan(&opted_plan, DefaultSerializer)
120 .map_err(BoxedError::new)
121 .context(ExternalSnafu)?;
122
123 let flow_plan = TypedPlan::from_substrait_plan(ctx, &sub_plan).await?;
124
125 Ok(flow_plan)
126}
127
128#[derive(Debug)]
129struct AvgExpandRule;
130
131impl AnalyzerRule for AvgExpandRule {
132 fn analyze(
133 &self,
134 plan: datafusion_expr::LogicalPlan,
135 _config: &ConfigOptions,
136 ) -> datafusion_common::Result<datafusion_expr::LogicalPlan> {
137 let transformed = plan
138 .transform_up_with_subqueries(expand_avg_analyzer)?
139 .data
140 .transform_down_with_subqueries(put_aggr_to_proj_analyzer)?
141 .data;
142 Ok(transformed)
143 }
144
145 fn name(&self) -> &str {
146 "avg_expand"
147 }
148}
149
150fn put_aggr_to_proj_analyzer(
162 plan: datafusion_expr::LogicalPlan,
163) -> Result<Transformed<datafusion_expr::LogicalPlan>, DataFusionError> {
164 if let datafusion_expr::LogicalPlan::Projection(proj) = &plan {
165 if let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref() {
166 let mut replace_old_proj_exprs = HashMap::new();
167 let mut expanded_aggr_exprs = vec![];
168 for aggr_expr in &aggr.aggr_expr {
169 let mut is_composite = false;
170 if let Expr::AggregateFunction(_) = &aggr_expr {
171 expanded_aggr_exprs.push(aggr_expr.clone());
172 } else {
173 let old_name = aggr_expr.name_for_alias()?;
174 let new_proj_expr = aggr_expr
175 .clone()
176 .transform(|ch| {
177 if let Expr::AggregateFunction(_) = &ch {
178 is_composite = true;
179 expanded_aggr_exprs.push(ch.clone());
180 Ok(Transformed::yes(Expr::Column(Column::from_qualified_name(
181 ch.name_for_alias()?,
182 ))))
183 } else {
184 Ok(Transformed::no(ch))
185 }
186 })?
187 .data;
188 replace_old_proj_exprs.insert(old_name, new_proj_expr);
189 }
190 }
191
192 if expanded_aggr_exprs.len() > aggr.aggr_expr.len() {
193 let mut aggr = aggr.clone();
194 aggr.aggr_expr = expanded_aggr_exprs;
195 let mut aggr_plan = datafusion_expr::LogicalPlan::Aggregate(aggr);
196 aggr_plan = aggr_plan.recompute_schema()?;
198
199 let mut new_proj_exprs = proj.expr.clone();
201 for proj_expr in new_proj_exprs.iter_mut() {
202 if let Some(new_proj_expr) =
203 replace_old_proj_exprs.get(&proj_expr.name_for_alias()?)
204 {
205 *proj_expr = new_proj_expr.clone();
206 }
207 *proj_expr = proj_expr
208 .clone()
209 .transform(|expr| {
210 if let Some(new_expr) =
211 replace_old_proj_exprs.get(&expr.name_for_alias()?)
212 {
213 Ok(Transformed::yes(new_expr.clone()))
214 } else {
215 Ok(Transformed::no(expr))
216 }
217 })?
218 .data;
219 }
220 let proj = datafusion_expr::LogicalPlan::Projection(Projection::try_new(
221 new_proj_exprs,
222 Arc::new(aggr_plan),
223 )?);
224 return Ok(Transformed::yes(proj));
225 }
226 }
227 }
228 Ok(Transformed::no(plan))
229}
230
231fn expand_avg_analyzer(
233 plan: datafusion_expr::LogicalPlan,
234) -> Result<Transformed<datafusion_expr::LogicalPlan>, DataFusionError> {
235 let mut schema = merge_schema(&plan.inputs());
236
237 if let datafusion_expr::LogicalPlan::TableScan(ts) = &plan {
238 let source_schema =
239 DFSchema::try_from_qualified_schema(ts.table_name.clone(), &ts.source.schema())?;
240 schema.merge(&source_schema);
241 }
242
243 let mut expr_rewrite = ExpandAvgRewriter::new(&schema);
244
245 let name_preserver = NamePreserver::new(&plan);
246 plan.map_expressions(|expr| {
248 let original_name = name_preserver.save(&expr);
249 Ok(expr
250 .rewrite(&mut expr_rewrite)?
251 .update_data(|expr| original_name.restore(expr)))
252 })?
253 .map_data(|plan| plan.recompute_schema())
254}
255
256pub(crate) struct ExpandAvgRewriter<'a> {
262 #[allow(unused)]
264 pub(crate) schema: &'a DFSchema,
265}
266
267impl<'a> ExpandAvgRewriter<'a> {
268 fn new(schema: &'a DFSchema) -> Self {
269 Self { schema }
270 }
271}
272
273impl TreeNodeRewriter for ExpandAvgRewriter<'_> {
274 type Node = Expr;
275
276 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>, DataFusionError> {
277 if let Expr::AggregateFunction(aggr_func) = &expr {
278 if aggr_func.func.name() == "avg" {
279 let sum_expr = {
280 let mut tmp = aggr_func.clone();
281 tmp.func = sum_udaf();
282 Expr::AggregateFunction(tmp)
283 };
284 let sum_cast = {
285 let mut tmp = sum_expr.clone();
286 tmp = Expr::Cast(datafusion_expr::Cast {
287 expr: Box::new(tmp),
288 data_type: arrow_schema::DataType::Float64,
289 });
290 tmp
291 };
292
293 let count_expr = {
294 let mut tmp = aggr_func.clone();
295 tmp.func = count_udaf();
296
297 Expr::AggregateFunction(tmp)
298 };
299 let count_expr_ref =
300 Expr::Column(Column::from_qualified_name(count_expr.name_for_alias()?));
301
302 let div =
303 BinaryExpr::new(Box::new(sum_cast), Operator::Divide, Box::new(count_expr));
304 let div_expr = Box::new(Expr::BinaryExpr(div));
305
306 let zero = Box::new(0.lit());
307 let not_zero =
308 BinaryExpr::new(Box::new(count_expr_ref), Operator::NotEq, zero.clone());
309 let not_zero = Box::new(Expr::BinaryExpr(not_zero));
310 let null = Box::new(Expr::Literal(ScalarValue::Null, None));
311
312 let case_when =
313 datafusion_expr::Case::new(None, vec![(not_zero, div_expr)], Some(null));
314 let case_when_expr = Expr::Case(case_when);
315
316 return Ok(Transformed::yes(case_when_expr));
317 }
318 }
319
320 Ok(Transformed::no(expr))
321 }
322}
323
324#[derive(Debug)]
326struct TumbleExpandRule;
327
328impl AnalyzerRule for TumbleExpandRule {
329 fn analyze(
330 &self,
331 plan: datafusion_expr::LogicalPlan,
332 _config: &ConfigOptions,
333 ) -> datafusion_common::Result<datafusion_expr::LogicalPlan> {
334 let transformed = plan
335 .transform_up_with_subqueries(expand_tumble_analyzer)?
336 .data;
337 Ok(transformed)
338 }
339
340 fn name(&self) -> &str {
341 "tumble_expand"
342 }
343}
344
345fn expand_tumble_analyzer(
349 plan: datafusion_expr::LogicalPlan,
350) -> Result<Transformed<datafusion_expr::LogicalPlan>, DataFusionError> {
351 if let datafusion_expr::LogicalPlan::Projection(proj) = &plan {
352 if let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref() {
353 let mut new_group_expr = vec![];
354 let mut alias_to_expand = HashMap::new();
355 let mut encountered_tumble = false;
356 for expr in aggr.group_expr.iter() {
357 match expr {
358 datafusion_expr::Expr::ScalarFunction(func) if func.name() == "tumble" => {
359 encountered_tumble = true;
360
361 let tumble_start = TumbleExpand::new(TUMBLE_START);
362 let tumble_start = datafusion_expr::expr::ScalarFunction::new_udf(
363 Arc::new(tumble_start.into()),
364 func.args.clone(),
365 );
366 let tumble_start = datafusion_expr::Expr::ScalarFunction(tumble_start);
367 let start_col_name = tumble_start.name_for_alias()?;
368 new_group_expr.push(tumble_start);
369
370 let tumble_end = TumbleExpand::new(TUMBLE_END);
371 let tumble_end = datafusion_expr::expr::ScalarFunction::new_udf(
372 Arc::new(tumble_end.into()),
373 func.args.clone(),
374 );
375 let tumble_end = datafusion_expr::Expr::ScalarFunction(tumble_end);
376 let end_col_name = tumble_end.name_for_alias()?;
377 new_group_expr.push(tumble_end);
378
379 alias_to_expand
380 .insert(expr.name_for_alias()?, (start_col_name, end_col_name));
381 }
382 _ => new_group_expr.push(expr.clone()),
383 }
384 }
385 if !encountered_tumble {
386 return Ok(Transformed::no(plan));
387 }
388 let mut new_aggr = aggr.clone();
389 new_aggr.group_expr = new_group_expr;
390 let new_aggr = datafusion_expr::LogicalPlan::Aggregate(new_aggr).recompute_schema()?;
391 let mut new_proj_expr = vec![];
393 let mut have_expanded = false;
394
395 for proj_expr in proj.expr.iter() {
396 if let Some((start_col_name, end_col_name)) =
397 alias_to_expand.get(&proj_expr.name_for_alias()?)
398 {
399 let start_col = Column::from_qualified_name(start_col_name);
400 let end_col = Column::from_qualified_name(end_col_name);
401 new_proj_expr.push(datafusion_expr::Expr::Column(start_col));
402 new_proj_expr.push(datafusion_expr::Expr::Column(end_col));
403 have_expanded = true;
404 } else {
405 new_proj_expr.push(proj_expr.clone());
406 }
407 }
408
409 if !have_expanded {
411 for (start_col_name, end_col_name) in alias_to_expand.values() {
412 let start_col = Column::from_qualified_name(start_col_name);
413 let end_col = Column::from_qualified_name(end_col_name);
414 new_proj_expr
415 .push(datafusion_expr::Expr::Column(start_col).alias("window_start"));
416 new_proj_expr.push(datafusion_expr::Expr::Column(end_col).alias("window_end"));
417 }
418 }
419
420 let new_proj = datafusion_expr::LogicalPlan::Projection(Projection::try_new(
421 new_proj_expr,
422 Arc::new(new_aggr),
423 )?);
424 return Ok(Transformed::yes(new_proj));
425 }
426 }
427
428 Ok(Transformed::no(plan))
429}
430
431#[derive(Debug)]
434pub struct TumbleExpand {
435 signature: Signature,
436 name: String,
437}
438
439impl TumbleExpand {
440 pub fn new(name: &str) -> Self {
441 Self {
442 signature: Signature::new(TypeSignature::UserDefined, Volatility::Immutable),
443 name: name.to_string(),
444 }
445 }
446}
447
448impl ScalarUDFImpl for TumbleExpand {
449 fn as_any(&self) -> &dyn std::any::Any {
450 self
451 }
452
453 fn name(&self) -> &str {
454 &self.name
455 }
456
457 fn signature(&self) -> &Signature {
459 &self.signature
460 }
461
462 fn coerce_types(
463 &self,
464 arg_types: &[arrow_schema::DataType],
465 ) -> datafusion_common::Result<Vec<arrow_schema::DataType>> {
466 match (arg_types.first(), arg_types.get(1), arg_types.get(2)) {
467 (Some(ts), Some(window), opt) => {
468 use arrow_schema::DataType::*;
469 if !matches!(ts, Date32 | Timestamp(_, _)) {
470 return Err(DataFusionError::Plan(
471 format!("Expect timestamp column as first arg for tumble_start, found {:?}", ts)
472 ));
473 }
474 if !matches!(window, Utf8 | Interval(_)) {
475 return Err(DataFusionError::Plan(
476 format!("Expect second arg for window size's type being interval for tumble_start, found {:?}", window),
477 ));
478 }
479
480 if let Some(start_time) = opt{
481 if !matches!(start_time, Utf8 | Date32 | Timestamp(_, _)){
482 return Err(DataFusionError::Plan(
483 format!("Expect start_time to either be date, timestamp or string, found {:?}", start_time)
484 ));
485 }
486 }
487
488 Ok(arg_types.to_vec())
489 }
490 _ => Err(DataFusionError::Plan(
491 "Expect tumble function have at least two arg(timestamp column and window size) and a third optional arg for starting time".to_string(),
492 )),
493 }
494 }
495
496 fn return_type(
497 &self,
498 arg_types: &[arrow_schema::DataType],
499 ) -> Result<arrow_schema::DataType, DataFusionError> {
500 arg_types.first().cloned().ok_or_else(|| {
501 DataFusionError::Plan(
502 "Expect tumble function have at least two arg(timestamp column and window size)"
503 .to_string(),
504 )
505 })
506 }
507
508 fn invoke_with_args(
509 &self,
510 _args: ScalarFunctionArgs,
511 ) -> datafusion_common::Result<ColumnarValue> {
512 Err(DataFusionError::Plan(
513 "This function should not be executed by datafusion".to_string(),
514 ))
515 }
516}
517
518#[derive(Debug)]
520struct CheckGroupByRule {}
521
522impl CheckGroupByRule {
523 pub fn new() -> Self {
524 Self {}
525 }
526}
527
528impl AnalyzerRule for CheckGroupByRule {
529 fn analyze(
530 &self,
531 plan: datafusion_expr::LogicalPlan,
532 _config: &ConfigOptions,
533 ) -> datafusion_common::Result<datafusion_expr::LogicalPlan> {
534 let transformed = plan
535 .transform_up_with_subqueries(check_group_by_analyzer)?
536 .data;
537 Ok(transformed)
538 }
539
540 fn name(&self) -> &str {
541 "check_groupby"
542 }
543}
544
545fn check_group_by_analyzer(
547 plan: datafusion_expr::LogicalPlan,
548) -> Result<Transformed<datafusion_expr::LogicalPlan>, DataFusionError> {
549 if let datafusion_expr::LogicalPlan::Projection(proj) = &plan {
550 if let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref() {
551 let mut found_column_used = FindColumn::new();
552 proj.expr
553 .iter()
554 .map(|i| i.visit(&mut found_column_used))
555 .count();
556 for expr in aggr.group_expr.iter() {
557 if !found_column_used
558 .names_for_alias
559 .contains(&expr.name_for_alias()?)
560 {
561 return Err(DataFusionError::Plan(format!("Expect {} expr in group by also exist in select list, but select list only contain {:?}",expr.name_for_alias()?, found_column_used.names_for_alias)));
562 }
563 }
564 }
565 }
566
567 Ok(Transformed::no(plan))
568}
569
570#[derive(Debug, Default)]
572struct FindColumn {
573 names_for_alias: HashSet<String>,
574}
575
576impl FindColumn {
577 fn new() -> Self {
578 Default::default()
579 }
580}
581
582impl TreeNodeVisitor<'_> for FindColumn {
583 type Node = datafusion_expr::Expr;
584 fn f_down(
585 &mut self,
586 node: &datafusion_expr::Expr,
587 ) -> Result<TreeNodeRecursion, DataFusionError> {
588 if let datafusion_expr::Expr::Column(_) = node {
589 self.names_for_alias.insert(node.name_for_alias()?);
590 }
591 Ok(TreeNodeRecursion::Continue)
592 }
593}