query/optimizer/
scan_hint.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16
17use api::v1::SemanticType;
18use arrow_schema::SortOptions;
19use common_function::aggrs::aggr_wrapper::aggr_state_func_name;
20use common_recordbatch::OrderOption;
21use datafusion::datasource::DefaultTableSource;
22use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor};
23use datafusion_common::{Column, Result};
24use datafusion_expr::expr::Sort;
25use datafusion_expr::{Expr, LogicalPlan, utils};
26use datafusion_optimizer::{OptimizerConfig, OptimizerRule};
27use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
28use store_api::storage::{TimeSeriesDistribution, TimeSeriesRowSelector};
29
30use crate::dummy_catalog::DummyTableProvider;
31#[cfg(feature = "vector_index")]
32mod vector_search;
33#[cfg(feature = "vector_index")]
34use vector_search::VectorSearchState;
35
36/// This rule will traverse the plan to collect necessary hints for leaf
37/// table scan node and set them in [`ScanRequest`]. Hints include:
38/// - the nearest order requirement to the leaf table scan node as ordering hint.
39/// - the group by columns when all aggregate functions are `last_value` as
40///   time series row selector hint.
41///
42/// [`ScanRequest`]: store_api::storage::ScanRequest
43#[derive(Debug)]
44pub struct ScanHintRule;
45
46impl OptimizerRule for ScanHintRule {
47    fn name(&self) -> &str {
48        "ScanHintRule"
49    }
50
51    fn rewrite(
52        &self,
53        plan: LogicalPlan,
54        _config: &dyn OptimizerConfig,
55    ) -> Result<Transformed<LogicalPlan>> {
56        Self::optimize(plan)
57    }
58}
59
60impl ScanHintRule {
61    fn optimize(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
62        let mut visitor = ScanHintVisitor::default();
63        let _ = plan.visit(&mut visitor)?;
64
65        if visitor.need_rewrite() {
66            plan.transform_down(&mut |plan| Self::set_hints(plan, &mut visitor))
67        } else {
68            Ok(Transformed::no(plan))
69        }
70    }
71
72    fn set_hints(
73        plan: LogicalPlan,
74        visitor: &mut ScanHintVisitor,
75    ) -> Result<Transformed<LogicalPlan>> {
76        match &plan {
77            LogicalPlan::TableScan(table_scan) => {
78                let mut transformed = false;
79                if let Some(source) = table_scan
80                    .source
81                    .as_any()
82                    .downcast_ref::<DefaultTableSource>()
83                {
84                    // The provider in the region server is [DummyTableProvider].
85                    if let Some(adapter) = source
86                        .table_provider
87                        .as_any()
88                        .downcast_ref::<DummyTableProvider>()
89                    {
90                        // set order_hint
91                        if let Some(order_expr) = &visitor.order_expr {
92                            Self::set_order_hint(adapter, order_expr);
93                        }
94
95                        // set time series selector hint
96                        if let Some((group_by_cols, order_by_col)) = &visitor.ts_row_selector {
97                            Self::set_time_series_row_selector_hint(
98                                adapter,
99                                group_by_cols,
100                                order_by_col,
101                            );
102                        }
103
104                        #[cfg(feature = "vector_index")]
105                        if let Some(vector_request) = visitor
106                            .vector_search
107                            .take_vector_request_from_dummy(adapter, &table_scan.table_name)
108                        {
109                            adapter.with_vector_search_hint(vector_request);
110                        }
111                        transformed = true;
112                    }
113                }
114                if transformed {
115                    Ok(Transformed::yes(plan))
116                } else {
117                    Ok(Transformed::no(plan))
118                }
119            }
120            _ => Ok(Transformed::no(plan)),
121        }
122    }
123
124    fn set_order_hint(adapter: &DummyTableProvider, order_expr: &Vec<Sort>) {
125        let mut opts = Vec::with_capacity(order_expr.len());
126        for sort in order_expr {
127            let name = match sort.expr.try_as_col() {
128                Some(col) => col.name.clone(),
129                None => return,
130            };
131            opts.push(OrderOption {
132                name,
133                options: SortOptions {
134                    descending: !sort.asc,
135                    nulls_first: sort.nulls_first,
136                },
137            });
138        }
139        adapter.with_ordering_hint(&opts);
140
141        let region_metadata = adapter.region_metadata();
142        let time_index_name = region_metadata
143            .time_index_column()
144            .column_schema
145            .name
146            .as_str();
147        let sort_cols = order_expr
148            .iter()
149            .filter_map(|s| s.expr.try_as_col())
150            .collect::<Vec<_>>();
151
152        // Special-case metric engine: when the nearest sort requirement is `__tsid, <time index>`,
153        // we can safely enable per-series distribution hint so the region can use `SeriesScan`.
154        //
155        // This pattern is produced by promql planning when `__tsid` is available and is used as the
156        // series identifier (instead of expanding to all tag columns).
157        if sort_cols.len() == 2
158            && sort_cols[0].name == DATA_SCHEMA_TSID_COLUMN_NAME
159            && sort_cols[1].name == time_index_name
160        {
161            adapter.with_distribution(TimeSeriesDistribution::PerSeries);
162            return;
163        }
164
165        let mut sort_expr_cursor = sort_cols.into_iter();
166        // ignore table without pk
167        if region_metadata.primary_key.is_empty() {
168            return;
169        }
170        let mut pk_column_iter = region_metadata.primary_key_columns();
171        let mut curr_sort_expr = sort_expr_cursor.next();
172        let mut curr_pk_col = pk_column_iter.next();
173
174        while let (Some(sort_expr), Some(pk_col)) = (curr_sort_expr, curr_pk_col) {
175            if sort_expr.name == pk_col.column_schema.name {
176                curr_sort_expr = sort_expr_cursor.next();
177                curr_pk_col = pk_column_iter.next();
178            } else {
179                return;
180            }
181        }
182
183        let next_remaining = sort_expr_cursor.next();
184        match (curr_sort_expr, next_remaining) {
185            (Some(expr), None)
186                if expr.name == region_metadata.time_index_column().column_schema.name =>
187            {
188                adapter.with_distribution(TimeSeriesDistribution::PerSeries);
189            }
190            (None, _) => adapter.with_distribution(TimeSeriesDistribution::PerSeries),
191            (Some(_), _) => {}
192        }
193    }
194
195    fn set_time_series_row_selector_hint(
196        adapter: &DummyTableProvider,
197        group_by_cols: &HashSet<Column>,
198        order_by_col: &Column,
199    ) {
200        let region_metadata = adapter.region_metadata();
201        let mut should_set_selector_hint = true;
202        // check if order_by column is time index
203        if let Some(column_metadata) = region_metadata.column_by_name(&order_by_col.name) {
204            if column_metadata.semantic_type != SemanticType::Timestamp {
205                should_set_selector_hint = false;
206            }
207        } else {
208            should_set_selector_hint = false;
209        }
210
211        // check if all group_by columns are primary key
212        for col in group_by_cols {
213            let Some(column_metadata) = region_metadata.column_by_name(&col.name) else {
214                should_set_selector_hint = false;
215                break;
216            };
217            if column_metadata.semantic_type != SemanticType::Tag {
218                should_set_selector_hint = false;
219                break;
220            }
221        }
222
223        if should_set_selector_hint {
224            adapter.with_time_series_selector_hint(TimeSeriesRowSelector::LastRow);
225        }
226    }
227}
228
229/// Traverse and fetch hints.
230#[derive(Default)]
231struct ScanHintVisitor {
232    /// The closest order requirement to the leaf node.
233    order_expr: Option<Vec<Sort>>,
234    /// Row selection on time series distribution.
235    /// This field stores saved `group_by` columns when all aggregate functions are `last_value`
236    /// and the `order_by` column which should be time index.
237    ts_row_selector: Option<(HashSet<Column>, Column)>,
238    #[cfg(feature = "vector_index")]
239    vector_search: VectorSearchState,
240}
241
242impl TreeNodeVisitor<'_> for ScanHintVisitor {
243    type Node = LogicalPlan;
244
245    fn f_down(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
246        #[cfg(feature = "vector_index")]
247        if let LogicalPlan::Limit(limit) = node {
248            // Track LIMIT so vector hint k can be derived within the same input chain.
249            self.vector_search.on_limit_enter(limit);
250        }
251
252        // Get order requirement from sort plan
253        if let LogicalPlan::Sort(sort) = node {
254            self.order_expr = Some(sort.expr.clone());
255
256            #[cfg(feature = "vector_index")]
257            {
258                // Capture vector ORDER BY and TopK hints from sort nodes.
259                self.vector_search.on_sort_enter(sort);
260            }
261        }
262
263        // Get time series row selector from aggr plan
264        if let LogicalPlan::Aggregate(aggregate) = node {
265            let mut is_all_last_value = !aggregate.aggr_expr.is_empty();
266            let mut order_by_expr = None;
267            for expr in &aggregate.aggr_expr {
268                // check function name
269                let Expr::AggregateFunction(func) = expr else {
270                    is_all_last_value = false;
271                    break;
272                };
273                if (func.func.name() != "last_value"
274                    && func.func.name() != aggr_state_func_name("last_value"))
275                    || func.params.filter.is_some()
276                    || func.params.distinct
277                {
278                    is_all_last_value = false;
279                    break;
280                }
281                // check order by requirement
282                let order_by = &func.params.order_by;
283                if let Some(first_order_by) = order_by.first()
284                    && order_by.len() == 1
285                {
286                    if let Some(existing_order_by) = &order_by_expr {
287                        if existing_order_by != first_order_by {
288                            is_all_last_value = false;
289                            break;
290                        }
291                    } else {
292                        // only allow `order by xxx [ASC]`, xxx is a bare column reference so `last_value()` is the max
293                        // value of the column.
294                        if !first_order_by.asc || !matches!(&first_order_by.expr, Expr::Column(_)) {
295                            is_all_last_value = false;
296                            break;
297                        }
298                        order_by_expr = Some(first_order_by.clone());
299                    }
300                }
301            }
302            is_all_last_value &= order_by_expr.is_some();
303            if is_all_last_value {
304                // make sure all the exprs are DIRECT `col` and collect them
305                let mut group_by_cols = HashSet::with_capacity(aggregate.group_expr.len());
306                for expr in &aggregate.group_expr {
307                    if let Expr::Column(col) = expr {
308                        group_by_cols.insert(col.clone());
309                    } else {
310                        is_all_last_value = false;
311                        break;
312                    }
313                }
314                // Safety: checked in the above loop
315                let order_by_expr = order_by_expr.unwrap();
316                let Expr::Column(order_by_col) = order_by_expr.expr else {
317                    unreachable!()
318                };
319                if is_all_last_value {
320                    self.ts_row_selector = Some((group_by_cols, order_by_col));
321                }
322            }
323        }
324
325        // Avoid carrying vector hints across branching inputs (join/subquery) to prevent
326        // pruning results before global ordering is applied.
327        let is_branching = matches!(node, LogicalPlan::Subquery(_)) || node.inputs().len() > 1;
328        if is_branching && self.ts_row_selector.is_some() {
329            // clean previous time series selector hint when encounter subqueries or join
330            self.ts_row_selector = None;
331        }
332        #[cfg(feature = "vector_index")]
333        if is_branching {
334            self.vector_search.on_branching_enter();
335        }
336
337        if let LogicalPlan::Filter(filter) = node
338            && let Some(group_by_exprs) = &self.ts_row_selector
339        {
340            let mut filter_referenced_cols = HashSet::default();
341            utils::expr_to_columns(&filter.predicate, &mut filter_referenced_cols)?;
342            // ensure only group_by columns are used in filter
343            if !filter_referenced_cols.is_subset(&group_by_exprs.0) {
344                self.ts_row_selector = None;
345            }
346        }
347
348        #[cfg(feature = "vector_index")]
349        if let LogicalPlan::Filter(filter) = node {
350            self.vector_search.on_filter_enter(&filter.predicate);
351        }
352
353        #[cfg(feature = "vector_index")]
354        if let LogicalPlan::TableScan(table_scan) = node {
355            // Record vector hints at leaf scans after scope checks.
356            self.vector_search.on_table_scan(table_scan);
357        }
358
359        Ok(TreeNodeRecursion::Continue)
360    }
361
362    fn f_up(&mut self, _node: &Self::Node) -> Result<TreeNodeRecursion> {
363        #[cfg(feature = "vector_index")]
364        match _node {
365            LogicalPlan::Limit(_) => {
366                self.vector_search.on_limit_exit();
367            }
368            LogicalPlan::Sort(_) => {
369                self.vector_search.on_sort_exit();
370            }
371            LogicalPlan::Filter(_) => {
372                self.vector_search.on_filter_exit();
373            }
374            LogicalPlan::Subquery(_) => {
375                self.vector_search.on_branching_exit();
376            }
377            _ if _node.inputs().len() > 1 => {
378                self.vector_search.on_branching_exit();
379            }
380            _ => {}
381        }
382
383        Ok(TreeNodeRecursion::Continue)
384    }
385}
386
387impl ScanHintVisitor {
388    fn need_rewrite(&self) -> bool {
389        let base = self.order_expr.is_some() || self.ts_row_selector.is_some();
390        #[cfg(feature = "vector_index")]
391        {
392            base || self.vector_search.need_rewrite()
393        }
394        #[cfg(not(feature = "vector_index"))]
395        {
396            base
397        }
398    }
399}
400
401#[cfg(test)]
402mod test {
403    use std::sync::Arc;
404
405    use datafusion::functions_aggregate::first_last::last_value_udaf;
406    use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
407    use datafusion_expr::{LogicalPlanBuilder, col};
408    use datafusion_optimizer::OptimizerContext;
409    use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
410    use store_api::storage::RegionId;
411
412    use super::*;
413    use crate::optimizer::test_util::{mock_table_provider, mock_table_provider_with_tsid};
414
415    #[test]
416    fn set_order_hint() {
417        let provider = Arc::new(mock_table_provider(RegionId::new(1, 1)));
418        let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
419        let plan = LogicalPlanBuilder::scan("t", table_source, None)
420            .unwrap()
421            .sort(vec![col("ts").sort(true, false)])
422            .unwrap()
423            .sort(vec![col("ts").sort(false, true)])
424            .unwrap()
425            .build()
426            .unwrap();
427
428        let context = OptimizerContext::default();
429        ScanHintRule.rewrite(plan, &context).unwrap();
430
431        // should read the first (with `.sort(true, false)`) sort option
432        let scan_req = provider.scan_request();
433        assert_eq!(
434            OrderOption {
435                name: "ts".to_string(),
436                options: SortOptions {
437                    descending: false,
438                    nulls_first: false
439                }
440            },
441            scan_req.output_ordering.as_ref().unwrap()[0]
442        );
443    }
444
445    #[test]
446    fn set_time_series_row_selector_hint() {
447        let provider = Arc::new(mock_table_provider(RegionId::new(1, 1)));
448        let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
449        let plan = LogicalPlanBuilder::scan("t", table_source, None)
450            .unwrap()
451            .aggregate(
452                vec![col("k0")],
453                vec![Expr::AggregateFunction(AggregateFunction {
454                    func: last_value_udaf(),
455                    params: AggregateFunctionParams {
456                        args: vec![col("v0")],
457                        distinct: false,
458                        filter: None,
459                        order_by: vec![Sort {
460                            expr: col("ts"),
461                            asc: true,
462                            nulls_first: true,
463                        }],
464                        null_treatment: None,
465                    },
466                })],
467            )
468            .unwrap()
469            .build()
470            .unwrap();
471
472        let context = OptimizerContext::default();
473        ScanHintRule.rewrite(plan, &context).unwrap();
474
475        let scan_req = provider.scan_request();
476        let _ = scan_req.series_row_selector.unwrap();
477    }
478
479    #[test]
480    fn set_order_hint_sets_per_series_distribution_for_tsid_sort() {
481        let provider = Arc::new(mock_table_provider_with_tsid(RegionId::new(1, 1)));
482        let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
483        let plan = LogicalPlanBuilder::scan("t", table_source, None)
484            .unwrap()
485            .sort(vec![
486                col(DATA_SCHEMA_TSID_COLUMN_NAME).sort(true, true),
487                col("ts").sort(true, true),
488            ])
489            .unwrap()
490            .build()
491            .unwrap();
492
493        let context = OptimizerContext::default();
494        ScanHintRule.rewrite(plan, &context).unwrap();
495
496        let scan_req = provider.scan_request();
497        assert_eq!(
498            scan_req.distribution,
499            Some(TimeSeriesDistribution::PerSeries)
500        );
501    }
502}