query/dist_plan/
commutativity.rs1use std::collections::HashSet;
16use std::sync::Arc;
17
18use common_function::aggrs::aggr_wrapper::{StateMergeHelper, is_all_aggr_exprs_steppable};
19use common_telemetry::debug;
20use datafusion::error::Result as DfResult;
21use datafusion_expr::{Expr, LogicalPlan, UserDefinedLogicalNode};
22use promql::extension_plan::{
23 EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize,
24};
25
26use crate::dist_plan::MergeScanLogicalPlan;
27use crate::dist_plan::analyzer::AliasMapping;
28use crate::dist_plan::merge_sort::{MergeSortLogicalPlan, merge_sort_transformer};
29
30pub struct StepTransformAction {
31 extra_parent_plans: Vec<LogicalPlan>,
32 new_child_plan: Option<LogicalPlan>,
33}
34
35pub fn step_aggr_to_upper_aggr(
45 aggr_plan: &LogicalPlan,
46) -> datafusion_common::Result<StepTransformAction> {
47 let LogicalPlan::Aggregate(input_aggr) = aggr_plan else {
48 return Err(datafusion_common::DataFusionError::Plan(
49 "step_aggr_to_upper_aggr only accepts Aggregate plan".to_string(),
50 ));
51 };
52 if !is_all_aggr_exprs_steppable(&input_aggr.aggr_expr) {
53 return Err(datafusion_common::DataFusionError::NotImplemented(format!(
54 "Some aggregate expressions are not steppable in [{}]",
55 input_aggr
56 .aggr_expr
57 .iter()
58 .map(|e| e.to_string())
59 .collect::<Vec<_>>()
60 .join(", ")
61 )));
62 }
63
64 let step_aggr_plan = StateMergeHelper::split_aggr_node(input_aggr.clone())?;
65
66 let ret = StepTransformAction {
68 extra_parent_plans: vec![step_aggr_plan.upper_merge.clone()],
69 new_child_plan: Some(step_aggr_plan.lower_state.clone()),
70 };
71 Ok(ret)
72}
73
74#[allow(dead_code)]
75pub enum Commutativity {
76 Commutative,
77 PartialCommutative,
78 ConditionalCommutative(Option<Transformer>),
79 TransformedCommutative {
80 transformer: Option<StageTransformer>,
82 },
83 NonCommutative,
84 Unimplemented,
85 Unsupported,
87}
88
89pub struct Categorizer {}
90
91impl Categorizer {
92 pub fn check_plan(
93 plan: &LogicalPlan,
94 partition_cols: Option<AliasMapping>,
95 ) -> DfResult<Commutativity> {
96 let partition_cols = partition_cols.unwrap_or_default();
97
98 let comm = match plan {
99 LogicalPlan::Projection(proj) => {
100 for expr in &proj.expr {
101 let commutativity = Self::check_expr(expr);
102 if !matches!(commutativity, Commutativity::Commutative) {
103 return Ok(commutativity);
104 }
105 }
106 Commutativity::Commutative
107 }
108 LogicalPlan::Filter(filter) => Self::check_expr(&filter.predicate),
110 LogicalPlan::Window(_) => Commutativity::Unimplemented,
111 LogicalPlan::Aggregate(aggr) => {
112 let is_all_steppable = is_all_aggr_exprs_steppable(&aggr.aggr_expr);
113 let matches_partition = Self::check_partition(&aggr.group_expr, &partition_cols);
114 if !matches_partition && is_all_steppable {
115 debug!("Plan is steppable: {plan}");
116 return Ok(Commutativity::TransformedCommutative {
117 transformer: Some(Arc::new(|plan: &LogicalPlan| {
118 debug!("Before Step optimize: {plan}");
119 let ret = step_aggr_to_upper_aggr(plan);
120 ret.inspect_err(|err| {
121 common_telemetry::error!("Failed to step aggregate plan: {err:?}");
122 })
123 .map(|s| TransformerAction {
124 extra_parent_plans: s.extra_parent_plans,
125 new_child_plan: s.new_child_plan,
126 })
127 })),
128 });
129 }
130 if !matches_partition {
131 return Ok(Commutativity::NonCommutative);
132 }
133 for expr in &aggr.aggr_expr {
134 let commutativity = Self::check_expr(expr);
135 if !matches!(commutativity, Commutativity::Commutative) {
136 return Ok(commutativity);
137 }
138 }
139 Commutativity::ConditionalCommutative(None)
144 }
145 LogicalPlan::Sort(_) => {
146 if partition_cols.is_empty() {
147 return Ok(Commutativity::Commutative);
148 }
149
150 Commutativity::ConditionalCommutative(Some(Arc::new(merge_sort_transformer)))
154 }
155 LogicalPlan::Join(_) => Commutativity::NonCommutative,
156 LogicalPlan::Repartition(_) => {
157 Commutativity::Unimplemented
159 }
160 LogicalPlan::Union(_) => Commutativity::Unimplemented,
161 LogicalPlan::TableScan(_) => Commutativity::Commutative,
162 LogicalPlan::EmptyRelation(_) => Commutativity::NonCommutative,
163 LogicalPlan::Subquery(_) => Commutativity::Unimplemented,
164 LogicalPlan::SubqueryAlias(_) => Commutativity::Commutative,
165 LogicalPlan::Limit(limit) => {
166 if partition_cols.is_empty() && limit.fetch.is_some() {
169 Commutativity::Commutative
170 } else if limit.skip.is_none() && limit.fetch.is_some() {
171 Commutativity::PartialCommutative
172 } else {
173 Commutativity::Unimplemented
174 }
175 }
176 LogicalPlan::Extension(extension) => {
177 Self::check_extension_plan(extension.node.as_ref() as _, &partition_cols)
178 }
179 LogicalPlan::Distinct(_) => {
180 if partition_cols.is_empty() {
181 Commutativity::Commutative
182 } else {
183 Commutativity::Unimplemented
184 }
185 }
186 LogicalPlan::Unnest(_) => Commutativity::Commutative,
187 LogicalPlan::Statement(_) => Commutativity::Unsupported,
188 LogicalPlan::Values(_) => Commutativity::Unsupported,
189 LogicalPlan::Explain(_) => Commutativity::Unsupported,
190 LogicalPlan::Analyze(_) => Commutativity::Unsupported,
191 LogicalPlan::DescribeTable(_) => Commutativity::Unsupported,
192 LogicalPlan::Dml(_) => Commutativity::Unsupported,
193 LogicalPlan::Ddl(_) => Commutativity::Unsupported,
194 LogicalPlan::Copy(_) => Commutativity::Unsupported,
195 LogicalPlan::RecursiveQuery(_) => Commutativity::Unsupported,
196 };
197
198 Ok(comm)
199 }
200
201 pub fn check_extension_plan(
202 plan: &dyn UserDefinedLogicalNode,
203 partition_cols: &AliasMapping,
204 ) -> Commutativity {
205 match plan.name() {
206 name if name == SeriesDivide::name() => {
207 let series_divide = plan.as_any().downcast_ref::<SeriesDivide>().unwrap();
208 let tags = series_divide.tags().iter().collect::<HashSet<_>>();
209
210 for all_alias in partition_cols.values() {
211 let all_alias = all_alias.iter().map(|c| &c.name).collect::<HashSet<_>>();
212 if tags.intersection(&all_alias).count() == 0 {
213 return Commutativity::NonCommutative;
214 }
215 }
216
217 Commutativity::Commutative
218 }
219 name if name == SeriesNormalize::name()
220 || name == InstantManipulate::name()
221 || name == RangeManipulate::name() =>
222 {
223 Commutativity::Commutative
226 }
227 name if name == EmptyMetric::name()
228 || name == MergeScanLogicalPlan::name()
229 || name == MergeSortLogicalPlan::name() =>
230 {
231 Commutativity::Unimplemented
232 }
233 _ => Commutativity::Unsupported,
234 }
235 }
236
237 pub fn check_expr(expr: &Expr) -> Commutativity {
238 #[allow(deprecated)]
239 match expr {
240 Expr::Column(_)
241 | Expr::ScalarVariable(_, _)
242 | Expr::Literal(_, _)
243 | Expr::BinaryExpr(_)
244 | Expr::Not(_)
245 | Expr::IsNotNull(_)
246 | Expr::IsNull(_)
247 | Expr::IsTrue(_)
248 | Expr::IsFalse(_)
249 | Expr::IsNotTrue(_)
250 | Expr::IsNotFalse(_)
251 | Expr::Negative(_)
252 | Expr::Between(_)
253 | Expr::Exists(_)
254 | Expr::InList(_)
255 | Expr::Case(_) => Commutativity::Commutative,
256 Expr::ScalarFunction(_udf) => Commutativity::Commutative,
257 Expr::AggregateFunction(_udaf) => Commutativity::Commutative,
258
259 Expr::Like(_)
260 | Expr::SimilarTo(_)
261 | Expr::IsUnknown(_)
262 | Expr::IsNotUnknown(_)
263 | Expr::Cast(_)
264 | Expr::TryCast(_)
265 | Expr::WindowFunction(_)
266 | Expr::InSubquery(_)
267 | Expr::ScalarSubquery(_)
268 | Expr::Wildcard { .. } => Commutativity::Unimplemented,
269
270 Expr::Alias(alias) => Self::check_expr(&alias.expr),
271
272 Expr::Unnest(_)
273 | Expr::GroupingSet(_)
274 | Expr::Placeholder(_)
275 | Expr::OuterReferenceColumn(_, _) => Commutativity::Unimplemented,
276 }
277 }
278
279 fn check_partition(exprs: &[Expr], partition_cols: &AliasMapping) -> bool {
286 let mut ref_cols = HashSet::new();
287 for expr in exprs {
288 expr.add_column_refs(&mut ref_cols);
289 }
290 let ref_cols = ref_cols
291 .into_iter()
292 .map(|c| c.name.clone())
293 .collect::<HashSet<_>>();
294 for all_alias in partition_cols.values() {
295 let all_alias = all_alias
296 .iter()
297 .map(|c| c.name.clone())
298 .collect::<HashSet<_>>();
299 if ref_cols.intersection(&all_alias).count() == 0 {
302 return false;
303 }
304 }
305
306 true
307 }
308}
309
310pub type Transformer = Arc<dyn Fn(&LogicalPlan) -> Option<LogicalPlan>>;
311
312pub type StageTransformer = Arc<dyn Fn(&LogicalPlan) -> DfResult<TransformerAction>>;
314
315pub struct TransformerAction {
317 pub extra_parent_plans: Vec<LogicalPlan>,
326 pub new_child_plan: Option<LogicalPlan>,
328}
329
330pub fn partial_commutative_transformer(plan: &LogicalPlan) -> Option<LogicalPlan> {
331 Some(plan.clone())
332}
333
334#[cfg(test)]
335mod test {
336 use datafusion_expr::{LogicalPlanBuilder, Sort};
337
338 use super::*;
339
340 #[test]
341 fn sort_on_empty_partition() {
342 let plan = LogicalPlan::Sort(Sort {
343 expr: vec![],
344 input: Arc::new(LogicalPlanBuilder::empty(false).build().unwrap()),
345 fetch: None,
346 });
347 assert!(matches!(
348 Categorizer::check_plan(&plan, Some(Default::default())).unwrap(),
349 Commutativity::Commutative
350 ));
351 }
352}