1use std::collections::HashSet;
16use std::sync::Arc;
17
18use common_function::aggrs::approximate::hll::{HllState, HLL_MERGE_NAME, HLL_NAME};
19use common_function::aggrs::approximate::uddsketch::{
20 UddSketchState, UDDSKETCH_MERGE_NAME, UDDSKETCH_STATE_NAME,
21};
22use common_telemetry::debug;
23use datafusion::functions_aggregate::sum::sum_udaf;
24use datafusion_common::Column;
25use datafusion_expr::{Expr, LogicalPlan, Projection, UserDefinedLogicalNode};
26use promql::extension_plan::{
27 EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize,
28};
29
30use crate::dist_plan::merge_sort::{merge_sort_transformer, MergeSortLogicalPlan};
31use crate::dist_plan::MergeScanLogicalPlan;
32
33pub fn step_aggr_to_upper_aggr(
43 aggr_plan: &LogicalPlan,
44) -> datafusion_common::Result<[LogicalPlan; 2]> {
45 let LogicalPlan::Aggregate(input_aggr) = aggr_plan else {
46 return Err(datafusion_common::DataFusionError::Plan(
47 "step_aggr_to_upper_aggr only accepts Aggregate plan".to_string(),
48 ));
49 };
50 if !is_all_aggr_exprs_steppable(&input_aggr.aggr_expr) {
51 return Err(datafusion_common::DataFusionError::NotImplemented(
52 "Some aggregate expressions are not steppable".to_string(),
53 ));
54 }
55 let mut upper_aggr_expr = vec![];
56 for aggr_expr in &input_aggr.aggr_expr {
57 let Some(aggr_func) = get_aggr_func(aggr_expr) else {
58 return Err(datafusion_common::DataFusionError::NotImplemented(
59 "Aggregate function not found".to_string(),
60 ));
61 };
62 let col_name = aggr_expr.qualified_name();
63 let input_column = Expr::Column(datafusion_common::Column::new(col_name.0, col_name.1));
64 let upper_func = match aggr_func.func.name() {
65 "sum" | "min" | "max" | "last_value" | "first_value" => {
66 let mut new_aggr_func = aggr_func.clone();
68 new_aggr_func.args = vec![input_column.clone()];
69 new_aggr_func
70 }
71 "count" => {
72 let mut new_aggr_func = aggr_func.clone();
74 new_aggr_func.func = sum_udaf();
75 new_aggr_func.args = vec![input_column.clone()];
76 new_aggr_func
77 }
78 UDDSKETCH_STATE_NAME | UDDSKETCH_MERGE_NAME => {
79 let mut new_aggr_func = aggr_func.clone();
81 new_aggr_func.func = Arc::new(UddSketchState::merge_udf_impl());
82 new_aggr_func.args[2] = input_column.clone();
83 new_aggr_func
84 }
85 HLL_NAME | HLL_MERGE_NAME => {
86 let mut new_aggr_func = aggr_func.clone();
88 new_aggr_func.func = Arc::new(HllState::merge_udf_impl());
89 new_aggr_func.args = vec![input_column.clone()];
90 new_aggr_func
91 }
92 _ => {
93 return Err(datafusion_common::DataFusionError::NotImplemented(format!(
94 "Aggregate function {} is not supported for Step aggregation",
95 aggr_func.func.name()
96 )))
97 }
98 };
99
100 let mut new_aggr_expr = aggr_expr.clone();
102 {
103 let new_aggr_func = get_aggr_func_mut(&mut new_aggr_expr).unwrap();
104 *new_aggr_func = upper_func;
105 }
106
107 upper_aggr_expr.push(new_aggr_expr);
108 }
109 let mut new_aggr = input_aggr.clone();
110 new_aggr.input = Arc::new(LogicalPlan::Aggregate(input_aggr.clone()));
112
113 new_aggr.aggr_expr = upper_aggr_expr;
114
115 let mut new_group_expr = new_aggr.group_expr.clone();
117 for expr in &mut new_group_expr {
118 if let Expr::Column(_) = expr {
119 continue;
121 }
122 let col_name = expr.qualified_name();
123 let input_column = Expr::Column(datafusion_common::Column::new(col_name.0, col_name.1));
124 *expr = input_column;
125 }
126 new_aggr.group_expr = new_group_expr.clone();
127
128 let mut new_projection_exprs = new_group_expr;
129 for (lower_aggr_expr, upper_aggr_expr) in
132 input_aggr.aggr_expr.iter().zip(new_aggr.aggr_expr.iter())
133 {
134 let lower_col_name = lower_aggr_expr.qualified_name();
135 let (table, col_name) = upper_aggr_expr.qualified_name();
136 let aggr_out_column = Column::new(table, col_name);
137 let aliased_output_aggr_expr =
138 Expr::Column(aggr_out_column).alias_qualified(lower_col_name.0, lower_col_name.1);
139 new_projection_exprs.push(aliased_output_aggr_expr);
140 }
141 let upper_aggr_plan = LogicalPlan::Aggregate(new_aggr);
142 debug!("Before recompute schema: {upper_aggr_plan:?}");
143 let upper_aggr_plan = upper_aggr_plan.recompute_schema()?;
144 debug!("After recompute schema: {upper_aggr_plan:?}");
145 let new_projection =
147 Projection::try_new(new_projection_exprs, Arc::new(upper_aggr_plan.clone()))?;
148 let projection = LogicalPlan::Projection(new_projection);
149 Ok([projection, upper_aggr_plan])
151}
152
153pub fn is_all_aggr_exprs_steppable(aggr_exprs: &[Expr]) -> bool {
158 let step_action = HashSet::from([
159 "sum",
160 "count",
161 "min",
162 "max",
163 "first_value",
164 "last_value",
165 UDDSKETCH_STATE_NAME,
166 UDDSKETCH_MERGE_NAME,
167 HLL_NAME,
168 HLL_MERGE_NAME,
169 ]);
170 aggr_exprs.iter().all(|expr| {
171 if let Some(aggr_func) = get_aggr_func(expr) {
172 if aggr_func.distinct {
173 return false;
175 }
176 step_action.contains(aggr_func.func.name())
177 } else {
178 false
179 }
180 })
181}
182
183pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
184 let mut expr_ref = expr;
185 while let Expr::Alias(alias) = expr_ref {
186 expr_ref = &alias.expr;
187 }
188 if let Expr::AggregateFunction(aggr_func) = expr_ref {
189 Some(aggr_func)
190 } else {
191 None
192 }
193}
194
195pub fn get_aggr_func_mut(expr: &mut Expr) -> Option<&mut datafusion_expr::expr::AggregateFunction> {
196 let mut expr_ref = expr;
197 while let Expr::Alias(alias) = expr_ref {
198 expr_ref = &mut alias.expr;
199 }
200 if let Expr::AggregateFunction(aggr_func) = expr_ref {
201 Some(aggr_func)
202 } else {
203 None
204 }
205}
206
207#[allow(dead_code)]
208pub enum Commutativity {
209 Commutative,
210 PartialCommutative,
211 ConditionalCommutative(Option<Transformer>),
212 TransformedCommutative {
213 transformer: Option<StageTransformer>,
215 },
216 NonCommutative,
217 Unimplemented,
218 Unsupported,
220}
221
222pub struct Categorizer {}
223
224impl Categorizer {
225 pub fn check_plan(plan: &LogicalPlan, partition_cols: Option<Vec<String>>) -> Commutativity {
226 let partition_cols = partition_cols.unwrap_or_default();
227
228 match plan {
229 LogicalPlan::Projection(proj) => {
230 for expr in &proj.expr {
231 let commutativity = Self::check_expr(expr);
232 if !matches!(commutativity, Commutativity::Commutative) {
233 return commutativity;
234 }
235 }
236 Commutativity::Commutative
237 }
238 LogicalPlan::Filter(filter) => Self::check_expr(&filter.predicate),
240 LogicalPlan::Window(_) => Commutativity::Unimplemented,
241 LogicalPlan::Aggregate(aggr) => {
242 let is_all_steppable = is_all_aggr_exprs_steppable(&aggr.aggr_expr);
243 let matches_partition = Self::check_partition(&aggr.group_expr, &partition_cols);
244 if !matches_partition && is_all_steppable {
245 debug!("Plan is steppable: {plan}");
246 return Commutativity::TransformedCommutative {
247 transformer: Some(Arc::new(|plan: &LogicalPlan| {
248 debug!("Before Step optimize: {plan}");
249 let ret = step_aggr_to_upper_aggr(plan);
250 debug!("After Step Optimize: {ret:?}");
251 ret.ok().map(|s| TransformerAction {
252 extra_parent_plans: s.to_vec(),
253 new_child_plan: None,
254 })
255 })),
256 };
257 }
258 if !matches_partition {
259 return Commutativity::NonCommutative;
260 }
261 for expr in &aggr.aggr_expr {
262 let commutativity = Self::check_expr(expr);
263 if !matches!(commutativity, Commutativity::Commutative) {
264 return commutativity;
265 }
266 }
267 Commutativity::Commutative
268 }
269 LogicalPlan::Sort(_) => {
270 if partition_cols.is_empty() {
271 return Commutativity::Commutative;
272 }
273
274 Commutativity::ConditionalCommutative(Some(Arc::new(merge_sort_transformer)))
278 }
279 LogicalPlan::Join(_) => Commutativity::NonCommutative,
280 LogicalPlan::Repartition(_) => {
281 Commutativity::Unimplemented
283 }
284 LogicalPlan::Union(_) => Commutativity::Unimplemented,
285 LogicalPlan::TableScan(_) => Commutativity::Commutative,
286 LogicalPlan::EmptyRelation(_) => Commutativity::NonCommutative,
287 LogicalPlan::Subquery(_) => Commutativity::Unimplemented,
288 LogicalPlan::SubqueryAlias(_) => Commutativity::Unimplemented,
289 LogicalPlan::Limit(limit) => {
290 if partition_cols.is_empty() && limit.fetch.is_some() {
293 Commutativity::Commutative
294 } else if limit.skip.is_none() && limit.fetch.is_some() {
295 Commutativity::PartialCommutative
296 } else {
297 Commutativity::Unimplemented
298 }
299 }
300 LogicalPlan::Extension(extension) => {
301 Self::check_extension_plan(extension.node.as_ref() as _, &partition_cols)
302 }
303 LogicalPlan::Distinct(_) => {
304 if partition_cols.is_empty() {
305 Commutativity::Commutative
306 } else {
307 Commutativity::Unimplemented
308 }
309 }
310 LogicalPlan::Unnest(_) => Commutativity::Commutative,
311 LogicalPlan::Statement(_) => Commutativity::Unsupported,
312 LogicalPlan::Values(_) => Commutativity::Unsupported,
313 LogicalPlan::Explain(_) => Commutativity::Unsupported,
314 LogicalPlan::Analyze(_) => Commutativity::Unsupported,
315 LogicalPlan::DescribeTable(_) => Commutativity::Unsupported,
316 LogicalPlan::Dml(_) => Commutativity::Unsupported,
317 LogicalPlan::Ddl(_) => Commutativity::Unsupported,
318 LogicalPlan::Copy(_) => Commutativity::Unsupported,
319 LogicalPlan::RecursiveQuery(_) => Commutativity::Unsupported,
320 }
321 }
322
323 pub fn check_extension_plan(
324 plan: &dyn UserDefinedLogicalNode,
325 partition_cols: &[String],
326 ) -> Commutativity {
327 match plan.name() {
328 name if name == SeriesDivide::name() => {
329 let series_divide = plan.as_any().downcast_ref::<SeriesDivide>().unwrap();
330 let tags = series_divide.tags().iter().collect::<HashSet<_>>();
331 for partition_col in partition_cols {
332 if !tags.contains(partition_col) {
333 return Commutativity::NonCommutative;
334 }
335 }
336 Commutativity::Commutative
337 }
338 name if name == SeriesNormalize::name()
339 || name == InstantManipulate::name()
340 || name == RangeManipulate::name() =>
341 {
342 Commutativity::Commutative
345 }
346 name if name == EmptyMetric::name()
347 || name == MergeScanLogicalPlan::name()
348 || name == MergeSortLogicalPlan::name() =>
349 {
350 Commutativity::Unimplemented
351 }
352 _ => Commutativity::Unsupported,
353 }
354 }
355
356 pub fn check_expr(expr: &Expr) -> Commutativity {
357 match expr {
358 Expr::Column(_)
359 | Expr::ScalarVariable(_, _)
360 | Expr::Literal(_)
361 | Expr::BinaryExpr(_)
362 | Expr::Not(_)
363 | Expr::IsNotNull(_)
364 | Expr::IsNull(_)
365 | Expr::IsTrue(_)
366 | Expr::IsFalse(_)
367 | Expr::IsNotTrue(_)
368 | Expr::IsNotFalse(_)
369 | Expr::Negative(_)
370 | Expr::Between(_)
371 | Expr::Exists(_)
372 | Expr::InList(_)
373 | Expr::Case(_) => Commutativity::Commutative,
374 Expr::ScalarFunction(_udf) => Commutativity::Commutative,
375 Expr::AggregateFunction(_udaf) => Commutativity::Commutative,
376
377 Expr::Like(_)
378 | Expr::SimilarTo(_)
379 | Expr::IsUnknown(_)
380 | Expr::IsNotUnknown(_)
381 | Expr::Cast(_)
382 | Expr::TryCast(_)
383 | Expr::WindowFunction(_)
384 | Expr::InSubquery(_)
385 | Expr::ScalarSubquery(_)
386 | Expr::Wildcard { .. } => Commutativity::Unimplemented,
387
388 Expr::Alias(alias) => Self::check_expr(&alias.expr),
389
390 Expr::Unnest(_)
391 | Expr::GroupingSet(_)
392 | Expr::Placeholder(_)
393 | Expr::OuterReferenceColumn(_, _) => Commutativity::Unimplemented,
394 }
395 }
396
397 fn check_partition(exprs: &[Expr], partition_cols: &[String]) -> bool {
400 let mut ref_cols = HashSet::new();
401 for expr in exprs {
402 expr.add_column_refs(&mut ref_cols);
403 }
404 let ref_cols = ref_cols
405 .into_iter()
406 .map(|c| c.name.clone())
407 .collect::<HashSet<_>>();
408 for col in partition_cols {
409 if !ref_cols.contains(col) {
410 return false;
411 }
412 }
413
414 true
415 }
416}
417
418pub type Transformer = Arc<dyn Fn(&LogicalPlan) -> Option<LogicalPlan>>;
419
420pub type StageTransformer = Arc<dyn Fn(&LogicalPlan) -> Option<TransformerAction>>;
422
423pub struct TransformerAction {
425 pub extra_parent_plans: Vec<LogicalPlan>,
434 pub new_child_plan: Option<LogicalPlan>,
436}
437
438pub fn partial_commutative_transformer(plan: &LogicalPlan) -> Option<LogicalPlan> {
439 Some(plan.clone())
440}
441
442#[cfg(test)]
443mod test {
444 use datafusion_expr::{LogicalPlanBuilder, Sort};
445
446 use super::*;
447
448 #[test]
449 fn sort_on_empty_partition() {
450 let plan = LogicalPlan::Sort(Sort {
451 expr: vec![],
452 input: Arc::new(LogicalPlanBuilder::empty(false).build().unwrap()),
453 fetch: None,
454 });
455 assert!(matches!(
456 Categorizer::check_plan(&plan, Some(vec![])),
457 Commutativity::Commutative
458 ));
459 }
460}