query/dist_plan/
analyzer.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeSet, HashSet};
16use std::sync::Arc;
17
18use common_telemetry::debug;
19use datafusion::config::{ConfigExtension, ExtensionOptions};
20use datafusion::datasource::DefaultTableSource;
21use datafusion::error::Result as DfResult;
22use datafusion_common::Column;
23use datafusion_common::config::ConfigOptions;
24use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
25use datafusion_expr::expr::{Exists, InSubquery};
26use datafusion_expr::utils::expr_to_columns;
27use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder, Subquery, col as col_fn};
28use datafusion_optimizer::analyzer::AnalyzerRule;
29use datafusion_optimizer::simplify_expressions::SimplifyExpressions;
30use datafusion_optimizer::{OptimizerContext, OptimizerRule};
31use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
32use table::metadata::TableType;
33use table::table::adapter::DfTableProviderAdapter;
34
35use crate::dist_plan::analyzer::utils::aliased_columns_for;
36use crate::dist_plan::commutativity::{
37    Categorizer, Commutativity, partial_commutative_transformer,
38};
39use crate::dist_plan::merge_scan::MergeScanLogicalPlan;
40use crate::metrics::PUSH_DOWN_FALLBACK_ERRORS_TOTAL;
41use crate::plan::ExtractExpr;
42use crate::query_engine::DefaultSerializer;
43
44#[cfg(test)]
45mod test;
46
47mod fallback;
48mod utils;
49
50pub(crate) use utils::AliasMapping;
51
52#[derive(Debug, Clone)]
53pub struct DistPlannerOptions {
54    pub allow_query_fallback: bool,
55}
56
57impl ConfigExtension for DistPlannerOptions {
58    const PREFIX: &'static str = "dist_planner";
59}
60
61impl ExtensionOptions for DistPlannerOptions {
62    fn as_any(&self) -> &dyn std::any::Any {
63        self
64    }
65
66    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
67        self
68    }
69
70    fn cloned(&self) -> Box<dyn ExtensionOptions> {
71        Box::new(self.clone())
72    }
73
74    fn set(&mut self, key: &str, value: &str) -> DfResult<()> {
75        Err(datafusion_common::DataFusionError::NotImplemented(format!(
76            "DistPlannerOptions does not support set key: {key} with value: {value}"
77        )))
78    }
79
80    fn entries(&self) -> Vec<datafusion::config::ConfigEntry> {
81        vec![datafusion::config::ConfigEntry {
82            key: "allow_query_fallback".to_string(),
83            value: Some(self.allow_query_fallback.to_string()),
84            description: "Allow query fallback to fallback plan rewriter",
85        }]
86    }
87}
88
89#[derive(Debug)]
90pub struct DistPlannerAnalyzer;
91
92impl AnalyzerRule for DistPlannerAnalyzer {
93    fn name(&self) -> &str {
94        "DistPlannerAnalyzer"
95    }
96
97    fn analyze(
98        &self,
99        plan: LogicalPlan,
100        config: &ConfigOptions,
101    ) -> datafusion_common::Result<LogicalPlan> {
102        // preprocess the input plan
103        let optimizer_context = OptimizerContext::new();
104        let plan = SimplifyExpressions::new()
105            .rewrite(plan, &optimizer_context)?
106            .data;
107
108        let opt = config.extensions.get::<DistPlannerOptions>();
109        let allow_fallback = opt.map(|o| o.allow_query_fallback).unwrap_or(false);
110
111        let result = match self.try_push_down(plan.clone()) {
112            Ok(plan) => plan,
113            Err(err) => {
114                if allow_fallback {
115                    common_telemetry::warn!(err; "Failed to push down plan, using fallback plan rewriter for plan: {plan}");
116                    // if push down failed, use fallback plan rewriter
117                    PUSH_DOWN_FALLBACK_ERRORS_TOTAL.inc();
118                    self.use_fallback(plan)?
119                } else {
120                    return Err(err);
121                }
122            }
123        };
124
125        Ok(result)
126    }
127}
128
129impl DistPlannerAnalyzer {
130    /// Try push down as many nodes as possible
131    fn try_push_down(&self, plan: LogicalPlan) -> DfResult<LogicalPlan> {
132        let plan = plan.transform(&Self::inspect_plan_with_subquery)?;
133        let mut rewriter = PlanRewriter::default();
134        let result = plan.data.rewrite(&mut rewriter)?.data;
135        Ok(result)
136    }
137
138    /// Use fallback plan rewriter to rewrite the plan and only push down table scan nodes
139    fn use_fallback(&self, plan: LogicalPlan) -> DfResult<LogicalPlan> {
140        let mut rewriter = fallback::FallbackPlanRewriter;
141        let result = plan.rewrite(&mut rewriter)?.data;
142        Ok(result)
143    }
144
145    fn inspect_plan_with_subquery(plan: LogicalPlan) -> DfResult<Transformed<LogicalPlan>> {
146        // Workaround for https://github.com/GreptimeTeam/greptimedb/issues/5469 and https://github.com/GreptimeTeam/greptimedb/issues/5799
147        // FIXME(yingwen): Remove the `Limit` plan once we update DataFusion.
148        if let LogicalPlan::Limit(_) | LogicalPlan::Distinct(_) = &plan {
149            return Ok(Transformed::no(plan));
150        }
151
152        let exprs = plan
153            .expressions_consider_join()
154            .into_iter()
155            .map(|e| e.transform(&Self::transform_subquery).map(|x| x.data))
156            .collect::<DfResult<Vec<_>>>()?;
157
158        // Some plans that are special treated (should not call `with_new_exprs` on them)
159        if !matches!(plan, LogicalPlan::Unnest(_)) {
160            let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
161            Ok(Transformed::yes(plan.with_new_exprs(exprs, inputs)?))
162        } else {
163            Ok(Transformed::no(plan))
164        }
165    }
166
167    fn transform_subquery(expr: Expr) -> DfResult<Transformed<Expr>> {
168        match expr {
169            Expr::Exists(exists) => Ok(Transformed::yes(Expr::Exists(Exists {
170                subquery: Self::handle_subquery(exists.subquery)?,
171                negated: exists.negated,
172            }))),
173            Expr::InSubquery(in_subquery) => Ok(Transformed::yes(Expr::InSubquery(InSubquery {
174                expr: in_subquery.expr,
175                subquery: Self::handle_subquery(in_subquery.subquery)?,
176                negated: in_subquery.negated,
177            }))),
178            Expr::ScalarSubquery(scalar_subquery) => Ok(Transformed::yes(Expr::ScalarSubquery(
179                Self::handle_subquery(scalar_subquery)?,
180            ))),
181
182            _ => Ok(Transformed::no(expr)),
183        }
184    }
185
186    fn handle_subquery(subquery: Subquery) -> DfResult<Subquery> {
187        let mut rewriter = PlanRewriter::default();
188        let mut rewrote_subquery = subquery
189            .subquery
190            .as_ref()
191            .clone()
192            .rewrite(&mut rewriter)?
193            .data;
194        // Workaround. DF doesn't support the first plan in subquery to be an Extension
195        if matches!(rewrote_subquery, LogicalPlan::Extension(_)) {
196            let output_schema = rewrote_subquery.schema().clone();
197            let project_exprs = output_schema
198                .fields()
199                .iter()
200                .map(|f| col_fn(f.name()))
201                .collect::<Vec<_>>();
202            rewrote_subquery = LogicalPlanBuilder::from(rewrote_subquery)
203                .project(project_exprs)?
204                .build()?;
205        }
206
207        Ok(Subquery {
208            subquery: Arc::new(rewrote_subquery),
209            outer_ref_columns: subquery.outer_ref_columns,
210            spans: Default::default(),
211        })
212    }
213}
214
215/// Status of the rewriter to mark if the current pass is expanded
216#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
217enum RewriterStatus {
218    #[default]
219    Unexpanded,
220    Expanded,
221}
222
223#[derive(Debug, Default)]
224struct PlanRewriter {
225    /// Current level in the tree
226    level: usize,
227    /// Simulated stack for the `rewrite` recursion
228    stack: Vec<(LogicalPlan, usize)>,
229    /// Stages to be expanded, will be added as parent node of merge scan one by one
230    stage: Vec<LogicalPlan>,
231    status: RewriterStatus,
232    /// Partition columns of the table in current pass
233    partition_cols: Option<AliasMapping>,
234    /// use stack count as scope to determine column requirements is needed or not
235    /// i.e for a logical plan like:
236    /// ```ignore
237    /// 1: Projection: t.number
238    /// 2: Sort: t.pk1+t.pk2
239    /// 3. Projection: t.number, t.pk1, t.pk2
240    /// ```
241    /// `Sort` will make a column requirement for `t.pk1` at level 2.
242    /// Which making `Projection` at level 1 need to add a ref to `t.pk1` as well.
243    /// So that the expanded plan will be
244    /// ```ignore
245    /// Projection: t.number
246    ///   MergeSort: t.pk1
247    ///     MergeScan: remote_input=
248    /// Projection: t.number, "t.pk1+t.pk2" <--- the original `Projection` at level 1 get added with `t.pk1+t.pk2`
249    ///  Sort: t.pk1+t.pk2
250    ///    Projection: t.number, t.pk1, t.pk2
251    /// ```
252    /// Making `MergeSort` can have `t.pk1` as input.
253    /// Meanwhile `Projection` at level 3 doesn't need to add any new column because 3 > 2
254    /// and col requirements at level 2 is not applicable for level 3.
255    ///
256    /// see more details in test `expand_proj_step_aggr` and `expand_proj_sort_proj`
257    ///
258    /// TODO(discord9): a simpler solution to track column requirements for merge scan
259    column_requirements: Vec<(HashSet<Column>, usize)>,
260    /// Whether to expand on next call
261    /// This is used to handle the case where a plan is transformed, but need to be expanded from it's
262    /// parent node. For example a Aggregate plan is split into two parts in frontend and datanode, and need
263    /// to be expanded from the parent node of the Aggregate plan.
264    expand_on_next_call: bool,
265    /// Expanding on next partial/conditional/transformed commutative plan
266    /// This is used to handle the case where a plan is transformed, but still
267    /// need to push down as many node as possible before next partial/conditional/transformed commutative
268    /// plan. I.e.
269    /// ```ignore
270    /// Limit:
271    ///     Sort:
272    /// ```
273    /// where `Limit` is partial commutative, and `Sort` is conditional commutative.
274    /// In this case, we need to expand the `Limit` plan,
275    /// so that we can push down the `Sort` plan as much as possible.
276    expand_on_next_part_cond_trans_commutative: bool,
277    new_child_plan: Option<LogicalPlan>,
278}
279
280impl PlanRewriter {
281    fn get_parent(&self) -> Option<&LogicalPlan> {
282        // level starts from 1, it's safe to minus by 1
283        self.stack
284            .iter()
285            .rev()
286            .find(|(_, level)| *level == self.level - 1)
287            .map(|(node, _)| node)
288    }
289
290    /// Return true if should stop and expand. The input plan is the parent node of current node
291    fn should_expand(&mut self, plan: &LogicalPlan) -> bool {
292        debug!(
293            "Check should_expand at level: {}  with Stack:\n{}, ",
294            self.level,
295            self.stack
296                .iter()
297                .map(|(p, l)| format!("{l}:{}{}", "  ".repeat(l - 1), p.display()))
298                .collect::<Vec<String>>()
299                .join("\n"),
300        );
301        if DFLogicalSubstraitConvertor
302            .encode(plan, DefaultSerializer)
303            .is_err()
304        {
305            return true;
306        }
307
308        if self.expand_on_next_call {
309            self.expand_on_next_call = false;
310            return true;
311        }
312
313        if self.expand_on_next_part_cond_trans_commutative {
314            let comm = Categorizer::check_plan(plan, self.partition_cols.clone());
315            match comm {
316                Commutativity::PartialCommutative => {
317                    // a small difference is that for partial commutative, we still need to
318                    // push down it(so `Limit` can be pushed down)
319
320                    // notice how limit needed to be expanded as well to make sure query is correct
321                    // i.e. `Limit fetch=10` need to be pushed down to the leaf node
322                    self.expand_on_next_part_cond_trans_commutative = false;
323                    self.expand_on_next_call = true;
324                }
325                Commutativity::ConditionalCommutative(_)
326                | Commutativity::TransformedCommutative { .. } => {
327                    // again a new node that can be push down, we should just
328                    // do push down now and avoid further expansion
329                    self.expand_on_next_part_cond_trans_commutative = false;
330                    return true;
331                }
332                _ => (),
333            }
334        }
335
336        match Categorizer::check_plan(plan, self.partition_cols.clone()) {
337            Commutativity::Commutative => {}
338            Commutativity::PartialCommutative => {
339                if let Some(plan) = partial_commutative_transformer(plan) {
340                    // notice this plan is parent of current node, so `self.level - 1` when updating column requirements
341                    self.update_column_requirements(&plan, self.level - 1);
342                    self.expand_on_next_part_cond_trans_commutative = true;
343                    self.stage.push(plan)
344                }
345            }
346            Commutativity::ConditionalCommutative(transformer) => {
347                if let Some(transformer) = transformer
348                    && let Some(plan) = transformer(plan)
349                {
350                    // notice this plan is parent of current node, so `self.level - 1` when updating column requirements
351                    self.update_column_requirements(&plan, self.level - 1);
352                    self.expand_on_next_part_cond_trans_commutative = true;
353                    self.stage.push(plan)
354                }
355            }
356            Commutativity::TransformedCommutative { transformer } => {
357                if let Some(transformer) = transformer
358                    && let Some(transformer_actions) = transformer(plan)
359                {
360                    debug!(
361                        "PlanRewriter: transformed plan: {}\n from {plan}",
362                        transformer_actions
363                            .extra_parent_plans
364                            .iter()
365                            .enumerate()
366                            .map(|(i, p)| format!(
367                                "Extra {i}-th parent plan from parent to child = {}",
368                                p.display()
369                            ))
370                            .collect::<Vec<_>>()
371                            .join("\n")
372                    );
373                    if let Some(new_child_plan) = &transformer_actions.new_child_plan {
374                        debug!("PlanRewriter: new child plan: {}", new_child_plan);
375                    }
376                    if let Some(last_stage) = transformer_actions.extra_parent_plans.last() {
377                        // update the column requirements from the last stage
378                        // notice current plan's parent plan is where we need to apply the column requirements
379                        self.update_column_requirements(last_stage, self.level - 1);
380                    }
381                    self.stage
382                        .extend(transformer_actions.extra_parent_plans.into_iter().rev());
383                    self.expand_on_next_call = true;
384                    self.new_child_plan = transformer_actions.new_child_plan;
385                }
386            }
387            Commutativity::NonCommutative
388            | Commutativity::Unimplemented
389            | Commutativity::Unsupported => {
390                return true;
391            }
392        }
393
394        false
395    }
396
397    /// Update the column requirements for the current plan, plan_level is the level of the plan
398    /// in the stack, which is used to determine if the column requirements are applicable
399    /// for other plans in the stack.
400    fn update_column_requirements(&mut self, plan: &LogicalPlan, plan_level: usize) {
401        debug!(
402            "PlanRewriter: update column requirements for plan: {plan}\n with old column_requirements: {:?}",
403            self.column_requirements
404        );
405        let mut container = HashSet::new();
406        for expr in plan.expressions() {
407            // this method won't fail
408            let _ = expr_to_columns(&expr, &mut container);
409        }
410
411        self.column_requirements.push((container, plan_level));
412        debug!(
413            "PlanRewriter: updated column requirements: {:?}",
414            self.column_requirements
415        );
416    }
417
418    fn is_expanded(&self) -> bool {
419        self.status == RewriterStatus::Expanded
420    }
421
422    fn set_expanded(&mut self) {
423        self.status = RewriterStatus::Expanded;
424    }
425
426    fn set_unexpanded(&mut self) {
427        self.status = RewriterStatus::Unexpanded;
428    }
429
430    fn maybe_set_partitions(&mut self, plan: &LogicalPlan) -> DfResult<()> {
431        if let Some(part_cols) = &mut self.partition_cols {
432            // update partition alias
433            let child = plan.inputs().first().cloned().ok_or_else(|| {
434                datafusion_common::DataFusionError::Internal(format!(
435                    "PlanRewriter: maybe_set_partitions: plan has no child: {plan}"
436                ))
437            })?;
438
439            for (_col_name, alias_set) in part_cols.iter_mut() {
440                let aliased_cols = aliased_columns_for(
441                    &alias_set.clone().into_iter().collect(),
442                    plan,
443                    Some(child),
444                )?;
445                *alias_set = aliased_cols.into_values().flatten().collect();
446            }
447
448            debug!(
449                "PlanRewriter: maybe_set_partitions: updated partition columns: {:?} at plan: {}",
450                part_cols,
451                plan.display()
452            );
453
454            return Ok(());
455        }
456
457        if let LogicalPlan::TableScan(table_scan) = plan
458            && let Some(source) = table_scan
459                .source
460                .as_any()
461                .downcast_ref::<DefaultTableSource>()
462            && let Some(provider) = source
463                .table_provider
464                .as_any()
465                .downcast_ref::<DfTableProviderAdapter>()
466        {
467            let table = provider.table();
468            if table.table_type() == TableType::Base {
469                let info = table.table_info();
470                let partition_key_indices = info.meta.partition_key_indices.clone();
471                let schema = info.meta.schema.clone();
472                let mut partition_cols = partition_key_indices
473                    .into_iter()
474                    .map(|index| schema.column_name_by_index(index).to_string())
475                    .collect::<Vec<String>>();
476
477                let partition_rules = table.partition_rules();
478                let exist_phy_part_cols_not_in_logical_table = partition_rules
479                    .map(|r| !r.extra_phy_cols_not_in_logical_table.is_empty())
480                    .unwrap_or(false);
481
482                if exist_phy_part_cols_not_in_logical_table && partition_cols.is_empty() {
483                    // there are other physical partition columns that are not in logical table and part cols are empty
484                    // so we need to add a placeholder for it to prevent certain optimization
485                    // this is used to make sure the final partition columns(that optimizer see) are not empty
486                    // notice if originally partition_cols is not empty, then there is no need to add this place holder,
487                    // as subset of phy part cols can still be used for certain optimization, and it works as if
488                    // those columns are always null
489                    // This helps with distinguishing between non-partitioned table and partitioned table with all phy part cols not in logical table
490                    partition_cols.push("__OTHER_PHYSICAL_PART_COLS_PLACEHOLDER__".to_string());
491                }
492                self.partition_cols = Some(
493                            partition_cols
494                                .into_iter()
495                                .map(|c| {
496                                    let index =
497                                        plan.schema().index_of_column_by_name(None, &c).ok_or_else(|| {
498                                            datafusion_common::DataFusionError::Internal(
499                                                format!(
500                                                    "PlanRewriter: maybe_set_partitions: column {c} not found in schema of plan: {plan}"
501                                                ),
502                                            )
503                                        })?;
504                                    let column = plan.schema().columns().get(index).cloned().ok_or_else(|| {
505                                        datafusion_common::DataFusionError::Internal(format!(
506                                            "PlanRewriter: maybe_set_partitions: column index {index} out of bounds in schema of plan: {plan}"
507                                        ))
508                                    })?;
509                                    Ok((c.clone(), BTreeSet::from([column])))
510                                })
511                                .collect::<DfResult<AliasMapping>>()?,
512                        );
513            }
514        }
515
516        Ok(())
517    }
518
519    /// pop one stack item and reduce the level by 1
520    fn pop_stack(&mut self) {
521        self.level -= 1;
522        self.stack.pop();
523    }
524
525    fn expand(&mut self, mut on_node: LogicalPlan) -> DfResult<LogicalPlan> {
526        // store schema before expand, new child plan might have a different schema, so not using it
527        let schema = on_node.schema().clone();
528        if let Some(new_child_plan) = self.new_child_plan.take() {
529            // if there is a new child plan, use it as the new root
530            on_node = new_child_plan;
531        }
532        let mut rewriter = EnforceDistRequirementRewriter::new(
533            std::mem::take(&mut self.column_requirements),
534            self.level,
535        );
536        debug!(
537            "PlanRewriter: enforce column requirements for node: {on_node} with rewriter: {rewriter:?}"
538        );
539        on_node = on_node.rewrite(&mut rewriter)?.data;
540        debug!(
541            "PlanRewriter: after enforced column requirements with rewriter: {rewriter:?} for node:\n{on_node}"
542        );
543
544        debug!(
545            "PlanRewriter: expand on node: {on_node} with partition col alias mapping: {:?}",
546            self.partition_cols
547        );
548
549        // add merge scan as the new root
550        let mut node = MergeScanLogicalPlan::new(
551            on_node,
552            false,
553            // at this stage, the partition cols should be set
554            // treat it as non-partitioned if None
555            self.partition_cols.clone().unwrap_or_default(),
556        )
557        .into_logical_plan();
558
559        // expand stages
560        for new_stage in self.stage.drain(..) {
561            node = new_stage
562                .with_new_exprs(new_stage.expressions_consider_join(), vec![node.clone()])?;
563        }
564        self.set_expanded();
565
566        // recover the schema, this make sure after expand the schema is the same as old node
567        // because after expand the raw top node might have extra columns i.e. sorting columns for `Sort` node
568        let node = LogicalPlanBuilder::from(node)
569            .project(schema.iter().map(|(qualifier, field)| {
570                Expr::Column(Column::new(qualifier.cloned(), field.name()))
571            }))?
572            .build()?;
573
574        Ok(node)
575    }
576}
577
578/// Implementation of the [`TreeNodeRewriter`] trait which is responsible for rewriting
579/// logical plans to enforce various requirement for distributed query.
580///
581/// Requirements enforced by this rewriter:
582/// - Enforce column requirements for `LogicalPlan::Projection` nodes. Makes sure the
583///   required columns are available in the sub plan.
584///
585#[derive(Debug)]
586struct EnforceDistRequirementRewriter {
587    /// only enforce column requirements after the expanding node in question,
588    /// meaning only for node with `cur_level` <= `level` will consider adding those column requirements
589    /// TODO(discord9): a simpler solution to track column requirements for merge scan
590    column_requirements: Vec<(HashSet<Column>, usize)>,
591    /// only apply column requirements >= `cur_level`
592    /// this is used to avoid applying column requirements that are not needed
593    /// for the current node, i.e. the node is not in the scope of the column requirements
594    /// i.e, for this plan:
595    /// ```ignore
596    /// Aggregate: min(t.number)
597    ///   Projection: t.number
598    /// ```
599    /// when on `Projection` node, we don't need to apply the column requirements of `Aggregate` node
600    /// because the `Projection` node is not in the scope of the `Aggregate` node
601    cur_level: usize,
602}
603
604impl EnforceDistRequirementRewriter {
605    fn new(column_requirements: Vec<(HashSet<Column>, usize)>, cur_level: usize) -> Self {
606        Self {
607            column_requirements,
608            cur_level,
609        }
610    }
611}
612
613impl TreeNodeRewriter for EnforceDistRequirementRewriter {
614    type Node = LogicalPlan;
615
616    fn f_down(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
617        // check that node doesn't have multiple children, i.e. join/subquery
618        if node.inputs().len() > 1 {
619            return Err(datafusion_common::DataFusionError::Internal(
620                "EnforceDistRequirementRewriter: node with multiple inputs is not supported"
621                    .to_string(),
622            ));
623        }
624        self.cur_level += 1;
625        Ok(Transformed::no(node))
626    }
627
628    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
629        self.cur_level -= 1;
630        // first get all applicable column requirements
631        let mut applicable_column_requirements = self
632            .column_requirements
633            .iter()
634            .filter(|(_, level)| *level >= self.cur_level)
635            .map(|(cols, _)| cols.clone())
636            .reduce(|mut acc, cols| {
637                acc.extend(cols);
638                acc
639            })
640            .unwrap_or_default();
641
642        debug!(
643            "EnforceDistRequirementRewriter: applicable column requirements at level {} = {:?} for node {}",
644            self.cur_level,
645            applicable_column_requirements,
646            node.display()
647        );
648
649        // make sure all projection applicable scope has the required columns
650        if let LogicalPlan::Projection(ref projection) = node {
651            for expr in &projection.expr {
652                let (qualifier, name) = expr.qualified_name();
653                let column = Column::new(qualifier, name);
654                applicable_column_requirements.remove(&column);
655            }
656            if applicable_column_requirements.is_empty() {
657                return Ok(Transformed::no(node));
658            }
659
660            let mut new_exprs = projection.expr.clone();
661            for col in &applicable_column_requirements {
662                new_exprs.push(Expr::Column(col.clone()));
663            }
664            let new_node =
665                node.with_new_exprs(new_exprs, node.inputs().into_iter().cloned().collect())?;
666            debug!(
667                "EnforceDistRequirementRewriter: added missing columns {:?} to projection node from old node: \n{node}\n Making new node: \n{new_node}",
668                applicable_column_requirements
669            );
670
671            // still need to continue for next projection if applicable
672            return Ok(Transformed::yes(new_node));
673        }
674        Ok(Transformed::no(node))
675    }
676}
677
678impl TreeNodeRewriter for PlanRewriter {
679    type Node = LogicalPlan;
680
681    /// descend
682    fn f_down<'a>(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
683        self.level += 1;
684        self.stack.push((node.clone(), self.level));
685        // decendening will clear the stage
686        self.stage.clear();
687        self.set_unexpanded();
688        self.partition_cols = None;
689        Ok(Transformed::no(node))
690    }
691
692    /// ascend
693    ///
694    /// Besure to call `pop_stack` before returning
695    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
696        // only expand once on each ascending
697        if self.is_expanded() {
698            self.pop_stack();
699            return Ok(Transformed::no(node));
700        }
701
702        // only expand when the leaf is table scan
703        if node.inputs().is_empty() && !matches!(node, LogicalPlan::TableScan(_)) {
704            self.set_expanded();
705            self.pop_stack();
706            return Ok(Transformed::no(node));
707        }
708
709        self.maybe_set_partitions(&node)?;
710
711        let Some(parent) = self.get_parent() else {
712            debug!("Plan Rewriter: expand now for no parent found for node: {node}");
713            let node = self.expand(node);
714            debug!(
715                "PlanRewriter: expanded plan: {}",
716                match &node {
717                    Ok(n) => n.to_string(),
718                    Err(e) => format!("Error expanding plan: {e}"),
719                }
720            );
721            let node = node?;
722            self.pop_stack();
723            return Ok(Transformed::yes(node));
724        };
725
726        let parent = parent.clone();
727
728        // TODO(ruihang): avoid this clone
729        if self.should_expand(&parent) {
730            // TODO(ruihang): does this work for nodes with multiple children?;
731            debug!(
732                "PlanRewriter: should expand child:\n {node}\n Of Parent: {}",
733                parent.display()
734            );
735            let node = self.expand(node);
736            debug!(
737                "PlanRewriter: expanded plan: {}",
738                match &node {
739                    Ok(n) => n.to_string(),
740                    Err(e) => format!("Error expanding plan: {e}"),
741                }
742            );
743            let node = node?;
744            self.pop_stack();
745            return Ok(Transformed::yes(node));
746        }
747
748        self.pop_stack();
749        Ok(Transformed::no(node))
750    }
751}