query/optimizer/
scan_hint.rs1use std::collections::HashSet;
16
17use api::v1::SemanticType;
18use arrow_schema::SortOptions;
19use common_recordbatch::OrderOption;
20use datafusion::datasource::DefaultTableSource;
21use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor};
22use datafusion_common::{Column, Result};
23use datafusion_expr::expr::Sort;
24use datafusion_expr::{utils, Expr, LogicalPlan};
25use datafusion_optimizer::{OptimizerConfig, OptimizerRule};
26use store_api::storage::{TimeSeriesDistribution, TimeSeriesRowSelector};
27
28use crate::dummy_catalog::DummyTableProvider;
29
30#[derive(Debug)]
38pub struct ScanHintRule;
39
40impl OptimizerRule for ScanHintRule {
41 fn name(&self) -> &str {
42 "ScanHintRule"
43 }
44
45 fn rewrite(
46 &self,
47 plan: LogicalPlan,
48 _config: &dyn OptimizerConfig,
49 ) -> Result<Transformed<LogicalPlan>> {
50 Self::optimize(plan)
51 }
52}
53
54impl ScanHintRule {
55 fn optimize(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
56 let mut visitor = ScanHintVisitor::default();
57 let _ = plan.visit(&mut visitor)?;
58
59 if visitor.need_rewrite() {
60 plan.transform_down(&|plan| Self::set_hints(plan, &visitor))
61 } else {
62 Ok(Transformed::no(plan))
63 }
64 }
65
66 fn set_hints(plan: LogicalPlan, visitor: &ScanHintVisitor) -> Result<Transformed<LogicalPlan>> {
67 match &plan {
68 LogicalPlan::TableScan(table_scan) => {
69 let mut transformed = false;
70 if let Some(source) = table_scan
71 .source
72 .as_any()
73 .downcast_ref::<DefaultTableSource>()
74 {
75 if let Some(adapter) = source
77 .table_provider
78 .as_any()
79 .downcast_ref::<DummyTableProvider>()
80 {
81 if let Some(order_expr) = &visitor.order_expr {
83 Self::set_order_hint(adapter, order_expr);
84 }
85
86 if let Some((group_by_cols, order_by_col)) = &visitor.ts_row_selector {
88 Self::set_time_series_row_selector_hint(
89 adapter,
90 group_by_cols,
91 order_by_col,
92 );
93 }
94
95 transformed = true;
96 }
97 }
98 if transformed {
99 Ok(Transformed::yes(plan))
100 } else {
101 Ok(Transformed::no(plan))
102 }
103 }
104 _ => Ok(Transformed::no(plan)),
105 }
106 }
107
108 fn set_order_hint(adapter: &DummyTableProvider, order_expr: &Vec<Sort>) {
109 let mut opts = Vec::with_capacity(order_expr.len());
110 for sort in order_expr {
111 let name = match sort.expr.try_as_col() {
112 Some(col) => col.name.clone(),
113 None => return,
114 };
115 opts.push(OrderOption {
116 name,
117 options: SortOptions {
118 descending: !sort.asc,
119 nulls_first: sort.nulls_first,
120 },
121 });
122 }
123 adapter.with_ordering_hint(&opts);
124
125 let mut sort_expr_cursor = order_expr.iter().filter_map(|s| s.expr.try_as_col());
126 let region_metadata = adapter.region_metadata();
127 if region_metadata.primary_key.is_empty() {
129 return;
130 }
131 let mut pk_column_iter = region_metadata.primary_key_columns();
132 let mut curr_sort_expr = sort_expr_cursor.next();
133 let mut curr_pk_col = pk_column_iter.next();
134
135 while let (Some(sort_expr), Some(pk_col)) = (curr_sort_expr, curr_pk_col) {
136 if sort_expr.name == pk_col.column_schema.name {
137 curr_sort_expr = sort_expr_cursor.next();
138 curr_pk_col = pk_column_iter.next();
139 } else {
140 return;
141 }
142 }
143
144 let next_remaining = sort_expr_cursor.next();
145 match (curr_sort_expr, next_remaining) {
146 (Some(expr), None)
147 if expr.name == region_metadata.time_index_column().column_schema.name =>
148 {
149 adapter.with_distribution(TimeSeriesDistribution::PerSeries);
150 }
151 (None, _) => adapter.with_distribution(TimeSeriesDistribution::PerSeries),
152 (Some(_), _) => {}
153 }
154 }
155
156 fn set_time_series_row_selector_hint(
157 adapter: &DummyTableProvider,
158 group_by_cols: &HashSet<Column>,
159 order_by_col: &Column,
160 ) {
161 let region_metadata = adapter.region_metadata();
162 let mut should_set_selector_hint = true;
163 if let Some(column_metadata) = region_metadata.column_by_name(&order_by_col.name) {
165 if column_metadata.semantic_type != SemanticType::Timestamp {
166 should_set_selector_hint = false;
167 }
168 } else {
169 should_set_selector_hint = false;
170 }
171
172 for col in group_by_cols {
174 let Some(column_metadata) = region_metadata.column_by_name(&col.name) else {
175 should_set_selector_hint = false;
176 break;
177 };
178 if column_metadata.semantic_type != SemanticType::Tag {
179 should_set_selector_hint = false;
180 break;
181 }
182 }
183
184 if should_set_selector_hint {
185 adapter.with_time_series_selector_hint(TimeSeriesRowSelector::LastRow);
186 }
187 }
188}
189
190#[derive(Default)]
192struct ScanHintVisitor {
193 order_expr: Option<Vec<Sort>>,
195 ts_row_selector: Option<(HashSet<Column>, Column)>,
199}
200
201impl TreeNodeVisitor<'_> for ScanHintVisitor {
202 type Node = LogicalPlan;
203
204 fn f_down(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
205 if let LogicalPlan::Sort(sort) = node {
207 self.order_expr = Some(sort.expr.clone());
208 }
209
210 if let LogicalPlan::Aggregate(aggregate) = node {
212 let mut is_all_last_value = !aggregate.aggr_expr.is_empty();
213 let mut order_by_expr = None;
214 for expr in &aggregate.aggr_expr {
215 let Expr::AggregateFunction(func) = expr else {
217 is_all_last_value = false;
218 break;
219 };
220 if func.func.name() != "last_value"
221 || func.params.filter.is_some()
222 || func.params.distinct
223 {
224 is_all_last_value = false;
225 break;
226 }
227 let order_by = &func.params.order_by;
229 if let Some(first_order_by) = order_by.first()
230 && order_by.len() == 1
231 {
232 if let Some(existing_order_by) = &order_by_expr {
233 if existing_order_by != first_order_by {
234 is_all_last_value = false;
235 break;
236 }
237 } else {
238 if !first_order_by.asc || !matches!(&first_order_by.expr, Expr::Column(_)) {
241 is_all_last_value = false;
242 break;
243 }
244 order_by_expr = Some(first_order_by.clone());
245 }
246 }
247 }
248 is_all_last_value &= order_by_expr.is_some();
249 if is_all_last_value {
250 let mut group_by_cols = HashSet::with_capacity(aggregate.group_expr.len());
252 for expr in &aggregate.group_expr {
253 if let Expr::Column(col) = expr {
254 group_by_cols.insert(col.clone());
255 } else {
256 is_all_last_value = false;
257 break;
258 }
259 }
260 let order_by_expr = order_by_expr.unwrap();
262 let Expr::Column(order_by_col) = order_by_expr.expr else {
263 unreachable!()
264 };
265 if is_all_last_value {
266 self.ts_row_selector = Some((group_by_cols, order_by_col));
267 }
268 }
269 }
270
271 if self.ts_row_selector.is_some()
272 && (matches!(node, LogicalPlan::Subquery(_)) || node.inputs().len() > 1)
273 {
274 self.ts_row_selector = None;
276 }
277
278 if let LogicalPlan::Filter(filter) = node
279 && let Some(group_by_exprs) = &self.ts_row_selector
280 {
281 let mut filter_referenced_cols = HashSet::default();
282 utils::expr_to_columns(&filter.predicate, &mut filter_referenced_cols)?;
283 if !filter_referenced_cols.is_subset(&group_by_exprs.0) {
285 self.ts_row_selector = None;
286 }
287 }
288
289 Ok(TreeNodeRecursion::Continue)
290 }
291}
292
293impl ScanHintVisitor {
294 fn need_rewrite(&self) -> bool {
295 self.order_expr.is_some() || self.ts_row_selector.is_some()
296 }
297}
298
299#[cfg(test)]
300mod test {
301 use std::sync::Arc;
302
303 use datafusion::functions_aggregate::first_last::last_value_udaf;
304 use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
305 use datafusion_expr::{col, LogicalPlanBuilder};
306 use datafusion_optimizer::OptimizerContext;
307 use store_api::storage::RegionId;
308
309 use super::*;
310 use crate::optimizer::test_util::mock_table_provider;
311
312 #[test]
313 fn set_order_hint() {
314 let provider = Arc::new(mock_table_provider(RegionId::new(1, 1)));
315 let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
316 let plan = LogicalPlanBuilder::scan("t", table_source, None)
317 .unwrap()
318 .sort(vec![col("ts").sort(true, false)])
319 .unwrap()
320 .sort(vec![col("ts").sort(false, true)])
321 .unwrap()
322 .build()
323 .unwrap();
324
325 let context = OptimizerContext::default();
326 ScanHintRule.rewrite(plan, &context).unwrap();
327
328 let scan_req = provider.scan_request();
330 assert_eq!(
331 OrderOption {
332 name: "ts".to_string(),
333 options: SortOptions {
334 descending: false,
335 nulls_first: false
336 }
337 },
338 scan_req.output_ordering.as_ref().unwrap()[0]
339 );
340 }
341
342 #[test]
343 fn set_time_series_row_selector_hint() {
344 let provider = Arc::new(mock_table_provider(RegionId::new(1, 1)));
345 let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
346 let plan = LogicalPlanBuilder::scan("t", table_source, None)
347 .unwrap()
348 .aggregate(
349 vec![col("k0")],
350 vec![Expr::AggregateFunction(AggregateFunction {
351 func: last_value_udaf(),
352 params: AggregateFunctionParams {
353 args: vec![col("v0")],
354 distinct: false,
355 filter: None,
356 order_by: vec![Sort {
357 expr: col("ts"),
358 asc: true,
359 nulls_first: true,
360 }],
361 null_treatment: None,
362 },
363 })],
364 )
365 .unwrap()
366 .build()
367 .unwrap();
368
369 let context = OptimizerContext::default();
370 ScanHintRule.rewrite(plan, &context).unwrap();
371
372 let scan_req = provider.scan_request();
373 let _ = scan_req.series_row_selector.unwrap();
374 }
375}