1use std::collections::HashSet;
16use std::sync::Arc;
17
18use datafusion::datasource::DefaultTableSource;
19use datafusion::error::Result as DfResult;
20use datafusion_common::config::ConfigOptions;
21use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
22use datafusion_common::Column;
23use datafusion_expr::expr::{Exists, InSubquery};
24use datafusion_expr::utils::expr_to_columns;
25use datafusion_expr::{col as col_fn, Expr, LogicalPlan, LogicalPlanBuilder, Subquery};
26use datafusion_optimizer::analyzer::AnalyzerRule;
27use datafusion_optimizer::simplify_expressions::SimplifyExpressions;
28use datafusion_optimizer::{OptimizerContext, OptimizerRule};
29use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
30use table::metadata::TableType;
31use table::table::adapter::DfTableProviderAdapter;
32
33use crate::dist_plan::commutativity::{
34 partial_commutative_transformer, Categorizer, Commutativity,
35};
36use crate::dist_plan::merge_scan::MergeScanLogicalPlan;
37use crate::plan::ExtractExpr;
38use crate::query_engine::DefaultSerializer;
39
40#[derive(Debug)]
41pub struct DistPlannerAnalyzer;
42
43impl AnalyzerRule for DistPlannerAnalyzer {
44 fn name(&self) -> &str {
45 "DistPlannerAnalyzer"
46 }
47
48 fn analyze(
49 &self,
50 plan: LogicalPlan,
51 _config: &ConfigOptions,
52 ) -> datafusion_common::Result<LogicalPlan> {
53 let optimizer_context = OptimizerContext::new();
55 let plan = SimplifyExpressions::new()
56 .rewrite(plan, &optimizer_context)?
57 .data;
58
59 let plan = plan.transform(&Self::inspect_plan_with_subquery)?;
60 let mut rewriter = PlanRewriter::default();
61 let result = plan.data.rewrite(&mut rewriter)?.data;
62
63 Ok(result)
64 }
65}
66
67impl DistPlannerAnalyzer {
68 fn inspect_plan_with_subquery(plan: LogicalPlan) -> DfResult<Transformed<LogicalPlan>> {
69 if let LogicalPlan::Limit(_) | LogicalPlan::Distinct(_) = &plan {
72 return Ok(Transformed::no(plan));
73 }
74
75 let exprs = plan
76 .expressions_consider_join()
77 .into_iter()
78 .map(|e| e.transform(&Self::transform_subquery).map(|x| x.data))
79 .collect::<DfResult<Vec<_>>>()?;
80
81 if !matches!(plan, LogicalPlan::Unnest(_)) {
83 let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
84 Ok(Transformed::yes(plan.with_new_exprs(exprs, inputs)?))
85 } else {
86 Ok(Transformed::no(plan))
87 }
88 }
89
90 fn transform_subquery(expr: Expr) -> DfResult<Transformed<Expr>> {
91 match expr {
92 Expr::Exists(exists) => Ok(Transformed::yes(Expr::Exists(Exists {
93 subquery: Self::handle_subquery(exists.subquery)?,
94 negated: exists.negated,
95 }))),
96 Expr::InSubquery(in_subquery) => Ok(Transformed::yes(Expr::InSubquery(InSubquery {
97 expr: in_subquery.expr,
98 subquery: Self::handle_subquery(in_subquery.subquery)?,
99 negated: in_subquery.negated,
100 }))),
101 Expr::ScalarSubquery(scalar_subquery) => Ok(Transformed::yes(Expr::ScalarSubquery(
102 Self::handle_subquery(scalar_subquery)?,
103 ))),
104
105 _ => Ok(Transformed::no(expr)),
106 }
107 }
108
109 fn handle_subquery(subquery: Subquery) -> DfResult<Subquery> {
110 let mut rewriter = PlanRewriter::default();
111 let mut rewrote_subquery = subquery
112 .subquery
113 .as_ref()
114 .clone()
115 .rewrite(&mut rewriter)?
116 .data;
117 if matches!(rewrote_subquery, LogicalPlan::Extension(_)) {
119 let output_schema = rewrote_subquery.schema().clone();
120 let project_exprs = output_schema
121 .fields()
122 .iter()
123 .map(|f| col_fn(f.name()))
124 .collect::<Vec<_>>();
125 rewrote_subquery = LogicalPlanBuilder::from(rewrote_subquery)
126 .project(project_exprs)?
127 .build()?;
128 }
129
130 Ok(Subquery {
131 subquery: Arc::new(rewrote_subquery),
132 outer_ref_columns: subquery.outer_ref_columns,
133 })
134 }
135}
136
137#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
139enum RewriterStatus {
140 #[default]
141 Unexpanded,
142 Expanded,
143}
144
145#[derive(Debug, Default)]
146struct PlanRewriter {
147 level: usize,
149 stack: Vec<(LogicalPlan, usize)>,
151 stage: Vec<LogicalPlan>,
153 status: RewriterStatus,
154 partition_cols: Option<Vec<String>>,
156 column_requirements: HashSet<Column>,
157}
158
159impl PlanRewriter {
160 fn get_parent(&self) -> Option<&LogicalPlan> {
161 self.stack
163 .iter()
164 .rev()
165 .find(|(_, level)| *level == self.level - 1)
166 .map(|(node, _)| node)
167 }
168
169 fn should_expand(&mut self, plan: &LogicalPlan) -> bool {
171 if DFLogicalSubstraitConvertor
172 .encode(plan, DefaultSerializer)
173 .is_err()
174 {
175 return true;
176 }
177 match Categorizer::check_plan(plan, self.partition_cols.clone()) {
178 Commutativity::Commutative => {}
179 Commutativity::PartialCommutative => {
180 if let Some(plan) = partial_commutative_transformer(plan) {
181 self.update_column_requirements(&plan);
182 self.stage.push(plan)
183 }
184 }
185 Commutativity::ConditionalCommutative(transformer) => {
186 if let Some(transformer) = transformer
187 && let Some(plan) = transformer(plan)
188 {
189 self.update_column_requirements(&plan);
190 self.stage.push(plan)
191 }
192 }
193 Commutativity::TransformedCommutative(transformer) => {
194 if let Some(transformer) = transformer
195 && let Some(plan) = transformer(plan)
196 {
197 self.update_column_requirements(&plan);
198 self.stage.push(plan)
199 }
200 }
201 Commutativity::NonCommutative
202 | Commutativity::Unimplemented
203 | Commutativity::Unsupported => {
204 return true;
205 }
206 }
207
208 false
209 }
210
211 fn update_column_requirements(&mut self, plan: &LogicalPlan) {
212 let mut container = HashSet::new();
213 for expr in plan.expressions() {
214 let _ = expr_to_columns(&expr, &mut container);
216 }
217
218 for col in container {
219 self.column_requirements.insert(col);
220 }
221 }
222
223 fn is_expanded(&self) -> bool {
224 self.status == RewriterStatus::Expanded
225 }
226
227 fn set_expanded(&mut self) {
228 self.status = RewriterStatus::Expanded;
229 }
230
231 fn set_unexpanded(&mut self) {
232 self.status = RewriterStatus::Unexpanded;
233 }
234
235 fn maybe_set_partitions(&mut self, plan: &LogicalPlan) {
236 if self.partition_cols.is_some() {
237 return;
239 }
240
241 if let LogicalPlan::TableScan(table_scan) = plan {
242 if let Some(source) = table_scan
243 .source
244 .as_any()
245 .downcast_ref::<DefaultTableSource>()
246 {
247 if let Some(provider) = source
248 .table_provider
249 .as_any()
250 .downcast_ref::<DfTableProviderAdapter>()
251 {
252 if provider.table().table_type() == TableType::Base {
253 let info = provider.table().table_info();
254 let partition_key_indices = info.meta.partition_key_indices.clone();
255 let schema = info.meta.schema.clone();
256 let partition_cols = partition_key_indices
257 .into_iter()
258 .map(|index| schema.column_name_by_index(index).to_string())
259 .collect::<Vec<String>>();
260 self.partition_cols = Some(partition_cols);
261 }
262 }
263 }
264 }
265 }
266
267 fn pop_stack(&mut self) {
269 self.level -= 1;
270 self.stack.pop();
271 }
272
273 fn expand(&mut self, mut on_node: LogicalPlan) -> DfResult<LogicalPlan> {
274 let schema = on_node.schema().clone();
276 let mut rewriter = EnforceDistRequirementRewriter {
277 column_requirements: std::mem::take(&mut self.column_requirements),
278 };
279 on_node = on_node.rewrite(&mut rewriter)?.data;
280
281 let mut node = MergeScanLogicalPlan::new(
283 on_node,
284 false,
285 self.partition_cols.clone().unwrap_or_default(),
288 )
289 .into_logical_plan();
290
291 for new_stage in self.stage.drain(..) {
293 node = new_stage
294 .with_new_exprs(new_stage.expressions_consider_join(), vec![node.clone()])?;
295 }
296 self.set_expanded();
297
298 let node = LogicalPlanBuilder::from(node)
300 .project(schema.iter().map(|(qualifier, field)| {
301 Expr::Column(Column::new(qualifier.cloned(), field.name()))
302 }))?
303 .build()?;
304
305 Ok(node)
306 }
307}
308
309struct EnforceDistRequirementRewriter {
316 column_requirements: HashSet<Column>,
317}
318
319impl TreeNodeRewriter for EnforceDistRequirementRewriter {
320 type Node = LogicalPlan;
321
322 fn f_down(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
323 if let LogicalPlan::Projection(ref projection) = node {
324 let mut column_requirements = std::mem::take(&mut self.column_requirements);
325 if column_requirements.is_empty() {
326 return Ok(Transformed::no(node));
327 }
328
329 for expr in &projection.expr {
330 let (qualifier, name) = expr.qualified_name();
331 let column = Column::new(qualifier, name);
332 column_requirements.remove(&column);
333 }
334 if column_requirements.is_empty() {
335 return Ok(Transformed::no(node));
336 }
337
338 let mut new_exprs = projection.expr.clone();
339 for col in &column_requirements {
340 new_exprs.push(Expr::Column(col.clone()));
341 }
342 let new_node =
343 node.with_new_exprs(new_exprs, node.inputs().into_iter().cloned().collect())?;
344 return Ok(Transformed::yes(new_node));
345 }
346
347 Ok(Transformed::no(node))
348 }
349
350 fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
351 Ok(Transformed::no(node))
352 }
353}
354
355impl TreeNodeRewriter for PlanRewriter {
356 type Node = LogicalPlan;
357
358 fn f_down<'a>(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
360 self.level += 1;
361 self.stack.push((node.clone(), self.level));
362 self.stage.clear();
364 self.set_unexpanded();
365 self.partition_cols = None;
366 Ok(Transformed::no(node))
367 }
368
369 fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
373 if self.is_expanded() {
375 self.pop_stack();
376 return Ok(Transformed::no(node));
377 }
378
379 if node.inputs().is_empty() && !matches!(node, LogicalPlan::TableScan(_)) {
381 self.set_expanded();
382 self.pop_stack();
383 return Ok(Transformed::no(node));
384 }
385
386 self.maybe_set_partitions(&node);
387
388 let Some(parent) = self.get_parent() else {
389 let node = self.expand(node)?;
390 self.pop_stack();
391 return Ok(Transformed::yes(node));
392 };
393
394 if self.should_expand(&parent.clone()) {
396 let node = self.expand(node)?;
398 self.pop_stack();
399 return Ok(Transformed::yes(node));
400 }
401
402 self.pop_stack();
403 Ok(Transformed::no(node))
404 }
405}
406
407#[cfg(test)]
408mod test {
409 use std::sync::Arc;
410
411 use datafusion::datasource::DefaultTableSource;
412 use datafusion::functions_aggregate::expr_fn::avg;
413 use datafusion_common::JoinType;
414 use datafusion_expr::{col, lit, Expr, LogicalPlanBuilder};
415 use table::table::adapter::DfTableProviderAdapter;
416 use table::table::numbers::NumbersTable;
417
418 use super::*;
419
420 #[ignore = "Projection is disabled for https://github.com/apache/arrow-datafusion/issues/6489"]
421 #[test]
422 fn transform_simple_projection_filter() {
423 let numbers_table = NumbersTable::table(0);
424 let table_source = Arc::new(DefaultTableSource::new(Arc::new(
425 DfTableProviderAdapter::new(numbers_table),
426 )));
427
428 let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
429 .unwrap()
430 .filter(col("number").lt(lit(10)))
431 .unwrap()
432 .project(vec![col("number")])
433 .unwrap()
434 .distinct()
435 .unwrap()
436 .build()
437 .unwrap();
438
439 let config = ConfigOptions::default();
440 let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
441 let expected = [
442 "Distinct:",
443 " MergeScan [is_placeholder=false]",
444 " Distinct:",
445 " Projection: t.number",
446 " Filter: t.number < Int32(10)",
447 " TableScan: t",
448 ]
449 .join("\n");
450 assert_eq!(expected, result.to_string());
451 }
452
453 #[test]
454 fn transform_aggregator() {
455 let numbers_table = NumbersTable::table(0);
456 let table_source = Arc::new(DefaultTableSource::new(Arc::new(
457 DfTableProviderAdapter::new(numbers_table),
458 )));
459
460 let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
461 .unwrap()
462 .aggregate(Vec::<Expr>::new(), vec![avg(col("number"))])
463 .unwrap()
464 .build()
465 .unwrap();
466
467 let config = ConfigOptions::default();
468 let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
469 let expected = "Projection: avg(t.number)\
470 \n MergeScan [is_placeholder=false]";
471 assert_eq!(expected, result.to_string());
472 }
473
474 #[test]
475 fn transform_distinct_order() {
476 let numbers_table = NumbersTable::table(0);
477 let table_source = Arc::new(DefaultTableSource::new(Arc::new(
478 DfTableProviderAdapter::new(numbers_table),
479 )));
480
481 let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
482 .unwrap()
483 .distinct()
484 .unwrap()
485 .sort(vec![col("number").sort(true, false)])
486 .unwrap()
487 .build()
488 .unwrap();
489
490 let config = ConfigOptions::default();
491 let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
492 let expected = ["Projection: t.number", " MergeScan [is_placeholder=false]"].join("\n");
493 assert_eq!(expected, result.to_string());
494 }
495
496 #[test]
497 fn transform_single_limit() {
498 let numbers_table = NumbersTable::table(0);
499 let table_source = Arc::new(DefaultTableSource::new(Arc::new(
500 DfTableProviderAdapter::new(numbers_table),
501 )));
502
503 let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
504 .unwrap()
505 .limit(0, Some(1))
506 .unwrap()
507 .build()
508 .unwrap();
509
510 let config = ConfigOptions::default();
511 let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
512 let expected = "Projection: t.number\
513 \n MergeScan [is_placeholder=false]";
514 assert_eq!(expected, result.to_string());
515 }
516
517 #[test]
518 fn transform_unalighed_join_with_alias() {
519 let left = NumbersTable::table(0);
520 let right = NumbersTable::table(1);
521 let left_source = Arc::new(DefaultTableSource::new(Arc::new(
522 DfTableProviderAdapter::new(left),
523 )));
524 let right_source = Arc::new(DefaultTableSource::new(Arc::new(
525 DfTableProviderAdapter::new(right),
526 )));
527
528 let right_plan = LogicalPlanBuilder::scan_with_filters("t", right_source, None, vec![])
529 .unwrap()
530 .alias("right")
531 .unwrap()
532 .build()
533 .unwrap();
534
535 let plan = LogicalPlanBuilder::scan_with_filters("t", left_source, None, vec![])
536 .unwrap()
537 .join_on(
538 right_plan,
539 JoinType::LeftSemi,
540 vec![col("t.number").eq(col("right.number"))],
541 )
542 .unwrap()
543 .limit(0, Some(1))
544 .unwrap()
545 .build()
546 .unwrap();
547
548 let config = ConfigOptions::default();
549 let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
550 let expected = [
551 "Limit: skip=0, fetch=1",
552 " LeftSemi Join: Filter: t.number = right.number",
553 " Projection: t.number",
554 " MergeScan [is_placeholder=false]",
555 " SubqueryAlias: right",
556 " Projection: t.number",
557 " MergeScan [is_placeholder=false]",
558 ]
559 .join("\n");
560 assert_eq!(expected, result.to_string());
561 }
562}