query/dist_plan/analyzer/
utils.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
16use std::sync::Arc;
17
18use arrow::array::ArrayRef;
19use arrow_schema::{ArrowError, DataType};
20use chrono::{DateTime, Utc};
21use datafusion::common::alias::AliasGenerator;
22use datafusion::config::ConfigOptions;
23use datafusion::error::Result as DfResult;
24use datafusion_common::Column;
25use datafusion_common::tree_node::{Transformed, TreeNode as _, TreeNodeRewriter};
26use datafusion_expr::expr::Alias;
27use datafusion_expr::{Expr, Extension, LogicalPlan};
28use datafusion_optimizer::simplify_expressions::SimplifyExpressions;
29use datafusion_optimizer::{OptimizerConfig, OptimizerRule as _};
30
31use crate::dist_plan::merge_sort::MergeSortLogicalPlan;
32use crate::plan::ExtractExpr as _;
33
34/// The `ConstEvaluator` in `SimplifyExpressions` might evaluate some UDFs early in the
35/// planning stage, by executing them directly. For example, the `database()` function.
36/// So the `ConfigOptions` here (which is set from the session context) should be present
37/// in the UDF's `ScalarFunctionArgs`. However, the default implementation in DataFusion
38/// seems to lost track on it: the `ConfigOptions` is recreated with its default values again.
39/// So we create a custom `OptimizerConfig` with the desired `ConfigOptions`
40/// to walk around the issue.
41/// TODO(LFC): Maybe use DataFusion's `OptimizerContext` again
42///   once https://github.com/apache/datafusion/pull/17742 is merged.
43pub(crate) struct PatchOptimizerContext {
44    pub(crate) inner: datafusion_optimizer::OptimizerContext,
45    pub(crate) config: Arc<ConfigOptions>,
46}
47
48impl OptimizerConfig for PatchOptimizerContext {
49    fn query_execution_start_time(&self) -> DateTime<Utc> {
50        self.inner.query_execution_start_time()
51    }
52
53    fn alias_generator(&self) -> &Arc<AliasGenerator> {
54        self.inner.alias_generator()
55    }
56
57    fn options(&self) -> Arc<ConfigOptions> {
58        self.config.clone()
59    }
60}
61
62/// Simplify all expressions recursively in the plan tree
63/// which keeping the output schema unchanged
64pub(crate) struct PlanTreeExpressionSimplifier {
65    optimizer_context: PatchOptimizerContext,
66}
67
68impl PlanTreeExpressionSimplifier {
69    pub fn new(optimizer_context: PatchOptimizerContext) -> Self {
70        Self { optimizer_context }
71    }
72}
73
74impl TreeNodeRewriter for PlanTreeExpressionSimplifier {
75    type Node = LogicalPlan;
76    fn f_down(&mut self, plan: Self::Node) -> DfResult<Transformed<Self::Node>> {
77        let simp = SimplifyExpressions::new()
78            .rewrite(plan, &self.optimizer_context)?
79            .data;
80        Ok(Transformed::yes(simp))
81    }
82}
83
84/// A patch for substrait simply throw timezone away, so when decoding, if columns have different timezone then expected schema, use expected schema's timezone
85pub fn patch_batch_timezone(
86    expected_schema: arrow_schema::SchemaRef,
87    columns: Vec<ArrayRef>,
88) -> Result<arrow::record_batch::RecordBatch, ArrowError> {
89    let patched_columns: Vec<ArrayRef> = expected_schema
90        .fields()
91        .iter()
92        .zip(columns.into_iter())
93        .map(|(expected_field, column)| {
94            let expected_type = expected_field.data_type();
95            let actual_type = column.data_type();
96
97            // Check if both are timestamp types with different timezones
98            match (expected_type, actual_type) {
99                (
100                    DataType::Timestamp(expected_unit, expected_tz),
101                    DataType::Timestamp(actual_unit, actual_tz),
102                ) if expected_unit == actual_unit && expected_tz != actual_tz => {
103                    // Cast the column to the expected timezone
104                    arrow::compute::cast(&column, expected_type)
105                }
106                _ => Ok(column),
107            }
108        })
109        .collect::<Result<Vec<_>, _>>()?;
110
111    arrow::record_batch::RecordBatch::try_new(expected_schema.clone(), patched_columns)
112}
113
114fn rewrite_column(
115    mapping: &BTreeMap<Column, BTreeSet<Column>>,
116    original_node: &LogicalPlan,
117    alias_node: &LogicalPlan,
118) -> impl Fn(Expr) -> DfResult<Transformed<Expr>> {
119    move |e: Expr| {
120        if let Expr::Column(col) = e {
121            if let Some(aliased_cols) = mapping.get(&col) {
122                // if multiple alias is available, just use first one
123                if let Some(aliased_col) = aliased_cols.iter().next() {
124                    Ok(Transformed::yes(Expr::Column(aliased_col.clone())))
125                } else {
126                    Err(datafusion_common::DataFusionError::Internal(format!(
127                        "PlanRewriter: expand: column {col} from {original_node}\n has empty alias set in plan: {alias_node}\n but expect at least one alias",
128                    )))
129                }
130            } else {
131                Err(datafusion_common::DataFusionError::Internal(format!(
132                    "PlanRewriter: expand: column {col} from {original_node}\n has no alias in plan: {alias_node}",
133                )))
134            }
135        } else {
136            Ok(Transformed::no(e))
137        }
138    }
139}
140
141/// Rewrite the expressions of the given merge sort plan from original columns(at merge sort's input plan) to aliased columns at the given aliased node
142pub fn rewrite_merge_sort_exprs(
143    merge_sort: &MergeSortLogicalPlan,
144    aliased_node: &LogicalPlan,
145) -> DfResult<LogicalPlan> {
146    let merge_sort = LogicalPlan::Extension(Extension {
147        node: Arc::new(merge_sort.clone()),
148    });
149
150    // tracking alias for sort exprs,
151    let sort_input = merge_sort.inputs().first().cloned().ok_or_else(|| {
152        datafusion_common::DataFusionError::Internal(format!(
153            "PlanRewriter: expand: merge sort stage has no input: {merge_sort}"
154        ))
155    })?;
156    let sort_exprs = merge_sort.expressions_consider_join();
157    let column_refs = sort_exprs
158        .iter()
159        .flat_map(|e| e.column_refs().into_iter().cloned())
160        .collect::<BTreeSet<_>>();
161    let column_alias_mapping = aliased_columns_for(&column_refs, aliased_node, Some(sort_input))?;
162    let aliased_sort_exprs = sort_exprs
163        .into_iter()
164        .map(|e| {
165            e.transform(rewrite_column(
166                &column_alias_mapping,
167                &merge_sort,
168                aliased_node,
169            ))
170        })
171        .map(|e| e.map(|e| e.data))
172        .collect::<DfResult<Vec<_>>>()?;
173    let new_merge_sort = merge_sort.with_new_exprs(
174        aliased_sort_exprs,
175        merge_sort.inputs().into_iter().cloned().collect(),
176    )?;
177    Ok(new_merge_sort)
178}
179
180/// Return all the original columns(at original node) for the given aliased columns at the aliased node
181///
182/// if `original_node` is None, it means original columns are from leaf node
183///
184/// Return value use `BTreeMap` to have deterministic order for choose first alias when multiple alias exist
185#[allow(unused)]
186pub fn original_column_for(
187    aliased_columns: &BTreeSet<Column>,
188    aliased_node: LogicalPlan,
189    original_node: Option<Arc<LogicalPlan>>,
190) -> DfResult<BTreeMap<Column, Column>> {
191    let schema_cols: BTreeSet<Column> = aliased_node.schema().columns().iter().cloned().collect();
192    let cur_aliases: BTreeMap<Column, Column> = aliased_columns
193        .iter()
194        .filter(|c| schema_cols.contains(c))
195        .map(|c| (c.clone(), c.clone()))
196        .collect();
197
198    if cur_aliases.is_empty() {
199        return Ok(BTreeMap::new());
200    }
201
202    original_column_for_inner(cur_aliases, &aliased_node, &original_node)
203}
204
205fn original_column_for_inner(
206    mut cur_aliases: BTreeMap<Column, Column>,
207    node: &LogicalPlan,
208    original_node: &Option<Arc<LogicalPlan>>,
209) -> DfResult<BTreeMap<Column, Column>> {
210    let mut current_node = node;
211
212    loop {
213        // Base case: check if we've reached the target node
214        if let Some(original_node) = original_node
215            && *current_node == **original_node
216        {
217            return Ok(cur_aliases);
218        } else if current_node.inputs().is_empty() {
219            // leaf node reached
220            return Ok(cur_aliases);
221        }
222
223        // Validate node has exactly one child
224        if current_node.inputs().len() != 1 {
225            return Err(datafusion::error::DataFusionError::Internal(format!(
226                "only accept plan with at most one child, found: {}",
227                current_node
228            )));
229        }
230
231        // Get alias layer and update aliases
232        let layer = get_alias_layer_from_node(current_node)?;
233        let mut new_aliases = BTreeMap::new();
234        for (start_alias, cur_alias) in cur_aliases {
235            if let Some(old_column) = layer.get_old_from_new(cur_alias.clone()) {
236                new_aliases.insert(start_alias, old_column);
237            }
238        }
239
240        // Move to child node and continue iteration
241        cur_aliases = new_aliases;
242        current_node = current_node.inputs()[0];
243    }
244}
245
246/// Return all the aliased columns(at aliased node) for the given original columns(at original node)
247///
248/// if `original_node` is None, it means original columns are from leaf node
249///
250/// Return value use `BTreeMap` to have deterministic order for choose first alias when multiple alias exist
251pub fn aliased_columns_for(
252    original_columns: &BTreeSet<Column>,
253    aliased_node: &LogicalPlan,
254    original_node: Option<&LogicalPlan>,
255) -> DfResult<BTreeMap<Column, BTreeSet<Column>>> {
256    let initial_aliases: BTreeMap<Column, BTreeSet<Column>> = {
257        if let Some(original) = &original_node {
258            let schema_cols: BTreeSet<Column> = original.schema().columns().into_iter().collect();
259            original_columns
260                .iter()
261                .filter(|c| schema_cols.contains(c))
262                .map(|c| (c.clone(), [c.clone()].into()))
263                .collect()
264        } else {
265            original_columns
266                .iter()
267                .map(|c| (c.clone(), [c.clone()].into()))
268                .collect()
269        }
270    };
271
272    if initial_aliases.is_empty() {
273        return Ok(BTreeMap::new());
274    }
275
276    aliased_columns_for_inner(initial_aliases, aliased_node, original_node)
277}
278
279fn aliased_columns_for_inner(
280    cur_aliases: BTreeMap<Column, BTreeSet<Column>>,
281    node: &LogicalPlan,
282    original_node: Option<&LogicalPlan>,
283) -> DfResult<BTreeMap<Column, BTreeSet<Column>>> {
284    // First, collect the path from current node to the target node
285    let mut path = Vec::new();
286    let mut current_node = node;
287
288    // Descend to the target node, collecting nodes along the way
289    loop {
290        // Base case: check if we've reached the target node
291        if let Some(original_node) = original_node
292            && *current_node == *original_node
293        {
294            break;
295        } else if current_node.inputs().is_empty() {
296            // leaf node reached
297            break;
298        }
299
300        // Validate node has exactly one child
301        if current_node.inputs().len() != 1 {
302            return Err(datafusion::error::DataFusionError::Internal(format!(
303                "only accept plan with at most one child, found: {}",
304                current_node
305            )));
306        }
307
308        // Add current node to path and move to child
309        path.push(current_node);
310        current_node = current_node.inputs()[0];
311    }
312
313    // Now apply alias layers in reverse order (from original to aliased)
314    let mut result = cur_aliases;
315    for &node_in_path in path.iter().rev() {
316        let layer = get_alias_layer_from_node(node_in_path)?;
317        let mut new_aliases = BTreeMap::new();
318        for (original_column, cur_alias_set) in result {
319            let mut new_alias_set = BTreeSet::new();
320            for cur_alias in cur_alias_set {
321                new_alias_set.extend(layer.get_new_from_old(cur_alias.clone()));
322            }
323            if !new_alias_set.is_empty() {
324                new_aliases.insert(original_column, new_alias_set);
325            }
326        }
327        result = new_aliases;
328    }
329
330    Ok(result)
331}
332
333/// Return a mapping of original column to all the aliased columns in current node of the plan
334/// TODO(discord9): also support merge scan node
335fn get_alias_layer_from_node(node: &LogicalPlan) -> DfResult<AliasLayer> {
336    match node {
337        LogicalPlan::Projection(proj) => Ok(get_alias_layer_from_exprs(&proj.expr)),
338        LogicalPlan::Aggregate(aggr) => Ok(get_alias_layer_from_exprs(&aggr.group_expr)),
339        LogicalPlan::SubqueryAlias(subquery_alias) => {
340            let mut layer = AliasLayer::default();
341            let old_columns = subquery_alias.input.schema().columns();
342            for old_column in old_columns {
343                let new_column = Column::new(
344                    Some(subquery_alias.alias.clone()),
345                    old_column.name().to_string(),
346                );
347                // mapping from old_column to new_column
348                layer.insert_alias(old_column, [new_column].into());
349            }
350            Ok(layer)
351        }
352        LogicalPlan::TableScan(scan) => {
353            let columns = scan.projected_schema.columns();
354            let mut layer = AliasLayer::default();
355            for col in columns {
356                layer.insert_alias(col.clone(), [col.clone()].into());
357            }
358            Ok(layer)
359        }
360        _ => {
361            let input_schema = node
362                .inputs()
363                .first()
364                .ok_or_else(|| {
365                    datafusion::error::DataFusionError::Internal(format!(
366                        "only accept plan with at most one child, found: {}",
367                        node
368                    ))
369                })?
370                .schema();
371            let output_schema = node.schema();
372            // only accept at most one child plan, and if not one of the above nodes,
373            // also shouldn't modify the schema or else alias scope tracker can't support them
374            if node.inputs().len() > 1 {
375                Err(datafusion::error::DataFusionError::Internal(format!(
376                    "only accept plan with at most one child, found: {}",
377                    node
378                )))
379            } else if node.inputs().len() == 1 {
380                if input_schema != output_schema {
381                    let input_columns = input_schema.columns();
382                    let all_input_is_in_output = input_columns
383                        .iter()
384                        .all(|c| output_schema.is_column_from_schema(c));
385                    if all_input_is_in_output {
386                        // all input is in output, so it's just adding some columns, we can do identity mapping for input columns
387                        let mut layer = AliasLayer::default();
388                        for col in input_columns {
389                            layer.insert_alias(col.clone(), [col.clone()].into());
390                        }
391                        Ok(layer)
392                    } else {
393                        // otherwise use the intersection of input and output
394                        // TODO(discord9): maybe just make this case unsupported for now?
395                        common_telemetry::debug!(
396                            "Might be unsupported plan for alias tracking, track alias anyway: {}",
397                            node
398                        );
399                        let input_columns = input_schema.columns();
400                        let output_columns =
401                            output_schema.columns().into_iter().collect::<HashSet<_>>();
402                        let common_columns: HashSet<Column> = input_columns
403                            .iter()
404                            .filter(|c| output_columns.contains(c))
405                            .cloned()
406                            .collect();
407
408                        let mut layer = AliasLayer::default();
409                        for col in &common_columns {
410                            layer.insert_alias(col.clone(), [col.clone()].into());
411                        }
412                        Ok(layer)
413                    }
414                } else {
415                    // identity mapping
416                    let mut layer = AliasLayer::default();
417                    for col in output_schema.columns() {
418                        layer.insert_alias(col.clone(), [col.clone()].into());
419                    }
420                    Ok(layer)
421                }
422            } else {
423                // unknown plan with no input, error msg
424                Err(datafusion::error::DataFusionError::Internal(format!(
425                    "Unsupported plan with no input: {}",
426                    node
427                )))
428            }
429        }
430    }
431}
432
433fn get_alias_layer_from_exprs(exprs: &[Expr]) -> AliasLayer {
434    let mut alias_mapping: HashMap<Column, HashSet<Column>> = HashMap::new();
435    for expr in exprs {
436        if let Expr::Alias(alias) = expr {
437            if let Some(column) = get_alias_original_column(alias) {
438                alias_mapping
439                    .entry(column.clone())
440                    .or_default()
441                    .insert(Column::new(alias.relation.clone(), alias.name.clone()));
442            }
443        } else if let Expr::Column(column) = expr {
444            // identity mapping
445            alias_mapping
446                .entry(column.clone())
447                .or_default()
448                .insert(column.clone());
449        }
450    }
451    let mut layer = AliasLayer::default();
452    for (old_column, new_columns) in alias_mapping {
453        layer.insert_alias(old_column, new_columns);
454    }
455    layer
456}
457
458#[derive(Default, Debug, Clone)]
459struct AliasLayer {
460    /// for convenient of querying, key is field's name
461    old_to_new: BTreeMap<Column, HashSet<Column>>,
462}
463
464impl AliasLayer {
465    pub fn insert_alias(&mut self, old_column: Column, new_columns: HashSet<Column>) {
466        self.old_to_new
467            .entry(old_column)
468            .or_default()
469            .extend(new_columns);
470    }
471
472    pub fn get_new_from_old(&self, old_column: Column) -> HashSet<Column> {
473        let mut res_cols = HashSet::new();
474        for (old, new_cols) in self.old_to_new.iter() {
475            if old.name() == old_column.name() {
476                match (&old.relation, &old_column.relation) {
477                    (Some(o), Some(c)) => {
478                        if o.resolved_eq(c) {
479                            res_cols.extend(new_cols.clone());
480                        }
481                    }
482                    _ => {
483                        // if any of the two relation is None, meaning not fully qualified, just match name
484                        res_cols.extend(new_cols.clone());
485                    }
486                }
487            }
488        }
489        res_cols
490    }
491
492    pub fn get_old_from_new(&self, new_column: Column) -> Option<Column> {
493        for (old, new_set) in &self.old_to_new {
494            if new_set.iter().any(|n| {
495                if n.name() != new_column.name() {
496                    return false;
497                }
498                match (&n.relation, &new_column.relation) {
499                    (Some(r1), Some(r2)) => r1.resolved_eq(r2),
500                    _ => true,
501                }
502            }) {
503                return Some(old.clone());
504            }
505        }
506        None
507    }
508}
509
510fn get_alias_original_column(alias: &Alias) -> Option<Column> {
511    let mut cur_alias = alias;
512    while let Expr::Alias(inner_alias) = cur_alias.expr.as_ref() {
513        cur_alias = inner_alias;
514    }
515    if let Expr::Column(column) = cur_alias.expr.as_ref() {
516        return Some(column.clone());
517    }
518
519    None
520}
521
522/// Mapping of original column in table to all the alias at current node
523pub type AliasMapping = BTreeMap<String, BTreeSet<Column>>;
524
525#[cfg(test)]
526mod tests {
527    use std::sync::Arc;
528
529    use common_telemetry::init_default_ut_logging;
530    use datafusion::datasource::DefaultTableSource;
531    use datafusion::functions_aggregate::min_max::{max, min};
532    use datafusion_expr::{LogicalPlanBuilder, col};
533    use pretty_assertions::assert_eq;
534    use table::table::adapter::DfTableProviderAdapter;
535
536    use super::*;
537    use crate::dist_plan::analyzer::test::TestTable;
538
539    fn qcol(name: &str) -> Column {
540        Column::from_qualified_name(name)
541    }
542
543    #[test]
544    fn proj_multi_layered_alias_tracker() {
545        // use logging for better debugging
546        init_default_ut_logging();
547        let test_table = TestTable::table_with_name(0, "t".to_string());
548        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
549            DfTableProviderAdapter::new(test_table),
550        )));
551        let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
552            .unwrap()
553            .project(vec![
554                col("number"),
555                col("pk3").alias("pk1"),
556                col("pk3").alias("pk2"),
557            ])
558            .unwrap()
559            .project(vec![
560                col("number"),
561                col("pk2").alias("pk4"),
562                col("pk1").alias("pk5"),
563            ])
564            .unwrap()
565            .build()
566            .unwrap();
567
568        let child = plan.inputs()[0].clone();
569
570        assert_eq!(
571            aliased_columns_for(&[qcol("pk1"), qcol("pk2")].into(), &plan, Some(&child)).unwrap(),
572            [
573                (qcol("pk1"), [qcol("pk5")].into()),
574                (qcol("pk2"), [qcol("pk4")].into())
575            ]
576            .into()
577        );
578
579        // columns not in the plan should return empty mapping
580        assert_eq!(
581            aliased_columns_for(&[qcol("pk1"), qcol("pk2")].into(), &plan, Some(&plan)).unwrap(),
582            [].into()
583        );
584
585        assert_eq!(
586            aliased_columns_for(&[qcol("t.pk3")].into(), &plan, Some(&child)).unwrap(),
587            [].into()
588        );
589
590        assert_eq!(
591            original_column_for(&[qcol("pk5"), qcol("pk4")].into(), plan.clone(), None).unwrap(),
592            [(qcol("pk5"), qcol("t.pk3")), (qcol("pk4"), qcol("t.pk3"))].into()
593        );
594
595        assert_eq!(
596            aliased_columns_for(&[qcol("pk3")].into(), &plan, None).unwrap(),
597            [(qcol("pk3"), [qcol("pk5"), qcol("pk4")].into())].into()
598        );
599        assert_eq!(
600            original_column_for(&[qcol("pk1"), qcol("pk2")].into(), child.clone(), None).unwrap(),
601            [(qcol("pk1"), qcol("t.pk3")), (qcol("pk2"), qcol("t.pk3"))].into()
602        );
603
604        assert_eq!(
605            aliased_columns_for(&[qcol("pk3")].into(), &child, None).unwrap(),
606            [(qcol("pk3"), [qcol("pk1"), qcol("pk2")].into())].into()
607        );
608
609        assert_eq!(
610            original_column_for(
611                &[qcol("pk4"), qcol("pk5")].into(),
612                plan.clone(),
613                Some(Arc::new(child.clone()))
614            )
615            .unwrap(),
616            [(qcol("pk4"), qcol("pk2")), (qcol("pk5"), qcol("pk1"))].into()
617        );
618    }
619
620    #[test]
621    fn sort_subquery_alias_layered_tracker() {
622        let test_table = TestTable::table_with_name(0, "t".to_string());
623        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
624            DfTableProviderAdapter::new(test_table),
625        )));
626
627        let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
628            .unwrap()
629            .sort(vec![col("t.number").sort(true, false)])
630            .unwrap()
631            .alias("a")
632            .unwrap()
633            .build()
634            .unwrap();
635
636        let sort_plan = plan.inputs()[0].clone();
637        let scan_plan = sort_plan.inputs()[0].clone();
638
639        // Test aliased_columns_for from scan to final plan
640        assert_eq!(
641            aliased_columns_for(&[qcol("t.number")].into(), &plan, Some(&scan_plan)).unwrap(),
642            [(qcol("t.number"), [qcol("a.number")].into())].into()
643        );
644
645        // Test aliased_columns_for from sort to final plan
646        assert_eq!(
647            aliased_columns_for(&[qcol("t.number")].into(), &plan, Some(&sort_plan)).unwrap(),
648            [(qcol("t.number"), [qcol("a.number")].into())].into()
649        );
650
651        // Test aliased_columns_for from leaf to final plan
652        assert_eq!(
653            aliased_columns_for(&[qcol("t.number")].into(), &plan, None).unwrap(),
654            [(qcol("t.number"), [qcol("a.number")].into())].into()
655        );
656
657        // Test original_column_for from final plan to scan
658        assert_eq!(
659            original_column_for(
660                &[qcol("a.number")].into(),
661                plan.clone(),
662                Some(Arc::new(scan_plan.clone()))
663            )
664            .unwrap(),
665            [(qcol("a.number"), qcol("t.number"))].into()
666        );
667
668        // Test original_column_for from final plan to sort
669        assert_eq!(
670            original_column_for(
671                &[qcol("a.number")].into(),
672                plan.clone(),
673                Some(Arc::new(sort_plan.clone()))
674            )
675            .unwrap(),
676            [(qcol("a.number"), qcol("t.number"))].into()
677        );
678    }
679
680    #[test]
681    fn proj_alias_layered_tracker() {
682        // use logging for better debugging
683        init_default_ut_logging();
684        let test_table = TestTable::table_with_name(0, "t".to_string());
685        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
686            DfTableProviderAdapter::new(test_table),
687        )));
688        let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
689            .unwrap()
690            .project(vec![
691                col("number"),
692                col("pk3").alias("pk1"),
693                col("pk2").alias("pk3"),
694            ])
695            .unwrap()
696            .project(vec![
697                col("number"),
698                col("pk1").alias("pk2"),
699                col("pk3").alias("pk1"),
700            ])
701            .unwrap()
702            .build()
703            .unwrap();
704
705        let first_proj = plan.inputs()[0].clone();
706        let scan_plan = first_proj.inputs()[0].clone();
707
708        // Test original_column_for from final plan to scan
709        assert_eq!(
710            original_column_for(
711                &[qcol("pk1")].into(),
712                plan.clone(),
713                Some(Arc::new(scan_plan.clone()))
714            )
715            .unwrap(),
716            [(qcol("pk1"), qcol("t.pk2"))].into()
717        );
718
719        // Test original_column_for from final plan to first projection
720        assert_eq!(
721            original_column_for(
722                &[qcol("pk1")].into(),
723                plan.clone(),
724                Some(Arc::new(first_proj.clone()))
725            )
726            .unwrap(),
727            [(qcol("pk1"), qcol("pk3"))].into()
728        );
729
730        // Test original_column_for from final plan to leaf
731        assert_eq!(
732            original_column_for(
733                &[qcol("pk1")].into(),
734                plan.clone(),
735                Some(Arc::new(plan.clone()))
736            )
737            .unwrap(),
738            [(qcol("pk1"), qcol("pk1"))].into()
739        );
740
741        // Test aliased_columns_for from scan to first projection
742        assert_eq!(
743            aliased_columns_for(&[qcol("t.pk2")].into(), &first_proj, Some(&scan_plan)).unwrap(),
744            [(qcol("t.pk2"), [qcol("pk3")].into())].into()
745        );
746
747        // Test aliased_columns_for from first projection to final plan
748        assert_eq!(
749            aliased_columns_for(&[qcol("pk3")].into(), &plan, Some(&first_proj)).unwrap(),
750            [(qcol("pk3"), [qcol("pk1")].into())].into()
751        );
752
753        // Test aliased_columns_for from scan to final plan
754        assert_eq!(
755            aliased_columns_for(&[qcol("t.pk2")].into(), &plan, Some(&scan_plan)).unwrap(),
756            [(qcol("t.pk2"), [qcol("pk1")].into())].into()
757        );
758
759        // Test aliased_columns_for from leaf to final plan
760        assert_eq!(
761            aliased_columns_for(&[qcol("pk2")].into(), &plan, None).unwrap(),
762            [(qcol("pk2"), [qcol("pk1")].into())].into()
763        );
764    }
765
766    #[test]
767    fn proj_alias_relation_layered_tracker() {
768        // use logging for better debugging
769        init_default_ut_logging();
770        let test_table = TestTable::table_with_name(0, "t".to_string());
771        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
772            DfTableProviderAdapter::new(test_table),
773        )));
774        let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
775            .unwrap()
776            .project(vec![
777                col("number"),
778                col("pk3").alias_qualified(Some("b"), "pk1"),
779                col("pk2").alias_qualified(Some("a"), "pk1"),
780            ])
781            .unwrap()
782            .build()
783            .unwrap();
784
785        let scan_plan = plan.inputs()[0].clone();
786
787        // Test aliased_columns_for from scan to projection
788        assert_eq!(
789            aliased_columns_for(&[qcol("t.pk2")].into(), &plan, Some(&scan_plan)).unwrap(),
790            [(qcol("t.pk2"), [qcol("a.pk1")].into())].into()
791        );
792    }
793
794    #[test]
795    fn proj_alias_aliased_aggr() {
796        // use logging for better debugging
797        init_default_ut_logging();
798        let test_table = TestTable::table_with_name(0, "t".to_string());
799        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
800            DfTableProviderAdapter::new(test_table),
801        )));
802        let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
803            .unwrap()
804            .project(vec![
805                col("number"),
806                col("pk1").alias("pk3"),
807                col("pk2").alias("pk4"),
808            ])
809            .unwrap()
810            .project(vec![
811                col("number"),
812                col("pk3").alias("pk42"),
813                col("pk4").alias("pk43"),
814            ])
815            .unwrap()
816            .aggregate(vec![col("pk42"), col("pk43")], vec![min(col("number"))])
817            .unwrap()
818            .build()
819            .unwrap();
820
821        let aggr_plan = plan.clone();
822        let second_proj = aggr_plan.inputs()[0].clone();
823        let first_proj = second_proj.inputs()[0].clone();
824        let scan_plan = first_proj.inputs()[0].clone();
825
826        // Test aliased_columns_for from scan to final plan
827        assert_eq!(
828            aliased_columns_for(&[qcol("t.pk1")].into(), &plan, Some(&scan_plan)).unwrap(),
829            [(qcol("t.pk1"), [qcol("pk42")].into())].into()
830        );
831
832        // Test aliased_columns_for from scan to first projection
833        assert_eq!(
834            aliased_columns_for(&[Column::from_name("pk1")].into(), &first_proj, None).unwrap(),
835            [(Column::from_name("pk1"), [qcol("pk3")].into())].into()
836        );
837    }
838
839    #[test]
840    fn aggr_aggr_alias() {
841        // use logging for better debugging
842        init_default_ut_logging();
843        let test_table = TestTable::table_with_name(0, "t".to_string());
844        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
845            DfTableProviderAdapter::new(test_table),
846        )));
847        let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
848            .unwrap()
849            .aggregate(vec![col("pk1"), col("pk2")], vec![max(col("number"))])
850            .unwrap()
851            .aggregate(
852                vec![col("pk1"), col("pk2")],
853                vec![min(col("max(t.number)"))],
854            )
855            .unwrap()
856            .build()
857            .unwrap();
858
859        let second_aggr = plan.clone();
860        let first_aggr = second_aggr.inputs()[0].clone();
861        let scan_plan = first_aggr.inputs()[0].clone();
862
863        // Test aliased_columns_for from scan to final plan (identity mapping for aggregates)
864        assert_eq!(
865            aliased_columns_for(&[qcol("t.pk1")].into(), &plan, Some(&scan_plan)).unwrap(),
866            [(qcol("t.pk1"), [qcol("t.pk1")].into())].into()
867        );
868
869        // Test aliased_columns_for from scan to first aggregate
870        assert_eq!(
871            aliased_columns_for(&[qcol("t.pk1")].into(), &first_aggr, Some(&scan_plan)).unwrap(),
872            [(qcol("t.pk1"), [qcol("t.pk1")].into())].into()
873        );
874
875        // Test aliased_columns_for from first aggregate to final plan
876        assert_eq!(
877            aliased_columns_for(&[qcol("t.pk1")].into(), &plan, Some(&first_aggr)).unwrap(),
878            [(qcol("t.pk1"), [qcol("t.pk1")].into())].into()
879        );
880
881        // Test aliased_columns_for from leaf to final plan
882        assert_eq!(
883            aliased_columns_for(&[Column::from_name("pk1")].into(), &plan, None).unwrap(),
884            [(Column::from_name("pk1"), [qcol("t.pk1")].into())].into()
885        );
886    }
887
888    #[test]
889    fn aggr_aggr_alias_projection() {
890        // use logging for better debugging
891        init_default_ut_logging();
892        let test_table = TestTable::table_with_name(0, "t".to_string());
893        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
894            DfTableProviderAdapter::new(test_table),
895        )));
896        let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
897            .unwrap()
898            .aggregate(vec![col("pk1"), col("pk2")], vec![max(col("number"))])
899            .unwrap()
900            .aggregate(
901                vec![col("pk1"), col("pk2")],
902                vec![min(col("max(t.number)"))],
903            )
904            .unwrap()
905            .project(vec![
906                col("pk1").alias("pk11"),
907                col("pk2").alias("pk22"),
908                col("min(max(t.number))").alias("min_max_number"),
909            ])
910            .unwrap()
911            .build()
912            .unwrap();
913
914        let proj_plan = plan.clone();
915        let second_aggr = proj_plan.inputs()[0].clone();
916
917        // Test original_column_for from projection to second aggregate for aggr gen column
918        assert_eq!(
919            original_column_for(
920                &[Column::from_name("min_max_number")].into(),
921                plan.clone(),
922                Some(Arc::new(second_aggr.clone()))
923            )
924            .unwrap(),
925            [(
926                Column::from_name("min_max_number"),
927                Column::from_name("min(max(t.number))")
928            )]
929            .into()
930        );
931
932        // Test aliased_columns_for from second aggregate to projection
933        assert_eq!(
934            aliased_columns_for(
935                &[Column::from_name("min(max(t.number))")].into(),
936                &plan,
937                Some(&second_aggr)
938            )
939            .unwrap(),
940            [(
941                Column::from_name("min(max(t.number))"),
942                [Column::from_name("min_max_number")].into()
943            )]
944            .into()
945        );
946    }
947}