query/optimizer/
pass_distribution.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use datafusion::config::ConfigOptions;
18use datafusion::physical_optimizer::PhysicalOptimizerRule;
19use datafusion::physical_plan::ExecutionPlan;
20use datafusion_common::tree_node::{Transformed, TreeNode};
21use datafusion_common::Result as DfResult;
22use datafusion_physical_expr::Distribution;
23
24use crate::dist_plan::MergeScanExec;
25
26/// This is a [`PhysicalOptimizerRule`] to pass distribution requirement to
27/// [`MergeScanExec`] to avoid unnecessary shuffling.
28///
29/// This rule is expected to be run before [`EnforceDistribution`].
30///
31/// [`EnforceDistribution`]: datafusion::physical_optimizer::enforce_distribution::EnforceDistribution
32/// [`MergeScanExec`]: crate::dist_plan::MergeScanExec
33#[derive(Debug)]
34pub struct PassDistribution;
35
36impl PhysicalOptimizerRule for PassDistribution {
37    fn optimize(
38        &self,
39        plan: Arc<dyn ExecutionPlan>,
40        config: &ConfigOptions,
41    ) -> DfResult<Arc<dyn ExecutionPlan>> {
42        Self::do_optimize(plan, config)
43    }
44
45    fn name(&self) -> &str {
46        "PassDistributionRule"
47    }
48
49    fn schema_check(&self) -> bool {
50        false
51    }
52}
53
54impl PassDistribution {
55    fn do_optimize(
56        plan: Arc<dyn ExecutionPlan>,
57        _config: &ConfigOptions,
58    ) -> DfResult<Arc<dyn ExecutionPlan>> {
59        let mut distribution_requirement = None;
60        let result = plan.transform_down(|plan| {
61            if let Some(distribution) = plan.required_input_distribution().first()
62                && !matches!(distribution, Distribution::UnspecifiedDistribution)
63                // incorrect workaround, doesn't fix the actual issue
64                && plan.name() != "HashJoinExec"
65            {
66                distribution_requirement = Some(distribution.clone());
67            }
68
69            if let Some(merge_scan) = plan.as_any().downcast_ref::<MergeScanExec>()
70                && let Some(distribution) = distribution_requirement.as_ref()
71                && let Some(new_plan) = merge_scan.try_with_new_distribution(distribution.clone())
72            {
73                Ok(Transformed::yes(Arc::new(new_plan) as _))
74            } else {
75                Ok(Transformed::no(plan))
76            }
77        })?;
78
79        Ok(result.data)
80    }
81}