1use datafusion::datasource::DefaultTableSource;
16use datafusion_common::tree_node::{
17 Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor,
18};
19use datafusion_common::{Column, Result as DataFusionResult, ScalarValue};
20use datafusion_expr::expr::{AggregateFunction, WindowFunction};
21use datafusion_expr::utils::COUNT_STAR_EXPANSION;
22use datafusion_expr::{Expr, LogicalPlan, WindowFunctionDefinition, col, lit};
23use datafusion_optimizer::AnalyzerRule;
24use datafusion_optimizer::utils::NamePreserver;
25use datafusion_sql::TableReference;
26use table::table::adapter::DfTableProviderAdapter;
27
28#[derive(Debug)]
34pub struct CountWildcardToTimeIndexRule;
35
36impl AnalyzerRule for CountWildcardToTimeIndexRule {
37 fn name(&self) -> &str {
38 "count_wildcard_to_time_index_rule"
39 }
40
41 fn analyze(
42 &self,
43 plan: LogicalPlan,
44 _config: &datafusion::config::ConfigOptions,
45 ) -> DataFusionResult<LogicalPlan> {
46 plan.transform_down_with_subqueries(&Self::analyze_internal)
47 .data()
48 }
49}
50
51impl CountWildcardToTimeIndexRule {
52 fn analyze_internal(plan: LogicalPlan) -> DataFusionResult<Transformed<LogicalPlan>> {
53 let name_preserver = NamePreserver::new(&plan);
54 let new_arg = if let Some(time_index) = Self::try_find_time_index_col(&plan) {
55 vec![col(time_index)]
56 } else {
57 vec![lit(COUNT_STAR_EXPANSION)]
58 };
59 plan.map_expressions(|expr| {
60 let original_name = name_preserver.save(&expr);
61 let transformed_expr = expr.transform_up(|expr| match expr {
62 Expr::WindowFunction(mut window_function)
63 if Self::is_count_star_window_aggregate(&window_function) =>
64 {
65 window_function.params.args.clone_from(&new_arg);
66 Ok(Transformed::yes(Expr::WindowFunction(window_function)))
67 }
68 Expr::AggregateFunction(mut aggregate_function)
69 if Self::is_count_star_aggregate(&aggregate_function) =>
70 {
71 aggregate_function.params.args.clone_from(&new_arg);
72 Ok(Transformed::yes(Expr::AggregateFunction(
73 aggregate_function,
74 )))
75 }
76 _ => Ok(Transformed::no(expr)),
77 })?;
78 Ok(transformed_expr.update_data(|data| original_name.restore(data)))
79 })
80 }
81
82 fn try_find_time_index_col(plan: &LogicalPlan) -> Option<Column> {
83 let mut finder = TimeIndexFinder::default();
84 plan.visit(&mut finder).unwrap();
86 let col = finder.into_column();
87
88 if let Some(col) = &col {
90 let mut is_valid = false;
91 if plan.inputs().len() > 1 {
93 return None;
94 }
95 for input in plan.inputs() {
96 if input.schema().has_column(col) {
97 is_valid = true;
98 break;
99 }
100 }
101 if !is_valid {
102 return None;
103 }
104 }
105
106 col
107 }
108}
109
110impl CountWildcardToTimeIndexRule {
112 #[expect(deprecated)]
113 fn args_at_most_wildcard_or_literal_one(args: &[Expr]) -> bool {
114 match args {
115 [] => true,
116 [Expr::Literal(ScalarValue::Int64(Some(v)), _)] => *v == 1,
117 [Expr::Wildcard { .. }] => true,
118 _ => false,
119 }
120 }
121
122 fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
123 let args = &aggregate_function.params.args;
124 matches!(aggregate_function,
125 AggregateFunction {
126 func,
127 ..
128 } if func.name() == "count" && Self::args_at_most_wildcard_or_literal_one(args))
129 }
130
131 fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
132 let args = &window_function.params.args;
133 matches!(window_function.fun,
134 WindowFunctionDefinition::AggregateUDF(ref udaf)
135 if udaf.name() == "count" && Self::args_at_most_wildcard_or_literal_one(args))
136 }
137}
138
139#[derive(Default)]
140struct TimeIndexFinder {
141 time_index_col: Option<String>,
142 table_alias: Option<TableReference>,
143}
144
145impl TreeNodeVisitor<'_> for TimeIndexFinder {
146 type Node = LogicalPlan;
147
148 fn f_down(&mut self, node: &Self::Node) -> DataFusionResult<TreeNodeRecursion> {
149 if let LogicalPlan::SubqueryAlias(subquery_alias) = node {
150 self.table_alias = Some(subquery_alias.alias.clone());
151 }
152
153 if let LogicalPlan::TableScan(table_scan) = &node
154 && let Some(source) = table_scan
155 .source
156 .as_any()
157 .downcast_ref::<DefaultTableSource>()
158 && let Some(adapter) = source
159 .table_provider
160 .as_any()
161 .downcast_ref::<DfTableProviderAdapter>()
162 {
163 let table_info = adapter.table().table_info();
164 self.table_alias
165 .get_or_insert(table_scan.table_name.clone());
166 self.time_index_col = table_info
167 .meta
168 .schema
169 .timestamp_column()
170 .map(|c| c.name.clone());
171
172 return Ok(TreeNodeRecursion::Stop);
173 }
174
175 if node.inputs().len() > 1 {
176 return Ok(TreeNodeRecursion::Stop);
178 }
179
180 Ok(TreeNodeRecursion::Continue)
181 }
182
183 fn f_up(&mut self, _node: &Self::Node) -> DataFusionResult<TreeNodeRecursion> {
184 Ok(TreeNodeRecursion::Stop)
185 }
186}
187
188impl TimeIndexFinder {
189 fn into_column(self) -> Option<Column> {
190 self.time_index_col
191 .map(|c| Column::new(self.table_alias, c))
192 }
193}
194
195#[cfg(test)]
196mod test {
197 use std::sync::Arc;
198
199 use common_catalog::consts::DEFAULT_CATALOG_NAME;
200 use common_error::ext::{BoxedError, ErrorExt, StackError};
201 use common_error::status_code::StatusCode;
202 use common_recordbatch::SendableRecordBatchStream;
203 use datafusion::functions_aggregate::count::count_all;
204 use datafusion_common::Column;
205 use datafusion_expr::LogicalPlanBuilder;
206 use datafusion_sql::TableReference;
207 use datatypes::data_type::ConcreteDataType;
208 use datatypes::schema::{ColumnSchema, SchemaBuilder};
209 use store_api::data_source::DataSource;
210 use store_api::storage::ScanRequest;
211 use table::metadata::{FilterPushDownType, TableInfoBuilder, TableMetaBuilder, TableType};
212 use table::table::numbers::NumbersTable;
213 use table::{Table, TableRef};
214
215 use super::*;
216
217 #[test]
218 fn uppercase_table_name() {
219 let numbers_table = NumbersTable::table_with_name(0, "AbCdE".to_string());
220 let table_source = Arc::new(DefaultTableSource::new(Arc::new(
221 DfTableProviderAdapter::new(numbers_table),
222 )));
223
224 let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
225 .unwrap()
226 .aggregate(Vec::<Expr>::new(), vec![count_all()])
227 .unwrap()
228 .alias(r#""FgHiJ""#)
229 .unwrap()
230 .build()
231 .unwrap();
232
233 let mut finder = TimeIndexFinder::default();
234 plan.visit(&mut finder).unwrap();
235
236 assert_eq!(finder.table_alias, Some(TableReference::bare("FgHiJ")));
237 assert!(finder.time_index_col.is_none());
238 }
239
240 #[test]
241 fn bare_table_name_time_index() {
242 let table_ref = TableReference::bare("multi_partitioned_test_1");
243 let table =
244 build_time_index_table("multi_partitioned_test_1", "public", DEFAULT_CATALOG_NAME);
245 let table_source = Arc::new(DefaultTableSource::new(Arc::new(
246 DfTableProviderAdapter::new(table),
247 )));
248
249 let plan =
250 LogicalPlanBuilder::scan_with_filters(table_ref.clone(), table_source, None, vec![])
251 .unwrap()
252 .aggregate(Vec::<Expr>::new(), vec![count_all()])
253 .unwrap()
254 .build()
255 .unwrap();
256
257 let time_index = CountWildcardToTimeIndexRule::try_find_time_index_col(&plan);
258 assert_eq!(
259 time_index,
260 Some(Column::new(Some(table_ref), "greptime_timestamp"))
261 );
262 }
263
264 #[test]
265 fn schema_qualified_table_name_time_index() {
266 let table_ref = TableReference::partial("telemetry_events", "multi_partitioned_test_1");
267 let table = build_time_index_table(
268 "multi_partitioned_test_1",
269 "telemetry_events",
270 DEFAULT_CATALOG_NAME,
271 );
272 let table_source = Arc::new(DefaultTableSource::new(Arc::new(
273 DfTableProviderAdapter::new(table),
274 )));
275
276 let plan =
277 LogicalPlanBuilder::scan_with_filters(table_ref.clone(), table_source, None, vec![])
278 .unwrap()
279 .aggregate(Vec::<Expr>::new(), vec![count_all()])
280 .unwrap()
281 .build()
282 .unwrap();
283
284 let time_index = CountWildcardToTimeIndexRule::try_find_time_index_col(&plan);
285 assert_eq!(
286 time_index,
287 Some(Column::new(Some(table_ref), "greptime_timestamp"))
288 );
289 }
290
291 #[test]
292 fn fully_qualified_table_name_time_index() {
293 let table_ref = TableReference::full(
294 "telemetry_catalog",
295 "telemetry_events",
296 "multi_partitioned_test_1",
297 );
298 let table = build_time_index_table(
299 "multi_partitioned_test_1",
300 "telemetry_events",
301 "telemetry_catalog",
302 );
303 let table_source = Arc::new(DefaultTableSource::new(Arc::new(
304 DfTableProviderAdapter::new(table),
305 )));
306
307 let plan =
308 LogicalPlanBuilder::scan_with_filters(table_ref.clone(), table_source, None, vec![])
309 .unwrap()
310 .aggregate(Vec::<Expr>::new(), vec![count_all()])
311 .unwrap()
312 .build()
313 .unwrap();
314
315 let time_index = CountWildcardToTimeIndexRule::try_find_time_index_col(&plan);
316 assert_eq!(
317 time_index,
318 Some(Column::new(Some(table_ref), "greptime_timestamp"))
319 );
320 }
321
322 fn build_time_index_table(table_name: &str, schema_name: &str, catalog_name: &str) -> TableRef {
323 let column_schemas = vec![
324 ColumnSchema::new(
325 "greptime_timestamp",
326 ConcreteDataType::timestamp_nanosecond_datatype(),
327 false,
328 )
329 .with_time_index(true),
330 ];
331 let schema = SchemaBuilder::try_from_columns(column_schemas)
332 .unwrap()
333 .build()
334 .unwrap();
335 let meta = TableMetaBuilder::new_external_table()
336 .schema(Arc::new(schema))
337 .next_column_id(1)
338 .build()
339 .unwrap();
340 let info = TableInfoBuilder::new(table_name.to_string(), meta)
341 .table_id(1)
342 .table_version(0)
343 .catalog_name(catalog_name)
344 .schema_name(schema_name)
345 .table_type(TableType::Base)
346 .build()
347 .unwrap();
348 let data_source = Arc::new(DummyDataSource);
349 Arc::new(Table::new(
350 Arc::new(info),
351 FilterPushDownType::Unsupported,
352 data_source,
353 ))
354 }
355
356 struct DummyDataSource;
357
358 impl DataSource for DummyDataSource {
359 fn get_stream(
360 &self,
361 _request: ScanRequest,
362 ) -> Result<SendableRecordBatchStream, BoxedError> {
363 Err(BoxedError::new(DummyDataSourceError))
364 }
365 }
366
367 #[derive(Debug)]
368 struct DummyDataSourceError;
369
370 impl std::fmt::Display for DummyDataSourceError {
371 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
372 write!(f, "dummy data source error")
373 }
374 }
375
376 impl std::error::Error for DummyDataSourceError {}
377
378 impl StackError for DummyDataSourceError {
379 fn debug_fmt(&self, _: usize, _: &mut Vec<String>) {}
380
381 fn next(&self) -> Option<&dyn StackError> {
382 None
383 }
384 }
385
386 impl ErrorExt for DummyDataSourceError {
387 fn status_code(&self) -> StatusCode {
388 StatusCode::Internal
389 }
390
391 fn as_any(&self) -> &dyn std::any::Any {
392 self
393 }
394 }
395}