1use std::collections::HashSet;
16use std::sync::Arc;
17
18use common_telemetry::debug;
19use datafusion::datasource::DefaultTableSource;
20use datafusion::error::Result as DfResult;
21use datafusion_common::config::ConfigOptions;
22use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
23use datafusion_common::Column;
24use datafusion_expr::expr::{Exists, InSubquery};
25use datafusion_expr::utils::expr_to_columns;
26use datafusion_expr::{col as col_fn, Expr, LogicalPlan, LogicalPlanBuilder, Subquery};
27use datafusion_optimizer::analyzer::AnalyzerRule;
28use datafusion_optimizer::simplify_expressions::SimplifyExpressions;
29use datafusion_optimizer::{OptimizerContext, OptimizerRule};
30use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
31use table::metadata::TableType;
32use table::table::adapter::DfTableProviderAdapter;
33
34use crate::dist_plan::commutativity::{
35 partial_commutative_transformer, Categorizer, Commutativity,
36};
37use crate::dist_plan::merge_scan::MergeScanLogicalPlan;
38use crate::plan::ExtractExpr;
39use crate::query_engine::DefaultSerializer;
40
41#[derive(Debug)]
42pub struct DistPlannerAnalyzer;
43
44impl AnalyzerRule for DistPlannerAnalyzer {
45 fn name(&self) -> &str {
46 "DistPlannerAnalyzer"
47 }
48
49 fn analyze(
50 &self,
51 plan: LogicalPlan,
52 _config: &ConfigOptions,
53 ) -> datafusion_common::Result<LogicalPlan> {
54 let optimizer_context = OptimizerContext::new();
56 let plan = SimplifyExpressions::new()
57 .rewrite(plan, &optimizer_context)?
58 .data;
59
60 let plan = plan.transform(&Self::inspect_plan_with_subquery)?;
61 let mut rewriter = PlanRewriter::default();
62 let result = plan.data.rewrite(&mut rewriter)?.data;
63
64 Ok(result)
65 }
66}
67
68impl DistPlannerAnalyzer {
69 fn inspect_plan_with_subquery(plan: LogicalPlan) -> DfResult<Transformed<LogicalPlan>> {
70 if let LogicalPlan::Limit(_) | LogicalPlan::Distinct(_) = &plan {
73 return Ok(Transformed::no(plan));
74 }
75
76 let exprs = plan
77 .expressions_consider_join()
78 .into_iter()
79 .map(|e| e.transform(&Self::transform_subquery).map(|x| x.data))
80 .collect::<DfResult<Vec<_>>>()?;
81
82 if !matches!(plan, LogicalPlan::Unnest(_)) {
84 let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
85 Ok(Transformed::yes(plan.with_new_exprs(exprs, inputs)?))
86 } else {
87 Ok(Transformed::no(plan))
88 }
89 }
90
91 fn transform_subquery(expr: Expr) -> DfResult<Transformed<Expr>> {
92 match expr {
93 Expr::Exists(exists) => Ok(Transformed::yes(Expr::Exists(Exists {
94 subquery: Self::handle_subquery(exists.subquery)?,
95 negated: exists.negated,
96 }))),
97 Expr::InSubquery(in_subquery) => Ok(Transformed::yes(Expr::InSubquery(InSubquery {
98 expr: in_subquery.expr,
99 subquery: Self::handle_subquery(in_subquery.subquery)?,
100 negated: in_subquery.negated,
101 }))),
102 Expr::ScalarSubquery(scalar_subquery) => Ok(Transformed::yes(Expr::ScalarSubquery(
103 Self::handle_subquery(scalar_subquery)?,
104 ))),
105
106 _ => Ok(Transformed::no(expr)),
107 }
108 }
109
110 fn handle_subquery(subquery: Subquery) -> DfResult<Subquery> {
111 let mut rewriter = PlanRewriter::default();
112 let mut rewrote_subquery = subquery
113 .subquery
114 .as_ref()
115 .clone()
116 .rewrite(&mut rewriter)?
117 .data;
118 if matches!(rewrote_subquery, LogicalPlan::Extension(_)) {
120 let output_schema = rewrote_subquery.schema().clone();
121 let project_exprs = output_schema
122 .fields()
123 .iter()
124 .map(|f| col_fn(f.name()))
125 .collect::<Vec<_>>();
126 rewrote_subquery = LogicalPlanBuilder::from(rewrote_subquery)
127 .project(project_exprs)?
128 .build()?;
129 }
130
131 Ok(Subquery {
132 subquery: Arc::new(rewrote_subquery),
133 outer_ref_columns: subquery.outer_ref_columns,
134 })
135 }
136}
137
138#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
140enum RewriterStatus {
141 #[default]
142 Unexpanded,
143 Expanded,
144}
145
146#[derive(Debug, Default)]
147struct PlanRewriter {
148 level: usize,
150 stack: Vec<(LogicalPlan, usize)>,
152 stage: Vec<LogicalPlan>,
154 status: RewriterStatus,
155 partition_cols: Option<Vec<String>>,
157 column_requirements: HashSet<Column>,
158 expand_on_next_call: bool,
163 expand_on_next_part_cond_trans_commutative: bool,
175 new_child_plan: Option<LogicalPlan>,
176}
177
178impl PlanRewriter {
179 fn get_parent(&self) -> Option<&LogicalPlan> {
180 self.stack
182 .iter()
183 .rev()
184 .find(|(_, level)| *level == self.level - 1)
185 .map(|(node, _)| node)
186 }
187
188 fn should_expand(&mut self, plan: &LogicalPlan) -> bool {
190 if DFLogicalSubstraitConvertor
191 .encode(plan, DefaultSerializer)
192 .is_err()
193 {
194 return true;
195 }
196
197 if self.expand_on_next_call {
198 self.expand_on_next_call = false;
199 return true;
200 }
201
202 if self.expand_on_next_part_cond_trans_commutative {
203 let comm = Categorizer::check_plan(plan, self.partition_cols.clone());
204 match comm {
205 Commutativity::PartialCommutative => {
206 self.expand_on_next_part_cond_trans_commutative = false;
209 self.expand_on_next_call = true;
210 }
211 Commutativity::ConditionalCommutative(_)
212 | Commutativity::TransformedCommutative { .. } => {
213 self.expand_on_next_part_cond_trans_commutative = false;
216 return true;
217 }
218 _ => (),
219 }
220 }
221
222 match Categorizer::check_plan(plan, self.partition_cols.clone()) {
223 Commutativity::Commutative => {}
224 Commutativity::PartialCommutative => {
225 if let Some(plan) = partial_commutative_transformer(plan) {
226 self.update_column_requirements(&plan);
227 self.expand_on_next_part_cond_trans_commutative = true;
228 self.stage.push(plan)
229 }
230 }
231 Commutativity::ConditionalCommutative(transformer) => {
232 if let Some(transformer) = transformer
233 && let Some(plan) = transformer(plan)
234 {
235 self.update_column_requirements(&plan);
236 self.expand_on_next_part_cond_trans_commutative = true;
237 self.stage.push(plan)
238 }
239 }
240 Commutativity::TransformedCommutative { transformer } => {
241 if let Some(transformer) = transformer
242 && let Some(transformer_actions) = transformer(plan)
243 {
244 debug!(
245 "PlanRewriter: transformed plan: {:?}\n from {plan}",
246 transformer_actions.extra_parent_plans
247 );
248 if let Some(last_stage) = transformer_actions.extra_parent_plans.last() {
249 self.update_column_requirements(last_stage);
251 }
252 self.stage
253 .extend(transformer_actions.extra_parent_plans.into_iter().rev());
254 self.expand_on_next_call = true;
255 self.new_child_plan = transformer_actions.new_child_plan;
256 }
257 }
258 Commutativity::NonCommutative
259 | Commutativity::Unimplemented
260 | Commutativity::Unsupported => {
261 return true;
262 }
263 }
264
265 false
266 }
267
268 fn update_column_requirements(&mut self, plan: &LogicalPlan) {
269 debug!(
270 "PlanRewriter: update column requirements for plan: {plan}\n withcolumn_requirements: {:?}",
271 self.column_requirements
272 );
273 let mut container = HashSet::new();
274 for expr in plan.expressions() {
275 let _ = expr_to_columns(&expr, &mut container);
277 }
278
279 for col in container {
280 self.column_requirements.insert(col);
281 }
282 debug!(
283 "PlanRewriter: updated column requirements: {:?}",
284 self.column_requirements
285 );
286 }
287
288 fn is_expanded(&self) -> bool {
289 self.status == RewriterStatus::Expanded
290 }
291
292 fn set_expanded(&mut self) {
293 self.status = RewriterStatus::Expanded;
294 }
295
296 fn set_unexpanded(&mut self) {
297 self.status = RewriterStatus::Unexpanded;
298 }
299
300 fn maybe_set_partitions(&mut self, plan: &LogicalPlan) {
301 if self.partition_cols.is_some() {
302 return;
304 }
305
306 if let LogicalPlan::TableScan(table_scan) = plan {
307 if let Some(source) = table_scan
308 .source
309 .as_any()
310 .downcast_ref::<DefaultTableSource>()
311 {
312 if let Some(provider) = source
313 .table_provider
314 .as_any()
315 .downcast_ref::<DfTableProviderAdapter>()
316 {
317 if provider.table().table_type() == TableType::Base {
318 let info = provider.table().table_info();
319 let partition_key_indices = info.meta.partition_key_indices.clone();
320 let schema = info.meta.schema.clone();
321 let partition_cols = partition_key_indices
322 .into_iter()
323 .map(|index| schema.column_name_by_index(index).to_string())
324 .collect::<Vec<String>>();
325 self.partition_cols = Some(partition_cols);
326 }
327 }
328 }
329 }
330 }
331
332 fn pop_stack(&mut self) {
334 self.level -= 1;
335 self.stack.pop();
336 }
337
338 fn expand(&mut self, mut on_node: LogicalPlan) -> DfResult<LogicalPlan> {
339 if let Some(new_child_plan) = self.new_child_plan.take() {
340 on_node = new_child_plan;
342 }
343 let schema = on_node.schema().clone();
345 let mut rewriter = EnforceDistRequirementRewriter {
346 column_requirements: std::mem::take(&mut self.column_requirements),
347 };
348 on_node = on_node.rewrite(&mut rewriter)?.data;
349
350 let mut node = MergeScanLogicalPlan::new(
352 on_node,
353 false,
354 self.partition_cols.clone().unwrap_or_default(),
357 )
358 .into_logical_plan();
359
360 for new_stage in self.stage.drain(..) {
362 node = new_stage
363 .with_new_exprs(new_stage.expressions_consider_join(), vec![node.clone()])?;
364 }
365 self.set_expanded();
366
367 let node = LogicalPlanBuilder::from(node)
369 .project(schema.iter().map(|(qualifier, field)| {
370 Expr::Column(Column::new(qualifier.cloned(), field.name()))
371 }))?
372 .build()?;
373
374 Ok(node)
375 }
376}
377
378struct EnforceDistRequirementRewriter {
385 column_requirements: HashSet<Column>,
386}
387
388impl TreeNodeRewriter for EnforceDistRequirementRewriter {
389 type Node = LogicalPlan;
390
391 fn f_down(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
392 if let LogicalPlan::Projection(ref projection) = node {
393 let mut column_requirements = std::mem::take(&mut self.column_requirements);
394 if column_requirements.is_empty() {
395 return Ok(Transformed::no(node));
396 }
397
398 for expr in &projection.expr {
399 let (qualifier, name) = expr.qualified_name();
400 let column = Column::new(qualifier, name);
401 column_requirements.remove(&column);
402 }
403 if column_requirements.is_empty() {
404 return Ok(Transformed::no(node));
405 }
406
407 let mut new_exprs = projection.expr.clone();
408 for col in &column_requirements {
409 new_exprs.push(Expr::Column(col.clone()));
410 }
411 let new_node =
412 node.with_new_exprs(new_exprs, node.inputs().into_iter().cloned().collect())?;
413 return Ok(Transformed::yes(new_node));
414 }
415
416 Ok(Transformed::no(node))
417 }
418
419 fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
420 Ok(Transformed::no(node))
421 }
422}
423
424impl TreeNodeRewriter for PlanRewriter {
425 type Node = LogicalPlan;
426
427 fn f_down<'a>(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
429 self.level += 1;
430 self.stack.push((node.clone(), self.level));
431 self.stage.clear();
433 self.set_unexpanded();
434 self.partition_cols = None;
435 Ok(Transformed::no(node))
436 }
437
438 fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
442 if self.is_expanded() {
444 self.pop_stack();
445 return Ok(Transformed::no(node));
446 }
447
448 if node.inputs().is_empty() && !matches!(node, LogicalPlan::TableScan(_)) {
450 self.set_expanded();
451 self.pop_stack();
452 return Ok(Transformed::no(node));
453 }
454
455 self.maybe_set_partitions(&node);
456
457 let Some(parent) = self.get_parent() else {
458 let node = self.expand(node)?;
459 self.pop_stack();
460 return Ok(Transformed::yes(node));
461 };
462
463 let parent = parent.clone();
464
465 if self.should_expand(&parent) {
467 debug!("PlanRewriter: should expand child:\n {node}\n Of Parent: {parent}");
469 let node = self.expand(node);
470 debug!(
471 "PlanRewriter: expanded plan: {}",
472 match &node {
473 Ok(n) => n.to_string(),
474 Err(e) => format!("Error expanding plan: {e}"),
475 }
476 );
477 let node = node?;
478 self.pop_stack();
479 return Ok(Transformed::yes(node));
480 }
481
482 self.pop_stack();
483 Ok(Transformed::no(node))
484 }
485}
486
487#[cfg(test)]
488mod test {
489 use std::sync::Arc;
490
491 use datafusion::datasource::DefaultTableSource;
492 use datafusion::functions_aggregate::expr_fn::avg;
493 use datafusion_common::JoinType;
494 use datafusion_expr::{col, lit, Expr, LogicalPlanBuilder};
495 use table::table::adapter::DfTableProviderAdapter;
496 use table::table::numbers::NumbersTable;
497
498 use super::*;
499
500 #[ignore = "Projection is disabled for https://github.com/apache/arrow-datafusion/issues/6489"]
501 #[test]
502 fn transform_simple_projection_filter() {
503 let numbers_table = NumbersTable::table(0);
504 let table_source = Arc::new(DefaultTableSource::new(Arc::new(
505 DfTableProviderAdapter::new(numbers_table),
506 )));
507
508 let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
509 .unwrap()
510 .filter(col("number").lt(lit(10)))
511 .unwrap()
512 .project(vec![col("number")])
513 .unwrap()
514 .distinct()
515 .unwrap()
516 .build()
517 .unwrap();
518
519 let config = ConfigOptions::default();
520 let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
521 let expected = [
522 "Distinct:",
523 " MergeScan [is_placeholder=false]",
524 " Distinct:",
525 " Projection: t.number",
526 " Filter: t.number < Int32(10)",
527 " TableScan: t",
528 ]
529 .join("\n");
530 assert_eq!(expected, result.to_string());
531 }
532
533 #[test]
534 fn transform_aggregator() {
535 let numbers_table = NumbersTable::table(0);
536 let table_source = Arc::new(DefaultTableSource::new(Arc::new(
537 DfTableProviderAdapter::new(numbers_table),
538 )));
539
540 let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
541 .unwrap()
542 .aggregate(Vec::<Expr>::new(), vec![avg(col("number"))])
543 .unwrap()
544 .build()
545 .unwrap();
546
547 let config = ConfigOptions::default();
548 let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
549 let expected = "Projection: avg(t.number)\
550 \n MergeScan [is_placeholder=false]";
551 assert_eq!(expected, result.to_string());
552 }
553
554 #[test]
555 fn transform_distinct_order() {
556 let numbers_table = NumbersTable::table(0);
557 let table_source = Arc::new(DefaultTableSource::new(Arc::new(
558 DfTableProviderAdapter::new(numbers_table),
559 )));
560
561 let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
562 .unwrap()
563 .distinct()
564 .unwrap()
565 .sort(vec![col("number").sort(true, false)])
566 .unwrap()
567 .build()
568 .unwrap();
569
570 let config = ConfigOptions::default();
571 let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
572 let expected = ["Projection: t.number", " MergeScan [is_placeholder=false]"].join("\n");
573 assert_eq!(expected, result.to_string());
574 }
575
576 #[test]
577 fn transform_single_limit() {
578 let numbers_table = NumbersTable::table(0);
579 let table_source = Arc::new(DefaultTableSource::new(Arc::new(
580 DfTableProviderAdapter::new(numbers_table),
581 )));
582
583 let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
584 .unwrap()
585 .limit(0, Some(1))
586 .unwrap()
587 .build()
588 .unwrap();
589
590 let config = ConfigOptions::default();
591 let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
592 let expected = "Projection: t.number\
593 \n MergeScan [is_placeholder=false]";
594 assert_eq!(expected, result.to_string());
595 }
596
597 #[test]
598 fn transform_unalighed_join_with_alias() {
599 let left = NumbersTable::table(0);
600 let right = NumbersTable::table(1);
601 let left_source = Arc::new(DefaultTableSource::new(Arc::new(
602 DfTableProviderAdapter::new(left),
603 )));
604 let right_source = Arc::new(DefaultTableSource::new(Arc::new(
605 DfTableProviderAdapter::new(right),
606 )));
607
608 let right_plan = LogicalPlanBuilder::scan_with_filters("t", right_source, None, vec![])
609 .unwrap()
610 .alias("right")
611 .unwrap()
612 .build()
613 .unwrap();
614
615 let plan = LogicalPlanBuilder::scan_with_filters("t", left_source, None, vec![])
616 .unwrap()
617 .join_on(
618 right_plan,
619 JoinType::LeftSemi,
620 vec![col("t.number").eq(col("right.number"))],
621 )
622 .unwrap()
623 .limit(0, Some(1))
624 .unwrap()
625 .build()
626 .unwrap();
627
628 let config = ConfigOptions::default();
629 let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
630 let expected = [
631 "Limit: skip=0, fetch=1",
632 " LeftSemi Join: Filter: t.number = right.number",
633 " Projection: t.number",
634 " MergeScan [is_placeholder=false]",
635 " SubqueryAlias: right",
636 " Projection: t.number",
637 " MergeScan [is_placeholder=false]",
638 ]
639 .join("\n");
640 assert_eq!(expected, result.to_string());
641 }
642}