query/dist_plan/
commutativity.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::sync::Arc;
17
18use common_function::aggrs::aggr_wrapper::{aggr_state_func_name, StateMergeHelper};
19use common_function::function_registry::FUNCTION_REGISTRY;
20use common_telemetry::debug;
21use datafusion_expr::{Expr, LogicalPlan, UserDefinedLogicalNode};
22use promql::extension_plan::{
23    EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize,
24};
25
26use crate::dist_plan::analyzer::AliasMapping;
27use crate::dist_plan::merge_sort::{merge_sort_transformer, MergeSortLogicalPlan};
28use crate::dist_plan::MergeScanLogicalPlan;
29
30pub struct StepTransformAction {
31    extra_parent_plans: Vec<LogicalPlan>,
32    new_child_plan: Option<LogicalPlan>,
33}
34
35/// generate the upper aggregation plan that will execute on the frontend.
36/// Basically a logical plan resembling the following:
37/// Projection:
38///     Aggregate:
39///
40/// from Aggregate
41///
42/// The upper Projection exists sole to make sure parent plan can recognize the output
43/// of the upper aggregation plan.
44pub fn step_aggr_to_upper_aggr(
45    aggr_plan: &LogicalPlan,
46) -> datafusion_common::Result<StepTransformAction> {
47    let LogicalPlan::Aggregate(input_aggr) = aggr_plan else {
48        return Err(datafusion_common::DataFusionError::Plan(
49            "step_aggr_to_upper_aggr only accepts Aggregate plan".to_string(),
50        ));
51    };
52    if !is_all_aggr_exprs_steppable(&input_aggr.aggr_expr) {
53        return Err(datafusion_common::DataFusionError::NotImplemented(format!(
54            "Some aggregate expressions are not steppable in [{}]",
55            input_aggr
56                .aggr_expr
57                .iter()
58                .map(|e| e.to_string())
59                .collect::<Vec<_>>()
60                .join(", ")
61        )));
62    }
63
64    let step_aggr_plan = StateMergeHelper::split_aggr_node(input_aggr.clone())?;
65
66    // TODO(discord9): remove duplication
67    let ret = StepTransformAction {
68        extra_parent_plans: vec![step_aggr_plan.upper_merge.clone()],
69        new_child_plan: Some(step_aggr_plan.lower_state.clone()),
70    };
71    Ok(ret)
72}
73
74/// Check if the given aggregate expression is steppable.
75/// As in if it can be split into multiple steps:
76/// i.e. on datanode first call `state(input)` then
77/// on frontend call `calc(merge(state))` to get the final result.
78pub fn is_all_aggr_exprs_steppable(aggr_exprs: &[Expr]) -> bool {
79    aggr_exprs.iter().all(|expr| {
80        if let Some(aggr_func) = get_aggr_func(expr) {
81            if aggr_func.params.distinct {
82                // Distinct aggregate functions are not steppable(yet).
83                return false;
84            }
85
86            // whether the corresponding state function exists in the registry
87            FUNCTION_REGISTRY.is_aggr_func_exist(&aggr_state_func_name(aggr_func.func.name()))
88        } else {
89            false
90        }
91    })
92}
93
94pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
95    let mut expr_ref = expr;
96    while let Expr::Alias(alias) = expr_ref {
97        expr_ref = &alias.expr;
98    }
99    if let Expr::AggregateFunction(aggr_func) = expr_ref {
100        Some(aggr_func)
101    } else {
102        None
103    }
104}
105
106#[allow(dead_code)]
107pub enum Commutativity {
108    Commutative,
109    PartialCommutative,
110    ConditionalCommutative(Option<Transformer>),
111    TransformedCommutative {
112        /// Return plans from parent to child order
113        transformer: Option<StageTransformer>,
114    },
115    NonCommutative,
116    Unimplemented,
117    /// For unrelated plans like DDL
118    Unsupported,
119}
120
121pub struct Categorizer {}
122
123impl Categorizer {
124    pub fn check_plan(plan: &LogicalPlan, partition_cols: Option<AliasMapping>) -> Commutativity {
125        let partition_cols = partition_cols.unwrap_or_default();
126
127        match plan {
128            LogicalPlan::Projection(proj) => {
129                for expr in &proj.expr {
130                    let commutativity = Self::check_expr(expr);
131                    if !matches!(commutativity, Commutativity::Commutative) {
132                        return commutativity;
133                    }
134                }
135                Commutativity::Commutative
136            }
137            // TODO(ruihang): Change this to Commutative once Like is supported in substrait
138            LogicalPlan::Filter(filter) => Self::check_expr(&filter.predicate),
139            LogicalPlan::Window(_) => Commutativity::Unimplemented,
140            LogicalPlan::Aggregate(aggr) => {
141                let is_all_steppable = is_all_aggr_exprs_steppable(&aggr.aggr_expr);
142                let matches_partition = Self::check_partition(&aggr.group_expr, &partition_cols);
143                if !matches_partition && is_all_steppable {
144                    debug!("Plan is steppable: {plan}");
145                    return Commutativity::TransformedCommutative {
146                        transformer: Some(Arc::new(|plan: &LogicalPlan| {
147                            debug!("Before Step optimize: {plan}");
148                            let ret = step_aggr_to_upper_aggr(plan);
149                            ret.ok().map(|s| TransformerAction {
150                                extra_parent_plans: s.extra_parent_plans,
151                                new_child_plan: s.new_child_plan,
152                            })
153                        })),
154                    };
155                }
156                if !matches_partition {
157                    return Commutativity::NonCommutative;
158                }
159                for expr in &aggr.aggr_expr {
160                    let commutativity = Self::check_expr(expr);
161                    if !matches!(commutativity, Commutativity::Commutative) {
162                        return commutativity;
163                    }
164                }
165                // all group by expressions are partition columns can push down, unless
166                // another push down(including `Limit` or `Sort`) is already in progress(which will then prvent next cond commutative node from being push down).
167                // TODO(discord9): This is a temporary solution(that works), a better description of
168                // commutativity is needed under this situation.
169                Commutativity::ConditionalCommutative(None)
170            }
171            LogicalPlan::Sort(_) => {
172                if partition_cols.is_empty() {
173                    return Commutativity::Commutative;
174                }
175
176                // sort plan needs to consider column priority
177                // Change Sort to MergeSort which assumes the input streams are already sorted hence can be more efficient
178                // We should ensure the number of partition is not smaller than the number of region at present. Otherwise this would result in incorrect output.
179                Commutativity::ConditionalCommutative(Some(Arc::new(merge_sort_transformer)))
180            }
181            LogicalPlan::Join(_) => Commutativity::NonCommutative,
182            LogicalPlan::Repartition(_) => {
183                // unsupported? or non-commutative
184                Commutativity::Unimplemented
185            }
186            LogicalPlan::Union(_) => Commutativity::Unimplemented,
187            LogicalPlan::TableScan(_) => Commutativity::Commutative,
188            LogicalPlan::EmptyRelation(_) => Commutativity::NonCommutative,
189            LogicalPlan::Subquery(_) => Commutativity::Unimplemented,
190            LogicalPlan::SubqueryAlias(_) => Commutativity::Unimplemented,
191            LogicalPlan::Limit(limit) => {
192                // Only execute `fetch` on remote nodes.
193                // wait for https://github.com/apache/arrow-datafusion/pull/7669
194                if partition_cols.is_empty() && limit.fetch.is_some() {
195                    Commutativity::Commutative
196                } else if limit.skip.is_none() && limit.fetch.is_some() {
197                    Commutativity::PartialCommutative
198                } else {
199                    Commutativity::Unimplemented
200                }
201            }
202            LogicalPlan::Extension(extension) => {
203                Self::check_extension_plan(extension.node.as_ref() as _, &partition_cols)
204            }
205            LogicalPlan::Distinct(_) => {
206                if partition_cols.is_empty() {
207                    Commutativity::Commutative
208                } else {
209                    Commutativity::Unimplemented
210                }
211            }
212            LogicalPlan::Unnest(_) => Commutativity::Commutative,
213            LogicalPlan::Statement(_) => Commutativity::Unsupported,
214            LogicalPlan::Values(_) => Commutativity::Unsupported,
215            LogicalPlan::Explain(_) => Commutativity::Unsupported,
216            LogicalPlan::Analyze(_) => Commutativity::Unsupported,
217            LogicalPlan::DescribeTable(_) => Commutativity::Unsupported,
218            LogicalPlan::Dml(_) => Commutativity::Unsupported,
219            LogicalPlan::Ddl(_) => Commutativity::Unsupported,
220            LogicalPlan::Copy(_) => Commutativity::Unsupported,
221            LogicalPlan::RecursiveQuery(_) => Commutativity::Unsupported,
222        }
223    }
224
225    pub fn check_extension_plan(
226        plan: &dyn UserDefinedLogicalNode,
227        partition_cols: &AliasMapping,
228    ) -> Commutativity {
229        match plan.name() {
230            name if name == SeriesDivide::name() => {
231                let series_divide = plan.as_any().downcast_ref::<SeriesDivide>().unwrap();
232                let tags = series_divide.tags().iter().collect::<HashSet<_>>();
233
234                for all_alias in partition_cols.values() {
235                    let all_alias = all_alias.iter().map(|c| &c.name).collect::<HashSet<_>>();
236                    if tags.intersection(&all_alias).count() == 0 {
237                        return Commutativity::NonCommutative;
238                    }
239                }
240
241                Commutativity::Commutative
242            }
243            name if name == SeriesNormalize::name()
244                || name == InstantManipulate::name()
245                || name == RangeManipulate::name() =>
246            {
247                // They should always follows Series Divide.
248                // Either all commutative or all non-commutative (which will be blocked by SeriesDivide).
249                Commutativity::Commutative
250            }
251            name if name == EmptyMetric::name()
252                || name == MergeScanLogicalPlan::name()
253                || name == MergeSortLogicalPlan::name() =>
254            {
255                Commutativity::Unimplemented
256            }
257            _ => Commutativity::Unsupported,
258        }
259    }
260
261    pub fn check_expr(expr: &Expr) -> Commutativity {
262        #[allow(deprecated)]
263        match expr {
264            Expr::Column(_)
265            | Expr::ScalarVariable(_, _)
266            | Expr::Literal(_, _)
267            | Expr::BinaryExpr(_)
268            | Expr::Not(_)
269            | Expr::IsNotNull(_)
270            | Expr::IsNull(_)
271            | Expr::IsTrue(_)
272            | Expr::IsFalse(_)
273            | Expr::IsNotTrue(_)
274            | Expr::IsNotFalse(_)
275            | Expr::Negative(_)
276            | Expr::Between(_)
277            | Expr::Exists(_)
278            | Expr::InList(_)
279            | Expr::Case(_) => Commutativity::Commutative,
280            Expr::ScalarFunction(_udf) => Commutativity::Commutative,
281            Expr::AggregateFunction(_udaf) => Commutativity::Commutative,
282
283            Expr::Like(_)
284            | Expr::SimilarTo(_)
285            | Expr::IsUnknown(_)
286            | Expr::IsNotUnknown(_)
287            | Expr::Cast(_)
288            | Expr::TryCast(_)
289            | Expr::WindowFunction(_)
290            | Expr::InSubquery(_)
291            | Expr::ScalarSubquery(_)
292            | Expr::Wildcard { .. } => Commutativity::Unimplemented,
293
294            Expr::Alias(alias) => Self::check_expr(&alias.expr),
295
296            Expr::Unnest(_)
297            | Expr::GroupingSet(_)
298            | Expr::Placeholder(_)
299            | Expr::OuterReferenceColumn(_, _) => Commutativity::Unimplemented,
300        }
301    }
302
303    /// Return true if the given expr and partition cols satisfied the rule.
304    /// In this case the plan can be treated as fully commutative.
305    fn check_partition(exprs: &[Expr], partition_cols: &AliasMapping) -> bool {
306        let mut ref_cols = HashSet::new();
307        for expr in exprs {
308            expr.add_column_refs(&mut ref_cols);
309        }
310        let ref_cols = ref_cols
311            .into_iter()
312            .map(|c| c.name.clone())
313            .collect::<HashSet<_>>();
314        for all_alias in partition_cols.values() {
315            let all_alias = all_alias
316                .iter()
317                .map(|c| c.name.clone())
318                .collect::<HashSet<_>>();
319            // check if ref columns intersect with all alias of partition columns
320            // is empty, if it's empty, not all partition columns show up in `exprs`
321            if ref_cols.intersection(&all_alias).count() == 0 {
322                return false;
323            }
324        }
325
326        true
327    }
328}
329
330pub type Transformer = Arc<dyn Fn(&LogicalPlan) -> Option<LogicalPlan>>;
331
332/// Returns transformer action that need to be applied
333pub type StageTransformer = Arc<dyn Fn(&LogicalPlan) -> Option<TransformerAction>>;
334
335/// The Action that a transformer should take on the plan.
336pub struct TransformerAction {
337    /// list of plans that need to be applied to parent plans, in the order of parent to child.
338    /// i.e. if this returns `[Projection, Aggregate]`, then the parent plan should be transformed to
339    /// ```ignore
340    /// Original Parent Plan:
341    ///     Projection:
342    ///         Aggregate:
343    ///             MergeScan: ...
344    /// ```
345    pub extra_parent_plans: Vec<LogicalPlan>,
346    /// new child plan, if None, use the original plan.
347    pub new_child_plan: Option<LogicalPlan>,
348}
349
350pub fn partial_commutative_transformer(plan: &LogicalPlan) -> Option<LogicalPlan> {
351    Some(plan.clone())
352}
353
354#[cfg(test)]
355mod test {
356    use datafusion_expr::{LogicalPlanBuilder, Sort};
357
358    use super::*;
359
360    #[test]
361    fn sort_on_empty_partition() {
362        let plan = LogicalPlan::Sort(Sort {
363            expr: vec![],
364            input: Arc::new(LogicalPlanBuilder::empty(false).build().unwrap()),
365            fetch: None,
366        });
367        assert!(matches!(
368            Categorizer::check_plan(&plan, Some(Default::default())),
369            Commutativity::Commutative
370        ));
371    }
372}