1use std::sync::Arc;
16
17use datafusion::config::ConfigOptions;
18use datafusion::physical_optimizer::PhysicalOptimizerRule;
19use datafusion::physical_plan::ExecutionPlan;
20use datafusion::physical_plan::projection::ProjectionExec;
21use datafusion_common::Result as DfResult;
22use datafusion_physical_expr::Distribution;
23use datafusion_physical_expr::utils::map_columns_before_projection;
24
25use crate::dist_plan::MergeScanExec;
26
27#[derive(Debug)]
35pub struct PassDistribution;
36
37impl PhysicalOptimizerRule for PassDistribution {
38 fn optimize(
39 &self,
40 plan: Arc<dyn ExecutionPlan>,
41 config: &ConfigOptions,
42 ) -> DfResult<Arc<dyn ExecutionPlan>> {
43 Self::do_optimize(plan, config)
44 }
45
46 fn name(&self) -> &str {
47 "PassDistributionRule"
48 }
49
50 fn schema_check(&self) -> bool {
51 false
52 }
53}
54
55impl PassDistribution {
56 fn do_optimize(
57 plan: Arc<dyn ExecutionPlan>,
58 _config: &ConfigOptions,
59 ) -> DfResult<Arc<dyn ExecutionPlan>> {
60 Self::rewrite_with_distribution(plan, None)
62 }
63
64 fn rewrite_with_distribution(
66 plan: Arc<dyn ExecutionPlan>,
67 current_req: Option<Distribution>,
68 ) -> DfResult<Arc<dyn ExecutionPlan>> {
69 if let Some(merge_scan) = plan.as_any().downcast_ref::<MergeScanExec>()
71 && let Some(distribution) = current_req.as_ref()
72 && let Some(new_plan) = merge_scan.try_with_new_distribution(distribution.clone())
73 {
74 return Ok(Arc::new(new_plan) as _);
76 }
77
78 let children = plan.children();
80 if children.is_empty() {
81 return Ok(plan);
82 }
83
84 let required = plan.required_input_distribution();
85 let mut new_children = Vec::with_capacity(children.len());
86 for (idx, child) in children.into_iter().enumerate() {
87 let child_req = match required.get(idx) {
88 Some(Distribution::UnspecifiedDistribution) if idx == 0 => {
89 Self::map_hash_requirement_through_projection(plan.as_ref(), ¤t_req)
90 }
91 Some(Distribution::UnspecifiedDistribution) => None,
92 None => current_req.clone(),
93 Some(req) => Some(req.clone()),
94 };
95 let new_child = Self::rewrite_with_distribution(child.clone(), child_req)?;
96 new_children.push(new_child);
97 }
98
99 let unchanged = plan
101 .children()
102 .into_iter()
103 .zip(new_children.iter())
104 .all(|(old, new)| Arc::ptr_eq(old, new));
105 if unchanged {
106 Ok(plan)
107 } else {
108 plan.with_new_children(new_children)
109 }
110 }
111
112 fn map_hash_requirement_through_projection(
113 plan: &dyn ExecutionPlan,
114 current_req: &Option<Distribution>,
115 ) -> Option<Distribution> {
116 let Some(Distribution::HashPartitioned(required_exprs)) = current_req else {
117 return None;
118 };
119
120 let projection = plan.as_any().downcast_ref::<ProjectionExec>()?;
121 let proj_exprs = projection
122 .expr()
123 .iter()
124 .map(|expr| (Arc::clone(&expr.expr), expr.alias.clone()))
125 .collect::<Vec<_>>();
126 let mapped = map_columns_before_projection(required_exprs, &proj_exprs);
127
128 (mapped.len() == required_exprs.len()).then_some(Distribution::HashPartitioned(mapped))
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use std::collections::{BTreeMap, BTreeSet};
135
136 use api::v1::region::{RemoteDynFilterUnregister, RemoteDynFilterUpdate};
137 use arrow_schema::{DataType, Field, Schema, SchemaRef, TimeUnit};
138 use async_trait::async_trait;
139 use common_query::request::QueryRequest;
140 use common_recordbatch::SendableRecordBatchStream;
141 use datafusion::common::NullEquality;
142 use datafusion::execution::SessionStateBuilder;
143 use datafusion::physical_optimizer::PhysicalOptimizerRule;
144 use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
145 use datafusion::physical_plan::projection::{ProjectionExec, ProjectionExpr};
146 use datafusion::physical_plan::{ExecutionPlanProperties, Partitioning};
147 use datafusion_expr::{JoinType, LogicalPlanBuilder};
148 use datafusion_physical_expr::PhysicalExpr;
149 use datafusion_physical_expr::expressions::Column as PhysicalColumn;
150 use session::ReadPreference;
151 use session::context::QueryContext;
152 use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
153 use store_api::storage::RegionId;
154 use table::table_name::TableName;
155
156 use super::*;
157 use crate::dist_plan::RemoteDynFilterProducerId;
158 use crate::error::Result as QueryResult;
159 use crate::region_query::RegionQueryHandler;
160
161 struct NoopRegionQueryHandler;
162
163 #[async_trait]
164 impl RegionQueryHandler for NoopRegionQueryHandler {
165 async fn do_get(
166 &self,
167 _read_preference: ReadPreference,
168 _request: QueryRequest,
169 ) -> QueryResult<SendableRecordBatchStream> {
170 unreachable!("pass distribution tests should not execute remote queries")
171 }
172
173 async fn handle_remote_dyn_filter_update(
174 &self,
175 _region_id: RegionId,
176 _query_id: String,
177 _update: RemoteDynFilterUpdate,
178 ) -> QueryResult<()> {
179 unreachable!("pass distribution tests should not send remote dyn filter updates")
180 }
181
182 async fn handle_remote_dyn_filter_unregister(
183 &self,
184 _region_id: RegionId,
185 _query_id: String,
186 _unregister: RemoteDynFilterUnregister,
187 ) -> QueryResult<()> {
188 unreachable!("pass distribution tests should not send remote dyn filter unregisters")
189 }
190 }
191
192 #[test]
193 fn passes_hash_requirement_through_projection_to_merge_scan() {
194 let schema = test_schema();
195 let left_merge_scan = Arc::new(test_merge_scan_exec(schema.clone()));
196 let right_merge_scan = Arc::new(test_merge_scan_exec(schema.clone()));
197 let left_projection = Arc::new(
198 ProjectionExec::try_new(
199 vec![
200 ProjectionExpr::new(partition_column("greptime_value", 3), "greptime_value"),
201 ProjectionExpr::new(
202 partition_column(DATA_SCHEMA_TSID_COLUMN_NAME, 1),
203 DATA_SCHEMA_TSID_COLUMN_NAME,
204 ),
205 ProjectionExpr::new(
206 partition_column("greptime_timestamp", 2),
207 "greptime_timestamp",
208 ),
209 ],
210 left_merge_scan,
211 )
212 .unwrap(),
213 ) as Arc<dyn datafusion::physical_plan::ExecutionPlan>;
214 let join = Arc::new(
215 HashJoinExec::try_new(
216 left_projection,
217 right_merge_scan,
218 vec![
219 (
220 partition_column(DATA_SCHEMA_TSID_COLUMN_NAME, 1),
221 partition_column(DATA_SCHEMA_TSID_COLUMN_NAME, 1),
222 ),
223 (
224 partition_column("greptime_timestamp", 2),
225 partition_column("greptime_timestamp", 2),
226 ),
227 ],
228 None,
229 &JoinType::Inner,
230 None,
231 PartitionMode::Partitioned,
232 NullEquality::NullEqualsNull,
233 false,
234 )
235 .unwrap(),
236 ) as Arc<dyn datafusion::physical_plan::ExecutionPlan>;
237
238 let optimized = PassDistribution
239 .optimize(join, &ConfigOptions::default())
240 .unwrap();
241 let hash_join = optimized.as_any().downcast_ref::<HashJoinExec>().unwrap();
242 let left_projection = hash_join
243 .left()
244 .as_any()
245 .downcast_ref::<ProjectionExec>()
246 .unwrap();
247 let left_partitioning = left_projection.input().output_partitioning();
248 let right_partitioning = hash_join.right().output_partitioning();
249
250 let Partitioning::Hash(left_exprs, left_count) = left_partitioning else {
251 panic!("expected left merge scan hash partitioning");
252 };
253 let Partitioning::Hash(right_exprs, right_count) = right_partitioning else {
254 panic!("expected right merge scan hash partitioning");
255 };
256
257 assert_eq!(*left_count, 32);
258 assert_eq!(*right_count, 32);
259 assert_eq!(
260 column_names(left_exprs),
261 vec![DATA_SCHEMA_TSID_COLUMN_NAME, "greptime_timestamp"]
262 );
263 assert_eq!(
264 column_names(right_exprs),
265 vec![DATA_SCHEMA_TSID_COLUMN_NAME, "greptime_timestamp"]
266 );
267 }
268
269 fn test_merge_scan_exec(schema: SchemaRef) -> MergeScanExec {
270 let session_state = SessionStateBuilder::new().with_default_features().build();
271 let partition_cols = BTreeMap::from([
272 (
273 DATA_SCHEMA_TSID_COLUMN_NAME.to_string(),
274 BTreeSet::from([datafusion_common::Column::from_name(
275 DATA_SCHEMA_TSID_COLUMN_NAME,
276 )]),
277 ),
278 (
279 "greptime_timestamp".to_string(),
280 BTreeSet::from([datafusion_common::Column::from_name("greptime_timestamp")]),
281 ),
282 ]);
283 let plan = LogicalPlanBuilder::empty(false).build().unwrap();
284
285 MergeScanExec::new(
286 &session_state,
287 TableName::new("greptime", "public", "test"),
288 vec![RegionId::new(1, 0), RegionId::new(1, 1)],
289 plan,
290 schema.as_ref(),
291 Arc::new(NoopRegionQueryHandler),
292 QueryContext::arc(),
293 32,
294 partition_cols,
295 Some(RemoteDynFilterProducerId::new(1)),
296 )
297 .unwrap()
298 }
299
300 fn test_schema() -> SchemaRef {
301 Arc::new(Schema::new(vec![
302 Field::new("host", DataType::Utf8, true),
303 Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
304 Field::new(
305 "greptime_timestamp",
306 DataType::Timestamp(TimeUnit::Millisecond, None),
307 false,
308 ),
309 Field::new("greptime_value", DataType::Float64, true),
310 ]))
311 }
312
313 fn partition_column(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
314 Arc::new(PhysicalColumn::new(name, index))
315 }
316
317 fn column_names(exprs: &[Arc<dyn PhysicalExpr>]) -> Vec<&str> {
318 exprs
319 .iter()
320 .map(|expr| {
321 expr.as_any()
322 .downcast_ref::<PhysicalColumn>()
323 .unwrap()
324 .name()
325 })
326 .collect()
327 }
328}