query/dist_plan/
commutativity.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::sync::Arc;
17
18use common_function::aggrs::aggr_wrapper::{StateMergeHelper, is_all_aggr_exprs_steppable};
19#[cfg(feature = "vector_index")]
20use common_function::scalars::vector::distance::{
21    VEC_COS_DISTANCE, VEC_DOT_PRODUCT, VEC_L2SQ_DISTANCE,
22};
23use common_telemetry::debug;
24use datafusion::error::Result as DfResult;
25#[cfg(feature = "vector_index")]
26use datafusion_common::DataFusionError;
27use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
28#[cfg(feature = "vector_index")]
29use datafusion_expr::Sort;
30use datafusion_expr::{Expr, LogicalPlan, UserDefinedLogicalNode};
31use promql::extension_plan::{
32    EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize,
33};
34use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
35
36use crate::dist_plan::MergeScanLogicalPlan;
37use crate::dist_plan::analyzer::AliasMapping;
38use crate::dist_plan::merge_sort::{MergeSortLogicalPlan, merge_sort_transformer};
39
40#[cfg(feature = "vector_index")]
41fn is_vector_sort(sort: &Sort) -> bool {
42    if sort.expr.len() != 1 {
43        return false;
44    }
45    let sort_expr = &sort.expr[0].expr;
46    let Expr::ScalarFunction(func) = sort_expr else {
47        return false;
48    };
49    matches!(
50        func.name().to_lowercase().as_str(),
51        VEC_L2SQ_DISTANCE | VEC_COS_DISTANCE | VEC_DOT_PRODUCT
52    )
53}
54
55#[cfg(feature = "vector_index")]
56fn vector_sort_transformer(plan: &LogicalPlan) -> DfResult<TransformerAction> {
57    let LogicalPlan::Sort(sort) = plan else {
58        return Err(DataFusionError::Internal(format!(
59            "vector_sort_transformer expects Sort, got {plan}"
60        )));
61    };
62    Ok(TransformerAction {
63        extra_parent_plans: vec![
64            MergeSortLogicalPlan::new(sort.input.clone(), sort.expr.clone(), sort.fetch)
65                .into_logical_plan(),
66        ],
67        new_child_plan: Some(LogicalPlan::Sort(sort.clone())),
68    })
69}
70
71pub struct StepTransformAction {
72    extra_parent_plans: Vec<LogicalPlan>,
73    new_child_plan: Option<LogicalPlan>,
74}
75
76/// generate the upper aggregation plan that will execute on the frontend.
77/// Basically a logical plan resembling the following:
78/// Projection:
79///     Aggregate:
80///
81/// from Aggregate
82///
83/// The upper Projection exists sole to make sure parent plan can recognize the output
84/// of the upper aggregation plan.
85pub fn step_aggr_to_upper_aggr(
86    aggr_plan: &LogicalPlan,
87) -> datafusion_common::Result<StepTransformAction> {
88    let LogicalPlan::Aggregate(input_aggr) = aggr_plan else {
89        return Err(datafusion_common::DataFusionError::Plan(
90            "step_aggr_to_upper_aggr only accepts Aggregate plan".to_string(),
91        ));
92    };
93    if !is_all_aggr_exprs_steppable(&input_aggr.aggr_expr) {
94        return Err(datafusion_common::DataFusionError::NotImplemented(format!(
95            "Some aggregate expressions are not steppable in [{}]",
96            input_aggr
97                .aggr_expr
98                .iter()
99                .map(|e| e.to_string())
100                .collect::<Vec<_>>()
101                .join(", ")
102        )));
103    }
104
105    let step_aggr_plan = StateMergeHelper::split_aggr_node(input_aggr.clone())?;
106
107    // TODO(discord9): remove duplication
108    let ret = StepTransformAction {
109        extra_parent_plans: vec![step_aggr_plan.upper_merge.clone()],
110        new_child_plan: Some(step_aggr_plan.lower_state.clone()),
111    };
112    Ok(ret)
113}
114
115#[allow(dead_code)]
116pub enum Commutativity {
117    Commutative,
118    PartialCommutative,
119    ConditionalCommutative(Option<Transformer>),
120    TransformedCommutative {
121        /// Return plans from parent to child order
122        transformer: Option<StageTransformer>,
123    },
124    NonCommutative,
125    Unimplemented,
126    /// For unrelated plans like DDL
127    Unsupported,
128}
129
130pub struct Categorizer {}
131
132impl Categorizer {
133    pub fn check_plan(
134        plan: &LogicalPlan,
135        partition_cols: Option<AliasMapping>,
136    ) -> DfResult<Commutativity> {
137        // Subquery is treated separately in `inspect_plan_with_subquery`. To avoid rewrite the
138        // "maybe rewritten" plan, stop the check here.
139        if has_subquery(plan)? {
140            return Ok(Commutativity::Unimplemented);
141        }
142
143        let partition_cols = partition_cols.unwrap_or_default();
144
145        let comm = match plan {
146            LogicalPlan::Projection(proj) => {
147                for expr in &proj.expr {
148                    let commutativity = Self::check_expr(expr);
149                    if !matches!(commutativity, Commutativity::Commutative) {
150                        return Ok(commutativity);
151                    }
152                }
153                Commutativity::Commutative
154            }
155            // TODO(ruihang): Change this to Commutative once Like is supported in substrait
156            LogicalPlan::Filter(filter) => Self::check_expr(&filter.predicate),
157            LogicalPlan::Window(_) => Commutativity::Unimplemented,
158            LogicalPlan::Aggregate(aggr) => {
159                let is_all_steppable = is_all_aggr_exprs_steppable(&aggr.aggr_expr);
160                let matches_partition = Self::check_partition(&aggr.group_expr, &partition_cols);
161                if !matches_partition && is_all_steppable {
162                    debug!("Plan is steppable: {plan}");
163                    return Ok(Commutativity::TransformedCommutative {
164                        transformer: Some(Arc::new(|plan: &LogicalPlan| {
165                            debug!("Before Step optimize: {plan}");
166                            let ret = step_aggr_to_upper_aggr(plan);
167                            ret.inspect_err(|err| {
168                                common_telemetry::error!("Failed to step aggregate plan: {err:?}");
169                            })
170                            .map(|s| TransformerAction {
171                                extra_parent_plans: s.extra_parent_plans,
172                                new_child_plan: s.new_child_plan,
173                            })
174                        })),
175                    });
176                }
177                if !matches_partition {
178                    return Ok(Commutativity::NonCommutative);
179                }
180                for expr in &aggr.aggr_expr {
181                    let commutativity = Self::check_expr(expr);
182                    if !matches!(commutativity, Commutativity::Commutative) {
183                        return Ok(commutativity);
184                    }
185                }
186                // all group by expressions are partition columns can push down, unless
187                // another push down(including `Limit` or `Sort`) is already in progress(which will then prevent next cond commutative node from being push down).
188                // TODO(discord9): This is a temporary solution(that works), a better description of
189                // commutativity is needed under this situation.
190                Commutativity::ConditionalCommutative(None)
191            }
192            LogicalPlan::Sort(_sort) => {
193                if partition_cols.is_empty() {
194                    return Ok(Commutativity::Commutative);
195                }
196
197                // sort plan needs to consider column priority
198                // Change Sort to MergeSort which assumes the input streams are already sorted hence can be more efficient.
199                #[cfg(feature = "vector_index")]
200                if is_vector_sort(_sort) {
201                    return Ok(Commutativity::TransformedCommutative {
202                        transformer: Some(Arc::new(vector_sort_transformer)),
203                    });
204                }
205                Commutativity::ConditionalCommutative(Some(Arc::new(merge_sort_transformer)))
206            }
207            LogicalPlan::Join(_) => Commutativity::NonCommutative,
208            LogicalPlan::Repartition(_) => {
209                // unsupported? or non-commutative
210                Commutativity::Unimplemented
211            }
212            LogicalPlan::Union(_) => Commutativity::Unimplemented,
213            LogicalPlan::TableScan(_) => Commutativity::Commutative,
214            LogicalPlan::EmptyRelation(_) => Commutativity::NonCommutative,
215            LogicalPlan::Subquery(_) => Commutativity::Unimplemented,
216            LogicalPlan::SubqueryAlias(_) => Commutativity::Commutative,
217            LogicalPlan::Limit(limit) => {
218                // Only execute `fetch` on remote nodes.
219                // wait for https://github.com/apache/arrow-datafusion/pull/7669
220                if partition_cols.is_empty() && limit.fetch.is_some() {
221                    Commutativity::Commutative
222                } else if limit.skip.is_none() && limit.fetch.is_some() {
223                    Commutativity::PartialCommutative
224                } else {
225                    Commutativity::Unimplemented
226                }
227            }
228            LogicalPlan::Extension(extension) => {
229                Self::check_extension_plan(extension.node.as_ref() as _, &partition_cols)
230            }
231            LogicalPlan::Distinct(_) => {
232                if partition_cols.is_empty() {
233                    Commutativity::Commutative
234                } else {
235                    Commutativity::PartialCommutative
236                }
237            }
238            LogicalPlan::Unnest(_) => Commutativity::Commutative,
239            LogicalPlan::Statement(_) => Commutativity::Unsupported,
240            LogicalPlan::Values(_) => Commutativity::Unsupported,
241            LogicalPlan::Explain(_) => Commutativity::Unsupported,
242            LogicalPlan::Analyze(_) => Commutativity::Unsupported,
243            LogicalPlan::DescribeTable(_) => Commutativity::Unsupported,
244            LogicalPlan::Dml(_) => Commutativity::Unsupported,
245            LogicalPlan::Ddl(_) => Commutativity::Unsupported,
246            LogicalPlan::Copy(_) => Commutativity::Unsupported,
247            LogicalPlan::RecursiveQuery(_) => Commutativity::Unsupported,
248        };
249
250        Ok(comm)
251    }
252
253    pub fn check_extension_plan(
254        plan: &dyn UserDefinedLogicalNode,
255        partition_cols: &AliasMapping,
256    ) -> Commutativity {
257        match plan.name() {
258            name if name == SeriesDivide::name() => {
259                let series_divide = plan.as_any().downcast_ref::<SeriesDivide>().unwrap();
260                // Metric engine `__tsid` uniquely identifies a time-series. Treat a series divide
261                // that keys by `__tsid` as commutative across regions so it can be pushed down.
262                if series_divide
263                    .tags()
264                    .iter()
265                    .any(|tag| tag == DATA_SCHEMA_TSID_COLUMN_NAME)
266                {
267                    return Commutativity::Commutative;
268                }
269
270                let tags = series_divide.tags().iter().collect::<HashSet<_>>();
271
272                for all_alias in partition_cols.values() {
273                    let all_alias = all_alias.iter().map(|c| &c.name).collect::<HashSet<_>>();
274                    if tags.intersection(&all_alias).count() == 0 {
275                        return Commutativity::NonCommutative;
276                    }
277                }
278
279                Commutativity::Commutative
280            }
281            name if name == SeriesNormalize::name()
282                || name == InstantManipulate::name()
283                || name == RangeManipulate::name() =>
284            {
285                // They should always follows Series Divide.
286                // Either all commutative or all non-commutative (which will be blocked by SeriesDivide).
287                Commutativity::Commutative
288            }
289            name if name == EmptyMetric::name()
290                || name == MergeScanLogicalPlan::name()
291                || name == MergeSortLogicalPlan::name() =>
292            {
293                Commutativity::Unimplemented
294            }
295            _ => Commutativity::Unsupported,
296        }
297    }
298
299    pub fn check_expr(expr: &Expr) -> Commutativity {
300        #[allow(deprecated)]
301        match expr {
302            Expr::Column(_)
303            | Expr::ScalarVariable(_, _)
304            | Expr::Literal(_, _)
305            | Expr::BinaryExpr(_)
306            | Expr::Not(_)
307            | Expr::IsNotNull(_)
308            | Expr::IsNull(_)
309            | Expr::IsTrue(_)
310            | Expr::IsFalse(_)
311            | Expr::IsNotTrue(_)
312            | Expr::IsNotFalse(_)
313            | Expr::Negative(_)
314            | Expr::Between(_)
315            | Expr::Exists(_)
316            | Expr::InList(_)
317            | Expr::Case(_) => Commutativity::Commutative,
318            Expr::ScalarFunction(_udf) => Commutativity::Commutative,
319            Expr::AggregateFunction(_udaf) => Commutativity::Commutative,
320
321            Expr::Like(_)
322            | Expr::SimilarTo(_)
323            | Expr::IsUnknown(_)
324            | Expr::IsNotUnknown(_)
325            | Expr::Cast(_)
326            | Expr::TryCast(_)
327            | Expr::WindowFunction(_)
328            | Expr::InSubquery(_)
329            | Expr::ScalarSubquery(_)
330            | Expr::Wildcard { .. } => Commutativity::Unimplemented,
331
332            Expr::Alias(alias) => Self::check_expr(&alias.expr),
333
334            Expr::Unnest(_)
335            | Expr::GroupingSet(_)
336            | Expr::Placeholder(_)
337            | Expr::OuterReferenceColumn(_, _) => Commutativity::Unimplemented,
338        }
339    }
340
341    /// Return true if the given expr and partition cols satisfied the rule.
342    /// In this case the plan can be treated as fully commutative.
343    ///
344    /// So only if all partition columns show up in `exprs`, return true.
345    /// Otherwise return false.
346    ///
347    fn check_partition(exprs: &[Expr], partition_cols: &AliasMapping) -> bool {
348        let mut ref_cols = HashSet::new();
349        for expr in exprs {
350            expr.add_column_refs(&mut ref_cols);
351        }
352        let ref_cols = ref_cols
353            .into_iter()
354            .map(|c| c.name.clone())
355            .collect::<HashSet<_>>();
356        for all_alias in partition_cols.values() {
357            let all_alias = all_alias
358                .iter()
359                .map(|c| c.name.clone())
360                .collect::<HashSet<_>>();
361            // check if ref columns intersect with all alias of partition columns
362            // is empty, if it's empty, not all partition columns show up in `exprs`
363            if ref_cols.intersection(&all_alias).count() == 0 {
364                return false;
365            }
366        }
367
368        true
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use std::collections::{BTreeMap, BTreeSet};
375
376    use datafusion_common::Column;
377    use datafusion_expr::LogicalPlanBuilder;
378
379    use super::*;
380
381    #[test]
382    fn series_divide_by_tsid_is_commutative() {
383        let input = LogicalPlanBuilder::empty(false).build().unwrap();
384        let series_divide = SeriesDivide::new(
385            vec![DATA_SCHEMA_TSID_COLUMN_NAME.to_string()],
386            "ts".to_string(),
387            input,
388        );
389
390        let partition_cols: AliasMapping = BTreeMap::from([(
391            "some_partition_col".to_string(),
392            BTreeSet::from([Column::from_name("some_partition_col")]),
393        )]);
394
395        let commutativity = Categorizer::check_extension_plan(&series_divide, &partition_cols);
396        assert!(matches!(commutativity, Commutativity::Commutative));
397    }
398}
399
400pub type Transformer = Arc<dyn Fn(&LogicalPlan) -> Option<LogicalPlan>>;
401
402/// Returns transformer action that need to be applied
403pub type StageTransformer = Arc<dyn Fn(&LogicalPlan) -> DfResult<TransformerAction>>;
404
405/// The Action that a transformer should take on the plan.
406pub struct TransformerAction {
407    /// list of plans that need to be applied to parent plans, in the order of parent to child.
408    /// i.e. if this returns `[Projection, Aggregate]`, then the parent plan should be transformed to
409    /// ```ignore
410    /// Original Parent Plan:
411    ///     Projection:
412    ///         Aggregate:
413    ///             MergeScan: ...
414    /// ```
415    pub extra_parent_plans: Vec<LogicalPlan>,
416    /// new child plan, if None, use the original plan.
417    pub new_child_plan: Option<LogicalPlan>,
418}
419
420pub fn partial_commutative_transformer(plan: &LogicalPlan) -> Option<LogicalPlan> {
421    Some(plan.clone())
422}
423
424fn has_subquery(plan: &LogicalPlan) -> DfResult<bool> {
425    let mut found = false;
426    plan.apply_expressions(|e| {
427        e.apply(|x| {
428            if matches!(
429                x,
430                Expr::Exists(_) | Expr::InSubquery(_) | Expr::ScalarSubquery(_)
431            ) {
432                found = true;
433                Ok(TreeNodeRecursion::Stop)
434            } else {
435                Ok(TreeNodeRecursion::Continue)
436            }
437        })
438    })?;
439    Ok(found)
440}
441
442#[cfg(test)]
443mod test {
444    use datafusion_expr::{LogicalPlanBuilder, Sort};
445
446    use super::*;
447
448    #[test]
449    fn sort_on_empty_partition() {
450        let plan = LogicalPlan::Sort(Sort {
451            expr: vec![],
452            input: Arc::new(LogicalPlanBuilder::empty(false).build().unwrap()),
453            fetch: None,
454        });
455        assert!(matches!(
456            Categorizer::check_plan(&plan, Some(Default::default())).unwrap(),
457            Commutativity::Commutative
458        ));
459    }
460}