1use std::collections::HashSet;
16use std::sync::Arc;
17
18use common_function::aggrs::approximate::hll::{HllState, HLL_MERGE_NAME, HLL_NAME};
19use common_function::aggrs::approximate::uddsketch::{
20 UddSketchState, UDDSKETCH_MERGE_NAME, UDDSKETCH_STATE_NAME,
21};
22use common_telemetry::debug;
23use datafusion::functions_aggregate::sum::sum_udaf;
24use datafusion_common::Column;
25use datafusion_expr::{Expr, LogicalPlan, Projection, UserDefinedLogicalNode};
26use promql::extension_plan::{
27 EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize,
28};
29
30use crate::dist_plan::analyzer::AliasMapping;
31use crate::dist_plan::merge_sort::{merge_sort_transformer, MergeSortLogicalPlan};
32use crate::dist_plan::MergeScanLogicalPlan;
33
34pub fn step_aggr_to_upper_aggr(
44 aggr_plan: &LogicalPlan,
45) -> datafusion_common::Result<[LogicalPlan; 2]> {
46 let LogicalPlan::Aggregate(input_aggr) = aggr_plan else {
47 return Err(datafusion_common::DataFusionError::Plan(
48 "step_aggr_to_upper_aggr only accepts Aggregate plan".to_string(),
49 ));
50 };
51 if !is_all_aggr_exprs_steppable(&input_aggr.aggr_expr) {
52 return Err(datafusion_common::DataFusionError::NotImplemented(
53 "Some aggregate expressions are not steppable".to_string(),
54 ));
55 }
56 let mut upper_aggr_expr = vec![];
57 for aggr_expr in &input_aggr.aggr_expr {
58 let Some(aggr_func) = get_aggr_func(aggr_expr) else {
59 return Err(datafusion_common::DataFusionError::NotImplemented(
60 "Aggregate function not found".to_string(),
61 ));
62 };
63 let col_name = aggr_expr.qualified_name();
64 let input_column = Expr::Column(datafusion_common::Column::new(col_name.0, col_name.1));
65 let upper_func = match aggr_func.func.name() {
66 "sum" | "min" | "max" | "last_value" | "first_value" => {
67 let mut new_aggr_func = aggr_func.clone();
69 new_aggr_func.args = vec![input_column.clone()];
70 new_aggr_func
71 }
72 "count" => {
73 let mut new_aggr_func = aggr_func.clone();
75 new_aggr_func.func = sum_udaf();
76 new_aggr_func.args = vec![input_column.clone()];
77 new_aggr_func
78 }
79 UDDSKETCH_STATE_NAME | UDDSKETCH_MERGE_NAME => {
80 let mut new_aggr_func = aggr_func.clone();
82 new_aggr_func.func = Arc::new(UddSketchState::merge_udf_impl());
83 new_aggr_func.args[2] = input_column.clone();
84 new_aggr_func
85 }
86 HLL_NAME | HLL_MERGE_NAME => {
87 let mut new_aggr_func = aggr_func.clone();
89 new_aggr_func.func = Arc::new(HllState::merge_udf_impl());
90 new_aggr_func.args = vec![input_column.clone()];
91 new_aggr_func
92 }
93 _ => {
94 return Err(datafusion_common::DataFusionError::NotImplemented(format!(
95 "Aggregate function {} is not supported for Step aggregation",
96 aggr_func.func.name()
97 )))
98 }
99 };
100
101 let mut new_aggr_expr = aggr_expr.clone();
103 {
104 let new_aggr_func = get_aggr_func_mut(&mut new_aggr_expr).unwrap();
105 *new_aggr_func = upper_func;
106 }
107
108 upper_aggr_expr.push(new_aggr_expr);
109 }
110 let mut new_aggr = input_aggr.clone();
111 new_aggr.input = Arc::new(LogicalPlan::Aggregate(input_aggr.clone()));
113
114 new_aggr.aggr_expr = upper_aggr_expr;
115
116 let mut new_group_expr = new_aggr.group_expr.clone();
118 for expr in &mut new_group_expr {
119 if let Expr::Column(_) = expr {
120 continue;
122 }
123 let col_name = expr.qualified_name();
124 let input_column = Expr::Column(datafusion_common::Column::new(col_name.0, col_name.1));
125 *expr = input_column;
126 }
127 new_aggr.group_expr = new_group_expr.clone();
128
129 let mut new_projection_exprs = new_group_expr;
130 for (lower_aggr_expr, upper_aggr_expr) in
133 input_aggr.aggr_expr.iter().zip(new_aggr.aggr_expr.iter())
134 {
135 let lower_col_name = lower_aggr_expr.qualified_name();
136 let (table, col_name) = upper_aggr_expr.qualified_name();
137 let aggr_out_column = Column::new(table, col_name);
138 let aliased_output_aggr_expr =
139 Expr::Column(aggr_out_column).alias_qualified(lower_col_name.0, lower_col_name.1);
140 new_projection_exprs.push(aliased_output_aggr_expr);
141 }
142 let upper_aggr_plan = LogicalPlan::Aggregate(new_aggr);
143 let upper_aggr_plan = upper_aggr_plan.recompute_schema()?;
144 let new_projection =
146 Projection::try_new(new_projection_exprs, Arc::new(upper_aggr_plan.clone()))?;
147 let projection = LogicalPlan::Projection(new_projection);
148 Ok([projection, upper_aggr_plan])
150}
151
152pub fn is_all_aggr_exprs_steppable(aggr_exprs: &[Expr]) -> bool {
157 let step_action = HashSet::from([
158 "sum",
159 "count",
160 "min",
161 "max",
162 "first_value",
163 "last_value",
164 UDDSKETCH_STATE_NAME,
165 UDDSKETCH_MERGE_NAME,
166 HLL_NAME,
167 HLL_MERGE_NAME,
168 ]);
169 aggr_exprs.iter().all(|expr| {
170 if let Some(aggr_func) = get_aggr_func(expr) {
171 if aggr_func.distinct {
172 return false;
174 }
175 step_action.contains(aggr_func.func.name())
176 } else {
177 false
178 }
179 })
180}
181
182pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
183 let mut expr_ref = expr;
184 while let Expr::Alias(alias) = expr_ref {
185 expr_ref = &alias.expr;
186 }
187 if let Expr::AggregateFunction(aggr_func) = expr_ref {
188 Some(aggr_func)
189 } else {
190 None
191 }
192}
193
194pub fn get_aggr_func_mut(expr: &mut Expr) -> Option<&mut datafusion_expr::expr::AggregateFunction> {
195 let mut expr_ref = expr;
196 while let Expr::Alias(alias) = expr_ref {
197 expr_ref = &mut alias.expr;
198 }
199 if let Expr::AggregateFunction(aggr_func) = expr_ref {
200 Some(aggr_func)
201 } else {
202 None
203 }
204}
205
206#[allow(dead_code)]
207pub enum Commutativity {
208 Commutative,
209 PartialCommutative,
210 ConditionalCommutative(Option<Transformer>),
211 TransformedCommutative {
212 transformer: Option<StageTransformer>,
214 },
215 NonCommutative,
216 Unimplemented,
217 Unsupported,
219}
220
221pub struct Categorizer {}
222
223impl Categorizer {
224 pub fn check_plan(plan: &LogicalPlan, partition_cols: Option<AliasMapping>) -> Commutativity {
225 let partition_cols = partition_cols.unwrap_or_default();
226
227 match plan {
228 LogicalPlan::Projection(proj) => {
229 for expr in &proj.expr {
230 let commutativity = Self::check_expr(expr);
231 if !matches!(commutativity, Commutativity::Commutative) {
232 return commutativity;
233 }
234 }
235 Commutativity::Commutative
236 }
237 LogicalPlan::Filter(filter) => Self::check_expr(&filter.predicate),
239 LogicalPlan::Window(_) => Commutativity::Unimplemented,
240 LogicalPlan::Aggregate(aggr) => {
241 let is_all_steppable = is_all_aggr_exprs_steppable(&aggr.aggr_expr);
242 let matches_partition = Self::check_partition(&aggr.group_expr, &partition_cols);
243 if !matches_partition && is_all_steppable {
244 debug!("Plan is steppable: {plan}");
245 return Commutativity::TransformedCommutative {
246 transformer: Some(Arc::new(|plan: &LogicalPlan| {
247 debug!("Before Step optimize: {plan}");
248 let ret = step_aggr_to_upper_aggr(plan);
249 ret.ok().map(|s| TransformerAction {
250 extra_parent_plans: s.to_vec(),
251 new_child_plan: None,
252 })
253 })),
254 };
255 }
256 if !matches_partition {
257 return Commutativity::NonCommutative;
258 }
259 for expr in &aggr.aggr_expr {
260 let commutativity = Self::check_expr(expr);
261 if !matches!(commutativity, Commutativity::Commutative) {
262 return commutativity;
263 }
264 }
265 Commutativity::ConditionalCommutative(None)
270 }
271 LogicalPlan::Sort(_) => {
272 if partition_cols.is_empty() {
273 return Commutativity::Commutative;
274 }
275
276 Commutativity::ConditionalCommutative(Some(Arc::new(merge_sort_transformer)))
280 }
281 LogicalPlan::Join(_) => Commutativity::NonCommutative,
282 LogicalPlan::Repartition(_) => {
283 Commutativity::Unimplemented
285 }
286 LogicalPlan::Union(_) => Commutativity::Unimplemented,
287 LogicalPlan::TableScan(_) => Commutativity::Commutative,
288 LogicalPlan::EmptyRelation(_) => Commutativity::NonCommutative,
289 LogicalPlan::Subquery(_) => Commutativity::Unimplemented,
290 LogicalPlan::SubqueryAlias(_) => Commutativity::Unimplemented,
291 LogicalPlan::Limit(limit) => {
292 if partition_cols.is_empty() && limit.fetch.is_some() {
295 Commutativity::Commutative
296 } else if limit.skip.is_none() && limit.fetch.is_some() {
297 Commutativity::PartialCommutative
298 } else {
299 Commutativity::Unimplemented
300 }
301 }
302 LogicalPlan::Extension(extension) => {
303 Self::check_extension_plan(extension.node.as_ref() as _, &partition_cols)
304 }
305 LogicalPlan::Distinct(_) => {
306 if partition_cols.is_empty() {
307 Commutativity::Commutative
308 } else {
309 Commutativity::Unimplemented
310 }
311 }
312 LogicalPlan::Unnest(_) => Commutativity::Commutative,
313 LogicalPlan::Statement(_) => Commutativity::Unsupported,
314 LogicalPlan::Values(_) => Commutativity::Unsupported,
315 LogicalPlan::Explain(_) => Commutativity::Unsupported,
316 LogicalPlan::Analyze(_) => Commutativity::Unsupported,
317 LogicalPlan::DescribeTable(_) => Commutativity::Unsupported,
318 LogicalPlan::Dml(_) => Commutativity::Unsupported,
319 LogicalPlan::Ddl(_) => Commutativity::Unsupported,
320 LogicalPlan::Copy(_) => Commutativity::Unsupported,
321 LogicalPlan::RecursiveQuery(_) => Commutativity::Unsupported,
322 }
323 }
324
325 pub fn check_extension_plan(
326 plan: &dyn UserDefinedLogicalNode,
327 partition_cols: &AliasMapping,
328 ) -> Commutativity {
329 match plan.name() {
330 name if name == SeriesDivide::name() => {
331 let series_divide = plan.as_any().downcast_ref::<SeriesDivide>().unwrap();
332 let tags = series_divide.tags().iter().collect::<HashSet<_>>();
333
334 for all_alias in partition_cols.values() {
335 let all_alias = all_alias.iter().map(|c| &c.name).collect::<HashSet<_>>();
336 if tags.intersection(&all_alias).count() == 0 {
337 return Commutativity::NonCommutative;
338 }
339 }
340
341 Commutativity::Commutative
342 }
343 name if name == SeriesNormalize::name()
344 || name == InstantManipulate::name()
345 || name == RangeManipulate::name() =>
346 {
347 Commutativity::Commutative
350 }
351 name if name == EmptyMetric::name()
352 || name == MergeScanLogicalPlan::name()
353 || name == MergeSortLogicalPlan::name() =>
354 {
355 Commutativity::Unimplemented
356 }
357 _ => Commutativity::Unsupported,
358 }
359 }
360
361 pub fn check_expr(expr: &Expr) -> Commutativity {
362 match expr {
363 Expr::Column(_)
364 | Expr::ScalarVariable(_, _)
365 | Expr::Literal(_)
366 | Expr::BinaryExpr(_)
367 | Expr::Not(_)
368 | Expr::IsNotNull(_)
369 | Expr::IsNull(_)
370 | Expr::IsTrue(_)
371 | Expr::IsFalse(_)
372 | Expr::IsNotTrue(_)
373 | Expr::IsNotFalse(_)
374 | Expr::Negative(_)
375 | Expr::Between(_)
376 | Expr::Exists(_)
377 | Expr::InList(_)
378 | Expr::Case(_) => Commutativity::Commutative,
379 Expr::ScalarFunction(_udf) => Commutativity::Commutative,
380 Expr::AggregateFunction(_udaf) => Commutativity::Commutative,
381
382 Expr::Like(_)
383 | Expr::SimilarTo(_)
384 | Expr::IsUnknown(_)
385 | Expr::IsNotUnknown(_)
386 | Expr::Cast(_)
387 | Expr::TryCast(_)
388 | Expr::WindowFunction(_)
389 | Expr::InSubquery(_)
390 | Expr::ScalarSubquery(_)
391 | Expr::Wildcard { .. } => Commutativity::Unimplemented,
392
393 Expr::Alias(alias) => Self::check_expr(&alias.expr),
394
395 Expr::Unnest(_)
396 | Expr::GroupingSet(_)
397 | Expr::Placeholder(_)
398 | Expr::OuterReferenceColumn(_, _) => Commutativity::Unimplemented,
399 }
400 }
401
402 fn check_partition(exprs: &[Expr], partition_cols: &AliasMapping) -> bool {
405 let mut ref_cols = HashSet::new();
406 for expr in exprs {
407 expr.add_column_refs(&mut ref_cols);
408 }
409 let ref_cols = ref_cols
410 .into_iter()
411 .map(|c| c.name.clone())
412 .collect::<HashSet<_>>();
413 for all_alias in partition_cols.values() {
414 let all_alias = all_alias
415 .iter()
416 .map(|c| c.name.clone())
417 .collect::<HashSet<_>>();
418 if ref_cols.intersection(&all_alias).count() == 0 {
421 return false;
422 }
423 }
424
425 true
426 }
427}
428
429pub type Transformer = Arc<dyn Fn(&LogicalPlan) -> Option<LogicalPlan>>;
430
431pub type StageTransformer = Arc<dyn Fn(&LogicalPlan) -> Option<TransformerAction>>;
433
434pub struct TransformerAction {
436 pub extra_parent_plans: Vec<LogicalPlan>,
445 pub new_child_plan: Option<LogicalPlan>,
447}
448
449pub fn partial_commutative_transformer(plan: &LogicalPlan) -> Option<LogicalPlan> {
450 Some(plan.clone())
451}
452
453#[cfg(test)]
454mod test {
455 use datafusion_expr::{LogicalPlanBuilder, Sort};
456
457 use super::*;
458
459 #[test]
460 fn sort_on_empty_partition() {
461 let plan = LogicalPlan::Sort(Sort {
462 expr: vec![],
463 input: Arc::new(LogicalPlanBuilder::empty(false).build().unwrap()),
464 fetch: None,
465 });
466 assert!(matches!(
467 Categorizer::check_plan(&plan, Some(Default::default())),
468 Commutativity::Commutative
469 ));
470 }
471}