query/optimizer/
scan_hint.rs1use std::collections::HashSet;
16
17use api::v1::SemanticType;
18use arrow_schema::SortOptions;
19use common_function::aggrs::aggr_wrapper::aggr_state_func_name;
20use common_recordbatch::OrderOption;
21use datafusion::datasource::DefaultTableSource;
22use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor};
23use datafusion_common::{Column, Result};
24use datafusion_expr::expr::Sort;
25use datafusion_expr::{Expr, LogicalPlan, utils};
26use datafusion_optimizer::{OptimizerConfig, OptimizerRule};
27use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
28use store_api::storage::{TimeSeriesDistribution, TimeSeriesRowSelector};
29
30use crate::dummy_catalog::DummyTableProvider;
31#[cfg(feature = "vector_index")]
32mod vector_search;
33#[cfg(feature = "vector_index")]
34use vector_search::VectorSearchState;
35
36#[derive(Debug)]
44pub struct ScanHintRule;
45
46impl OptimizerRule for ScanHintRule {
47 fn name(&self) -> &str {
48 "ScanHintRule"
49 }
50
51 fn rewrite(
52 &self,
53 plan: LogicalPlan,
54 _config: &dyn OptimizerConfig,
55 ) -> Result<Transformed<LogicalPlan>> {
56 Self::optimize(plan)
57 }
58}
59
60impl ScanHintRule {
61 fn optimize(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
62 let mut visitor = ScanHintVisitor::default();
63 let _ = plan.visit(&mut visitor)?;
64
65 if visitor.need_rewrite() {
66 plan.transform_down(&mut |plan| Self::set_hints(plan, &mut visitor))
67 } else {
68 Ok(Transformed::no(plan))
69 }
70 }
71
72 fn set_hints(
73 plan: LogicalPlan,
74 visitor: &mut ScanHintVisitor,
75 ) -> Result<Transformed<LogicalPlan>> {
76 match &plan {
77 LogicalPlan::TableScan(table_scan) => {
78 let mut transformed = false;
79 if let Some(source) = table_scan
80 .source
81 .as_any()
82 .downcast_ref::<DefaultTableSource>()
83 {
84 if let Some(adapter) = source
86 .table_provider
87 .as_any()
88 .downcast_ref::<DummyTableProvider>()
89 {
90 if let Some(order_expr) = &visitor.order_expr {
92 Self::set_order_hint(adapter, order_expr);
93 }
94
95 if let Some((group_by_cols, order_by_col)) = &visitor.ts_row_selector {
97 Self::set_time_series_row_selector_hint(
98 adapter,
99 group_by_cols,
100 order_by_col,
101 );
102 }
103
104 #[cfg(feature = "vector_index")]
105 if let Some(vector_request) = visitor
106 .vector_search
107 .take_vector_request_from_dummy(adapter, &table_scan.table_name)
108 {
109 adapter.with_vector_search_hint(vector_request);
110 }
111 transformed = true;
112 }
113 }
114 if transformed {
115 Ok(Transformed::yes(plan))
116 } else {
117 Ok(Transformed::no(plan))
118 }
119 }
120 _ => Ok(Transformed::no(plan)),
121 }
122 }
123
124 fn set_order_hint(adapter: &DummyTableProvider, order_expr: &Vec<Sort>) {
125 let mut opts = Vec::with_capacity(order_expr.len());
126 for sort in order_expr {
127 let name = match sort.expr.try_as_col() {
128 Some(col) => col.name.clone(),
129 None => return,
130 };
131 opts.push(OrderOption {
132 name,
133 options: SortOptions {
134 descending: !sort.asc,
135 nulls_first: sort.nulls_first,
136 },
137 });
138 }
139 adapter.with_ordering_hint(&opts);
140
141 let region_metadata = adapter.region_metadata();
142 let time_index_name = region_metadata
143 .time_index_column()
144 .column_schema
145 .name
146 .as_str();
147 let sort_cols = order_expr
148 .iter()
149 .filter_map(|s| s.expr.try_as_col())
150 .collect::<Vec<_>>();
151
152 if sort_cols.len() == 2
158 && sort_cols[0].name == DATA_SCHEMA_TSID_COLUMN_NAME
159 && sort_cols[1].name == time_index_name
160 {
161 adapter.with_distribution(TimeSeriesDistribution::PerSeries);
162 return;
163 }
164
165 let mut sort_expr_cursor = sort_cols.into_iter();
166 if region_metadata.primary_key.is_empty() {
168 return;
169 }
170 let mut pk_column_iter = region_metadata.primary_key_columns();
171 let mut curr_sort_expr = sort_expr_cursor.next();
172 let mut curr_pk_col = pk_column_iter.next();
173
174 while let (Some(sort_expr), Some(pk_col)) = (curr_sort_expr, curr_pk_col) {
175 if sort_expr.name == pk_col.column_schema.name {
176 curr_sort_expr = sort_expr_cursor.next();
177 curr_pk_col = pk_column_iter.next();
178 } else {
179 return;
180 }
181 }
182
183 let next_remaining = sort_expr_cursor.next();
184 match (curr_sort_expr, next_remaining) {
185 (Some(expr), None)
186 if expr.name == region_metadata.time_index_column().column_schema.name =>
187 {
188 adapter.with_distribution(TimeSeriesDistribution::PerSeries);
189 }
190 (None, _) => adapter.with_distribution(TimeSeriesDistribution::PerSeries),
191 (Some(_), _) => {}
192 }
193 }
194
195 fn set_time_series_row_selector_hint(
196 adapter: &DummyTableProvider,
197 group_by_cols: &HashSet<Column>,
198 order_by_col: &Column,
199 ) {
200 let region_metadata = adapter.region_metadata();
201 let mut should_set_selector_hint = true;
202 if let Some(column_metadata) = region_metadata.column_by_name(&order_by_col.name) {
204 if column_metadata.semantic_type != SemanticType::Timestamp {
205 should_set_selector_hint = false;
206 }
207 } else {
208 should_set_selector_hint = false;
209 }
210
211 for col in group_by_cols {
213 let Some(column_metadata) = region_metadata.column_by_name(&col.name) else {
214 should_set_selector_hint = false;
215 break;
216 };
217 if column_metadata.semantic_type != SemanticType::Tag {
218 should_set_selector_hint = false;
219 break;
220 }
221 }
222
223 if should_set_selector_hint {
224 adapter.with_time_series_selector_hint(TimeSeriesRowSelector::LastRow);
225 }
226 }
227}
228
229#[derive(Default)]
231struct ScanHintVisitor {
232 order_expr: Option<Vec<Sort>>,
234 ts_row_selector: Option<(HashSet<Column>, Column)>,
238 #[cfg(feature = "vector_index")]
239 vector_search: VectorSearchState,
240}
241
242impl TreeNodeVisitor<'_> for ScanHintVisitor {
243 type Node = LogicalPlan;
244
245 fn f_down(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
246 #[cfg(feature = "vector_index")]
247 if let LogicalPlan::Limit(limit) = node {
248 self.vector_search.on_limit_enter(limit);
250 }
251
252 if let LogicalPlan::Sort(sort) = node {
254 self.order_expr = Some(sort.expr.clone());
255
256 #[cfg(feature = "vector_index")]
257 {
258 self.vector_search.on_sort_enter(sort);
260 }
261 }
262
263 if let LogicalPlan::Aggregate(aggregate) = node {
265 let mut is_all_last_value = !aggregate.aggr_expr.is_empty();
266 let mut order_by_expr = None;
267 for expr in &aggregate.aggr_expr {
268 let Expr::AggregateFunction(func) = expr else {
270 is_all_last_value = false;
271 break;
272 };
273 if (func.func.name() != "last_value"
274 && func.func.name() != aggr_state_func_name("last_value"))
275 || func.params.filter.is_some()
276 || func.params.distinct
277 {
278 is_all_last_value = false;
279 break;
280 }
281 let order_by = &func.params.order_by;
283 if let Some(first_order_by) = order_by.first()
284 && order_by.len() == 1
285 {
286 if let Some(existing_order_by) = &order_by_expr {
287 if existing_order_by != first_order_by {
288 is_all_last_value = false;
289 break;
290 }
291 } else {
292 if !first_order_by.asc || !matches!(&first_order_by.expr, Expr::Column(_)) {
295 is_all_last_value = false;
296 break;
297 }
298 order_by_expr = Some(first_order_by.clone());
299 }
300 }
301 }
302 is_all_last_value &= order_by_expr.is_some();
303 if is_all_last_value {
304 let mut group_by_cols = HashSet::with_capacity(aggregate.group_expr.len());
306 for expr in &aggregate.group_expr {
307 if let Expr::Column(col) = expr {
308 group_by_cols.insert(col.clone());
309 } else {
310 is_all_last_value = false;
311 break;
312 }
313 }
314 let order_by_expr = order_by_expr.unwrap();
316 let Expr::Column(order_by_col) = order_by_expr.expr else {
317 unreachable!()
318 };
319 if is_all_last_value {
320 self.ts_row_selector = Some((group_by_cols, order_by_col));
321 }
322 }
323 }
324
325 let is_branching = matches!(node, LogicalPlan::Subquery(_)) || node.inputs().len() > 1;
328 if is_branching && self.ts_row_selector.is_some() {
329 self.ts_row_selector = None;
331 }
332 #[cfg(feature = "vector_index")]
333 if is_branching {
334 self.vector_search.on_branching_enter();
335 }
336
337 if let LogicalPlan::Filter(filter) = node
338 && let Some(group_by_exprs) = &self.ts_row_selector
339 {
340 let mut filter_referenced_cols = HashSet::default();
341 utils::expr_to_columns(&filter.predicate, &mut filter_referenced_cols)?;
342 if !filter_referenced_cols.is_subset(&group_by_exprs.0) {
344 self.ts_row_selector = None;
345 }
346 }
347
348 #[cfg(feature = "vector_index")]
349 if let LogicalPlan::Filter(filter) = node {
350 self.vector_search.on_filter_enter(&filter.predicate);
351 }
352
353 #[cfg(feature = "vector_index")]
354 if let LogicalPlan::TableScan(table_scan) = node {
355 self.vector_search.on_table_scan(table_scan);
357 }
358
359 Ok(TreeNodeRecursion::Continue)
360 }
361
362 fn f_up(&mut self, _node: &Self::Node) -> Result<TreeNodeRecursion> {
363 #[cfg(feature = "vector_index")]
364 match _node {
365 LogicalPlan::Limit(_) => {
366 self.vector_search.on_limit_exit();
367 }
368 LogicalPlan::Sort(_) => {
369 self.vector_search.on_sort_exit();
370 }
371 LogicalPlan::Filter(_) => {
372 self.vector_search.on_filter_exit();
373 }
374 LogicalPlan::Subquery(_) => {
375 self.vector_search.on_branching_exit();
376 }
377 _ if _node.inputs().len() > 1 => {
378 self.vector_search.on_branching_exit();
379 }
380 _ => {}
381 }
382
383 Ok(TreeNodeRecursion::Continue)
384 }
385}
386
387impl ScanHintVisitor {
388 fn need_rewrite(&self) -> bool {
389 let base = self.order_expr.is_some() || self.ts_row_selector.is_some();
390 #[cfg(feature = "vector_index")]
391 {
392 base || self.vector_search.need_rewrite()
393 }
394 #[cfg(not(feature = "vector_index"))]
395 {
396 base
397 }
398 }
399}
400
401#[cfg(test)]
402mod test {
403 use std::sync::Arc;
404
405 use datafusion::functions_aggregate::first_last::last_value_udaf;
406 use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
407 use datafusion_expr::{LogicalPlanBuilder, col};
408 use datafusion_optimizer::OptimizerContext;
409 use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
410 use store_api::storage::RegionId;
411
412 use super::*;
413 use crate::optimizer::test_util::{mock_table_provider, mock_table_provider_with_tsid};
414
415 #[test]
416 fn set_order_hint() {
417 let provider = Arc::new(mock_table_provider(RegionId::new(1, 1)));
418 let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
419 let plan = LogicalPlanBuilder::scan("t", table_source, None)
420 .unwrap()
421 .sort(vec![col("ts").sort(true, false)])
422 .unwrap()
423 .sort(vec![col("ts").sort(false, true)])
424 .unwrap()
425 .build()
426 .unwrap();
427
428 let context = OptimizerContext::default();
429 ScanHintRule.rewrite(plan, &context).unwrap();
430
431 let scan_req = provider.scan_request();
433 assert_eq!(
434 OrderOption {
435 name: "ts".to_string(),
436 options: SortOptions {
437 descending: false,
438 nulls_first: false
439 }
440 },
441 scan_req.output_ordering.as_ref().unwrap()[0]
442 );
443 }
444
445 #[test]
446 fn set_time_series_row_selector_hint() {
447 let provider = Arc::new(mock_table_provider(RegionId::new(1, 1)));
448 let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
449 let plan = LogicalPlanBuilder::scan("t", table_source, None)
450 .unwrap()
451 .aggregate(
452 vec![col("k0")],
453 vec![Expr::AggregateFunction(AggregateFunction {
454 func: last_value_udaf(),
455 params: AggregateFunctionParams {
456 args: vec![col("v0")],
457 distinct: false,
458 filter: None,
459 order_by: vec![Sort {
460 expr: col("ts"),
461 asc: true,
462 nulls_first: true,
463 }],
464 null_treatment: None,
465 },
466 })],
467 )
468 .unwrap()
469 .build()
470 .unwrap();
471
472 let context = OptimizerContext::default();
473 ScanHintRule.rewrite(plan, &context).unwrap();
474
475 let scan_req = provider.scan_request();
476 let _ = scan_req.series_row_selector.unwrap();
477 }
478
479 #[test]
480 fn set_order_hint_sets_per_series_distribution_for_tsid_sort() {
481 let provider = Arc::new(mock_table_provider_with_tsid(RegionId::new(1, 1)));
482 let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
483 let plan = LogicalPlanBuilder::scan("t", table_source, None)
484 .unwrap()
485 .sort(vec![
486 col(DATA_SCHEMA_TSID_COLUMN_NAME).sort(true, true),
487 col("ts").sort(true, true),
488 ])
489 .unwrap()
490 .build()
491 .unwrap();
492
493 let context = OptimizerContext::default();
494 ScanHintRule.rewrite(plan, &context).unwrap();
495
496 let scan_req = provider.scan_request();
497 assert_eq!(
498 scan_req.distribution,
499 Some(TimeSeriesDistribution::PerSeries)
500 );
501 }
502}