query/dist_plan/
analyzer.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, BTreeSet, HashSet};
16use std::sync::Arc;
17
18use chrono::{DateTime, Utc};
19use common_telemetry::debug;
20use datafusion::config::{ConfigExtension, ExtensionOptions};
21use datafusion::datasource::DefaultTableSource;
22use datafusion::error::Result as DfResult;
23use datafusion_common::Column;
24use datafusion_common::alias::AliasGenerator;
25use datafusion_common::config::ConfigOptions;
26use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
27use datafusion_expr::expr::{Exists, InSubquery};
28use datafusion_expr::utils::expr_to_columns;
29use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder, Subquery, col as col_fn};
30use datafusion_optimizer::analyzer::AnalyzerRule;
31use datafusion_optimizer::simplify_expressions::SimplifyExpressions;
32use datafusion_optimizer::{OptimizerConfig, OptimizerRule};
33use promql::extension_plan::SeriesDivide;
34use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
35use table::metadata::TableType;
36use table::table::adapter::DfTableProviderAdapter;
37
38use crate::dist_plan::analyzer::utils::{aliased_columns_for, rewrite_merge_sort_exprs};
39use crate::dist_plan::commutativity::{
40    Categorizer, Commutativity, partial_commutative_transformer,
41};
42use crate::dist_plan::merge_scan::MergeScanLogicalPlan;
43use crate::dist_plan::merge_sort::MergeSortLogicalPlan;
44use crate::metrics::PUSH_DOWN_FALLBACK_ERRORS_TOTAL;
45use crate::plan::ExtractExpr;
46use crate::query_engine::DefaultSerializer;
47
48#[cfg(test)]
49mod test;
50
51mod fallback;
52mod utils;
53
54pub(crate) use utils::AliasMapping;
55
56/// Placeholder for other physical partition columns that are not in logical table
57const OTHER_PHY_PART_COL_PLACEHOLDER: &str = "__OTHER_PHYSICAL_PART_COLS_PLACEHOLDER__";
58
59#[derive(Debug, Clone)]
60pub struct DistPlannerOptions {
61    pub allow_query_fallback: bool,
62}
63
64impl ConfigExtension for DistPlannerOptions {
65    const PREFIX: &'static str = "dist_planner";
66}
67
68impl ExtensionOptions for DistPlannerOptions {
69    fn as_any(&self) -> &dyn std::any::Any {
70        self
71    }
72
73    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
74        self
75    }
76
77    fn cloned(&self) -> Box<dyn ExtensionOptions> {
78        Box::new(self.clone())
79    }
80
81    fn set(&mut self, key: &str, value: &str) -> DfResult<()> {
82        Err(datafusion_common::DataFusionError::NotImplemented(format!(
83            "DistPlannerOptions does not support set key: {key} with value: {value}"
84        )))
85    }
86
87    fn entries(&self) -> Vec<datafusion::config::ConfigEntry> {
88        vec![datafusion::config::ConfigEntry {
89            key: "allow_query_fallback".to_string(),
90            value: Some(self.allow_query_fallback.to_string()),
91            description: "Allow query fallback to fallback plan rewriter",
92        }]
93    }
94}
95
96#[derive(Debug)]
97pub struct DistPlannerAnalyzer;
98
99impl AnalyzerRule for DistPlannerAnalyzer {
100    fn name(&self) -> &str {
101        "DistPlannerAnalyzer"
102    }
103
104    fn analyze(
105        &self,
106        plan: LogicalPlan,
107        config: &ConfigOptions,
108    ) -> datafusion_common::Result<LogicalPlan> {
109        let mut config = config.clone();
110        // Aligned with the behavior in `datafusion_optimizer::OptimizerContext::new()`.
111        config.optimizer.filter_null_join_keys = true;
112        let config = Arc::new(config);
113
114        // The `ConstEvaluator` in `SimplifyExpressions` might evaluate some UDFs early in the
115        // planning stage, by executing them directly. For example, the `database()` function.
116        // So the `ConfigOptions` here (which is set from the session context) should be present
117        // in the UDF's `ScalarFunctionArgs`. However, the default implementation in DataFusion
118        // seems to lost track on it: the `ConfigOptions` is recreated with its default values again.
119        // So we create a custom `OptimizerConfig` with the desired `ConfigOptions`
120        // to walk around the issue.
121        // TODO(LFC): Maybe use DataFusion's `OptimizerContext` again
122        //   once https://github.com/apache/datafusion/pull/17742 is merged.
123        struct OptimizerContext {
124            inner: datafusion_optimizer::OptimizerContext,
125            config: Arc<ConfigOptions>,
126        }
127
128        impl OptimizerConfig for OptimizerContext {
129            fn query_execution_start_time(&self) -> DateTime<Utc> {
130                self.inner.query_execution_start_time()
131            }
132
133            fn alias_generator(&self) -> &Arc<AliasGenerator> {
134                self.inner.alias_generator()
135            }
136
137            fn options(&self) -> Arc<ConfigOptions> {
138                self.config.clone()
139            }
140        }
141
142        let optimizer_context = OptimizerContext {
143            inner: datafusion_optimizer::OptimizerContext::new(),
144            config: config.clone(),
145        };
146
147        let plan = SimplifyExpressions::new()
148            .rewrite(plan, &optimizer_context)?
149            .data;
150
151        let opt = config.extensions.get::<DistPlannerOptions>();
152        let allow_fallback = opt.map(|o| o.allow_query_fallback).unwrap_or(false);
153
154        let result = match self.try_push_down(plan.clone()) {
155            Ok(plan) => plan,
156            Err(err) => {
157                if allow_fallback {
158                    common_telemetry::warn!(err; "Failed to push down plan, using fallback plan rewriter for plan: {plan}");
159                    // if push down failed, use fallback plan rewriter
160                    PUSH_DOWN_FALLBACK_ERRORS_TOTAL.inc();
161                    self.use_fallback(plan)?
162                } else {
163                    return Err(err);
164                }
165            }
166        };
167
168        Ok(result)
169    }
170}
171
172impl DistPlannerAnalyzer {
173    /// Try push down as many nodes as possible
174    fn try_push_down(&self, plan: LogicalPlan) -> DfResult<LogicalPlan> {
175        let plan = plan.transform(&Self::inspect_plan_with_subquery)?;
176        let mut rewriter = PlanRewriter::default();
177        let result = plan.data.rewrite(&mut rewriter)?.data;
178        Ok(result)
179    }
180
181    /// Use fallback plan rewriter to rewrite the plan and only push down table scan nodes
182    fn use_fallback(&self, plan: LogicalPlan) -> DfResult<LogicalPlan> {
183        let mut rewriter = fallback::FallbackPlanRewriter;
184        let result = plan.rewrite(&mut rewriter)?.data;
185        Ok(result)
186    }
187
188    fn inspect_plan_with_subquery(plan: LogicalPlan) -> DfResult<Transformed<LogicalPlan>> {
189        // Workaround for https://github.com/GreptimeTeam/greptimedb/issues/5469 and https://github.com/GreptimeTeam/greptimedb/issues/5799
190        // FIXME(yingwen): Remove the `Limit` plan once we update DataFusion.
191        if let LogicalPlan::Limit(_) | LogicalPlan::Distinct(_) = &plan {
192            return Ok(Transformed::no(plan));
193        }
194
195        let exprs = plan
196            .expressions_consider_join()
197            .into_iter()
198            .map(|e| e.transform(&Self::transform_subquery).map(|x| x.data))
199            .collect::<DfResult<Vec<_>>>()?;
200
201        // Some plans that are special treated (should not call `with_new_exprs` on them)
202        if !matches!(plan, LogicalPlan::Unnest(_)) {
203            let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
204            Ok(Transformed::yes(plan.with_new_exprs(exprs, inputs)?))
205        } else {
206            Ok(Transformed::no(plan))
207        }
208    }
209
210    fn transform_subquery(expr: Expr) -> DfResult<Transformed<Expr>> {
211        match expr {
212            Expr::Exists(exists) => Ok(Transformed::yes(Expr::Exists(Exists {
213                subquery: Self::handle_subquery(exists.subquery)?,
214                negated: exists.negated,
215            }))),
216            Expr::InSubquery(in_subquery) => Ok(Transformed::yes(Expr::InSubquery(InSubquery {
217                expr: in_subquery.expr,
218                subquery: Self::handle_subquery(in_subquery.subquery)?,
219                negated: in_subquery.negated,
220            }))),
221            Expr::ScalarSubquery(scalar_subquery) => Ok(Transformed::yes(Expr::ScalarSubquery(
222                Self::handle_subquery(scalar_subquery)?,
223            ))),
224
225            _ => Ok(Transformed::no(expr)),
226        }
227    }
228
229    fn handle_subquery(subquery: Subquery) -> DfResult<Subquery> {
230        let mut rewriter = PlanRewriter::default();
231        let mut rewrote_subquery = subquery
232            .subquery
233            .as_ref()
234            .clone()
235            .rewrite(&mut rewriter)?
236            .data;
237        // Workaround. DF doesn't support the first plan in subquery to be an Extension
238        if matches!(rewrote_subquery, LogicalPlan::Extension(_)) {
239            let output_schema = rewrote_subquery.schema().clone();
240            let project_exprs = output_schema
241                .fields()
242                .iter()
243                .map(|f| col_fn(f.name()))
244                .collect::<Vec<_>>();
245            rewrote_subquery = LogicalPlanBuilder::from(rewrote_subquery)
246                .project(project_exprs)?
247                .build()?;
248        }
249
250        Ok(Subquery {
251            subquery: Arc::new(rewrote_subquery),
252            outer_ref_columns: subquery.outer_ref_columns,
253            spans: Default::default(),
254        })
255    }
256}
257
258/// Status of the rewriter to mark if the current pass is expanded
259#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
260enum RewriterStatus {
261    #[default]
262    Unexpanded,
263    Expanded,
264}
265
266#[derive(Debug, Default)]
267struct PlanRewriter {
268    /// Current level in the tree
269    level: usize,
270    /// Simulated stack for the `rewrite` recursion
271    stack: Vec<(LogicalPlan, usize)>,
272    /// Stages to be expanded, will be added as parent node of merge scan one by one
273    stage: Vec<LogicalPlan>,
274    status: RewriterStatus,
275    /// Partition columns of the table in current pass
276    partition_cols: Option<AliasMapping>,
277    /// use stack count as scope to determine column requirements is needed or not
278    /// i.e for a logical plan like:
279    /// ```ignore
280    /// 1: Projection: t.number
281    /// 2: Sort: t.pk1+t.pk2
282    /// 3. Projection: t.number, t.pk1, t.pk2
283    /// ```
284    /// `Sort` will make a column requirement for `t.pk1` at level 2.
285    /// Which making `Projection` at level 1 need to add a ref to `t.pk1` as well.
286    /// So that the expanded plan will be
287    /// ```ignore
288    /// Projection: t.number
289    ///   MergeSort: t.pk1
290    ///     MergeScan: remote_input=
291    /// Projection: t.number, "t.pk1+t.pk2" <--- the original `Projection` at level 1 get added with `t.pk1+t.pk2`
292    ///  Sort: t.pk1+t.pk2
293    ///    Projection: t.number, t.pk1, t.pk2
294    /// ```
295    /// Making `MergeSort` can have `t.pk1` as input.
296    /// Meanwhile `Projection` at level 3 doesn't need to add any new column because 3 > 2
297    /// and col requirements at level 2 is not applicable for level 3.
298    ///
299    /// see more details in test `expand_proj_step_aggr` and `expand_proj_sort_proj`
300    ///
301    /// TODO(discord9): a simpler solution to track column requirements for merge scan
302    column_requirements: Vec<(HashSet<Column>, usize)>,
303    /// Whether to expand on next call
304    /// This is used to handle the case where a plan is transformed, but need to be expanded from it's
305    /// parent node. For example a Aggregate plan is split into two parts in frontend and datanode, and need
306    /// to be expanded from the parent node of the Aggregate plan.
307    expand_on_next_call: bool,
308    /// Expanding on next partial/conditional/transformed commutative plan
309    /// This is used to handle the case where a plan is transformed, but still
310    /// need to push down as many node as possible before next partial/conditional/transformed commutative
311    /// plan. I.e.
312    /// ```ignore
313    /// Limit:
314    ///     Sort:
315    /// ```
316    /// where `Limit` is partial commutative, and `Sort` is conditional commutative.
317    /// In this case, we need to expand the `Limit` plan,
318    /// so that we can push down the `Sort` plan as much as possible.
319    expand_on_next_part_cond_trans_commutative: bool,
320    new_child_plan: Option<LogicalPlan>,
321}
322
323impl PlanRewriter {
324    fn get_parent(&self) -> Option<&LogicalPlan> {
325        // level starts from 1, it's safe to minus by 1
326        self.stack
327            .iter()
328            .rev()
329            .find(|(_, level)| *level == self.level - 1)
330            .map(|(node, _)| node)
331    }
332
333    /// Return true if should stop and expand. The input plan is the parent node of current node
334    fn should_expand(&mut self, plan: &LogicalPlan) -> DfResult<bool> {
335        debug!(
336            "Check should_expand at level: {}  with Stack:\n{}, ",
337            self.level,
338            self.stack
339                .iter()
340                .map(|(p, l)| format!("{l}:{}{}", "  ".repeat(l - 1), p.display()))
341                .collect::<Vec<String>>()
342                .join("\n"),
343        );
344        if let Err(e) = DFLogicalSubstraitConvertor.encode(plan, DefaultSerializer) {
345            debug!(
346                "PlanRewriter: plan cannot be converted to substrait with error={e:?}, expanding now: {plan}"
347            );
348            return Ok(true);
349        }
350
351        if self.expand_on_next_call {
352            self.expand_on_next_call = false;
353            debug!("PlanRewriter: expand_on_next_call is true, expanding now");
354            return Ok(true);
355        }
356
357        if self.expand_on_next_part_cond_trans_commutative {
358            let comm = Categorizer::check_plan(plan, self.partition_cols.clone())?;
359            match comm {
360                Commutativity::PartialCommutative => {
361                    // a small difference is that for partial commutative, we still need to
362                    // push down it(so `Limit` can be pushed down)
363
364                    // notice how limit needed to be expanded as well to make sure query is correct
365                    // i.e. `Limit fetch=10` need to be pushed down to the leaf node
366                    self.expand_on_next_part_cond_trans_commutative = false;
367                    self.expand_on_next_call = true;
368                }
369                Commutativity::ConditionalCommutative(_)
370                | Commutativity::TransformedCommutative { .. } => {
371                    // again a new node that can be push down, we should just
372                    // do push down now and avoid further expansion
373                    self.expand_on_next_part_cond_trans_commutative = false;
374                    debug!(
375                        "PlanRewriter: meet a new conditional/transformed commutative plan, expanding now: {plan}"
376                    );
377                    return Ok(true);
378                }
379                _ => (),
380            }
381        }
382
383        match Categorizer::check_plan(plan, self.partition_cols.clone())? {
384            Commutativity::Commutative => {
385                // PATCH: we should reconsider SORT's commutativity instead of doing this trick.
386                // explain: for a fully commutative SeriesDivide, its child Sort plan only serves it. I.e., that
387                //   Sort plan is also fully commutative, instead of conditional commutative. So we can remove
388                //   the generated MergeSort from stage safely.
389                if let LogicalPlan::Extension(ext_a) = plan
390                    && ext_a.node.name() == SeriesDivide::name()
391                    && let Some(LogicalPlan::Extension(ext_b)) = self.stage.last()
392                    && ext_b.node.name() == MergeSortLogicalPlan::name()
393                {
394                    // revert last `ConditionalCommutative` result for Sort plan in this case.
395                    // `update_column_requirements` left unchanged because Sort won't generate
396                    // new columns or remove existing columns.
397                    self.stage.pop();
398                    self.expand_on_next_part_cond_trans_commutative = false;
399                }
400            }
401            Commutativity::PartialCommutative => {
402                if let Some(plan) = partial_commutative_transformer(plan) {
403                    // notice this plan is parent of current node, so `self.level - 1` when updating column requirements
404                    self.update_column_requirements(&plan, self.level - 1);
405                    self.expand_on_next_part_cond_trans_commutative = true;
406                    self.stage.push(plan)
407                }
408            }
409            Commutativity::ConditionalCommutative(transformer) => {
410                if let Some(transformer) = transformer
411                    && let Some(plan) = transformer(plan)
412                {
413                    // notice this plan is parent of current node, so `self.level - 1` when updating column requirements
414                    self.update_column_requirements(&plan, self.level - 1);
415                    self.expand_on_next_part_cond_trans_commutative = true;
416                    self.stage.push(plan)
417                }
418            }
419            Commutativity::TransformedCommutative { transformer } => {
420                if let Some(transformer) = transformer {
421                    let transformer_actions = transformer(plan)?;
422                    debug!(
423                        "PlanRewriter: transformed plan: {}\n from {plan}",
424                        transformer_actions
425                            .extra_parent_plans
426                            .iter()
427                            .enumerate()
428                            .map(|(i, p)| format!(
429                                "Extra {i}-th parent plan from parent to child = {}",
430                                p.display()
431                            ))
432                            .collect::<Vec<_>>()
433                            .join("\n")
434                    );
435                    if let Some(new_child_plan) = &transformer_actions.new_child_plan {
436                        debug!("PlanRewriter: new child plan: {}", new_child_plan);
437                    }
438                    if let Some(last_stage) = transformer_actions.extra_parent_plans.last() {
439                        // update the column requirements from the last stage
440                        // notice current plan's parent plan is where we need to apply the column requirements
441                        self.update_column_requirements(last_stage, self.level - 1);
442                    }
443                    self.stage
444                        .extend(transformer_actions.extra_parent_plans.into_iter().rev());
445                    self.expand_on_next_call = true;
446                    self.new_child_plan = transformer_actions.new_child_plan;
447                }
448            }
449            Commutativity::NonCommutative
450            | Commutativity::Unimplemented
451            | Commutativity::Unsupported => {
452                debug!("PlanRewriter: meet a non-commutative plan, expanding now: {plan}");
453                return Ok(true);
454            }
455        }
456
457        Ok(false)
458    }
459
460    /// Update the column requirements for the current plan, plan_level is the level of the plan
461    /// in the stack, which is used to determine if the column requirements are applicable
462    /// for other plans in the stack.
463    fn update_column_requirements(&mut self, plan: &LogicalPlan, plan_level: usize) {
464        debug!(
465            "PlanRewriter: update column requirements for plan: {plan}\n with old column_requirements: {:?}",
466            self.column_requirements
467        );
468        let mut container = HashSet::new();
469        for expr in plan.expressions() {
470            // this method won't fail
471            let _ = expr_to_columns(&expr, &mut container);
472        }
473
474        self.column_requirements.push((container, plan_level));
475        debug!(
476            "PlanRewriter: updated column requirements: {:?}",
477            self.column_requirements
478        );
479    }
480
481    fn is_expanded(&self) -> bool {
482        self.status == RewriterStatus::Expanded
483    }
484
485    fn set_expanded(&mut self) {
486        self.status = RewriterStatus::Expanded;
487    }
488
489    fn set_unexpanded(&mut self) {
490        self.status = RewriterStatus::Unexpanded;
491    }
492
493    fn maybe_set_partitions(&mut self, plan: &LogicalPlan) -> DfResult<()> {
494        if let Some(part_cols) = &mut self.partition_cols {
495            // update partition alias
496            let child = plan.inputs().first().cloned().ok_or_else(|| {
497                datafusion_common::DataFusionError::Internal(format!(
498                    "PlanRewriter: maybe_set_partitions: plan has no child: {plan}"
499                ))
500            })?;
501
502            for (_col_name, alias_set) in part_cols.iter_mut() {
503                let aliased_cols = aliased_columns_for(
504                    &alias_set.clone().into_iter().collect(),
505                    plan,
506                    Some(child),
507                )?;
508                *alias_set = aliased_cols.into_values().flatten().collect();
509            }
510
511            debug!(
512                "PlanRewriter: maybe_set_partitions: updated partition columns: {:?} at plan: {}",
513                part_cols,
514                plan.display()
515            );
516
517            return Ok(());
518        }
519
520        if let LogicalPlan::TableScan(table_scan) = plan
521            && let Some(source) = table_scan
522                .source
523                .as_any()
524                .downcast_ref::<DefaultTableSource>()
525            && let Some(provider) = source
526                .table_provider
527                .as_any()
528                .downcast_ref::<DfTableProviderAdapter>()
529        {
530            let table = provider.table();
531            if table.table_type() == TableType::Base {
532                let info = table.table_info();
533                let partition_key_indices = info.meta.partition_key_indices.clone();
534                let schema = info.meta.schema.clone();
535                let mut partition_cols = partition_key_indices
536                    .into_iter()
537                    .map(|index| schema.column_name_by_index(index).to_string())
538                    .collect::<Vec<String>>();
539
540                let partition_rules = table.partition_rules();
541                let exist_phy_part_cols_not_in_logical_table = partition_rules
542                    .map(|r| !r.extra_phy_cols_not_in_logical_table.is_empty())
543                    .unwrap_or(false);
544
545                if exist_phy_part_cols_not_in_logical_table && partition_cols.is_empty() {
546                    // there are other physical partition columns that are not in logical table and part cols are empty
547                    // so we need to add a placeholder for it to prevent certain optimization
548                    // this is used to make sure the final partition columns(that optimizer see) are not empty
549                    // notice if originally partition_cols is not empty, then there is no need to add this place holder,
550                    // as subset of phy part cols can still be used for certain optimization, and it works as if
551                    // those columns are always null
552                    // This helps with distinguishing between non-partitioned table and partitioned table with all phy part cols not in logical table
553                    partition_cols.push(OTHER_PHY_PART_COL_PLACEHOLDER.to_string());
554                }
555                self.partition_cols = Some(
556                            partition_cols
557                                .into_iter()
558                                .map(|c| {
559                                    if c == OTHER_PHY_PART_COL_PLACEHOLDER {
560                                        // for placeholder, just return a empty alias
561                                        return Ok((c.clone(), BTreeSet::new()));
562                                    }
563                                    let index =
564                                        if let Some(c) = plan.schema().index_of_column_by_name(None, &c){
565                                            c
566                                        } else {
567                                            // the `projection` field of `TableScan` doesn't contain the partition columns,
568                                            // this is similar to not having a alias, hence return empty alias set
569                                            return Ok((c.clone(), BTreeSet::new()))
570                                        };
571                                    let column = plan.schema().columns().get(index).cloned().ok_or_else(|| {
572                                        datafusion_common::DataFusionError::Internal(format!(
573                                            "PlanRewriter: maybe_set_partitions: column index {index} out of bounds in schema of plan: {plan}"
574                                        ))
575                                    })?;
576                                    Ok((c.clone(), BTreeSet::from([column])))
577                                })
578                                .collect::<DfResult<AliasMapping>>()?,
579                        );
580            }
581        }
582
583        Ok(())
584    }
585
586    /// pop one stack item and reduce the level by 1
587    fn pop_stack(&mut self) {
588        self.level -= 1;
589        self.stack.pop();
590    }
591
592    fn expand(&mut self, mut on_node: LogicalPlan) -> DfResult<LogicalPlan> {
593        // store schema before expand, new child plan might have a different schema, so not using it
594        let schema = on_node.schema().clone();
595        if let Some(new_child_plan) = self.new_child_plan.take() {
596            // if there is a new child plan, use it as the new root
597            on_node = new_child_plan;
598        }
599        let mut rewriter = EnforceDistRequirementRewriter::new(
600            std::mem::take(&mut self.column_requirements),
601            self.level,
602        );
603        debug!(
604            "PlanRewriter: enforce column requirements for node: {on_node} with rewriter: {rewriter:?}"
605        );
606        on_node = on_node.rewrite(&mut rewriter)?.data;
607        debug!(
608            "PlanRewriter: after enforced column requirements with rewriter: {rewriter:?} for node:\n{on_node}"
609        );
610
611        debug!(
612            "PlanRewriter: expand on node: {on_node} with partition col alias mapping: {:?}",
613            self.partition_cols
614        );
615
616        // add merge scan as the new root
617        let mut node = MergeScanLogicalPlan::new(
618            on_node.clone(),
619            false,
620            // at this stage, the partition cols should be set
621            // treat it as non-partitioned if None
622            self.partition_cols.clone().unwrap_or_default(),
623        )
624        .into_logical_plan();
625
626        // expand stages
627        for new_stage in self.stage.drain(..) {
628            // tracking alias for merge sort's sort exprs
629            let new_stage = if let LogicalPlan::Extension(ext) = &new_stage
630                && let Some(merge_sort) = ext.node.as_any().downcast_ref::<MergeSortLogicalPlan>()
631            {
632                // TODO(discord9): change `on_node` to `node` once alias tracking is supported for merge scan
633                rewrite_merge_sort_exprs(merge_sort, &on_node)?
634            } else {
635                new_stage
636            };
637            node = new_stage
638                .with_new_exprs(new_stage.expressions_consider_join(), vec![node.clone()])?;
639        }
640        self.set_expanded();
641
642        // recover the schema, this make sure after expand the schema is the same as old node
643        // because after expand the raw top node might have extra columns i.e. sorting columns for `Sort` node
644        let node = LogicalPlanBuilder::from(node)
645            .project(schema.iter().map(|(qualifier, field)| {
646                Expr::Column(Column::new(qualifier.cloned(), field.name()))
647            }))?
648            .build()?;
649
650        Ok(node)
651    }
652}
653
654/// Implementation of the [`TreeNodeRewriter`] trait which is responsible for rewriting
655/// logical plans to enforce various requirement for distributed query.
656///
657/// Requirements enforced by this rewriter:
658/// - Enforce column requirements for `LogicalPlan::Projection` nodes. Makes sure the
659///   required columns are available in the sub plan.
660///
661#[derive(Debug)]
662struct EnforceDistRequirementRewriter {
663    /// only enforce column requirements after the expanding node in question,
664    /// meaning only for node with `cur_level` <= `level` will consider adding those column requirements
665    /// TODO(discord9): a simpler solution to track column requirements for merge scan
666    column_requirements: Vec<(HashSet<Column>, usize)>,
667    /// only apply column requirements >= `cur_level`
668    /// this is used to avoid applying column requirements that are not needed
669    /// for the current node, i.e. the node is not in the scope of the column requirements
670    /// i.e, for this plan:
671    /// ```ignore
672    /// Aggregate: min(t.number)
673    ///   Projection: t.number
674    /// ```
675    /// when on `Projection` node, we don't need to apply the column requirements of `Aggregate` node
676    /// because the `Projection` node is not in the scope of the `Aggregate` node
677    cur_level: usize,
678    plan_per_level: BTreeMap<usize, LogicalPlan>,
679}
680
681impl EnforceDistRequirementRewriter {
682    fn new(column_requirements: Vec<(HashSet<Column>, usize)>, cur_level: usize) -> Self {
683        Self {
684            column_requirements,
685            cur_level,
686            plan_per_level: BTreeMap::new(),
687        }
688    }
689
690    /// Return a mapping from (original column, level) to aliased columns in current node of all
691    /// applicable column requirements
692    /// i.e. only column requirements with level >= `cur_level` will be considered
693    fn get_current_applicable_column_requirements(
694        &self,
695        node: &LogicalPlan,
696    ) -> DfResult<BTreeMap<(Column, usize), BTreeSet<Column>>> {
697        let col_req_per_level = self
698            .column_requirements
699            .iter()
700            .filter(|(_, level)| *level >= self.cur_level)
701            .collect::<Vec<_>>();
702
703        // track alias for columns and use aliased columns instead
704        // aliased col reqs at current level
705        let mut result_alias_mapping = BTreeMap::new();
706        let Some(child) = node.inputs().first().cloned() else {
707            return Ok(Default::default());
708        };
709        for (col_req, level) in col_req_per_level {
710            if let Some(original) = self.plan_per_level.get(level) {
711                // query for alias in current plan
712                let aliased_cols =
713                    aliased_columns_for(&col_req.iter().cloned().collect(), node, Some(original))?;
714                for original_col in col_req {
715                    let aliased_cols = aliased_cols.get(original_col).cloned();
716                    if let Some(cols) = aliased_cols
717                        && !cols.is_empty()
718                    {
719                        result_alias_mapping.insert((original_col.clone(), *level), cols);
720                    } else {
721                        // if no aliased column found in current node, there should be alias in child node as promised by enforce col reqs
722                        // because it should insert required columns in child node
723                        // so we can find the alias in child node
724                        // if not found, it's an internal error
725                        let aliases_in_child = aliased_columns_for(
726                            &[original_col.clone()].into(),
727                            child,
728                            Some(original),
729                        )?;
730                        let Some(aliases) = aliases_in_child
731                            .get(original_col)
732                            .cloned()
733                            .filter(|a| !a.is_empty())
734                        else {
735                            return Err(datafusion_common::DataFusionError::Internal(format!(
736                                "EnforceDistRequirementRewriter: no alias found for required column {original_col} in child plan {child} from original plan {original}",
737                            )));
738                        };
739
740                        result_alias_mapping.insert((original_col.clone(), *level), aliases);
741                    }
742                }
743            }
744        }
745        Ok(result_alias_mapping)
746    }
747}
748
749impl TreeNodeRewriter for EnforceDistRequirementRewriter {
750    type Node = LogicalPlan;
751
752    fn f_down(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
753        // check that node doesn't have multiple children, i.e. join/subquery
754        if node.inputs().len() > 1 {
755            return Err(datafusion_common::DataFusionError::Internal(
756                "EnforceDistRequirementRewriter: node with multiple inputs is not supported"
757                    .to_string(),
758            ));
759        }
760        self.plan_per_level.insert(self.cur_level, node.clone());
761        self.cur_level += 1;
762        Ok(Transformed::no(node))
763    }
764
765    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
766        self.cur_level -= 1;
767        // first get all applicable column requirements
768
769        // make sure all projection applicable scope has the required columns
770        if let LogicalPlan::Projection(ref projection) = node {
771            let mut applicable_column_requirements =
772                self.get_current_applicable_column_requirements(&node)?;
773
774            debug!(
775                "EnforceDistRequirementRewriter: applicable column requirements at level {} = {:?} for node {}",
776                self.cur_level,
777                applicable_column_requirements,
778                node.display()
779            );
780
781            for expr in &projection.expr {
782                let (qualifier, name) = expr.qualified_name();
783                let column = Column::new(qualifier, name);
784                applicable_column_requirements.retain(|_col_level, alias_set| {
785                    // remove all columns that are already in the projection exprs
786                    !alias_set.contains(&column)
787                });
788            }
789            if applicable_column_requirements.is_empty() {
790                return Ok(Transformed::no(node));
791            }
792
793            let mut new_exprs = projection.expr.clone();
794            for (col, alias_set) in &applicable_column_requirements {
795                // use the first alias in alias set as the column to add
796                new_exprs.push(Expr::Column(alias_set.first().cloned().ok_or_else(
797                    || {
798                        datafusion_common::DataFusionError::Internal(
799                            format!("EnforceDistRequirementRewriter: alias set is empty, for column {col:?} in node {node}"),
800                        )
801                    },
802                )?));
803            }
804            let new_node =
805                node.with_new_exprs(new_exprs, node.inputs().into_iter().cloned().collect())?;
806            debug!(
807                "EnforceDistRequirementRewriter: added missing columns {:?} to projection node from old node: \n{node}\n Making new node: \n{new_node}",
808                applicable_column_requirements
809            );
810
811            // update plan for later use
812            self.plan_per_level.insert(self.cur_level, new_node.clone());
813
814            // still need to continue for next projection if applicable
815            return Ok(Transformed::yes(new_node));
816        }
817        Ok(Transformed::no(node))
818    }
819}
820
821impl TreeNodeRewriter for PlanRewriter {
822    type Node = LogicalPlan;
823
824    /// descend
825    fn f_down<'a>(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
826        self.level += 1;
827        self.stack.push((node.clone(), self.level));
828        // decendening will clear the stage
829        self.stage.clear();
830        self.set_unexpanded();
831        self.partition_cols = None;
832        Ok(Transformed::no(node))
833    }
834
835    /// ascend
836    ///
837    /// Besure to call `pop_stack` before returning
838    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
839        // only expand once on each ascending
840        if self.is_expanded() {
841            self.pop_stack();
842            return Ok(Transformed::no(node));
843        }
844
845        // only expand when the leaf is table scan
846        if node.inputs().is_empty() && !matches!(node, LogicalPlan::TableScan(_)) {
847            self.set_expanded();
848            self.pop_stack();
849            return Ok(Transformed::no(node));
850        }
851
852        self.maybe_set_partitions(&node)?;
853
854        let Some(parent) = self.get_parent() else {
855            debug!("Plan Rewriter: expand now for no parent found for node: {node}");
856            let node = self.expand(node);
857            debug!(
858                "PlanRewriter: expanded plan: {}",
859                match &node {
860                    Ok(n) => n.to_string(),
861                    Err(e) => format!("Error expanding plan: {e}"),
862                }
863            );
864            let node = node?;
865            self.pop_stack();
866            return Ok(Transformed::yes(node));
867        };
868
869        let parent = parent.clone();
870
871        if self.should_expand(&parent)? {
872            // TODO(ruihang): does this work for nodes with multiple children?;
873            debug!(
874                "PlanRewriter: should expand child:\n {node}\n Of Parent: {}",
875                parent.display()
876            );
877            let node = self.expand(node);
878            debug!(
879                "PlanRewriter: expanded plan: {}",
880                match &node {
881                    Ok(n) => n.to_string(),
882                    Err(e) => format!("Error expanding plan: {e}"),
883                }
884            );
885            let node = node?;
886            self.pop_stack();
887            return Ok(Transformed::yes(node));
888        }
889
890        self.pop_stack();
891        Ok(Transformed::no(node))
892    }
893}