query/optimizer/
pass_distribution.rs1use std::sync::Arc;
16
17use datafusion::config::ConfigOptions;
18use datafusion::physical_optimizer::PhysicalOptimizerRule;
19use datafusion::physical_plan::ExecutionPlan;
20use datafusion_common::tree_node::{Transformed, TreeNode};
21use datafusion_common::Result as DfResult;
22use datafusion_physical_expr::Distribution;
23
24use crate::dist_plan::MergeScanExec;
25
26#[derive(Debug)]
34pub struct PassDistribution;
35
36impl PhysicalOptimizerRule for PassDistribution {
37 fn optimize(
38 &self,
39 plan: Arc<dyn ExecutionPlan>,
40 config: &ConfigOptions,
41 ) -> DfResult<Arc<dyn ExecutionPlan>> {
42 Self::do_optimize(plan, config)
43 }
44
45 fn name(&self) -> &str {
46 "PassDistributionRule"
47 }
48
49 fn schema_check(&self) -> bool {
50 false
51 }
52}
53
54impl PassDistribution {
55 fn do_optimize(
56 plan: Arc<dyn ExecutionPlan>,
57 _config: &ConfigOptions,
58 ) -> DfResult<Arc<dyn ExecutionPlan>> {
59 let mut distribution_requirement = None;
60 let result = plan.transform_down(|plan| {
61 if let Some(distribution) = plan.required_input_distribution().first()
62 && !matches!(distribution, Distribution::UnspecifiedDistribution)
63 && plan.name() != "HashJoinExec"
65 {
66 distribution_requirement = Some(distribution.clone());
67 }
68
69 if let Some(merge_scan) = plan.as_any().downcast_ref::<MergeScanExec>()
70 && let Some(distribution) = distribution_requirement.as_ref()
71 && let Some(new_plan) = merge_scan.try_with_new_distribution(distribution.clone())
72 {
73 Ok(Transformed::yes(Arc::new(new_plan) as _))
74 } else {
75 Ok(Transformed::no(plan))
76 }
77 })?;
78
79 Ok(result.data)
80 }
81}