query/dist_plan/analyzer/
utils.rs1use std::collections::{HashMap, HashSet};
16
17use datafusion::datasource::DefaultTableSource;
18use datafusion_common::Column;
19use datafusion_expr::{Expr, LogicalPlan, TableScan};
20use table::metadata::TableType;
21use table::table::adapter::DfTableProviderAdapter;
22
23pub type AliasMapping = HashMap<String, HashSet<Column>>;
25
26#[derive(Debug, Clone)]
28pub struct AliasTracker {
29 pub mapping: AliasMapping,
33}
34
35impl AliasTracker {
36 pub fn new(table_scan: &TableScan) -> Option<Self> {
37 if let Some(source) = table_scan
38 .source
39 .as_any()
40 .downcast_ref::<DefaultTableSource>()
41 {
42 if let Some(provider) = source
43 .table_provider
44 .as_any()
45 .downcast_ref::<DfTableProviderAdapter>()
46 {
47 if provider.table().table_type() == TableType::Base {
48 let info = provider.table().table_info();
49 let schema = info.meta.schema.clone();
50 let col_schema = schema.column_schemas();
51 let mapping = col_schema
52 .iter()
53 .map(|col| {
54 (
55 col.name.clone(),
56 HashSet::from_iter(std::iter::once(Column::new_unqualified(
57 col.name.clone(),
58 ))),
59 )
60 })
61 .collect();
62 return Some(Self { mapping });
63 }
64 }
65 }
66
67 None
68 }
69
70 pub fn update_alias(&mut self, node: &LogicalPlan) {
74 if let LogicalPlan::Projection(projection) = node {
75 let mut alias_mapping: AliasMapping = HashMap::new();
78 for expr in &projection.expr {
79 if let Expr::Alias(alias) = expr {
80 let outer_alias = alias.clone();
81 let mut cur_alias = alias.clone();
82 while let Expr::Alias(alias) = *cur_alias.expr {
83 cur_alias = alias;
84 }
85 if let Expr::Column(column) = *cur_alias.expr {
86 alias_mapping
87 .entry(column.name.clone())
88 .or_default()
89 .insert(Column::new(outer_alias.relation, outer_alias.name));
90 }
91 } else if let Expr::Column(column) = expr {
92 alias_mapping
94 .entry(column.name.clone())
95 .or_default()
96 .insert(column.clone());
97 }
98 }
99
100 let mut new_mapping = HashMap::new();
102 for (table_col_name, cur_columns) in std::mem::take(&mut self.mapping) {
103 let new_aliases = {
104 let mut new_aliases = HashSet::new();
105 for cur_column in &cur_columns {
106 let new_alias_for_cur_column = alias_mapping
107 .get(cur_column.name())
108 .cloned()
109 .unwrap_or_default();
110
111 for new_alias in new_alias_for_cur_column {
112 let is_table_ref_eq = match (&new_alias.relation, &cur_column.relation)
113 {
114 (Some(o), Some(c)) => o.resolved_eq(c),
115 _ => true,
116 };
117 if is_table_ref_eq {
119 new_aliases.insert(new_alias.clone());
120 }
121 }
122 }
123 new_aliases
124 };
125
126 new_mapping.insert(table_col_name, new_aliases);
127 }
128
129 self.mapping = new_mapping;
130 common_telemetry::debug!(
131 "Updating alias tracker to {:?} using node: \n{node}",
132 self.mapping
133 );
134 }
135 }
136
137 pub fn get_all_alias_for_col(&self, col_name: &str) -> Option<&HashSet<Column>> {
138 self.mapping.get(col_name)
139 }
140
141 #[allow(unused)]
142 pub fn is_alias_for(&self, original_col: &str, cur_col: &Column) -> bool {
143 self.mapping
144 .get(original_col)
145 .map(|cols| cols.contains(cur_col))
146 .unwrap_or(false)
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use std::sync::Arc;
153
154 use common_telemetry::init_default_ut_logging;
155 use datafusion::error::Result as DfResult;
156 use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
157 use datafusion_expr::{col, LogicalPlanBuilder};
158
159 use super::*;
160 use crate::dist_plan::analyzer::test::TestTable;
161
162 #[derive(Debug)]
163 struct TrackerTester {
164 alias_tracker: Option<AliasTracker>,
165 mapping_at_each_level: Vec<AliasMapping>,
166 }
167
168 impl TreeNodeVisitor<'_> for TrackerTester {
169 type Node = LogicalPlan;
170
171 fn f_up(&mut self, node: &LogicalPlan) -> DfResult<TreeNodeRecursion> {
172 if let Some(alias_tracker) = &mut self.alias_tracker {
173 alias_tracker.update_alias(node);
174 self.mapping_at_each_level.push(
175 self.alias_tracker
176 .as_ref()
177 .map(|a| a.mapping.clone())
178 .unwrap_or_default()
179 .clone(),
180 );
181 } else if let LogicalPlan::TableScan(table_scan) = node {
182 self.alias_tracker = AliasTracker::new(table_scan);
183 self.mapping_at_each_level.push(
184 self.alias_tracker
185 .as_ref()
186 .map(|a| a.mapping.clone())
187 .unwrap_or_default()
188 .clone(),
189 );
190 }
191 Ok(TreeNodeRecursion::Continue)
192 }
193 }
194
195 #[test]
196 fn proj_alias_tracker() {
197 init_default_ut_logging();
199 let test_table = TestTable::table_with_name(0, "numbers".to_string());
200 let table_source = Arc::new(DefaultTableSource::new(Arc::new(
201 DfTableProviderAdapter::new(test_table),
202 )));
203 let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
204 .unwrap()
205 .project(vec![
206 col("number"),
207 col("pk3").alias("pk1"),
208 col("pk2").alias("pk3"),
209 ])
210 .unwrap()
211 .project(vec![
212 col("number"),
213 col("pk1").alias("pk2"),
214 col("pk3").alias("pk1"),
215 ])
216 .unwrap()
217 .build()
218 .unwrap();
219
220 let mut tracker_tester = TrackerTester {
221 alias_tracker: None,
222 mapping_at_each_level: Vec::new(),
223 };
224 plan.visit(&mut tracker_tester).unwrap();
225
226 assert_eq!(
227 tracker_tester.mapping_at_each_level,
228 vec![
229 HashMap::from([
230 ("number".to_string(), HashSet::from(["number".into()])),
231 ("pk1".to_string(), HashSet::from(["pk1".into()])),
232 ("pk2".to_string(), HashSet::from(["pk2".into()])),
233 ("pk3".to_string(), HashSet::from(["pk3".into()])),
234 ("ts".to_string(), HashSet::from(["ts".into()]))
235 ]),
236 HashMap::from([
237 ("number".to_string(), HashSet::from(["t.number".into()])),
238 ("pk1".to_string(), HashSet::from([])),
239 ("pk2".to_string(), HashSet::from(["pk3".into()])),
240 ("pk3".to_string(), HashSet::from(["pk1".into()])),
241 ("ts".to_string(), HashSet::from([]))
242 ]),
243 HashMap::from([
244 ("number".to_string(), HashSet::from(["t.number".into()])),
245 ("pk1".to_string(), HashSet::from([])),
246 ("pk2".to_string(), HashSet::from(["pk1".into()])),
247 ("pk3".to_string(), HashSet::from(["pk2".into()])),
248 ("ts".to_string(), HashSet::from([]))
249 ])
250 ]
251 );
252 }
253
254 #[test]
255 fn proj_multi_alias_tracker() {
256 init_default_ut_logging();
258 let test_table = TestTable::table_with_name(0, "numbers".to_string());
259 let table_source = Arc::new(DefaultTableSource::new(Arc::new(
260 DfTableProviderAdapter::new(test_table),
261 )));
262 let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
263 .unwrap()
264 .project(vec![
265 col("number"),
266 col("pk3").alias("pk1"),
267 col("pk3").alias("pk2"),
268 ])
269 .unwrap()
270 .project(vec![
271 col("number"),
272 col("pk2").alias("pk4"),
273 col("pk1").alias("pk5"),
274 ])
275 .unwrap()
276 .build()
277 .unwrap();
278
279 let mut tracker_tester = TrackerTester {
280 alias_tracker: None,
281 mapping_at_each_level: Vec::new(),
282 };
283 plan.visit(&mut tracker_tester).unwrap();
284
285 assert_eq!(
286 tracker_tester.mapping_at_each_level,
287 vec![
288 HashMap::from([
289 ("number".to_string(), HashSet::from(["number".into()])),
290 ("pk1".to_string(), HashSet::from(["pk1".into()])),
291 ("pk2".to_string(), HashSet::from(["pk2".into()])),
292 ("pk3".to_string(), HashSet::from(["pk3".into()])),
293 ("ts".to_string(), HashSet::from(["ts".into()]))
294 ]),
295 HashMap::from([
296 ("number".to_string(), HashSet::from(["t.number".into()])),
297 ("pk1".to_string(), HashSet::from([])),
298 ("pk2".to_string(), HashSet::from([])),
299 (
300 "pk3".to_string(),
301 HashSet::from(["pk1".into(), "pk2".into()])
302 ),
303 ("ts".to_string(), HashSet::from([]))
304 ]),
305 HashMap::from([
306 ("number".to_string(), HashSet::from(["t.number".into()])),
307 ("pk1".to_string(), HashSet::from([])),
308 ("pk2".to_string(), HashSet::from([])),
309 (
310 "pk3".to_string(),
311 HashSet::from(["pk4".into(), "pk5".into()])
312 ),
313 ("ts".to_string(), HashSet::from([]))
314 ])
315 ]
316 );
317 }
318}