query/dist_plan/
commutativity.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::sync::Arc;
17
18use common_function::aggrs::aggr_wrapper::{StateMergeHelper, is_all_aggr_exprs_steppable};
19use common_telemetry::debug;
20use datafusion::error::Result as DfResult;
21use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
22use datafusion_expr::{Expr, LogicalPlan, UserDefinedLogicalNode};
23use promql::extension_plan::{
24    EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize,
25};
26
27use crate::dist_plan::MergeScanLogicalPlan;
28use crate::dist_plan::analyzer::AliasMapping;
29use crate::dist_plan::merge_sort::{MergeSortLogicalPlan, merge_sort_transformer};
30
31pub struct StepTransformAction {
32    extra_parent_plans: Vec<LogicalPlan>,
33    new_child_plan: Option<LogicalPlan>,
34}
35
36/// generate the upper aggregation plan that will execute on the frontend.
37/// Basically a logical plan resembling the following:
38/// Projection:
39///     Aggregate:
40///
41/// from Aggregate
42///
43/// The upper Projection exists sole to make sure parent plan can recognize the output
44/// of the upper aggregation plan.
45pub fn step_aggr_to_upper_aggr(
46    aggr_plan: &LogicalPlan,
47) -> datafusion_common::Result<StepTransformAction> {
48    let LogicalPlan::Aggregate(input_aggr) = aggr_plan else {
49        return Err(datafusion_common::DataFusionError::Plan(
50            "step_aggr_to_upper_aggr only accepts Aggregate plan".to_string(),
51        ));
52    };
53    if !is_all_aggr_exprs_steppable(&input_aggr.aggr_expr) {
54        return Err(datafusion_common::DataFusionError::NotImplemented(format!(
55            "Some aggregate expressions are not steppable in [{}]",
56            input_aggr
57                .aggr_expr
58                .iter()
59                .map(|e| e.to_string())
60                .collect::<Vec<_>>()
61                .join(", ")
62        )));
63    }
64
65    let step_aggr_plan = StateMergeHelper::split_aggr_node(input_aggr.clone())?;
66
67    // TODO(discord9): remove duplication
68    let ret = StepTransformAction {
69        extra_parent_plans: vec![step_aggr_plan.upper_merge.clone()],
70        new_child_plan: Some(step_aggr_plan.lower_state.clone()),
71    };
72    Ok(ret)
73}
74
75#[allow(dead_code)]
76pub enum Commutativity {
77    Commutative,
78    PartialCommutative,
79    ConditionalCommutative(Option<Transformer>),
80    TransformedCommutative {
81        /// Return plans from parent to child order
82        transformer: Option<StageTransformer>,
83    },
84    NonCommutative,
85    Unimplemented,
86    /// For unrelated plans like DDL
87    Unsupported,
88}
89
90pub struct Categorizer {}
91
92impl Categorizer {
93    pub fn check_plan(
94        plan: &LogicalPlan,
95        partition_cols: Option<AliasMapping>,
96    ) -> DfResult<Commutativity> {
97        // Subquery is treated separately in `inspect_plan_with_subquery`. To avoid rewrite the
98        // "maybe rewritten" plan, stop the check here.
99        if has_subquery(plan)? {
100            return Ok(Commutativity::Unimplemented);
101        }
102
103        let partition_cols = partition_cols.unwrap_or_default();
104
105        let comm = match plan {
106            LogicalPlan::Projection(proj) => {
107                for expr in &proj.expr {
108                    let commutativity = Self::check_expr(expr);
109                    if !matches!(commutativity, Commutativity::Commutative) {
110                        return Ok(commutativity);
111                    }
112                }
113                Commutativity::Commutative
114            }
115            // TODO(ruihang): Change this to Commutative once Like is supported in substrait
116            LogicalPlan::Filter(filter) => Self::check_expr(&filter.predicate),
117            LogicalPlan::Window(_) => Commutativity::Unimplemented,
118            LogicalPlan::Aggregate(aggr) => {
119                let is_all_steppable = is_all_aggr_exprs_steppable(&aggr.aggr_expr);
120                let matches_partition = Self::check_partition(&aggr.group_expr, &partition_cols);
121                if !matches_partition && is_all_steppable {
122                    debug!("Plan is steppable: {plan}");
123                    return Ok(Commutativity::TransformedCommutative {
124                        transformer: Some(Arc::new(|plan: &LogicalPlan| {
125                            debug!("Before Step optimize: {plan}");
126                            let ret = step_aggr_to_upper_aggr(plan);
127                            ret.inspect_err(|err| {
128                                common_telemetry::error!("Failed to step aggregate plan: {err:?}");
129                            })
130                            .map(|s| TransformerAction {
131                                extra_parent_plans: s.extra_parent_plans,
132                                new_child_plan: s.new_child_plan,
133                            })
134                        })),
135                    });
136                }
137                if !matches_partition {
138                    return Ok(Commutativity::NonCommutative);
139                }
140                for expr in &aggr.aggr_expr {
141                    let commutativity = Self::check_expr(expr);
142                    if !matches!(commutativity, Commutativity::Commutative) {
143                        return Ok(commutativity);
144                    }
145                }
146                // all group by expressions are partition columns can push down, unless
147                // another push down(including `Limit` or `Sort`) is already in progress(which will then prvent next cond commutative node from being push down).
148                // TODO(discord9): This is a temporary solution(that works), a better description of
149                // commutativity is needed under this situation.
150                Commutativity::ConditionalCommutative(None)
151            }
152            LogicalPlan::Sort(_) => {
153                if partition_cols.is_empty() {
154                    return Ok(Commutativity::Commutative);
155                }
156
157                // sort plan needs to consider column priority
158                // Change Sort to MergeSort which assumes the input streams are already sorted hence can be more efficient
159                // We should ensure the number of partition is not smaller than the number of region at present. Otherwise this would result in incorrect output.
160                Commutativity::ConditionalCommutative(Some(Arc::new(merge_sort_transformer)))
161            }
162            LogicalPlan::Join(_) => Commutativity::NonCommutative,
163            LogicalPlan::Repartition(_) => {
164                // unsupported? or non-commutative
165                Commutativity::Unimplemented
166            }
167            LogicalPlan::Union(_) => Commutativity::Unimplemented,
168            LogicalPlan::TableScan(_) => Commutativity::Commutative,
169            LogicalPlan::EmptyRelation(_) => Commutativity::NonCommutative,
170            LogicalPlan::Subquery(_) => Commutativity::Unimplemented,
171            LogicalPlan::SubqueryAlias(_) => Commutativity::Commutative,
172            LogicalPlan::Limit(limit) => {
173                // Only execute `fetch` on remote nodes.
174                // wait for https://github.com/apache/arrow-datafusion/pull/7669
175                if partition_cols.is_empty() && limit.fetch.is_some() {
176                    Commutativity::Commutative
177                } else if limit.skip.is_none() && limit.fetch.is_some() {
178                    Commutativity::PartialCommutative
179                } else {
180                    Commutativity::Unimplemented
181                }
182            }
183            LogicalPlan::Extension(extension) => {
184                Self::check_extension_plan(extension.node.as_ref() as _, &partition_cols)
185            }
186            LogicalPlan::Distinct(_) => {
187                if partition_cols.is_empty() {
188                    Commutativity::Commutative
189                } else {
190                    Commutativity::Unimplemented
191                }
192            }
193            LogicalPlan::Unnest(_) => Commutativity::Commutative,
194            LogicalPlan::Statement(_) => Commutativity::Unsupported,
195            LogicalPlan::Values(_) => Commutativity::Unsupported,
196            LogicalPlan::Explain(_) => Commutativity::Unsupported,
197            LogicalPlan::Analyze(_) => Commutativity::Unsupported,
198            LogicalPlan::DescribeTable(_) => Commutativity::Unsupported,
199            LogicalPlan::Dml(_) => Commutativity::Unsupported,
200            LogicalPlan::Ddl(_) => Commutativity::Unsupported,
201            LogicalPlan::Copy(_) => Commutativity::Unsupported,
202            LogicalPlan::RecursiveQuery(_) => Commutativity::Unsupported,
203        };
204
205        Ok(comm)
206    }
207
208    pub fn check_extension_plan(
209        plan: &dyn UserDefinedLogicalNode,
210        partition_cols: &AliasMapping,
211    ) -> Commutativity {
212        match plan.name() {
213            name if name == SeriesDivide::name() => {
214                let series_divide = plan.as_any().downcast_ref::<SeriesDivide>().unwrap();
215                let tags = series_divide.tags().iter().collect::<HashSet<_>>();
216
217                for all_alias in partition_cols.values() {
218                    let all_alias = all_alias.iter().map(|c| &c.name).collect::<HashSet<_>>();
219                    if tags.intersection(&all_alias).count() == 0 {
220                        return Commutativity::NonCommutative;
221                    }
222                }
223
224                Commutativity::Commutative
225            }
226            name if name == SeriesNormalize::name()
227                || name == InstantManipulate::name()
228                || name == RangeManipulate::name() =>
229            {
230                // They should always follows Series Divide.
231                // Either all commutative or all non-commutative (which will be blocked by SeriesDivide).
232                Commutativity::Commutative
233            }
234            name if name == EmptyMetric::name()
235                || name == MergeScanLogicalPlan::name()
236                || name == MergeSortLogicalPlan::name() =>
237            {
238                Commutativity::Unimplemented
239            }
240            _ => Commutativity::Unsupported,
241        }
242    }
243
244    pub fn check_expr(expr: &Expr) -> Commutativity {
245        #[allow(deprecated)]
246        match expr {
247            Expr::Column(_)
248            | Expr::ScalarVariable(_, _)
249            | Expr::Literal(_, _)
250            | Expr::BinaryExpr(_)
251            | Expr::Not(_)
252            | Expr::IsNotNull(_)
253            | Expr::IsNull(_)
254            | Expr::IsTrue(_)
255            | Expr::IsFalse(_)
256            | Expr::IsNotTrue(_)
257            | Expr::IsNotFalse(_)
258            | Expr::Negative(_)
259            | Expr::Between(_)
260            | Expr::Exists(_)
261            | Expr::InList(_)
262            | Expr::Case(_) => Commutativity::Commutative,
263            Expr::ScalarFunction(_udf) => Commutativity::Commutative,
264            Expr::AggregateFunction(_udaf) => Commutativity::Commutative,
265
266            Expr::Like(_)
267            | Expr::SimilarTo(_)
268            | Expr::IsUnknown(_)
269            | Expr::IsNotUnknown(_)
270            | Expr::Cast(_)
271            | Expr::TryCast(_)
272            | Expr::WindowFunction(_)
273            | Expr::InSubquery(_)
274            | Expr::ScalarSubquery(_)
275            | Expr::Wildcard { .. } => Commutativity::Unimplemented,
276
277            Expr::Alias(alias) => Self::check_expr(&alias.expr),
278
279            Expr::Unnest(_)
280            | Expr::GroupingSet(_)
281            | Expr::Placeholder(_)
282            | Expr::OuterReferenceColumn(_, _) => Commutativity::Unimplemented,
283        }
284    }
285
286    /// Return true if the given expr and partition cols satisfied the rule.
287    /// In this case the plan can be treated as fully commutative.
288    ///
289    /// So only if all partition columns show up in `exprs`, return true.
290    /// Otherwise return false.
291    ///
292    fn check_partition(exprs: &[Expr], partition_cols: &AliasMapping) -> bool {
293        let mut ref_cols = HashSet::new();
294        for expr in exprs {
295            expr.add_column_refs(&mut ref_cols);
296        }
297        let ref_cols = ref_cols
298            .into_iter()
299            .map(|c| c.name.clone())
300            .collect::<HashSet<_>>();
301        for all_alias in partition_cols.values() {
302            let all_alias = all_alias
303                .iter()
304                .map(|c| c.name.clone())
305                .collect::<HashSet<_>>();
306            // check if ref columns intersect with all alias of partition columns
307            // is empty, if it's empty, not all partition columns show up in `exprs`
308            if ref_cols.intersection(&all_alias).count() == 0 {
309                return false;
310            }
311        }
312
313        true
314    }
315}
316
317pub type Transformer = Arc<dyn Fn(&LogicalPlan) -> Option<LogicalPlan>>;
318
319/// Returns transformer action that need to be applied
320pub type StageTransformer = Arc<dyn Fn(&LogicalPlan) -> DfResult<TransformerAction>>;
321
322/// The Action that a transformer should take on the plan.
323pub struct TransformerAction {
324    /// list of plans that need to be applied to parent plans, in the order of parent to child.
325    /// i.e. if this returns `[Projection, Aggregate]`, then the parent plan should be transformed to
326    /// ```ignore
327    /// Original Parent Plan:
328    ///     Projection:
329    ///         Aggregate:
330    ///             MergeScan: ...
331    /// ```
332    pub extra_parent_plans: Vec<LogicalPlan>,
333    /// new child plan, if None, use the original plan.
334    pub new_child_plan: Option<LogicalPlan>,
335}
336
337pub fn partial_commutative_transformer(plan: &LogicalPlan) -> Option<LogicalPlan> {
338    Some(plan.clone())
339}
340
341fn has_subquery(plan: &LogicalPlan) -> DfResult<bool> {
342    let mut found = false;
343    plan.apply_expressions(|e| {
344        e.apply(|x| {
345            if matches!(
346                x,
347                Expr::Exists(_) | Expr::InSubquery(_) | Expr::ScalarSubquery(_)
348            ) {
349                found = true;
350                Ok(TreeNodeRecursion::Stop)
351            } else {
352                Ok(TreeNodeRecursion::Continue)
353            }
354        })
355    })?;
356    Ok(found)
357}
358
359#[cfg(test)]
360mod test {
361    use datafusion_expr::{LogicalPlanBuilder, Sort};
362
363    use super::*;
364
365    #[test]
366    fn sort_on_empty_partition() {
367        let plan = LogicalPlan::Sort(Sort {
368            expr: vec![],
369            input: Arc::new(LogicalPlanBuilder::empty(false).build().unwrap()),
370            fetch: None,
371        });
372        assert!(matches!(
373            Categorizer::check_plan(&plan, Some(Default::default())).unwrap(),
374            Commutativity::Commutative
375        ));
376    }
377}