1use std::collections::HashSet;
16
17use api::v1::SemanticType;
18use arrow_schema::SortOptions;
19use common_function::aggrs::aggr_wrapper::aggr_state_func_name;
20use common_recordbatch::OrderOption;
21use datafusion::datasource::DefaultTableSource;
22use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor};
23use datafusion_common::{Column, Result};
24use datafusion_expr::expr::Sort;
25use datafusion_expr::{Expr, LogicalPlan, utils};
26use datafusion_optimizer::{OptimizerConfig, OptimizerRule};
27use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
28use store_api::storage::{TimeSeriesDistribution, TimeSeriesRowSelector};
29
30use crate::dummy_catalog::DummyTableProvider;
31#[cfg(feature = "vector_index")]
32mod vector_search;
33#[cfg(feature = "vector_index")]
34use vector_search::VectorSearchState;
35
36#[derive(Debug)]
44pub struct ScanHintRule;
45
46impl OptimizerRule for ScanHintRule {
47 fn name(&self) -> &str {
48 "ScanHintRule"
49 }
50
51 fn rewrite(
52 &self,
53 plan: LogicalPlan,
54 _config: &dyn OptimizerConfig,
55 ) -> Result<Transformed<LogicalPlan>> {
56 Self::optimize(plan)
57 }
58}
59
60impl ScanHintRule {
61 fn optimize(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
62 let mut visitor = ScanHintVisitor::default();
63 let _ = plan.visit(&mut visitor)?;
64
65 if visitor.need_rewrite() {
66 plan.transform_down(&mut |plan| Self::set_hints(plan, &mut visitor))
67 } else {
68 Ok(Transformed::no(plan))
69 }
70 }
71
72 fn set_hints(
73 plan: LogicalPlan,
74 visitor: &mut ScanHintVisitor,
75 ) -> Result<Transformed<LogicalPlan>> {
76 match &plan {
77 LogicalPlan::TableScan(table_scan) => {
78 let mut transformed = false;
79 if let Some(source) = table_scan
80 .source
81 .as_any()
82 .downcast_ref::<DefaultTableSource>()
83 {
84 if let Some(adapter) = source
86 .table_provider
87 .as_any()
88 .downcast_ref::<DummyTableProvider>()
89 {
90 if let Some(order_expr) = &visitor.order_expr {
92 Self::set_order_hint(adapter, order_expr);
93 }
94
95 if let Some((group_by_cols, order_by_col)) = &visitor.ts_row_selector {
97 Self::set_time_series_row_selector_hint(
98 adapter,
99 group_by_cols,
100 order_by_col,
101 );
102 }
103
104 #[cfg(feature = "vector_index")]
105 if let Some(vector_request) = visitor
106 .vector_search
107 .take_vector_request_from_dummy(adapter, &table_scan.table_name)
108 {
109 adapter.with_vector_search_hint(vector_request);
110 }
111 transformed = true;
112 }
113 }
114 if transformed {
115 Ok(Transformed::yes(plan))
116 } else {
117 Ok(Transformed::no(plan))
118 }
119 }
120 _ => Ok(Transformed::no(plan)),
121 }
122 }
123
124 fn set_order_hint(adapter: &DummyTableProvider, order_expr: &Vec<Sort>) {
125 let mut opts = Vec::with_capacity(order_expr.len());
126 for sort in order_expr {
127 let name = match sort.expr.try_as_col() {
128 Some(col) => col.name.clone(),
129 None => return,
130 };
131 opts.push(OrderOption {
132 name,
133 options: SortOptions {
134 descending: !sort.asc,
135 nulls_first: sort.nulls_first,
136 },
137 });
138 }
139 adapter.with_ordering_hint(&opts);
140
141 let region_metadata = adapter.region_metadata();
142 let time_index_name = region_metadata
143 .time_index_column()
144 .column_schema
145 .name
146 .as_str();
147 let sort_cols = order_expr
148 .iter()
149 .filter_map(|s| s.expr.try_as_col())
150 .collect::<Vec<_>>();
151
152 if sort_cols.len() == 2
158 && sort_cols[0].name == DATA_SCHEMA_TSID_COLUMN_NAME
159 && sort_cols[1].name == time_index_name
160 {
161 adapter.with_distribution(TimeSeriesDistribution::PerSeries);
162 return;
163 }
164
165 let mut sort_expr_cursor = sort_cols.into_iter();
166 if region_metadata.primary_key.is_empty() {
168 return;
169 }
170 let mut pk_column_iter = region_metadata.primary_key_columns();
171 let mut curr_sort_expr = sort_expr_cursor.next();
172 let mut curr_pk_col = pk_column_iter.next();
173
174 while let (Some(sort_expr), Some(pk_col)) = (curr_sort_expr, curr_pk_col) {
175 if sort_expr.name == pk_col.column_schema.name {
176 curr_sort_expr = sort_expr_cursor.next();
177 curr_pk_col = pk_column_iter.next();
178 } else {
179 return;
180 }
181 }
182
183 let next_remaining = sort_expr_cursor.next();
184 match (curr_sort_expr, next_remaining) {
185 (Some(expr), None)
186 if expr.name == region_metadata.time_index_column().column_schema.name =>
187 {
188 adapter.with_distribution(TimeSeriesDistribution::PerSeries);
189 }
190 (None, _) => adapter.with_distribution(TimeSeriesDistribution::PerSeries),
191 (Some(_), _) => {}
192 }
193 }
194
195 fn set_time_series_row_selector_hint(
196 adapter: &DummyTableProvider,
197 group_by_cols: &HashSet<Column>,
198 order_by_col: &Column,
199 ) {
200 let region_metadata = adapter.region_metadata();
201 let mut should_set_selector_hint = true;
202 if let Some(column_metadata) = region_metadata.column_by_name(&order_by_col.name) {
204 if column_metadata.semantic_type != SemanticType::Timestamp {
205 should_set_selector_hint = false;
206 }
207 } else {
208 should_set_selector_hint = false;
209 }
210
211 for col in group_by_cols {
213 let Some(column_metadata) = region_metadata.column_by_name(&col.name) else {
214 should_set_selector_hint = false;
215 break;
216 };
217 if column_metadata.semantic_type != SemanticType::Tag {
218 should_set_selector_hint = false;
219 break;
220 }
221 }
222
223 if should_set_selector_hint {
224 adapter.with_time_series_selector_hint(TimeSeriesRowSelector::LastRow);
225 }
226 }
227}
228
229#[derive(Default)]
231struct ScanHintVisitor {
232 order_expr: Option<Vec<Sort>>,
234 ts_row_selector: Option<(HashSet<Column>, Column)>,
238 #[cfg(feature = "vector_index")]
239 vector_search: VectorSearchState,
240}
241
242impl TreeNodeVisitor<'_> for ScanHintVisitor {
243 type Node = LogicalPlan;
244
245 fn f_down(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
246 #[cfg(feature = "vector_index")]
247 if let LogicalPlan::Limit(limit) = node {
248 self.vector_search.on_limit_enter(limit);
250 }
251
252 if let LogicalPlan::Sort(sort) = node {
254 self.order_expr = Some(sort.expr.clone());
255
256 #[cfg(feature = "vector_index")]
257 {
258 self.vector_search.on_sort_enter(sort);
260 }
261 }
262
263 if let LogicalPlan::Aggregate(aggregate) = node {
265 let mut is_all_last_value = !aggregate.aggr_expr.is_empty();
266 let mut order_by_expr = None;
267 for expr in &aggregate.aggr_expr {
268 let Expr::AggregateFunction(func) = expr else {
270 is_all_last_value = false;
271 break;
272 };
273 if (func.func.name() != "last_value"
274 && func.func.name() != aggr_state_func_name("last_value"))
275 || func.params.filter.is_some()
276 || func.params.distinct
277 {
278 is_all_last_value = false;
279 break;
280 }
281 let order_by = &func.params.order_by;
283 if let Some(first_order_by) = order_by.first()
284 && order_by.len() == 1
285 {
286 if let Some(existing_order_by) = &order_by_expr {
287 if existing_order_by != first_order_by {
288 is_all_last_value = false;
289 break;
290 }
291 } else {
292 if !first_order_by.asc || !matches!(&first_order_by.expr, Expr::Column(_)) {
295 is_all_last_value = false;
296 break;
297 }
298 order_by_expr = Some(first_order_by.clone());
299 }
300 }
301 }
302 is_all_last_value &= order_by_expr.is_some();
303 if is_all_last_value {
304 let mut group_by_cols = HashSet::with_capacity(aggregate.group_expr.len());
306 for expr in &aggregate.group_expr {
307 if let Expr::Column(col) = expr {
308 group_by_cols.insert(col.clone());
309 } else {
310 is_all_last_value = false;
311 break;
312 }
313 }
314 let order_by_expr = order_by_expr.unwrap();
316 let Expr::Column(order_by_col) = order_by_expr.expr else {
317 unreachable!()
318 };
319 if is_all_last_value {
320 self.ts_row_selector = Some((group_by_cols, order_by_col));
321 }
322 }
323 }
324
325 let is_branching_for_ts = matches!(
329 node,
330 LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_)
331 ) || node.inputs().len() > 1;
332 if is_branching_for_ts && self.ts_row_selector.is_some() {
333 self.ts_row_selector = None;
335 }
336 #[cfg(feature = "vector_index")]
337 if is_branching_for_vector(node) {
338 self.vector_search.on_branching_enter();
339 }
340
341 if let LogicalPlan::Filter(filter) = node
342 && let Some(group_by_exprs) = &self.ts_row_selector
343 {
344 let mut filter_referenced_cols = HashSet::default();
345 utils::expr_to_columns(&filter.predicate, &mut filter_referenced_cols)?;
346 if !filter_referenced_cols.is_subset(&group_by_exprs.0) {
348 self.ts_row_selector = None;
349 }
350 }
351
352 #[cfg(feature = "vector_index")]
353 if let LogicalPlan::Filter(filter) = node {
354 self.vector_search.on_filter_enter(&filter.predicate);
355 }
356
357 #[cfg(feature = "vector_index")]
358 if let LogicalPlan::TableScan(table_scan) = node {
359 self.vector_search.on_table_scan(table_scan);
361 }
362
363 Ok(TreeNodeRecursion::Continue)
364 }
365
366 fn f_up(&mut self, _node: &Self::Node) -> Result<TreeNodeRecursion> {
367 #[cfg(feature = "vector_index")]
368 match _node {
369 LogicalPlan::Limit(_) => {
370 self.vector_search.on_limit_exit();
371 }
372 LogicalPlan::Sort(_) => {
373 self.vector_search.on_sort_exit();
374 }
375 LogicalPlan::Filter(_) => {
376 self.vector_search.on_filter_exit();
377 }
378 LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) => {
379 if is_branching_for_vector(_node) {
380 self.vector_search.on_branching_exit();
381 }
382 }
383 _ if _node.inputs().len() > 1 => {
384 self.vector_search.on_branching_exit();
385 }
386 _ => {}
387 }
388
389 Ok(TreeNodeRecursion::Continue)
390 }
391}
392
393impl ScanHintVisitor {
394 fn need_rewrite(&self) -> bool {
395 let base = self.order_expr.is_some() || self.ts_row_selector.is_some();
396 #[cfg(feature = "vector_index")]
397 {
398 base || self.vector_search.need_rewrite()
399 }
400 #[cfg(not(feature = "vector_index"))]
401 {
402 base
403 }
404 }
405}
406
407#[cfg(feature = "vector_index")]
408fn is_branching_for_vector(node: &LogicalPlan) -> bool {
409 if node.inputs().len() > 1 {
410 return true;
411 }
412
413 match node {
414 LogicalPlan::Subquery(subquery) => has_non_inlineable_ops(subquery.subquery.as_ref()),
415 LogicalPlan::SubqueryAlias(alias) => has_non_inlineable_ops(alias.input.as_ref()),
416 _ => false,
417 }
418}
419
420#[cfg(feature = "vector_index")]
421fn has_non_inlineable_ops(plan: &LogicalPlan) -> bool {
422 if matches!(
423 plan,
424 LogicalPlan::Limit(_)
425 | LogicalPlan::Sort(_)
426 | LogicalPlan::Distinct(_)
427 | LogicalPlan::Aggregate(_)
428 | LogicalPlan::Window(_)
429 | LogicalPlan::Union(_)
430 | LogicalPlan::Join(_)
431 ) {
432 return true;
433 }
434
435 for input in plan.inputs() {
436 if has_non_inlineable_ops(input) {
437 return true;
438 }
439 }
440
441 false
442}
443
444#[cfg(test)]
445mod test {
446 use std::sync::Arc;
447
448 use datafusion::functions_aggregate::first_last::last_value_udaf;
449 use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
450 use datafusion_expr::{LogicalPlanBuilder, col};
451 use datafusion_optimizer::OptimizerContext;
452 use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
453 use store_api::storage::RegionId;
454
455 use super::*;
456 use crate::optimizer::test_util::{mock_table_provider, mock_table_provider_with_tsid};
457
458 #[test]
459 fn set_order_hint() {
460 let provider = Arc::new(mock_table_provider(RegionId::new(1, 1)));
461 let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
462 let plan = LogicalPlanBuilder::scan("t", table_source, None)
463 .unwrap()
464 .sort(vec![col("ts").sort(true, false)])
465 .unwrap()
466 .sort(vec![col("ts").sort(false, true)])
467 .unwrap()
468 .build()
469 .unwrap();
470
471 let context = OptimizerContext::default();
472 ScanHintRule.rewrite(plan, &context).unwrap();
473
474 let scan_req = provider.scan_request();
476 assert_eq!(
477 OrderOption {
478 name: "ts".to_string(),
479 options: SortOptions {
480 descending: false,
481 nulls_first: false
482 }
483 },
484 scan_req.output_ordering.as_ref().unwrap()[0]
485 );
486 }
487
488 #[test]
489 fn set_time_series_row_selector_hint() {
490 let provider = Arc::new(mock_table_provider(RegionId::new(1, 1)));
491 let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
492 let plan = LogicalPlanBuilder::scan("t", table_source, None)
493 .unwrap()
494 .aggregate(
495 vec![col("k0")],
496 vec![Expr::AggregateFunction(AggregateFunction {
497 func: last_value_udaf(),
498 params: AggregateFunctionParams {
499 args: vec![col("v0")],
500 distinct: false,
501 filter: None,
502 order_by: vec![Sort {
503 expr: col("ts"),
504 asc: true,
505 nulls_first: true,
506 }],
507 null_treatment: None,
508 },
509 })],
510 )
511 .unwrap()
512 .build()
513 .unwrap();
514
515 let context = OptimizerContext::default();
516 ScanHintRule.rewrite(plan, &context).unwrap();
517
518 let scan_req = provider.scan_request();
519 let _ = scan_req.series_row_selector.unwrap();
520 }
521
522 #[test]
523 fn set_order_hint_sets_per_series_distribution_for_tsid_sort() {
524 let provider = Arc::new(mock_table_provider_with_tsid(RegionId::new(1, 1)));
525 let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
526 let plan = LogicalPlanBuilder::scan("t", table_source, None)
527 .unwrap()
528 .sort(vec![
529 col(DATA_SCHEMA_TSID_COLUMN_NAME).sort(true, true),
530 col("ts").sort(true, true),
531 ])
532 .unwrap()
533 .build()
534 .unwrap();
535
536 let context = OptimizerContext::default();
537 ScanHintRule.rewrite(plan, &context).unwrap();
538
539 let scan_req = provider.scan_request();
540 assert_eq!(
541 scan_req.distribution,
542 Some(TimeSeriesDistribution::PerSeries)
543 );
544 }
545}