query/dist_plan/
analyzer.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, BTreeSet, HashSet};
16use std::sync::Arc;
17
18use chrono::{DateTime, Utc};
19use common_telemetry::debug;
20use datafusion::config::{ConfigExtension, ExtensionOptions};
21use datafusion::datasource::DefaultTableSource;
22use datafusion::error::Result as DfResult;
23use datafusion_common::Column;
24use datafusion_common::alias::AliasGenerator;
25use datafusion_common::config::ConfigOptions;
26use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
27use datafusion_expr::expr::{Exists, InSubquery};
28use datafusion_expr::utils::expr_to_columns;
29use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder, Subquery, col as col_fn};
30use datafusion_optimizer::analyzer::AnalyzerRule;
31use datafusion_optimizer::simplify_expressions::SimplifyExpressions;
32use datafusion_optimizer::{OptimizerConfig, OptimizerRule};
33use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
34use table::metadata::TableType;
35use table::table::adapter::DfTableProviderAdapter;
36
37use crate::dist_plan::analyzer::utils::{aliased_columns_for, rewrite_merge_sort_exprs};
38use crate::dist_plan::commutativity::{
39    Categorizer, Commutativity, partial_commutative_transformer,
40};
41use crate::dist_plan::merge_scan::MergeScanLogicalPlan;
42use crate::dist_plan::merge_sort::MergeSortLogicalPlan;
43use crate::metrics::PUSH_DOWN_FALLBACK_ERRORS_TOTAL;
44use crate::plan::ExtractExpr;
45use crate::query_engine::DefaultSerializer;
46
47#[cfg(test)]
48mod test;
49
50mod fallback;
51mod utils;
52
53pub(crate) use utils::AliasMapping;
54
55/// Placeholder for other physical partition columns that are not in logical table
56const OTHER_PHY_PART_COL_PLACEHOLDER: &str = "__OTHER_PHYSICAL_PART_COLS_PLACEHOLDER__";
57
58#[derive(Debug, Clone)]
59pub struct DistPlannerOptions {
60    pub allow_query_fallback: bool,
61}
62
63impl ConfigExtension for DistPlannerOptions {
64    const PREFIX: &'static str = "dist_planner";
65}
66
67impl ExtensionOptions for DistPlannerOptions {
68    fn as_any(&self) -> &dyn std::any::Any {
69        self
70    }
71
72    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
73        self
74    }
75
76    fn cloned(&self) -> Box<dyn ExtensionOptions> {
77        Box::new(self.clone())
78    }
79
80    fn set(&mut self, key: &str, value: &str) -> DfResult<()> {
81        Err(datafusion_common::DataFusionError::NotImplemented(format!(
82            "DistPlannerOptions does not support set key: {key} with value: {value}"
83        )))
84    }
85
86    fn entries(&self) -> Vec<datafusion::config::ConfigEntry> {
87        vec![datafusion::config::ConfigEntry {
88            key: "allow_query_fallback".to_string(),
89            value: Some(self.allow_query_fallback.to_string()),
90            description: "Allow query fallback to fallback plan rewriter",
91        }]
92    }
93}
94
95#[derive(Debug)]
96pub struct DistPlannerAnalyzer;
97
98impl AnalyzerRule for DistPlannerAnalyzer {
99    fn name(&self) -> &str {
100        "DistPlannerAnalyzer"
101    }
102
103    fn analyze(
104        &self,
105        plan: LogicalPlan,
106        config: &ConfigOptions,
107    ) -> datafusion_common::Result<LogicalPlan> {
108        let mut config = config.clone();
109        // Aligned with the behavior in `datafusion_optimizer::OptimizerContext::new()`.
110        config.optimizer.filter_null_join_keys = true;
111        let config = Arc::new(config);
112
113        // The `ConstEvaluator` in `SimplifyExpressions` might evaluate some UDFs early in the
114        // planning stage, by executing them directly. For example, the `database()` function.
115        // So the `ConfigOptions` here (which is set from the session context) should be present
116        // in the UDF's `ScalarFunctionArgs`. However, the default implementation in DataFusion
117        // seems to lost track on it: the `ConfigOptions` is recreated with its default values again.
118        // So we create a custom `OptimizerConfig` with the desired `ConfigOptions`
119        // to walk around the issue.
120        // TODO(LFC): Maybe use DataFusion's `OptimizerContext` again
121        //   once https://github.com/apache/datafusion/pull/17742 is merged.
122        struct OptimizerContext {
123            inner: datafusion_optimizer::OptimizerContext,
124            config: Arc<ConfigOptions>,
125        }
126
127        impl OptimizerConfig for OptimizerContext {
128            fn query_execution_start_time(&self) -> DateTime<Utc> {
129                self.inner.query_execution_start_time()
130            }
131
132            fn alias_generator(&self) -> &Arc<AliasGenerator> {
133                self.inner.alias_generator()
134            }
135
136            fn options(&self) -> Arc<ConfigOptions> {
137                self.config.clone()
138            }
139        }
140
141        let optimizer_context = OptimizerContext {
142            inner: datafusion_optimizer::OptimizerContext::new(),
143            config: config.clone(),
144        };
145
146        let plan = SimplifyExpressions::new()
147            .rewrite(plan, &optimizer_context)?
148            .data;
149
150        let opt = config.extensions.get::<DistPlannerOptions>();
151        let allow_fallback = opt.map(|o| o.allow_query_fallback).unwrap_or(false);
152
153        let result = match self.try_push_down(plan.clone()) {
154            Ok(plan) => plan,
155            Err(err) => {
156                if allow_fallback {
157                    common_telemetry::warn!(err; "Failed to push down plan, using fallback plan rewriter for plan: {plan}");
158                    // if push down failed, use fallback plan rewriter
159                    PUSH_DOWN_FALLBACK_ERRORS_TOTAL.inc();
160                    self.use_fallback(plan)?
161                } else {
162                    return Err(err);
163                }
164            }
165        };
166
167        Ok(result)
168    }
169}
170
171impl DistPlannerAnalyzer {
172    /// Try push down as many nodes as possible
173    fn try_push_down(&self, plan: LogicalPlan) -> DfResult<LogicalPlan> {
174        let plan = plan.transform(&Self::inspect_plan_with_subquery)?;
175        let mut rewriter = PlanRewriter::default();
176        let result = plan.data.rewrite(&mut rewriter)?.data;
177        Ok(result)
178    }
179
180    /// Use fallback plan rewriter to rewrite the plan and only push down table scan nodes
181    fn use_fallback(&self, plan: LogicalPlan) -> DfResult<LogicalPlan> {
182        let mut rewriter = fallback::FallbackPlanRewriter;
183        let result = plan.rewrite(&mut rewriter)?.data;
184        Ok(result)
185    }
186
187    fn inspect_plan_with_subquery(plan: LogicalPlan) -> DfResult<Transformed<LogicalPlan>> {
188        // Workaround for https://github.com/GreptimeTeam/greptimedb/issues/5469 and https://github.com/GreptimeTeam/greptimedb/issues/5799
189        // FIXME(yingwen): Remove the `Limit` plan once we update DataFusion.
190        if let LogicalPlan::Limit(_) | LogicalPlan::Distinct(_) = &plan {
191            return Ok(Transformed::no(plan));
192        }
193
194        let exprs = plan
195            .expressions_consider_join()
196            .into_iter()
197            .map(|e| e.transform(&Self::transform_subquery).map(|x| x.data))
198            .collect::<DfResult<Vec<_>>>()?;
199
200        // Some plans that are special treated (should not call `with_new_exprs` on them)
201        if !matches!(plan, LogicalPlan::Unnest(_)) {
202            let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
203            Ok(Transformed::yes(plan.with_new_exprs(exprs, inputs)?))
204        } else {
205            Ok(Transformed::no(plan))
206        }
207    }
208
209    fn transform_subquery(expr: Expr) -> DfResult<Transformed<Expr>> {
210        match expr {
211            Expr::Exists(exists) => Ok(Transformed::yes(Expr::Exists(Exists {
212                subquery: Self::handle_subquery(exists.subquery)?,
213                negated: exists.negated,
214            }))),
215            Expr::InSubquery(in_subquery) => Ok(Transformed::yes(Expr::InSubquery(InSubquery {
216                expr: in_subquery.expr,
217                subquery: Self::handle_subquery(in_subquery.subquery)?,
218                negated: in_subquery.negated,
219            }))),
220            Expr::ScalarSubquery(scalar_subquery) => Ok(Transformed::yes(Expr::ScalarSubquery(
221                Self::handle_subquery(scalar_subquery)?,
222            ))),
223
224            _ => Ok(Transformed::no(expr)),
225        }
226    }
227
228    fn handle_subquery(subquery: Subquery) -> DfResult<Subquery> {
229        let mut rewriter = PlanRewriter::default();
230        let mut rewrote_subquery = subquery
231            .subquery
232            .as_ref()
233            .clone()
234            .rewrite(&mut rewriter)?
235            .data;
236        // Workaround. DF doesn't support the first plan in subquery to be an Extension
237        if matches!(rewrote_subquery, LogicalPlan::Extension(_)) {
238            let output_schema = rewrote_subquery.schema().clone();
239            let project_exprs = output_schema
240                .fields()
241                .iter()
242                .map(|f| col_fn(f.name()))
243                .collect::<Vec<_>>();
244            rewrote_subquery = LogicalPlanBuilder::from(rewrote_subquery)
245                .project(project_exprs)?
246                .build()?;
247        }
248
249        Ok(Subquery {
250            subquery: Arc::new(rewrote_subquery),
251            outer_ref_columns: subquery.outer_ref_columns,
252            spans: Default::default(),
253        })
254    }
255}
256
257/// Status of the rewriter to mark if the current pass is expanded
258#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
259enum RewriterStatus {
260    #[default]
261    Unexpanded,
262    Expanded,
263}
264
265#[derive(Debug, Default)]
266struct PlanRewriter {
267    /// Current level in the tree
268    level: usize,
269    /// Simulated stack for the `rewrite` recursion
270    stack: Vec<(LogicalPlan, usize)>,
271    /// Stages to be expanded, will be added as parent node of merge scan one by one
272    stage: Vec<LogicalPlan>,
273    status: RewriterStatus,
274    /// Partition columns of the table in current pass
275    partition_cols: Option<AliasMapping>,
276    /// use stack count as scope to determine column requirements is needed or not
277    /// i.e for a logical plan like:
278    /// ```ignore
279    /// 1: Projection: t.number
280    /// 2: Sort: t.pk1+t.pk2
281    /// 3. Projection: t.number, t.pk1, t.pk2
282    /// ```
283    /// `Sort` will make a column requirement for `t.pk1` at level 2.
284    /// Which making `Projection` at level 1 need to add a ref to `t.pk1` as well.
285    /// So that the expanded plan will be
286    /// ```ignore
287    /// Projection: t.number
288    ///   MergeSort: t.pk1
289    ///     MergeScan: remote_input=
290    /// Projection: t.number, "t.pk1+t.pk2" <--- the original `Projection` at level 1 get added with `t.pk1+t.pk2`
291    ///  Sort: t.pk1+t.pk2
292    ///    Projection: t.number, t.pk1, t.pk2
293    /// ```
294    /// Making `MergeSort` can have `t.pk1` as input.
295    /// Meanwhile `Projection` at level 3 doesn't need to add any new column because 3 > 2
296    /// and col requirements at level 2 is not applicable for level 3.
297    ///
298    /// see more details in test `expand_proj_step_aggr` and `expand_proj_sort_proj`
299    ///
300    /// TODO(discord9): a simpler solution to track column requirements for merge scan
301    column_requirements: Vec<(HashSet<Column>, usize)>,
302    /// Whether to expand on next call
303    /// This is used to handle the case where a plan is transformed, but need to be expanded from it's
304    /// parent node. For example a Aggregate plan is split into two parts in frontend and datanode, and need
305    /// to be expanded from the parent node of the Aggregate plan.
306    expand_on_next_call: bool,
307    /// Expanding on next partial/conditional/transformed commutative plan
308    /// This is used to handle the case where a plan is transformed, but still
309    /// need to push down as many node as possible before next partial/conditional/transformed commutative
310    /// plan. I.e.
311    /// ```ignore
312    /// Limit:
313    ///     Sort:
314    /// ```
315    /// where `Limit` is partial commutative, and `Sort` is conditional commutative.
316    /// In this case, we need to expand the `Limit` plan,
317    /// so that we can push down the `Sort` plan as much as possible.
318    expand_on_next_part_cond_trans_commutative: bool,
319    new_child_plan: Option<LogicalPlan>,
320}
321
322impl PlanRewriter {
323    fn get_parent(&self) -> Option<&LogicalPlan> {
324        // level starts from 1, it's safe to minus by 1
325        self.stack
326            .iter()
327            .rev()
328            .find(|(_, level)| *level == self.level - 1)
329            .map(|(node, _)| node)
330    }
331
332    /// Return true if should stop and expand. The input plan is the parent node of current node
333    fn should_expand(&mut self, plan: &LogicalPlan) -> DfResult<bool> {
334        debug!(
335            "Check should_expand at level: {}  with Stack:\n{}, ",
336            self.level,
337            self.stack
338                .iter()
339                .map(|(p, l)| format!("{l}:{}{}", "  ".repeat(l - 1), p.display()))
340                .collect::<Vec<String>>()
341                .join("\n"),
342        );
343        if let Err(e) = DFLogicalSubstraitConvertor.encode(plan, DefaultSerializer) {
344            debug!(
345                "PlanRewriter: plan cannot be converted to substrait with error={e:?}, expanding now: {plan}"
346            );
347            return Ok(true);
348        }
349
350        if self.expand_on_next_call {
351            self.expand_on_next_call = false;
352            debug!("PlanRewriter: expand_on_next_call is true, expanding now");
353            return Ok(true);
354        }
355
356        if self.expand_on_next_part_cond_trans_commutative {
357            let comm = Categorizer::check_plan(plan, self.partition_cols.clone())?;
358            match comm {
359                Commutativity::PartialCommutative => {
360                    // a small difference is that for partial commutative, we still need to
361                    // push down it(so `Limit` can be pushed down)
362
363                    // notice how limit needed to be expanded as well to make sure query is correct
364                    // i.e. `Limit fetch=10` need to be pushed down to the leaf node
365                    self.expand_on_next_part_cond_trans_commutative = false;
366                    self.expand_on_next_call = true;
367                }
368                Commutativity::ConditionalCommutative(_)
369                | Commutativity::TransformedCommutative { .. } => {
370                    // again a new node that can be push down, we should just
371                    // do push down now and avoid further expansion
372                    self.expand_on_next_part_cond_trans_commutative = false;
373                    debug!(
374                        "PlanRewriter: meet a new conditional/transformed commutative plan, expanding now: {plan}"
375                    );
376                    return Ok(true);
377                }
378                _ => (),
379            }
380        }
381
382        match Categorizer::check_plan(plan, self.partition_cols.clone())? {
383            Commutativity::Commutative => {}
384            Commutativity::PartialCommutative => {
385                if let Some(plan) = partial_commutative_transformer(plan) {
386                    // notice this plan is parent of current node, so `self.level - 1` when updating column requirements
387                    self.update_column_requirements(&plan, self.level - 1);
388                    self.expand_on_next_part_cond_trans_commutative = true;
389                    self.stage.push(plan)
390                }
391            }
392            Commutativity::ConditionalCommutative(transformer) => {
393                if let Some(transformer) = transformer
394                    && let Some(plan) = transformer(plan)
395                {
396                    // notice this plan is parent of current node, so `self.level - 1` when updating column requirements
397                    self.update_column_requirements(&plan, self.level - 1);
398                    self.expand_on_next_part_cond_trans_commutative = true;
399                    self.stage.push(plan)
400                }
401            }
402            Commutativity::TransformedCommutative { transformer } => {
403                if let Some(transformer) = transformer {
404                    let transformer_actions = transformer(plan)?;
405                    debug!(
406                        "PlanRewriter: transformed plan: {}\n from {plan}",
407                        transformer_actions
408                            .extra_parent_plans
409                            .iter()
410                            .enumerate()
411                            .map(|(i, p)| format!(
412                                "Extra {i}-th parent plan from parent to child = {}",
413                                p.display()
414                            ))
415                            .collect::<Vec<_>>()
416                            .join("\n")
417                    );
418                    if let Some(new_child_plan) = &transformer_actions.new_child_plan {
419                        debug!("PlanRewriter: new child plan: {}", new_child_plan);
420                    }
421                    if let Some(last_stage) = transformer_actions.extra_parent_plans.last() {
422                        // update the column requirements from the last stage
423                        // notice current plan's parent plan is where we need to apply the column requirements
424                        self.update_column_requirements(last_stage, self.level - 1);
425                    }
426                    self.stage
427                        .extend(transformer_actions.extra_parent_plans.into_iter().rev());
428                    self.expand_on_next_call = true;
429                    self.new_child_plan = transformer_actions.new_child_plan;
430                }
431            }
432            Commutativity::NonCommutative
433            | Commutativity::Unimplemented
434            | Commutativity::Unsupported => {
435                debug!("PlanRewriter: meet a non-commutative plan, expanding now: {plan}");
436                return Ok(true);
437            }
438        }
439
440        Ok(false)
441    }
442
443    /// Update the column requirements for the current plan, plan_level is the level of the plan
444    /// in the stack, which is used to determine if the column requirements are applicable
445    /// for other plans in the stack.
446    fn update_column_requirements(&mut self, plan: &LogicalPlan, plan_level: usize) {
447        debug!(
448            "PlanRewriter: update column requirements for plan: {plan}\n with old column_requirements: {:?}",
449            self.column_requirements
450        );
451        let mut container = HashSet::new();
452        for expr in plan.expressions() {
453            // this method won't fail
454            let _ = expr_to_columns(&expr, &mut container);
455        }
456
457        self.column_requirements.push((container, plan_level));
458        debug!(
459            "PlanRewriter: updated column requirements: {:?}",
460            self.column_requirements
461        );
462    }
463
464    fn is_expanded(&self) -> bool {
465        self.status == RewriterStatus::Expanded
466    }
467
468    fn set_expanded(&mut self) {
469        self.status = RewriterStatus::Expanded;
470    }
471
472    fn set_unexpanded(&mut self) {
473        self.status = RewriterStatus::Unexpanded;
474    }
475
476    fn maybe_set_partitions(&mut self, plan: &LogicalPlan) -> DfResult<()> {
477        if let Some(part_cols) = &mut self.partition_cols {
478            // update partition alias
479            let child = plan.inputs().first().cloned().ok_or_else(|| {
480                datafusion_common::DataFusionError::Internal(format!(
481                    "PlanRewriter: maybe_set_partitions: plan has no child: {plan}"
482                ))
483            })?;
484
485            for (_col_name, alias_set) in part_cols.iter_mut() {
486                let aliased_cols = aliased_columns_for(
487                    &alias_set.clone().into_iter().collect(),
488                    plan,
489                    Some(child),
490                )?;
491                *alias_set = aliased_cols.into_values().flatten().collect();
492            }
493
494            debug!(
495                "PlanRewriter: maybe_set_partitions: updated partition columns: {:?} at plan: {}",
496                part_cols,
497                plan.display()
498            );
499
500            return Ok(());
501        }
502
503        if let LogicalPlan::TableScan(table_scan) = plan
504            && let Some(source) = table_scan
505                .source
506                .as_any()
507                .downcast_ref::<DefaultTableSource>()
508            && let Some(provider) = source
509                .table_provider
510                .as_any()
511                .downcast_ref::<DfTableProviderAdapter>()
512        {
513            let table = provider.table();
514            if table.table_type() == TableType::Base {
515                let info = table.table_info();
516                let partition_key_indices = info.meta.partition_key_indices.clone();
517                let schema = info.meta.schema.clone();
518                let mut partition_cols = partition_key_indices
519                    .into_iter()
520                    .map(|index| schema.column_name_by_index(index).to_string())
521                    .collect::<Vec<String>>();
522
523                let partition_rules = table.partition_rules();
524                let exist_phy_part_cols_not_in_logical_table = partition_rules
525                    .map(|r| !r.extra_phy_cols_not_in_logical_table.is_empty())
526                    .unwrap_or(false);
527
528                if exist_phy_part_cols_not_in_logical_table && partition_cols.is_empty() {
529                    // there are other physical partition columns that are not in logical table and part cols are empty
530                    // so we need to add a placeholder for it to prevent certain optimization
531                    // this is used to make sure the final partition columns(that optimizer see) are not empty
532                    // notice if originally partition_cols is not empty, then there is no need to add this place holder,
533                    // as subset of phy part cols can still be used for certain optimization, and it works as if
534                    // those columns are always null
535                    // This helps with distinguishing between non-partitioned table and partitioned table with all phy part cols not in logical table
536                    partition_cols.push(OTHER_PHY_PART_COL_PLACEHOLDER.to_string());
537                }
538                self.partition_cols = Some(
539                            partition_cols
540                                .into_iter()
541                                .map(|c| {
542                                    if c == OTHER_PHY_PART_COL_PLACEHOLDER {
543                                        // for placeholder, just return a empty alias
544                                        return Ok((c.clone(), BTreeSet::new()));
545                                    }
546                                    let index =
547                                        plan.schema().index_of_column_by_name(None, &c).ok_or_else(|| {
548                                            datafusion_common::DataFusionError::Internal(
549                                                format!(
550                                                    "PlanRewriter: maybe_set_partitions: column {c} not found in schema of plan: {plan}"
551                                                ),
552                                            )
553                                        })?;
554                                    let column = plan.schema().columns().get(index).cloned().ok_or_else(|| {
555                                        datafusion_common::DataFusionError::Internal(format!(
556                                            "PlanRewriter: maybe_set_partitions: column index {index} out of bounds in schema of plan: {plan}"
557                                        ))
558                                    })?;
559                                    Ok((c.clone(), BTreeSet::from([column])))
560                                })
561                                .collect::<DfResult<AliasMapping>>()?,
562                        );
563            }
564        }
565
566        Ok(())
567    }
568
569    /// pop one stack item and reduce the level by 1
570    fn pop_stack(&mut self) {
571        self.level -= 1;
572        self.stack.pop();
573    }
574
575    fn expand(&mut self, mut on_node: LogicalPlan) -> DfResult<LogicalPlan> {
576        // store schema before expand, new child plan might have a different schema, so not using it
577        let schema = on_node.schema().clone();
578        if let Some(new_child_plan) = self.new_child_plan.take() {
579            // if there is a new child plan, use it as the new root
580            on_node = new_child_plan;
581        }
582        let mut rewriter = EnforceDistRequirementRewriter::new(
583            std::mem::take(&mut self.column_requirements),
584            self.level,
585        );
586        debug!(
587            "PlanRewriter: enforce column requirements for node: {on_node} with rewriter: {rewriter:?}"
588        );
589        on_node = on_node.rewrite(&mut rewriter)?.data;
590        debug!(
591            "PlanRewriter: after enforced column requirements with rewriter: {rewriter:?} for node:\n{on_node}"
592        );
593
594        debug!(
595            "PlanRewriter: expand on node: {on_node} with partition col alias mapping: {:?}",
596            self.partition_cols
597        );
598
599        // add merge scan as the new root
600        let mut node = MergeScanLogicalPlan::new(
601            on_node.clone(),
602            false,
603            // at this stage, the partition cols should be set
604            // treat it as non-partitioned if None
605            self.partition_cols.clone().unwrap_or_default(),
606        )
607        .into_logical_plan();
608
609        // expand stages
610        for new_stage in self.stage.drain(..) {
611            // tracking alias for merge sort's sort exprs
612            let new_stage = if let LogicalPlan::Extension(ext) = &new_stage
613                && let Some(merge_sort) = ext.node.as_any().downcast_ref::<MergeSortLogicalPlan>()
614            {
615                // TODO(discord9): change `on_node` to `node` once alias tracking is supported for merge scan
616                rewrite_merge_sort_exprs(merge_sort, &on_node)?
617            } else {
618                new_stage
619            };
620            node = new_stage
621                .with_new_exprs(new_stage.expressions_consider_join(), vec![node.clone()])?;
622        }
623        self.set_expanded();
624
625        // recover the schema, this make sure after expand the schema is the same as old node
626        // because after expand the raw top node might have extra columns i.e. sorting columns for `Sort` node
627        let node = LogicalPlanBuilder::from(node)
628            .project(schema.iter().map(|(qualifier, field)| {
629                Expr::Column(Column::new(qualifier.cloned(), field.name()))
630            }))?
631            .build()?;
632
633        Ok(node)
634    }
635}
636
637/// Implementation of the [`TreeNodeRewriter`] trait which is responsible for rewriting
638/// logical plans to enforce various requirement for distributed query.
639///
640/// Requirements enforced by this rewriter:
641/// - Enforce column requirements for `LogicalPlan::Projection` nodes. Makes sure the
642///   required columns are available in the sub plan.
643///
644#[derive(Debug)]
645struct EnforceDistRequirementRewriter {
646    /// only enforce column requirements after the expanding node in question,
647    /// meaning only for node with `cur_level` <= `level` will consider adding those column requirements
648    /// TODO(discord9): a simpler solution to track column requirements for merge scan
649    column_requirements: Vec<(HashSet<Column>, usize)>,
650    /// only apply column requirements >= `cur_level`
651    /// this is used to avoid applying column requirements that are not needed
652    /// for the current node, i.e. the node is not in the scope of the column requirements
653    /// i.e, for this plan:
654    /// ```ignore
655    /// Aggregate: min(t.number)
656    ///   Projection: t.number
657    /// ```
658    /// when on `Projection` node, we don't need to apply the column requirements of `Aggregate` node
659    /// because the `Projection` node is not in the scope of the `Aggregate` node
660    cur_level: usize,
661    plan_per_level: BTreeMap<usize, LogicalPlan>,
662}
663
664impl EnforceDistRequirementRewriter {
665    fn new(column_requirements: Vec<(HashSet<Column>, usize)>, cur_level: usize) -> Self {
666        Self {
667            column_requirements,
668            cur_level,
669            plan_per_level: BTreeMap::new(),
670        }
671    }
672
673    /// Return a mapping from (original column, level) to aliased columns in current node of all
674    /// applicable column requirements
675    /// i.e. only column requirements with level >= `cur_level` will be considered
676    fn get_current_applicable_column_requirements(
677        &self,
678        node: &LogicalPlan,
679    ) -> DfResult<BTreeMap<(Column, usize), BTreeSet<Column>>> {
680        let col_req_per_level = self
681            .column_requirements
682            .iter()
683            .filter(|(_, level)| *level >= self.cur_level)
684            .collect::<Vec<_>>();
685
686        // track alias for columns and use aliased columns instead
687        // aliased col reqs at current level
688        let mut result_alias_mapping = BTreeMap::new();
689        let Some(child) = node.inputs().first().cloned() else {
690            return Ok(Default::default());
691        };
692        for (col_req, level) in col_req_per_level {
693            if let Some(original) = self.plan_per_level.get(level) {
694                // query for alias in current plan
695                let aliased_cols =
696                    aliased_columns_for(&col_req.iter().cloned().collect(), node, Some(original))?;
697                for original_col in col_req {
698                    let aliased_cols = aliased_cols.get(original_col).cloned();
699                    if let Some(cols) = aliased_cols
700                        && !cols.is_empty()
701                    {
702                        result_alias_mapping.insert((original_col.clone(), *level), cols);
703                    } else {
704                        // if no aliased column found in current node, there should be alias in child node as promised by enforce col reqs
705                        // because it should insert required columns in child node
706                        // so we can find the alias in child node
707                        // if not found, it's an internal error
708                        let aliases_in_child = aliased_columns_for(
709                            &[original_col.clone()].into(),
710                            child,
711                            Some(original),
712                        )?;
713                        let Some(aliases) = aliases_in_child
714                            .get(original_col)
715                            .cloned()
716                            .filter(|a| !a.is_empty())
717                        else {
718                            return Err(datafusion_common::DataFusionError::Internal(format!(
719                                "EnforceDistRequirementRewriter: no alias found for required column {original_col} in child plan {child} from original plan {original}",
720                            )));
721                        };
722
723                        result_alias_mapping.insert((original_col.clone(), *level), aliases);
724                    }
725                }
726            }
727        }
728        Ok(result_alias_mapping)
729    }
730}
731
732impl TreeNodeRewriter for EnforceDistRequirementRewriter {
733    type Node = LogicalPlan;
734
735    fn f_down(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
736        // check that node doesn't have multiple children, i.e. join/subquery
737        if node.inputs().len() > 1 {
738            return Err(datafusion_common::DataFusionError::Internal(
739                "EnforceDistRequirementRewriter: node with multiple inputs is not supported"
740                    .to_string(),
741            ));
742        }
743        self.plan_per_level.insert(self.cur_level, node.clone());
744        self.cur_level += 1;
745        Ok(Transformed::no(node))
746    }
747
748    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
749        self.cur_level -= 1;
750        // first get all applicable column requirements
751
752        // make sure all projection applicable scope has the required columns
753        if let LogicalPlan::Projection(ref projection) = node {
754            let mut applicable_column_requirements =
755                self.get_current_applicable_column_requirements(&node)?;
756
757            debug!(
758                "EnforceDistRequirementRewriter: applicable column requirements at level {} = {:?} for node {}",
759                self.cur_level,
760                applicable_column_requirements,
761                node.display()
762            );
763
764            for expr in &projection.expr {
765                let (qualifier, name) = expr.qualified_name();
766                let column = Column::new(qualifier, name);
767                applicable_column_requirements.retain(|_col_level, alias_set| {
768                    // remove all columns that are already in the projection exprs
769                    !alias_set.contains(&column)
770                });
771            }
772            if applicable_column_requirements.is_empty() {
773                return Ok(Transformed::no(node));
774            }
775
776            let mut new_exprs = projection.expr.clone();
777            for (col, alias_set) in &applicable_column_requirements {
778                // use the first alias in alias set as the column to add
779                new_exprs.push(Expr::Column(alias_set.first().cloned().ok_or_else(
780                    || {
781                        datafusion_common::DataFusionError::Internal(
782                            format!("EnforceDistRequirementRewriter: alias set is empty, for column {col:?} in node {node}"),
783                        )
784                    },
785                )?));
786            }
787            let new_node =
788                node.with_new_exprs(new_exprs, node.inputs().into_iter().cloned().collect())?;
789            debug!(
790                "EnforceDistRequirementRewriter: added missing columns {:?} to projection node from old node: \n{node}\n Making new node: \n{new_node}",
791                applicable_column_requirements
792            );
793
794            // update plan for later use
795            self.plan_per_level.insert(self.cur_level, new_node.clone());
796
797            // still need to continue for next projection if applicable
798            return Ok(Transformed::yes(new_node));
799        }
800        Ok(Transformed::no(node))
801    }
802}
803
804impl TreeNodeRewriter for PlanRewriter {
805    type Node = LogicalPlan;
806
807    /// descend
808    fn f_down<'a>(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
809        self.level += 1;
810        self.stack.push((node.clone(), self.level));
811        // decendening will clear the stage
812        self.stage.clear();
813        self.set_unexpanded();
814        self.partition_cols = None;
815        Ok(Transformed::no(node))
816    }
817
818    /// ascend
819    ///
820    /// Besure to call `pop_stack` before returning
821    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
822        // only expand once on each ascending
823        if self.is_expanded() {
824            self.pop_stack();
825            return Ok(Transformed::no(node));
826        }
827
828        // only expand when the leaf is table scan
829        if node.inputs().is_empty() && !matches!(node, LogicalPlan::TableScan(_)) {
830            self.set_expanded();
831            self.pop_stack();
832            return Ok(Transformed::no(node));
833        }
834
835        self.maybe_set_partitions(&node)?;
836
837        let Some(parent) = self.get_parent() else {
838            debug!("Plan Rewriter: expand now for no parent found for node: {node}");
839            let node = self.expand(node);
840            debug!(
841                "PlanRewriter: expanded plan: {}",
842                match &node {
843                    Ok(n) => n.to_string(),
844                    Err(e) => format!("Error expanding plan: {e}"),
845                }
846            );
847            let node = node?;
848            self.pop_stack();
849            return Ok(Transformed::yes(node));
850        };
851
852        let parent = parent.clone();
853
854        if self.should_expand(&parent)? {
855            // TODO(ruihang): does this work for nodes with multiple children?;
856            debug!(
857                "PlanRewriter: should expand child:\n {node}\n Of Parent: {}",
858                parent.display()
859            );
860            let node = self.expand(node);
861            debug!(
862                "PlanRewriter: expanded plan: {}",
863                match &node {
864                    Ok(n) => n.to_string(),
865                    Err(e) => format!("Error expanding plan: {e}"),
866                }
867            );
868            let node = node?;
869            self.pop_stack();
870            return Ok(Transformed::yes(node));
871        }
872
873        self.pop_stack();
874        Ok(Transformed::no(node))
875    }
876}