flow/
df_optimizer.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Datafusion optimizer for flow plan
16
17#![warn(unused)]
18
19use std::collections::HashSet;
20use std::sync::Arc;
21
22use common_error::ext::BoxedError;
23use common_telemetry::debug;
24use datafusion::config::ConfigOptions;
25use datafusion::error::DataFusionError;
26use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
27use datafusion::optimizer::common_subexpr_eliminate::CommonSubexprEliminate;
28use datafusion::optimizer::optimize_projections::OptimizeProjections;
29use datafusion::optimizer::simplify_expressions::SimplifyExpressions;
30use datafusion::optimizer::{Analyzer, AnalyzerRule, Optimizer, OptimizerContext};
31use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor};
32use query::QueryEngine;
33use query::optimizer::count_wildcard::CountWildcardToTimeIndexRule;
34use query::parser::QueryLanguageParser;
35use query::query_engine::DefaultSerializer;
36use session::context::QueryContextRef;
37use snafu::ResultExt;
38/// note here we are using the `substrait_proto_df` crate from the `substrait` module and
39/// rename it to `substrait_proto`
40use substrait::DFLogicalSubstraitConvertor;
41
42use crate::adapter::FlownodeContext;
43use crate::error::{DatafusionSnafu, Error, ExternalSnafu, UnexpectedSnafu};
44use crate::plan::TypedPlan;
45
46// TODO(discord9): use `Analyzer` to manage rules if more `AnalyzerRule` is needed
47pub async fn apply_df_optimizer(
48    plan: datafusion_expr::LogicalPlan,
49    query_ctx: &QueryContextRef,
50) -> Result<datafusion_expr::LogicalPlan, Error> {
51    let cfg = query_ctx.create_config_options();
52    let analyzer = Analyzer::with_rules(vec![
53        Arc::new(CountWildcardToTimeIndexRule),
54        Arc::new(CheckGroupByRule::new()),
55        Arc::new(TypeCoercion::new()),
56    ]);
57    let plan = analyzer
58        .execute_and_check(plan, &cfg, |p, r| {
59            debug!("After apply rule {}, get plan: \n{:?}", r.name(), p);
60        })
61        .context(DatafusionSnafu {
62            context: "Fail to apply analyzer",
63        })?;
64
65    let ctx = OptimizerContext::new();
66    let optimizer = Optimizer::with_rules(vec![
67        Arc::new(OptimizeProjections::new()),
68        Arc::new(CommonSubexprEliminate::new()),
69        Arc::new(SimplifyExpressions::new()),
70    ]);
71    let plan = optimizer
72        .optimize(plan, &ctx, |_, _| {})
73        .context(DatafusionSnafu {
74            context: "Fail to apply optimizer",
75        })?;
76
77    Ok(plan)
78}
79
80/// To reuse existing code for parse sql, the sql is first parsed into a datafusion logical plan,
81/// then to a substrait plan, and finally to a flow plan.
82pub async fn sql_to_flow_plan(
83    ctx: &mut FlownodeContext,
84    engine: &Arc<dyn QueryEngine>,
85    sql: &str,
86) -> Result<TypedPlan, Error> {
87    let query_ctx = ctx.query_context.clone().ok_or_else(|| {
88        UnexpectedSnafu {
89            reason: "Query context is missing",
90        }
91        .build()
92    })?;
93    let stmt = QueryLanguageParser::parse_sql(sql, &query_ctx)
94        .map_err(BoxedError::new)
95        .context(ExternalSnafu)?;
96    let plan = engine
97        .planner()
98        .plan(&stmt, query_ctx.clone())
99        .await
100        .map_err(BoxedError::new)
101        .context(ExternalSnafu)?;
102
103    let opted_plan = apply_df_optimizer(plan, &query_ctx).await?;
104
105    // TODO(discord9): add df optimization
106    let sub_plan = DFLogicalSubstraitConvertor {}
107        .to_sub_plan(&opted_plan, DefaultSerializer)
108        .map_err(BoxedError::new)
109        .context(ExternalSnafu)?;
110
111    let flow_plan = TypedPlan::from_substrait_plan(ctx, &sub_plan).await?;
112
113    Ok(flow_plan)
114}
115
116/// This rule check all group by exprs, and make sure they are also in select clause in a aggr query
117#[derive(Debug)]
118struct CheckGroupByRule {}
119
120impl CheckGroupByRule {
121    pub fn new() -> Self {
122        Self {}
123    }
124}
125
126impl AnalyzerRule for CheckGroupByRule {
127    fn analyze(
128        &self,
129        plan: datafusion_expr::LogicalPlan,
130        _config: &ConfigOptions,
131    ) -> datafusion_common::Result<datafusion_expr::LogicalPlan> {
132        let transformed = plan
133            .transform_up_with_subqueries(check_group_by_analyzer)?
134            .data;
135        Ok(transformed)
136    }
137
138    fn name(&self) -> &str {
139        "check_groupby"
140    }
141}
142
143/// make sure everything in group by's expr is in select
144fn check_group_by_analyzer(
145    plan: datafusion_expr::LogicalPlan,
146) -> Result<Transformed<datafusion_expr::LogicalPlan>, DataFusionError> {
147    if let datafusion_expr::LogicalPlan::Projection(proj) = &plan
148        && let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref()
149    {
150        let mut found_column_used = FindColumn::new();
151        proj.expr
152            .iter()
153            .map(|i| i.visit(&mut found_column_used))
154            .count();
155        for expr in aggr.group_expr.iter() {
156            if !found_column_used
157                .names_for_alias
158                .contains(&expr.name_for_alias()?)
159            {
160                return Err(DataFusionError::Plan(format!(
161                    "Expect {} expr in group by also exist in select list, but select list only contain {:?}",
162                    expr.name_for_alias()?,
163                    found_column_used.names_for_alias
164                )));
165            }
166        }
167    }
168
169    Ok(Transformed::no(plan))
170}
171
172/// Find all column names in a plan
173#[derive(Debug, Default)]
174struct FindColumn {
175    names_for_alias: HashSet<String>,
176}
177
178impl FindColumn {
179    fn new() -> Self {
180        Default::default()
181    }
182}
183
184impl TreeNodeVisitor<'_> for FindColumn {
185    type Node = datafusion_expr::Expr;
186    fn f_down(
187        &mut self,
188        node: &datafusion_expr::Expr,
189    ) -> Result<TreeNodeRecursion, DataFusionError> {
190        if let datafusion_expr::Expr::Column(_) = node {
191            self.names_for_alias.insert(node.name_for_alias()?);
192        }
193        Ok(TreeNodeRecursion::Continue)
194    }
195}