query/dist_plan/
analyzer.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::sync::Arc;
17
18use common_telemetry::debug;
19use datafusion::datasource::DefaultTableSource;
20use datafusion::error::Result as DfResult;
21use datafusion_common::config::ConfigOptions;
22use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
23use datafusion_common::Column;
24use datafusion_expr::expr::{Exists, InSubquery};
25use datafusion_expr::utils::expr_to_columns;
26use datafusion_expr::{col as col_fn, Expr, LogicalPlan, LogicalPlanBuilder, Subquery};
27use datafusion_optimizer::analyzer::AnalyzerRule;
28use datafusion_optimizer::simplify_expressions::SimplifyExpressions;
29use datafusion_optimizer::{OptimizerContext, OptimizerRule};
30use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
31use table::metadata::TableType;
32use table::table::adapter::DfTableProviderAdapter;
33
34use crate::dist_plan::commutativity::{
35    partial_commutative_transformer, Categorizer, Commutativity,
36};
37use crate::dist_plan::merge_scan::MergeScanLogicalPlan;
38use crate::plan::ExtractExpr;
39use crate::query_engine::DefaultSerializer;
40
41#[derive(Debug)]
42pub struct DistPlannerAnalyzer;
43
44impl AnalyzerRule for DistPlannerAnalyzer {
45    fn name(&self) -> &str {
46        "DistPlannerAnalyzer"
47    }
48
49    fn analyze(
50        &self,
51        plan: LogicalPlan,
52        _config: &ConfigOptions,
53    ) -> datafusion_common::Result<LogicalPlan> {
54        // preprocess the input plan
55        let optimizer_context = OptimizerContext::new();
56        let plan = SimplifyExpressions::new()
57            .rewrite(plan, &optimizer_context)?
58            .data;
59
60        let plan = plan.transform(&Self::inspect_plan_with_subquery)?;
61        let mut rewriter = PlanRewriter::default();
62        let result = plan.data.rewrite(&mut rewriter)?.data;
63
64        Ok(result)
65    }
66}
67
68impl DistPlannerAnalyzer {
69    fn inspect_plan_with_subquery(plan: LogicalPlan) -> DfResult<Transformed<LogicalPlan>> {
70        // Workaround for https://github.com/GreptimeTeam/greptimedb/issues/5469 and https://github.com/GreptimeTeam/greptimedb/issues/5799
71        // FIXME(yingwen): Remove the `Limit` plan once we update DataFusion.
72        if let LogicalPlan::Limit(_) | LogicalPlan::Distinct(_) = &plan {
73            return Ok(Transformed::no(plan));
74        }
75
76        let exprs = plan
77            .expressions_consider_join()
78            .into_iter()
79            .map(|e| e.transform(&Self::transform_subquery).map(|x| x.data))
80            .collect::<DfResult<Vec<_>>>()?;
81
82        // Some plans that are special treated (should not call `with_new_exprs` on them)
83        if !matches!(plan, LogicalPlan::Unnest(_)) {
84            let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
85            Ok(Transformed::yes(plan.with_new_exprs(exprs, inputs)?))
86        } else {
87            Ok(Transformed::no(plan))
88        }
89    }
90
91    fn transform_subquery(expr: Expr) -> DfResult<Transformed<Expr>> {
92        match expr {
93            Expr::Exists(exists) => Ok(Transformed::yes(Expr::Exists(Exists {
94                subquery: Self::handle_subquery(exists.subquery)?,
95                negated: exists.negated,
96            }))),
97            Expr::InSubquery(in_subquery) => Ok(Transformed::yes(Expr::InSubquery(InSubquery {
98                expr: in_subquery.expr,
99                subquery: Self::handle_subquery(in_subquery.subquery)?,
100                negated: in_subquery.negated,
101            }))),
102            Expr::ScalarSubquery(scalar_subquery) => Ok(Transformed::yes(Expr::ScalarSubquery(
103                Self::handle_subquery(scalar_subquery)?,
104            ))),
105
106            _ => Ok(Transformed::no(expr)),
107        }
108    }
109
110    fn handle_subquery(subquery: Subquery) -> DfResult<Subquery> {
111        let mut rewriter = PlanRewriter::default();
112        let mut rewrote_subquery = subquery
113            .subquery
114            .as_ref()
115            .clone()
116            .rewrite(&mut rewriter)?
117            .data;
118        // Workaround. DF doesn't support the first plan in subquery to be an Extension
119        if matches!(rewrote_subquery, LogicalPlan::Extension(_)) {
120            let output_schema = rewrote_subquery.schema().clone();
121            let project_exprs = output_schema
122                .fields()
123                .iter()
124                .map(|f| col_fn(f.name()))
125                .collect::<Vec<_>>();
126            rewrote_subquery = LogicalPlanBuilder::from(rewrote_subquery)
127                .project(project_exprs)?
128                .build()?;
129        }
130
131        Ok(Subquery {
132            subquery: Arc::new(rewrote_subquery),
133            outer_ref_columns: subquery.outer_ref_columns,
134        })
135    }
136}
137
138/// Status of the rewriter to mark if the current pass is expanded
139#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
140enum RewriterStatus {
141    #[default]
142    Unexpanded,
143    Expanded,
144}
145
146#[derive(Debug, Default)]
147struct PlanRewriter {
148    /// Current level in the tree
149    level: usize,
150    /// Simulated stack for the `rewrite` recursion
151    stack: Vec<(LogicalPlan, usize)>,
152    /// Stages to be expanded, will be added as parent node of merge scan one by one
153    stage: Vec<LogicalPlan>,
154    status: RewriterStatus,
155    /// Partition columns of the table in current pass
156    partition_cols: Option<Vec<String>>,
157    column_requirements: HashSet<Column>,
158    /// Whether to expand on next call
159    /// This is used to handle the case where a plan is transformed, but need to be expanded from it's
160    /// parent node. For example a Aggregate plan is split into two parts in frontend and datanode, and need
161    /// to be expanded from the parent node of the Aggregate plan.
162    expand_on_next_call: bool,
163    /// Expanding on next partial/conditional/transformed commutative plan
164    /// This is used to handle the case where a plan is transformed, but still
165    /// need to push down as many node as possible before next partial/conditional/transformed commutative
166    /// plan. I.e.
167    /// ```
168    /// Limit:
169    ///     Sort:
170    /// ```
171    /// where `Limit` is partial commutative, and `Sort` is conditional commutative.
172    /// In this case, we need to expand the `Limit` plan,
173    /// so that we can push down the `Sort` plan as much as possible.
174    expand_on_next_part_cond_trans_commutative: bool,
175    new_child_plan: Option<LogicalPlan>,
176}
177
178impl PlanRewriter {
179    fn get_parent(&self) -> Option<&LogicalPlan> {
180        // level starts from 1, it's safe to minus by 1
181        self.stack
182            .iter()
183            .rev()
184            .find(|(_, level)| *level == self.level - 1)
185            .map(|(node, _)| node)
186    }
187
188    /// Return true if should stop and expand. The input plan is the parent node of current node
189    fn should_expand(&mut self, plan: &LogicalPlan) -> bool {
190        if DFLogicalSubstraitConvertor
191            .encode(plan, DefaultSerializer)
192            .is_err()
193        {
194            return true;
195        }
196
197        if self.expand_on_next_call {
198            self.expand_on_next_call = false;
199            return true;
200        }
201
202        if self.expand_on_next_part_cond_trans_commutative {
203            let comm = Categorizer::check_plan(plan, self.partition_cols.clone());
204            match comm {
205                Commutativity::PartialCommutative => {
206                    // a small difference is that for partial commutative, we still need to
207                    // expand on next call(so `Limit` can be pushed down)
208                    self.expand_on_next_part_cond_trans_commutative = false;
209                    self.expand_on_next_call = true;
210                }
211                Commutativity::ConditionalCommutative(_)
212                | Commutativity::TransformedCommutative { .. } => {
213                    // for conditional commutative and transformed commutative, we can
214                    // expand now
215                    self.expand_on_next_part_cond_trans_commutative = false;
216                    return true;
217                }
218                _ => (),
219            }
220        }
221
222        match Categorizer::check_plan(plan, self.partition_cols.clone()) {
223            Commutativity::Commutative => {}
224            Commutativity::PartialCommutative => {
225                if let Some(plan) = partial_commutative_transformer(plan) {
226                    self.update_column_requirements(&plan);
227                    self.expand_on_next_part_cond_trans_commutative = true;
228                    self.stage.push(plan)
229                }
230            }
231            Commutativity::ConditionalCommutative(transformer) => {
232                if let Some(transformer) = transformer
233                    && let Some(plan) = transformer(plan)
234                {
235                    self.update_column_requirements(&plan);
236                    self.expand_on_next_part_cond_trans_commutative = true;
237                    self.stage.push(plan)
238                }
239            }
240            Commutativity::TransformedCommutative { transformer } => {
241                if let Some(transformer) = transformer
242                    && let Some(transformer_actions) = transformer(plan)
243                {
244                    debug!(
245                        "PlanRewriter: transformed plan: {:?}\n from {plan}",
246                        transformer_actions.extra_parent_plans
247                    );
248                    if let Some(last_stage) = transformer_actions.extra_parent_plans.last() {
249                        // update the column requirements from the last stage
250                        self.update_column_requirements(last_stage);
251                    }
252                    self.stage
253                        .extend(transformer_actions.extra_parent_plans.into_iter().rev());
254                    self.expand_on_next_call = true;
255                    self.new_child_plan = transformer_actions.new_child_plan;
256                }
257            }
258            Commutativity::NonCommutative
259            | Commutativity::Unimplemented
260            | Commutativity::Unsupported => {
261                return true;
262            }
263        }
264
265        false
266    }
267
268    fn update_column_requirements(&mut self, plan: &LogicalPlan) {
269        debug!(
270            "PlanRewriter: update column requirements for plan: {plan}\n withcolumn_requirements: {:?}",
271            self.column_requirements
272        );
273        let mut container = HashSet::new();
274        for expr in plan.expressions() {
275            // this method won't fail
276            let _ = expr_to_columns(&expr, &mut container);
277        }
278
279        for col in container {
280            self.column_requirements.insert(col);
281        }
282        debug!(
283            "PlanRewriter: updated column requirements: {:?}",
284            self.column_requirements
285        );
286    }
287
288    fn is_expanded(&self) -> bool {
289        self.status == RewriterStatus::Expanded
290    }
291
292    fn set_expanded(&mut self) {
293        self.status = RewriterStatus::Expanded;
294    }
295
296    fn set_unexpanded(&mut self) {
297        self.status = RewriterStatus::Unexpanded;
298    }
299
300    fn maybe_set_partitions(&mut self, plan: &LogicalPlan) {
301        if self.partition_cols.is_some() {
302            // only need to set once
303            return;
304        }
305
306        if let LogicalPlan::TableScan(table_scan) = plan {
307            if let Some(source) = table_scan
308                .source
309                .as_any()
310                .downcast_ref::<DefaultTableSource>()
311            {
312                if let Some(provider) = source
313                    .table_provider
314                    .as_any()
315                    .downcast_ref::<DfTableProviderAdapter>()
316                {
317                    if provider.table().table_type() == TableType::Base {
318                        let info = provider.table().table_info();
319                        let partition_key_indices = info.meta.partition_key_indices.clone();
320                        let schema = info.meta.schema.clone();
321                        let partition_cols = partition_key_indices
322                            .into_iter()
323                            .map(|index| schema.column_name_by_index(index).to_string())
324                            .collect::<Vec<String>>();
325                        self.partition_cols = Some(partition_cols);
326                    }
327                }
328            }
329        }
330    }
331
332    /// pop one stack item and reduce the level by 1
333    fn pop_stack(&mut self) {
334        self.level -= 1;
335        self.stack.pop();
336    }
337
338    fn expand(&mut self, mut on_node: LogicalPlan) -> DfResult<LogicalPlan> {
339        if let Some(new_child_plan) = self.new_child_plan.take() {
340            // if there is a new child plan, use it as the new root
341            on_node = new_child_plan;
342        }
343        // store schema before expand
344        let schema = on_node.schema().clone();
345        let mut rewriter = EnforceDistRequirementRewriter {
346            column_requirements: std::mem::take(&mut self.column_requirements),
347        };
348        on_node = on_node.rewrite(&mut rewriter)?.data;
349
350        // add merge scan as the new root
351        let mut node = MergeScanLogicalPlan::new(
352            on_node,
353            false,
354            // at this stage, the partition cols should be set
355            // treat it as non-partitioned if None
356            self.partition_cols.clone().unwrap_or_default(),
357        )
358        .into_logical_plan();
359
360        // expand stages
361        for new_stage in self.stage.drain(..) {
362            node = new_stage
363                .with_new_exprs(new_stage.expressions_consider_join(), vec![node.clone()])?;
364        }
365        self.set_expanded();
366
367        // recover the schema
368        let node = LogicalPlanBuilder::from(node)
369            .project(schema.iter().map(|(qualifier, field)| {
370                Expr::Column(Column::new(qualifier.cloned(), field.name()))
371            }))?
372            .build()?;
373
374        Ok(node)
375    }
376}
377
378/// Implementation of the [`TreeNodeRewriter`] trait which is responsible for rewriting
379/// logical plans to enforce various requirement for distributed query.
380///
381/// Requirements enforced by this rewriter:
382/// - Enforce column requirements for `LogicalPlan::Projection` nodes. Makes sure the
383///   required columns are available in the sub plan.
384struct EnforceDistRequirementRewriter {
385    column_requirements: HashSet<Column>,
386}
387
388impl TreeNodeRewriter for EnforceDistRequirementRewriter {
389    type Node = LogicalPlan;
390
391    fn f_down(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
392        if let LogicalPlan::Projection(ref projection) = node {
393            let mut column_requirements = std::mem::take(&mut self.column_requirements);
394            if column_requirements.is_empty() {
395                return Ok(Transformed::no(node));
396            }
397
398            for expr in &projection.expr {
399                let (qualifier, name) = expr.qualified_name();
400                let column = Column::new(qualifier, name);
401                column_requirements.remove(&column);
402            }
403            if column_requirements.is_empty() {
404                return Ok(Transformed::no(node));
405            }
406
407            let mut new_exprs = projection.expr.clone();
408            for col in &column_requirements {
409                new_exprs.push(Expr::Column(col.clone()));
410            }
411            let new_node =
412                node.with_new_exprs(new_exprs, node.inputs().into_iter().cloned().collect())?;
413            return Ok(Transformed::yes(new_node));
414        }
415
416        Ok(Transformed::no(node))
417    }
418
419    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
420        Ok(Transformed::no(node))
421    }
422}
423
424impl TreeNodeRewriter for PlanRewriter {
425    type Node = LogicalPlan;
426
427    /// descend
428    fn f_down<'a>(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
429        self.level += 1;
430        self.stack.push((node.clone(), self.level));
431        // decendening will clear the stage
432        self.stage.clear();
433        self.set_unexpanded();
434        self.partition_cols = None;
435        Ok(Transformed::no(node))
436    }
437
438    /// ascend
439    ///
440    /// Besure to call `pop_stack` before returning
441    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
442        // only expand once on each ascending
443        if self.is_expanded() {
444            self.pop_stack();
445            return Ok(Transformed::no(node));
446        }
447
448        // only expand when the leaf is table scan
449        if node.inputs().is_empty() && !matches!(node, LogicalPlan::TableScan(_)) {
450            self.set_expanded();
451            self.pop_stack();
452            return Ok(Transformed::no(node));
453        }
454
455        self.maybe_set_partitions(&node);
456
457        let Some(parent) = self.get_parent() else {
458            let node = self.expand(node)?;
459            self.pop_stack();
460            return Ok(Transformed::yes(node));
461        };
462
463        let parent = parent.clone();
464
465        // TODO(ruihang): avoid this clone
466        if self.should_expand(&parent) {
467            // TODO(ruihang): does this work for nodes with multiple children?;
468            debug!("PlanRewriter: should expand child:\n {node}\n Of Parent: {parent}");
469            let node = self.expand(node);
470            debug!(
471                "PlanRewriter: expanded plan: {}",
472                match &node {
473                    Ok(n) => n.to_string(),
474                    Err(e) => format!("Error expanding plan: {e}"),
475                }
476            );
477            let node = node?;
478            self.pop_stack();
479            return Ok(Transformed::yes(node));
480        }
481
482        self.pop_stack();
483        Ok(Transformed::no(node))
484    }
485}
486
487#[cfg(test)]
488mod test {
489    use std::sync::Arc;
490
491    use datafusion::datasource::DefaultTableSource;
492    use datafusion::functions_aggregate::expr_fn::avg;
493    use datafusion_common::JoinType;
494    use datafusion_expr::{col, lit, Expr, LogicalPlanBuilder};
495    use table::table::adapter::DfTableProviderAdapter;
496    use table::table::numbers::NumbersTable;
497
498    use super::*;
499
500    #[ignore = "Projection is disabled for https://github.com/apache/arrow-datafusion/issues/6489"]
501    #[test]
502    fn transform_simple_projection_filter() {
503        let numbers_table = NumbersTable::table(0);
504        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
505            DfTableProviderAdapter::new(numbers_table),
506        )));
507
508        let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
509            .unwrap()
510            .filter(col("number").lt(lit(10)))
511            .unwrap()
512            .project(vec![col("number")])
513            .unwrap()
514            .distinct()
515            .unwrap()
516            .build()
517            .unwrap();
518
519        let config = ConfigOptions::default();
520        let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
521        let expected = [
522            "Distinct:",
523            "  MergeScan [is_placeholder=false]",
524            "    Distinct:",
525            "      Projection: t.number",
526            "        Filter: t.number < Int32(10)",
527            "          TableScan: t",
528        ]
529        .join("\n");
530        assert_eq!(expected, result.to_string());
531    }
532
533    #[test]
534    fn transform_aggregator() {
535        let numbers_table = NumbersTable::table(0);
536        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
537            DfTableProviderAdapter::new(numbers_table),
538        )));
539
540        let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
541            .unwrap()
542            .aggregate(Vec::<Expr>::new(), vec![avg(col("number"))])
543            .unwrap()
544            .build()
545            .unwrap();
546
547        let config = ConfigOptions::default();
548        let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
549        let expected = "Projection: avg(t.number)\
550        \n  MergeScan [is_placeholder=false]";
551        assert_eq!(expected, result.to_string());
552    }
553
554    #[test]
555    fn transform_distinct_order() {
556        let numbers_table = NumbersTable::table(0);
557        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
558            DfTableProviderAdapter::new(numbers_table),
559        )));
560
561        let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
562            .unwrap()
563            .distinct()
564            .unwrap()
565            .sort(vec![col("number").sort(true, false)])
566            .unwrap()
567            .build()
568            .unwrap();
569
570        let config = ConfigOptions::default();
571        let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
572        let expected = ["Projection: t.number", "  MergeScan [is_placeholder=false]"].join("\n");
573        assert_eq!(expected, result.to_string());
574    }
575
576    #[test]
577    fn transform_single_limit() {
578        let numbers_table = NumbersTable::table(0);
579        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
580            DfTableProviderAdapter::new(numbers_table),
581        )));
582
583        let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
584            .unwrap()
585            .limit(0, Some(1))
586            .unwrap()
587            .build()
588            .unwrap();
589
590        let config = ConfigOptions::default();
591        let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
592        let expected = "Projection: t.number\
593        \n  MergeScan [is_placeholder=false]";
594        assert_eq!(expected, result.to_string());
595    }
596
597    #[test]
598    fn transform_unalighed_join_with_alias() {
599        let left = NumbersTable::table(0);
600        let right = NumbersTable::table(1);
601        let left_source = Arc::new(DefaultTableSource::new(Arc::new(
602            DfTableProviderAdapter::new(left),
603        )));
604        let right_source = Arc::new(DefaultTableSource::new(Arc::new(
605            DfTableProviderAdapter::new(right),
606        )));
607
608        let right_plan = LogicalPlanBuilder::scan_with_filters("t", right_source, None, vec![])
609            .unwrap()
610            .alias("right")
611            .unwrap()
612            .build()
613            .unwrap();
614
615        let plan = LogicalPlanBuilder::scan_with_filters("t", left_source, None, vec![])
616            .unwrap()
617            .join_on(
618                right_plan,
619                JoinType::LeftSemi,
620                vec![col("t.number").eq(col("right.number"))],
621            )
622            .unwrap()
623            .limit(0, Some(1))
624            .unwrap()
625            .build()
626            .unwrap();
627
628        let config = ConfigOptions::default();
629        let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
630        let expected = [
631            "Limit: skip=0, fetch=1",
632            "  LeftSemi Join:  Filter: t.number = right.number",
633            "    Projection: t.number",
634            "      MergeScan [is_placeholder=false]",
635            "    SubqueryAlias: right",
636            "      Projection: t.number",
637            "        MergeScan [is_placeholder=false]",
638        ]
639        .join("\n");
640        assert_eq!(expected, result.to_string());
641    }
642}