1use std::collections::{BTreeSet, HashSet};
16use std::sync::Arc;
17
18use common_telemetry::debug;
19use datafusion::config::{ConfigExtension, ExtensionOptions};
20use datafusion::datasource::DefaultTableSource;
21use datafusion::error::Result as DfResult;
22use datafusion_common::Column;
23use datafusion_common::config::ConfigOptions;
24use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
25use datafusion_expr::expr::{Exists, InSubquery};
26use datafusion_expr::utils::expr_to_columns;
27use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder, Subquery, col as col_fn};
28use datafusion_optimizer::analyzer::AnalyzerRule;
29use datafusion_optimizer::simplify_expressions::SimplifyExpressions;
30use datafusion_optimizer::{OptimizerContext, OptimizerRule};
31use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
32use table::metadata::TableType;
33use table::table::adapter::DfTableProviderAdapter;
34
35use crate::dist_plan::analyzer::utils::aliased_columns_for;
36use crate::dist_plan::commutativity::{
37 Categorizer, Commutativity, partial_commutative_transformer,
38};
39use crate::dist_plan::merge_scan::MergeScanLogicalPlan;
40use crate::metrics::PUSH_DOWN_FALLBACK_ERRORS_TOTAL;
41use crate::plan::ExtractExpr;
42use crate::query_engine::DefaultSerializer;
43
44#[cfg(test)]
45mod test;
46
47mod fallback;
48mod utils;
49
50pub(crate) use utils::AliasMapping;
51
52#[derive(Debug, Clone)]
53pub struct DistPlannerOptions {
54 pub allow_query_fallback: bool,
55}
56
57impl ConfigExtension for DistPlannerOptions {
58 const PREFIX: &'static str = "dist_planner";
59}
60
61impl ExtensionOptions for DistPlannerOptions {
62 fn as_any(&self) -> &dyn std::any::Any {
63 self
64 }
65
66 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
67 self
68 }
69
70 fn cloned(&self) -> Box<dyn ExtensionOptions> {
71 Box::new(self.clone())
72 }
73
74 fn set(&mut self, key: &str, value: &str) -> DfResult<()> {
75 Err(datafusion_common::DataFusionError::NotImplemented(format!(
76 "DistPlannerOptions does not support set key: {key} with value: {value}"
77 )))
78 }
79
80 fn entries(&self) -> Vec<datafusion::config::ConfigEntry> {
81 vec![datafusion::config::ConfigEntry {
82 key: "allow_query_fallback".to_string(),
83 value: Some(self.allow_query_fallback.to_string()),
84 description: "Allow query fallback to fallback plan rewriter",
85 }]
86 }
87}
88
89#[derive(Debug)]
90pub struct DistPlannerAnalyzer;
91
92impl AnalyzerRule for DistPlannerAnalyzer {
93 fn name(&self) -> &str {
94 "DistPlannerAnalyzer"
95 }
96
97 fn analyze(
98 &self,
99 plan: LogicalPlan,
100 config: &ConfigOptions,
101 ) -> datafusion_common::Result<LogicalPlan> {
102 let optimizer_context = OptimizerContext::new();
104 let plan = SimplifyExpressions::new()
105 .rewrite(plan, &optimizer_context)?
106 .data;
107
108 let opt = config.extensions.get::<DistPlannerOptions>();
109 let allow_fallback = opt.map(|o| o.allow_query_fallback).unwrap_or(false);
110
111 let result = match self.try_push_down(plan.clone()) {
112 Ok(plan) => plan,
113 Err(err) => {
114 if allow_fallback {
115 common_telemetry::warn!(err; "Failed to push down plan, using fallback plan rewriter for plan: {plan}");
116 PUSH_DOWN_FALLBACK_ERRORS_TOTAL.inc();
118 self.use_fallback(plan)?
119 } else {
120 return Err(err);
121 }
122 }
123 };
124
125 Ok(result)
126 }
127}
128
129impl DistPlannerAnalyzer {
130 fn try_push_down(&self, plan: LogicalPlan) -> DfResult<LogicalPlan> {
132 let plan = plan.transform(&Self::inspect_plan_with_subquery)?;
133 let mut rewriter = PlanRewriter::default();
134 let result = plan.data.rewrite(&mut rewriter)?.data;
135 Ok(result)
136 }
137
138 fn use_fallback(&self, plan: LogicalPlan) -> DfResult<LogicalPlan> {
140 let mut rewriter = fallback::FallbackPlanRewriter;
141 let result = plan.rewrite(&mut rewriter)?.data;
142 Ok(result)
143 }
144
145 fn inspect_plan_with_subquery(plan: LogicalPlan) -> DfResult<Transformed<LogicalPlan>> {
146 if let LogicalPlan::Limit(_) | LogicalPlan::Distinct(_) = &plan {
149 return Ok(Transformed::no(plan));
150 }
151
152 let exprs = plan
153 .expressions_consider_join()
154 .into_iter()
155 .map(|e| e.transform(&Self::transform_subquery).map(|x| x.data))
156 .collect::<DfResult<Vec<_>>>()?;
157
158 if !matches!(plan, LogicalPlan::Unnest(_)) {
160 let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
161 Ok(Transformed::yes(plan.with_new_exprs(exprs, inputs)?))
162 } else {
163 Ok(Transformed::no(plan))
164 }
165 }
166
167 fn transform_subquery(expr: Expr) -> DfResult<Transformed<Expr>> {
168 match expr {
169 Expr::Exists(exists) => Ok(Transformed::yes(Expr::Exists(Exists {
170 subquery: Self::handle_subquery(exists.subquery)?,
171 negated: exists.negated,
172 }))),
173 Expr::InSubquery(in_subquery) => Ok(Transformed::yes(Expr::InSubquery(InSubquery {
174 expr: in_subquery.expr,
175 subquery: Self::handle_subquery(in_subquery.subquery)?,
176 negated: in_subquery.negated,
177 }))),
178 Expr::ScalarSubquery(scalar_subquery) => Ok(Transformed::yes(Expr::ScalarSubquery(
179 Self::handle_subquery(scalar_subquery)?,
180 ))),
181
182 _ => Ok(Transformed::no(expr)),
183 }
184 }
185
186 fn handle_subquery(subquery: Subquery) -> DfResult<Subquery> {
187 let mut rewriter = PlanRewriter::default();
188 let mut rewrote_subquery = subquery
189 .subquery
190 .as_ref()
191 .clone()
192 .rewrite(&mut rewriter)?
193 .data;
194 if matches!(rewrote_subquery, LogicalPlan::Extension(_)) {
196 let output_schema = rewrote_subquery.schema().clone();
197 let project_exprs = output_schema
198 .fields()
199 .iter()
200 .map(|f| col_fn(f.name()))
201 .collect::<Vec<_>>();
202 rewrote_subquery = LogicalPlanBuilder::from(rewrote_subquery)
203 .project(project_exprs)?
204 .build()?;
205 }
206
207 Ok(Subquery {
208 subquery: Arc::new(rewrote_subquery),
209 outer_ref_columns: subquery.outer_ref_columns,
210 spans: Default::default(),
211 })
212 }
213}
214
215#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
217enum RewriterStatus {
218 #[default]
219 Unexpanded,
220 Expanded,
221}
222
223#[derive(Debug, Default)]
224struct PlanRewriter {
225 level: usize,
227 stack: Vec<(LogicalPlan, usize)>,
229 stage: Vec<LogicalPlan>,
231 status: RewriterStatus,
232 partition_cols: Option<AliasMapping>,
234 column_requirements: Vec<(HashSet<Column>, usize)>,
260 expand_on_next_call: bool,
265 expand_on_next_part_cond_trans_commutative: bool,
277 new_child_plan: Option<LogicalPlan>,
278}
279
280impl PlanRewriter {
281 fn get_parent(&self) -> Option<&LogicalPlan> {
282 self.stack
284 .iter()
285 .rev()
286 .find(|(_, level)| *level == self.level - 1)
287 .map(|(node, _)| node)
288 }
289
290 fn should_expand(&mut self, plan: &LogicalPlan) -> bool {
292 debug!(
293 "Check should_expand at level: {} with Stack:\n{}, ",
294 self.level,
295 self.stack
296 .iter()
297 .map(|(p, l)| format!("{l}:{}{}", " ".repeat(l - 1), p.display()))
298 .collect::<Vec<String>>()
299 .join("\n"),
300 );
301 if DFLogicalSubstraitConvertor
302 .encode(plan, DefaultSerializer)
303 .is_err()
304 {
305 return true;
306 }
307
308 if self.expand_on_next_call {
309 self.expand_on_next_call = false;
310 return true;
311 }
312
313 if self.expand_on_next_part_cond_trans_commutative {
314 let comm = Categorizer::check_plan(plan, self.partition_cols.clone());
315 match comm {
316 Commutativity::PartialCommutative => {
317 self.expand_on_next_part_cond_trans_commutative = false;
323 self.expand_on_next_call = true;
324 }
325 Commutativity::ConditionalCommutative(_)
326 | Commutativity::TransformedCommutative { .. } => {
327 self.expand_on_next_part_cond_trans_commutative = false;
330 return true;
331 }
332 _ => (),
333 }
334 }
335
336 match Categorizer::check_plan(plan, self.partition_cols.clone()) {
337 Commutativity::Commutative => {}
338 Commutativity::PartialCommutative => {
339 if let Some(plan) = partial_commutative_transformer(plan) {
340 self.update_column_requirements(&plan, self.level - 1);
342 self.expand_on_next_part_cond_trans_commutative = true;
343 self.stage.push(plan)
344 }
345 }
346 Commutativity::ConditionalCommutative(transformer) => {
347 if let Some(transformer) = transformer
348 && let Some(plan) = transformer(plan)
349 {
350 self.update_column_requirements(&plan, self.level - 1);
352 self.expand_on_next_part_cond_trans_commutative = true;
353 self.stage.push(plan)
354 }
355 }
356 Commutativity::TransformedCommutative { transformer } => {
357 if let Some(transformer) = transformer
358 && let Some(transformer_actions) = transformer(plan)
359 {
360 debug!(
361 "PlanRewriter: transformed plan: {}\n from {plan}",
362 transformer_actions
363 .extra_parent_plans
364 .iter()
365 .enumerate()
366 .map(|(i, p)| format!(
367 "Extra {i}-th parent plan from parent to child = {}",
368 p.display()
369 ))
370 .collect::<Vec<_>>()
371 .join("\n")
372 );
373 if let Some(new_child_plan) = &transformer_actions.new_child_plan {
374 debug!("PlanRewriter: new child plan: {}", new_child_plan);
375 }
376 if let Some(last_stage) = transformer_actions.extra_parent_plans.last() {
377 self.update_column_requirements(last_stage, self.level - 1);
380 }
381 self.stage
382 .extend(transformer_actions.extra_parent_plans.into_iter().rev());
383 self.expand_on_next_call = true;
384 self.new_child_plan = transformer_actions.new_child_plan;
385 }
386 }
387 Commutativity::NonCommutative
388 | Commutativity::Unimplemented
389 | Commutativity::Unsupported => {
390 return true;
391 }
392 }
393
394 false
395 }
396
397 fn update_column_requirements(&mut self, plan: &LogicalPlan, plan_level: usize) {
401 debug!(
402 "PlanRewriter: update column requirements for plan: {plan}\n with old column_requirements: {:?}",
403 self.column_requirements
404 );
405 let mut container = HashSet::new();
406 for expr in plan.expressions() {
407 let _ = expr_to_columns(&expr, &mut container);
409 }
410
411 self.column_requirements.push((container, plan_level));
412 debug!(
413 "PlanRewriter: updated column requirements: {:?}",
414 self.column_requirements
415 );
416 }
417
418 fn is_expanded(&self) -> bool {
419 self.status == RewriterStatus::Expanded
420 }
421
422 fn set_expanded(&mut self) {
423 self.status = RewriterStatus::Expanded;
424 }
425
426 fn set_unexpanded(&mut self) {
427 self.status = RewriterStatus::Unexpanded;
428 }
429
430 fn maybe_set_partitions(&mut self, plan: &LogicalPlan) -> DfResult<()> {
431 if let Some(part_cols) = &mut self.partition_cols {
432 let child = plan.inputs().first().cloned().ok_or_else(|| {
434 datafusion_common::DataFusionError::Internal(format!(
435 "PlanRewriter: maybe_set_partitions: plan has no child: {plan}"
436 ))
437 })?;
438
439 for (_col_name, alias_set) in part_cols.iter_mut() {
440 let aliased_cols = aliased_columns_for(
441 &alias_set.clone().into_iter().collect(),
442 plan,
443 Some(child),
444 )?;
445 *alias_set = aliased_cols.into_values().flatten().collect();
446 }
447
448 debug!(
449 "PlanRewriter: maybe_set_partitions: updated partition columns: {:?} at plan: {}",
450 part_cols,
451 plan.display()
452 );
453
454 return Ok(());
455 }
456
457 if let LogicalPlan::TableScan(table_scan) = plan
458 && let Some(source) = table_scan
459 .source
460 .as_any()
461 .downcast_ref::<DefaultTableSource>()
462 && let Some(provider) = source
463 .table_provider
464 .as_any()
465 .downcast_ref::<DfTableProviderAdapter>()
466 {
467 let table = provider.table();
468 if table.table_type() == TableType::Base {
469 let info = table.table_info();
470 let partition_key_indices = info.meta.partition_key_indices.clone();
471 let schema = info.meta.schema.clone();
472 let mut partition_cols = partition_key_indices
473 .into_iter()
474 .map(|index| schema.column_name_by_index(index).to_string())
475 .collect::<Vec<String>>();
476
477 let partition_rules = table.partition_rules();
478 let exist_phy_part_cols_not_in_logical_table = partition_rules
479 .map(|r| !r.extra_phy_cols_not_in_logical_table.is_empty())
480 .unwrap_or(false);
481
482 if exist_phy_part_cols_not_in_logical_table && partition_cols.is_empty() {
483 partition_cols.push("__OTHER_PHYSICAL_PART_COLS_PLACEHOLDER__".to_string());
491 }
492 self.partition_cols = Some(
493 partition_cols
494 .into_iter()
495 .map(|c| {
496 let index =
497 plan.schema().index_of_column_by_name(None, &c).ok_or_else(|| {
498 datafusion_common::DataFusionError::Internal(
499 format!(
500 "PlanRewriter: maybe_set_partitions: column {c} not found in schema of plan: {plan}"
501 ),
502 )
503 })?;
504 let column = plan.schema().columns().get(index).cloned().ok_or_else(|| {
505 datafusion_common::DataFusionError::Internal(format!(
506 "PlanRewriter: maybe_set_partitions: column index {index} out of bounds in schema of plan: {plan}"
507 ))
508 })?;
509 Ok((c.clone(), BTreeSet::from([column])))
510 })
511 .collect::<DfResult<AliasMapping>>()?,
512 );
513 }
514 }
515
516 Ok(())
517 }
518
519 fn pop_stack(&mut self) {
521 self.level -= 1;
522 self.stack.pop();
523 }
524
525 fn expand(&mut self, mut on_node: LogicalPlan) -> DfResult<LogicalPlan> {
526 let schema = on_node.schema().clone();
528 if let Some(new_child_plan) = self.new_child_plan.take() {
529 on_node = new_child_plan;
531 }
532 let mut rewriter = EnforceDistRequirementRewriter::new(
533 std::mem::take(&mut self.column_requirements),
534 self.level,
535 );
536 debug!(
537 "PlanRewriter: enforce column requirements for node: {on_node} with rewriter: {rewriter:?}"
538 );
539 on_node = on_node.rewrite(&mut rewriter)?.data;
540 debug!(
541 "PlanRewriter: after enforced column requirements with rewriter: {rewriter:?} for node:\n{on_node}"
542 );
543
544 debug!(
545 "PlanRewriter: expand on node: {on_node} with partition col alias mapping: {:?}",
546 self.partition_cols
547 );
548
549 let mut node = MergeScanLogicalPlan::new(
551 on_node,
552 false,
553 self.partition_cols.clone().unwrap_or_default(),
556 )
557 .into_logical_plan();
558
559 for new_stage in self.stage.drain(..) {
561 node = new_stage
562 .with_new_exprs(new_stage.expressions_consider_join(), vec![node.clone()])?;
563 }
564 self.set_expanded();
565
566 let node = LogicalPlanBuilder::from(node)
569 .project(schema.iter().map(|(qualifier, field)| {
570 Expr::Column(Column::new(qualifier.cloned(), field.name()))
571 }))?
572 .build()?;
573
574 Ok(node)
575 }
576}
577
578#[derive(Debug)]
586struct EnforceDistRequirementRewriter {
587 column_requirements: Vec<(HashSet<Column>, usize)>,
591 cur_level: usize,
602}
603
604impl EnforceDistRequirementRewriter {
605 fn new(column_requirements: Vec<(HashSet<Column>, usize)>, cur_level: usize) -> Self {
606 Self {
607 column_requirements,
608 cur_level,
609 }
610 }
611}
612
613impl TreeNodeRewriter for EnforceDistRequirementRewriter {
614 type Node = LogicalPlan;
615
616 fn f_down(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
617 if node.inputs().len() > 1 {
619 return Err(datafusion_common::DataFusionError::Internal(
620 "EnforceDistRequirementRewriter: node with multiple inputs is not supported"
621 .to_string(),
622 ));
623 }
624 self.cur_level += 1;
625 Ok(Transformed::no(node))
626 }
627
628 fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
629 self.cur_level -= 1;
630 let mut applicable_column_requirements = self
632 .column_requirements
633 .iter()
634 .filter(|(_, level)| *level >= self.cur_level)
635 .map(|(cols, _)| cols.clone())
636 .reduce(|mut acc, cols| {
637 acc.extend(cols);
638 acc
639 })
640 .unwrap_or_default();
641
642 debug!(
643 "EnforceDistRequirementRewriter: applicable column requirements at level {} = {:?} for node {}",
644 self.cur_level,
645 applicable_column_requirements,
646 node.display()
647 );
648
649 if let LogicalPlan::Projection(ref projection) = node {
651 for expr in &projection.expr {
652 let (qualifier, name) = expr.qualified_name();
653 let column = Column::new(qualifier, name);
654 applicable_column_requirements.remove(&column);
655 }
656 if applicable_column_requirements.is_empty() {
657 return Ok(Transformed::no(node));
658 }
659
660 let mut new_exprs = projection.expr.clone();
661 for col in &applicable_column_requirements {
662 new_exprs.push(Expr::Column(col.clone()));
663 }
664 let new_node =
665 node.with_new_exprs(new_exprs, node.inputs().into_iter().cloned().collect())?;
666 debug!(
667 "EnforceDistRequirementRewriter: added missing columns {:?} to projection node from old node: \n{node}\n Making new node: \n{new_node}",
668 applicable_column_requirements
669 );
670
671 return Ok(Transformed::yes(new_node));
673 }
674 Ok(Transformed::no(node))
675 }
676}
677
678impl TreeNodeRewriter for PlanRewriter {
679 type Node = LogicalPlan;
680
681 fn f_down<'a>(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
683 self.level += 1;
684 self.stack.push((node.clone(), self.level));
685 self.stage.clear();
687 self.set_unexpanded();
688 self.partition_cols = None;
689 Ok(Transformed::no(node))
690 }
691
692 fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
696 if self.is_expanded() {
698 self.pop_stack();
699 return Ok(Transformed::no(node));
700 }
701
702 if node.inputs().is_empty() && !matches!(node, LogicalPlan::TableScan(_)) {
704 self.set_expanded();
705 self.pop_stack();
706 return Ok(Transformed::no(node));
707 }
708
709 self.maybe_set_partitions(&node)?;
710
711 let Some(parent) = self.get_parent() else {
712 debug!("Plan Rewriter: expand now for no parent found for node: {node}");
713 let node = self.expand(node);
714 debug!(
715 "PlanRewriter: expanded plan: {}",
716 match &node {
717 Ok(n) => n.to_string(),
718 Err(e) => format!("Error expanding plan: {e}"),
719 }
720 );
721 let node = node?;
722 self.pop_stack();
723 return Ok(Transformed::yes(node));
724 };
725
726 let parent = parent.clone();
727
728 if self.should_expand(&parent) {
730 debug!(
732 "PlanRewriter: should expand child:\n {node}\n Of Parent: {}",
733 parent.display()
734 );
735 let node = self.expand(node);
736 debug!(
737 "PlanRewriter: expanded plan: {}",
738 match &node {
739 Ok(n) => n.to_string(),
740 Err(e) => format!("Error expanding plan: {e}"),
741 }
742 );
743 let node = node?;
744 self.pop_stack();
745 return Ok(Transformed::yes(node));
746 }
747
748 self.pop_stack();
749 Ok(Transformed::no(node))
750 }
751}