query/optimizer/
pass_distribution.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use datafusion::config::ConfigOptions;
18use datafusion::physical_optimizer::PhysicalOptimizerRule;
19use datafusion::physical_plan::ExecutionPlan;
20use datafusion_common::Result as DfResult;
21use datafusion_physical_expr::Distribution;
22
23use crate::dist_plan::MergeScanExec;
24
25/// This is a [`PhysicalOptimizerRule`] to pass distribution requirement to
26/// [`MergeScanExec`] to avoid unnecessary shuffling.
27///
28/// This rule is expected to be run before [`EnforceDistribution`].
29///
30/// [`EnforceDistribution`]: datafusion::physical_optimizer::enforce_distribution::EnforceDistribution
31/// [`MergeScanExec`]: crate::dist_plan::MergeScanExec
32#[derive(Debug)]
33pub struct PassDistribution;
34
35impl PhysicalOptimizerRule for PassDistribution {
36    fn optimize(
37        &self,
38        plan: Arc<dyn ExecutionPlan>,
39        config: &ConfigOptions,
40    ) -> DfResult<Arc<dyn ExecutionPlan>> {
41        Self::do_optimize(plan, config)
42    }
43
44    fn name(&self) -> &str {
45        "PassDistributionRule"
46    }
47
48    fn schema_check(&self) -> bool {
49        false
50    }
51}
52
53impl PassDistribution {
54    fn do_optimize(
55        plan: Arc<dyn ExecutionPlan>,
56        _config: &ConfigOptions,
57    ) -> DfResult<Arc<dyn ExecutionPlan>> {
58        // Start from root with no requirement
59        Self::rewrite_with_distribution(plan, None)
60    }
61
62    /// Top-down rewrite that propagates distribution requirements to children.
63    fn rewrite_with_distribution(
64        plan: Arc<dyn ExecutionPlan>,
65        current_req: Option<Distribution>,
66    ) -> DfResult<Arc<dyn ExecutionPlan>> {
67        // If this is a MergeScanExec, try to apply the current requirement.
68        if let Some(merge_scan) = plan.as_any().downcast_ref::<MergeScanExec>()
69            && let Some(distribution) = current_req.as_ref()
70            && let Some(new_plan) = merge_scan.try_with_new_distribution(distribution.clone())
71        {
72            // Leaf node; no children to process
73            return Ok(Arc::new(new_plan) as _);
74        }
75
76        // Compute per-child requirements from the current node.
77        let children = plan.children();
78        if children.is_empty() {
79            return Ok(plan);
80        }
81
82        let required = plan.required_input_distribution();
83        let mut new_children = Vec::with_capacity(children.len());
84        for (idx, child) in children.into_iter().enumerate() {
85            let child_req = match required.get(idx) {
86                Some(Distribution::UnspecifiedDistribution) => None,
87                None => current_req.clone(),
88                Some(req) => Some(req.clone()),
89            };
90            let new_child = Self::rewrite_with_distribution(child.clone(), child_req)?;
91            new_children.push(new_child);
92        }
93
94        // Rebuild the node only if any child changed (pointer inequality)
95        let unchanged = plan
96            .children()
97            .into_iter()
98            .zip(new_children.iter())
99            .all(|(old, new)| Arc::ptr_eq(old, new));
100        if unchanged {
101            Ok(plan)
102        } else {
103            plan.with_new_children(new_children)
104        }
105    }
106}