query/optimizer/
scan_hint.rs1use std::collections::HashSet;
16
17use api::v1::SemanticType;
18use arrow_schema::SortOptions;
19use common_function::aggrs::aggr_wrapper::aggr_state_func_name;
20use common_recordbatch::OrderOption;
21use datafusion::datasource::DefaultTableSource;
22use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor};
23use datafusion_common::{Column, Result};
24use datafusion_expr::expr::Sort;
25use datafusion_expr::{Expr, LogicalPlan, utils};
26use datafusion_optimizer::{OptimizerConfig, OptimizerRule};
27use store_api::storage::{TimeSeriesDistribution, TimeSeriesRowSelector};
28
29use crate::dummy_catalog::DummyTableProvider;
30
31#[derive(Debug)]
39pub struct ScanHintRule;
40
41impl OptimizerRule for ScanHintRule {
42 fn name(&self) -> &str {
43 "ScanHintRule"
44 }
45
46 fn rewrite(
47 &self,
48 plan: LogicalPlan,
49 _config: &dyn OptimizerConfig,
50 ) -> Result<Transformed<LogicalPlan>> {
51 Self::optimize(plan)
52 }
53}
54
55impl ScanHintRule {
56 fn optimize(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
57 let mut visitor = ScanHintVisitor::default();
58 let _ = plan.visit(&mut visitor)?;
59
60 if visitor.need_rewrite() {
61 plan.transform_down(&|plan| Self::set_hints(plan, &visitor))
62 } else {
63 Ok(Transformed::no(plan))
64 }
65 }
66
67 fn set_hints(plan: LogicalPlan, visitor: &ScanHintVisitor) -> Result<Transformed<LogicalPlan>> {
68 match &plan {
69 LogicalPlan::TableScan(table_scan) => {
70 let mut transformed = false;
71 if let Some(source) = table_scan
72 .source
73 .as_any()
74 .downcast_ref::<DefaultTableSource>()
75 {
76 if let Some(adapter) = source
78 .table_provider
79 .as_any()
80 .downcast_ref::<DummyTableProvider>()
81 {
82 if let Some(order_expr) = &visitor.order_expr {
84 Self::set_order_hint(adapter, order_expr);
85 }
86
87 if let Some((group_by_cols, order_by_col)) = &visitor.ts_row_selector {
89 Self::set_time_series_row_selector_hint(
90 adapter,
91 group_by_cols,
92 order_by_col,
93 );
94 }
95
96 transformed = true;
97 }
98 }
99 if transformed {
100 Ok(Transformed::yes(plan))
101 } else {
102 Ok(Transformed::no(plan))
103 }
104 }
105 _ => Ok(Transformed::no(plan)),
106 }
107 }
108
109 fn set_order_hint(adapter: &DummyTableProvider, order_expr: &Vec<Sort>) {
110 let mut opts = Vec::with_capacity(order_expr.len());
111 for sort in order_expr {
112 let name = match sort.expr.try_as_col() {
113 Some(col) => col.name.clone(),
114 None => return,
115 };
116 opts.push(OrderOption {
117 name,
118 options: SortOptions {
119 descending: !sort.asc,
120 nulls_first: sort.nulls_first,
121 },
122 });
123 }
124 adapter.with_ordering_hint(&opts);
125
126 let mut sort_expr_cursor = order_expr.iter().filter_map(|s| s.expr.try_as_col());
127 let region_metadata = adapter.region_metadata();
128 if region_metadata.primary_key.is_empty() {
130 return;
131 }
132 let mut pk_column_iter = region_metadata.primary_key_columns();
133 let mut curr_sort_expr = sort_expr_cursor.next();
134 let mut curr_pk_col = pk_column_iter.next();
135
136 while let (Some(sort_expr), Some(pk_col)) = (curr_sort_expr, curr_pk_col) {
137 if sort_expr.name == pk_col.column_schema.name {
138 curr_sort_expr = sort_expr_cursor.next();
139 curr_pk_col = pk_column_iter.next();
140 } else {
141 return;
142 }
143 }
144
145 let next_remaining = sort_expr_cursor.next();
146 match (curr_sort_expr, next_remaining) {
147 (Some(expr), None)
148 if expr.name == region_metadata.time_index_column().column_schema.name =>
149 {
150 adapter.with_distribution(TimeSeriesDistribution::PerSeries);
151 }
152 (None, _) => adapter.with_distribution(TimeSeriesDistribution::PerSeries),
153 (Some(_), _) => {}
154 }
155 }
156
157 fn set_time_series_row_selector_hint(
158 adapter: &DummyTableProvider,
159 group_by_cols: &HashSet<Column>,
160 order_by_col: &Column,
161 ) {
162 let region_metadata = adapter.region_metadata();
163 let mut should_set_selector_hint = true;
164 if let Some(column_metadata) = region_metadata.column_by_name(&order_by_col.name) {
166 if column_metadata.semantic_type != SemanticType::Timestamp {
167 should_set_selector_hint = false;
168 }
169 } else {
170 should_set_selector_hint = false;
171 }
172
173 for col in group_by_cols {
175 let Some(column_metadata) = region_metadata.column_by_name(&col.name) else {
176 should_set_selector_hint = false;
177 break;
178 };
179 if column_metadata.semantic_type != SemanticType::Tag {
180 should_set_selector_hint = false;
181 break;
182 }
183 }
184
185 if should_set_selector_hint {
186 adapter.with_time_series_selector_hint(TimeSeriesRowSelector::LastRow);
187 }
188 }
189}
190
191#[derive(Default)]
193struct ScanHintVisitor {
194 order_expr: Option<Vec<Sort>>,
196 ts_row_selector: Option<(HashSet<Column>, Column)>,
200}
201
202impl TreeNodeVisitor<'_> for ScanHintVisitor {
203 type Node = LogicalPlan;
204
205 fn f_down(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
206 if let LogicalPlan::Sort(sort) = node {
208 self.order_expr = Some(sort.expr.clone());
209 }
210
211 if let LogicalPlan::Aggregate(aggregate) = node {
213 let mut is_all_last_value = !aggregate.aggr_expr.is_empty();
214 let mut order_by_expr = None;
215 for expr in &aggregate.aggr_expr {
216 let Expr::AggregateFunction(func) = expr else {
218 is_all_last_value = false;
219 break;
220 };
221 if (func.func.name() != "last_value"
222 && func.func.name() != aggr_state_func_name("last_value"))
223 || func.params.filter.is_some()
224 || func.params.distinct
225 {
226 is_all_last_value = false;
227 break;
228 }
229 let order_by = &func.params.order_by;
231 if let Some(first_order_by) = order_by.first()
232 && order_by.len() == 1
233 {
234 if let Some(existing_order_by) = &order_by_expr {
235 if existing_order_by != first_order_by {
236 is_all_last_value = false;
237 break;
238 }
239 } else {
240 if !first_order_by.asc || !matches!(&first_order_by.expr, Expr::Column(_)) {
243 is_all_last_value = false;
244 break;
245 }
246 order_by_expr = Some(first_order_by.clone());
247 }
248 }
249 }
250 is_all_last_value &= order_by_expr.is_some();
251 if is_all_last_value {
252 let mut group_by_cols = HashSet::with_capacity(aggregate.group_expr.len());
254 for expr in &aggregate.group_expr {
255 if let Expr::Column(col) = expr {
256 group_by_cols.insert(col.clone());
257 } else {
258 is_all_last_value = false;
259 break;
260 }
261 }
262 let order_by_expr = order_by_expr.unwrap();
264 let Expr::Column(order_by_col) = order_by_expr.expr else {
265 unreachable!()
266 };
267 if is_all_last_value {
268 self.ts_row_selector = Some((group_by_cols, order_by_col));
269 }
270 }
271 }
272
273 if self.ts_row_selector.is_some()
274 && (matches!(node, LogicalPlan::Subquery(_)) || node.inputs().len() > 1)
275 {
276 self.ts_row_selector = None;
278 }
279
280 if let LogicalPlan::Filter(filter) = node
281 && let Some(group_by_exprs) = &self.ts_row_selector
282 {
283 let mut filter_referenced_cols = HashSet::default();
284 utils::expr_to_columns(&filter.predicate, &mut filter_referenced_cols)?;
285 if !filter_referenced_cols.is_subset(&group_by_exprs.0) {
287 self.ts_row_selector = None;
288 }
289 }
290
291 Ok(TreeNodeRecursion::Continue)
292 }
293}
294
295impl ScanHintVisitor {
296 fn need_rewrite(&self) -> bool {
297 self.order_expr.is_some() || self.ts_row_selector.is_some()
298 }
299}
300
301#[cfg(test)]
302mod test {
303 use std::sync::Arc;
304
305 use datafusion::functions_aggregate::first_last::last_value_udaf;
306 use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
307 use datafusion_expr::{LogicalPlanBuilder, col};
308 use datafusion_optimizer::OptimizerContext;
309 use store_api::storage::RegionId;
310
311 use super::*;
312 use crate::optimizer::test_util::mock_table_provider;
313
314 #[test]
315 fn set_order_hint() {
316 let provider = Arc::new(mock_table_provider(RegionId::new(1, 1)));
317 let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
318 let plan = LogicalPlanBuilder::scan("t", table_source, None)
319 .unwrap()
320 .sort(vec![col("ts").sort(true, false)])
321 .unwrap()
322 .sort(vec![col("ts").sort(false, true)])
323 .unwrap()
324 .build()
325 .unwrap();
326
327 let context = OptimizerContext::default();
328 ScanHintRule.rewrite(plan, &context).unwrap();
329
330 let scan_req = provider.scan_request();
332 assert_eq!(
333 OrderOption {
334 name: "ts".to_string(),
335 options: SortOptions {
336 descending: false,
337 nulls_first: false
338 }
339 },
340 scan_req.output_ordering.as_ref().unwrap()[0]
341 );
342 }
343
344 #[test]
345 fn set_time_series_row_selector_hint() {
346 let provider = Arc::new(mock_table_provider(RegionId::new(1, 1)));
347 let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
348 let plan = LogicalPlanBuilder::scan("t", table_source, None)
349 .unwrap()
350 .aggregate(
351 vec![col("k0")],
352 vec![Expr::AggregateFunction(AggregateFunction {
353 func: last_value_udaf(),
354 params: AggregateFunctionParams {
355 args: vec![col("v0")],
356 distinct: false,
357 filter: None,
358 order_by: vec![Sort {
359 expr: col("ts"),
360 asc: true,
361 nulls_first: true,
362 }],
363 null_treatment: None,
364 },
365 })],
366 )
367 .unwrap()
368 .build()
369 .unwrap();
370
371 let context = OptimizerContext::default();
372 ScanHintRule.rewrite(plan, &context).unwrap();
373
374 let scan_req = provider.scan_request();
375 let _ = scan_req.series_row_selector.unwrap();
376 }
377}