query/dist_plan/
commutativity.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::sync::Arc;
17
18use datafusion_expr::{Expr, LogicalPlan, UserDefinedLogicalNode};
19use promql::extension_plan::{
20    EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize,
21};
22
23use crate::dist_plan::merge_sort::{merge_sort_transformer, MergeSortLogicalPlan};
24use crate::dist_plan::MergeScanLogicalPlan;
25
26#[allow(dead_code)]
27pub enum Commutativity {
28    Commutative,
29    PartialCommutative,
30    ConditionalCommutative(Option<Transformer>),
31    TransformedCommutative(Option<Transformer>),
32    NonCommutative,
33    Unimplemented,
34    /// For unrelated plans like DDL
35    Unsupported,
36}
37
38pub struct Categorizer {}
39
40impl Categorizer {
41    pub fn check_plan(plan: &LogicalPlan, partition_cols: Option<Vec<String>>) -> Commutativity {
42        let partition_cols = partition_cols.unwrap_or_default();
43
44        match plan {
45            LogicalPlan::Projection(proj) => {
46                for expr in &proj.expr {
47                    let commutativity = Self::check_expr(expr);
48                    if !matches!(commutativity, Commutativity::Commutative) {
49                        return commutativity;
50                    }
51                }
52                Commutativity::Commutative
53            }
54            // TODO(ruihang): Change this to Commutative once Like is supported in substrait
55            LogicalPlan::Filter(filter) => Self::check_expr(&filter.predicate),
56            LogicalPlan::Window(_) => Commutativity::Unimplemented,
57            LogicalPlan::Aggregate(aggr) => {
58                if !Self::check_partition(&aggr.group_expr, &partition_cols) {
59                    return Commutativity::NonCommutative;
60                }
61                for expr in &aggr.aggr_expr {
62                    let commutativity = Self::check_expr(expr);
63                    if !matches!(commutativity, Commutativity::Commutative) {
64                        return commutativity;
65                    }
66                }
67                Commutativity::Commutative
68            }
69            LogicalPlan::Sort(_) => {
70                if partition_cols.is_empty() {
71                    return Commutativity::Commutative;
72                }
73
74                // sort plan needs to consider column priority
75                // Change Sort to MergeSort which assumes the input streams are already sorted hence can be more efficient
76                // We should ensure the number of partition is not smaller than the number of region at present. Otherwise this would result in incorrect output.
77                Commutativity::ConditionalCommutative(Some(Arc::new(merge_sort_transformer)))
78            }
79            LogicalPlan::Join(_) => Commutativity::NonCommutative,
80            LogicalPlan::Repartition(_) => {
81                // unsupported? or non-commutative
82                Commutativity::Unimplemented
83            }
84            LogicalPlan::Union(_) => Commutativity::Unimplemented,
85            LogicalPlan::TableScan(_) => Commutativity::Commutative,
86            LogicalPlan::EmptyRelation(_) => Commutativity::NonCommutative,
87            LogicalPlan::Subquery(_) => Commutativity::Unimplemented,
88            LogicalPlan::SubqueryAlias(_) => Commutativity::Unimplemented,
89            LogicalPlan::Limit(limit) => {
90                // Only execute `fetch` on remote nodes.
91                // wait for https://github.com/apache/arrow-datafusion/pull/7669
92                if partition_cols.is_empty() && limit.fetch.is_some() {
93                    Commutativity::Commutative
94                } else if limit.skip.is_none() && limit.fetch.is_some() {
95                    Commutativity::PartialCommutative
96                } else {
97                    Commutativity::Unimplemented
98                }
99            }
100            LogicalPlan::Extension(extension) => {
101                Self::check_extension_plan(extension.node.as_ref() as _, &partition_cols)
102            }
103            LogicalPlan::Distinct(_) => {
104                if partition_cols.is_empty() {
105                    Commutativity::Commutative
106                } else {
107                    Commutativity::Unimplemented
108                }
109            }
110            LogicalPlan::Unnest(_) => Commutativity::Commutative,
111            LogicalPlan::Statement(_) => Commutativity::Unsupported,
112            LogicalPlan::Values(_) => Commutativity::Unsupported,
113            LogicalPlan::Explain(_) => Commutativity::Unsupported,
114            LogicalPlan::Analyze(_) => Commutativity::Unsupported,
115            LogicalPlan::DescribeTable(_) => Commutativity::Unsupported,
116            LogicalPlan::Dml(_) => Commutativity::Unsupported,
117            LogicalPlan::Ddl(_) => Commutativity::Unsupported,
118            LogicalPlan::Copy(_) => Commutativity::Unsupported,
119            LogicalPlan::RecursiveQuery(_) => Commutativity::Unsupported,
120        }
121    }
122
123    pub fn check_extension_plan(
124        plan: &dyn UserDefinedLogicalNode,
125        partition_cols: &[String],
126    ) -> Commutativity {
127        match plan.name() {
128            name if name == SeriesDivide::name() => {
129                let series_divide = plan.as_any().downcast_ref::<SeriesDivide>().unwrap();
130                let tags = series_divide.tags().iter().collect::<HashSet<_>>();
131                for partition_col in partition_cols {
132                    if !tags.contains(partition_col) {
133                        return Commutativity::NonCommutative;
134                    }
135                }
136                Commutativity::Commutative
137            }
138            name if name == SeriesNormalize::name()
139                || name == InstantManipulate::name()
140                || name == RangeManipulate::name() =>
141            {
142                // They should always follows Series Divide.
143                // Either all commutative or all non-commutative (which will be blocked by SeriesDivide).
144                Commutativity::Commutative
145            }
146            name if name == EmptyMetric::name()
147                || name == MergeScanLogicalPlan::name()
148                || name == MergeSortLogicalPlan::name() =>
149            {
150                Commutativity::Unimplemented
151            }
152            _ => Commutativity::Unsupported,
153        }
154    }
155
156    pub fn check_expr(expr: &Expr) -> Commutativity {
157        match expr {
158            Expr::Column(_)
159            | Expr::ScalarVariable(_, _)
160            | Expr::Literal(_)
161            | Expr::BinaryExpr(_)
162            | Expr::Not(_)
163            | Expr::IsNotNull(_)
164            | Expr::IsNull(_)
165            | Expr::IsTrue(_)
166            | Expr::IsFalse(_)
167            | Expr::IsNotTrue(_)
168            | Expr::IsNotFalse(_)
169            | Expr::Negative(_)
170            | Expr::Between(_)
171            | Expr::Exists(_)
172            | Expr::InList(_) => Commutativity::Commutative,
173            Expr::ScalarFunction(_udf) => Commutativity::Commutative,
174            Expr::AggregateFunction(_udaf) => Commutativity::Commutative,
175
176            Expr::Like(_)
177            | Expr::SimilarTo(_)
178            | Expr::IsUnknown(_)
179            | Expr::IsNotUnknown(_)
180            | Expr::Case(_)
181            | Expr::Cast(_)
182            | Expr::TryCast(_)
183            | Expr::WindowFunction(_)
184            | Expr::InSubquery(_)
185            | Expr::ScalarSubquery(_)
186            | Expr::Wildcard { .. } => Commutativity::Unimplemented,
187
188            Expr::Alias(alias) => Self::check_expr(&alias.expr),
189
190            Expr::Unnest(_)
191            | Expr::GroupingSet(_)
192            | Expr::Placeholder(_)
193            | Expr::OuterReferenceColumn(_, _) => Commutativity::Unimplemented,
194        }
195    }
196
197    /// Return true if the given expr and partition cols satisfied the rule.
198    /// In this case the plan can be treated as fully commutative.
199    fn check_partition(exprs: &[Expr], partition_cols: &[String]) -> bool {
200        let mut ref_cols = HashSet::new();
201        for expr in exprs {
202            expr.add_column_refs(&mut ref_cols);
203        }
204        let ref_cols = ref_cols
205            .into_iter()
206            .map(|c| c.name.clone())
207            .collect::<HashSet<_>>();
208        for col in partition_cols {
209            if !ref_cols.contains(col) {
210                return false;
211            }
212        }
213
214        true
215    }
216}
217
218pub type Transformer = Arc<dyn Fn(&LogicalPlan) -> Option<LogicalPlan>>;
219
220pub fn partial_commutative_transformer(plan: &LogicalPlan) -> Option<LogicalPlan> {
221    Some(plan.clone())
222}
223
224#[cfg(test)]
225mod test {
226    use datafusion_expr::{LogicalPlanBuilder, Sort};
227
228    use super::*;
229
230    #[test]
231    fn sort_on_empty_partition() {
232        let plan = LogicalPlan::Sort(Sort {
233            expr: vec![],
234            input: Arc::new(LogicalPlanBuilder::empty(false).build().unwrap()),
235            fetch: None,
236        });
237        assert!(matches!(
238            Categorizer::check_plan(&plan, Some(vec![])),
239            Commutativity::Commutative
240        ));
241    }
242}