query/optimizer/
count_wildcard.rs1use datafusion::datasource::DefaultTableSource;
16use datafusion_common::tree_node::{
17 Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor,
18};
19use datafusion_common::{Column, Result as DataFusionResult, ScalarValue};
20use datafusion_expr::expr::{AggregateFunction, WindowFunction};
21use datafusion_expr::utils::COUNT_STAR_EXPANSION;
22use datafusion_expr::{col, lit, Expr, LogicalPlan, WindowFunctionDefinition};
23use datafusion_optimizer::utils::NamePreserver;
24use datafusion_optimizer::AnalyzerRule;
25use datafusion_sql::TableReference;
26use table::table::adapter::DfTableProviderAdapter;
27
28#[derive(Debug)]
34pub struct CountWildcardToTimeIndexRule;
35
36impl AnalyzerRule for CountWildcardToTimeIndexRule {
37 fn name(&self) -> &str {
38 "count_wildcard_to_time_index_rule"
39 }
40
41 fn analyze(
42 &self,
43 plan: LogicalPlan,
44 _config: &datafusion::config::ConfigOptions,
45 ) -> DataFusionResult<LogicalPlan> {
46 plan.transform_down_with_subqueries(&Self::analyze_internal)
47 .data()
48 }
49}
50
51impl CountWildcardToTimeIndexRule {
52 fn analyze_internal(plan: LogicalPlan) -> DataFusionResult<Transformed<LogicalPlan>> {
53 let name_preserver = NamePreserver::new(&plan);
54 let new_arg = if let Some(time_index) = Self::try_find_time_index_col(&plan) {
55 vec![col(time_index)]
56 } else {
57 vec![lit(COUNT_STAR_EXPANSION)]
58 };
59 plan.map_expressions(|expr| {
60 let original_name = name_preserver.save(&expr);
61 let transformed_expr = expr.transform_up(|expr| match expr {
62 Expr::WindowFunction(mut window_function)
63 if Self::is_count_star_window_aggregate(&window_function) =>
64 {
65 window_function.params.args.clone_from(&new_arg);
66 Ok(Transformed::yes(Expr::WindowFunction(window_function)))
67 }
68 Expr::AggregateFunction(mut aggregate_function)
69 if Self::is_count_star_aggregate(&aggregate_function) =>
70 {
71 aggregate_function.params.args.clone_from(&new_arg);
72 Ok(Transformed::yes(Expr::AggregateFunction(
73 aggregate_function,
74 )))
75 }
76 _ => Ok(Transformed::no(expr)),
77 })?;
78 Ok(transformed_expr.update_data(|data| original_name.restore(data)))
79 })
80 }
81
82 fn try_find_time_index_col(plan: &LogicalPlan) -> Option<Column> {
83 let mut finder = TimeIndexFinder::default();
84 plan.visit(&mut finder).unwrap();
86 let col = finder.into_column();
87
88 if let Some(col) = &col {
90 let mut is_valid = false;
91 for input in plan.inputs() {
92 if input.schema().has_column(col) {
93 is_valid = true;
94 break;
95 }
96 }
97 if !is_valid {
98 return None;
99 }
100 }
101
102 col
103 }
104}
105
106impl CountWildcardToTimeIndexRule {
108 #[expect(deprecated)]
109 fn args_at_most_wildcard_or_literal_one(args: &[Expr]) -> bool {
110 match args {
111 [] => true,
112 [Expr::Literal(ScalarValue::Int64(Some(v)), _)] => *v == 1,
113 [Expr::Wildcard { .. }] => true,
114 _ => false,
115 }
116 }
117
118 fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
119 let args = &aggregate_function.params.args;
120 matches!(aggregate_function,
121 AggregateFunction {
122 func,
123 ..
124 } if func.name() == "count" && Self::args_at_most_wildcard_or_literal_one(args))
125 }
126
127 fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
128 let args = &window_function.params.args;
129 matches!(window_function.fun,
130 WindowFunctionDefinition::AggregateUDF(ref udaf)
131 if udaf.name() == "count" && Self::args_at_most_wildcard_or_literal_one(args))
132 }
133}
134
135#[derive(Default)]
136struct TimeIndexFinder {
137 time_index_col: Option<String>,
138 table_alias: Option<TableReference>,
139}
140
141impl TreeNodeVisitor<'_> for TimeIndexFinder {
142 type Node = LogicalPlan;
143
144 fn f_down(&mut self, node: &Self::Node) -> DataFusionResult<TreeNodeRecursion> {
145 if let LogicalPlan::SubqueryAlias(subquery_alias) = node {
146 self.table_alias = Some(subquery_alias.alias.clone());
147 }
148
149 if let LogicalPlan::TableScan(table_scan) = &node {
150 if let Some(source) = table_scan
151 .source
152 .as_any()
153 .downcast_ref::<DefaultTableSource>()
154 {
155 if let Some(adapter) = source
156 .table_provider
157 .as_any()
158 .downcast_ref::<DfTableProviderAdapter>()
159 {
160 let table_info = adapter.table().table_info();
161 self.table_alias
162 .get_or_insert(TableReference::bare(table_info.name.clone()));
163 self.time_index_col = table_info
164 .meta
165 .schema
166 .timestamp_column()
167 .map(|c| c.name.clone());
168
169 return Ok(TreeNodeRecursion::Stop);
170 }
171 }
172 }
173
174 Ok(TreeNodeRecursion::Continue)
175 }
176
177 fn f_up(&mut self, _node: &Self::Node) -> DataFusionResult<TreeNodeRecursion> {
178 Ok(TreeNodeRecursion::Stop)
179 }
180}
181
182impl TimeIndexFinder {
183 fn into_column(self) -> Option<Column> {
184 self.time_index_col
185 .map(|c| Column::new(self.table_alias, c))
186 }
187}
188
189#[cfg(test)]
190mod test {
191 use std::sync::Arc;
192
193 use datafusion::functions_aggregate::count::count_all;
194 use datafusion_expr::LogicalPlanBuilder;
195 use table::table::numbers::NumbersTable;
196
197 use super::*;
198
199 #[test]
200 fn uppercase_table_name() {
201 let numbers_table = NumbersTable::table_with_name(0, "AbCdE".to_string());
202 let table_source = Arc::new(DefaultTableSource::new(Arc::new(
203 DfTableProviderAdapter::new(numbers_table),
204 )));
205
206 let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
207 .unwrap()
208 .aggregate(Vec::<Expr>::new(), vec![count_all()])
209 .unwrap()
210 .alias(r#""FgHiJ""#)
211 .unwrap()
212 .build()
213 .unwrap();
214
215 let mut finder = TimeIndexFinder::default();
216 plan.visit(&mut finder).unwrap();
217
218 assert_eq!(finder.table_alias, Some(TableReference::bare("FgHiJ")));
219 assert!(finder.time_index_col.is_none());
220 }
221}