query/dist_plan/
commutativity.rs1use std::collections::HashSet;
16use std::sync::Arc;
17
18use common_function::aggrs::aggr_wrapper::{aggr_state_func_name, StateMergeHelper};
19use common_function::function_registry::FUNCTION_REGISTRY;
20use common_telemetry::debug;
21use datafusion_expr::{Expr, LogicalPlan, UserDefinedLogicalNode};
22use promql::extension_plan::{
23 EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize,
24};
25
26use crate::dist_plan::analyzer::AliasMapping;
27use crate::dist_plan::merge_sort::{merge_sort_transformer, MergeSortLogicalPlan};
28use crate::dist_plan::MergeScanLogicalPlan;
29
30pub struct StepTransformAction {
31 extra_parent_plans: Vec<LogicalPlan>,
32 new_child_plan: Option<LogicalPlan>,
33}
34
35pub fn step_aggr_to_upper_aggr(
45 aggr_plan: &LogicalPlan,
46) -> datafusion_common::Result<StepTransformAction> {
47 let LogicalPlan::Aggregate(input_aggr) = aggr_plan else {
48 return Err(datafusion_common::DataFusionError::Plan(
49 "step_aggr_to_upper_aggr only accepts Aggregate plan".to_string(),
50 ));
51 };
52 if !is_all_aggr_exprs_steppable(&input_aggr.aggr_expr) {
53 return Err(datafusion_common::DataFusionError::NotImplemented(format!(
54 "Some aggregate expressions are not steppable in [{}]",
55 input_aggr
56 .aggr_expr
57 .iter()
58 .map(|e| e.to_string())
59 .collect::<Vec<_>>()
60 .join(", ")
61 )));
62 }
63
64 let step_aggr_plan = StateMergeHelper::split_aggr_node(input_aggr.clone())?;
65
66 let ret = StepTransformAction {
68 extra_parent_plans: vec![step_aggr_plan.upper_merge.clone()],
69 new_child_plan: Some(step_aggr_plan.lower_state.clone()),
70 };
71 Ok(ret)
72}
73
74pub fn is_all_aggr_exprs_steppable(aggr_exprs: &[Expr]) -> bool {
79 aggr_exprs.iter().all(|expr| {
80 if let Some(aggr_func) = get_aggr_func(expr) {
81 if aggr_func.params.distinct {
82 return false;
84 }
85
86 FUNCTION_REGISTRY.is_aggr_func_exist(&aggr_state_func_name(aggr_func.func.name()))
88 } else {
89 false
90 }
91 })
92}
93
94pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
95 let mut expr_ref = expr;
96 while let Expr::Alias(alias) = expr_ref {
97 expr_ref = &alias.expr;
98 }
99 if let Expr::AggregateFunction(aggr_func) = expr_ref {
100 Some(aggr_func)
101 } else {
102 None
103 }
104}
105
106#[allow(dead_code)]
107pub enum Commutativity {
108 Commutative,
109 PartialCommutative,
110 ConditionalCommutative(Option<Transformer>),
111 TransformedCommutative {
112 transformer: Option<StageTransformer>,
114 },
115 NonCommutative,
116 Unimplemented,
117 Unsupported,
119}
120
121pub struct Categorizer {}
122
123impl Categorizer {
124 pub fn check_plan(plan: &LogicalPlan, partition_cols: Option<AliasMapping>) -> Commutativity {
125 let partition_cols = partition_cols.unwrap_or_default();
126
127 match plan {
128 LogicalPlan::Projection(proj) => {
129 for expr in &proj.expr {
130 let commutativity = Self::check_expr(expr);
131 if !matches!(commutativity, Commutativity::Commutative) {
132 return commutativity;
133 }
134 }
135 Commutativity::Commutative
136 }
137 LogicalPlan::Filter(filter) => Self::check_expr(&filter.predicate),
139 LogicalPlan::Window(_) => Commutativity::Unimplemented,
140 LogicalPlan::Aggregate(aggr) => {
141 let is_all_steppable = is_all_aggr_exprs_steppable(&aggr.aggr_expr);
142 let matches_partition = Self::check_partition(&aggr.group_expr, &partition_cols);
143 if !matches_partition && is_all_steppable {
144 debug!("Plan is steppable: {plan}");
145 return Commutativity::TransformedCommutative {
146 transformer: Some(Arc::new(|plan: &LogicalPlan| {
147 debug!("Before Step optimize: {plan}");
148 let ret = step_aggr_to_upper_aggr(plan);
149 ret.ok().map(|s| TransformerAction {
150 extra_parent_plans: s.extra_parent_plans,
151 new_child_plan: s.new_child_plan,
152 })
153 })),
154 };
155 }
156 if !matches_partition {
157 return Commutativity::NonCommutative;
158 }
159 for expr in &aggr.aggr_expr {
160 let commutativity = Self::check_expr(expr);
161 if !matches!(commutativity, Commutativity::Commutative) {
162 return commutativity;
163 }
164 }
165 Commutativity::ConditionalCommutative(None)
170 }
171 LogicalPlan::Sort(_) => {
172 if partition_cols.is_empty() {
173 return Commutativity::Commutative;
174 }
175
176 Commutativity::ConditionalCommutative(Some(Arc::new(merge_sort_transformer)))
180 }
181 LogicalPlan::Join(_) => Commutativity::NonCommutative,
182 LogicalPlan::Repartition(_) => {
183 Commutativity::Unimplemented
185 }
186 LogicalPlan::Union(_) => Commutativity::Unimplemented,
187 LogicalPlan::TableScan(_) => Commutativity::Commutative,
188 LogicalPlan::EmptyRelation(_) => Commutativity::NonCommutative,
189 LogicalPlan::Subquery(_) => Commutativity::Unimplemented,
190 LogicalPlan::SubqueryAlias(_) => Commutativity::Unimplemented,
191 LogicalPlan::Limit(limit) => {
192 if partition_cols.is_empty() && limit.fetch.is_some() {
195 Commutativity::Commutative
196 } else if limit.skip.is_none() && limit.fetch.is_some() {
197 Commutativity::PartialCommutative
198 } else {
199 Commutativity::Unimplemented
200 }
201 }
202 LogicalPlan::Extension(extension) => {
203 Self::check_extension_plan(extension.node.as_ref() as _, &partition_cols)
204 }
205 LogicalPlan::Distinct(_) => {
206 if partition_cols.is_empty() {
207 Commutativity::Commutative
208 } else {
209 Commutativity::Unimplemented
210 }
211 }
212 LogicalPlan::Unnest(_) => Commutativity::Commutative,
213 LogicalPlan::Statement(_) => Commutativity::Unsupported,
214 LogicalPlan::Values(_) => Commutativity::Unsupported,
215 LogicalPlan::Explain(_) => Commutativity::Unsupported,
216 LogicalPlan::Analyze(_) => Commutativity::Unsupported,
217 LogicalPlan::DescribeTable(_) => Commutativity::Unsupported,
218 LogicalPlan::Dml(_) => Commutativity::Unsupported,
219 LogicalPlan::Ddl(_) => Commutativity::Unsupported,
220 LogicalPlan::Copy(_) => Commutativity::Unsupported,
221 LogicalPlan::RecursiveQuery(_) => Commutativity::Unsupported,
222 }
223 }
224
225 pub fn check_extension_plan(
226 plan: &dyn UserDefinedLogicalNode,
227 partition_cols: &AliasMapping,
228 ) -> Commutativity {
229 match plan.name() {
230 name if name == SeriesDivide::name() => {
231 let series_divide = plan.as_any().downcast_ref::<SeriesDivide>().unwrap();
232 let tags = series_divide.tags().iter().collect::<HashSet<_>>();
233
234 for all_alias in partition_cols.values() {
235 let all_alias = all_alias.iter().map(|c| &c.name).collect::<HashSet<_>>();
236 if tags.intersection(&all_alias).count() == 0 {
237 return Commutativity::NonCommutative;
238 }
239 }
240
241 Commutativity::Commutative
242 }
243 name if name == SeriesNormalize::name()
244 || name == InstantManipulate::name()
245 || name == RangeManipulate::name() =>
246 {
247 Commutativity::Commutative
250 }
251 name if name == EmptyMetric::name()
252 || name == MergeScanLogicalPlan::name()
253 || name == MergeSortLogicalPlan::name() =>
254 {
255 Commutativity::Unimplemented
256 }
257 _ => Commutativity::Unsupported,
258 }
259 }
260
261 pub fn check_expr(expr: &Expr) -> Commutativity {
262 #[allow(deprecated)]
263 match expr {
264 Expr::Column(_)
265 | Expr::ScalarVariable(_, _)
266 | Expr::Literal(_, _)
267 | Expr::BinaryExpr(_)
268 | Expr::Not(_)
269 | Expr::IsNotNull(_)
270 | Expr::IsNull(_)
271 | Expr::IsTrue(_)
272 | Expr::IsFalse(_)
273 | Expr::IsNotTrue(_)
274 | Expr::IsNotFalse(_)
275 | Expr::Negative(_)
276 | Expr::Between(_)
277 | Expr::Exists(_)
278 | Expr::InList(_)
279 | Expr::Case(_) => Commutativity::Commutative,
280 Expr::ScalarFunction(_udf) => Commutativity::Commutative,
281 Expr::AggregateFunction(_udaf) => Commutativity::Commutative,
282
283 Expr::Like(_)
284 | Expr::SimilarTo(_)
285 | Expr::IsUnknown(_)
286 | Expr::IsNotUnknown(_)
287 | Expr::Cast(_)
288 | Expr::TryCast(_)
289 | Expr::WindowFunction(_)
290 | Expr::InSubquery(_)
291 | Expr::ScalarSubquery(_)
292 | Expr::Wildcard { .. } => Commutativity::Unimplemented,
293
294 Expr::Alias(alias) => Self::check_expr(&alias.expr),
295
296 Expr::Unnest(_)
297 | Expr::GroupingSet(_)
298 | Expr::Placeholder(_)
299 | Expr::OuterReferenceColumn(_, _) => Commutativity::Unimplemented,
300 }
301 }
302
303 fn check_partition(exprs: &[Expr], partition_cols: &AliasMapping) -> bool {
306 let mut ref_cols = HashSet::new();
307 for expr in exprs {
308 expr.add_column_refs(&mut ref_cols);
309 }
310 let ref_cols = ref_cols
311 .into_iter()
312 .map(|c| c.name.clone())
313 .collect::<HashSet<_>>();
314 for all_alias in partition_cols.values() {
315 let all_alias = all_alias
316 .iter()
317 .map(|c| c.name.clone())
318 .collect::<HashSet<_>>();
319 if ref_cols.intersection(&all_alias).count() == 0 {
322 return false;
323 }
324 }
325
326 true
327 }
328}
329
330pub type Transformer = Arc<dyn Fn(&LogicalPlan) -> Option<LogicalPlan>>;
331
332pub type StageTransformer = Arc<dyn Fn(&LogicalPlan) -> Option<TransformerAction>>;
334
335pub struct TransformerAction {
337 pub extra_parent_plans: Vec<LogicalPlan>,
346 pub new_child_plan: Option<LogicalPlan>,
348}
349
350pub fn partial_commutative_transformer(plan: &LogicalPlan) -> Option<LogicalPlan> {
351 Some(plan.clone())
352}
353
354#[cfg(test)]
355mod test {
356 use datafusion_expr::{LogicalPlanBuilder, Sort};
357
358 use super::*;
359
360 #[test]
361 fn sort_on_empty_partition() {
362 let plan = LogicalPlan::Sort(Sort {
363 expr: vec![],
364 input: Arc::new(LogicalPlanBuilder::empty(false).build().unwrap()),
365 fetch: None,
366 });
367 assert!(matches!(
368 Categorizer::check_plan(&plan, Some(Default::default())),
369 Commutativity::Commutative
370 ));
371 }
372}