query/optimizer/
scan_hint.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16
17use api::v1::SemanticType;
18use arrow_schema::SortOptions;
19use common_function::aggrs::aggr_wrapper::aggr_state_func_name;
20use common_recordbatch::OrderOption;
21use datafusion::datasource::DefaultTableSource;
22use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor};
23use datafusion_common::{Column, Result};
24use datafusion_expr::expr::Sort;
25use datafusion_expr::{Expr, LogicalPlan, utils};
26use datafusion_optimizer::{OptimizerConfig, OptimizerRule};
27use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
28use store_api::storage::{TimeSeriesDistribution, TimeSeriesRowSelector};
29
30use crate::dummy_catalog::DummyTableProvider;
31#[cfg(feature = "vector_index")]
32mod vector_search;
33#[cfg(feature = "vector_index")]
34use vector_search::VectorSearchState;
35
36/// This rule will traverse the plan to collect necessary hints for leaf
37/// table scan node and set them in [`ScanRequest`]. Hints include:
38/// - the nearest order requirement to the leaf table scan node as ordering hint.
39/// - the group by columns when all aggregate functions are `last_value` as
40///   time series row selector hint.
41///
42/// [`ScanRequest`]: store_api::storage::ScanRequest
43#[derive(Debug)]
44pub struct ScanHintRule;
45
46impl OptimizerRule for ScanHintRule {
47    fn name(&self) -> &str {
48        "ScanHintRule"
49    }
50
51    fn rewrite(
52        &self,
53        plan: LogicalPlan,
54        _config: &dyn OptimizerConfig,
55    ) -> Result<Transformed<LogicalPlan>> {
56        Self::optimize(plan)
57    }
58}
59
60impl ScanHintRule {
61    fn optimize(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
62        let mut visitor = ScanHintVisitor::default();
63        let _ = plan.visit(&mut visitor)?;
64
65        if visitor.need_rewrite() {
66            plan.transform_down(&mut |plan| Self::set_hints(plan, &mut visitor))
67        } else {
68            Ok(Transformed::no(plan))
69        }
70    }
71
72    fn set_hints(
73        plan: LogicalPlan,
74        visitor: &mut ScanHintVisitor,
75    ) -> Result<Transformed<LogicalPlan>> {
76        match &plan {
77            LogicalPlan::TableScan(table_scan) => {
78                let mut transformed = false;
79                if let Some(source) = table_scan
80                    .source
81                    .as_any()
82                    .downcast_ref::<DefaultTableSource>()
83                {
84                    // The provider in the region server is [DummyTableProvider].
85                    if let Some(adapter) = source
86                        .table_provider
87                        .as_any()
88                        .downcast_ref::<DummyTableProvider>()
89                    {
90                        // set order_hint
91                        if let Some(order_expr) = &visitor.order_expr {
92                            Self::set_order_hint(adapter, order_expr);
93                        }
94
95                        // set time series selector hint
96                        if let Some((group_by_cols, order_by_col)) = &visitor.ts_row_selector {
97                            Self::set_time_series_row_selector_hint(
98                                adapter,
99                                group_by_cols,
100                                order_by_col,
101                            );
102                        }
103
104                        #[cfg(feature = "vector_index")]
105                        if let Some(vector_request) = visitor
106                            .vector_search
107                            .take_vector_request_from_dummy(adapter, &table_scan.table_name)
108                        {
109                            adapter.with_vector_search_hint(vector_request);
110                        }
111                        transformed = true;
112                    }
113                }
114                if transformed {
115                    Ok(Transformed::yes(plan))
116                } else {
117                    Ok(Transformed::no(plan))
118                }
119            }
120            _ => Ok(Transformed::no(plan)),
121        }
122    }
123
124    fn set_order_hint(adapter: &DummyTableProvider, order_expr: &Vec<Sort>) {
125        let mut opts = Vec::with_capacity(order_expr.len());
126        for sort in order_expr {
127            let name = match sort.expr.try_as_col() {
128                Some(col) => col.name.clone(),
129                None => return,
130            };
131            opts.push(OrderOption {
132                name,
133                options: SortOptions {
134                    descending: !sort.asc,
135                    nulls_first: sort.nulls_first,
136                },
137            });
138        }
139        adapter.with_ordering_hint(&opts);
140
141        let region_metadata = adapter.region_metadata();
142        let time_index_name = region_metadata
143            .time_index_column()
144            .column_schema
145            .name
146            .as_str();
147        let sort_cols = order_expr
148            .iter()
149            .filter_map(|s| s.expr.try_as_col())
150            .collect::<Vec<_>>();
151
152        // Special-case metric engine: when the nearest sort requirement is `__tsid, <time index>`,
153        // we can safely enable per-series distribution hint so the region can use `SeriesScan`.
154        //
155        // This pattern is produced by promql planning when `__tsid` is available and is used as the
156        // series identifier (instead of expanding to all tag columns).
157        if sort_cols.len() == 2
158            && sort_cols[0].name == DATA_SCHEMA_TSID_COLUMN_NAME
159            && sort_cols[1].name == time_index_name
160        {
161            adapter.with_distribution(TimeSeriesDistribution::PerSeries);
162            return;
163        }
164
165        let mut sort_expr_cursor = sort_cols.into_iter();
166        // ignore table without pk
167        if region_metadata.primary_key.is_empty() {
168            return;
169        }
170        let mut pk_column_iter = region_metadata.primary_key_columns();
171        let mut curr_sort_expr = sort_expr_cursor.next();
172        let mut curr_pk_col = pk_column_iter.next();
173
174        while let (Some(sort_expr), Some(pk_col)) = (curr_sort_expr, curr_pk_col) {
175            if sort_expr.name == pk_col.column_schema.name {
176                curr_sort_expr = sort_expr_cursor.next();
177                curr_pk_col = pk_column_iter.next();
178            } else {
179                return;
180            }
181        }
182
183        let next_remaining = sort_expr_cursor.next();
184        match (curr_sort_expr, next_remaining) {
185            (Some(expr), None)
186                if expr.name == region_metadata.time_index_column().column_schema.name =>
187            {
188                adapter.with_distribution(TimeSeriesDistribution::PerSeries);
189            }
190            (None, _) => adapter.with_distribution(TimeSeriesDistribution::PerSeries),
191            (Some(_), _) => {}
192        }
193    }
194
195    fn set_time_series_row_selector_hint(
196        adapter: &DummyTableProvider,
197        group_by_cols: &HashSet<Column>,
198        order_by_col: &Column,
199    ) {
200        let region_metadata = adapter.region_metadata();
201        let mut should_set_selector_hint = true;
202        // check if order_by column is time index
203        if let Some(column_metadata) = region_metadata.column_by_name(&order_by_col.name) {
204            if column_metadata.semantic_type != SemanticType::Timestamp {
205                should_set_selector_hint = false;
206            }
207        } else {
208            should_set_selector_hint = false;
209        }
210
211        // check if all group_by columns are primary key
212        for col in group_by_cols {
213            let Some(column_metadata) = region_metadata.column_by_name(&col.name) else {
214                should_set_selector_hint = false;
215                break;
216            };
217            if column_metadata.semantic_type != SemanticType::Tag {
218                should_set_selector_hint = false;
219                break;
220            }
221        }
222
223        if should_set_selector_hint {
224            adapter.with_time_series_selector_hint(TimeSeriesRowSelector::LastRow);
225        }
226    }
227}
228
229/// Traverse and fetch hints.
230#[derive(Default)]
231struct ScanHintVisitor {
232    /// The closest order requirement to the leaf node.
233    order_expr: Option<Vec<Sort>>,
234    /// Row selection on time series distribution.
235    /// This field stores saved `group_by` columns when all aggregate functions are `last_value`
236    /// and the `order_by` column which should be time index.
237    ts_row_selector: Option<(HashSet<Column>, Column)>,
238    #[cfg(feature = "vector_index")]
239    vector_search: VectorSearchState,
240}
241
242impl TreeNodeVisitor<'_> for ScanHintVisitor {
243    type Node = LogicalPlan;
244
245    fn f_down(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
246        #[cfg(feature = "vector_index")]
247        if let LogicalPlan::Limit(limit) = node {
248            // Track LIMIT so vector hint k can be derived within the same input chain.
249            self.vector_search.on_limit_enter(limit);
250        }
251
252        // Get order requirement from sort plan
253        if let LogicalPlan::Sort(sort) = node {
254            self.order_expr = Some(sort.expr.clone());
255
256            #[cfg(feature = "vector_index")]
257            {
258                // Capture vector ORDER BY and TopK hints from sort nodes.
259                self.vector_search.on_sort_enter(sort);
260            }
261        }
262
263        // Get time series row selector from aggr plan
264        if let LogicalPlan::Aggregate(aggregate) = node {
265            let mut is_all_last_value = !aggregate.aggr_expr.is_empty();
266            let mut order_by_expr = None;
267            for expr in &aggregate.aggr_expr {
268                // check function name
269                let Expr::AggregateFunction(func) = expr else {
270                    is_all_last_value = false;
271                    break;
272                };
273                if (func.func.name() != "last_value"
274                    && func.func.name() != aggr_state_func_name("last_value"))
275                    || func.params.filter.is_some()
276                    || func.params.distinct
277                {
278                    is_all_last_value = false;
279                    break;
280                }
281                // check order by requirement
282                let order_by = &func.params.order_by;
283                if let Some(first_order_by) = order_by.first()
284                    && order_by.len() == 1
285                {
286                    if let Some(existing_order_by) = &order_by_expr {
287                        if existing_order_by != first_order_by {
288                            is_all_last_value = false;
289                            break;
290                        }
291                    } else {
292                        // only allow `order by xxx [ASC]`, xxx is a bare column reference so `last_value()` is the max
293                        // value of the column.
294                        if !first_order_by.asc || !matches!(&first_order_by.expr, Expr::Column(_)) {
295                            is_all_last_value = false;
296                            break;
297                        }
298                        order_by_expr = Some(first_order_by.clone());
299                    }
300                }
301            }
302            is_all_last_value &= order_by_expr.is_some();
303            if is_all_last_value {
304                // make sure all the exprs are DIRECT `col` and collect them
305                let mut group_by_cols = HashSet::with_capacity(aggregate.group_expr.len());
306                for expr in &aggregate.group_expr {
307                    if let Expr::Column(col) = expr {
308                        group_by_cols.insert(col.clone());
309                    } else {
310                        is_all_last_value = false;
311                        break;
312                    }
313                }
314                // Safety: checked in the above loop
315                let order_by_expr = order_by_expr.unwrap();
316                let Expr::Column(order_by_col) = order_by_expr.expr else {
317                    unreachable!()
318                };
319                if is_all_last_value {
320                    self.ts_row_selector = Some((group_by_cols, order_by_col));
321                }
322            }
323        }
324
325        // Avoid carrying vector hints across branching inputs (join/subquery) to prevent
326        // pruning results before global ordering is applied. Only treat a subquery as a
327        // barrier when it contains non-inlineable operators.
328        let is_branching_for_ts = matches!(
329            node,
330            LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_)
331        ) || node.inputs().len() > 1;
332        if is_branching_for_ts && self.ts_row_selector.is_some() {
333            // clean previous time series selector hint when encounter subqueries or join
334            self.ts_row_selector = None;
335        }
336        #[cfg(feature = "vector_index")]
337        if is_branching_for_vector(node) {
338            self.vector_search.on_branching_enter();
339        }
340
341        if let LogicalPlan::Filter(filter) = node
342            && let Some(group_by_exprs) = &self.ts_row_selector
343        {
344            let mut filter_referenced_cols = HashSet::default();
345            utils::expr_to_columns(&filter.predicate, &mut filter_referenced_cols)?;
346            // ensure only group_by columns are used in filter
347            if !filter_referenced_cols.is_subset(&group_by_exprs.0) {
348                self.ts_row_selector = None;
349            }
350        }
351
352        #[cfg(feature = "vector_index")]
353        if let LogicalPlan::Filter(filter) = node {
354            self.vector_search.on_filter_enter(&filter.predicate);
355        }
356
357        #[cfg(feature = "vector_index")]
358        if let LogicalPlan::TableScan(table_scan) = node {
359            // Record vector hints at leaf scans after scope checks.
360            self.vector_search.on_table_scan(table_scan);
361        }
362
363        Ok(TreeNodeRecursion::Continue)
364    }
365
366    fn f_up(&mut self, _node: &Self::Node) -> Result<TreeNodeRecursion> {
367        #[cfg(feature = "vector_index")]
368        match _node {
369            LogicalPlan::Limit(_) => {
370                self.vector_search.on_limit_exit();
371            }
372            LogicalPlan::Sort(_) => {
373                self.vector_search.on_sort_exit();
374            }
375            LogicalPlan::Filter(_) => {
376                self.vector_search.on_filter_exit();
377            }
378            LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) => {
379                if is_branching_for_vector(_node) {
380                    self.vector_search.on_branching_exit();
381                }
382            }
383            _ if _node.inputs().len() > 1 => {
384                self.vector_search.on_branching_exit();
385            }
386            _ => {}
387        }
388
389        Ok(TreeNodeRecursion::Continue)
390    }
391}
392
393impl ScanHintVisitor {
394    fn need_rewrite(&self) -> bool {
395        let base = self.order_expr.is_some() || self.ts_row_selector.is_some();
396        #[cfg(feature = "vector_index")]
397        {
398            base || self.vector_search.need_rewrite()
399        }
400        #[cfg(not(feature = "vector_index"))]
401        {
402            base
403        }
404    }
405}
406
407#[cfg(feature = "vector_index")]
408fn is_branching_for_vector(node: &LogicalPlan) -> bool {
409    if node.inputs().len() > 1 {
410        return true;
411    }
412
413    match node {
414        LogicalPlan::Subquery(subquery) => has_non_inlineable_ops(subquery.subquery.as_ref()),
415        LogicalPlan::SubqueryAlias(alias) => has_non_inlineable_ops(alias.input.as_ref()),
416        _ => false,
417    }
418}
419
420#[cfg(feature = "vector_index")]
421fn has_non_inlineable_ops(plan: &LogicalPlan) -> bool {
422    if matches!(
423        plan,
424        LogicalPlan::Limit(_)
425            | LogicalPlan::Sort(_)
426            | LogicalPlan::Distinct(_)
427            | LogicalPlan::Aggregate(_)
428            | LogicalPlan::Window(_)
429            | LogicalPlan::Union(_)
430            | LogicalPlan::Join(_)
431    ) {
432        return true;
433    }
434
435    for input in plan.inputs() {
436        if has_non_inlineable_ops(input) {
437            return true;
438        }
439    }
440
441    false
442}
443
444#[cfg(test)]
445mod test {
446    use std::sync::Arc;
447
448    use datafusion::functions_aggregate::first_last::last_value_udaf;
449    use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
450    use datafusion_expr::{LogicalPlanBuilder, col};
451    use datafusion_optimizer::OptimizerContext;
452    use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
453    use store_api::storage::RegionId;
454
455    use super::*;
456    use crate::optimizer::test_util::{mock_table_provider, mock_table_provider_with_tsid};
457
458    #[test]
459    fn set_order_hint() {
460        let provider = Arc::new(mock_table_provider(RegionId::new(1, 1)));
461        let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
462        let plan = LogicalPlanBuilder::scan("t", table_source, None)
463            .unwrap()
464            .sort(vec![col("ts").sort(true, false)])
465            .unwrap()
466            .sort(vec![col("ts").sort(false, true)])
467            .unwrap()
468            .build()
469            .unwrap();
470
471        let context = OptimizerContext::default();
472        ScanHintRule.rewrite(plan, &context).unwrap();
473
474        // should read the first (with `.sort(true, false)`) sort option
475        let scan_req = provider.scan_request();
476        assert_eq!(
477            OrderOption {
478                name: "ts".to_string(),
479                options: SortOptions {
480                    descending: false,
481                    nulls_first: false
482                }
483            },
484            scan_req.output_ordering.as_ref().unwrap()[0]
485        );
486    }
487
488    #[test]
489    fn set_time_series_row_selector_hint() {
490        let provider = Arc::new(mock_table_provider(RegionId::new(1, 1)));
491        let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
492        let plan = LogicalPlanBuilder::scan("t", table_source, None)
493            .unwrap()
494            .aggregate(
495                vec![col("k0")],
496                vec![Expr::AggregateFunction(AggregateFunction {
497                    func: last_value_udaf(),
498                    params: AggregateFunctionParams {
499                        args: vec![col("v0")],
500                        distinct: false,
501                        filter: None,
502                        order_by: vec![Sort {
503                            expr: col("ts"),
504                            asc: true,
505                            nulls_first: true,
506                        }],
507                        null_treatment: None,
508                    },
509                })],
510            )
511            .unwrap()
512            .build()
513            .unwrap();
514
515        let context = OptimizerContext::default();
516        ScanHintRule.rewrite(plan, &context).unwrap();
517
518        let scan_req = provider.scan_request();
519        let _ = scan_req.series_row_selector.unwrap();
520    }
521
522    #[test]
523    fn set_order_hint_sets_per_series_distribution_for_tsid_sort() {
524        let provider = Arc::new(mock_table_provider_with_tsid(RegionId::new(1, 1)));
525        let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
526        let plan = LogicalPlanBuilder::scan("t", table_source, None)
527            .unwrap()
528            .sort(vec![
529                col(DATA_SCHEMA_TSID_COLUMN_NAME).sort(true, true),
530                col("ts").sort(true, true),
531            ])
532            .unwrap()
533            .build()
534            .unwrap();
535
536        let context = OptimizerContext::default();
537        ScanHintRule.rewrite(plan, &context).unwrap();
538
539        let scan_req = provider.scan_request();
540        assert_eq!(
541            scan_req.distribution,
542            Some(TimeSeriesDistribution::PerSeries)
543        );
544    }
545}