query/dist_plan/analyzer/
utils.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet};
16
17use datafusion::datasource::DefaultTableSource;
18use datafusion_common::Column;
19use datafusion_expr::{Expr, LogicalPlan, TableScan};
20use table::metadata::TableType;
21use table::table::adapter::DfTableProviderAdapter;
22
23/// Mapping of original column in table to all the alias at current node
24pub type AliasMapping = HashMap<String, HashSet<Column>>;
25
26/// tracking aliases for the source table columns in the plan
27#[derive(Debug, Clone)]
28pub struct AliasTracker {
29    /// mapping from the original table name to the alias used in the plan
30    /// notice how one column might have multiple aliases in the plan
31    ///
32    pub mapping: AliasMapping,
33}
34
35impl AliasTracker {
36    pub fn new(table_scan: &TableScan) -> Option<Self> {
37        if let Some(source) = table_scan
38            .source
39            .as_any()
40            .downcast_ref::<DefaultTableSource>()
41        {
42            if let Some(provider) = source
43                .table_provider
44                .as_any()
45                .downcast_ref::<DfTableProviderAdapter>()
46            {
47                if provider.table().table_type() == TableType::Base {
48                    let info = provider.table().table_info();
49                    let schema = info.meta.schema.clone();
50                    let col_schema = schema.column_schemas();
51                    let mapping = col_schema
52                        .iter()
53                        .map(|col| {
54                            (
55                                col.name.clone(),
56                                HashSet::from_iter(std::iter::once(Column::new_unqualified(
57                                    col.name.clone(),
58                                ))),
59                            )
60                        })
61                        .collect();
62                    return Some(Self { mapping });
63                }
64            }
65        }
66
67        None
68    }
69
70    /// update alias for original columns
71    ///
72    /// only handle `Alias` with column in `Projection` node
73    pub fn update_alias(&mut self, node: &LogicalPlan) {
74        if let LogicalPlan::Projection(projection) = node {
75            // first collect all the alias mapping, i.e. the col_a AS b AS c AS d become `a->d`
76            // notice one column might have multiple aliases
77            let mut alias_mapping: AliasMapping = HashMap::new();
78            for expr in &projection.expr {
79                if let Expr::Alias(alias) = expr {
80                    let outer_alias = alias.clone();
81                    let mut cur_alias = alias.clone();
82                    while let Expr::Alias(alias) = *cur_alias.expr {
83                        cur_alias = alias;
84                    }
85                    if let Expr::Column(column) = *cur_alias.expr {
86                        alias_mapping
87                            .entry(column.name.clone())
88                            .or_default()
89                            .insert(Column::new(outer_alias.relation, outer_alias.name));
90                    }
91                } else if let Expr::Column(column) = expr {
92                    // identity mapping
93                    alias_mapping
94                        .entry(column.name.clone())
95                        .or_default()
96                        .insert(column.clone());
97                }
98            }
99
100            // update mapping using `alias_mapping`
101            let mut new_mapping = HashMap::new();
102            for (table_col_name, cur_columns) in std::mem::take(&mut self.mapping) {
103                let new_aliases = {
104                    let mut new_aliases = HashSet::new();
105                    for cur_column in &cur_columns {
106                        let new_alias_for_cur_column = alias_mapping
107                            .get(cur_column.name())
108                            .cloned()
109                            .unwrap_or_default();
110
111                        for new_alias in new_alias_for_cur_column {
112                            let is_table_ref_eq = match (&new_alias.relation, &cur_column.relation)
113                            {
114                                (Some(o), Some(c)) => o.resolved_eq(c),
115                                _ => true,
116                            };
117                            // is the same column if both name and table ref is eq
118                            if is_table_ref_eq {
119                                new_aliases.insert(new_alias.clone());
120                            }
121                        }
122                    }
123                    new_aliases
124                };
125
126                new_mapping.insert(table_col_name, new_aliases);
127            }
128
129            self.mapping = new_mapping;
130            common_telemetry::debug!(
131                "Updating alias tracker to {:?} using node: \n{node}",
132                self.mapping
133            );
134        }
135    }
136
137    pub fn get_all_alias_for_col(&self, col_name: &str) -> Option<&HashSet<Column>> {
138        self.mapping.get(col_name)
139    }
140
141    #[allow(unused)]
142    pub fn is_alias_for(&self, original_col: &str, cur_col: &Column) -> bool {
143        self.mapping
144            .get(original_col)
145            .map(|cols| cols.contains(cur_col))
146            .unwrap_or(false)
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use std::sync::Arc;
153
154    use common_telemetry::init_default_ut_logging;
155    use datafusion::error::Result as DfResult;
156    use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
157    use datafusion_expr::{col, LogicalPlanBuilder};
158
159    use super::*;
160    use crate::dist_plan::analyzer::test::TestTable;
161
162    #[derive(Debug)]
163    struct TrackerTester {
164        alias_tracker: Option<AliasTracker>,
165        mapping_at_each_level: Vec<AliasMapping>,
166    }
167
168    impl TreeNodeVisitor<'_> for TrackerTester {
169        type Node = LogicalPlan;
170
171        fn f_up(&mut self, node: &LogicalPlan) -> DfResult<TreeNodeRecursion> {
172            if let Some(alias_tracker) = &mut self.alias_tracker {
173                alias_tracker.update_alias(node);
174                self.mapping_at_each_level.push(
175                    self.alias_tracker
176                        .as_ref()
177                        .map(|a| a.mapping.clone())
178                        .unwrap_or_default()
179                        .clone(),
180                );
181            } else if let LogicalPlan::TableScan(table_scan) = node {
182                self.alias_tracker = AliasTracker::new(table_scan);
183                self.mapping_at_each_level.push(
184                    self.alias_tracker
185                        .as_ref()
186                        .map(|a| a.mapping.clone())
187                        .unwrap_or_default()
188                        .clone(),
189                );
190            }
191            Ok(TreeNodeRecursion::Continue)
192        }
193    }
194
195    #[test]
196    fn proj_alias_tracker() {
197        // use logging for better debugging
198        init_default_ut_logging();
199        let test_table = TestTable::table_with_name(0, "numbers".to_string());
200        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
201            DfTableProviderAdapter::new(test_table),
202        )));
203        let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
204            .unwrap()
205            .project(vec![
206                col("number"),
207                col("pk3").alias("pk1"),
208                col("pk2").alias("pk3"),
209            ])
210            .unwrap()
211            .project(vec![
212                col("number"),
213                col("pk1").alias("pk2"),
214                col("pk3").alias("pk1"),
215            ])
216            .unwrap()
217            .build()
218            .unwrap();
219
220        let mut tracker_tester = TrackerTester {
221            alias_tracker: None,
222            mapping_at_each_level: Vec::new(),
223        };
224        plan.visit(&mut tracker_tester).unwrap();
225
226        assert_eq!(
227            tracker_tester.mapping_at_each_level,
228            vec![
229                HashMap::from([
230                    ("number".to_string(), HashSet::from(["number".into()])),
231                    ("pk1".to_string(), HashSet::from(["pk1".into()])),
232                    ("pk2".to_string(), HashSet::from(["pk2".into()])),
233                    ("pk3".to_string(), HashSet::from(["pk3".into()])),
234                    ("ts".to_string(), HashSet::from(["ts".into()]))
235                ]),
236                HashMap::from([
237                    ("number".to_string(), HashSet::from(["t.number".into()])),
238                    ("pk1".to_string(), HashSet::from([])),
239                    ("pk2".to_string(), HashSet::from(["pk3".into()])),
240                    ("pk3".to_string(), HashSet::from(["pk1".into()])),
241                    ("ts".to_string(), HashSet::from([]))
242                ]),
243                HashMap::from([
244                    ("number".to_string(), HashSet::from(["t.number".into()])),
245                    ("pk1".to_string(), HashSet::from([])),
246                    ("pk2".to_string(), HashSet::from(["pk1".into()])),
247                    ("pk3".to_string(), HashSet::from(["pk2".into()])),
248                    ("ts".to_string(), HashSet::from([]))
249                ])
250            ]
251        );
252    }
253
254    #[test]
255    fn proj_multi_alias_tracker() {
256        // use logging for better debugging
257        init_default_ut_logging();
258        let test_table = TestTable::table_with_name(0, "numbers".to_string());
259        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
260            DfTableProviderAdapter::new(test_table),
261        )));
262        let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
263            .unwrap()
264            .project(vec![
265                col("number"),
266                col("pk3").alias("pk1"),
267                col("pk3").alias("pk2"),
268            ])
269            .unwrap()
270            .project(vec![
271                col("number"),
272                col("pk2").alias("pk4"),
273                col("pk1").alias("pk5"),
274            ])
275            .unwrap()
276            .build()
277            .unwrap();
278
279        let mut tracker_tester = TrackerTester {
280            alias_tracker: None,
281            mapping_at_each_level: Vec::new(),
282        };
283        plan.visit(&mut tracker_tester).unwrap();
284
285        assert_eq!(
286            tracker_tester.mapping_at_each_level,
287            vec![
288                HashMap::from([
289                    ("number".to_string(), HashSet::from(["number".into()])),
290                    ("pk1".to_string(), HashSet::from(["pk1".into()])),
291                    ("pk2".to_string(), HashSet::from(["pk2".into()])),
292                    ("pk3".to_string(), HashSet::from(["pk3".into()])),
293                    ("ts".to_string(), HashSet::from(["ts".into()]))
294                ]),
295                HashMap::from([
296                    ("number".to_string(), HashSet::from(["t.number".into()])),
297                    ("pk1".to_string(), HashSet::from([])),
298                    ("pk2".to_string(), HashSet::from([])),
299                    (
300                        "pk3".to_string(),
301                        HashSet::from(["pk1".into(), "pk2".into()])
302                    ),
303                    ("ts".to_string(), HashSet::from([]))
304                ]),
305                HashMap::from([
306                    ("number".to_string(), HashSet::from(["t.number".into()])),
307                    ("pk1".to_string(), HashSet::from([])),
308                    ("pk2".to_string(), HashSet::from([])),
309                    (
310                        "pk3".to_string(),
311                        HashSet::from(["pk4".into(), "pk5".into()])
312                    ),
313                    ("ts".to_string(), HashSet::from([]))
314                ])
315            ]
316        );
317    }
318}