query/optimizer/
scan_hint.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16
17use api::v1::SemanticType;
18use arrow_schema::SortOptions;
19use common_recordbatch::OrderOption;
20use datafusion::datasource::DefaultTableSource;
21use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor};
22use datafusion_common::{Column, Result};
23use datafusion_expr::expr::Sort;
24use datafusion_expr::{utils, Expr, LogicalPlan};
25use datafusion_optimizer::{OptimizerConfig, OptimizerRule};
26use store_api::storage::{TimeSeriesDistribution, TimeSeriesRowSelector};
27
28use crate::dummy_catalog::DummyTableProvider;
29
30/// This rule will traverse the plan to collect necessary hints for leaf
31/// table scan node and set them in [`ScanRequest`]. Hints include:
32/// - the nearest order requirement to the leaf table scan node as ordering hint.
33/// - the group by columns when all aggregate functions are `last_value` as
34///   time series row selector hint.
35///
36/// [`ScanRequest`]: store_api::storage::ScanRequest
37#[derive(Debug)]
38pub struct ScanHintRule;
39
40impl OptimizerRule for ScanHintRule {
41    fn name(&self) -> &str {
42        "ScanHintRule"
43    }
44
45    fn rewrite(
46        &self,
47        plan: LogicalPlan,
48        _config: &dyn OptimizerConfig,
49    ) -> Result<Transformed<LogicalPlan>> {
50        Self::optimize(plan)
51    }
52}
53
54impl ScanHintRule {
55    fn optimize(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
56        let mut visitor = ScanHintVisitor::default();
57        let _ = plan.visit(&mut visitor)?;
58
59        if visitor.need_rewrite() {
60            plan.transform_down(&|plan| Self::set_hints(plan, &visitor))
61        } else {
62            Ok(Transformed::no(plan))
63        }
64    }
65
66    fn set_hints(plan: LogicalPlan, visitor: &ScanHintVisitor) -> Result<Transformed<LogicalPlan>> {
67        match &plan {
68            LogicalPlan::TableScan(table_scan) => {
69                let mut transformed = false;
70                if let Some(source) = table_scan
71                    .source
72                    .as_any()
73                    .downcast_ref::<DefaultTableSource>()
74                {
75                    // The provider in the region server is [DummyTableProvider].
76                    if let Some(adapter) = source
77                        .table_provider
78                        .as_any()
79                        .downcast_ref::<DummyTableProvider>()
80                    {
81                        // set order_hint
82                        if let Some(order_expr) = &visitor.order_expr {
83                            Self::set_order_hint(adapter, order_expr);
84                        }
85
86                        // set time series selector hint
87                        if let Some((group_by_cols, order_by_col)) = &visitor.ts_row_selector {
88                            Self::set_time_series_row_selector_hint(
89                                adapter,
90                                group_by_cols,
91                                order_by_col,
92                            );
93                        }
94
95                        transformed = true;
96                    }
97                }
98                if transformed {
99                    Ok(Transformed::yes(plan))
100                } else {
101                    Ok(Transformed::no(plan))
102                }
103            }
104            _ => Ok(Transformed::no(plan)),
105        }
106    }
107
108    fn set_order_hint(adapter: &DummyTableProvider, order_expr: &Vec<Sort>) {
109        let mut opts = Vec::with_capacity(order_expr.len());
110        for sort in order_expr {
111            let name = match sort.expr.try_as_col() {
112                Some(col) => col.name.clone(),
113                None => return,
114            };
115            opts.push(OrderOption {
116                name,
117                options: SortOptions {
118                    descending: !sort.asc,
119                    nulls_first: sort.nulls_first,
120                },
121            });
122        }
123        adapter.with_ordering_hint(&opts);
124
125        let mut sort_expr_cursor = order_expr.iter().filter_map(|s| s.expr.try_as_col());
126        let region_metadata = adapter.region_metadata();
127        // ignore table without pk
128        if region_metadata.primary_key.is_empty() {
129            return;
130        }
131        let mut pk_column_iter = region_metadata.primary_key_columns();
132        let mut curr_sort_expr = sort_expr_cursor.next();
133        let mut curr_pk_col = pk_column_iter.next();
134
135        while let (Some(sort_expr), Some(pk_col)) = (curr_sort_expr, curr_pk_col) {
136            if sort_expr.name == pk_col.column_schema.name {
137                curr_sort_expr = sort_expr_cursor.next();
138                curr_pk_col = pk_column_iter.next();
139            } else {
140                return;
141            }
142        }
143
144        let next_remaining = sort_expr_cursor.next();
145        match (curr_sort_expr, next_remaining) {
146            (Some(expr), None)
147                if expr.name == region_metadata.time_index_column().column_schema.name =>
148            {
149                adapter.with_distribution(TimeSeriesDistribution::PerSeries);
150            }
151            (None, _) => adapter.with_distribution(TimeSeriesDistribution::PerSeries),
152            (Some(_), _) => {}
153        }
154    }
155
156    fn set_time_series_row_selector_hint(
157        adapter: &DummyTableProvider,
158        group_by_cols: &HashSet<Column>,
159        order_by_col: &Column,
160    ) {
161        let region_metadata = adapter.region_metadata();
162        let mut should_set_selector_hint = true;
163        // check if order_by column is time index
164        if let Some(column_metadata) = region_metadata.column_by_name(&order_by_col.name) {
165            if column_metadata.semantic_type != SemanticType::Timestamp {
166                should_set_selector_hint = false;
167            }
168        } else {
169            should_set_selector_hint = false;
170        }
171
172        // check if all group_by columns are primary key
173        for col in group_by_cols {
174            let Some(column_metadata) = region_metadata.column_by_name(&col.name) else {
175                should_set_selector_hint = false;
176                break;
177            };
178            if column_metadata.semantic_type != SemanticType::Tag {
179                should_set_selector_hint = false;
180                break;
181            }
182        }
183
184        if should_set_selector_hint {
185            adapter.with_time_series_selector_hint(TimeSeriesRowSelector::LastRow);
186        }
187    }
188}
189
190/// Traverse and fetch hints.
191#[derive(Default)]
192struct ScanHintVisitor {
193    /// The closest order requirement to the leaf node.
194    order_expr: Option<Vec<Sort>>,
195    /// Row selection on time series distribution.
196    /// This field stores saved `group_by` columns when all aggregate functions are `last_value`
197    /// and the `order_by` column which should be time index.
198    ts_row_selector: Option<(HashSet<Column>, Column)>,
199}
200
201impl TreeNodeVisitor<'_> for ScanHintVisitor {
202    type Node = LogicalPlan;
203
204    fn f_down(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
205        // Get order requirement from sort plan
206        if let LogicalPlan::Sort(sort) = node {
207            self.order_expr = Some(sort.expr.clone());
208        }
209
210        // Get time series row selector from aggr plan
211        if let LogicalPlan::Aggregate(aggregate) = node {
212            let mut is_all_last_value = !aggregate.aggr_expr.is_empty();
213            let mut order_by_expr = None;
214            for expr in &aggregate.aggr_expr {
215                // check function name
216                let Expr::AggregateFunction(func) = expr else {
217                    is_all_last_value = false;
218                    break;
219                };
220                if func.func.name() != "last_value" || func.filter.is_some() || func.distinct {
221                    is_all_last_value = false;
222                    break;
223                }
224                // check order by requirement
225                if let Some(order_by) = &func.order_by
226                    && let Some(first_order_by) = order_by.first()
227                    && order_by.len() == 1
228                {
229                    if let Some(existing_order_by) = &order_by_expr {
230                        if existing_order_by != first_order_by {
231                            is_all_last_value = false;
232                            break;
233                        }
234                    } else {
235                        // only allow `order by xxx [ASC]`, xxx is a bare column reference so `last_value()` is the max
236                        // value of the column.
237                        if !first_order_by.asc || !matches!(&first_order_by.expr, Expr::Column(_)) {
238                            is_all_last_value = false;
239                            break;
240                        }
241                        order_by_expr = Some(first_order_by.clone());
242                    }
243                }
244            }
245            is_all_last_value &= order_by_expr.is_some();
246            if is_all_last_value {
247                // make sure all the exprs are DIRECT `col` and collect them
248                let mut group_by_cols = HashSet::with_capacity(aggregate.group_expr.len());
249                for expr in &aggregate.group_expr {
250                    if let Expr::Column(col) = expr {
251                        group_by_cols.insert(col.clone());
252                    } else {
253                        is_all_last_value = false;
254                        break;
255                    }
256                }
257                // Safety: checked in the above loop
258                let order_by_expr = order_by_expr.unwrap();
259                let Expr::Column(order_by_col) = order_by_expr.expr else {
260                    unreachable!()
261                };
262                if is_all_last_value {
263                    self.ts_row_selector = Some((group_by_cols, order_by_col));
264                }
265            }
266        }
267
268        if self.ts_row_selector.is_some()
269            && (matches!(node, LogicalPlan::Subquery(_)) || node.inputs().len() > 1)
270        {
271            // clean previous time series selector hint when encounter subqueries or join
272            self.ts_row_selector = None;
273        }
274
275        if let LogicalPlan::Filter(filter) = node
276            && let Some(group_by_exprs) = &self.ts_row_selector
277        {
278            let mut filter_referenced_cols = HashSet::default();
279            utils::expr_to_columns(&filter.predicate, &mut filter_referenced_cols)?;
280            // ensure only group_by columns are used in filter
281            if !filter_referenced_cols.is_subset(&group_by_exprs.0) {
282                self.ts_row_selector = None;
283            }
284        }
285
286        Ok(TreeNodeRecursion::Continue)
287    }
288}
289
290impl ScanHintVisitor {
291    fn need_rewrite(&self) -> bool {
292        self.order_expr.is_some() || self.ts_row_selector.is_some()
293    }
294}
295
296#[cfg(test)]
297mod test {
298    use std::sync::Arc;
299
300    use datafusion::functions_aggregate::first_last::last_value_udaf;
301    use datafusion_expr::expr::AggregateFunction;
302    use datafusion_expr::{col, LogicalPlanBuilder};
303    use datafusion_optimizer::OptimizerContext;
304    use store_api::storage::RegionId;
305
306    use super::*;
307    use crate::optimizer::test_util::mock_table_provider;
308
309    #[test]
310    fn set_order_hint() {
311        let provider = Arc::new(mock_table_provider(RegionId::new(1, 1)));
312        let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
313        let plan = LogicalPlanBuilder::scan("t", table_source, None)
314            .unwrap()
315            .sort(vec![col("ts").sort(true, false)])
316            .unwrap()
317            .sort(vec![col("ts").sort(false, true)])
318            .unwrap()
319            .build()
320            .unwrap();
321
322        let context = OptimizerContext::default();
323        assert!(ScanHintRule.supports_rewrite());
324        ScanHintRule.rewrite(plan, &context).unwrap();
325
326        // should read the first (with `.sort(true, false)`) sort option
327        let scan_req = provider.scan_request();
328        assert_eq!(
329            OrderOption {
330                name: "ts".to_string(),
331                options: SortOptions {
332                    descending: false,
333                    nulls_first: false
334                }
335            },
336            scan_req.output_ordering.as_ref().unwrap()[0]
337        );
338    }
339
340    #[test]
341    fn set_time_series_row_selector_hint() {
342        let provider = Arc::new(mock_table_provider(RegionId::new(1, 1)));
343        let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
344        let plan = LogicalPlanBuilder::scan("t", table_source, None)
345            .unwrap()
346            .aggregate(
347                vec![col("k0")],
348                vec![Expr::AggregateFunction(AggregateFunction {
349                    func: last_value_udaf(),
350                    args: vec![col("v0")],
351                    distinct: false,
352                    filter: None,
353                    order_by: Some(vec![Sort {
354                        expr: col("ts"),
355                        asc: true,
356                        nulls_first: true,
357                    }]),
358                    null_treatment: None,
359                })],
360            )
361            .unwrap()
362            .build()
363            .unwrap();
364
365        let context = OptimizerContext::default();
366        assert!(ScanHintRule.supports_rewrite());
367        ScanHintRule.rewrite(plan, &context).unwrap();
368
369        let scan_req = provider.scan_request();
370        let _ = scan_req.series_row_selector.unwrap();
371    }
372}