query/dist_plan/
commutativity.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::sync::Arc;
17
18use common_function::aggrs::approximate::hll::{HllState, HLL_MERGE_NAME, HLL_NAME};
19use common_function::aggrs::approximate::uddsketch::{
20    UddSketchState, UDDSKETCH_MERGE_NAME, UDDSKETCH_STATE_NAME,
21};
22use common_telemetry::debug;
23use datafusion::functions_aggregate::sum::sum_udaf;
24use datafusion_common::Column;
25use datafusion_expr::{Expr, LogicalPlan, Projection, UserDefinedLogicalNode};
26use promql::extension_plan::{
27    EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize,
28};
29
30use crate::dist_plan::analyzer::AliasMapping;
31use crate::dist_plan::merge_sort::{merge_sort_transformer, MergeSortLogicalPlan};
32use crate::dist_plan::MergeScanLogicalPlan;
33
34/// generate the upper aggregation plan that will execute on the frontend.
35/// Basically a logical plan resembling the following:
36/// Projection:
37///     Aggregate:
38///
39/// from Aggregate
40///
41/// The upper Projection exists sole to make sure parent plan can recognize the output
42/// of the upper aggregation plan.
43pub fn step_aggr_to_upper_aggr(
44    aggr_plan: &LogicalPlan,
45) -> datafusion_common::Result<[LogicalPlan; 2]> {
46    let LogicalPlan::Aggregate(input_aggr) = aggr_plan else {
47        return Err(datafusion_common::DataFusionError::Plan(
48            "step_aggr_to_upper_aggr only accepts Aggregate plan".to_string(),
49        ));
50    };
51    if !is_all_aggr_exprs_steppable(&input_aggr.aggr_expr) {
52        return Err(datafusion_common::DataFusionError::NotImplemented(
53            "Some aggregate expressions are not steppable".to_string(),
54        ));
55    }
56    let mut upper_aggr_expr = vec![];
57    for aggr_expr in &input_aggr.aggr_expr {
58        let Some(aggr_func) = get_aggr_func(aggr_expr) else {
59            return Err(datafusion_common::DataFusionError::NotImplemented(
60                "Aggregate function not found".to_string(),
61            ));
62        };
63        let col_name = aggr_expr.qualified_name();
64        let input_column = Expr::Column(datafusion_common::Column::new(col_name.0, col_name.1));
65        let upper_func = match aggr_func.func.name() {
66            "sum" | "min" | "max" | "last_value" | "first_value" => {
67                // aggr_calc(aggr_merge(input_column))) as col_name
68                let mut new_aggr_func = aggr_func.clone();
69                new_aggr_func.args = vec![input_column.clone()];
70                new_aggr_func
71            }
72            "count" => {
73                // sum(input_column) as col_name
74                let mut new_aggr_func = aggr_func.clone();
75                new_aggr_func.func = sum_udaf();
76                new_aggr_func.args = vec![input_column.clone()];
77                new_aggr_func
78            }
79            UDDSKETCH_STATE_NAME | UDDSKETCH_MERGE_NAME => {
80                // udd_merge(bucket_size, error_rate input_column) as col_name
81                let mut new_aggr_func = aggr_func.clone();
82                new_aggr_func.func = Arc::new(UddSketchState::merge_udf_impl());
83                new_aggr_func.args[2] = input_column.clone();
84                new_aggr_func
85            }
86            HLL_NAME | HLL_MERGE_NAME => {
87                // hll_merge(input_column) as col_name
88                let mut new_aggr_func = aggr_func.clone();
89                new_aggr_func.func = Arc::new(HllState::merge_udf_impl());
90                new_aggr_func.args = vec![input_column.clone()];
91                new_aggr_func
92            }
93            _ => {
94                return Err(datafusion_common::DataFusionError::NotImplemented(format!(
95                    "Aggregate function {} is not supported for Step aggregation",
96                    aggr_func.func.name()
97                )))
98            }
99        };
100
101        // deal with nested alias case
102        let mut new_aggr_expr = aggr_expr.clone();
103        {
104            let new_aggr_func = get_aggr_func_mut(&mut new_aggr_expr).unwrap();
105            *new_aggr_func = upper_func;
106        }
107
108        upper_aggr_expr.push(new_aggr_expr);
109    }
110    let mut new_aggr = input_aggr.clone();
111    // use lower aggregate plan as input, this will be replace by merge scan plan later
112    new_aggr.input = Arc::new(LogicalPlan::Aggregate(input_aggr.clone()));
113
114    new_aggr.aggr_expr = upper_aggr_expr;
115
116    // group by expr also need to be all ref by column to avoid duplicated computing
117    let mut new_group_expr = new_aggr.group_expr.clone();
118    for expr in &mut new_group_expr {
119        if let Expr::Column(_) = expr {
120            // already a column, no need to change
121            continue;
122        }
123        let col_name = expr.qualified_name();
124        let input_column = Expr::Column(datafusion_common::Column::new(col_name.0, col_name.1));
125        *expr = input_column;
126    }
127    new_aggr.group_expr = new_group_expr.clone();
128
129    let mut new_projection_exprs = new_group_expr;
130    // the upper aggr expr need to be aliased to the input aggr expr's name,
131    // so that the parent plan can recognize it.
132    for (lower_aggr_expr, upper_aggr_expr) in
133        input_aggr.aggr_expr.iter().zip(new_aggr.aggr_expr.iter())
134    {
135        let lower_col_name = lower_aggr_expr.qualified_name();
136        let (table, col_name) = upper_aggr_expr.qualified_name();
137        let aggr_out_column = Column::new(table, col_name);
138        let aliased_output_aggr_expr =
139            Expr::Column(aggr_out_column).alias_qualified(lower_col_name.0, lower_col_name.1);
140        new_projection_exprs.push(aliased_output_aggr_expr);
141    }
142    let upper_aggr_plan = LogicalPlan::Aggregate(new_aggr);
143    let upper_aggr_plan = upper_aggr_plan.recompute_schema()?;
144    // create a projection on top of the new aggregate plan
145    let new_projection =
146        Projection::try_new(new_projection_exprs, Arc::new(upper_aggr_plan.clone()))?;
147    let projection = LogicalPlan::Projection(new_projection);
148    // return the new logical plan
149    Ok([projection, upper_aggr_plan])
150}
151
152/// Check if the given aggregate expression is steppable.
153/// As in if it can be split into multiple steps:
154/// i.e. on datanode first call `state(input)` then
155/// on frontend call `calc(merge(state))` to get the final result.
156pub fn is_all_aggr_exprs_steppable(aggr_exprs: &[Expr]) -> bool {
157    let step_action = HashSet::from([
158        "sum",
159        "count",
160        "min",
161        "max",
162        "first_value",
163        "last_value",
164        UDDSKETCH_STATE_NAME,
165        UDDSKETCH_MERGE_NAME,
166        HLL_NAME,
167        HLL_MERGE_NAME,
168    ]);
169    aggr_exprs.iter().all(|expr| {
170        if let Some(aggr_func) = get_aggr_func(expr) {
171            if aggr_func.distinct {
172                // Distinct aggregate functions are not steppable(yet).
173                return false;
174            }
175            step_action.contains(aggr_func.func.name())
176        } else {
177            false
178        }
179    })
180}
181
182pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
183    let mut expr_ref = expr;
184    while let Expr::Alias(alias) = expr_ref {
185        expr_ref = &alias.expr;
186    }
187    if let Expr::AggregateFunction(aggr_func) = expr_ref {
188        Some(aggr_func)
189    } else {
190        None
191    }
192}
193
194pub fn get_aggr_func_mut(expr: &mut Expr) -> Option<&mut datafusion_expr::expr::AggregateFunction> {
195    let mut expr_ref = expr;
196    while let Expr::Alias(alias) = expr_ref {
197        expr_ref = &mut alias.expr;
198    }
199    if let Expr::AggregateFunction(aggr_func) = expr_ref {
200        Some(aggr_func)
201    } else {
202        None
203    }
204}
205
206#[allow(dead_code)]
207pub enum Commutativity {
208    Commutative,
209    PartialCommutative,
210    ConditionalCommutative(Option<Transformer>),
211    TransformedCommutative {
212        /// Return plans from parent to child order
213        transformer: Option<StageTransformer>,
214    },
215    NonCommutative,
216    Unimplemented,
217    /// For unrelated plans like DDL
218    Unsupported,
219}
220
221pub struct Categorizer {}
222
223impl Categorizer {
224    pub fn check_plan(plan: &LogicalPlan, partition_cols: Option<AliasMapping>) -> Commutativity {
225        let partition_cols = partition_cols.unwrap_or_default();
226
227        match plan {
228            LogicalPlan::Projection(proj) => {
229                for expr in &proj.expr {
230                    let commutativity = Self::check_expr(expr);
231                    if !matches!(commutativity, Commutativity::Commutative) {
232                        return commutativity;
233                    }
234                }
235                Commutativity::Commutative
236            }
237            // TODO(ruihang): Change this to Commutative once Like is supported in substrait
238            LogicalPlan::Filter(filter) => Self::check_expr(&filter.predicate),
239            LogicalPlan::Window(_) => Commutativity::Unimplemented,
240            LogicalPlan::Aggregate(aggr) => {
241                let is_all_steppable = is_all_aggr_exprs_steppable(&aggr.aggr_expr);
242                let matches_partition = Self::check_partition(&aggr.group_expr, &partition_cols);
243                if !matches_partition && is_all_steppable {
244                    debug!("Plan is steppable: {plan}");
245                    return Commutativity::TransformedCommutative {
246                        transformer: Some(Arc::new(|plan: &LogicalPlan| {
247                            debug!("Before Step optimize: {plan}");
248                            let ret = step_aggr_to_upper_aggr(plan);
249                            ret.ok().map(|s| TransformerAction {
250                                extra_parent_plans: s.to_vec(),
251                                new_child_plan: None,
252                            })
253                        })),
254                    };
255                }
256                if !matches_partition {
257                    return Commutativity::NonCommutative;
258                }
259                for expr in &aggr.aggr_expr {
260                    let commutativity = Self::check_expr(expr);
261                    if !matches!(commutativity, Commutativity::Commutative) {
262                        return commutativity;
263                    }
264                }
265                // all group by expressions are partition columns can push down, unless
266                // another push down(including `Limit` or `Sort`) is already in progress(which will then prvent next cond commutative node from being push down).
267                // TODO(discord9): This is a temporary solution(that works), a better description of
268                // commutativity is needed under this situation.
269                Commutativity::ConditionalCommutative(None)
270            }
271            LogicalPlan::Sort(_) => {
272                if partition_cols.is_empty() {
273                    return Commutativity::Commutative;
274                }
275
276                // sort plan needs to consider column priority
277                // Change Sort to MergeSort which assumes the input streams are already sorted hence can be more efficient
278                // We should ensure the number of partition is not smaller than the number of region at present. Otherwise this would result in incorrect output.
279                Commutativity::ConditionalCommutative(Some(Arc::new(merge_sort_transformer)))
280            }
281            LogicalPlan::Join(_) => Commutativity::NonCommutative,
282            LogicalPlan::Repartition(_) => {
283                // unsupported? or non-commutative
284                Commutativity::Unimplemented
285            }
286            LogicalPlan::Union(_) => Commutativity::Unimplemented,
287            LogicalPlan::TableScan(_) => Commutativity::Commutative,
288            LogicalPlan::EmptyRelation(_) => Commutativity::NonCommutative,
289            LogicalPlan::Subquery(_) => Commutativity::Unimplemented,
290            LogicalPlan::SubqueryAlias(_) => Commutativity::Unimplemented,
291            LogicalPlan::Limit(limit) => {
292                // Only execute `fetch` on remote nodes.
293                // wait for https://github.com/apache/arrow-datafusion/pull/7669
294                if partition_cols.is_empty() && limit.fetch.is_some() {
295                    Commutativity::Commutative
296                } else if limit.skip.is_none() && limit.fetch.is_some() {
297                    Commutativity::PartialCommutative
298                } else {
299                    Commutativity::Unimplemented
300                }
301            }
302            LogicalPlan::Extension(extension) => {
303                Self::check_extension_plan(extension.node.as_ref() as _, &partition_cols)
304            }
305            LogicalPlan::Distinct(_) => {
306                if partition_cols.is_empty() {
307                    Commutativity::Commutative
308                } else {
309                    Commutativity::Unimplemented
310                }
311            }
312            LogicalPlan::Unnest(_) => Commutativity::Commutative,
313            LogicalPlan::Statement(_) => Commutativity::Unsupported,
314            LogicalPlan::Values(_) => Commutativity::Unsupported,
315            LogicalPlan::Explain(_) => Commutativity::Unsupported,
316            LogicalPlan::Analyze(_) => Commutativity::Unsupported,
317            LogicalPlan::DescribeTable(_) => Commutativity::Unsupported,
318            LogicalPlan::Dml(_) => Commutativity::Unsupported,
319            LogicalPlan::Ddl(_) => Commutativity::Unsupported,
320            LogicalPlan::Copy(_) => Commutativity::Unsupported,
321            LogicalPlan::RecursiveQuery(_) => Commutativity::Unsupported,
322        }
323    }
324
325    pub fn check_extension_plan(
326        plan: &dyn UserDefinedLogicalNode,
327        partition_cols: &AliasMapping,
328    ) -> Commutativity {
329        match plan.name() {
330            name if name == SeriesDivide::name() => {
331                let series_divide = plan.as_any().downcast_ref::<SeriesDivide>().unwrap();
332                let tags = series_divide.tags().iter().collect::<HashSet<_>>();
333
334                for all_alias in partition_cols.values() {
335                    let all_alias = all_alias.iter().map(|c| &c.name).collect::<HashSet<_>>();
336                    if tags.intersection(&all_alias).count() == 0 {
337                        return Commutativity::NonCommutative;
338                    }
339                }
340
341                Commutativity::Commutative
342            }
343            name if name == SeriesNormalize::name()
344                || name == InstantManipulate::name()
345                || name == RangeManipulate::name() =>
346            {
347                // They should always follows Series Divide.
348                // Either all commutative or all non-commutative (which will be blocked by SeriesDivide).
349                Commutativity::Commutative
350            }
351            name if name == EmptyMetric::name()
352                || name == MergeScanLogicalPlan::name()
353                || name == MergeSortLogicalPlan::name() =>
354            {
355                Commutativity::Unimplemented
356            }
357            _ => Commutativity::Unsupported,
358        }
359    }
360
361    pub fn check_expr(expr: &Expr) -> Commutativity {
362        match expr {
363            Expr::Column(_)
364            | Expr::ScalarVariable(_, _)
365            | Expr::Literal(_)
366            | Expr::BinaryExpr(_)
367            | Expr::Not(_)
368            | Expr::IsNotNull(_)
369            | Expr::IsNull(_)
370            | Expr::IsTrue(_)
371            | Expr::IsFalse(_)
372            | Expr::IsNotTrue(_)
373            | Expr::IsNotFalse(_)
374            | Expr::Negative(_)
375            | Expr::Between(_)
376            | Expr::Exists(_)
377            | Expr::InList(_)
378            | Expr::Case(_) => Commutativity::Commutative,
379            Expr::ScalarFunction(_udf) => Commutativity::Commutative,
380            Expr::AggregateFunction(_udaf) => Commutativity::Commutative,
381
382            Expr::Like(_)
383            | Expr::SimilarTo(_)
384            | Expr::IsUnknown(_)
385            | Expr::IsNotUnknown(_)
386            | Expr::Cast(_)
387            | Expr::TryCast(_)
388            | Expr::WindowFunction(_)
389            | Expr::InSubquery(_)
390            | Expr::ScalarSubquery(_)
391            | Expr::Wildcard { .. } => Commutativity::Unimplemented,
392
393            Expr::Alias(alias) => Self::check_expr(&alias.expr),
394
395            Expr::Unnest(_)
396            | Expr::GroupingSet(_)
397            | Expr::Placeholder(_)
398            | Expr::OuterReferenceColumn(_, _) => Commutativity::Unimplemented,
399        }
400    }
401
402    /// Return true if the given expr and partition cols satisfied the rule.
403    /// In this case the plan can be treated as fully commutative.
404    fn check_partition(exprs: &[Expr], partition_cols: &AliasMapping) -> bool {
405        let mut ref_cols = HashSet::new();
406        for expr in exprs {
407            expr.add_column_refs(&mut ref_cols);
408        }
409        let ref_cols = ref_cols
410            .into_iter()
411            .map(|c| c.name.clone())
412            .collect::<HashSet<_>>();
413        for all_alias in partition_cols.values() {
414            let all_alias = all_alias
415                .iter()
416                .map(|c| c.name.clone())
417                .collect::<HashSet<_>>();
418            // check if ref columns intersect with all alias of partition columns
419            // is empty, if it's empty, not all partition columns show up in `exprs`
420            if ref_cols.intersection(&all_alias).count() == 0 {
421                return false;
422            }
423        }
424
425        true
426    }
427}
428
429pub type Transformer = Arc<dyn Fn(&LogicalPlan) -> Option<LogicalPlan>>;
430
431/// Returns transformer action that need to be applied
432pub type StageTransformer = Arc<dyn Fn(&LogicalPlan) -> Option<TransformerAction>>;
433
434/// The Action that a transformer should take on the plan.
435pub struct TransformerAction {
436    /// list of plans that need to be applied to parent plans, in the order of parent to child.
437    /// i.e. if this returns `[Projection, Aggregate]`, then the parent plan should be transformed to
438    /// ```ignore
439    /// Original Parent Plan:
440    ///     Projection:
441    ///         Aggregate:
442    ///             MergeScan: ...
443    /// ```
444    pub extra_parent_plans: Vec<LogicalPlan>,
445    /// new child plan, if None, use the original plan.
446    pub new_child_plan: Option<LogicalPlan>,
447}
448
449pub fn partial_commutative_transformer(plan: &LogicalPlan) -> Option<LogicalPlan> {
450    Some(plan.clone())
451}
452
453#[cfg(test)]
454mod test {
455    use datafusion_expr::{LogicalPlanBuilder, Sort};
456
457    use super::*;
458
459    #[test]
460    fn sort_on_empty_partition() {
461        let plan = LogicalPlan::Sort(Sort {
462            expr: vec![],
463            input: Arc::new(LogicalPlanBuilder::empty(false).build().unwrap()),
464            fetch: None,
465        });
466        assert!(matches!(
467            Categorizer::check_plan(&plan, Some(Default::default())),
468            Commutativity::Commutative
469        ));
470    }
471}