query/optimizer/
scan_hint.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16
17use api::v1::SemanticType;
18use arrow_schema::SortOptions;
19use common_function::aggrs::aggr_wrapper::aggr_state_func_name;
20use common_recordbatch::OrderOption;
21use datafusion::datasource::DefaultTableSource;
22use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor};
23use datafusion_common::{Column, Result};
24use datafusion_expr::expr::Sort;
25use datafusion_expr::{Expr, LogicalPlan, utils};
26use datafusion_optimizer::{OptimizerConfig, OptimizerRule};
27use store_api::storage::{TimeSeriesDistribution, TimeSeriesRowSelector};
28
29use crate::dummy_catalog::DummyTableProvider;
30
31/// This rule will traverse the plan to collect necessary hints for leaf
32/// table scan node and set them in [`ScanRequest`]. Hints include:
33/// - the nearest order requirement to the leaf table scan node as ordering hint.
34/// - the group by columns when all aggregate functions are `last_value` as
35///   time series row selector hint.
36///
37/// [`ScanRequest`]: store_api::storage::ScanRequest
38#[derive(Debug)]
39pub struct ScanHintRule;
40
41impl OptimizerRule for ScanHintRule {
42    fn name(&self) -> &str {
43        "ScanHintRule"
44    }
45
46    fn rewrite(
47        &self,
48        plan: LogicalPlan,
49        _config: &dyn OptimizerConfig,
50    ) -> Result<Transformed<LogicalPlan>> {
51        Self::optimize(plan)
52    }
53}
54
55impl ScanHintRule {
56    fn optimize(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
57        let mut visitor = ScanHintVisitor::default();
58        let _ = plan.visit(&mut visitor)?;
59
60        if visitor.need_rewrite() {
61            plan.transform_down(&|plan| Self::set_hints(plan, &visitor))
62        } else {
63            Ok(Transformed::no(plan))
64        }
65    }
66
67    fn set_hints(plan: LogicalPlan, visitor: &ScanHintVisitor) -> Result<Transformed<LogicalPlan>> {
68        match &plan {
69            LogicalPlan::TableScan(table_scan) => {
70                let mut transformed = false;
71                if let Some(source) = table_scan
72                    .source
73                    .as_any()
74                    .downcast_ref::<DefaultTableSource>()
75                {
76                    // The provider in the region server is [DummyTableProvider].
77                    if let Some(adapter) = source
78                        .table_provider
79                        .as_any()
80                        .downcast_ref::<DummyTableProvider>()
81                    {
82                        // set order_hint
83                        if let Some(order_expr) = &visitor.order_expr {
84                            Self::set_order_hint(adapter, order_expr);
85                        }
86
87                        // set time series selector hint
88                        if let Some((group_by_cols, order_by_col)) = &visitor.ts_row_selector {
89                            Self::set_time_series_row_selector_hint(
90                                adapter,
91                                group_by_cols,
92                                order_by_col,
93                            );
94                        }
95
96                        transformed = true;
97                    }
98                }
99                if transformed {
100                    Ok(Transformed::yes(plan))
101                } else {
102                    Ok(Transformed::no(plan))
103                }
104            }
105            _ => Ok(Transformed::no(plan)),
106        }
107    }
108
109    fn set_order_hint(adapter: &DummyTableProvider, order_expr: &Vec<Sort>) {
110        let mut opts = Vec::with_capacity(order_expr.len());
111        for sort in order_expr {
112            let name = match sort.expr.try_as_col() {
113                Some(col) => col.name.clone(),
114                None => return,
115            };
116            opts.push(OrderOption {
117                name,
118                options: SortOptions {
119                    descending: !sort.asc,
120                    nulls_first: sort.nulls_first,
121                },
122            });
123        }
124        adapter.with_ordering_hint(&opts);
125
126        let mut sort_expr_cursor = order_expr.iter().filter_map(|s| s.expr.try_as_col());
127        let region_metadata = adapter.region_metadata();
128        // ignore table without pk
129        if region_metadata.primary_key.is_empty() {
130            return;
131        }
132        let mut pk_column_iter = region_metadata.primary_key_columns();
133        let mut curr_sort_expr = sort_expr_cursor.next();
134        let mut curr_pk_col = pk_column_iter.next();
135
136        while let (Some(sort_expr), Some(pk_col)) = (curr_sort_expr, curr_pk_col) {
137            if sort_expr.name == pk_col.column_schema.name {
138                curr_sort_expr = sort_expr_cursor.next();
139                curr_pk_col = pk_column_iter.next();
140            } else {
141                return;
142            }
143        }
144
145        let next_remaining = sort_expr_cursor.next();
146        match (curr_sort_expr, next_remaining) {
147            (Some(expr), None)
148                if expr.name == region_metadata.time_index_column().column_schema.name =>
149            {
150                adapter.with_distribution(TimeSeriesDistribution::PerSeries);
151            }
152            (None, _) => adapter.with_distribution(TimeSeriesDistribution::PerSeries),
153            (Some(_), _) => {}
154        }
155    }
156
157    fn set_time_series_row_selector_hint(
158        adapter: &DummyTableProvider,
159        group_by_cols: &HashSet<Column>,
160        order_by_col: &Column,
161    ) {
162        let region_metadata = adapter.region_metadata();
163        let mut should_set_selector_hint = true;
164        // check if order_by column is time index
165        if let Some(column_metadata) = region_metadata.column_by_name(&order_by_col.name) {
166            if column_metadata.semantic_type != SemanticType::Timestamp {
167                should_set_selector_hint = false;
168            }
169        } else {
170            should_set_selector_hint = false;
171        }
172
173        // check if all group_by columns are primary key
174        for col in group_by_cols {
175            let Some(column_metadata) = region_metadata.column_by_name(&col.name) else {
176                should_set_selector_hint = false;
177                break;
178            };
179            if column_metadata.semantic_type != SemanticType::Tag {
180                should_set_selector_hint = false;
181                break;
182            }
183        }
184
185        if should_set_selector_hint {
186            adapter.with_time_series_selector_hint(TimeSeriesRowSelector::LastRow);
187        }
188    }
189}
190
191/// Traverse and fetch hints.
192#[derive(Default)]
193struct ScanHintVisitor {
194    /// The closest order requirement to the leaf node.
195    order_expr: Option<Vec<Sort>>,
196    /// Row selection on time series distribution.
197    /// This field stores saved `group_by` columns when all aggregate functions are `last_value`
198    /// and the `order_by` column which should be time index.
199    ts_row_selector: Option<(HashSet<Column>, Column)>,
200}
201
202impl TreeNodeVisitor<'_> for ScanHintVisitor {
203    type Node = LogicalPlan;
204
205    fn f_down(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
206        // Get order requirement from sort plan
207        if let LogicalPlan::Sort(sort) = node {
208            self.order_expr = Some(sort.expr.clone());
209        }
210
211        // Get time series row selector from aggr plan
212        if let LogicalPlan::Aggregate(aggregate) = node {
213            let mut is_all_last_value = !aggregate.aggr_expr.is_empty();
214            let mut order_by_expr = None;
215            for expr in &aggregate.aggr_expr {
216                // check function name
217                let Expr::AggregateFunction(func) = expr else {
218                    is_all_last_value = false;
219                    break;
220                };
221                if (func.func.name() != "last_value"
222                    && func.func.name() != aggr_state_func_name("last_value"))
223                    || func.params.filter.is_some()
224                    || func.params.distinct
225                {
226                    is_all_last_value = false;
227                    break;
228                }
229                // check order by requirement
230                let order_by = &func.params.order_by;
231                if let Some(first_order_by) = order_by.first()
232                    && order_by.len() == 1
233                {
234                    if let Some(existing_order_by) = &order_by_expr {
235                        if existing_order_by != first_order_by {
236                            is_all_last_value = false;
237                            break;
238                        }
239                    } else {
240                        // only allow `order by xxx [ASC]`, xxx is a bare column reference so `last_value()` is the max
241                        // value of the column.
242                        if !first_order_by.asc || !matches!(&first_order_by.expr, Expr::Column(_)) {
243                            is_all_last_value = false;
244                            break;
245                        }
246                        order_by_expr = Some(first_order_by.clone());
247                    }
248                }
249            }
250            is_all_last_value &= order_by_expr.is_some();
251            if is_all_last_value {
252                // make sure all the exprs are DIRECT `col` and collect them
253                let mut group_by_cols = HashSet::with_capacity(aggregate.group_expr.len());
254                for expr in &aggregate.group_expr {
255                    if let Expr::Column(col) = expr {
256                        group_by_cols.insert(col.clone());
257                    } else {
258                        is_all_last_value = false;
259                        break;
260                    }
261                }
262                // Safety: checked in the above loop
263                let order_by_expr = order_by_expr.unwrap();
264                let Expr::Column(order_by_col) = order_by_expr.expr else {
265                    unreachable!()
266                };
267                if is_all_last_value {
268                    self.ts_row_selector = Some((group_by_cols, order_by_col));
269                }
270            }
271        }
272
273        if self.ts_row_selector.is_some()
274            && (matches!(node, LogicalPlan::Subquery(_)) || node.inputs().len() > 1)
275        {
276            // clean previous time series selector hint when encounter subqueries or join
277            self.ts_row_selector = None;
278        }
279
280        if let LogicalPlan::Filter(filter) = node
281            && let Some(group_by_exprs) = &self.ts_row_selector
282        {
283            let mut filter_referenced_cols = HashSet::default();
284            utils::expr_to_columns(&filter.predicate, &mut filter_referenced_cols)?;
285            // ensure only group_by columns are used in filter
286            if !filter_referenced_cols.is_subset(&group_by_exprs.0) {
287                self.ts_row_selector = None;
288            }
289        }
290
291        Ok(TreeNodeRecursion::Continue)
292    }
293}
294
295impl ScanHintVisitor {
296    fn need_rewrite(&self) -> bool {
297        self.order_expr.is_some() || self.ts_row_selector.is_some()
298    }
299}
300
301#[cfg(test)]
302mod test {
303    use std::sync::Arc;
304
305    use datafusion::functions_aggregate::first_last::last_value_udaf;
306    use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
307    use datafusion_expr::{LogicalPlanBuilder, col};
308    use datafusion_optimizer::OptimizerContext;
309    use store_api::storage::RegionId;
310
311    use super::*;
312    use crate::optimizer::test_util::mock_table_provider;
313
314    #[test]
315    fn set_order_hint() {
316        let provider = Arc::new(mock_table_provider(RegionId::new(1, 1)));
317        let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
318        let plan = LogicalPlanBuilder::scan("t", table_source, None)
319            .unwrap()
320            .sort(vec![col("ts").sort(true, false)])
321            .unwrap()
322            .sort(vec![col("ts").sort(false, true)])
323            .unwrap()
324            .build()
325            .unwrap();
326
327        let context = OptimizerContext::default();
328        ScanHintRule.rewrite(plan, &context).unwrap();
329
330        // should read the first (with `.sort(true, false)`) sort option
331        let scan_req = provider.scan_request();
332        assert_eq!(
333            OrderOption {
334                name: "ts".to_string(),
335                options: SortOptions {
336                    descending: false,
337                    nulls_first: false
338                }
339            },
340            scan_req.output_ordering.as_ref().unwrap()[0]
341        );
342    }
343
344    #[test]
345    fn set_time_series_row_selector_hint() {
346        let provider = Arc::new(mock_table_provider(RegionId::new(1, 1)));
347        let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
348        let plan = LogicalPlanBuilder::scan("t", table_source, None)
349            .unwrap()
350            .aggregate(
351                vec![col("k0")],
352                vec![Expr::AggregateFunction(AggregateFunction {
353                    func: last_value_udaf(),
354                    params: AggregateFunctionParams {
355                        args: vec![col("v0")],
356                        distinct: false,
357                        filter: None,
358                        order_by: vec![Sort {
359                            expr: col("ts"),
360                            asc: true,
361                            nulls_first: true,
362                        }],
363                        null_treatment: None,
364                    },
365                })],
366            )
367            .unwrap()
368            .build()
369            .unwrap();
370
371        let context = OptimizerContext::default();
372        ScanHintRule.rewrite(plan, &context).unwrap();
373
374        let scan_req = provider.scan_request();
375        let _ = scan_req.series_row_selector.unwrap();
376    }
377}