1use std::collections::HashSet;
16use std::sync::Arc;
17
18use common_function::aggrs::aggr_wrapper::{StateMergeHelper, is_all_aggr_exprs_steppable};
19#[cfg(feature = "vector_index")]
20use common_function::scalars::vector::distance::{
21 VEC_COS_DISTANCE, VEC_DOT_PRODUCT, VEC_L2SQ_DISTANCE,
22};
23use common_telemetry::debug;
24use datafusion::error::Result as DfResult;
25#[cfg(feature = "vector_index")]
26use datafusion_common::DataFusionError;
27use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
28#[cfg(feature = "vector_index")]
29use datafusion_expr::Sort;
30use datafusion_expr::{Expr, LogicalPlan, UserDefinedLogicalNode};
31use promql::extension_plan::{
32 EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize,
33};
34use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
35
36use crate::dist_plan::MergeScanLogicalPlan;
37use crate::dist_plan::analyzer::AliasMapping;
38use crate::dist_plan::merge_sort::{MergeSortLogicalPlan, merge_sort_transformer};
39
40#[cfg(feature = "vector_index")]
41fn is_vector_sort(sort: &Sort) -> bool {
42 if sort.expr.len() != 1 {
43 return false;
44 }
45 let sort_expr = &sort.expr[0].expr;
46 let Expr::ScalarFunction(func) = sort_expr else {
47 return false;
48 };
49 matches!(
50 func.name().to_lowercase().as_str(),
51 VEC_L2SQ_DISTANCE | VEC_COS_DISTANCE | VEC_DOT_PRODUCT
52 )
53}
54
55#[cfg(feature = "vector_index")]
56fn vector_sort_transformer(plan: &LogicalPlan) -> DfResult<TransformerAction> {
57 let LogicalPlan::Sort(sort) = plan else {
58 return Err(DataFusionError::Internal(format!(
59 "vector_sort_transformer expects Sort, got {plan}"
60 )));
61 };
62 Ok(TransformerAction {
63 extra_parent_plans: vec![
64 MergeSortLogicalPlan::new(sort.input.clone(), sort.expr.clone(), sort.fetch)
65 .into_logical_plan(),
66 ],
67 new_child_plan: Some(LogicalPlan::Sort(sort.clone())),
68 })
69}
70
71pub struct StepTransformAction {
72 extra_parent_plans: Vec<LogicalPlan>,
73 new_child_plan: Option<LogicalPlan>,
74}
75
76pub fn step_aggr_to_upper_aggr(
86 aggr_plan: &LogicalPlan,
87) -> datafusion_common::Result<StepTransformAction> {
88 let LogicalPlan::Aggregate(input_aggr) = aggr_plan else {
89 return Err(datafusion_common::DataFusionError::Plan(
90 "step_aggr_to_upper_aggr only accepts Aggregate plan".to_string(),
91 ));
92 };
93 if !is_all_aggr_exprs_steppable(&input_aggr.aggr_expr) {
94 return Err(datafusion_common::DataFusionError::NotImplemented(format!(
95 "Some aggregate expressions are not steppable in [{}]",
96 input_aggr
97 .aggr_expr
98 .iter()
99 .map(|e| e.to_string())
100 .collect::<Vec<_>>()
101 .join(", ")
102 )));
103 }
104
105 let step_aggr_plan = StateMergeHelper::split_aggr_node(input_aggr.clone())?;
106
107 let ret = StepTransformAction {
109 extra_parent_plans: vec![step_aggr_plan.upper_merge.clone()],
110 new_child_plan: Some(step_aggr_plan.lower_state.clone()),
111 };
112 Ok(ret)
113}
114
115#[allow(dead_code)]
116pub enum Commutativity {
117 Commutative,
118 PartialCommutative,
119 ConditionalCommutative(Option<Transformer>),
120 TransformedCommutative {
121 transformer: Option<StageTransformer>,
123 },
124 NonCommutative,
125 Unimplemented,
126 Unsupported,
128}
129
130pub struct Categorizer {}
131
132impl Categorizer {
133 pub fn check_plan(
134 plan: &LogicalPlan,
135 partition_cols: Option<AliasMapping>,
136 ) -> DfResult<Commutativity> {
137 if has_subquery(plan)? {
140 return Ok(Commutativity::Unimplemented);
141 }
142
143 let partition_cols = partition_cols.unwrap_or_default();
144
145 let comm = match plan {
146 LogicalPlan::Projection(proj) => {
147 for expr in &proj.expr {
148 let commutativity = Self::check_expr(expr);
149 if !matches!(commutativity, Commutativity::Commutative) {
150 return Ok(commutativity);
151 }
152 }
153 Commutativity::Commutative
154 }
155 LogicalPlan::Filter(filter) => Self::check_expr(&filter.predicate),
157 LogicalPlan::Window(_) => Commutativity::Unimplemented,
158 LogicalPlan::Aggregate(aggr) => {
159 let is_all_steppable = is_all_aggr_exprs_steppable(&aggr.aggr_expr);
160 let matches_partition = Self::check_partition(&aggr.group_expr, &partition_cols);
161 if !matches_partition && is_all_steppable {
162 debug!("Plan is steppable: {plan}");
163 return Ok(Commutativity::TransformedCommutative {
164 transformer: Some(Arc::new(|plan: &LogicalPlan| {
165 debug!("Before Step optimize: {plan}");
166 let ret = step_aggr_to_upper_aggr(plan);
167 ret.inspect_err(|err| {
168 common_telemetry::error!("Failed to step aggregate plan: {err:?}");
169 })
170 .map(|s| TransformerAction {
171 extra_parent_plans: s.extra_parent_plans,
172 new_child_plan: s.new_child_plan,
173 })
174 })),
175 });
176 }
177 if !matches_partition {
178 return Ok(Commutativity::NonCommutative);
179 }
180 for expr in &aggr.aggr_expr {
181 let commutativity = Self::check_expr(expr);
182 if !matches!(commutativity, Commutativity::Commutative) {
183 return Ok(commutativity);
184 }
185 }
186 Commutativity::ConditionalCommutative(None)
191 }
192 LogicalPlan::Sort(_sort) => {
193 if partition_cols.is_empty() {
194 return Ok(Commutativity::Commutative);
195 }
196
197 #[cfg(feature = "vector_index")]
200 if is_vector_sort(_sort) {
201 return Ok(Commutativity::TransformedCommutative {
202 transformer: Some(Arc::new(vector_sort_transformer)),
203 });
204 }
205 Commutativity::ConditionalCommutative(Some(Arc::new(merge_sort_transformer)))
206 }
207 LogicalPlan::Join(_) => Commutativity::NonCommutative,
208 LogicalPlan::Repartition(_) => {
209 Commutativity::Unimplemented
211 }
212 LogicalPlan::Union(_) => Commutativity::Unimplemented,
213 LogicalPlan::TableScan(_) => Commutativity::Commutative,
214 LogicalPlan::EmptyRelation(_) => Commutativity::NonCommutative,
215 LogicalPlan::Subquery(_) => Commutativity::Unimplemented,
216 LogicalPlan::SubqueryAlias(_) => Commutativity::Commutative,
217 LogicalPlan::Limit(limit) => {
218 if partition_cols.is_empty() && limit.fetch.is_some() {
221 Commutativity::Commutative
222 } else if limit.skip.is_none() && limit.fetch.is_some() {
223 Commutativity::PartialCommutative
224 } else {
225 Commutativity::Unimplemented
226 }
227 }
228 LogicalPlan::Extension(extension) => {
229 Self::check_extension_plan(extension.node.as_ref() as _, &partition_cols)
230 }
231 LogicalPlan::Distinct(_) => {
232 if partition_cols.is_empty() {
233 Commutativity::Commutative
234 } else {
235 Commutativity::PartialCommutative
236 }
237 }
238 LogicalPlan::Unnest(_) => Commutativity::Commutative,
239 LogicalPlan::Statement(_) => Commutativity::Unsupported,
240 LogicalPlan::Values(_) => Commutativity::Unsupported,
241 LogicalPlan::Explain(_) => Commutativity::Unsupported,
242 LogicalPlan::Analyze(_) => Commutativity::Unsupported,
243 LogicalPlan::DescribeTable(_) => Commutativity::Unsupported,
244 LogicalPlan::Dml(_) => Commutativity::Unsupported,
245 LogicalPlan::Ddl(_) => Commutativity::Unsupported,
246 LogicalPlan::Copy(_) => Commutativity::Unsupported,
247 LogicalPlan::RecursiveQuery(_) => Commutativity::Unsupported,
248 };
249
250 Ok(comm)
251 }
252
253 pub fn check_extension_plan(
254 plan: &dyn UserDefinedLogicalNode,
255 partition_cols: &AliasMapping,
256 ) -> Commutativity {
257 match plan.name() {
258 name if name == SeriesDivide::name() => {
259 let series_divide = plan.as_any().downcast_ref::<SeriesDivide>().unwrap();
260 if series_divide
263 .tags()
264 .iter()
265 .any(|tag| tag == DATA_SCHEMA_TSID_COLUMN_NAME)
266 {
267 return Commutativity::Commutative;
268 }
269
270 let tags = series_divide.tags().iter().collect::<HashSet<_>>();
271
272 for all_alias in partition_cols.values() {
273 let all_alias = all_alias.iter().map(|c| &c.name).collect::<HashSet<_>>();
274 if tags.intersection(&all_alias).count() == 0 {
275 return Commutativity::NonCommutative;
276 }
277 }
278
279 Commutativity::Commutative
280 }
281 name if name == SeriesNormalize::name()
282 || name == InstantManipulate::name()
283 || name == RangeManipulate::name() =>
284 {
285 Commutativity::Commutative
288 }
289 name if name == EmptyMetric::name()
290 || name == MergeScanLogicalPlan::name()
291 || name == MergeSortLogicalPlan::name() =>
292 {
293 Commutativity::Unimplemented
294 }
295 _ => Commutativity::Unsupported,
296 }
297 }
298
299 pub fn check_expr(expr: &Expr) -> Commutativity {
300 #[allow(deprecated)]
301 match expr {
302 Expr::Column(_)
303 | Expr::ScalarVariable(_, _)
304 | Expr::Literal(_, _)
305 | Expr::BinaryExpr(_)
306 | Expr::Not(_)
307 | Expr::IsNotNull(_)
308 | Expr::IsNull(_)
309 | Expr::IsTrue(_)
310 | Expr::IsFalse(_)
311 | Expr::IsNotTrue(_)
312 | Expr::IsNotFalse(_)
313 | Expr::Negative(_)
314 | Expr::Between(_)
315 | Expr::Exists(_)
316 | Expr::InList(_)
317 | Expr::Case(_) => Commutativity::Commutative,
318 Expr::ScalarFunction(_udf) => Commutativity::Commutative,
319 Expr::AggregateFunction(_udaf) => Commutativity::Commutative,
320
321 Expr::Like(_)
322 | Expr::SimilarTo(_)
323 | Expr::IsUnknown(_)
324 | Expr::IsNotUnknown(_)
325 | Expr::Cast(_)
326 | Expr::TryCast(_)
327 | Expr::WindowFunction(_)
328 | Expr::InSubquery(_)
329 | Expr::ScalarSubquery(_)
330 | Expr::Wildcard { .. } => Commutativity::Unimplemented,
331
332 Expr::Alias(alias) => Self::check_expr(&alias.expr),
333
334 Expr::Unnest(_)
335 | Expr::GroupingSet(_)
336 | Expr::Placeholder(_)
337 | Expr::OuterReferenceColumn(_, _)
338 | Expr::SetComparison(_) => Commutativity::Unimplemented,
339 }
340 }
341
342 fn check_partition(exprs: &[Expr], partition_cols: &AliasMapping) -> bool {
349 let mut ref_cols = HashSet::new();
350 for expr in exprs {
351 expr.add_column_refs(&mut ref_cols);
352 }
353 let ref_cols = ref_cols
354 .into_iter()
355 .map(|c| c.name.clone())
356 .collect::<HashSet<_>>();
357 for all_alias in partition_cols.values() {
358 let all_alias = all_alias
359 .iter()
360 .map(|c| c.name.clone())
361 .collect::<HashSet<_>>();
362 if ref_cols.intersection(&all_alias).count() == 0 {
365 return false;
366 }
367 }
368
369 true
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use std::collections::{BTreeMap, BTreeSet};
376
377 use datafusion_common::Column;
378 use datafusion_expr::LogicalPlanBuilder;
379
380 use super::*;
381
382 #[test]
383 fn series_divide_by_tsid_is_commutative() {
384 let input = LogicalPlanBuilder::empty(false).build().unwrap();
385 let series_divide = SeriesDivide::new(
386 vec![DATA_SCHEMA_TSID_COLUMN_NAME.to_string()],
387 "ts".to_string(),
388 input,
389 );
390
391 let partition_cols: AliasMapping = BTreeMap::from([(
392 "some_partition_col".to_string(),
393 BTreeSet::from([Column::from_name("some_partition_col")]),
394 )]);
395
396 let commutativity = Categorizer::check_extension_plan(&series_divide, &partition_cols);
397 assert!(matches!(commutativity, Commutativity::Commutative));
398 }
399}
400
401pub type Transformer = Arc<dyn Fn(&LogicalPlan) -> Option<LogicalPlan>>;
402
403pub type StageTransformer = Arc<dyn Fn(&LogicalPlan) -> DfResult<TransformerAction>>;
405
406pub struct TransformerAction {
408 pub extra_parent_plans: Vec<LogicalPlan>,
417 pub new_child_plan: Option<LogicalPlan>,
419}
420
421pub fn partial_commutative_transformer(plan: &LogicalPlan) -> Option<LogicalPlan> {
422 Some(plan.clone())
423}
424
425fn has_subquery(plan: &LogicalPlan) -> DfResult<bool> {
426 let mut found = false;
427 plan.apply_expressions(|e| {
428 e.apply(|x| {
429 if matches!(
430 x,
431 Expr::Exists(_) | Expr::InSubquery(_) | Expr::ScalarSubquery(_)
432 ) {
433 found = true;
434 Ok(TreeNodeRecursion::Stop)
435 } else {
436 Ok(TreeNodeRecursion::Continue)
437 }
438 })
439 })?;
440 Ok(found)
441}
442
443#[cfg(test)]
444mod test {
445 use datafusion_expr::{LogicalPlanBuilder, Sort};
446
447 use super::*;
448
449 #[test]
450 fn sort_on_empty_partition() {
451 let plan = LogicalPlan::Sort(Sort {
452 expr: vec![],
453 input: Arc::new(LogicalPlanBuilder::empty(false).build().unwrap()),
454 fetch: None,
455 });
456 assert!(matches!(
457 Categorizer::check_plan(&plan, Some(Default::default())).unwrap(),
458 Commutativity::Commutative
459 ));
460 }
461}