1use std::collections::HashSet;
16use std::sync::Arc;
17
18use common_function::aggrs::aggr_wrapper::{StateMergeHelper, is_all_aggr_exprs_steppable};
19#[cfg(feature = "vector_index")]
20use common_function::scalars::vector::distance::{
21 VEC_COS_DISTANCE, VEC_DOT_PRODUCT, VEC_L2SQ_DISTANCE,
22};
23use common_telemetry::debug;
24use datafusion::error::Result as DfResult;
25#[cfg(feature = "vector_index")]
26use datafusion_common::DataFusionError;
27use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
28#[cfg(feature = "vector_index")]
29use datafusion_expr::Sort;
30use datafusion_expr::{Expr, LogicalPlan, UserDefinedLogicalNode};
31use promql::extension_plan::{
32 EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize,
33};
34use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
35
36use crate::dist_plan::MergeScanLogicalPlan;
37use crate::dist_plan::analyzer::AliasMapping;
38use crate::dist_plan::merge_sort::{MergeSortLogicalPlan, merge_sort_transformer};
39
40#[cfg(feature = "vector_index")]
41fn is_vector_sort(sort: &Sort) -> bool {
42 if sort.expr.len() != 1 {
43 return false;
44 }
45 let sort_expr = &sort.expr[0].expr;
46 let Expr::ScalarFunction(func) = sort_expr else {
47 return false;
48 };
49 matches!(
50 func.name().to_lowercase().as_str(),
51 VEC_L2SQ_DISTANCE | VEC_COS_DISTANCE | VEC_DOT_PRODUCT
52 )
53}
54
55#[cfg(feature = "vector_index")]
56fn vector_sort_transformer(plan: &LogicalPlan) -> DfResult<TransformerAction> {
57 let LogicalPlan::Sort(sort) = plan else {
58 return Err(DataFusionError::Internal(format!(
59 "vector_sort_transformer expects Sort, got {plan}"
60 )));
61 };
62 Ok(TransformerAction {
63 extra_parent_plans: vec![
64 MergeSortLogicalPlan::new(sort.input.clone(), sort.expr.clone(), sort.fetch)
65 .into_logical_plan(),
66 ],
67 new_child_plan: Some(LogicalPlan::Sort(sort.clone())),
68 })
69}
70
71pub struct StepTransformAction {
72 extra_parent_plans: Vec<LogicalPlan>,
73 new_child_plan: Option<LogicalPlan>,
74}
75
76pub fn step_aggr_to_upper_aggr(
86 aggr_plan: &LogicalPlan,
87) -> datafusion_common::Result<StepTransformAction> {
88 let LogicalPlan::Aggregate(input_aggr) = aggr_plan else {
89 return Err(datafusion_common::DataFusionError::Plan(
90 "step_aggr_to_upper_aggr only accepts Aggregate plan".to_string(),
91 ));
92 };
93 if !is_all_aggr_exprs_steppable(&input_aggr.aggr_expr) {
94 return Err(datafusion_common::DataFusionError::NotImplemented(format!(
95 "Some aggregate expressions are not steppable in [{}]",
96 input_aggr
97 .aggr_expr
98 .iter()
99 .map(|e| e.to_string())
100 .collect::<Vec<_>>()
101 .join(", ")
102 )));
103 }
104
105 let step_aggr_plan = StateMergeHelper::split_aggr_node(input_aggr.clone())?;
106
107 let ret = StepTransformAction {
109 extra_parent_plans: vec![step_aggr_plan.upper_merge.clone()],
110 new_child_plan: Some(step_aggr_plan.lower_state.clone()),
111 };
112 Ok(ret)
113}
114
115#[allow(dead_code)]
116pub enum Commutativity {
117 Commutative,
118 PartialCommutative,
119 ConditionalCommutative(Option<Transformer>),
120 TransformedCommutative {
121 transformer: Option<StageTransformer>,
123 },
124 NonCommutative,
125 Unimplemented,
126 Unsupported,
128}
129
130pub struct Categorizer {}
131
132impl Categorizer {
133 pub fn check_plan(
134 plan: &LogicalPlan,
135 partition_cols: Option<AliasMapping>,
136 ) -> DfResult<Commutativity> {
137 if has_subquery(plan)? {
140 return Ok(Commutativity::Unimplemented);
141 }
142
143 let partition_cols = partition_cols.unwrap_or_default();
144
145 let comm = match plan {
146 LogicalPlan::Projection(proj) => {
147 for expr in &proj.expr {
148 let commutativity = Self::check_expr(expr);
149 if !matches!(commutativity, Commutativity::Commutative) {
150 return Ok(commutativity);
151 }
152 }
153 Commutativity::Commutative
154 }
155 LogicalPlan::Filter(filter) => Self::check_expr(&filter.predicate),
157 LogicalPlan::Window(_) => Commutativity::Unimplemented,
158 LogicalPlan::Aggregate(aggr) => {
159 let is_all_steppable = is_all_aggr_exprs_steppable(&aggr.aggr_expr);
160 let matches_partition = Self::check_partition(&aggr.group_expr, &partition_cols);
161 if !matches_partition && is_all_steppable {
162 debug!("Plan is steppable: {plan}");
163 return Ok(Commutativity::TransformedCommutative {
164 transformer: Some(Arc::new(|plan: &LogicalPlan| {
165 debug!("Before Step optimize: {plan}");
166 let ret = step_aggr_to_upper_aggr(plan);
167 ret.inspect_err(|err| {
168 common_telemetry::error!("Failed to step aggregate plan: {err:?}");
169 })
170 .map(|s| TransformerAction {
171 extra_parent_plans: s.extra_parent_plans,
172 new_child_plan: s.new_child_plan,
173 })
174 })),
175 });
176 }
177 if !matches_partition {
178 return Ok(Commutativity::NonCommutative);
179 }
180 for expr in &aggr.aggr_expr {
181 let commutativity = Self::check_expr(expr);
182 if !matches!(commutativity, Commutativity::Commutative) {
183 return Ok(commutativity);
184 }
185 }
186 Commutativity::ConditionalCommutative(None)
191 }
192 LogicalPlan::Sort(_sort) => {
193 if partition_cols.is_empty() {
194 return Ok(Commutativity::Commutative);
195 }
196
197 #[cfg(feature = "vector_index")]
200 if is_vector_sort(_sort) {
201 return Ok(Commutativity::TransformedCommutative {
202 transformer: Some(Arc::new(vector_sort_transformer)),
203 });
204 }
205 Commutativity::ConditionalCommutative(Some(Arc::new(merge_sort_transformer)))
206 }
207 LogicalPlan::Join(_) => Commutativity::NonCommutative,
208 LogicalPlan::Repartition(_) => {
209 Commutativity::Unimplemented
211 }
212 LogicalPlan::Union(_) => Commutativity::Unimplemented,
213 LogicalPlan::TableScan(_) => Commutativity::Commutative,
214 LogicalPlan::EmptyRelation(_) => Commutativity::NonCommutative,
215 LogicalPlan::Subquery(_) => Commutativity::Unimplemented,
216 LogicalPlan::SubqueryAlias(_) => Commutativity::Commutative,
217 LogicalPlan::Limit(limit) => {
218 if partition_cols.is_empty() && limit.fetch.is_some() {
221 Commutativity::Commutative
222 } else if limit.skip.is_none() && limit.fetch.is_some() {
223 Commutativity::PartialCommutative
224 } else {
225 Commutativity::Unimplemented
226 }
227 }
228 LogicalPlan::Extension(extension) => {
229 Self::check_extension_plan(extension.node.as_ref() as _, &partition_cols)
230 }
231 LogicalPlan::Distinct(_) => {
232 if partition_cols.is_empty() {
233 Commutativity::Commutative
234 } else {
235 Commutativity::PartialCommutative
236 }
237 }
238 LogicalPlan::Unnest(_) => Commutativity::Commutative,
239 LogicalPlan::Statement(_) => Commutativity::Unsupported,
240 LogicalPlan::Values(_) => Commutativity::Unsupported,
241 LogicalPlan::Explain(_) => Commutativity::Unsupported,
242 LogicalPlan::Analyze(_) => Commutativity::Unsupported,
243 LogicalPlan::DescribeTable(_) => Commutativity::Unsupported,
244 LogicalPlan::Dml(_) => Commutativity::Unsupported,
245 LogicalPlan::Ddl(_) => Commutativity::Unsupported,
246 LogicalPlan::Copy(_) => Commutativity::Unsupported,
247 LogicalPlan::RecursiveQuery(_) => Commutativity::Unsupported,
248 };
249
250 Ok(comm)
251 }
252
253 pub fn check_extension_plan(
254 plan: &dyn UserDefinedLogicalNode,
255 partition_cols: &AliasMapping,
256 ) -> Commutativity {
257 match plan.name() {
258 name if name == SeriesDivide::name() => {
259 let series_divide = plan.as_any().downcast_ref::<SeriesDivide>().unwrap();
260 if series_divide
263 .tags()
264 .iter()
265 .any(|tag| tag == DATA_SCHEMA_TSID_COLUMN_NAME)
266 {
267 return Commutativity::Commutative;
268 }
269
270 let tags = series_divide.tags().iter().collect::<HashSet<_>>();
271
272 for all_alias in partition_cols.values() {
273 let all_alias = all_alias.iter().map(|c| &c.name).collect::<HashSet<_>>();
274 if tags.intersection(&all_alias).count() == 0 {
275 return Commutativity::NonCommutative;
276 }
277 }
278
279 Commutativity::Commutative
280 }
281 name if name == SeriesNormalize::name()
282 || name == InstantManipulate::name()
283 || name == RangeManipulate::name() =>
284 {
285 Commutativity::Commutative
288 }
289 name if name == EmptyMetric::name()
290 || name == MergeScanLogicalPlan::name()
291 || name == MergeSortLogicalPlan::name() =>
292 {
293 Commutativity::Unimplemented
294 }
295 _ => Commutativity::Unsupported,
296 }
297 }
298
299 pub fn check_expr(expr: &Expr) -> Commutativity {
300 #[allow(deprecated)]
301 match expr {
302 Expr::Column(_)
303 | Expr::ScalarVariable(_, _)
304 | Expr::Literal(_, _)
305 | Expr::BinaryExpr(_)
306 | Expr::Not(_)
307 | Expr::IsNotNull(_)
308 | Expr::IsNull(_)
309 | Expr::IsTrue(_)
310 | Expr::IsFalse(_)
311 | Expr::IsNotTrue(_)
312 | Expr::IsNotFalse(_)
313 | Expr::Negative(_)
314 | Expr::Between(_)
315 | Expr::Exists(_)
316 | Expr::InList(_)
317 | Expr::Case(_) => Commutativity::Commutative,
318 Expr::ScalarFunction(_udf) => Commutativity::Commutative,
319 Expr::AggregateFunction(_udaf) => Commutativity::Commutative,
320
321 Expr::Like(_)
322 | Expr::SimilarTo(_)
323 | Expr::IsUnknown(_)
324 | Expr::IsNotUnknown(_)
325 | Expr::Cast(_)
326 | Expr::TryCast(_)
327 | Expr::WindowFunction(_)
328 | Expr::InSubquery(_)
329 | Expr::ScalarSubquery(_)
330 | Expr::Wildcard { .. } => Commutativity::Unimplemented,
331
332 Expr::Alias(alias) => Self::check_expr(&alias.expr),
333
334 Expr::Unnest(_)
335 | Expr::GroupingSet(_)
336 | Expr::Placeholder(_)
337 | Expr::OuterReferenceColumn(_, _) => Commutativity::Unimplemented,
338 }
339 }
340
341 fn check_partition(exprs: &[Expr], partition_cols: &AliasMapping) -> bool {
348 let mut ref_cols = HashSet::new();
349 for expr in exprs {
350 expr.add_column_refs(&mut ref_cols);
351 }
352 let ref_cols = ref_cols
353 .into_iter()
354 .map(|c| c.name.clone())
355 .collect::<HashSet<_>>();
356 for all_alias in partition_cols.values() {
357 let all_alias = all_alias
358 .iter()
359 .map(|c| c.name.clone())
360 .collect::<HashSet<_>>();
361 if ref_cols.intersection(&all_alias).count() == 0 {
364 return false;
365 }
366 }
367
368 true
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use std::collections::{BTreeMap, BTreeSet};
375
376 use datafusion_common::Column;
377 use datafusion_expr::LogicalPlanBuilder;
378
379 use super::*;
380
381 #[test]
382 fn series_divide_by_tsid_is_commutative() {
383 let input = LogicalPlanBuilder::empty(false).build().unwrap();
384 let series_divide = SeriesDivide::new(
385 vec![DATA_SCHEMA_TSID_COLUMN_NAME.to_string()],
386 "ts".to_string(),
387 input,
388 );
389
390 let partition_cols: AliasMapping = BTreeMap::from([(
391 "some_partition_col".to_string(),
392 BTreeSet::from([Column::from_name("some_partition_col")]),
393 )]);
394
395 let commutativity = Categorizer::check_extension_plan(&series_divide, &partition_cols);
396 assert!(matches!(commutativity, Commutativity::Commutative));
397 }
398}
399
400pub type Transformer = Arc<dyn Fn(&LogicalPlan) -> Option<LogicalPlan>>;
401
402pub type StageTransformer = Arc<dyn Fn(&LogicalPlan) -> DfResult<TransformerAction>>;
404
405pub struct TransformerAction {
407 pub extra_parent_plans: Vec<LogicalPlan>,
416 pub new_child_plan: Option<LogicalPlan>,
418}
419
420pub fn partial_commutative_transformer(plan: &LogicalPlan) -> Option<LogicalPlan> {
421 Some(plan.clone())
422}
423
424fn has_subquery(plan: &LogicalPlan) -> DfResult<bool> {
425 let mut found = false;
426 plan.apply_expressions(|e| {
427 e.apply(|x| {
428 if matches!(
429 x,
430 Expr::Exists(_) | Expr::InSubquery(_) | Expr::ScalarSubquery(_)
431 ) {
432 found = true;
433 Ok(TreeNodeRecursion::Stop)
434 } else {
435 Ok(TreeNodeRecursion::Continue)
436 }
437 })
438 })?;
439 Ok(found)
440}
441
442#[cfg(test)]
443mod test {
444 use datafusion_expr::{LogicalPlanBuilder, Sort};
445
446 use super::*;
447
448 #[test]
449 fn sort_on_empty_partition() {
450 let plan = LogicalPlan::Sort(Sort {
451 expr: vec![],
452 input: Arc::new(LogicalPlanBuilder::empty(false).build().unwrap()),
453 fetch: None,
454 });
455 assert!(matches!(
456 Categorizer::check_plan(&plan, Some(Default::default())).unwrap(),
457 Commutativity::Commutative
458 ));
459 }
460}