query/optimizer/
scan_hint.rs1use std::collections::HashSet;
16
17use api::v1::SemanticType;
18use arrow_schema::SortOptions;
19use common_recordbatch::OrderOption;
20use datafusion::datasource::DefaultTableSource;
21use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor};
22use datafusion_common::{Column, Result};
23use datafusion_expr::expr::Sort;
24use datafusion_expr::{utils, Expr, LogicalPlan};
25use datafusion_optimizer::{OptimizerConfig, OptimizerRule};
26use store_api::storage::{TimeSeriesDistribution, TimeSeriesRowSelector};
27
28use crate::dummy_catalog::DummyTableProvider;
29
30#[derive(Debug)]
38pub struct ScanHintRule;
39
40impl OptimizerRule for ScanHintRule {
41 fn name(&self) -> &str {
42 "ScanHintRule"
43 }
44
45 fn rewrite(
46 &self,
47 plan: LogicalPlan,
48 _config: &dyn OptimizerConfig,
49 ) -> Result<Transformed<LogicalPlan>> {
50 Self::optimize(plan)
51 }
52}
53
54impl ScanHintRule {
55 fn optimize(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
56 let mut visitor = ScanHintVisitor::default();
57 let _ = plan.visit(&mut visitor)?;
58
59 if visitor.need_rewrite() {
60 plan.transform_down(&|plan| Self::set_hints(plan, &visitor))
61 } else {
62 Ok(Transformed::no(plan))
63 }
64 }
65
66 fn set_hints(plan: LogicalPlan, visitor: &ScanHintVisitor) -> Result<Transformed<LogicalPlan>> {
67 match &plan {
68 LogicalPlan::TableScan(table_scan) => {
69 let mut transformed = false;
70 if let Some(source) = table_scan
71 .source
72 .as_any()
73 .downcast_ref::<DefaultTableSource>()
74 {
75 if let Some(adapter) = source
77 .table_provider
78 .as_any()
79 .downcast_ref::<DummyTableProvider>()
80 {
81 if let Some(order_expr) = &visitor.order_expr {
83 Self::set_order_hint(adapter, order_expr);
84 }
85
86 if let Some((group_by_cols, order_by_col)) = &visitor.ts_row_selector {
88 Self::set_time_series_row_selector_hint(
89 adapter,
90 group_by_cols,
91 order_by_col,
92 );
93 }
94
95 transformed = true;
96 }
97 }
98 if transformed {
99 Ok(Transformed::yes(plan))
100 } else {
101 Ok(Transformed::no(plan))
102 }
103 }
104 _ => Ok(Transformed::no(plan)),
105 }
106 }
107
108 fn set_order_hint(adapter: &DummyTableProvider, order_expr: &Vec<Sort>) {
109 let mut opts = Vec::with_capacity(order_expr.len());
110 for sort in order_expr {
111 let name = match sort.expr.try_as_col() {
112 Some(col) => col.name.clone(),
113 None => return,
114 };
115 opts.push(OrderOption {
116 name,
117 options: SortOptions {
118 descending: !sort.asc,
119 nulls_first: sort.nulls_first,
120 },
121 });
122 }
123 adapter.with_ordering_hint(&opts);
124
125 let mut sort_expr_cursor = order_expr.iter().filter_map(|s| s.expr.try_as_col());
126 let region_metadata = adapter.region_metadata();
127 if region_metadata.primary_key.is_empty() {
129 return;
130 }
131 let mut pk_column_iter = region_metadata.primary_key_columns();
132 let mut curr_sort_expr = sort_expr_cursor.next();
133 let mut curr_pk_col = pk_column_iter.next();
134
135 while let (Some(sort_expr), Some(pk_col)) = (curr_sort_expr, curr_pk_col) {
136 if sort_expr.name == pk_col.column_schema.name {
137 curr_sort_expr = sort_expr_cursor.next();
138 curr_pk_col = pk_column_iter.next();
139 } else {
140 return;
141 }
142 }
143
144 let next_remaining = sort_expr_cursor.next();
145 match (curr_sort_expr, next_remaining) {
146 (Some(expr), None)
147 if expr.name == region_metadata.time_index_column().column_schema.name =>
148 {
149 adapter.with_distribution(TimeSeriesDistribution::PerSeries);
150 }
151 (None, _) => adapter.with_distribution(TimeSeriesDistribution::PerSeries),
152 (Some(_), _) => {}
153 }
154 }
155
156 fn set_time_series_row_selector_hint(
157 adapter: &DummyTableProvider,
158 group_by_cols: &HashSet<Column>,
159 order_by_col: &Column,
160 ) {
161 let region_metadata = adapter.region_metadata();
162 let mut should_set_selector_hint = true;
163 if let Some(column_metadata) = region_metadata.column_by_name(&order_by_col.name) {
165 if column_metadata.semantic_type != SemanticType::Timestamp {
166 should_set_selector_hint = false;
167 }
168 } else {
169 should_set_selector_hint = false;
170 }
171
172 for col in group_by_cols {
174 let Some(column_metadata) = region_metadata.column_by_name(&col.name) else {
175 should_set_selector_hint = false;
176 break;
177 };
178 if column_metadata.semantic_type != SemanticType::Tag {
179 should_set_selector_hint = false;
180 break;
181 }
182 }
183
184 if should_set_selector_hint {
185 adapter.with_time_series_selector_hint(TimeSeriesRowSelector::LastRow);
186 }
187 }
188}
189
190#[derive(Default)]
192struct ScanHintVisitor {
193 order_expr: Option<Vec<Sort>>,
195 ts_row_selector: Option<(HashSet<Column>, Column)>,
199}
200
201impl TreeNodeVisitor<'_> for ScanHintVisitor {
202 type Node = LogicalPlan;
203
204 fn f_down(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
205 if let LogicalPlan::Sort(sort) = node {
207 self.order_expr = Some(sort.expr.clone());
208 }
209
210 if let LogicalPlan::Aggregate(aggregate) = node {
212 let mut is_all_last_value = !aggregate.aggr_expr.is_empty();
213 let mut order_by_expr = None;
214 for expr in &aggregate.aggr_expr {
215 let Expr::AggregateFunction(func) = expr else {
217 is_all_last_value = false;
218 break;
219 };
220 if func.func.name() != "last_value" || func.filter.is_some() || func.distinct {
221 is_all_last_value = false;
222 break;
223 }
224 if let Some(order_by) = &func.order_by
226 && let Some(first_order_by) = order_by.first()
227 && order_by.len() == 1
228 {
229 if let Some(existing_order_by) = &order_by_expr {
230 if existing_order_by != first_order_by {
231 is_all_last_value = false;
232 break;
233 }
234 } else {
235 if !first_order_by.asc || !matches!(&first_order_by.expr, Expr::Column(_)) {
238 is_all_last_value = false;
239 break;
240 }
241 order_by_expr = Some(first_order_by.clone());
242 }
243 }
244 }
245 is_all_last_value &= order_by_expr.is_some();
246 if is_all_last_value {
247 let mut group_by_cols = HashSet::with_capacity(aggregate.group_expr.len());
249 for expr in &aggregate.group_expr {
250 if let Expr::Column(col) = expr {
251 group_by_cols.insert(col.clone());
252 } else {
253 is_all_last_value = false;
254 break;
255 }
256 }
257 let order_by_expr = order_by_expr.unwrap();
259 let Expr::Column(order_by_col) = order_by_expr.expr else {
260 unreachable!()
261 };
262 if is_all_last_value {
263 self.ts_row_selector = Some((group_by_cols, order_by_col));
264 }
265 }
266 }
267
268 if self.ts_row_selector.is_some()
269 && (matches!(node, LogicalPlan::Subquery(_)) || node.inputs().len() > 1)
270 {
271 self.ts_row_selector = None;
273 }
274
275 if let LogicalPlan::Filter(filter) = node
276 && let Some(group_by_exprs) = &self.ts_row_selector
277 {
278 let mut filter_referenced_cols = HashSet::default();
279 utils::expr_to_columns(&filter.predicate, &mut filter_referenced_cols)?;
280 if !filter_referenced_cols.is_subset(&group_by_exprs.0) {
282 self.ts_row_selector = None;
283 }
284 }
285
286 Ok(TreeNodeRecursion::Continue)
287 }
288}
289
290impl ScanHintVisitor {
291 fn need_rewrite(&self) -> bool {
292 self.order_expr.is_some() || self.ts_row_selector.is_some()
293 }
294}
295
296#[cfg(test)]
297mod test {
298 use std::sync::Arc;
299
300 use datafusion::functions_aggregate::first_last::last_value_udaf;
301 use datafusion_expr::expr::AggregateFunction;
302 use datafusion_expr::{col, LogicalPlanBuilder};
303 use datafusion_optimizer::OptimizerContext;
304 use store_api::storage::RegionId;
305
306 use super::*;
307 use crate::optimizer::test_util::mock_table_provider;
308
309 #[test]
310 fn set_order_hint() {
311 let provider = Arc::new(mock_table_provider(RegionId::new(1, 1)));
312 let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
313 let plan = LogicalPlanBuilder::scan("t", table_source, None)
314 .unwrap()
315 .sort(vec![col("ts").sort(true, false)])
316 .unwrap()
317 .sort(vec![col("ts").sort(false, true)])
318 .unwrap()
319 .build()
320 .unwrap();
321
322 let context = OptimizerContext::default();
323 assert!(ScanHintRule.supports_rewrite());
324 ScanHintRule.rewrite(plan, &context).unwrap();
325
326 let scan_req = provider.scan_request();
328 assert_eq!(
329 OrderOption {
330 name: "ts".to_string(),
331 options: SortOptions {
332 descending: false,
333 nulls_first: false
334 }
335 },
336 scan_req.output_ordering.as_ref().unwrap()[0]
337 );
338 }
339
340 #[test]
341 fn set_time_series_row_selector_hint() {
342 let provider = Arc::new(mock_table_provider(RegionId::new(1, 1)));
343 let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
344 let plan = LogicalPlanBuilder::scan("t", table_source, None)
345 .unwrap()
346 .aggregate(
347 vec![col("k0")],
348 vec![Expr::AggregateFunction(AggregateFunction {
349 func: last_value_udaf(),
350 args: vec![col("v0")],
351 distinct: false,
352 filter: None,
353 order_by: Some(vec![Sort {
354 expr: col("ts"),
355 asc: true,
356 nulls_first: true,
357 }]),
358 null_treatment: None,
359 })],
360 )
361 .unwrap()
362 .build()
363 .unwrap();
364
365 let context = OptimizerContext::default();
366 assert!(ScanHintRule.supports_rewrite());
367 ScanHintRule.rewrite(plan, &context).unwrap();
368
369 let scan_req = provider.scan_request();
370 let _ = scan_req.series_row_selector.unwrap();
371 }
372}