query/dist_plan/
analyzer.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::sync::Arc;
17
18use common_telemetry::debug;
19use datafusion::datasource::DefaultTableSource;
20use datafusion::error::Result as DfResult;
21use datafusion_common::config::ConfigOptions;
22use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
23use datafusion_common::Column;
24use datafusion_expr::expr::{Exists, InSubquery};
25use datafusion_expr::utils::expr_to_columns;
26use datafusion_expr::{col as col_fn, Expr, LogicalPlan, LogicalPlanBuilder, Subquery};
27use datafusion_optimizer::analyzer::AnalyzerRule;
28use datafusion_optimizer::simplify_expressions::SimplifyExpressions;
29use datafusion_optimizer::{OptimizerContext, OptimizerRule};
30use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
31use table::metadata::TableType;
32use table::table::adapter::DfTableProviderAdapter;
33
34use crate::dist_plan::commutativity::{
35    partial_commutative_transformer, Categorizer, Commutativity,
36};
37use crate::dist_plan::merge_scan::MergeScanLogicalPlan;
38use crate::plan::ExtractExpr;
39use crate::query_engine::DefaultSerializer;
40
41#[derive(Debug)]
42pub struct DistPlannerAnalyzer;
43
44impl AnalyzerRule for DistPlannerAnalyzer {
45    fn name(&self) -> &str {
46        "DistPlannerAnalyzer"
47    }
48
49    fn analyze(
50        &self,
51        plan: LogicalPlan,
52        _config: &ConfigOptions,
53    ) -> datafusion_common::Result<LogicalPlan> {
54        // preprocess the input plan
55        let optimizer_context = OptimizerContext::new();
56        let plan = SimplifyExpressions::new()
57            .rewrite(plan, &optimizer_context)?
58            .data;
59
60        let plan = plan.transform(&Self::inspect_plan_with_subquery)?;
61        let mut rewriter = PlanRewriter::default();
62        let result = plan.data.rewrite(&mut rewriter)?.data;
63
64        Ok(result)
65    }
66}
67
68impl DistPlannerAnalyzer {
69    fn inspect_plan_with_subquery(plan: LogicalPlan) -> DfResult<Transformed<LogicalPlan>> {
70        // Workaround for https://github.com/GreptimeTeam/greptimedb/issues/5469 and https://github.com/GreptimeTeam/greptimedb/issues/5799
71        // FIXME(yingwen): Remove the `Limit` plan once we update DataFusion.
72        if let LogicalPlan::Limit(_) | LogicalPlan::Distinct(_) = &plan {
73            return Ok(Transformed::no(plan));
74        }
75
76        let exprs = plan
77            .expressions_consider_join()
78            .into_iter()
79            .map(|e| e.transform(&Self::transform_subquery).map(|x| x.data))
80            .collect::<DfResult<Vec<_>>>()?;
81
82        // Some plans that are special treated (should not call `with_new_exprs` on them)
83        if !matches!(plan, LogicalPlan::Unnest(_)) {
84            let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
85            Ok(Transformed::yes(plan.with_new_exprs(exprs, inputs)?))
86        } else {
87            Ok(Transformed::no(plan))
88        }
89    }
90
91    fn transform_subquery(expr: Expr) -> DfResult<Transformed<Expr>> {
92        match expr {
93            Expr::Exists(exists) => Ok(Transformed::yes(Expr::Exists(Exists {
94                subquery: Self::handle_subquery(exists.subquery)?,
95                negated: exists.negated,
96            }))),
97            Expr::InSubquery(in_subquery) => Ok(Transformed::yes(Expr::InSubquery(InSubquery {
98                expr: in_subquery.expr,
99                subquery: Self::handle_subquery(in_subquery.subquery)?,
100                negated: in_subquery.negated,
101            }))),
102            Expr::ScalarSubquery(scalar_subquery) => Ok(Transformed::yes(Expr::ScalarSubquery(
103                Self::handle_subquery(scalar_subquery)?,
104            ))),
105
106            _ => Ok(Transformed::no(expr)),
107        }
108    }
109
110    fn handle_subquery(subquery: Subquery) -> DfResult<Subquery> {
111        let mut rewriter = PlanRewriter::default();
112        let mut rewrote_subquery = subquery
113            .subquery
114            .as_ref()
115            .clone()
116            .rewrite(&mut rewriter)?
117            .data;
118        // Workaround. DF doesn't support the first plan in subquery to be an Extension
119        if matches!(rewrote_subquery, LogicalPlan::Extension(_)) {
120            let output_schema = rewrote_subquery.schema().clone();
121            let project_exprs = output_schema
122                .fields()
123                .iter()
124                .map(|f| col_fn(f.name()))
125                .collect::<Vec<_>>();
126            rewrote_subquery = LogicalPlanBuilder::from(rewrote_subquery)
127                .project(project_exprs)?
128                .build()?;
129        }
130
131        Ok(Subquery {
132            subquery: Arc::new(rewrote_subquery),
133            outer_ref_columns: subquery.outer_ref_columns,
134        })
135    }
136}
137
138/// Status of the rewriter to mark if the current pass is expanded
139#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
140enum RewriterStatus {
141    #[default]
142    Unexpanded,
143    Expanded,
144}
145
146#[derive(Debug, Default)]
147struct PlanRewriter {
148    /// Current level in the tree
149    level: usize,
150    /// Simulated stack for the `rewrite` recursion
151    stack: Vec<(LogicalPlan, usize)>,
152    /// Stages to be expanded, will be added as parent node of merge scan one by one
153    stage: Vec<LogicalPlan>,
154    status: RewriterStatus,
155    /// Partition columns of the table in current pass
156    partition_cols: Option<Vec<String>>,
157    column_requirements: HashSet<Column>,
158    expand_on_next_call: bool,
159    new_child_plan: Option<LogicalPlan>,
160}
161
162impl PlanRewriter {
163    fn get_parent(&self) -> Option<&LogicalPlan> {
164        // level starts from 1, it's safe to minus by 1
165        self.stack
166            .iter()
167            .rev()
168            .find(|(_, level)| *level == self.level - 1)
169            .map(|(node, _)| node)
170    }
171
172    /// Return true if should stop and expand. The input plan is the parent node of current node
173    fn should_expand(&mut self, plan: &LogicalPlan) -> bool {
174        if DFLogicalSubstraitConvertor
175            .encode(plan, DefaultSerializer)
176            .is_err()
177        {
178            return true;
179        }
180        if self.expand_on_next_call {
181            self.expand_on_next_call = false;
182            return true;
183        }
184        match Categorizer::check_plan(plan, self.partition_cols.clone()) {
185            Commutativity::Commutative => {}
186            Commutativity::PartialCommutative => {
187                if let Some(plan) = partial_commutative_transformer(plan) {
188                    self.update_column_requirements(&plan);
189                    self.stage.push(plan)
190                }
191            }
192            Commutativity::ConditionalCommutative(transformer) => {
193                if let Some(transformer) = transformer
194                    && let Some(plan) = transformer(plan)
195                {
196                    self.update_column_requirements(&plan);
197                    self.stage.push(plan)
198                }
199            }
200            Commutativity::TransformedCommutative { transformer } => {
201                if let Some(transformer) = transformer
202                    && let Some(transformer_actions) = transformer(plan)
203                {
204                    debug!(
205                        "PlanRewriter: transformed plan: {:#?}\n from {plan}",
206                        transformer_actions.extra_parent_plans
207                    );
208                    if let Some(last_stage) = transformer_actions.extra_parent_plans.last() {
209                        // update the column requirements from the last stage
210                        self.update_column_requirements(last_stage);
211                    }
212                    self.stage
213                        .extend(transformer_actions.extra_parent_plans.into_iter().rev());
214                    self.expand_on_next_call = true;
215                    self.new_child_plan = transformer_actions.new_child_plan;
216                }
217            }
218            Commutativity::NonCommutative
219            | Commutativity::Unimplemented
220            | Commutativity::Unsupported => {
221                return true;
222            }
223        }
224
225        false
226    }
227
228    fn update_column_requirements(&mut self, plan: &LogicalPlan) {
229        let mut container = HashSet::new();
230        for expr in plan.expressions() {
231            // this method won't fail
232            let _ = expr_to_columns(&expr, &mut container);
233        }
234
235        for col in container {
236            self.column_requirements.insert(col);
237        }
238    }
239
240    fn is_expanded(&self) -> bool {
241        self.status == RewriterStatus::Expanded
242    }
243
244    fn set_expanded(&mut self) {
245        self.status = RewriterStatus::Expanded;
246    }
247
248    fn set_unexpanded(&mut self) {
249        self.status = RewriterStatus::Unexpanded;
250    }
251
252    fn maybe_set_partitions(&mut self, plan: &LogicalPlan) {
253        if self.partition_cols.is_some() {
254            // only need to set once
255            return;
256        }
257
258        if let LogicalPlan::TableScan(table_scan) = plan {
259            if let Some(source) = table_scan
260                .source
261                .as_any()
262                .downcast_ref::<DefaultTableSource>()
263            {
264                if let Some(provider) = source
265                    .table_provider
266                    .as_any()
267                    .downcast_ref::<DfTableProviderAdapter>()
268                {
269                    if provider.table().table_type() == TableType::Base {
270                        let info = provider.table().table_info();
271                        let partition_key_indices = info.meta.partition_key_indices.clone();
272                        let schema = info.meta.schema.clone();
273                        let partition_cols = partition_key_indices
274                            .into_iter()
275                            .map(|index| schema.column_name_by_index(index).to_string())
276                            .collect::<Vec<String>>();
277                        self.partition_cols = Some(partition_cols);
278                    }
279                }
280            }
281        }
282    }
283
284    /// pop one stack item and reduce the level by 1
285    fn pop_stack(&mut self) {
286        self.level -= 1;
287        self.stack.pop();
288    }
289
290    fn expand(&mut self, mut on_node: LogicalPlan) -> DfResult<LogicalPlan> {
291        if let Some(new_child_plan) = self.new_child_plan.take() {
292            // if there is a new child plan, use it as the new root
293            on_node = new_child_plan;
294        }
295        // store schema before expand
296        let schema = on_node.schema().clone();
297        let mut rewriter = EnforceDistRequirementRewriter {
298            column_requirements: std::mem::take(&mut self.column_requirements),
299        };
300        on_node = on_node.rewrite(&mut rewriter)?.data;
301
302        // add merge scan as the new root
303        let mut node = MergeScanLogicalPlan::new(
304            on_node,
305            false,
306            // at this stage, the partition cols should be set
307            // treat it as non-partitioned if None
308            self.partition_cols.clone().unwrap_or_default(),
309        )
310        .into_logical_plan();
311
312        // expand stages
313        for new_stage in self.stage.drain(..) {
314            node = new_stage
315                .with_new_exprs(new_stage.expressions_consider_join(), vec![node.clone()])?;
316        }
317        self.set_expanded();
318
319        // recover the schema
320        let node = LogicalPlanBuilder::from(node)
321            .project(schema.iter().map(|(qualifier, field)| {
322                Expr::Column(Column::new(qualifier.cloned(), field.name()))
323            }))?
324            .build()?;
325
326        Ok(node)
327    }
328}
329
330/// Implementation of the [`TreeNodeRewriter`] trait which is responsible for rewriting
331/// logical plans to enforce various requirement for distributed query.
332///
333/// Requirements enforced by this rewriter:
334/// - Enforce column requirements for `LogicalPlan::Projection` nodes. Makes sure the
335///   required columns are available in the sub plan.
336struct EnforceDistRequirementRewriter {
337    column_requirements: HashSet<Column>,
338}
339
340impl TreeNodeRewriter for EnforceDistRequirementRewriter {
341    type Node = LogicalPlan;
342
343    fn f_down(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
344        if let LogicalPlan::Projection(ref projection) = node {
345            let mut column_requirements = std::mem::take(&mut self.column_requirements);
346            if column_requirements.is_empty() {
347                return Ok(Transformed::no(node));
348            }
349
350            for expr in &projection.expr {
351                let (qualifier, name) = expr.qualified_name();
352                let column = Column::new(qualifier, name);
353                column_requirements.remove(&column);
354            }
355            if column_requirements.is_empty() {
356                return Ok(Transformed::no(node));
357            }
358
359            let mut new_exprs = projection.expr.clone();
360            for col in &column_requirements {
361                new_exprs.push(Expr::Column(col.clone()));
362            }
363            let new_node =
364                node.with_new_exprs(new_exprs, node.inputs().into_iter().cloned().collect())?;
365            return Ok(Transformed::yes(new_node));
366        }
367
368        Ok(Transformed::no(node))
369    }
370
371    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
372        Ok(Transformed::no(node))
373    }
374}
375
376impl TreeNodeRewriter for PlanRewriter {
377    type Node = LogicalPlan;
378
379    /// descend
380    fn f_down<'a>(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
381        self.level += 1;
382        self.stack.push((node.clone(), self.level));
383        // decendening will clear the stage
384        self.stage.clear();
385        self.set_unexpanded();
386        self.partition_cols = None;
387        Ok(Transformed::no(node))
388    }
389
390    /// ascend
391    ///
392    /// Besure to call `pop_stack` before returning
393    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
394        // only expand once on each ascending
395        if self.is_expanded() {
396            self.pop_stack();
397            return Ok(Transformed::no(node));
398        }
399
400        // only expand when the leaf is table scan
401        if node.inputs().is_empty() && !matches!(node, LogicalPlan::TableScan(_)) {
402            self.set_expanded();
403            self.pop_stack();
404            return Ok(Transformed::no(node));
405        }
406
407        self.maybe_set_partitions(&node);
408
409        let Some(parent) = self.get_parent() else {
410            let node = self.expand(node)?;
411            self.pop_stack();
412            return Ok(Transformed::yes(node));
413        };
414
415        let parent = parent.clone();
416
417        // TODO(ruihang): avoid this clone
418        if self.should_expand(&parent) {
419            // TODO(ruihang): does this work for nodes with multiple children?;
420            debug!("PlanRewriter: should expand child:\n {node}\n Of Parent: {parent}");
421            let node = self.expand(node);
422            debug!(
423                "PlanRewriter: expanded plan: {}",
424                match &node {
425                    Ok(n) => n.to_string(),
426                    Err(e) => format!("Error expanding plan: {e}"),
427                }
428            );
429            let node = node?;
430            self.pop_stack();
431            return Ok(Transformed::yes(node));
432        }
433
434        self.pop_stack();
435        Ok(Transformed::no(node))
436    }
437}
438
439#[cfg(test)]
440mod test {
441    use std::sync::Arc;
442
443    use datafusion::datasource::DefaultTableSource;
444    use datafusion::functions_aggregate::expr_fn::avg;
445    use datafusion_common::JoinType;
446    use datafusion_expr::{col, lit, Expr, LogicalPlanBuilder};
447    use table::table::adapter::DfTableProviderAdapter;
448    use table::table::numbers::NumbersTable;
449
450    use super::*;
451
452    #[ignore = "Projection is disabled for https://github.com/apache/arrow-datafusion/issues/6489"]
453    #[test]
454    fn transform_simple_projection_filter() {
455        let numbers_table = NumbersTable::table(0);
456        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
457            DfTableProviderAdapter::new(numbers_table),
458        )));
459
460        let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
461            .unwrap()
462            .filter(col("number").lt(lit(10)))
463            .unwrap()
464            .project(vec![col("number")])
465            .unwrap()
466            .distinct()
467            .unwrap()
468            .build()
469            .unwrap();
470
471        let config = ConfigOptions::default();
472        let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
473        let expected = [
474            "Distinct:",
475            "  MergeScan [is_placeholder=false]",
476            "    Distinct:",
477            "      Projection: t.number",
478            "        Filter: t.number < Int32(10)",
479            "          TableScan: t",
480        ]
481        .join("\n");
482        assert_eq!(expected, result.to_string());
483    }
484
485    #[test]
486    fn transform_aggregator() {
487        let numbers_table = NumbersTable::table(0);
488        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
489            DfTableProviderAdapter::new(numbers_table),
490        )));
491
492        let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
493            .unwrap()
494            .aggregate(Vec::<Expr>::new(), vec![avg(col("number"))])
495            .unwrap()
496            .build()
497            .unwrap();
498
499        let config = ConfigOptions::default();
500        let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
501        let expected = "Projection: avg(t.number)\
502        \n  MergeScan [is_placeholder=false]";
503        assert_eq!(expected, result.to_string());
504    }
505
506    #[test]
507    fn transform_distinct_order() {
508        let numbers_table = NumbersTable::table(0);
509        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
510            DfTableProviderAdapter::new(numbers_table),
511        )));
512
513        let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
514            .unwrap()
515            .distinct()
516            .unwrap()
517            .sort(vec![col("number").sort(true, false)])
518            .unwrap()
519            .build()
520            .unwrap();
521
522        let config = ConfigOptions::default();
523        let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
524        let expected = ["Projection: t.number", "  MergeScan [is_placeholder=false]"].join("\n");
525        assert_eq!(expected, result.to_string());
526    }
527
528    #[test]
529    fn transform_single_limit() {
530        let numbers_table = NumbersTable::table(0);
531        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
532            DfTableProviderAdapter::new(numbers_table),
533        )));
534
535        let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
536            .unwrap()
537            .limit(0, Some(1))
538            .unwrap()
539            .build()
540            .unwrap();
541
542        let config = ConfigOptions::default();
543        let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
544        let expected = "Projection: t.number\
545        \n  MergeScan [is_placeholder=false]";
546        assert_eq!(expected, result.to_string());
547    }
548
549    #[test]
550    fn transform_unalighed_join_with_alias() {
551        let left = NumbersTable::table(0);
552        let right = NumbersTable::table(1);
553        let left_source = Arc::new(DefaultTableSource::new(Arc::new(
554            DfTableProviderAdapter::new(left),
555        )));
556        let right_source = Arc::new(DefaultTableSource::new(Arc::new(
557            DfTableProviderAdapter::new(right),
558        )));
559
560        let right_plan = LogicalPlanBuilder::scan_with_filters("t", right_source, None, vec![])
561            .unwrap()
562            .alias("right")
563            .unwrap()
564            .build()
565            .unwrap();
566
567        let plan = LogicalPlanBuilder::scan_with_filters("t", left_source, None, vec![])
568            .unwrap()
569            .join_on(
570                right_plan,
571                JoinType::LeftSemi,
572                vec![col("t.number").eq(col("right.number"))],
573            )
574            .unwrap()
575            .limit(0, Some(1))
576            .unwrap()
577            .build()
578            .unwrap();
579
580        let config = ConfigOptions::default();
581        let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
582        let expected = [
583            "Limit: skip=0, fetch=1",
584            "  LeftSemi Join:  Filter: t.number = right.number",
585            "    Projection: t.number",
586            "      MergeScan [is_placeholder=false]",
587            "    SubqueryAlias: right",
588            "      Projection: t.number",
589            "        MergeScan [is_placeholder=false]",
590        ]
591        .join("\n");
592        assert_eq!(expected, result.to_string());
593    }
594}