query/dist_plan/
analyzer.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, BTreeSet, HashSet};
16use std::sync::Arc;
17
18use common_telemetry::debug;
19use datafusion::config::{ConfigExtension, ExtensionOptions};
20use datafusion::datasource::DefaultTableSource;
21use datafusion::error::Result as DfResult;
22use datafusion_common::Column;
23use datafusion_common::config::ConfigOptions;
24use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
25use datafusion_expr::expr::{Exists, InSubquery};
26use datafusion_expr::utils::expr_to_columns;
27use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder, Subquery, col as col_fn};
28use datafusion_optimizer::analyzer::AnalyzerRule;
29use promql::extension_plan::SeriesDivide;
30use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
31use table::metadata::TableType;
32use table::table::adapter::DfTableProviderAdapter;
33
34use crate::dist_plan::analyzer::utils::{
35    PatchOptimizerContext, PlanTreeExpressionSimplifier, aliased_columns_for,
36    rewrite_merge_sort_exprs,
37};
38use crate::dist_plan::commutativity::{
39    Categorizer, Commutativity, partial_commutative_transformer,
40};
41use crate::dist_plan::merge_scan::MergeScanLogicalPlan;
42use crate::dist_plan::merge_sort::MergeSortLogicalPlan;
43use crate::metrics::PUSH_DOWN_FALLBACK_ERRORS_TOTAL;
44use crate::plan::ExtractExpr;
45use crate::query_engine::DefaultSerializer;
46
47#[cfg(test)]
48mod test;
49
50mod fallback;
51pub(crate) mod utils;
52
53pub(crate) use utils::AliasMapping;
54
55/// Placeholder for other physical partition columns that are not in logical table
56const OTHER_PHY_PART_COL_PLACEHOLDER: &str = "__OTHER_PHYSICAL_PART_COLS_PLACEHOLDER__";
57
58#[derive(Debug, Clone)]
59pub struct DistPlannerOptions {
60    pub allow_query_fallback: bool,
61}
62
63impl ConfigExtension for DistPlannerOptions {
64    const PREFIX: &'static str = "dist_planner";
65}
66
67impl ExtensionOptions for DistPlannerOptions {
68    fn as_any(&self) -> &dyn std::any::Any {
69        self
70    }
71
72    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
73        self
74    }
75
76    fn cloned(&self) -> Box<dyn ExtensionOptions> {
77        Box::new(self.clone())
78    }
79
80    fn set(&mut self, key: &str, value: &str) -> DfResult<()> {
81        Err(datafusion_common::DataFusionError::NotImplemented(format!(
82            "DistPlannerOptions does not support set key: {key} with value: {value}"
83        )))
84    }
85
86    fn entries(&self) -> Vec<datafusion::config::ConfigEntry> {
87        vec![datafusion::config::ConfigEntry {
88            key: "allow_query_fallback".to_string(),
89            value: Some(self.allow_query_fallback.to_string()),
90            description: "Allow query fallback to fallback plan rewriter",
91        }]
92    }
93}
94
95#[derive(Debug)]
96pub struct DistPlannerAnalyzer;
97
98impl AnalyzerRule for DistPlannerAnalyzer {
99    fn name(&self) -> &str {
100        "DistPlannerAnalyzer"
101    }
102
103    fn analyze(
104        &self,
105        plan: LogicalPlan,
106        config: &ConfigOptions,
107    ) -> datafusion_common::Result<LogicalPlan> {
108        let mut config = config.clone();
109        // Aligned with the behavior in `datafusion_optimizer::OptimizerContext::new()`.
110        config.optimizer.filter_null_join_keys = true;
111        let config = Arc::new(config);
112
113        let optimizer_context = PatchOptimizerContext {
114            inner: datafusion_optimizer::OptimizerContext::new(),
115            config: config.clone(),
116        };
117
118        let plan = plan
119            .rewrite_with_subqueries(&mut PlanTreeExpressionSimplifier::new(optimizer_context))?
120            .data;
121
122        let opt = config.extensions.get::<DistPlannerOptions>();
123        let allow_fallback = opt.map(|o| o.allow_query_fallback).unwrap_or(false);
124
125        let result = match self.try_push_down(plan.clone()) {
126            Ok(plan) => plan,
127            Err(err) => {
128                if allow_fallback {
129                    common_telemetry::warn!(err; "Failed to push down plan, using fallback plan rewriter for plan: {plan}");
130                    // if push down failed, use fallback plan rewriter
131                    PUSH_DOWN_FALLBACK_ERRORS_TOTAL.inc();
132                    self.use_fallback(plan)?
133                } else {
134                    return Err(err);
135                }
136            }
137        };
138
139        Ok(result)
140    }
141}
142
143impl DistPlannerAnalyzer {
144    /// Try push down as many nodes as possible
145    fn try_push_down(&self, plan: LogicalPlan) -> DfResult<LogicalPlan> {
146        let plan = plan.transform(&Self::inspect_plan_with_subquery)?;
147        let mut rewriter = PlanRewriter::default();
148        let result = plan.data.rewrite(&mut rewriter)?.data;
149        Ok(result)
150    }
151
152    /// Use fallback plan rewriter to rewrite the plan and only push down table scan nodes
153    fn use_fallback(&self, plan: LogicalPlan) -> DfResult<LogicalPlan> {
154        let mut rewriter = fallback::FallbackPlanRewriter;
155        let result = plan.rewrite(&mut rewriter)?.data;
156        Ok(result)
157    }
158
159    fn inspect_plan_with_subquery(plan: LogicalPlan) -> DfResult<Transformed<LogicalPlan>> {
160        // Workaround for https://github.com/GreptimeTeam/greptimedb/issues/5469 and https://github.com/GreptimeTeam/greptimedb/issues/5799
161        // FIXME(yingwen): Remove the `Limit` plan once we update DataFusion.
162        if let LogicalPlan::Limit(_) | LogicalPlan::Distinct(_) = &plan {
163            return Ok(Transformed::no(plan));
164        }
165
166        let exprs = plan
167            .expressions_consider_join()
168            .into_iter()
169            .map(|e| e.transform(&Self::transform_subquery).map(|x| x.data))
170            .collect::<DfResult<Vec<_>>>()?;
171
172        // Some plans that are special treated (should not call `with_new_exprs` on them)
173        if !matches!(plan, LogicalPlan::Unnest(_)) {
174            let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
175            Ok(Transformed::yes(plan.with_new_exprs(exprs, inputs)?))
176        } else {
177            Ok(Transformed::no(plan))
178        }
179    }
180
181    fn transform_subquery(expr: Expr) -> DfResult<Transformed<Expr>> {
182        match expr {
183            Expr::Exists(exists) => Ok(Transformed::yes(Expr::Exists(Exists {
184                subquery: Self::handle_subquery(exists.subquery)?,
185                negated: exists.negated,
186            }))),
187            Expr::InSubquery(in_subquery) => Ok(Transformed::yes(Expr::InSubquery(InSubquery {
188                expr: in_subquery.expr,
189                subquery: Self::handle_subquery(in_subquery.subquery)?,
190                negated: in_subquery.negated,
191            }))),
192            Expr::ScalarSubquery(scalar_subquery) => Ok(Transformed::yes(Expr::ScalarSubquery(
193                Self::handle_subquery(scalar_subquery)?,
194            ))),
195
196            _ => Ok(Transformed::no(expr)),
197        }
198    }
199
200    fn handle_subquery(subquery: Subquery) -> DfResult<Subquery> {
201        let mut rewriter = PlanRewriter::default();
202        let mut rewrote_subquery = subquery
203            .subquery
204            .as_ref()
205            .clone()
206            .rewrite(&mut rewriter)?
207            .data;
208        // Workaround. DF doesn't support the first plan in subquery to be an Extension
209        if matches!(rewrote_subquery, LogicalPlan::Extension(_)) {
210            let output_schema = rewrote_subquery.schema().clone();
211            let project_exprs = output_schema
212                .fields()
213                .iter()
214                .map(|f| col_fn(f.name()))
215                .collect::<Vec<_>>();
216            rewrote_subquery = LogicalPlanBuilder::from(rewrote_subquery)
217                .project(project_exprs)?
218                .build()?;
219        }
220
221        Ok(Subquery {
222            subquery: Arc::new(rewrote_subquery),
223            outer_ref_columns: subquery.outer_ref_columns,
224            spans: Default::default(),
225        })
226    }
227}
228
229/// Status of the rewriter to mark if the current pass is expanded
230#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
231enum RewriterStatus {
232    #[default]
233    Unexpanded,
234    Expanded,
235}
236
237#[derive(Debug, Default)]
238struct PlanRewriter {
239    /// Current level in the tree
240    level: usize,
241    /// Simulated stack for the `rewrite` recursion
242    stack: Vec<(LogicalPlan, usize)>,
243    /// Stages to be expanded, will be added as parent node of merge scan one by one
244    stage: Vec<LogicalPlan>,
245    status: RewriterStatus,
246    /// Partition columns of the table in current pass
247    partition_cols: Option<AliasMapping>,
248    /// use stack count as scope to determine column requirements is needed or not
249    /// i.e for a logical plan like:
250    /// ```ignore
251    /// 1: Projection: t.number
252    /// 2: Sort: t.pk1+t.pk2
253    /// 3. Projection: t.number, t.pk1, t.pk2
254    /// ```
255    /// `Sort` will make a column requirement for `t.pk1+t.pk2` at level 2.
256    /// Which making `Projection` at level 1 need to add a ref to `t.pk1` as well.
257    /// So that the expanded plan will be
258    /// ```ignore
259    /// Projection: t.number
260    ///   MergeSort: t.pk1+t.pk2
261    ///     MergeScan: remote_input=
262    /// Projection: t.number, "t.pk1+t.pk2" <--- the original `Projection` at level 1 get added with `t.pk1+t.pk2`
263    ///  Sort: t.pk1+t.pk2
264    ///    Projection: t.number, t.pk1, t.pk2
265    /// ```
266    /// Making `MergeSort` can have `t.pk1+t.pk2` as input.
267    /// Meanwhile `Projection` at level 3 doesn't need to add any new column because 3 > 2
268    /// and col requirements at level 2 is not applicable for level 3.
269    ///
270    /// see more details in test `expand_proj_step_aggr` and `expand_proj_sort_proj`
271    ///
272    /// TODO(discord9): a simpler solution to track column requirements for merge scan
273    column_requirements: Vec<(HashSet<Column>, usize)>,
274    /// Whether to expand on next call
275    /// This is used to handle the case where a plan is transformed, but need to be expanded from it's
276    /// parent node. For example a Aggregate plan is split into two parts in frontend and datanode, and need
277    /// to be expanded from the parent node of the Aggregate plan.
278    expand_on_next_call: bool,
279    /// Expanding on next partial/conditional/transformed commutative plan
280    /// This is used to handle the case where a plan is transformed, but still
281    /// need to push down as many node as possible before next partial/conditional/transformed commutative
282    /// plan. I.e.
283    /// ```ignore
284    /// Limit:
285    ///     Sort:
286    /// ```
287    /// where `Limit` is partial commutative, and `Sort` is conditional commutative.
288    /// In this case, we need to expand the `Limit` plan,
289    /// so that we can push down the `Sort` plan as much as possible.
290    expand_on_next_part_cond_trans_commutative: bool,
291    new_child_plan: Option<LogicalPlan>,
292}
293
294impl PlanRewriter {
295    fn get_parent(&self) -> Option<&LogicalPlan> {
296        // level starts from 1, it's safe to minus by 1
297        self.stack
298            .iter()
299            .rev()
300            .find(|(_, level)| *level == self.level - 1)
301            .map(|(node, _)| node)
302    }
303
304    /// Return true if should stop and expand. The input plan is the parent node of current node
305    fn should_expand(&mut self, plan: &LogicalPlan) -> DfResult<bool> {
306        debug!(
307            "Check should_expand at level: {}  with Stack:\n{}, ",
308            self.level,
309            self.stack
310                .iter()
311                .map(|(p, l)| format!("{l}:{}{}", "  ".repeat(l - 1), p.display()))
312                .collect::<Vec<String>>()
313                .join("\n"),
314        );
315        if let Err(e) = DFLogicalSubstraitConvertor.encode(plan, DefaultSerializer) {
316            debug!(
317                "PlanRewriter: plan cannot be converted to substrait with error={e:?}, expanding now: {plan}"
318            );
319            return Ok(true);
320        }
321
322        if self.expand_on_next_call {
323            self.expand_on_next_call = false;
324            debug!("PlanRewriter: expand_on_next_call is true, expanding now");
325            return Ok(true);
326        }
327
328        if self.expand_on_next_part_cond_trans_commutative {
329            let comm = Categorizer::check_plan(plan, self.partition_cols.clone())?;
330            match comm {
331                Commutativity::PartialCommutative => {
332                    // a small difference is that for partial commutative, we still need to
333                    // push down it(so `Limit` can be pushed down)
334
335                    // notice how limit needed to be expanded as well to make sure query is correct
336                    // i.e. `Limit fetch=10` need to be pushed down to the leaf node
337                    self.expand_on_next_part_cond_trans_commutative = false;
338                    self.expand_on_next_call = true;
339                }
340                Commutativity::ConditionalCommutative(_)
341                | Commutativity::TransformedCommutative { .. } => {
342                    // again a new node that can be push down, we should just
343                    // do push down now and avoid further expansion
344                    self.expand_on_next_part_cond_trans_commutative = false;
345                    debug!(
346                        "PlanRewriter: meet a new conditional/transformed commutative plan, expanding now: {plan}"
347                    );
348                    return Ok(true);
349                }
350                _ => (),
351            }
352        }
353
354        match Categorizer::check_plan(plan, self.partition_cols.clone())? {
355            Commutativity::Commutative => {
356                // PATCH: we should reconsider SORT's commutativity instead of doing this trick.
357                // explain: for a fully commutative SeriesDivide, its child Sort plan only serves it. I.e., that
358                //   Sort plan is also fully commutative, instead of conditional commutative. So we can remove
359                //   the generated MergeSort from stage safely.
360                if let LogicalPlan::Extension(ext_a) = plan
361                    && ext_a.node.name() == SeriesDivide::name()
362                    && let Some(LogicalPlan::Extension(ext_b)) = self.stage.last()
363                    && ext_b.node.name() == MergeSortLogicalPlan::name()
364                {
365                    // revert last `ConditionalCommutative` result for Sort plan in this case.
366                    // also need to remove any column requirements made by the Sort Plan
367                    // as it may refer to columns later no longer exist(rightfully) like by aggregate or projection
368                    self.stage.pop();
369                    self.expand_on_next_part_cond_trans_commutative = false;
370                    self.column_requirements.clear();
371                }
372            }
373            Commutativity::PartialCommutative => {
374                if let Some(plan) = partial_commutative_transformer(plan) {
375                    // notice this plan is parent of current node, so `self.level - 1` when updating column requirements
376                    self.update_column_requirements(&plan, self.level - 1);
377                    self.expand_on_next_part_cond_trans_commutative = true;
378                    self.stage.push(plan)
379                }
380            }
381            Commutativity::ConditionalCommutative(transformer) => {
382                if let Some(transformer) = transformer
383                    && let Some(plan) = transformer(plan)
384                {
385                    // notice this plan is parent of current node, so `self.level - 1` when updating column requirements
386                    self.update_column_requirements(&plan, self.level - 1);
387                    self.expand_on_next_part_cond_trans_commutative = true;
388                    self.stage.push(plan)
389                }
390            }
391            Commutativity::TransformedCommutative { transformer } => {
392                if let Some(transformer) = transformer {
393                    let transformer_actions = transformer(plan)?;
394                    debug!(
395                        "PlanRewriter: transformed plan: {}\n from {plan}",
396                        transformer_actions
397                            .extra_parent_plans
398                            .iter()
399                            .enumerate()
400                            .map(|(i, p)| format!(
401                                "Extra {i}-th parent plan from parent to child = {}",
402                                p.display()
403                            ))
404                            .collect::<Vec<_>>()
405                            .join("\n")
406                    );
407                    if let Some(new_child_plan) = &transformer_actions.new_child_plan {
408                        debug!("PlanRewriter: new child plan: {}", new_child_plan);
409                    }
410                    if let Some(last_stage) = transformer_actions.extra_parent_plans.last() {
411                        // update the column requirements from the last stage
412                        // notice current plan's parent plan is where we need to apply the column requirements
413                        self.update_column_requirements(last_stage, self.level - 1);
414                    }
415                    self.stage
416                        .extend(transformer_actions.extra_parent_plans.into_iter().rev());
417                    self.expand_on_next_call = true;
418                    self.new_child_plan = transformer_actions.new_child_plan;
419                }
420            }
421            Commutativity::NonCommutative
422            | Commutativity::Unimplemented
423            | Commutativity::Unsupported => {
424                debug!("PlanRewriter: meet a non-commutative plan, expanding now: {plan}");
425                return Ok(true);
426            }
427        }
428
429        Ok(false)
430    }
431
432    /// Update the column requirements for the current plan, plan_level is the level of the plan
433    /// in the stack, which is used to determine if the column requirements are applicable
434    /// for other plans in the stack.
435    fn update_column_requirements(&mut self, plan: &LogicalPlan, plan_level: usize) {
436        debug!(
437            "PlanRewriter: update column requirements for plan: {plan}\n with old column_requirements: {:?}",
438            self.column_requirements
439        );
440        let mut container = HashSet::new();
441        for expr in plan.expressions() {
442            // this method won't fail
443            let _ = expr_to_columns(&expr, &mut container);
444        }
445
446        self.column_requirements.push((container, plan_level));
447        debug!(
448            "PlanRewriter: updated column requirements: {:?}",
449            self.column_requirements
450        );
451    }
452
453    fn is_expanded(&self) -> bool {
454        self.status == RewriterStatus::Expanded
455    }
456
457    fn set_expanded(&mut self) {
458        self.status = RewriterStatus::Expanded;
459    }
460
461    fn set_unexpanded(&mut self) {
462        self.status = RewriterStatus::Unexpanded;
463    }
464
465    fn maybe_set_partitions(&mut self, plan: &LogicalPlan) -> DfResult<()> {
466        if let Some(part_cols) = &mut self.partition_cols {
467            // update partition alias
468            let child = plan.inputs().first().cloned().ok_or_else(|| {
469                datafusion_common::DataFusionError::Internal(format!(
470                    "PlanRewriter: maybe_set_partitions: plan has no child: {plan}"
471                ))
472            })?;
473
474            for (_col_name, alias_set) in part_cols.iter_mut() {
475                let aliased_cols = aliased_columns_for(
476                    &alias_set.clone().into_iter().collect(),
477                    plan,
478                    Some(child),
479                )?;
480                *alias_set = aliased_cols.into_values().flatten().collect();
481            }
482
483            debug!(
484                "PlanRewriter: maybe_set_partitions: updated partition columns: {:?} at plan: {}",
485                part_cols,
486                plan.display()
487            );
488
489            return Ok(());
490        }
491
492        if let LogicalPlan::TableScan(table_scan) = plan
493            && let Some(source) = table_scan
494                .source
495                .as_any()
496                .downcast_ref::<DefaultTableSource>()
497            && let Some(provider) = source
498                .table_provider
499                .as_any()
500                .downcast_ref::<DfTableProviderAdapter>()
501        {
502            let table = provider.table();
503            if table.table_type() == TableType::Base {
504                let info = table.table_info();
505                let partition_key_indices = info.meta.partition_key_indices.clone();
506                let schema = info.meta.schema.clone();
507                let mut partition_cols = partition_key_indices
508                    .into_iter()
509                    .map(|index| schema.column_name_by_index(index).to_string())
510                    .collect::<Vec<String>>();
511
512                let partition_rules = table.partition_rules();
513                let exist_phy_part_cols_not_in_logical_table = partition_rules
514                    .map(|r| !r.extra_phy_cols_not_in_logical_table.is_empty())
515                    .unwrap_or(false);
516
517                if exist_phy_part_cols_not_in_logical_table && partition_cols.is_empty() {
518                    // there are other physical partition columns that are not in logical table and part cols are empty
519                    // so we need to add a placeholder for it to prevent certain optimization
520                    // this is used to make sure the final partition columns(that optimizer see) are not empty
521                    // notice if originally partition_cols is not empty, then there is no need to add this place holder,
522                    // as subset of phy part cols can still be used for certain optimization, and it works as if
523                    // those columns are always null
524                    // This helps with distinguishing between non-partitioned table and partitioned table with all phy part cols not in logical table
525                    partition_cols.push(OTHER_PHY_PART_COL_PLACEHOLDER.to_string());
526                }
527                self.partition_cols = Some(
528                            partition_cols
529                                .into_iter()
530                                .map(|c| {
531                                    if c == OTHER_PHY_PART_COL_PLACEHOLDER {
532                                        // for placeholder, just return a empty alias
533                                        return Ok((c.clone(), BTreeSet::new()));
534                                    }
535                                    let index =
536                                        if let Some(c) = plan.schema().index_of_column_by_name(None, &c){
537                                            c
538                                        } else {
539                                            // the `projection` field of `TableScan` doesn't contain the partition columns,
540                                            // this is similar to not having a alias, hence return empty alias set
541                                            return Ok((c.clone(), BTreeSet::new()))
542                                        };
543                                    let column = plan.schema().columns().get(index).cloned().ok_or_else(|| {
544                                        datafusion_common::DataFusionError::Internal(format!(
545                                            "PlanRewriter: maybe_set_partitions: column index {index} out of bounds in schema of plan: {plan}"
546                                        ))
547                                    })?;
548                                    Ok((c.clone(), BTreeSet::from([column])))
549                                })
550                                .collect::<DfResult<AliasMapping>>()?,
551                        );
552            }
553        }
554
555        Ok(())
556    }
557
558    /// pop one stack item and reduce the level by 1
559    fn pop_stack(&mut self) {
560        self.level -= 1;
561        self.stack.pop();
562    }
563
564    fn expand(&mut self, mut on_node: LogicalPlan) -> DfResult<LogicalPlan> {
565        // store schema before expand, new child plan might have a different schema, so not using it
566        let schema = on_node.schema().clone();
567        if let Some(new_child_plan) = self.new_child_plan.take() {
568            // if there is a new child plan, use it as the new root
569            on_node = new_child_plan;
570        }
571        let mut rewriter = EnforceDistRequirementRewriter::new(
572            std::mem::take(&mut self.column_requirements),
573            self.level,
574        );
575        debug!(
576            "PlanRewriter: enforce column requirements for node: {on_node} with rewriter: {rewriter:?}"
577        );
578        on_node = on_node.rewrite(&mut rewriter)?.data;
579        debug!(
580            "PlanRewriter: after enforced column requirements with rewriter: {rewriter:?} for node:\n{on_node}"
581        );
582
583        debug!(
584            "PlanRewriter: expand on node: {on_node} with partition col alias mapping: {:?}",
585            self.partition_cols
586        );
587
588        // add merge scan as the new root
589        let mut node = MergeScanLogicalPlan::new(
590            on_node.clone(),
591            false,
592            // at this stage, the partition cols should be set
593            // treat it as non-partitioned if None
594            self.partition_cols.clone().unwrap_or_default(),
595        )
596        .into_logical_plan();
597
598        // expand stages
599        for new_stage in self.stage.drain(..) {
600            // tracking alias for merge sort's sort exprs
601            let new_stage = if let LogicalPlan::Extension(ext) = &new_stage
602                && let Some(merge_sort) = ext.node.as_any().downcast_ref::<MergeSortLogicalPlan>()
603            {
604                // TODO(discord9): change `on_node` to `node` once alias tracking is supported for merge scan
605                rewrite_merge_sort_exprs(merge_sort, &on_node)?
606            } else {
607                new_stage
608            };
609            node = new_stage
610                .with_new_exprs(new_stage.expressions_consider_join(), vec![node.clone()])?;
611        }
612        self.set_expanded();
613
614        // recover the schema, this make sure after expand the schema is the same as old node
615        // because after expand the raw top node might have extra columns i.e. sorting columns for `Sort` node
616        let node = LogicalPlanBuilder::from(node)
617            .project(schema.iter().map(|(qualifier, field)| {
618                Expr::Column(Column::new(qualifier.cloned(), field.name()))
619            }))?
620            .build()?;
621
622        Ok(node)
623    }
624}
625
626/// Implementation of the [`TreeNodeRewriter`] trait which is responsible for rewriting
627/// logical plans to enforce various requirement for distributed query.
628///
629/// Requirements enforced by this rewriter:
630/// - Enforce column requirements for `LogicalPlan::Projection` nodes. Makes sure the
631///   required columns are available in the sub plan.
632///
633#[derive(Debug)]
634struct EnforceDistRequirementRewriter {
635    /// only enforce column requirements after the expanding node in question,
636    /// meaning only for node with `cur_level` <= `level` will consider adding those column requirements
637    /// TODO(discord9): a simpler solution to track column requirements for merge scan
638    column_requirements: Vec<(HashSet<Column>, usize)>,
639    /// only apply column requirements >= `cur_level`
640    /// this is used to avoid applying column requirements that are not needed
641    /// for the current node, i.e. the node is not in the scope of the column requirements
642    /// i.e, for this plan:
643    /// ```ignore
644    /// Aggregate: min(t.number)
645    ///   Projection: t.number
646    /// ```
647    /// when on `Projection` node, we don't need to apply the column requirements of `Aggregate` node
648    /// because the `Projection` node is not in the scope of the `Aggregate` node
649    cur_level: usize,
650    plan_per_level: BTreeMap<usize, LogicalPlan>,
651}
652
653impl EnforceDistRequirementRewriter {
654    fn new(column_requirements: Vec<(HashSet<Column>, usize)>, cur_level: usize) -> Self {
655        debug!(
656            "Create EnforceDistRequirementRewriter with column_requirements: {:?} at cur_level: {}",
657            column_requirements, cur_level
658        );
659        Self {
660            column_requirements,
661            cur_level,
662            plan_per_level: BTreeMap::new(),
663        }
664    }
665
666    /// Return a mapping from (original column, level) to aliased columns in current node of all
667    /// applicable column requirements
668    /// i.e. only column requirements with level >= `cur_level` will be considered
669    fn get_current_applicable_column_requirements(
670        &self,
671        node: &LogicalPlan,
672    ) -> DfResult<BTreeMap<(Column, usize), BTreeSet<Column>>> {
673        let col_req_per_level = self
674            .column_requirements
675            .iter()
676            .filter(|(_, level)| *level >= self.cur_level)
677            .collect::<Vec<_>>();
678
679        // track alias for columns and use aliased columns instead
680        // aliased col reqs at current level
681        let mut result_alias_mapping = BTreeMap::new();
682        let Some(child) = node.inputs().first().cloned() else {
683            return Ok(Default::default());
684        };
685        for (col_req, level) in col_req_per_level {
686            if let Some(original) = self.plan_per_level.get(level) {
687                // query for alias in current plan
688                let aliased_cols =
689                    aliased_columns_for(&col_req.iter().cloned().collect(), node, Some(original))?;
690                for original_col in col_req {
691                    let aliased_cols = aliased_cols.get(original_col).cloned();
692                    if let Some(cols) = aliased_cols
693                        && !cols.is_empty()
694                    {
695                        result_alias_mapping.insert((original_col.clone(), *level), cols);
696                    } else {
697                        // if no aliased column found in current node, there should be alias in child node as promised by enforce col reqs
698                        // because it should insert required columns in child node
699                        // so we can find the alias in child node
700                        // if not found, it's an internal error
701                        let aliases_in_child = aliased_columns_for(
702                            &[original_col.clone()].into(),
703                            child,
704                            Some(original),
705                        )?;
706                        let Some(aliases) = aliases_in_child
707                            .get(original_col)
708                            .cloned()
709                            .filter(|a| !a.is_empty())
710                        else {
711                            return Err(datafusion_common::DataFusionError::Internal(format!(
712                                "EnforceDistRequirementRewriter: no alias found for required column {original_col} at level {level} in current node's child plan: \n{child} from original plan: \n{original}",
713                            )));
714                        };
715
716                        result_alias_mapping.insert((original_col.clone(), *level), aliases);
717                    }
718                }
719            }
720        }
721        Ok(result_alias_mapping)
722    }
723}
724
725impl TreeNodeRewriter for EnforceDistRequirementRewriter {
726    type Node = LogicalPlan;
727
728    fn f_down(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
729        // check that node doesn't have multiple children, i.e. join/subquery
730        if node.inputs().len() > 1 {
731            return Err(datafusion_common::DataFusionError::Internal(
732                "EnforceDistRequirementRewriter: node with multiple inputs is not supported"
733                    .to_string(),
734            ));
735        }
736        self.plan_per_level.insert(self.cur_level, node.clone());
737        self.cur_level += 1;
738        Ok(Transformed::no(node))
739    }
740
741    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
742        self.cur_level -= 1;
743        // first get all applicable column requirements
744
745        // make sure all projection applicable scope has the required columns
746        if let LogicalPlan::Projection(ref projection) = node {
747            let mut applicable_column_requirements =
748                self.get_current_applicable_column_requirements(&node)?;
749
750            debug!(
751                "EnforceDistRequirementRewriter: applicable column requirements at level {} = {:?} for node {}",
752                self.cur_level,
753                applicable_column_requirements,
754                node.display()
755            );
756
757            for expr in &projection.expr {
758                let (qualifier, name) = expr.qualified_name();
759                let column = Column::new(qualifier, name);
760                applicable_column_requirements.retain(|_col_level, alias_set| {
761                    // remove all columns that are already in the projection exprs
762                    !alias_set.contains(&column)
763                });
764            }
765            if applicable_column_requirements.is_empty() {
766                return Ok(Transformed::no(node));
767            }
768
769            let mut new_exprs = projection.expr.clone();
770            for (col, alias_set) in &applicable_column_requirements {
771                // use the first alias in alias set as the column to add
772                new_exprs.push(Expr::Column(alias_set.first().cloned().ok_or_else(
773                    || {
774                        datafusion_common::DataFusionError::Internal(
775                            format!("EnforceDistRequirementRewriter: alias set is empty, for column {col:?} in node {node}"),
776                        )
777                    },
778                )?));
779            }
780            let new_node =
781                node.with_new_exprs(new_exprs, node.inputs().into_iter().cloned().collect())?;
782            debug!(
783                "EnforceDistRequirementRewriter: added missing columns {:?} to projection node from old node: \n{node}\n Making new node: \n{new_node}",
784                applicable_column_requirements
785            );
786
787            // update plan for later use
788            self.plan_per_level.insert(self.cur_level, new_node.clone());
789
790            // still need to continue for next projection if applicable
791            return Ok(Transformed::yes(new_node));
792        }
793        Ok(Transformed::no(node))
794    }
795}
796
797impl TreeNodeRewriter for PlanRewriter {
798    type Node = LogicalPlan;
799
800    /// descend
801    fn f_down<'a>(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
802        self.level += 1;
803        self.stack.push((node.clone(), self.level));
804        // decendening will clear the stage
805        self.stage.clear();
806        self.set_unexpanded();
807        self.partition_cols = None;
808        Ok(Transformed::no(node))
809    }
810
811    /// ascend
812    ///
813    /// Besure to call `pop_stack` before returning
814    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
815        // only expand once on each ascending
816        if self.is_expanded() {
817            self.pop_stack();
818            return Ok(Transformed::no(node));
819        }
820
821        // only expand when the leaf is table scan
822        if node.inputs().is_empty() && !matches!(node, LogicalPlan::TableScan(_)) {
823            self.set_expanded();
824            self.pop_stack();
825            return Ok(Transformed::no(node));
826        }
827
828        self.maybe_set_partitions(&node)?;
829
830        let Some(parent) = self.get_parent() else {
831            debug!("Plan Rewriter: expand now for no parent found for node: {node}");
832            let node = self.expand(node);
833            debug!(
834                "PlanRewriter: expanded plan: {}",
835                match &node {
836                    Ok(n) => n.to_string(),
837                    Err(e) => format!("Error expanding plan: {e}"),
838                }
839            );
840            let node = node?;
841            self.pop_stack();
842            return Ok(Transformed::yes(node));
843        };
844
845        let parent = parent.clone();
846
847        if self.should_expand(&parent)? {
848            // TODO(ruihang): does this work for nodes with multiple children?;
849            debug!(
850                "PlanRewriter: should expand child:\n {node}\n Of Parent: {}",
851                parent.display()
852            );
853            let node = self.expand(node);
854            debug!(
855                "PlanRewriter: expanded plan: {}",
856                match &node {
857                    Ok(n) => n.to_string(),
858                    Err(e) => format!("Error expanding plan: {e}"),
859                }
860            );
861            let node = node?;
862            self.pop_stack();
863            return Ok(Transformed::yes(node));
864        }
865
866        self.pop_stack();
867        Ok(Transformed::no(node))
868    }
869}