Skip to main content

query/optimizer/
pass_distribution.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use datafusion::config::ConfigOptions;
18use datafusion::physical_optimizer::PhysicalOptimizerRule;
19use datafusion::physical_plan::ExecutionPlan;
20use datafusion::physical_plan::projection::ProjectionExec;
21use datafusion_common::Result as DfResult;
22use datafusion_physical_expr::Distribution;
23use datafusion_physical_expr::utils::map_columns_before_projection;
24
25use crate::dist_plan::MergeScanExec;
26
27/// This is a [`PhysicalOptimizerRule`] to pass distribution requirement to
28/// [`MergeScanExec`] to avoid unnecessary shuffling.
29///
30/// This rule is expected to be run before [`EnforceDistribution`].
31///
32/// [`EnforceDistribution`]: datafusion::physical_optimizer::enforce_distribution::EnforceDistribution
33/// [`MergeScanExec`]: crate::dist_plan::MergeScanExec
34#[derive(Debug)]
35pub struct PassDistribution;
36
37impl PhysicalOptimizerRule for PassDistribution {
38    fn optimize(
39        &self,
40        plan: Arc<dyn ExecutionPlan>,
41        config: &ConfigOptions,
42    ) -> DfResult<Arc<dyn ExecutionPlan>> {
43        Self::do_optimize(plan, config)
44    }
45
46    fn name(&self) -> &str {
47        "PassDistributionRule"
48    }
49
50    fn schema_check(&self) -> bool {
51        false
52    }
53}
54
55impl PassDistribution {
56    fn do_optimize(
57        plan: Arc<dyn ExecutionPlan>,
58        _config: &ConfigOptions,
59    ) -> DfResult<Arc<dyn ExecutionPlan>> {
60        // Start from root with no requirement
61        Self::rewrite_with_distribution(plan, None)
62    }
63
64    /// Top-down rewrite that propagates distribution requirements to children.
65    fn rewrite_with_distribution(
66        plan: Arc<dyn ExecutionPlan>,
67        current_req: Option<Distribution>,
68    ) -> DfResult<Arc<dyn ExecutionPlan>> {
69        // If this is a MergeScanExec, try to apply the current requirement.
70        if let Some(merge_scan) = plan.as_any().downcast_ref::<MergeScanExec>()
71            && let Some(distribution) = current_req.as_ref()
72            && let Some(new_plan) = merge_scan.try_with_new_distribution(distribution.clone())
73        {
74            // Leaf node; no children to process
75            return Ok(Arc::new(new_plan) as _);
76        }
77
78        // Compute per-child requirements from the current node.
79        let children = plan.children();
80        if children.is_empty() {
81            return Ok(plan);
82        }
83
84        let required = plan.required_input_distribution();
85        let mut new_children = Vec::with_capacity(children.len());
86        for (idx, child) in children.into_iter().enumerate() {
87            let child_req = match required.get(idx) {
88                Some(Distribution::UnspecifiedDistribution) if idx == 0 => {
89                    Self::map_hash_requirement_through_projection(plan.as_ref(), &current_req)
90                }
91                Some(Distribution::UnspecifiedDistribution) => None,
92                None => current_req.clone(),
93                Some(req) => Some(req.clone()),
94            };
95            let new_child = Self::rewrite_with_distribution(child.clone(), child_req)?;
96            new_children.push(new_child);
97        }
98
99        // Rebuild the node only if any child changed (pointer inequality)
100        let unchanged = plan
101            .children()
102            .into_iter()
103            .zip(new_children.iter())
104            .all(|(old, new)| Arc::ptr_eq(old, new));
105        if unchanged {
106            Ok(plan)
107        } else {
108            plan.with_new_children(new_children)
109        }
110    }
111
112    fn map_hash_requirement_through_projection(
113        plan: &dyn ExecutionPlan,
114        current_req: &Option<Distribution>,
115    ) -> Option<Distribution> {
116        let Some(Distribution::HashPartitioned(required_exprs)) = current_req else {
117            return None;
118        };
119
120        let projection = plan.as_any().downcast_ref::<ProjectionExec>()?;
121        let proj_exprs = projection
122            .expr()
123            .iter()
124            .map(|expr| (Arc::clone(&expr.expr), expr.alias.clone()))
125            .collect::<Vec<_>>();
126        let mapped = map_columns_before_projection(required_exprs, &proj_exprs);
127
128        (mapped.len() == required_exprs.len()).then_some(Distribution::HashPartitioned(mapped))
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use std::collections::{BTreeMap, BTreeSet};
135
136    use api::v1::region::{RemoteDynFilterUnregister, RemoteDynFilterUpdate};
137    use arrow_schema::{DataType, Field, Schema, SchemaRef, TimeUnit};
138    use async_trait::async_trait;
139    use common_query::request::QueryRequest;
140    use common_recordbatch::SendableRecordBatchStream;
141    use datafusion::common::NullEquality;
142    use datafusion::execution::SessionStateBuilder;
143    use datafusion::physical_optimizer::PhysicalOptimizerRule;
144    use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
145    use datafusion::physical_plan::projection::{ProjectionExec, ProjectionExpr};
146    use datafusion::physical_plan::{ExecutionPlanProperties, Partitioning};
147    use datafusion_expr::{JoinType, LogicalPlanBuilder};
148    use datafusion_physical_expr::PhysicalExpr;
149    use datafusion_physical_expr::expressions::Column as PhysicalColumn;
150    use session::ReadPreference;
151    use session::context::QueryContext;
152    use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
153    use store_api::storage::RegionId;
154    use table::table_name::TableName;
155
156    use super::*;
157    use crate::dist_plan::RemoteDynFilterProducerId;
158    use crate::error::Result as QueryResult;
159    use crate::region_query::RegionQueryHandler;
160
161    struct NoopRegionQueryHandler;
162
163    #[async_trait]
164    impl RegionQueryHandler for NoopRegionQueryHandler {
165        async fn do_get(
166            &self,
167            _read_preference: ReadPreference,
168            _request: QueryRequest,
169        ) -> QueryResult<SendableRecordBatchStream> {
170            unreachable!("pass distribution tests should not execute remote queries")
171        }
172
173        async fn handle_remote_dyn_filter_update(
174            &self,
175            _region_id: RegionId,
176            _query_id: String,
177            _update: RemoteDynFilterUpdate,
178        ) -> QueryResult<()> {
179            unreachable!("pass distribution tests should not send remote dyn filter updates")
180        }
181
182        async fn handle_remote_dyn_filter_unregister(
183            &self,
184            _region_id: RegionId,
185            _query_id: String,
186            _unregister: RemoteDynFilterUnregister,
187        ) -> QueryResult<()> {
188            unreachable!("pass distribution tests should not send remote dyn filter unregisters")
189        }
190    }
191
192    #[test]
193    fn passes_hash_requirement_through_projection_to_merge_scan() {
194        let schema = test_schema();
195        let left_merge_scan = Arc::new(test_merge_scan_exec(schema.clone()));
196        let right_merge_scan = Arc::new(test_merge_scan_exec(schema.clone()));
197        let left_projection = Arc::new(
198            ProjectionExec::try_new(
199                vec![
200                    ProjectionExpr::new(partition_column("greptime_value", 3), "greptime_value"),
201                    ProjectionExpr::new(
202                        partition_column(DATA_SCHEMA_TSID_COLUMN_NAME, 1),
203                        DATA_SCHEMA_TSID_COLUMN_NAME,
204                    ),
205                    ProjectionExpr::new(
206                        partition_column("greptime_timestamp", 2),
207                        "greptime_timestamp",
208                    ),
209                ],
210                left_merge_scan,
211            )
212            .unwrap(),
213        ) as Arc<dyn datafusion::physical_plan::ExecutionPlan>;
214        let join = Arc::new(
215            HashJoinExec::try_new(
216                left_projection,
217                right_merge_scan,
218                vec![
219                    (
220                        partition_column(DATA_SCHEMA_TSID_COLUMN_NAME, 1),
221                        partition_column(DATA_SCHEMA_TSID_COLUMN_NAME, 1),
222                    ),
223                    (
224                        partition_column("greptime_timestamp", 2),
225                        partition_column("greptime_timestamp", 2),
226                    ),
227                ],
228                None,
229                &JoinType::Inner,
230                None,
231                PartitionMode::Partitioned,
232                NullEquality::NullEqualsNull,
233                false,
234            )
235            .unwrap(),
236        ) as Arc<dyn datafusion::physical_plan::ExecutionPlan>;
237
238        let optimized = PassDistribution
239            .optimize(join, &ConfigOptions::default())
240            .unwrap();
241        let hash_join = optimized.as_any().downcast_ref::<HashJoinExec>().unwrap();
242        let left_projection = hash_join
243            .left()
244            .as_any()
245            .downcast_ref::<ProjectionExec>()
246            .unwrap();
247        let left_partitioning = left_projection.input().output_partitioning();
248        let right_partitioning = hash_join.right().output_partitioning();
249
250        let Partitioning::Hash(left_exprs, left_count) = left_partitioning else {
251            panic!("expected left merge scan hash partitioning");
252        };
253        let Partitioning::Hash(right_exprs, right_count) = right_partitioning else {
254            panic!("expected right merge scan hash partitioning");
255        };
256
257        assert_eq!(*left_count, 32);
258        assert_eq!(*right_count, 32);
259        assert_eq!(
260            column_names(left_exprs),
261            vec![DATA_SCHEMA_TSID_COLUMN_NAME, "greptime_timestamp"]
262        );
263        assert_eq!(
264            column_names(right_exprs),
265            vec![DATA_SCHEMA_TSID_COLUMN_NAME, "greptime_timestamp"]
266        );
267    }
268
269    fn test_merge_scan_exec(schema: SchemaRef) -> MergeScanExec {
270        let session_state = SessionStateBuilder::new().with_default_features().build();
271        let partition_cols = BTreeMap::from([
272            (
273                DATA_SCHEMA_TSID_COLUMN_NAME.to_string(),
274                BTreeSet::from([datafusion_common::Column::from_name(
275                    DATA_SCHEMA_TSID_COLUMN_NAME,
276                )]),
277            ),
278            (
279                "greptime_timestamp".to_string(),
280                BTreeSet::from([datafusion_common::Column::from_name("greptime_timestamp")]),
281            ),
282        ]);
283        let plan = LogicalPlanBuilder::empty(false).build().unwrap();
284
285        MergeScanExec::new(
286            &session_state,
287            TableName::new("greptime", "public", "test"),
288            vec![RegionId::new(1, 0), RegionId::new(1, 1)],
289            plan,
290            schema.as_ref(),
291            Arc::new(NoopRegionQueryHandler),
292            QueryContext::arc(),
293            32,
294            partition_cols,
295            Some(RemoteDynFilterProducerId::new(1)),
296        )
297        .unwrap()
298    }
299
300    fn test_schema() -> SchemaRef {
301        Arc::new(Schema::new(vec![
302            Field::new("host", DataType::Utf8, true),
303            Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
304            Field::new(
305                "greptime_timestamp",
306                DataType::Timestamp(TimeUnit::Millisecond, None),
307                false,
308            ),
309            Field::new("greptime_value", DataType::Float64, true),
310        ]))
311    }
312
313    fn partition_column(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
314        Arc::new(PhysicalColumn::new(name, index))
315    }
316
317    fn column_names(exprs: &[Arc<dyn PhysicalExpr>]) -> Vec<&str> {
318        exprs
319            .iter()
320            .map(|expr| {
321                expr.as_any()
322                    .downcast_ref::<PhysicalColumn>()
323                    .unwrap()
324                    .name()
325            })
326            .collect()
327    }
328}