query/dist_plan/
commutativity.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::sync::Arc;
17
18use common_function::aggrs::approximate::hll::{HllState, HLL_MERGE_NAME, HLL_NAME};
19use common_function::aggrs::approximate::uddsketch::{
20    UddSketchState, UDDSKETCH_MERGE_NAME, UDDSKETCH_STATE_NAME,
21};
22use common_telemetry::debug;
23use datafusion::functions_aggregate::sum::sum_udaf;
24use datafusion_common::Column;
25use datafusion_expr::{Expr, LogicalPlan, Projection, UserDefinedLogicalNode};
26use promql::extension_plan::{
27    EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize,
28};
29
30use crate::dist_plan::merge_sort::{merge_sort_transformer, MergeSortLogicalPlan};
31use crate::dist_plan::MergeScanLogicalPlan;
32
33/// generate the upper aggregation plan that will execute on the frontend.
34/// Basically a logical plan resembling the following:
35/// Projection:
36///     Aggregate:
37///
38/// from Aggregate
39///
40/// The upper Projection exists sole to make sure parent plan can recognize the output
41/// of the upper aggregation plan.
42pub fn step_aggr_to_upper_aggr(
43    aggr_plan: &LogicalPlan,
44) -> datafusion_common::Result<[LogicalPlan; 2]> {
45    let LogicalPlan::Aggregate(input_aggr) = aggr_plan else {
46        return Err(datafusion_common::DataFusionError::Plan(
47            "step_aggr_to_upper_aggr only accepts Aggregate plan".to_string(),
48        ));
49    };
50    if !is_all_aggr_exprs_steppable(&input_aggr.aggr_expr) {
51        return Err(datafusion_common::DataFusionError::NotImplemented(
52            "Some aggregate expressions are not steppable".to_string(),
53        ));
54    }
55    let mut upper_aggr_expr = vec![];
56    for aggr_expr in &input_aggr.aggr_expr {
57        let Some(aggr_func) = get_aggr_func(aggr_expr) else {
58            return Err(datafusion_common::DataFusionError::NotImplemented(
59                "Aggregate function not found".to_string(),
60            ));
61        };
62        let col_name = aggr_expr.qualified_name();
63        let input_column = Expr::Column(datafusion_common::Column::new(col_name.0, col_name.1));
64        let upper_func = match aggr_func.func.name() {
65            "sum" | "min" | "max" | "last_value" | "first_value" => {
66                // aggr_calc(aggr_merge(input_column))) as col_name
67                let mut new_aggr_func = aggr_func.clone();
68                new_aggr_func.args = vec![input_column.clone()];
69                new_aggr_func
70            }
71            "count" => {
72                // sum(input_column) as col_name
73                let mut new_aggr_func = aggr_func.clone();
74                new_aggr_func.func = sum_udaf();
75                new_aggr_func.args = vec![input_column.clone()];
76                new_aggr_func
77            }
78            UDDSKETCH_STATE_NAME | UDDSKETCH_MERGE_NAME => {
79                // udd_merge(bucket_size, error_rate input_column) as col_name
80                let mut new_aggr_func = aggr_func.clone();
81                new_aggr_func.func = Arc::new(UddSketchState::merge_udf_impl());
82                new_aggr_func.args[2] = input_column.clone();
83                new_aggr_func
84            }
85            HLL_NAME | HLL_MERGE_NAME => {
86                // hll_merge(input_column) as col_name
87                let mut new_aggr_func = aggr_func.clone();
88                new_aggr_func.func = Arc::new(HllState::merge_udf_impl());
89                new_aggr_func.args = vec![input_column.clone()];
90                new_aggr_func
91            }
92            _ => {
93                return Err(datafusion_common::DataFusionError::NotImplemented(format!(
94                    "Aggregate function {} is not supported for Step aggregation",
95                    aggr_func.func.name()
96                )))
97            }
98        };
99
100        // deal with nested alias case
101        let mut new_aggr_expr = aggr_expr.clone();
102        {
103            let new_aggr_func = get_aggr_func_mut(&mut new_aggr_expr).unwrap();
104            *new_aggr_func = upper_func;
105        }
106
107        upper_aggr_expr.push(new_aggr_expr);
108    }
109    let mut new_aggr = input_aggr.clone();
110    // use lower aggregate plan as input, this will be replace by merge scan plan later
111    new_aggr.input = Arc::new(LogicalPlan::Aggregate(input_aggr.clone()));
112
113    new_aggr.aggr_expr = upper_aggr_expr;
114
115    // group by expr also need to be all ref by column to avoid duplicated computing
116    let mut new_group_expr = new_aggr.group_expr.clone();
117    for expr in &mut new_group_expr {
118        if let Expr::Column(_) = expr {
119            // already a column, no need to change
120            continue;
121        }
122        let col_name = expr.qualified_name();
123        let input_column = Expr::Column(datafusion_common::Column::new(col_name.0, col_name.1));
124        *expr = input_column;
125    }
126    new_aggr.group_expr = new_group_expr.clone();
127
128    let mut new_projection_exprs = new_group_expr;
129    // the upper aggr expr need to be aliased to the input aggr expr's name,
130    // so that the parent plan can recognize it.
131    for (lower_aggr_expr, upper_aggr_expr) in
132        input_aggr.aggr_expr.iter().zip(new_aggr.aggr_expr.iter())
133    {
134        let lower_col_name = lower_aggr_expr.qualified_name();
135        let (table, col_name) = upper_aggr_expr.qualified_name();
136        let aggr_out_column = Column::new(table, col_name);
137        let aliased_output_aggr_expr =
138            Expr::Column(aggr_out_column).alias_qualified(lower_col_name.0, lower_col_name.1);
139        new_projection_exprs.push(aliased_output_aggr_expr);
140    }
141    let upper_aggr_plan = LogicalPlan::Aggregate(new_aggr);
142    debug!("Before recompute schema: {upper_aggr_plan:?}");
143    let upper_aggr_plan = upper_aggr_plan.recompute_schema()?;
144    debug!("After recompute schema: {upper_aggr_plan:?}");
145    // create a projection on top of the new aggregate plan
146    let new_projection =
147        Projection::try_new(new_projection_exprs, Arc::new(upper_aggr_plan.clone()))?;
148    let projection = LogicalPlan::Projection(new_projection);
149    // return the new logical plan
150    Ok([projection, upper_aggr_plan])
151}
152
153/// Check if the given aggregate expression is steppable.
154/// As in if it can be split into multiple steps:
155/// i.e. on datanode first call `state(input)` then
156/// on frontend call `calc(merge(state))` to get the final result.
157pub fn is_all_aggr_exprs_steppable(aggr_exprs: &[Expr]) -> bool {
158    let step_action = HashSet::from([
159        "sum",
160        "count",
161        "min",
162        "max",
163        "first_value",
164        "last_value",
165        UDDSKETCH_STATE_NAME,
166        UDDSKETCH_MERGE_NAME,
167        HLL_NAME,
168        HLL_MERGE_NAME,
169    ]);
170    aggr_exprs.iter().all(|expr| {
171        if let Some(aggr_func) = get_aggr_func(expr) {
172            if aggr_func.distinct {
173                // Distinct aggregate functions are not steppable(yet).
174                return false;
175            }
176            step_action.contains(aggr_func.func.name())
177        } else {
178            false
179        }
180    })
181}
182
183pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
184    let mut expr_ref = expr;
185    while let Expr::Alias(alias) = expr_ref {
186        expr_ref = &alias.expr;
187    }
188    if let Expr::AggregateFunction(aggr_func) = expr_ref {
189        Some(aggr_func)
190    } else {
191        None
192    }
193}
194
195pub fn get_aggr_func_mut(expr: &mut Expr) -> Option<&mut datafusion_expr::expr::AggregateFunction> {
196    let mut expr_ref = expr;
197    while let Expr::Alias(alias) = expr_ref {
198        expr_ref = &mut alias.expr;
199    }
200    if let Expr::AggregateFunction(aggr_func) = expr_ref {
201        Some(aggr_func)
202    } else {
203        None
204    }
205}
206
207#[allow(dead_code)]
208pub enum Commutativity {
209    Commutative,
210    PartialCommutative,
211    ConditionalCommutative(Option<Transformer>),
212    TransformedCommutative {
213        /// Return plans from parent to child order
214        transformer: Option<StageTransformer>,
215    },
216    NonCommutative,
217    Unimplemented,
218    /// For unrelated plans like DDL
219    Unsupported,
220}
221
222pub struct Categorizer {}
223
224impl Categorizer {
225    pub fn check_plan(plan: &LogicalPlan, partition_cols: Option<Vec<String>>) -> Commutativity {
226        let partition_cols = partition_cols.unwrap_or_default();
227
228        match plan {
229            LogicalPlan::Projection(proj) => {
230                for expr in &proj.expr {
231                    let commutativity = Self::check_expr(expr);
232                    if !matches!(commutativity, Commutativity::Commutative) {
233                        return commutativity;
234                    }
235                }
236                Commutativity::Commutative
237            }
238            // TODO(ruihang): Change this to Commutative once Like is supported in substrait
239            LogicalPlan::Filter(filter) => Self::check_expr(&filter.predicate),
240            LogicalPlan::Window(_) => Commutativity::Unimplemented,
241            LogicalPlan::Aggregate(aggr) => {
242                let is_all_steppable = is_all_aggr_exprs_steppable(&aggr.aggr_expr);
243                let matches_partition = Self::check_partition(&aggr.group_expr, &partition_cols);
244                if !matches_partition && is_all_steppable {
245                    debug!("Plan is steppable: {plan}");
246                    return Commutativity::TransformedCommutative {
247                        transformer: Some(Arc::new(|plan: &LogicalPlan| {
248                            debug!("Before Step optimize: {plan}");
249                            let ret = step_aggr_to_upper_aggr(plan);
250                            debug!("After Step Optimize: {ret:?}");
251                            ret.ok().map(|s| TransformerAction {
252                                extra_parent_plans: s.to_vec(),
253                                new_child_plan: None,
254                            })
255                        })),
256                    };
257                }
258                if !matches_partition {
259                    return Commutativity::NonCommutative;
260                }
261                for expr in &aggr.aggr_expr {
262                    let commutativity = Self::check_expr(expr);
263                    if !matches!(commutativity, Commutativity::Commutative) {
264                        return commutativity;
265                    }
266                }
267                Commutativity::Commutative
268            }
269            LogicalPlan::Sort(_) => {
270                if partition_cols.is_empty() {
271                    return Commutativity::Commutative;
272                }
273
274                // sort plan needs to consider column priority
275                // Change Sort to MergeSort which assumes the input streams are already sorted hence can be more efficient
276                // We should ensure the number of partition is not smaller than the number of region at present. Otherwise this would result in incorrect output.
277                Commutativity::ConditionalCommutative(Some(Arc::new(merge_sort_transformer)))
278            }
279            LogicalPlan::Join(_) => Commutativity::NonCommutative,
280            LogicalPlan::Repartition(_) => {
281                // unsupported? or non-commutative
282                Commutativity::Unimplemented
283            }
284            LogicalPlan::Union(_) => Commutativity::Unimplemented,
285            LogicalPlan::TableScan(_) => Commutativity::Commutative,
286            LogicalPlan::EmptyRelation(_) => Commutativity::NonCommutative,
287            LogicalPlan::Subquery(_) => Commutativity::Unimplemented,
288            LogicalPlan::SubqueryAlias(_) => Commutativity::Unimplemented,
289            LogicalPlan::Limit(limit) => {
290                // Only execute `fetch` on remote nodes.
291                // wait for https://github.com/apache/arrow-datafusion/pull/7669
292                if partition_cols.is_empty() && limit.fetch.is_some() {
293                    Commutativity::Commutative
294                } else if limit.skip.is_none() && limit.fetch.is_some() {
295                    Commutativity::PartialCommutative
296                } else {
297                    Commutativity::Unimplemented
298                }
299            }
300            LogicalPlan::Extension(extension) => {
301                Self::check_extension_plan(extension.node.as_ref() as _, &partition_cols)
302            }
303            LogicalPlan::Distinct(_) => {
304                if partition_cols.is_empty() {
305                    Commutativity::Commutative
306                } else {
307                    Commutativity::Unimplemented
308                }
309            }
310            LogicalPlan::Unnest(_) => Commutativity::Commutative,
311            LogicalPlan::Statement(_) => Commutativity::Unsupported,
312            LogicalPlan::Values(_) => Commutativity::Unsupported,
313            LogicalPlan::Explain(_) => Commutativity::Unsupported,
314            LogicalPlan::Analyze(_) => Commutativity::Unsupported,
315            LogicalPlan::DescribeTable(_) => Commutativity::Unsupported,
316            LogicalPlan::Dml(_) => Commutativity::Unsupported,
317            LogicalPlan::Ddl(_) => Commutativity::Unsupported,
318            LogicalPlan::Copy(_) => Commutativity::Unsupported,
319            LogicalPlan::RecursiveQuery(_) => Commutativity::Unsupported,
320        }
321    }
322
323    pub fn check_extension_plan(
324        plan: &dyn UserDefinedLogicalNode,
325        partition_cols: &[String],
326    ) -> Commutativity {
327        match plan.name() {
328            name if name == SeriesDivide::name() => {
329                let series_divide = plan.as_any().downcast_ref::<SeriesDivide>().unwrap();
330                let tags = series_divide.tags().iter().collect::<HashSet<_>>();
331                for partition_col in partition_cols {
332                    if !tags.contains(partition_col) {
333                        return Commutativity::NonCommutative;
334                    }
335                }
336                Commutativity::Commutative
337            }
338            name if name == SeriesNormalize::name()
339                || name == InstantManipulate::name()
340                || name == RangeManipulate::name() =>
341            {
342                // They should always follows Series Divide.
343                // Either all commutative or all non-commutative (which will be blocked by SeriesDivide).
344                Commutativity::Commutative
345            }
346            name if name == EmptyMetric::name()
347                || name == MergeScanLogicalPlan::name()
348                || name == MergeSortLogicalPlan::name() =>
349            {
350                Commutativity::Unimplemented
351            }
352            _ => Commutativity::Unsupported,
353        }
354    }
355
356    pub fn check_expr(expr: &Expr) -> Commutativity {
357        match expr {
358            Expr::Column(_)
359            | Expr::ScalarVariable(_, _)
360            | Expr::Literal(_)
361            | Expr::BinaryExpr(_)
362            | Expr::Not(_)
363            | Expr::IsNotNull(_)
364            | Expr::IsNull(_)
365            | Expr::IsTrue(_)
366            | Expr::IsFalse(_)
367            | Expr::IsNotTrue(_)
368            | Expr::IsNotFalse(_)
369            | Expr::Negative(_)
370            | Expr::Between(_)
371            | Expr::Exists(_)
372            | Expr::InList(_)
373            | Expr::Case(_) => Commutativity::Commutative,
374            Expr::ScalarFunction(_udf) => Commutativity::Commutative,
375            Expr::AggregateFunction(_udaf) => Commutativity::Commutative,
376
377            Expr::Like(_)
378            | Expr::SimilarTo(_)
379            | Expr::IsUnknown(_)
380            | Expr::IsNotUnknown(_)
381            | Expr::Cast(_)
382            | Expr::TryCast(_)
383            | Expr::WindowFunction(_)
384            | Expr::InSubquery(_)
385            | Expr::ScalarSubquery(_)
386            | Expr::Wildcard { .. } => Commutativity::Unimplemented,
387
388            Expr::Alias(alias) => Self::check_expr(&alias.expr),
389
390            Expr::Unnest(_)
391            | Expr::GroupingSet(_)
392            | Expr::Placeholder(_)
393            | Expr::OuterReferenceColumn(_, _) => Commutativity::Unimplemented,
394        }
395    }
396
397    /// Return true if the given expr and partition cols satisfied the rule.
398    /// In this case the plan can be treated as fully commutative.
399    fn check_partition(exprs: &[Expr], partition_cols: &[String]) -> bool {
400        let mut ref_cols = HashSet::new();
401        for expr in exprs {
402            expr.add_column_refs(&mut ref_cols);
403        }
404        let ref_cols = ref_cols
405            .into_iter()
406            .map(|c| c.name.clone())
407            .collect::<HashSet<_>>();
408        for col in partition_cols {
409            if !ref_cols.contains(col) {
410                return false;
411            }
412        }
413
414        true
415    }
416}
417
418pub type Transformer = Arc<dyn Fn(&LogicalPlan) -> Option<LogicalPlan>>;
419
420/// Returns transformer action that need to be applied
421pub type StageTransformer = Arc<dyn Fn(&LogicalPlan) -> Option<TransformerAction>>;
422
423/// The Action that a transformer should take on the plan.
424pub struct TransformerAction {
425    /// list of plans that need to be applied to parent plans, in the order of parent to child.
426    /// i.e. if this returns `[Projection, Aggregate]`, then the parent plan should be transformed to
427    /// ```
428    /// Original Parent Plan:
429    ///     Projection:
430    ///         Aggregate:
431    ///             MergeScan: ...
432    /// ```
433    pub extra_parent_plans: Vec<LogicalPlan>,
434    /// new child plan, if None, use the original plan.
435    pub new_child_plan: Option<LogicalPlan>,
436}
437
438pub fn partial_commutative_transformer(plan: &LogicalPlan) -> Option<LogicalPlan> {
439    Some(plan.clone())
440}
441
442#[cfg(test)]
443mod test {
444    use datafusion_expr::{LogicalPlanBuilder, Sort};
445
446    use super::*;
447
448    #[test]
449    fn sort_on_empty_partition() {
450        let plan = LogicalPlan::Sort(Sort {
451            expr: vec![],
452            input: Arc::new(LogicalPlanBuilder::empty(false).build().unwrap()),
453            fetch: None,
454        });
455        assert!(matches!(
456            Categorizer::check_plan(&plan, Some(vec![])),
457            Commutativity::Commutative
458        ));
459    }
460}