query/dist_plan/
commutativity.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::sync::Arc;
17
18use common_function::aggrs::aggr_wrapper::{StateMergeHelper, is_all_aggr_exprs_steppable};
19use common_telemetry::debug;
20use datafusion::error::Result as DfResult;
21use datafusion_expr::{Expr, LogicalPlan, UserDefinedLogicalNode};
22use promql::extension_plan::{
23    EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize,
24};
25
26use crate::dist_plan::MergeScanLogicalPlan;
27use crate::dist_plan::analyzer::AliasMapping;
28use crate::dist_plan::merge_sort::{MergeSortLogicalPlan, merge_sort_transformer};
29
30pub struct StepTransformAction {
31    extra_parent_plans: Vec<LogicalPlan>,
32    new_child_plan: Option<LogicalPlan>,
33}
34
35/// generate the upper aggregation plan that will execute on the frontend.
36/// Basically a logical plan resembling the following:
37/// Projection:
38///     Aggregate:
39///
40/// from Aggregate
41///
42/// The upper Projection exists sole to make sure parent plan can recognize the output
43/// of the upper aggregation plan.
44pub fn step_aggr_to_upper_aggr(
45    aggr_plan: &LogicalPlan,
46) -> datafusion_common::Result<StepTransformAction> {
47    let LogicalPlan::Aggregate(input_aggr) = aggr_plan else {
48        return Err(datafusion_common::DataFusionError::Plan(
49            "step_aggr_to_upper_aggr only accepts Aggregate plan".to_string(),
50        ));
51    };
52    if !is_all_aggr_exprs_steppable(&input_aggr.aggr_expr) {
53        return Err(datafusion_common::DataFusionError::NotImplemented(format!(
54            "Some aggregate expressions are not steppable in [{}]",
55            input_aggr
56                .aggr_expr
57                .iter()
58                .map(|e| e.to_string())
59                .collect::<Vec<_>>()
60                .join(", ")
61        )));
62    }
63
64    let step_aggr_plan = StateMergeHelper::split_aggr_node(input_aggr.clone())?;
65
66    // TODO(discord9): remove duplication
67    let ret = StepTransformAction {
68        extra_parent_plans: vec![step_aggr_plan.upper_merge.clone()],
69        new_child_plan: Some(step_aggr_plan.lower_state.clone()),
70    };
71    Ok(ret)
72}
73
74#[allow(dead_code)]
75pub enum Commutativity {
76    Commutative,
77    PartialCommutative,
78    ConditionalCommutative(Option<Transformer>),
79    TransformedCommutative {
80        /// Return plans from parent to child order
81        transformer: Option<StageTransformer>,
82    },
83    NonCommutative,
84    Unimplemented,
85    /// For unrelated plans like DDL
86    Unsupported,
87}
88
89pub struct Categorizer {}
90
91impl Categorizer {
92    pub fn check_plan(
93        plan: &LogicalPlan,
94        partition_cols: Option<AliasMapping>,
95    ) -> DfResult<Commutativity> {
96        let partition_cols = partition_cols.unwrap_or_default();
97
98        let comm = match plan {
99            LogicalPlan::Projection(proj) => {
100                for expr in &proj.expr {
101                    let commutativity = Self::check_expr(expr);
102                    if !matches!(commutativity, Commutativity::Commutative) {
103                        return Ok(commutativity);
104                    }
105                }
106                Commutativity::Commutative
107            }
108            // TODO(ruihang): Change this to Commutative once Like is supported in substrait
109            LogicalPlan::Filter(filter) => Self::check_expr(&filter.predicate),
110            LogicalPlan::Window(_) => Commutativity::Unimplemented,
111            LogicalPlan::Aggregate(aggr) => {
112                let is_all_steppable = is_all_aggr_exprs_steppable(&aggr.aggr_expr);
113                let matches_partition = Self::check_partition(&aggr.group_expr, &partition_cols);
114                if !matches_partition && is_all_steppable {
115                    debug!("Plan is steppable: {plan}");
116                    return Ok(Commutativity::TransformedCommutative {
117                        transformer: Some(Arc::new(|plan: &LogicalPlan| {
118                            debug!("Before Step optimize: {plan}");
119                            let ret = step_aggr_to_upper_aggr(plan);
120                            ret.inspect_err(|err| {
121                                common_telemetry::error!("Failed to step aggregate plan: {err:?}");
122                            })
123                            .map(|s| TransformerAction {
124                                extra_parent_plans: s.extra_parent_plans,
125                                new_child_plan: s.new_child_plan,
126                            })
127                        })),
128                    });
129                }
130                if !matches_partition {
131                    return Ok(Commutativity::NonCommutative);
132                }
133                for expr in &aggr.aggr_expr {
134                    let commutativity = Self::check_expr(expr);
135                    if !matches!(commutativity, Commutativity::Commutative) {
136                        return Ok(commutativity);
137                    }
138                }
139                // all group by expressions are partition columns can push down, unless
140                // another push down(including `Limit` or `Sort`) is already in progress(which will then prvent next cond commutative node from being push down).
141                // TODO(discord9): This is a temporary solution(that works), a better description of
142                // commutativity is needed under this situation.
143                Commutativity::ConditionalCommutative(None)
144            }
145            LogicalPlan::Sort(_) => {
146                if partition_cols.is_empty() {
147                    return Ok(Commutativity::Commutative);
148                }
149
150                // sort plan needs to consider column priority
151                // Change Sort to MergeSort which assumes the input streams are already sorted hence can be more efficient
152                // We should ensure the number of partition is not smaller than the number of region at present. Otherwise this would result in incorrect output.
153                Commutativity::ConditionalCommutative(Some(Arc::new(merge_sort_transformer)))
154            }
155            LogicalPlan::Join(_) => Commutativity::NonCommutative,
156            LogicalPlan::Repartition(_) => {
157                // unsupported? or non-commutative
158                Commutativity::Unimplemented
159            }
160            LogicalPlan::Union(_) => Commutativity::Unimplemented,
161            LogicalPlan::TableScan(_) => Commutativity::Commutative,
162            LogicalPlan::EmptyRelation(_) => Commutativity::NonCommutative,
163            LogicalPlan::Subquery(_) => Commutativity::Unimplemented,
164            LogicalPlan::SubqueryAlias(_) => Commutativity::Commutative,
165            LogicalPlan::Limit(limit) => {
166                // Only execute `fetch` on remote nodes.
167                // wait for https://github.com/apache/arrow-datafusion/pull/7669
168                if partition_cols.is_empty() && limit.fetch.is_some() {
169                    Commutativity::Commutative
170                } else if limit.skip.is_none() && limit.fetch.is_some() {
171                    Commutativity::PartialCommutative
172                } else {
173                    Commutativity::Unimplemented
174                }
175            }
176            LogicalPlan::Extension(extension) => {
177                Self::check_extension_plan(extension.node.as_ref() as _, &partition_cols)
178            }
179            LogicalPlan::Distinct(_) => {
180                if partition_cols.is_empty() {
181                    Commutativity::Commutative
182                } else {
183                    Commutativity::Unimplemented
184                }
185            }
186            LogicalPlan::Unnest(_) => Commutativity::Commutative,
187            LogicalPlan::Statement(_) => Commutativity::Unsupported,
188            LogicalPlan::Values(_) => Commutativity::Unsupported,
189            LogicalPlan::Explain(_) => Commutativity::Unsupported,
190            LogicalPlan::Analyze(_) => Commutativity::Unsupported,
191            LogicalPlan::DescribeTable(_) => Commutativity::Unsupported,
192            LogicalPlan::Dml(_) => Commutativity::Unsupported,
193            LogicalPlan::Ddl(_) => Commutativity::Unsupported,
194            LogicalPlan::Copy(_) => Commutativity::Unsupported,
195            LogicalPlan::RecursiveQuery(_) => Commutativity::Unsupported,
196        };
197
198        Ok(comm)
199    }
200
201    pub fn check_extension_plan(
202        plan: &dyn UserDefinedLogicalNode,
203        partition_cols: &AliasMapping,
204    ) -> Commutativity {
205        match plan.name() {
206            name if name == SeriesDivide::name() => {
207                let series_divide = plan.as_any().downcast_ref::<SeriesDivide>().unwrap();
208                let tags = series_divide.tags().iter().collect::<HashSet<_>>();
209
210                for all_alias in partition_cols.values() {
211                    let all_alias = all_alias.iter().map(|c| &c.name).collect::<HashSet<_>>();
212                    if tags.intersection(&all_alias).count() == 0 {
213                        return Commutativity::NonCommutative;
214                    }
215                }
216
217                Commutativity::Commutative
218            }
219            name if name == SeriesNormalize::name()
220                || name == InstantManipulate::name()
221                || name == RangeManipulate::name() =>
222            {
223                // They should always follows Series Divide.
224                // Either all commutative or all non-commutative (which will be blocked by SeriesDivide).
225                Commutativity::Commutative
226            }
227            name if name == EmptyMetric::name()
228                || name == MergeScanLogicalPlan::name()
229                || name == MergeSortLogicalPlan::name() =>
230            {
231                Commutativity::Unimplemented
232            }
233            _ => Commutativity::Unsupported,
234        }
235    }
236
237    pub fn check_expr(expr: &Expr) -> Commutativity {
238        #[allow(deprecated)]
239        match expr {
240            Expr::Column(_)
241            | Expr::ScalarVariable(_, _)
242            | Expr::Literal(_, _)
243            | Expr::BinaryExpr(_)
244            | Expr::Not(_)
245            | Expr::IsNotNull(_)
246            | Expr::IsNull(_)
247            | Expr::IsTrue(_)
248            | Expr::IsFalse(_)
249            | Expr::IsNotTrue(_)
250            | Expr::IsNotFalse(_)
251            | Expr::Negative(_)
252            | Expr::Between(_)
253            | Expr::Exists(_)
254            | Expr::InList(_)
255            | Expr::Case(_) => Commutativity::Commutative,
256            Expr::ScalarFunction(_udf) => Commutativity::Commutative,
257            Expr::AggregateFunction(_udaf) => Commutativity::Commutative,
258
259            Expr::Like(_)
260            | Expr::SimilarTo(_)
261            | Expr::IsUnknown(_)
262            | Expr::IsNotUnknown(_)
263            | Expr::Cast(_)
264            | Expr::TryCast(_)
265            | Expr::WindowFunction(_)
266            | Expr::InSubquery(_)
267            | Expr::ScalarSubquery(_)
268            | Expr::Wildcard { .. } => Commutativity::Unimplemented,
269
270            Expr::Alias(alias) => Self::check_expr(&alias.expr),
271
272            Expr::Unnest(_)
273            | Expr::GroupingSet(_)
274            | Expr::Placeholder(_)
275            | Expr::OuterReferenceColumn(_, _) => Commutativity::Unimplemented,
276        }
277    }
278
279    /// Return true if the given expr and partition cols satisfied the rule.
280    /// In this case the plan can be treated as fully commutative.
281    ///
282    /// So only if all partition columns show up in `exprs`, return true.
283    /// Otherwise return false.
284    ///
285    fn check_partition(exprs: &[Expr], partition_cols: &AliasMapping) -> bool {
286        let mut ref_cols = HashSet::new();
287        for expr in exprs {
288            expr.add_column_refs(&mut ref_cols);
289        }
290        let ref_cols = ref_cols
291            .into_iter()
292            .map(|c| c.name.clone())
293            .collect::<HashSet<_>>();
294        for all_alias in partition_cols.values() {
295            let all_alias = all_alias
296                .iter()
297                .map(|c| c.name.clone())
298                .collect::<HashSet<_>>();
299            // check if ref columns intersect with all alias of partition columns
300            // is empty, if it's empty, not all partition columns show up in `exprs`
301            if ref_cols.intersection(&all_alias).count() == 0 {
302                return false;
303            }
304        }
305
306        true
307    }
308}
309
310pub type Transformer = Arc<dyn Fn(&LogicalPlan) -> Option<LogicalPlan>>;
311
312/// Returns transformer action that need to be applied
313pub type StageTransformer = Arc<dyn Fn(&LogicalPlan) -> DfResult<TransformerAction>>;
314
315/// The Action that a transformer should take on the plan.
316pub struct TransformerAction {
317    /// list of plans that need to be applied to parent plans, in the order of parent to child.
318    /// i.e. if this returns `[Projection, Aggregate]`, then the parent plan should be transformed to
319    /// ```ignore
320    /// Original Parent Plan:
321    ///     Projection:
322    ///         Aggregate:
323    ///             MergeScan: ...
324    /// ```
325    pub extra_parent_plans: Vec<LogicalPlan>,
326    /// new child plan, if None, use the original plan.
327    pub new_child_plan: Option<LogicalPlan>,
328}
329
330pub fn partial_commutative_transformer(plan: &LogicalPlan) -> Option<LogicalPlan> {
331    Some(plan.clone())
332}
333
334#[cfg(test)]
335mod test {
336    use datafusion_expr::{LogicalPlanBuilder, Sort};
337
338    use super::*;
339
340    #[test]
341    fn sort_on_empty_partition() {
342        let plan = LogicalPlan::Sort(Sort {
343            expr: vec![],
344            input: Arc::new(LogicalPlanBuilder::empty(false).build().unwrap()),
345            fetch: None,
346        });
347        assert!(matches!(
348            Categorizer::check_plan(&plan, Some(Default::default())).unwrap(),
349            Commutativity::Commutative
350        ));
351    }
352}