flow/
df_optimizer.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Datafusion optimizer for flow plan
16
17#![warn(unused)]
18
19use std::collections::{HashMap, HashSet};
20use std::sync::Arc;
21
22use common_error::ext::BoxedError;
23use common_telemetry::debug;
24use datafusion::config::ConfigOptions;
25use datafusion::error::DataFusionError;
26use datafusion::functions_aggregate::count::count_udaf;
27use datafusion::functions_aggregate::sum::sum_udaf;
28use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
29use datafusion::optimizer::common_subexpr_eliminate::CommonSubexprEliminate;
30use datafusion::optimizer::optimize_projections::OptimizeProjections;
31use datafusion::optimizer::simplify_expressions::SimplifyExpressions;
32use datafusion::optimizer::unwrap_cast_in_comparison::UnwrapCastInComparison;
33use datafusion::optimizer::utils::NamePreserver;
34use datafusion::optimizer::{Analyzer, AnalyzerRule, Optimizer, OptimizerContext};
35use datafusion_common::tree_node::{
36    Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
37};
38use datafusion_common::{Column, DFSchema, ScalarValue};
39use datafusion_expr::utils::merge_schema;
40use datafusion_expr::{
41    BinaryExpr, ColumnarValue, Expr, Operator, Projection, ScalarFunctionArgs, ScalarUDFImpl,
42    Signature, TypeSignature, Volatility,
43};
44use query::optimizer::count_wildcard::CountWildcardToTimeIndexRule;
45use query::parser::QueryLanguageParser;
46use query::query_engine::DefaultSerializer;
47use query::QueryEngine;
48use snafu::ResultExt;
49/// note here we are using the `substrait_proto_df` crate from the `substrait` module and
50/// rename it to `substrait_proto`
51use substrait::DFLogicalSubstraitConvertor;
52
53use crate::adapter::FlownodeContext;
54use crate::error::{DatafusionSnafu, Error, ExternalSnafu, UnexpectedSnafu};
55use crate::expr::{TUMBLE_END, TUMBLE_START};
56use crate::plan::TypedPlan;
57
58// TODO(discord9): use `Analyzer` to manage rules if more `AnalyzerRule` is needed
59pub async fn apply_df_optimizer(
60    plan: datafusion_expr::LogicalPlan,
61) -> Result<datafusion_expr::LogicalPlan, Error> {
62    let cfg = ConfigOptions::new();
63    let analyzer = Analyzer::with_rules(vec![
64        Arc::new(CountWildcardToTimeIndexRule),
65        Arc::new(AvgExpandRule),
66        Arc::new(TumbleExpandRule),
67        Arc::new(CheckGroupByRule::new()),
68        Arc::new(TypeCoercion::new()),
69    ]);
70    let plan = analyzer
71        .execute_and_check(plan, &cfg, |p, r| {
72            debug!("After apply rule {}, get plan: \n{:?}", r.name(), p);
73        })
74        .context(DatafusionSnafu {
75            context: "Fail to apply analyzer",
76        })?;
77
78    let ctx = OptimizerContext::new();
79    let optimizer = Optimizer::with_rules(vec![
80        Arc::new(OptimizeProjections::new()),
81        Arc::new(CommonSubexprEliminate::new()),
82        Arc::new(SimplifyExpressions::new()),
83        Arc::new(UnwrapCastInComparison::new()),
84    ]);
85    let plan = optimizer
86        .optimize(plan, &ctx, |_, _| {})
87        .context(DatafusionSnafu {
88            context: "Fail to apply optimizer",
89        })?;
90
91    Ok(plan)
92}
93
94/// To reuse existing code for parse sql, the sql is first parsed into a datafusion logical plan,
95/// then to a substrait plan, and finally to a flow plan.
96pub async fn sql_to_flow_plan(
97    ctx: &mut FlownodeContext,
98    engine: &Arc<dyn QueryEngine>,
99    sql: &str,
100) -> Result<TypedPlan, Error> {
101    let query_ctx = ctx.query_context.clone().ok_or_else(|| {
102        UnexpectedSnafu {
103            reason: "Query context is missing",
104        }
105        .build()
106    })?;
107    let stmt = QueryLanguageParser::parse_sql(sql, &query_ctx)
108        .map_err(BoxedError::new)
109        .context(ExternalSnafu)?;
110    let plan = engine
111        .planner()
112        .plan(&stmt, query_ctx)
113        .await
114        .map_err(BoxedError::new)
115        .context(ExternalSnafu)?;
116
117    let opted_plan = apply_df_optimizer(plan).await?;
118
119    // TODO(discord9): add df optimization
120    let sub_plan = DFLogicalSubstraitConvertor {}
121        .to_sub_plan(&opted_plan, DefaultSerializer)
122        .map_err(BoxedError::new)
123        .context(ExternalSnafu)?;
124
125    let flow_plan = TypedPlan::from_substrait_plan(ctx, &sub_plan).await?;
126
127    Ok(flow_plan)
128}
129
130#[derive(Debug)]
131struct AvgExpandRule;
132
133impl AnalyzerRule for AvgExpandRule {
134    fn analyze(
135        &self,
136        plan: datafusion_expr::LogicalPlan,
137        _config: &ConfigOptions,
138    ) -> datafusion_common::Result<datafusion_expr::LogicalPlan> {
139        let transformed = plan
140            .transform_up_with_subqueries(expand_avg_analyzer)?
141            .data
142            .transform_down_with_subqueries(put_aggr_to_proj_analyzer)?
143            .data;
144        Ok(transformed)
145    }
146
147    fn name(&self) -> &str {
148        "avg_expand"
149    }
150}
151
152/// lift aggr's composite aggr_expr to outer proj, and leave aggr only with simple direct aggr expr
153/// i.e.
154/// ```ignore
155/// proj: avg(x)
156/// -- aggr: [sum(x)/count(x) as avg(x)]
157/// ```
158/// becomes:
159/// ```ignore
160/// proj: sum(x)/count(x) as avg(x)
161/// -- aggr: [sum(x), count(x)]
162/// ```
163fn put_aggr_to_proj_analyzer(
164    plan: datafusion_expr::LogicalPlan,
165) -> Result<Transformed<datafusion_expr::LogicalPlan>, DataFusionError> {
166    if let datafusion_expr::LogicalPlan::Projection(proj) = &plan {
167        if let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref() {
168            let mut replace_old_proj_exprs = HashMap::new();
169            let mut expanded_aggr_exprs = vec![];
170            for aggr_expr in &aggr.aggr_expr {
171                let mut is_composite = false;
172                if let Expr::AggregateFunction(_) = &aggr_expr {
173                    expanded_aggr_exprs.push(aggr_expr.clone());
174                } else {
175                    let old_name = aggr_expr.name_for_alias()?;
176                    let new_proj_expr = aggr_expr
177                        .clone()
178                        .transform(|ch| {
179                            if let Expr::AggregateFunction(_) = &ch {
180                                is_composite = true;
181                                expanded_aggr_exprs.push(ch.clone());
182                                Ok(Transformed::yes(Expr::Column(Column::from_qualified_name(
183                                    ch.name_for_alias()?,
184                                ))))
185                            } else {
186                                Ok(Transformed::no(ch))
187                            }
188                        })?
189                        .data;
190                    replace_old_proj_exprs.insert(old_name, new_proj_expr);
191                }
192            }
193
194            if expanded_aggr_exprs.len() > aggr.aggr_expr.len() {
195                let mut aggr = aggr.clone();
196                aggr.aggr_expr = expanded_aggr_exprs;
197                let mut aggr_plan = datafusion_expr::LogicalPlan::Aggregate(aggr);
198                // important to recompute schema after changing aggr_expr
199                aggr_plan = aggr_plan.recompute_schema()?;
200
201                // reconstruct proj with new proj_exprs
202                let mut new_proj_exprs = proj.expr.clone();
203                for proj_expr in new_proj_exprs.iter_mut() {
204                    if let Some(new_proj_expr) =
205                        replace_old_proj_exprs.get(&proj_expr.name_for_alias()?)
206                    {
207                        *proj_expr = new_proj_expr.clone();
208                    }
209                    *proj_expr = proj_expr
210                        .clone()
211                        .transform(|expr| {
212                            if let Some(new_expr) =
213                                replace_old_proj_exprs.get(&expr.name_for_alias()?)
214                            {
215                                Ok(Transformed::yes(new_expr.clone()))
216                            } else {
217                                Ok(Transformed::no(expr))
218                            }
219                        })?
220                        .data;
221                }
222                let proj = datafusion_expr::LogicalPlan::Projection(Projection::try_new(
223                    new_proj_exprs,
224                    Arc::new(aggr_plan),
225                )?);
226                return Ok(Transformed::yes(proj));
227            }
228        }
229    }
230    Ok(Transformed::no(plan))
231}
232
233/// expand `avg(<expr>)` function into `cast(sum((<expr>) AS f64)/count((<expr>)`
234fn expand_avg_analyzer(
235    plan: datafusion_expr::LogicalPlan,
236) -> Result<Transformed<datafusion_expr::LogicalPlan>, DataFusionError> {
237    let mut schema = merge_schema(&plan.inputs());
238
239    if let datafusion_expr::LogicalPlan::TableScan(ts) = &plan {
240        let source_schema =
241            DFSchema::try_from_qualified_schema(ts.table_name.clone(), &ts.source.schema())?;
242        schema.merge(&source_schema);
243    }
244
245    let mut expr_rewrite = ExpandAvgRewriter::new(&schema);
246
247    let name_preserver = NamePreserver::new(&plan);
248    // apply coercion rewrite all expressions in the plan individually
249    plan.map_expressions(|expr| {
250        let original_name = name_preserver.save(&expr);
251        Ok(expr
252            .rewrite(&mut expr_rewrite)?
253            .update_data(|expr| original_name.restore(expr)))
254    })?
255    .map_data(|plan| plan.recompute_schema())
256}
257
258/// rewrite `avg(<expr>)` function into `CASE WHEN count(<expr>) !=0 THEN  cast(sum((<expr>) AS avg_return_type)/count((<expr>) ELSE 0`
259///
260/// TODO(discord9): support avg return type decimal128
261///
262/// see impl details at https://github.com/apache/datafusion/blob/4ad4f90d86c57226a4e0fb1f79dfaaf0d404c273/datafusion/expr/src/type_coercion/aggregates.rs#L457-L462
263pub(crate) struct ExpandAvgRewriter<'a> {
264    /// schema of the plan
265    #[allow(unused)]
266    pub(crate) schema: &'a DFSchema,
267}
268
269impl<'a> ExpandAvgRewriter<'a> {
270    fn new(schema: &'a DFSchema) -> Self {
271        Self { schema }
272    }
273}
274
275impl TreeNodeRewriter for ExpandAvgRewriter<'_> {
276    type Node = Expr;
277
278    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>, DataFusionError> {
279        if let Expr::AggregateFunction(aggr_func) = &expr {
280            if aggr_func.func.name() == "avg" {
281                let sum_expr = {
282                    let mut tmp = aggr_func.clone();
283                    tmp.func = sum_udaf();
284                    Expr::AggregateFunction(tmp)
285                };
286                let sum_cast = {
287                    let mut tmp = sum_expr.clone();
288                    tmp = Expr::Cast(datafusion_expr::Cast {
289                        expr: Box::new(tmp),
290                        data_type: arrow_schema::DataType::Float64,
291                    });
292                    tmp
293                };
294
295                let count_expr = {
296                    let mut tmp = aggr_func.clone();
297                    tmp.func = count_udaf();
298
299                    Expr::AggregateFunction(tmp)
300                };
301                let count_expr_ref =
302                    Expr::Column(Column::from_qualified_name(count_expr.name_for_alias()?));
303
304                let div =
305                    BinaryExpr::new(Box::new(sum_cast), Operator::Divide, Box::new(count_expr));
306                let div_expr = Box::new(Expr::BinaryExpr(div));
307
308                let zero = Box::new(Expr::Literal(ScalarValue::Int64(Some(0))));
309                let not_zero =
310                    BinaryExpr::new(Box::new(count_expr_ref), Operator::NotEq, zero.clone());
311                let not_zero = Box::new(Expr::BinaryExpr(not_zero));
312                let null = Box::new(Expr::Literal(ScalarValue::Null));
313
314                let case_when =
315                    datafusion_expr::Case::new(None, vec![(not_zero, div_expr)], Some(null));
316                let case_when_expr = Expr::Case(case_when);
317
318                return Ok(Transformed::yes(case_when_expr));
319            }
320        }
321
322        Ok(Transformed::no(expr))
323    }
324}
325
326/// expand tumble in aggr expr to tumble_start and tumble_end with column name like `window_start`
327#[derive(Debug)]
328struct TumbleExpandRule;
329
330impl AnalyzerRule for TumbleExpandRule {
331    fn analyze(
332        &self,
333        plan: datafusion_expr::LogicalPlan,
334        _config: &ConfigOptions,
335    ) -> datafusion_common::Result<datafusion_expr::LogicalPlan> {
336        let transformed = plan
337            .transform_up_with_subqueries(expand_tumble_analyzer)?
338            .data;
339        Ok(transformed)
340    }
341
342    fn name(&self) -> &str {
343        "tumble_expand"
344    }
345}
346
347/// expand `tumble` in aggr expr to `tumble_start` and `tumble_end`, also expand related alias and column ref
348///
349/// will add `tumble_start` and `tumble_end` to outer projection if not exist before
350fn expand_tumble_analyzer(
351    plan: datafusion_expr::LogicalPlan,
352) -> Result<Transformed<datafusion_expr::LogicalPlan>, DataFusionError> {
353    if let datafusion_expr::LogicalPlan::Projection(proj) = &plan {
354        if let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref() {
355            let mut new_group_expr = vec![];
356            let mut alias_to_expand = HashMap::new();
357            let mut encountered_tumble = false;
358            for expr in aggr.group_expr.iter() {
359                match expr {
360                    datafusion_expr::Expr::ScalarFunction(func) if func.name() == "tumble" => {
361                        encountered_tumble = true;
362
363                        let tumble_start = TumbleExpand::new(TUMBLE_START);
364                        let tumble_start = datafusion_expr::expr::ScalarFunction::new_udf(
365                            Arc::new(tumble_start.into()),
366                            func.args.clone(),
367                        );
368                        let tumble_start = datafusion_expr::Expr::ScalarFunction(tumble_start);
369                        let start_col_name = tumble_start.name_for_alias()?;
370                        new_group_expr.push(tumble_start);
371
372                        let tumble_end = TumbleExpand::new(TUMBLE_END);
373                        let tumble_end = datafusion_expr::expr::ScalarFunction::new_udf(
374                            Arc::new(tumble_end.into()),
375                            func.args.clone(),
376                        );
377                        let tumble_end = datafusion_expr::Expr::ScalarFunction(tumble_end);
378                        let end_col_name = tumble_end.name_for_alias()?;
379                        new_group_expr.push(tumble_end);
380
381                        alias_to_expand
382                            .insert(expr.name_for_alias()?, (start_col_name, end_col_name));
383                    }
384                    _ => new_group_expr.push(expr.clone()),
385                }
386            }
387            if !encountered_tumble {
388                return Ok(Transformed::no(plan));
389            }
390            let mut new_aggr = aggr.clone();
391            new_aggr.group_expr = new_group_expr;
392            let new_aggr = datafusion_expr::LogicalPlan::Aggregate(new_aggr).recompute_schema()?;
393            // replace alias in projection if needed, and add new column ref if necessary
394            let mut new_proj_expr = vec![];
395            let mut have_expanded = false;
396
397            for proj_expr in proj.expr.iter() {
398                if let Some((start_col_name, end_col_name)) =
399                    alias_to_expand.get(&proj_expr.name_for_alias()?)
400                {
401                    let start_col = Column::from_qualified_name(start_col_name);
402                    let end_col = Column::from_qualified_name(end_col_name);
403                    new_proj_expr.push(datafusion_expr::Expr::Column(start_col));
404                    new_proj_expr.push(datafusion_expr::Expr::Column(end_col));
405                    have_expanded = true;
406                } else {
407                    new_proj_expr.push(proj_expr.clone());
408                }
409            }
410
411            // append to end of projection if not exist
412            if !have_expanded {
413                for (start_col_name, end_col_name) in alias_to_expand.values() {
414                    let start_col = Column::from_qualified_name(start_col_name);
415                    let end_col = Column::from_qualified_name(end_col_name);
416                    new_proj_expr
417                        .push(datafusion_expr::Expr::Column(start_col).alias("window_start"));
418                    new_proj_expr.push(datafusion_expr::Expr::Column(end_col).alias("window_end"));
419                }
420            }
421
422            let new_proj = datafusion_expr::LogicalPlan::Projection(Projection::try_new(
423                new_proj_expr,
424                Arc::new(new_aggr),
425            )?);
426            return Ok(Transformed::yes(new_proj));
427        }
428    }
429
430    Ok(Transformed::no(plan))
431}
432
433/// This is a placeholder for tumble_start and tumble_end function, so that datafusion can
434/// recognize them as scalar function
435#[derive(Debug)]
436pub struct TumbleExpand {
437    signature: Signature,
438    name: String,
439}
440
441impl TumbleExpand {
442    pub fn new(name: &str) -> Self {
443        Self {
444            signature: Signature::new(TypeSignature::UserDefined, Volatility::Immutable),
445            name: name.to_string(),
446        }
447    }
448}
449
450impl ScalarUDFImpl for TumbleExpand {
451    fn as_any(&self) -> &dyn std::any::Any {
452        self
453    }
454
455    fn name(&self) -> &str {
456        &self.name
457    }
458
459    /// elide the signature for now
460    fn signature(&self) -> &Signature {
461        &self.signature
462    }
463
464    fn coerce_types(
465        &self,
466        arg_types: &[arrow_schema::DataType],
467    ) -> datafusion_common::Result<Vec<arrow_schema::DataType>> {
468        match (arg_types.first(), arg_types.get(1), arg_types.get(2)) {
469            (Some(ts), Some(window), opt) => {
470                use arrow_schema::DataType::*;
471                if !matches!(ts, Date32 | Timestamp(_, _)) {
472                    return Err(DataFusionError::Plan(
473                        format!("Expect timestamp column as first arg for tumble_start, found {:?}", ts)
474                    ));
475                }
476                if !matches!(window, Utf8 | Interval(_)) {
477                    return Err(DataFusionError::Plan(
478                        format!("Expect second arg for window size's type being interval for tumble_start, found {:?}", window),
479                    ));
480                }
481
482                if let Some(start_time) = opt{
483                    if !matches!(start_time,  Utf8 | Date32 | Timestamp(_, _)){
484                        return Err(DataFusionError::Plan(
485                            format!("Expect start_time to either be date, timestamp or string, found {:?}", start_time)
486                        ));
487                    }
488                }
489
490                Ok(arg_types.to_vec())
491            }
492            _ => Err(DataFusionError::Plan(
493                "Expect tumble function have at least two arg(timestamp column and window size) and a third optional arg for starting time".to_string(),
494            )),
495        }
496    }
497
498    fn return_type(
499        &self,
500        arg_types: &[arrow_schema::DataType],
501    ) -> Result<arrow_schema::DataType, DataFusionError> {
502        arg_types.first().cloned().ok_or_else(|| {
503            DataFusionError::Plan(
504                "Expect tumble function have at least two arg(timestamp column and window size)"
505                    .to_string(),
506            )
507        })
508    }
509
510    fn invoke_with_args(
511        &self,
512        _args: ScalarFunctionArgs,
513    ) -> datafusion_common::Result<ColumnarValue> {
514        Err(DataFusionError::Plan(
515            "This function should not be executed by datafusion".to_string(),
516        ))
517    }
518}
519
520/// This rule check all group by exprs, and make sure they are also in select clause in a aggr query
521#[derive(Debug)]
522struct CheckGroupByRule {}
523
524impl CheckGroupByRule {
525    pub fn new() -> Self {
526        Self {}
527    }
528}
529
530impl AnalyzerRule for CheckGroupByRule {
531    fn analyze(
532        &self,
533        plan: datafusion_expr::LogicalPlan,
534        _config: &ConfigOptions,
535    ) -> datafusion_common::Result<datafusion_expr::LogicalPlan> {
536        let transformed = plan
537            .transform_up_with_subqueries(check_group_by_analyzer)?
538            .data;
539        Ok(transformed)
540    }
541
542    fn name(&self) -> &str {
543        "check_groupby"
544    }
545}
546
547/// make sure everything in group by's expr is in select
548fn check_group_by_analyzer(
549    plan: datafusion_expr::LogicalPlan,
550) -> Result<Transformed<datafusion_expr::LogicalPlan>, DataFusionError> {
551    if let datafusion_expr::LogicalPlan::Projection(proj) = &plan {
552        if let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref() {
553            let mut found_column_used = FindColumn::new();
554            proj.expr
555                .iter()
556                .map(|i| i.visit(&mut found_column_used))
557                .count();
558            for expr in aggr.group_expr.iter() {
559                if !found_column_used
560                    .names_for_alias
561                    .contains(&expr.name_for_alias()?)
562                {
563                    return Err(DataFusionError::Plan(format!("Expect {} expr in group by also exist in select list, but select list only contain {:?}",expr.name_for_alias()?, found_column_used.names_for_alias)));
564                }
565            }
566        }
567    }
568
569    Ok(Transformed::no(plan))
570}
571
572/// Find all column names in a plan
573#[derive(Debug, Default)]
574struct FindColumn {
575    names_for_alias: HashSet<String>,
576}
577
578impl FindColumn {
579    fn new() -> Self {
580        Default::default()
581    }
582}
583
584impl TreeNodeVisitor<'_> for FindColumn {
585    type Node = datafusion_expr::Expr;
586    fn f_down(
587        &mut self,
588        node: &datafusion_expr::Expr,
589    ) -> Result<TreeNodeRecursion, DataFusionError> {
590        if let datafusion_expr::Expr::Column(_) = node {
591            self.names_for_alias.insert(node.name_for_alias()?);
592        }
593        Ok(TreeNodeRecursion::Continue)
594    }
595}