flow/
df_optimizer.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Datafusion optimizer for flow plan
16
17#![warn(unused)]
18
19use std::collections::{HashMap, HashSet};
20use std::sync::Arc;
21
22use common_error::ext::BoxedError;
23use common_telemetry::debug;
24use datafusion::config::ConfigOptions;
25use datafusion::error::DataFusionError;
26use datafusion::functions_aggregate::count::count_udaf;
27use datafusion::functions_aggregate::sum::sum_udaf;
28use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
29use datafusion::optimizer::common_subexpr_eliminate::CommonSubexprEliminate;
30use datafusion::optimizer::optimize_projections::OptimizeProjections;
31use datafusion::optimizer::simplify_expressions::SimplifyExpressions;
32use datafusion::optimizer::utils::NamePreserver;
33use datafusion::optimizer::{Analyzer, AnalyzerRule, Optimizer, OptimizerContext};
34use datafusion_common::tree_node::{
35    Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
36};
37use datafusion_common::{Column, DFSchema, ScalarValue};
38use datafusion_expr::utils::merge_schema;
39use datafusion_expr::{
40    BinaryExpr, ColumnarValue, Expr, Literal, Operator, Projection, ScalarFunctionArgs,
41    ScalarUDFImpl, Signature, TypeSignature, Volatility,
42};
43use query::optimizer::count_wildcard::CountWildcardToTimeIndexRule;
44use query::parser::QueryLanguageParser;
45use query::query_engine::DefaultSerializer;
46use query::QueryEngine;
47use snafu::ResultExt;
48/// note here we are using the `substrait_proto_df` crate from the `substrait` module and
49/// rename it to `substrait_proto`
50use substrait::DFLogicalSubstraitConvertor;
51
52use crate::adapter::FlownodeContext;
53use crate::error::{DatafusionSnafu, Error, ExternalSnafu, UnexpectedSnafu};
54use crate::expr::{TUMBLE_END, TUMBLE_START};
55use crate::plan::TypedPlan;
56
57// TODO(discord9): use `Analyzer` to manage rules if more `AnalyzerRule` is needed
58pub async fn apply_df_optimizer(
59    plan: datafusion_expr::LogicalPlan,
60) -> Result<datafusion_expr::LogicalPlan, Error> {
61    let cfg = ConfigOptions::new();
62    let analyzer = Analyzer::with_rules(vec![
63        Arc::new(CountWildcardToTimeIndexRule),
64        Arc::new(AvgExpandRule),
65        Arc::new(TumbleExpandRule),
66        Arc::new(CheckGroupByRule::new()),
67        Arc::new(TypeCoercion::new()),
68    ]);
69    let plan = analyzer
70        .execute_and_check(plan, &cfg, |p, r| {
71            debug!("After apply rule {}, get plan: \n{:?}", r.name(), p);
72        })
73        .context(DatafusionSnafu {
74            context: "Fail to apply analyzer",
75        })?;
76
77    let ctx = OptimizerContext::new();
78    let optimizer = Optimizer::with_rules(vec![
79        Arc::new(OptimizeProjections::new()),
80        Arc::new(CommonSubexprEliminate::new()),
81        Arc::new(SimplifyExpressions::new()),
82    ]);
83    let plan = optimizer
84        .optimize(plan, &ctx, |_, _| {})
85        .context(DatafusionSnafu {
86            context: "Fail to apply optimizer",
87        })?;
88
89    Ok(plan)
90}
91
92/// To reuse existing code for parse sql, the sql is first parsed into a datafusion logical plan,
93/// then to a substrait plan, and finally to a flow plan.
94pub async fn sql_to_flow_plan(
95    ctx: &mut FlownodeContext,
96    engine: &Arc<dyn QueryEngine>,
97    sql: &str,
98) -> Result<TypedPlan, Error> {
99    let query_ctx = ctx.query_context.clone().ok_or_else(|| {
100        UnexpectedSnafu {
101            reason: "Query context is missing",
102        }
103        .build()
104    })?;
105    let stmt = QueryLanguageParser::parse_sql(sql, &query_ctx)
106        .map_err(BoxedError::new)
107        .context(ExternalSnafu)?;
108    let plan = engine
109        .planner()
110        .plan(&stmt, query_ctx)
111        .await
112        .map_err(BoxedError::new)
113        .context(ExternalSnafu)?;
114
115    let opted_plan = apply_df_optimizer(plan).await?;
116
117    // TODO(discord9): add df optimization
118    let sub_plan = DFLogicalSubstraitConvertor {}
119        .to_sub_plan(&opted_plan, DefaultSerializer)
120        .map_err(BoxedError::new)
121        .context(ExternalSnafu)?;
122
123    let flow_plan = TypedPlan::from_substrait_plan(ctx, &sub_plan).await?;
124
125    Ok(flow_plan)
126}
127
128#[derive(Debug)]
129struct AvgExpandRule;
130
131impl AnalyzerRule for AvgExpandRule {
132    fn analyze(
133        &self,
134        plan: datafusion_expr::LogicalPlan,
135        _config: &ConfigOptions,
136    ) -> datafusion_common::Result<datafusion_expr::LogicalPlan> {
137        let transformed = plan
138            .transform_up_with_subqueries(expand_avg_analyzer)?
139            .data
140            .transform_down_with_subqueries(put_aggr_to_proj_analyzer)?
141            .data;
142        Ok(transformed)
143    }
144
145    fn name(&self) -> &str {
146        "avg_expand"
147    }
148}
149
150/// lift aggr's composite aggr_expr to outer proj, and leave aggr only with simple direct aggr expr
151/// i.e.
152/// ```ignore
153/// proj: avg(x)
154/// -- aggr: [sum(x)/count(x) as avg(x)]
155/// ```
156/// becomes:
157/// ```ignore
158/// proj: sum(x)/count(x) as avg(x)
159/// -- aggr: [sum(x), count(x)]
160/// ```
161fn put_aggr_to_proj_analyzer(
162    plan: datafusion_expr::LogicalPlan,
163) -> Result<Transformed<datafusion_expr::LogicalPlan>, DataFusionError> {
164    if let datafusion_expr::LogicalPlan::Projection(proj) = &plan {
165        if let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref() {
166            let mut replace_old_proj_exprs = HashMap::new();
167            let mut expanded_aggr_exprs = vec![];
168            for aggr_expr in &aggr.aggr_expr {
169                let mut is_composite = false;
170                if let Expr::AggregateFunction(_) = &aggr_expr {
171                    expanded_aggr_exprs.push(aggr_expr.clone());
172                } else {
173                    let old_name = aggr_expr.name_for_alias()?;
174                    let new_proj_expr = aggr_expr
175                        .clone()
176                        .transform(|ch| {
177                            if let Expr::AggregateFunction(_) = &ch {
178                                is_composite = true;
179                                expanded_aggr_exprs.push(ch.clone());
180                                Ok(Transformed::yes(Expr::Column(Column::from_qualified_name(
181                                    ch.name_for_alias()?,
182                                ))))
183                            } else {
184                                Ok(Transformed::no(ch))
185                            }
186                        })?
187                        .data;
188                    replace_old_proj_exprs.insert(old_name, new_proj_expr);
189                }
190            }
191
192            if expanded_aggr_exprs.len() > aggr.aggr_expr.len() {
193                let mut aggr = aggr.clone();
194                aggr.aggr_expr = expanded_aggr_exprs;
195                let mut aggr_plan = datafusion_expr::LogicalPlan::Aggregate(aggr);
196                // important to recompute schema after changing aggr_expr
197                aggr_plan = aggr_plan.recompute_schema()?;
198
199                // reconstruct proj with new proj_exprs
200                let mut new_proj_exprs = proj.expr.clone();
201                for proj_expr in new_proj_exprs.iter_mut() {
202                    if let Some(new_proj_expr) =
203                        replace_old_proj_exprs.get(&proj_expr.name_for_alias()?)
204                    {
205                        *proj_expr = new_proj_expr.clone();
206                    }
207                    *proj_expr = proj_expr
208                        .clone()
209                        .transform(|expr| {
210                            if let Some(new_expr) =
211                                replace_old_proj_exprs.get(&expr.name_for_alias()?)
212                            {
213                                Ok(Transformed::yes(new_expr.clone()))
214                            } else {
215                                Ok(Transformed::no(expr))
216                            }
217                        })?
218                        .data;
219                }
220                let proj = datafusion_expr::LogicalPlan::Projection(Projection::try_new(
221                    new_proj_exprs,
222                    Arc::new(aggr_plan),
223                )?);
224                return Ok(Transformed::yes(proj));
225            }
226        }
227    }
228    Ok(Transformed::no(plan))
229}
230
231/// expand `avg(<expr>)` function into `cast(sum((<expr>) AS f64)/count((<expr>)`
232fn expand_avg_analyzer(
233    plan: datafusion_expr::LogicalPlan,
234) -> Result<Transformed<datafusion_expr::LogicalPlan>, DataFusionError> {
235    let mut schema = merge_schema(&plan.inputs());
236
237    if let datafusion_expr::LogicalPlan::TableScan(ts) = &plan {
238        let source_schema =
239            DFSchema::try_from_qualified_schema(ts.table_name.clone(), &ts.source.schema())?;
240        schema.merge(&source_schema);
241    }
242
243    let mut expr_rewrite = ExpandAvgRewriter::new(&schema);
244
245    let name_preserver = NamePreserver::new(&plan);
246    // apply coercion rewrite all expressions in the plan individually
247    plan.map_expressions(|expr| {
248        let original_name = name_preserver.save(&expr);
249        Ok(expr
250            .rewrite(&mut expr_rewrite)?
251            .update_data(|expr| original_name.restore(expr)))
252    })?
253    .map_data(|plan| plan.recompute_schema())
254}
255
256/// rewrite `avg(<expr>)` function into `CASE WHEN count(<expr>) !=0 THEN  cast(sum((<expr>) AS avg_return_type)/count((<expr>) ELSE 0`
257///
258/// TODO(discord9): support avg return type decimal128
259///
260/// see impl details at https://github.com/apache/datafusion/blob/4ad4f90d86c57226a4e0fb1f79dfaaf0d404c273/datafusion/expr/src/type_coercion/aggregates.rs#L457-L462
261pub(crate) struct ExpandAvgRewriter<'a> {
262    /// schema of the plan
263    #[allow(unused)]
264    pub(crate) schema: &'a DFSchema,
265}
266
267impl<'a> ExpandAvgRewriter<'a> {
268    fn new(schema: &'a DFSchema) -> Self {
269        Self { schema }
270    }
271}
272
273impl TreeNodeRewriter for ExpandAvgRewriter<'_> {
274    type Node = Expr;
275
276    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>, DataFusionError> {
277        if let Expr::AggregateFunction(aggr_func) = &expr {
278            if aggr_func.func.name() == "avg" {
279                let sum_expr = {
280                    let mut tmp = aggr_func.clone();
281                    tmp.func = sum_udaf();
282                    Expr::AggregateFunction(tmp)
283                };
284                let sum_cast = {
285                    let mut tmp = sum_expr.clone();
286                    tmp = Expr::Cast(datafusion_expr::Cast {
287                        expr: Box::new(tmp),
288                        data_type: arrow_schema::DataType::Float64,
289                    });
290                    tmp
291                };
292
293                let count_expr = {
294                    let mut tmp = aggr_func.clone();
295                    tmp.func = count_udaf();
296
297                    Expr::AggregateFunction(tmp)
298                };
299                let count_expr_ref =
300                    Expr::Column(Column::from_qualified_name(count_expr.name_for_alias()?));
301
302                let div =
303                    BinaryExpr::new(Box::new(sum_cast), Operator::Divide, Box::new(count_expr));
304                let div_expr = Box::new(Expr::BinaryExpr(div));
305
306                let zero = Box::new(0.lit());
307                let not_zero =
308                    BinaryExpr::new(Box::new(count_expr_ref), Operator::NotEq, zero.clone());
309                let not_zero = Box::new(Expr::BinaryExpr(not_zero));
310                let null = Box::new(Expr::Literal(ScalarValue::Null, None));
311
312                let case_when =
313                    datafusion_expr::Case::new(None, vec![(not_zero, div_expr)], Some(null));
314                let case_when_expr = Expr::Case(case_when);
315
316                return Ok(Transformed::yes(case_when_expr));
317            }
318        }
319
320        Ok(Transformed::no(expr))
321    }
322}
323
324/// expand tumble in aggr expr to tumble_start and tumble_end with column name like `window_start`
325#[derive(Debug)]
326struct TumbleExpandRule;
327
328impl AnalyzerRule for TumbleExpandRule {
329    fn analyze(
330        &self,
331        plan: datafusion_expr::LogicalPlan,
332        _config: &ConfigOptions,
333    ) -> datafusion_common::Result<datafusion_expr::LogicalPlan> {
334        let transformed = plan
335            .transform_up_with_subqueries(expand_tumble_analyzer)?
336            .data;
337        Ok(transformed)
338    }
339
340    fn name(&self) -> &str {
341        "tumble_expand"
342    }
343}
344
345/// expand `tumble` in aggr expr to `tumble_start` and `tumble_end`, also expand related alias and column ref
346///
347/// will add `tumble_start` and `tumble_end` to outer projection if not exist before
348fn expand_tumble_analyzer(
349    plan: datafusion_expr::LogicalPlan,
350) -> Result<Transformed<datafusion_expr::LogicalPlan>, DataFusionError> {
351    if let datafusion_expr::LogicalPlan::Projection(proj) = &plan {
352        if let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref() {
353            let mut new_group_expr = vec![];
354            let mut alias_to_expand = HashMap::new();
355            let mut encountered_tumble = false;
356            for expr in aggr.group_expr.iter() {
357                match expr {
358                    datafusion_expr::Expr::ScalarFunction(func) if func.name() == "tumble" => {
359                        encountered_tumble = true;
360
361                        let tumble_start = TumbleExpand::new(TUMBLE_START);
362                        let tumble_start = datafusion_expr::expr::ScalarFunction::new_udf(
363                            Arc::new(tumble_start.into()),
364                            func.args.clone(),
365                        );
366                        let tumble_start = datafusion_expr::Expr::ScalarFunction(tumble_start);
367                        let start_col_name = tumble_start.name_for_alias()?;
368                        new_group_expr.push(tumble_start);
369
370                        let tumble_end = TumbleExpand::new(TUMBLE_END);
371                        let tumble_end = datafusion_expr::expr::ScalarFunction::new_udf(
372                            Arc::new(tumble_end.into()),
373                            func.args.clone(),
374                        );
375                        let tumble_end = datafusion_expr::Expr::ScalarFunction(tumble_end);
376                        let end_col_name = tumble_end.name_for_alias()?;
377                        new_group_expr.push(tumble_end);
378
379                        alias_to_expand
380                            .insert(expr.name_for_alias()?, (start_col_name, end_col_name));
381                    }
382                    _ => new_group_expr.push(expr.clone()),
383                }
384            }
385            if !encountered_tumble {
386                return Ok(Transformed::no(plan));
387            }
388            let mut new_aggr = aggr.clone();
389            new_aggr.group_expr = new_group_expr;
390            let new_aggr = datafusion_expr::LogicalPlan::Aggregate(new_aggr).recompute_schema()?;
391            // replace alias in projection if needed, and add new column ref if necessary
392            let mut new_proj_expr = vec![];
393            let mut have_expanded = false;
394
395            for proj_expr in proj.expr.iter() {
396                if let Some((start_col_name, end_col_name)) =
397                    alias_to_expand.get(&proj_expr.name_for_alias()?)
398                {
399                    let start_col = Column::from_qualified_name(start_col_name);
400                    let end_col = Column::from_qualified_name(end_col_name);
401                    new_proj_expr.push(datafusion_expr::Expr::Column(start_col));
402                    new_proj_expr.push(datafusion_expr::Expr::Column(end_col));
403                    have_expanded = true;
404                } else {
405                    new_proj_expr.push(proj_expr.clone());
406                }
407            }
408
409            // append to end of projection if not exist
410            if !have_expanded {
411                for (start_col_name, end_col_name) in alias_to_expand.values() {
412                    let start_col = Column::from_qualified_name(start_col_name);
413                    let end_col = Column::from_qualified_name(end_col_name);
414                    new_proj_expr
415                        .push(datafusion_expr::Expr::Column(start_col).alias("window_start"));
416                    new_proj_expr.push(datafusion_expr::Expr::Column(end_col).alias("window_end"));
417                }
418            }
419
420            let new_proj = datafusion_expr::LogicalPlan::Projection(Projection::try_new(
421                new_proj_expr,
422                Arc::new(new_aggr),
423            )?);
424            return Ok(Transformed::yes(new_proj));
425        }
426    }
427
428    Ok(Transformed::no(plan))
429}
430
431/// This is a placeholder for tumble_start and tumble_end function, so that datafusion can
432/// recognize them as scalar function
433#[derive(Debug)]
434pub struct TumbleExpand {
435    signature: Signature,
436    name: String,
437}
438
439impl TumbleExpand {
440    pub fn new(name: &str) -> Self {
441        Self {
442            signature: Signature::new(TypeSignature::UserDefined, Volatility::Immutable),
443            name: name.to_string(),
444        }
445    }
446}
447
448impl ScalarUDFImpl for TumbleExpand {
449    fn as_any(&self) -> &dyn std::any::Any {
450        self
451    }
452
453    fn name(&self) -> &str {
454        &self.name
455    }
456
457    /// elide the signature for now
458    fn signature(&self) -> &Signature {
459        &self.signature
460    }
461
462    fn coerce_types(
463        &self,
464        arg_types: &[arrow_schema::DataType],
465    ) -> datafusion_common::Result<Vec<arrow_schema::DataType>> {
466        match (arg_types.first(), arg_types.get(1), arg_types.get(2)) {
467            (Some(ts), Some(window), opt) => {
468                use arrow_schema::DataType::*;
469                if !matches!(ts, Date32 | Timestamp(_, _)) {
470                    return Err(DataFusionError::Plan(
471                        format!("Expect timestamp column as first arg for tumble_start, found {:?}", ts)
472                    ));
473                }
474                if !matches!(window, Utf8 | Interval(_)) {
475                    return Err(DataFusionError::Plan(
476                        format!("Expect second arg for window size's type being interval for tumble_start, found {:?}", window),
477                    ));
478                }
479
480                if let Some(start_time) = opt{
481                    if !matches!(start_time,  Utf8 | Date32 | Timestamp(_, _)){
482                        return Err(DataFusionError::Plan(
483                            format!("Expect start_time to either be date, timestamp or string, found {:?}", start_time)
484                        ));
485                    }
486                }
487
488                Ok(arg_types.to_vec())
489            }
490            _ => Err(DataFusionError::Plan(
491                "Expect tumble function have at least two arg(timestamp column and window size) and a third optional arg for starting time".to_string(),
492            )),
493        }
494    }
495
496    fn return_type(
497        &self,
498        arg_types: &[arrow_schema::DataType],
499    ) -> Result<arrow_schema::DataType, DataFusionError> {
500        arg_types.first().cloned().ok_or_else(|| {
501            DataFusionError::Plan(
502                "Expect tumble function have at least two arg(timestamp column and window size)"
503                    .to_string(),
504            )
505        })
506    }
507
508    fn invoke_with_args(
509        &self,
510        _args: ScalarFunctionArgs,
511    ) -> datafusion_common::Result<ColumnarValue> {
512        Err(DataFusionError::Plan(
513            "This function should not be executed by datafusion".to_string(),
514        ))
515    }
516}
517
518/// This rule check all group by exprs, and make sure they are also in select clause in a aggr query
519#[derive(Debug)]
520struct CheckGroupByRule {}
521
522impl CheckGroupByRule {
523    pub fn new() -> Self {
524        Self {}
525    }
526}
527
528impl AnalyzerRule for CheckGroupByRule {
529    fn analyze(
530        &self,
531        plan: datafusion_expr::LogicalPlan,
532        _config: &ConfigOptions,
533    ) -> datafusion_common::Result<datafusion_expr::LogicalPlan> {
534        let transformed = plan
535            .transform_up_with_subqueries(check_group_by_analyzer)?
536            .data;
537        Ok(transformed)
538    }
539
540    fn name(&self) -> &str {
541        "check_groupby"
542    }
543}
544
545/// make sure everything in group by's expr is in select
546fn check_group_by_analyzer(
547    plan: datafusion_expr::LogicalPlan,
548) -> Result<Transformed<datafusion_expr::LogicalPlan>, DataFusionError> {
549    if let datafusion_expr::LogicalPlan::Projection(proj) = &plan {
550        if let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref() {
551            let mut found_column_used = FindColumn::new();
552            proj.expr
553                .iter()
554                .map(|i| i.visit(&mut found_column_used))
555                .count();
556            for expr in aggr.group_expr.iter() {
557                if !found_column_used
558                    .names_for_alias
559                    .contains(&expr.name_for_alias()?)
560                {
561                    return Err(DataFusionError::Plan(format!("Expect {} expr in group by also exist in select list, but select list only contain {:?}",expr.name_for_alias()?, found_column_used.names_for_alias)));
562                }
563            }
564        }
565    }
566
567    Ok(Transformed::no(plan))
568}
569
570/// Find all column names in a plan
571#[derive(Debug, Default)]
572struct FindColumn {
573    names_for_alias: HashSet<String>,
574}
575
576impl FindColumn {
577    fn new() -> Self {
578        Default::default()
579    }
580}
581
582impl TreeNodeVisitor<'_> for FindColumn {
583    type Node = datafusion_expr::Expr;
584    fn f_down(
585        &mut self,
586        node: &datafusion_expr::Expr,
587    ) -> Result<TreeNodeRecursion, DataFusionError> {
588        if let datafusion_expr::Expr::Column(_) = node {
589            self.names_for_alias.insert(node.name_for_alias()?);
590        }
591        Ok(TreeNodeRecursion::Continue)
592    }
593}