query/dist_plan/
commutativity.rs1use std::collections::HashSet;
16use std::sync::Arc;
17
18use common_function::aggrs::aggr_wrapper::{StateMergeHelper, is_all_aggr_exprs_steppable};
19use common_telemetry::debug;
20use datafusion::error::Result as DfResult;
21use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
22use datafusion_expr::{Expr, LogicalPlan, UserDefinedLogicalNode};
23use promql::extension_plan::{
24 EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize,
25};
26
27use crate::dist_plan::MergeScanLogicalPlan;
28use crate::dist_plan::analyzer::AliasMapping;
29use crate::dist_plan::merge_sort::{MergeSortLogicalPlan, merge_sort_transformer};
30
31pub struct StepTransformAction {
32 extra_parent_plans: Vec<LogicalPlan>,
33 new_child_plan: Option<LogicalPlan>,
34}
35
36pub fn step_aggr_to_upper_aggr(
46 aggr_plan: &LogicalPlan,
47) -> datafusion_common::Result<StepTransformAction> {
48 let LogicalPlan::Aggregate(input_aggr) = aggr_plan else {
49 return Err(datafusion_common::DataFusionError::Plan(
50 "step_aggr_to_upper_aggr only accepts Aggregate plan".to_string(),
51 ));
52 };
53 if !is_all_aggr_exprs_steppable(&input_aggr.aggr_expr) {
54 return Err(datafusion_common::DataFusionError::NotImplemented(format!(
55 "Some aggregate expressions are not steppable in [{}]",
56 input_aggr
57 .aggr_expr
58 .iter()
59 .map(|e| e.to_string())
60 .collect::<Vec<_>>()
61 .join(", ")
62 )));
63 }
64
65 let step_aggr_plan = StateMergeHelper::split_aggr_node(input_aggr.clone())?;
66
67 let ret = StepTransformAction {
69 extra_parent_plans: vec![step_aggr_plan.upper_merge.clone()],
70 new_child_plan: Some(step_aggr_plan.lower_state.clone()),
71 };
72 Ok(ret)
73}
74
75#[allow(dead_code)]
76pub enum Commutativity {
77 Commutative,
78 PartialCommutative,
79 ConditionalCommutative(Option<Transformer>),
80 TransformedCommutative {
81 transformer: Option<StageTransformer>,
83 },
84 NonCommutative,
85 Unimplemented,
86 Unsupported,
88}
89
90pub struct Categorizer {}
91
92impl Categorizer {
93 pub fn check_plan(
94 plan: &LogicalPlan,
95 partition_cols: Option<AliasMapping>,
96 ) -> DfResult<Commutativity> {
97 if has_subquery(plan)? {
100 return Ok(Commutativity::Unimplemented);
101 }
102
103 let partition_cols = partition_cols.unwrap_or_default();
104
105 let comm = match plan {
106 LogicalPlan::Projection(proj) => {
107 for expr in &proj.expr {
108 let commutativity = Self::check_expr(expr);
109 if !matches!(commutativity, Commutativity::Commutative) {
110 return Ok(commutativity);
111 }
112 }
113 Commutativity::Commutative
114 }
115 LogicalPlan::Filter(filter) => Self::check_expr(&filter.predicate),
117 LogicalPlan::Window(_) => Commutativity::Unimplemented,
118 LogicalPlan::Aggregate(aggr) => {
119 let is_all_steppable = is_all_aggr_exprs_steppable(&aggr.aggr_expr);
120 let matches_partition = Self::check_partition(&aggr.group_expr, &partition_cols);
121 if !matches_partition && is_all_steppable {
122 debug!("Plan is steppable: {plan}");
123 return Ok(Commutativity::TransformedCommutative {
124 transformer: Some(Arc::new(|plan: &LogicalPlan| {
125 debug!("Before Step optimize: {plan}");
126 let ret = step_aggr_to_upper_aggr(plan);
127 ret.inspect_err(|err| {
128 common_telemetry::error!("Failed to step aggregate plan: {err:?}");
129 })
130 .map(|s| TransformerAction {
131 extra_parent_plans: s.extra_parent_plans,
132 new_child_plan: s.new_child_plan,
133 })
134 })),
135 });
136 }
137 if !matches_partition {
138 return Ok(Commutativity::NonCommutative);
139 }
140 for expr in &aggr.aggr_expr {
141 let commutativity = Self::check_expr(expr);
142 if !matches!(commutativity, Commutativity::Commutative) {
143 return Ok(commutativity);
144 }
145 }
146 Commutativity::ConditionalCommutative(None)
151 }
152 LogicalPlan::Sort(_) => {
153 if partition_cols.is_empty() {
154 return Ok(Commutativity::Commutative);
155 }
156
157 Commutativity::ConditionalCommutative(Some(Arc::new(merge_sort_transformer)))
161 }
162 LogicalPlan::Join(_) => Commutativity::NonCommutative,
163 LogicalPlan::Repartition(_) => {
164 Commutativity::Unimplemented
166 }
167 LogicalPlan::Union(_) => Commutativity::Unimplemented,
168 LogicalPlan::TableScan(_) => Commutativity::Commutative,
169 LogicalPlan::EmptyRelation(_) => Commutativity::NonCommutative,
170 LogicalPlan::Subquery(_) => Commutativity::Unimplemented,
171 LogicalPlan::SubqueryAlias(_) => Commutativity::Commutative,
172 LogicalPlan::Limit(limit) => {
173 if partition_cols.is_empty() && limit.fetch.is_some() {
176 Commutativity::Commutative
177 } else if limit.skip.is_none() && limit.fetch.is_some() {
178 Commutativity::PartialCommutative
179 } else {
180 Commutativity::Unimplemented
181 }
182 }
183 LogicalPlan::Extension(extension) => {
184 Self::check_extension_plan(extension.node.as_ref() as _, &partition_cols)
185 }
186 LogicalPlan::Distinct(_) => {
187 if partition_cols.is_empty() {
188 Commutativity::Commutative
189 } else {
190 Commutativity::Unimplemented
191 }
192 }
193 LogicalPlan::Unnest(_) => Commutativity::Commutative,
194 LogicalPlan::Statement(_) => Commutativity::Unsupported,
195 LogicalPlan::Values(_) => Commutativity::Unsupported,
196 LogicalPlan::Explain(_) => Commutativity::Unsupported,
197 LogicalPlan::Analyze(_) => Commutativity::Unsupported,
198 LogicalPlan::DescribeTable(_) => Commutativity::Unsupported,
199 LogicalPlan::Dml(_) => Commutativity::Unsupported,
200 LogicalPlan::Ddl(_) => Commutativity::Unsupported,
201 LogicalPlan::Copy(_) => Commutativity::Unsupported,
202 LogicalPlan::RecursiveQuery(_) => Commutativity::Unsupported,
203 };
204
205 Ok(comm)
206 }
207
208 pub fn check_extension_plan(
209 plan: &dyn UserDefinedLogicalNode,
210 partition_cols: &AliasMapping,
211 ) -> Commutativity {
212 match plan.name() {
213 name if name == SeriesDivide::name() => {
214 let series_divide = plan.as_any().downcast_ref::<SeriesDivide>().unwrap();
215 let tags = series_divide.tags().iter().collect::<HashSet<_>>();
216
217 for all_alias in partition_cols.values() {
218 let all_alias = all_alias.iter().map(|c| &c.name).collect::<HashSet<_>>();
219 if tags.intersection(&all_alias).count() == 0 {
220 return Commutativity::NonCommutative;
221 }
222 }
223
224 Commutativity::Commutative
225 }
226 name if name == SeriesNormalize::name()
227 || name == InstantManipulate::name()
228 || name == RangeManipulate::name() =>
229 {
230 Commutativity::Commutative
233 }
234 name if name == EmptyMetric::name()
235 || name == MergeScanLogicalPlan::name()
236 || name == MergeSortLogicalPlan::name() =>
237 {
238 Commutativity::Unimplemented
239 }
240 _ => Commutativity::Unsupported,
241 }
242 }
243
244 pub fn check_expr(expr: &Expr) -> Commutativity {
245 #[allow(deprecated)]
246 match expr {
247 Expr::Column(_)
248 | Expr::ScalarVariable(_, _)
249 | Expr::Literal(_, _)
250 | Expr::BinaryExpr(_)
251 | Expr::Not(_)
252 | Expr::IsNotNull(_)
253 | Expr::IsNull(_)
254 | Expr::IsTrue(_)
255 | Expr::IsFalse(_)
256 | Expr::IsNotTrue(_)
257 | Expr::IsNotFalse(_)
258 | Expr::Negative(_)
259 | Expr::Between(_)
260 | Expr::Exists(_)
261 | Expr::InList(_)
262 | Expr::Case(_) => Commutativity::Commutative,
263 Expr::ScalarFunction(_udf) => Commutativity::Commutative,
264 Expr::AggregateFunction(_udaf) => Commutativity::Commutative,
265
266 Expr::Like(_)
267 | Expr::SimilarTo(_)
268 | Expr::IsUnknown(_)
269 | Expr::IsNotUnknown(_)
270 | Expr::Cast(_)
271 | Expr::TryCast(_)
272 | Expr::WindowFunction(_)
273 | Expr::InSubquery(_)
274 | Expr::ScalarSubquery(_)
275 | Expr::Wildcard { .. } => Commutativity::Unimplemented,
276
277 Expr::Alias(alias) => Self::check_expr(&alias.expr),
278
279 Expr::Unnest(_)
280 | Expr::GroupingSet(_)
281 | Expr::Placeholder(_)
282 | Expr::OuterReferenceColumn(_, _) => Commutativity::Unimplemented,
283 }
284 }
285
286 fn check_partition(exprs: &[Expr], partition_cols: &AliasMapping) -> bool {
293 let mut ref_cols = HashSet::new();
294 for expr in exprs {
295 expr.add_column_refs(&mut ref_cols);
296 }
297 let ref_cols = ref_cols
298 .into_iter()
299 .map(|c| c.name.clone())
300 .collect::<HashSet<_>>();
301 for all_alias in partition_cols.values() {
302 let all_alias = all_alias
303 .iter()
304 .map(|c| c.name.clone())
305 .collect::<HashSet<_>>();
306 if ref_cols.intersection(&all_alias).count() == 0 {
309 return false;
310 }
311 }
312
313 true
314 }
315}
316
317pub type Transformer = Arc<dyn Fn(&LogicalPlan) -> Option<LogicalPlan>>;
318
319pub type StageTransformer = Arc<dyn Fn(&LogicalPlan) -> DfResult<TransformerAction>>;
321
322pub struct TransformerAction {
324 pub extra_parent_plans: Vec<LogicalPlan>,
333 pub new_child_plan: Option<LogicalPlan>,
335}
336
337pub fn partial_commutative_transformer(plan: &LogicalPlan) -> Option<LogicalPlan> {
338 Some(plan.clone())
339}
340
341fn has_subquery(plan: &LogicalPlan) -> DfResult<bool> {
342 let mut found = false;
343 plan.apply_expressions(|e| {
344 e.apply(|x| {
345 if matches!(
346 x,
347 Expr::Exists(_) | Expr::InSubquery(_) | Expr::ScalarSubquery(_)
348 ) {
349 found = true;
350 Ok(TreeNodeRecursion::Stop)
351 } else {
352 Ok(TreeNodeRecursion::Continue)
353 }
354 })
355 })?;
356 Ok(found)
357}
358
359#[cfg(test)]
360mod test {
361 use datafusion_expr::{LogicalPlanBuilder, Sort};
362
363 use super::*;
364
365 #[test]
366 fn sort_on_empty_partition() {
367 let plan = LogicalPlan::Sort(Sort {
368 expr: vec![],
369 input: Arc::new(LogicalPlanBuilder::empty(false).build().unwrap()),
370 fetch: None,
371 });
372 assert!(matches!(
373 Categorizer::check_plan(&plan, Some(Default::default())).unwrap(),
374 Commutativity::Commutative
375 ));
376 }
377}