query/optimizer/
count_wildcard.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use datafusion::datasource::DefaultTableSource;
16use datafusion_common::tree_node::{
17    Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor,
18};
19use datafusion_common::{Column, Result as DataFusionResult, ScalarValue};
20use datafusion_expr::expr::{AggregateFunction, WindowFunction};
21use datafusion_expr::utils::COUNT_STAR_EXPANSION;
22use datafusion_expr::{Expr, LogicalPlan, WindowFunctionDefinition, col, lit};
23use datafusion_optimizer::AnalyzerRule;
24use datafusion_optimizer::utils::NamePreserver;
25use datafusion_sql::TableReference;
26use table::table::adapter::DfTableProviderAdapter;
27
28/// A replacement to DataFusion's [`CountWildcardRule`]. This rule
29/// would prefer to use TIME INDEX for counting wildcard as it's
30/// faster to read comparing to PRIMARY KEYs.
31///
32/// [`CountWildcardRule`]: datafusion::optimizer::analyzer::CountWildcardRule
33#[derive(Debug)]
34pub struct CountWildcardToTimeIndexRule;
35
36impl AnalyzerRule for CountWildcardToTimeIndexRule {
37    fn name(&self) -> &str {
38        "count_wildcard_to_time_index_rule"
39    }
40
41    fn analyze(
42        &self,
43        plan: LogicalPlan,
44        _config: &datafusion::config::ConfigOptions,
45    ) -> DataFusionResult<LogicalPlan> {
46        plan.transform_down_with_subqueries(&Self::analyze_internal)
47            .data()
48    }
49}
50
51impl CountWildcardToTimeIndexRule {
52    fn analyze_internal(plan: LogicalPlan) -> DataFusionResult<Transformed<LogicalPlan>> {
53        let name_preserver = NamePreserver::new(&plan);
54        let new_arg = if let Some(time_index) = Self::try_find_time_index_col(&plan) {
55            vec![col(time_index)]
56        } else {
57            vec![lit(COUNT_STAR_EXPANSION)]
58        };
59        plan.map_expressions(|expr| {
60            let original_name = name_preserver.save(&expr);
61            let transformed_expr = expr.transform_up(|expr| match expr {
62                Expr::WindowFunction(mut window_function)
63                    if Self::is_count_star_window_aggregate(&window_function) =>
64                {
65                    window_function.params.args.clone_from(&new_arg);
66                    Ok(Transformed::yes(Expr::WindowFunction(window_function)))
67                }
68                Expr::AggregateFunction(mut aggregate_function)
69                    if Self::is_count_star_aggregate(&aggregate_function) =>
70                {
71                    aggregate_function.params.args.clone_from(&new_arg);
72                    Ok(Transformed::yes(Expr::AggregateFunction(
73                        aggregate_function,
74                    )))
75                }
76                _ => Ok(Transformed::no(expr)),
77            })?;
78            Ok(transformed_expr.update_data(|data| original_name.restore(data)))
79        })
80    }
81
82    fn try_find_time_index_col(plan: &LogicalPlan) -> Option<Column> {
83        let mut finder = TimeIndexFinder::default();
84        // Safety: `TimeIndexFinder` won't throw error.
85        plan.visit(&mut finder).unwrap();
86        let col = finder.into_column();
87
88        // check if the time index is a valid column as for current plan
89        if let Some(col) = &col {
90            let mut is_valid = false;
91            // if more than one input, we give up and just use `count(1)`
92            if plan.inputs().len() > 1 {
93                return None;
94            }
95            for input in plan.inputs() {
96                if input.schema().has_column(col) {
97                    is_valid = true;
98                    break;
99                }
100            }
101            if !is_valid {
102                return None;
103            }
104        }
105
106        col
107    }
108}
109
110/// Utility functions from the original rule.
111impl CountWildcardToTimeIndexRule {
112    #[expect(deprecated)]
113    fn args_at_most_wildcard_or_literal_one(args: &[Expr]) -> bool {
114        match args {
115            [] => true,
116            [Expr::Literal(ScalarValue::Int64(Some(v)), _)] => *v == 1,
117            [Expr::Wildcard { .. }] => true,
118            _ => false,
119        }
120    }
121
122    fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
123        let args = &aggregate_function.params.args;
124        matches!(aggregate_function,
125            AggregateFunction {
126                func,
127                ..
128            } if func.name() == "count" && Self::args_at_most_wildcard_or_literal_one(args))
129    }
130
131    fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
132        let args = &window_function.params.args;
133        matches!(window_function.fun,
134                WindowFunctionDefinition::AggregateUDF(ref udaf)
135                    if udaf.name() == "count" && Self::args_at_most_wildcard_or_literal_one(args))
136    }
137}
138
139#[derive(Default)]
140struct TimeIndexFinder {
141    time_index_col: Option<String>,
142    table_alias: Option<TableReference>,
143}
144
145impl TreeNodeVisitor<'_> for TimeIndexFinder {
146    type Node = LogicalPlan;
147
148    fn f_down(&mut self, node: &Self::Node) -> DataFusionResult<TreeNodeRecursion> {
149        if let LogicalPlan::SubqueryAlias(subquery_alias) = node {
150            self.table_alias = Some(subquery_alias.alias.clone());
151        }
152
153        if let LogicalPlan::TableScan(table_scan) = &node
154            && let Some(source) = table_scan
155                .source
156                .as_any()
157                .downcast_ref::<DefaultTableSource>()
158            && let Some(adapter) = source
159                .table_provider
160                .as_any()
161                .downcast_ref::<DfTableProviderAdapter>()
162        {
163            let table_info = adapter.table().table_info();
164            self.table_alias
165                .get_or_insert(table_scan.table_name.clone());
166            self.time_index_col = table_info
167                .meta
168                .schema
169                .timestamp_column()
170                .map(|c| c.name.clone());
171
172            return Ok(TreeNodeRecursion::Stop);
173        }
174
175        if node.inputs().len() > 1 {
176            // if more than one input, we give up and just use `count(1)`
177            return Ok(TreeNodeRecursion::Stop);
178        }
179
180        Ok(TreeNodeRecursion::Continue)
181    }
182
183    fn f_up(&mut self, _node: &Self::Node) -> DataFusionResult<TreeNodeRecursion> {
184        Ok(TreeNodeRecursion::Stop)
185    }
186}
187
188impl TimeIndexFinder {
189    fn into_column(self) -> Option<Column> {
190        self.time_index_col
191            .map(|c| Column::new(self.table_alias, c))
192    }
193}
194
195#[cfg(test)]
196mod test {
197    use std::sync::Arc;
198
199    use common_catalog::consts::DEFAULT_CATALOG_NAME;
200    use common_error::ext::{BoxedError, ErrorExt, StackError};
201    use common_error::status_code::StatusCode;
202    use common_recordbatch::SendableRecordBatchStream;
203    use datafusion::functions_aggregate::count::count_all;
204    use datafusion_common::Column;
205    use datafusion_expr::LogicalPlanBuilder;
206    use datafusion_sql::TableReference;
207    use datatypes::data_type::ConcreteDataType;
208    use datatypes::schema::{ColumnSchema, SchemaBuilder};
209    use store_api::data_source::DataSource;
210    use store_api::storage::ScanRequest;
211    use table::metadata::{FilterPushDownType, TableInfoBuilder, TableMetaBuilder, TableType};
212    use table::table::numbers::NumbersTable;
213    use table::{Table, TableRef};
214
215    use super::*;
216
217    #[test]
218    fn uppercase_table_name() {
219        let numbers_table = NumbersTable::table_with_name(0, "AbCdE".to_string());
220        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
221            DfTableProviderAdapter::new(numbers_table),
222        )));
223
224        let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
225            .unwrap()
226            .aggregate(Vec::<Expr>::new(), vec![count_all()])
227            .unwrap()
228            .alias(r#""FgHiJ""#)
229            .unwrap()
230            .build()
231            .unwrap();
232
233        let mut finder = TimeIndexFinder::default();
234        plan.visit(&mut finder).unwrap();
235
236        assert_eq!(finder.table_alias, Some(TableReference::bare("FgHiJ")));
237        assert!(finder.time_index_col.is_none());
238    }
239
240    #[test]
241    fn bare_table_name_time_index() {
242        let table_ref = TableReference::bare("multi_partitioned_test_1");
243        let table =
244            build_time_index_table("multi_partitioned_test_1", "public", DEFAULT_CATALOG_NAME);
245        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
246            DfTableProviderAdapter::new(table),
247        )));
248
249        let plan =
250            LogicalPlanBuilder::scan_with_filters(table_ref.clone(), table_source, None, vec![])
251                .unwrap()
252                .aggregate(Vec::<Expr>::new(), vec![count_all()])
253                .unwrap()
254                .build()
255                .unwrap();
256
257        let time_index = CountWildcardToTimeIndexRule::try_find_time_index_col(&plan);
258        assert_eq!(
259            time_index,
260            Some(Column::new(Some(table_ref), "greptime_timestamp"))
261        );
262    }
263
264    #[test]
265    fn schema_qualified_table_name_time_index() {
266        let table_ref = TableReference::partial("telemetry_events", "multi_partitioned_test_1");
267        let table = build_time_index_table(
268            "multi_partitioned_test_1",
269            "telemetry_events",
270            DEFAULT_CATALOG_NAME,
271        );
272        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
273            DfTableProviderAdapter::new(table),
274        )));
275
276        let plan =
277            LogicalPlanBuilder::scan_with_filters(table_ref.clone(), table_source, None, vec![])
278                .unwrap()
279                .aggregate(Vec::<Expr>::new(), vec![count_all()])
280                .unwrap()
281                .build()
282                .unwrap();
283
284        let time_index = CountWildcardToTimeIndexRule::try_find_time_index_col(&plan);
285        assert_eq!(
286            time_index,
287            Some(Column::new(Some(table_ref), "greptime_timestamp"))
288        );
289    }
290
291    #[test]
292    fn fully_qualified_table_name_time_index() {
293        let table_ref = TableReference::full(
294            "telemetry_catalog",
295            "telemetry_events",
296            "multi_partitioned_test_1",
297        );
298        let table = build_time_index_table(
299            "multi_partitioned_test_1",
300            "telemetry_events",
301            "telemetry_catalog",
302        );
303        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
304            DfTableProviderAdapter::new(table),
305        )));
306
307        let plan =
308            LogicalPlanBuilder::scan_with_filters(table_ref.clone(), table_source, None, vec![])
309                .unwrap()
310                .aggregate(Vec::<Expr>::new(), vec![count_all()])
311                .unwrap()
312                .build()
313                .unwrap();
314
315        let time_index = CountWildcardToTimeIndexRule::try_find_time_index_col(&plan);
316        assert_eq!(
317            time_index,
318            Some(Column::new(Some(table_ref), "greptime_timestamp"))
319        );
320    }
321
322    fn build_time_index_table(table_name: &str, schema_name: &str, catalog_name: &str) -> TableRef {
323        let column_schemas = vec![
324            ColumnSchema::new(
325                "greptime_timestamp",
326                ConcreteDataType::timestamp_nanosecond_datatype(),
327                false,
328            )
329            .with_time_index(true),
330        ];
331        let schema = SchemaBuilder::try_from_columns(column_schemas)
332            .unwrap()
333            .build()
334            .unwrap();
335        let meta = TableMetaBuilder::new_external_table()
336            .schema(Arc::new(schema))
337            .next_column_id(1)
338            .build()
339            .unwrap();
340        let info = TableInfoBuilder::new(table_name.to_string(), meta)
341            .table_id(1)
342            .table_version(0)
343            .catalog_name(catalog_name)
344            .schema_name(schema_name)
345            .table_type(TableType::Base)
346            .build()
347            .unwrap();
348        let data_source = Arc::new(DummyDataSource);
349        Arc::new(Table::new(
350            Arc::new(info),
351            FilterPushDownType::Unsupported,
352            data_source,
353        ))
354    }
355
356    struct DummyDataSource;
357
358    impl DataSource for DummyDataSource {
359        fn get_stream(
360            &self,
361            _request: ScanRequest,
362        ) -> Result<SendableRecordBatchStream, BoxedError> {
363            Err(BoxedError::new(DummyDataSourceError))
364        }
365    }
366
367    #[derive(Debug)]
368    struct DummyDataSourceError;
369
370    impl std::fmt::Display for DummyDataSourceError {
371        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
372            write!(f, "dummy data source error")
373        }
374    }
375
376    impl std::error::Error for DummyDataSourceError {}
377
378    impl StackError for DummyDataSourceError {
379        fn debug_fmt(&self, _: usize, _: &mut Vec<String>) {}
380
381        fn next(&self) -> Option<&dyn StackError> {
382            None
383        }
384    }
385
386    impl ErrorExt for DummyDataSourceError {
387        fn status_code(&self) -> StatusCode {
388            StatusCode::Internal
389        }
390
391        fn as_any(&self) -> &dyn std::any::Any {
392            self
393        }
394    }
395}