query/dist_plan/
analyzer.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::sync::Arc;
17
18use datafusion::datasource::DefaultTableSource;
19use datafusion::error::Result as DfResult;
20use datafusion_common::config::ConfigOptions;
21use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
22use datafusion_common::Column;
23use datafusion_expr::expr::{Exists, InSubquery};
24use datafusion_expr::utils::expr_to_columns;
25use datafusion_expr::{col as col_fn, Expr, LogicalPlan, LogicalPlanBuilder, Subquery};
26use datafusion_optimizer::analyzer::AnalyzerRule;
27use datafusion_optimizer::simplify_expressions::SimplifyExpressions;
28use datafusion_optimizer::{OptimizerContext, OptimizerRule};
29use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
30use table::metadata::TableType;
31use table::table::adapter::DfTableProviderAdapter;
32
33use crate::dist_plan::commutativity::{
34    partial_commutative_transformer, Categorizer, Commutativity,
35};
36use crate::dist_plan::merge_scan::MergeScanLogicalPlan;
37use crate::plan::ExtractExpr;
38use crate::query_engine::DefaultSerializer;
39
40#[derive(Debug)]
41pub struct DistPlannerAnalyzer;
42
43impl AnalyzerRule for DistPlannerAnalyzer {
44    fn name(&self) -> &str {
45        "DistPlannerAnalyzer"
46    }
47
48    fn analyze(
49        &self,
50        plan: LogicalPlan,
51        _config: &ConfigOptions,
52    ) -> datafusion_common::Result<LogicalPlan> {
53        // preprocess the input plan
54        let optimizer_context = OptimizerContext::new();
55        let plan = SimplifyExpressions::new()
56            .rewrite(plan, &optimizer_context)?
57            .data;
58
59        let plan = plan.transform(&Self::inspect_plan_with_subquery)?;
60        let mut rewriter = PlanRewriter::default();
61        let result = plan.data.rewrite(&mut rewriter)?.data;
62
63        Ok(result)
64    }
65}
66
67impl DistPlannerAnalyzer {
68    fn inspect_plan_with_subquery(plan: LogicalPlan) -> DfResult<Transformed<LogicalPlan>> {
69        // Workaround for https://github.com/GreptimeTeam/greptimedb/issues/5469 and https://github.com/GreptimeTeam/greptimedb/issues/5799
70        // FIXME(yingwen): Remove the `Limit` plan once we update DataFusion.
71        if let LogicalPlan::Limit(_) | LogicalPlan::Distinct(_) = &plan {
72            return Ok(Transformed::no(plan));
73        }
74
75        let exprs = plan
76            .expressions_consider_join()
77            .into_iter()
78            .map(|e| e.transform(&Self::transform_subquery).map(|x| x.data))
79            .collect::<DfResult<Vec<_>>>()?;
80
81        // Some plans that are special treated (should not call `with_new_exprs` on them)
82        if !matches!(plan, LogicalPlan::Unnest(_)) {
83            let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
84            Ok(Transformed::yes(plan.with_new_exprs(exprs, inputs)?))
85        } else {
86            Ok(Transformed::no(plan))
87        }
88    }
89
90    fn transform_subquery(expr: Expr) -> DfResult<Transformed<Expr>> {
91        match expr {
92            Expr::Exists(exists) => Ok(Transformed::yes(Expr::Exists(Exists {
93                subquery: Self::handle_subquery(exists.subquery)?,
94                negated: exists.negated,
95            }))),
96            Expr::InSubquery(in_subquery) => Ok(Transformed::yes(Expr::InSubquery(InSubquery {
97                expr: in_subquery.expr,
98                subquery: Self::handle_subquery(in_subquery.subquery)?,
99                negated: in_subquery.negated,
100            }))),
101            Expr::ScalarSubquery(scalar_subquery) => Ok(Transformed::yes(Expr::ScalarSubquery(
102                Self::handle_subquery(scalar_subquery)?,
103            ))),
104
105            _ => Ok(Transformed::no(expr)),
106        }
107    }
108
109    fn handle_subquery(subquery: Subquery) -> DfResult<Subquery> {
110        let mut rewriter = PlanRewriter::default();
111        let mut rewrote_subquery = subquery
112            .subquery
113            .as_ref()
114            .clone()
115            .rewrite(&mut rewriter)?
116            .data;
117        // Workaround. DF doesn't support the first plan in subquery to be an Extension
118        if matches!(rewrote_subquery, LogicalPlan::Extension(_)) {
119            let output_schema = rewrote_subquery.schema().clone();
120            let project_exprs = output_schema
121                .fields()
122                .iter()
123                .map(|f| col_fn(f.name()))
124                .collect::<Vec<_>>();
125            rewrote_subquery = LogicalPlanBuilder::from(rewrote_subquery)
126                .project(project_exprs)?
127                .build()?;
128        }
129
130        Ok(Subquery {
131            subquery: Arc::new(rewrote_subquery),
132            outer_ref_columns: subquery.outer_ref_columns,
133        })
134    }
135}
136
137/// Status of the rewriter to mark if the current pass is expanded
138#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
139enum RewriterStatus {
140    #[default]
141    Unexpanded,
142    Expanded,
143}
144
145#[derive(Debug, Default)]
146struct PlanRewriter {
147    /// Current level in the tree
148    level: usize,
149    /// Simulated stack for the `rewrite` recursion
150    stack: Vec<(LogicalPlan, usize)>,
151    /// Stages to be expanded
152    stage: Vec<LogicalPlan>,
153    status: RewriterStatus,
154    /// Partition columns of the table in current pass
155    partition_cols: Option<Vec<String>>,
156    column_requirements: HashSet<Column>,
157}
158
159impl PlanRewriter {
160    fn get_parent(&self) -> Option<&LogicalPlan> {
161        // level starts from 1, it's safe to minus by 1
162        self.stack
163            .iter()
164            .rev()
165            .find(|(_, level)| *level == self.level - 1)
166            .map(|(node, _)| node)
167    }
168
169    /// Return true if should stop and expand. The input plan is the parent node of current node
170    fn should_expand(&mut self, plan: &LogicalPlan) -> bool {
171        if DFLogicalSubstraitConvertor
172            .encode(plan, DefaultSerializer)
173            .is_err()
174        {
175            return true;
176        }
177        match Categorizer::check_plan(plan, self.partition_cols.clone()) {
178            Commutativity::Commutative => {}
179            Commutativity::PartialCommutative => {
180                if let Some(plan) = partial_commutative_transformer(plan) {
181                    self.update_column_requirements(&plan);
182                    self.stage.push(plan)
183                }
184            }
185            Commutativity::ConditionalCommutative(transformer) => {
186                if let Some(transformer) = transformer
187                    && let Some(plan) = transformer(plan)
188                {
189                    self.update_column_requirements(&plan);
190                    self.stage.push(plan)
191                }
192            }
193            Commutativity::TransformedCommutative(transformer) => {
194                if let Some(transformer) = transformer
195                    && let Some(plan) = transformer(plan)
196                {
197                    self.update_column_requirements(&plan);
198                    self.stage.push(plan)
199                }
200            }
201            Commutativity::NonCommutative
202            | Commutativity::Unimplemented
203            | Commutativity::Unsupported => {
204                return true;
205            }
206        }
207
208        false
209    }
210
211    fn update_column_requirements(&mut self, plan: &LogicalPlan) {
212        let mut container = HashSet::new();
213        for expr in plan.expressions() {
214            // this method won't fail
215            let _ = expr_to_columns(&expr, &mut container);
216        }
217
218        for col in container {
219            self.column_requirements.insert(col);
220        }
221    }
222
223    fn is_expanded(&self) -> bool {
224        self.status == RewriterStatus::Expanded
225    }
226
227    fn set_expanded(&mut self) {
228        self.status = RewriterStatus::Expanded;
229    }
230
231    fn set_unexpanded(&mut self) {
232        self.status = RewriterStatus::Unexpanded;
233    }
234
235    fn maybe_set_partitions(&mut self, plan: &LogicalPlan) {
236        if self.partition_cols.is_some() {
237            // only need to set once
238            return;
239        }
240
241        if let LogicalPlan::TableScan(table_scan) = plan {
242            if let Some(source) = table_scan
243                .source
244                .as_any()
245                .downcast_ref::<DefaultTableSource>()
246            {
247                if let Some(provider) = source
248                    .table_provider
249                    .as_any()
250                    .downcast_ref::<DfTableProviderAdapter>()
251                {
252                    if provider.table().table_type() == TableType::Base {
253                        let info = provider.table().table_info();
254                        let partition_key_indices = info.meta.partition_key_indices.clone();
255                        let schema = info.meta.schema.clone();
256                        let partition_cols = partition_key_indices
257                            .into_iter()
258                            .map(|index| schema.column_name_by_index(index).to_string())
259                            .collect::<Vec<String>>();
260                        self.partition_cols = Some(partition_cols);
261                    }
262                }
263            }
264        }
265    }
266
267    /// pop one stack item and reduce the level by 1
268    fn pop_stack(&mut self) {
269        self.level -= 1;
270        self.stack.pop();
271    }
272
273    fn expand(&mut self, mut on_node: LogicalPlan) -> DfResult<LogicalPlan> {
274        // store schema before expand
275        let schema = on_node.schema().clone();
276        let mut rewriter = EnforceDistRequirementRewriter {
277            column_requirements: std::mem::take(&mut self.column_requirements),
278        };
279        on_node = on_node.rewrite(&mut rewriter)?.data;
280
281        // add merge scan as the new root
282        let mut node = MergeScanLogicalPlan::new(
283            on_node,
284            false,
285            // at this stage, the partition cols should be set
286            // treat it as non-partitioned if None
287            self.partition_cols.clone().unwrap_or_default(),
288        )
289        .into_logical_plan();
290
291        // expand stages
292        for new_stage in self.stage.drain(..) {
293            node = new_stage
294                .with_new_exprs(new_stage.expressions_consider_join(), vec![node.clone()])?;
295        }
296        self.set_expanded();
297
298        // recover the schema
299        let node = LogicalPlanBuilder::from(node)
300            .project(schema.iter().map(|(qualifier, field)| {
301                Expr::Column(Column::new(qualifier.cloned(), field.name()))
302            }))?
303            .build()?;
304
305        Ok(node)
306    }
307}
308
309/// Implementation of the [`TreeNodeRewriter`] trait which is responsible for rewriting
310/// logical plans to enforce various requirement for distributed query.
311///
312/// Requirements enforced by this rewriter:
313/// - Enforce column requirements for `LogicalPlan::Projection` nodes. Makes sure the
314///   required columns are available in the sub plan.
315struct EnforceDistRequirementRewriter {
316    column_requirements: HashSet<Column>,
317}
318
319impl TreeNodeRewriter for EnforceDistRequirementRewriter {
320    type Node = LogicalPlan;
321
322    fn f_down(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
323        if let LogicalPlan::Projection(ref projection) = node {
324            let mut column_requirements = std::mem::take(&mut self.column_requirements);
325            if column_requirements.is_empty() {
326                return Ok(Transformed::no(node));
327            }
328
329            for expr in &projection.expr {
330                let (qualifier, name) = expr.qualified_name();
331                let column = Column::new(qualifier, name);
332                column_requirements.remove(&column);
333            }
334            if column_requirements.is_empty() {
335                return Ok(Transformed::no(node));
336            }
337
338            let mut new_exprs = projection.expr.clone();
339            for col in &column_requirements {
340                new_exprs.push(Expr::Column(col.clone()));
341            }
342            let new_node =
343                node.with_new_exprs(new_exprs, node.inputs().into_iter().cloned().collect())?;
344            return Ok(Transformed::yes(new_node));
345        }
346
347        Ok(Transformed::no(node))
348    }
349
350    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
351        Ok(Transformed::no(node))
352    }
353}
354
355impl TreeNodeRewriter for PlanRewriter {
356    type Node = LogicalPlan;
357
358    /// descend
359    fn f_down<'a>(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
360        self.level += 1;
361        self.stack.push((node.clone(), self.level));
362        // decendening will clear the stage
363        self.stage.clear();
364        self.set_unexpanded();
365        self.partition_cols = None;
366        Ok(Transformed::no(node))
367    }
368
369    /// ascend
370    ///
371    /// Besure to call `pop_stack` before returning
372    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
373        // only expand once on each ascending
374        if self.is_expanded() {
375            self.pop_stack();
376            return Ok(Transformed::no(node));
377        }
378
379        // only expand when the leaf is table scan
380        if node.inputs().is_empty() && !matches!(node, LogicalPlan::TableScan(_)) {
381            self.set_expanded();
382            self.pop_stack();
383            return Ok(Transformed::no(node));
384        }
385
386        self.maybe_set_partitions(&node);
387
388        let Some(parent) = self.get_parent() else {
389            let node = self.expand(node)?;
390            self.pop_stack();
391            return Ok(Transformed::yes(node));
392        };
393
394        // TODO(ruihang): avoid this clone
395        if self.should_expand(&parent.clone()) {
396            // TODO(ruihang): does this work for nodes with multiple children?;
397            let node = self.expand(node)?;
398            self.pop_stack();
399            return Ok(Transformed::yes(node));
400        }
401
402        self.pop_stack();
403        Ok(Transformed::no(node))
404    }
405}
406
407#[cfg(test)]
408mod test {
409    use std::sync::Arc;
410
411    use datafusion::datasource::DefaultTableSource;
412    use datafusion::functions_aggregate::expr_fn::avg;
413    use datafusion_common::JoinType;
414    use datafusion_expr::{col, lit, Expr, LogicalPlanBuilder};
415    use table::table::adapter::DfTableProviderAdapter;
416    use table::table::numbers::NumbersTable;
417
418    use super::*;
419
420    #[ignore = "Projection is disabled for https://github.com/apache/arrow-datafusion/issues/6489"]
421    #[test]
422    fn transform_simple_projection_filter() {
423        let numbers_table = NumbersTable::table(0);
424        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
425            DfTableProviderAdapter::new(numbers_table),
426        )));
427
428        let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
429            .unwrap()
430            .filter(col("number").lt(lit(10)))
431            .unwrap()
432            .project(vec![col("number")])
433            .unwrap()
434            .distinct()
435            .unwrap()
436            .build()
437            .unwrap();
438
439        let config = ConfigOptions::default();
440        let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
441        let expected = [
442            "Distinct:",
443            "  MergeScan [is_placeholder=false]",
444            "    Distinct:",
445            "      Projection: t.number",
446            "        Filter: t.number < Int32(10)",
447            "          TableScan: t",
448        ]
449        .join("\n");
450        assert_eq!(expected, result.to_string());
451    }
452
453    #[test]
454    fn transform_aggregator() {
455        let numbers_table = NumbersTable::table(0);
456        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
457            DfTableProviderAdapter::new(numbers_table),
458        )));
459
460        let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
461            .unwrap()
462            .aggregate(Vec::<Expr>::new(), vec![avg(col("number"))])
463            .unwrap()
464            .build()
465            .unwrap();
466
467        let config = ConfigOptions::default();
468        let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
469        let expected = "Projection: avg(t.number)\
470        \n  MergeScan [is_placeholder=false]";
471        assert_eq!(expected, result.to_string());
472    }
473
474    #[test]
475    fn transform_distinct_order() {
476        let numbers_table = NumbersTable::table(0);
477        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
478            DfTableProviderAdapter::new(numbers_table),
479        )));
480
481        let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
482            .unwrap()
483            .distinct()
484            .unwrap()
485            .sort(vec![col("number").sort(true, false)])
486            .unwrap()
487            .build()
488            .unwrap();
489
490        let config = ConfigOptions::default();
491        let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
492        let expected = ["Projection: t.number", "  MergeScan [is_placeholder=false]"].join("\n");
493        assert_eq!(expected, result.to_string());
494    }
495
496    #[test]
497    fn transform_single_limit() {
498        let numbers_table = NumbersTable::table(0);
499        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
500            DfTableProviderAdapter::new(numbers_table),
501        )));
502
503        let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
504            .unwrap()
505            .limit(0, Some(1))
506            .unwrap()
507            .build()
508            .unwrap();
509
510        let config = ConfigOptions::default();
511        let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
512        let expected = "Projection: t.number\
513        \n  MergeScan [is_placeholder=false]";
514        assert_eq!(expected, result.to_string());
515    }
516
517    #[test]
518    fn transform_unalighed_join_with_alias() {
519        let left = NumbersTable::table(0);
520        let right = NumbersTable::table(1);
521        let left_source = Arc::new(DefaultTableSource::new(Arc::new(
522            DfTableProviderAdapter::new(left),
523        )));
524        let right_source = Arc::new(DefaultTableSource::new(Arc::new(
525            DfTableProviderAdapter::new(right),
526        )));
527
528        let right_plan = LogicalPlanBuilder::scan_with_filters("t", right_source, None, vec![])
529            .unwrap()
530            .alias("right")
531            .unwrap()
532            .build()
533            .unwrap();
534
535        let plan = LogicalPlanBuilder::scan_with_filters("t", left_source, None, vec![])
536            .unwrap()
537            .join_on(
538                right_plan,
539                JoinType::LeftSemi,
540                vec![col("t.number").eq(col("right.number"))],
541            )
542            .unwrap()
543            .limit(0, Some(1))
544            .unwrap()
545            .build()
546            .unwrap();
547
548        let config = ConfigOptions::default();
549        let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
550        let expected = [
551            "Limit: skip=0, fetch=1",
552            "  LeftSemi Join:  Filter: t.number = right.number",
553            "    Projection: t.number",
554            "      MergeScan [is_placeholder=false]",
555            "    SubqueryAlias: right",
556            "      Projection: t.number",
557            "        MergeScan [is_placeholder=false]",
558        ]
559        .join("\n");
560        assert_eq!(expected, result.to_string());
561    }
562}