query/optimizer/
count_wildcard.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use datafusion::datasource::DefaultTableSource;
16use datafusion_common::tree_node::{
17    Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor,
18};
19use datafusion_common::{Column, Result as DataFusionResult, ScalarValue};
20use datafusion_expr::expr::{AggregateFunction, WindowFunction};
21use datafusion_expr::utils::COUNT_STAR_EXPANSION;
22use datafusion_expr::{col, lit, Expr, LogicalPlan, WindowFunctionDefinition};
23use datafusion_optimizer::utils::NamePreserver;
24use datafusion_optimizer::AnalyzerRule;
25use datafusion_sql::TableReference;
26use table::table::adapter::DfTableProviderAdapter;
27
28/// A replacement to DataFusion's [`CountWildcardRule`]. This rule
29/// would prefer to use TIME INDEX for counting wildcard as it's
30/// faster to read comparing to PRIMARY KEYs.
31///
32/// [`CountWildcardRule`]: datafusion::optimizer::analyzer::CountWildcardRule
33#[derive(Debug)]
34pub struct CountWildcardToTimeIndexRule;
35
36impl AnalyzerRule for CountWildcardToTimeIndexRule {
37    fn name(&self) -> &str {
38        "count_wildcard_to_time_index_rule"
39    }
40
41    fn analyze(
42        &self,
43        plan: LogicalPlan,
44        _config: &datafusion::config::ConfigOptions,
45    ) -> DataFusionResult<LogicalPlan> {
46        plan.transform_down_with_subqueries(&Self::analyze_internal)
47            .data()
48    }
49}
50
51impl CountWildcardToTimeIndexRule {
52    fn analyze_internal(plan: LogicalPlan) -> DataFusionResult<Transformed<LogicalPlan>> {
53        let name_preserver = NamePreserver::new(&plan);
54        let new_arg = if let Some(time_index) = Self::try_find_time_index_col(&plan) {
55            vec![col(time_index)]
56        } else {
57            vec![lit(COUNT_STAR_EXPANSION)]
58        };
59        plan.map_expressions(|expr| {
60            let original_name = name_preserver.save(&expr);
61            let transformed_expr = expr.transform_up(|expr| match expr {
62                Expr::WindowFunction(mut window_function)
63                    if Self::is_count_star_window_aggregate(&window_function) =>
64                {
65                    window_function.params.args.clone_from(&new_arg);
66                    Ok(Transformed::yes(Expr::WindowFunction(window_function)))
67                }
68                Expr::AggregateFunction(mut aggregate_function)
69                    if Self::is_count_star_aggregate(&aggregate_function) =>
70                {
71                    aggregate_function.params.args.clone_from(&new_arg);
72                    Ok(Transformed::yes(Expr::AggregateFunction(
73                        aggregate_function,
74                    )))
75                }
76                _ => Ok(Transformed::no(expr)),
77            })?;
78            Ok(transformed_expr.update_data(|data| original_name.restore(data)))
79        })
80    }
81
82    fn try_find_time_index_col(plan: &LogicalPlan) -> Option<Column> {
83        let mut finder = TimeIndexFinder::default();
84        // Safety: `TimeIndexFinder` won't throw error.
85        plan.visit(&mut finder).unwrap();
86        let col = finder.into_column();
87
88        // check if the time index is a valid column as for current plan
89        if let Some(col) = &col {
90            let mut is_valid = false;
91            for input in plan.inputs() {
92                if input.schema().has_column(col) {
93                    is_valid = true;
94                    break;
95                }
96            }
97            if !is_valid {
98                return None;
99            }
100        }
101
102        col
103    }
104}
105
106/// Utility functions from the original rule.
107impl CountWildcardToTimeIndexRule {
108    #[expect(deprecated)]
109    fn args_at_most_wildcard_or_literal_one(args: &[Expr]) -> bool {
110        match args {
111            [] => true,
112            [Expr::Literal(ScalarValue::Int64(Some(v)), _)] => *v == 1,
113            [Expr::Wildcard { .. }] => true,
114            _ => false,
115        }
116    }
117
118    fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
119        let args = &aggregate_function.params.args;
120        matches!(aggregate_function,
121            AggregateFunction {
122                func,
123                ..
124            } if func.name() == "count" && Self::args_at_most_wildcard_or_literal_one(args))
125    }
126
127    fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
128        let args = &window_function.params.args;
129        matches!(window_function.fun,
130                WindowFunctionDefinition::AggregateUDF(ref udaf)
131                    if udaf.name() == "count" && Self::args_at_most_wildcard_or_literal_one(args))
132    }
133}
134
135#[derive(Default)]
136struct TimeIndexFinder {
137    time_index_col: Option<String>,
138    table_alias: Option<TableReference>,
139}
140
141impl TreeNodeVisitor<'_> for TimeIndexFinder {
142    type Node = LogicalPlan;
143
144    fn f_down(&mut self, node: &Self::Node) -> DataFusionResult<TreeNodeRecursion> {
145        if let LogicalPlan::SubqueryAlias(subquery_alias) = node {
146            self.table_alias = Some(subquery_alias.alias.clone());
147        }
148
149        if let LogicalPlan::TableScan(table_scan) = &node {
150            if let Some(source) = table_scan
151                .source
152                .as_any()
153                .downcast_ref::<DefaultTableSource>()
154            {
155                if let Some(adapter) = source
156                    .table_provider
157                    .as_any()
158                    .downcast_ref::<DfTableProviderAdapter>()
159                {
160                    let table_info = adapter.table().table_info();
161                    self.table_alias
162                        .get_or_insert(TableReference::bare(table_info.name.clone()));
163                    self.time_index_col = table_info
164                        .meta
165                        .schema
166                        .timestamp_column()
167                        .map(|c| c.name.clone());
168
169                    return Ok(TreeNodeRecursion::Stop);
170                }
171            }
172        }
173
174        Ok(TreeNodeRecursion::Continue)
175    }
176
177    fn f_up(&mut self, _node: &Self::Node) -> DataFusionResult<TreeNodeRecursion> {
178        Ok(TreeNodeRecursion::Stop)
179    }
180}
181
182impl TimeIndexFinder {
183    fn into_column(self) -> Option<Column> {
184        self.time_index_col
185            .map(|c| Column::new(self.table_alias, c))
186    }
187}
188
189#[cfg(test)]
190mod test {
191    use std::sync::Arc;
192
193    use datafusion::functions_aggregate::count::count_all;
194    use datafusion_expr::LogicalPlanBuilder;
195    use table::table::numbers::NumbersTable;
196
197    use super::*;
198
199    #[test]
200    fn uppercase_table_name() {
201        let numbers_table = NumbersTable::table_with_name(0, "AbCdE".to_string());
202        let table_source = Arc::new(DefaultTableSource::new(Arc::new(
203            DfTableProviderAdapter::new(numbers_table),
204        )));
205
206        let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
207            .unwrap()
208            .aggregate(Vec::<Expr>::new(), vec![count_all()])
209            .unwrap()
210            .alias(r#""FgHiJ""#)
211            .unwrap()
212            .build()
213            .unwrap();
214
215        let mut finder = TimeIndexFinder::default();
216        plan.visit(&mut finder).unwrap();
217
218        assert_eq!(finder.table_alias, Some(TableReference::bare("FgHiJ")));
219        assert!(finder.time_index_col.is_none());
220    }
221}