1use std::collections::{HashMap, HashSet};
16use std::sync::Arc;
17
18use common_telemetry::debug;
19use datafusion::config::{ConfigExtension, ExtensionOptions};
20use datafusion::datasource::DefaultTableSource;
21use datafusion::error::Result as DfResult;
22use datafusion_common::config::ConfigOptions;
23use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
24use datafusion_common::Column;
25use datafusion_expr::expr::{Exists, InSubquery};
26use datafusion_expr::utils::expr_to_columns;
27use datafusion_expr::{col as col_fn, Expr, LogicalPlan, LogicalPlanBuilder, Subquery};
28use datafusion_optimizer::analyzer::AnalyzerRule;
29use datafusion_optimizer::simplify_expressions::SimplifyExpressions;
30use datafusion_optimizer::{OptimizerContext, OptimizerRule};
31use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
32use table::metadata::TableType;
33use table::table::adapter::DfTableProviderAdapter;
34
35use crate::dist_plan::commutativity::{
36 partial_commutative_transformer, Categorizer, Commutativity,
37};
38use crate::dist_plan::merge_scan::MergeScanLogicalPlan;
39use crate::metrics::PUSH_DOWN_FALLBACK_ERRORS_TOTAL;
40use crate::plan::ExtractExpr;
41use crate::query_engine::DefaultSerializer;
42
43#[cfg(test)]
44mod test;
45
46mod fallback;
47mod utils;
48
49pub(crate) use utils::{AliasMapping, AliasTracker};
50
51#[derive(Debug, Clone)]
52pub struct DistPlannerOptions {
53 pub allow_query_fallback: bool,
54}
55
56impl ConfigExtension for DistPlannerOptions {
57 const PREFIX: &'static str = "dist_planner";
58}
59
60impl ExtensionOptions for DistPlannerOptions {
61 fn as_any(&self) -> &dyn std::any::Any {
62 self
63 }
64
65 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
66 self
67 }
68
69 fn cloned(&self) -> Box<dyn ExtensionOptions> {
70 Box::new(self.clone())
71 }
72
73 fn set(&mut self, key: &str, value: &str) -> DfResult<()> {
74 Err(datafusion_common::DataFusionError::NotImplemented(format!(
75 "DistPlannerOptions does not support set key: {key} with value: {value}"
76 )))
77 }
78
79 fn entries(&self) -> Vec<datafusion::config::ConfigEntry> {
80 vec![datafusion::config::ConfigEntry {
81 key: "allow_query_fallback".to_string(),
82 value: Some(self.allow_query_fallback.to_string()),
83 description: "Allow query fallback to fallback plan rewriter",
84 }]
85 }
86}
87
88#[derive(Debug)]
89pub struct DistPlannerAnalyzer;
90
91impl AnalyzerRule for DistPlannerAnalyzer {
92 fn name(&self) -> &str {
93 "DistPlannerAnalyzer"
94 }
95
96 fn analyze(
97 &self,
98 plan: LogicalPlan,
99 config: &ConfigOptions,
100 ) -> datafusion_common::Result<LogicalPlan> {
101 let optimizer_context = OptimizerContext::new();
103 let plan = SimplifyExpressions::new()
104 .rewrite(plan, &optimizer_context)?
105 .data;
106
107 let opt = config.extensions.get::<DistPlannerOptions>();
108 let allow_fallback = opt.map(|o| o.allow_query_fallback).unwrap_or(false);
109
110 let result = match self.try_push_down(plan.clone()) {
111 Ok(plan) => plan,
112 Err(err) => {
113 if allow_fallback {
114 common_telemetry::warn!(err; "Failed to push down plan, using fallback plan rewriter for plan: {plan}");
115 PUSH_DOWN_FALLBACK_ERRORS_TOTAL.inc();
117 self.use_fallback(plan)?
118 } else {
119 return Err(err);
120 }
121 }
122 };
123
124 Ok(result)
125 }
126}
127
128impl DistPlannerAnalyzer {
129 fn try_push_down(&self, plan: LogicalPlan) -> DfResult<LogicalPlan> {
131 let plan = plan.transform(&Self::inspect_plan_with_subquery)?;
132 let mut rewriter = PlanRewriter::default();
133 let result = plan.data.rewrite(&mut rewriter)?.data;
134 Ok(result)
135 }
136
137 fn use_fallback(&self, plan: LogicalPlan) -> DfResult<LogicalPlan> {
139 let mut rewriter = fallback::FallbackPlanRewriter;
140 let result = plan.rewrite(&mut rewriter)?.data;
141 Ok(result)
142 }
143
144 fn inspect_plan_with_subquery(plan: LogicalPlan) -> DfResult<Transformed<LogicalPlan>> {
145 if let LogicalPlan::Limit(_) | LogicalPlan::Distinct(_) = &plan {
148 return Ok(Transformed::no(plan));
149 }
150
151 let exprs = plan
152 .expressions_consider_join()
153 .into_iter()
154 .map(|e| e.transform(&Self::transform_subquery).map(|x| x.data))
155 .collect::<DfResult<Vec<_>>>()?;
156
157 if !matches!(plan, LogicalPlan::Unnest(_)) {
159 let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
160 Ok(Transformed::yes(plan.with_new_exprs(exprs, inputs)?))
161 } else {
162 Ok(Transformed::no(plan))
163 }
164 }
165
166 fn transform_subquery(expr: Expr) -> DfResult<Transformed<Expr>> {
167 match expr {
168 Expr::Exists(exists) => Ok(Transformed::yes(Expr::Exists(Exists {
169 subquery: Self::handle_subquery(exists.subquery)?,
170 negated: exists.negated,
171 }))),
172 Expr::InSubquery(in_subquery) => Ok(Transformed::yes(Expr::InSubquery(InSubquery {
173 expr: in_subquery.expr,
174 subquery: Self::handle_subquery(in_subquery.subquery)?,
175 negated: in_subquery.negated,
176 }))),
177 Expr::ScalarSubquery(scalar_subquery) => Ok(Transformed::yes(Expr::ScalarSubquery(
178 Self::handle_subquery(scalar_subquery)?,
179 ))),
180
181 _ => Ok(Transformed::no(expr)),
182 }
183 }
184
185 fn handle_subquery(subquery: Subquery) -> DfResult<Subquery> {
186 let mut rewriter = PlanRewriter::default();
187 let mut rewrote_subquery = subquery
188 .subquery
189 .as_ref()
190 .clone()
191 .rewrite(&mut rewriter)?
192 .data;
193 if matches!(rewrote_subquery, LogicalPlan::Extension(_)) {
195 let output_schema = rewrote_subquery.schema().clone();
196 let project_exprs = output_schema
197 .fields()
198 .iter()
199 .map(|f| col_fn(f.name()))
200 .collect::<Vec<_>>();
201 rewrote_subquery = LogicalPlanBuilder::from(rewrote_subquery)
202 .project(project_exprs)?
203 .build()?;
204 }
205
206 Ok(Subquery {
207 subquery: Arc::new(rewrote_subquery),
208 outer_ref_columns: subquery.outer_ref_columns,
209 spans: Default::default(),
210 })
211 }
212}
213
214#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
216enum RewriterStatus {
217 #[default]
218 Unexpanded,
219 Expanded,
220}
221
222#[derive(Debug, Default)]
223struct PlanRewriter {
224 level: usize,
226 stack: Vec<(LogicalPlan, usize)>,
228 stage: Vec<LogicalPlan>,
230 status: RewriterStatus,
231 partition_cols: Option<Vec<String>>,
233 alias_tracker: Option<AliasTracker>,
234 column_requirements: Vec<(HashSet<Column>, usize)>,
260 expand_on_next_call: bool,
265 expand_on_next_part_cond_trans_commutative: bool,
277 new_child_plan: Option<LogicalPlan>,
278}
279
280impl PlanRewriter {
281 fn get_parent(&self) -> Option<&LogicalPlan> {
282 self.stack
284 .iter()
285 .rev()
286 .find(|(_, level)| *level == self.level - 1)
287 .map(|(node, _)| node)
288 }
289
290 fn should_expand(&mut self, plan: &LogicalPlan) -> bool {
292 debug!(
293 "Check should_expand at level: {} with Stack:\n{}, ",
294 self.level,
295 self.stack
296 .iter()
297 .map(|(p, l)| format!("{l}:{}{}", " ".repeat(l - 1), p.display()))
298 .collect::<Vec<String>>()
299 .join("\n"),
300 );
301 if DFLogicalSubstraitConvertor
302 .encode(plan, DefaultSerializer)
303 .is_err()
304 {
305 return true;
306 }
307
308 if self.expand_on_next_call {
309 self.expand_on_next_call = false;
310 return true;
311 }
312
313 if self.expand_on_next_part_cond_trans_commutative {
314 let comm = Categorizer::check_plan(plan, self.get_aliased_partition_columns());
315 match comm {
316 Commutativity::PartialCommutative => {
317 self.expand_on_next_part_cond_trans_commutative = false;
323 self.expand_on_next_call = true;
324 }
325 Commutativity::ConditionalCommutative(_)
326 | Commutativity::TransformedCommutative { .. } => {
327 self.expand_on_next_part_cond_trans_commutative = false;
330 return true;
331 }
332 _ => (),
333 }
334 }
335
336 match Categorizer::check_plan(plan, self.get_aliased_partition_columns()) {
337 Commutativity::Commutative => {}
338 Commutativity::PartialCommutative => {
339 if let Some(plan) = partial_commutative_transformer(plan) {
340 self.update_column_requirements(&plan, self.level - 1);
342 self.expand_on_next_part_cond_trans_commutative = true;
343 self.stage.push(plan)
344 }
345 }
346 Commutativity::ConditionalCommutative(transformer) => {
347 if let Some(transformer) = transformer
348 && let Some(plan) = transformer(plan)
349 {
350 self.update_column_requirements(&plan, self.level - 1);
352 self.expand_on_next_part_cond_trans_commutative = true;
353 self.stage.push(plan)
354 }
355 }
356 Commutativity::TransformedCommutative { transformer } => {
357 if let Some(transformer) = transformer
358 && let Some(transformer_actions) = transformer(plan)
359 {
360 debug!(
361 "PlanRewriter: transformed plan: {}\n from {plan}",
362 transformer_actions
363 .extra_parent_plans
364 .iter()
365 .enumerate()
366 .map(|(i, p)| format!(
367 "Extra {i}-th parent plan from parent to child = {}",
368 p.display()
369 ))
370 .collect::<Vec<_>>()
371 .join("\n")
372 );
373 if let Some(new_child_plan) = &transformer_actions.new_child_plan {
374 debug!("PlanRewriter: new child plan: {}", new_child_plan);
375 }
376 if let Some(last_stage) = transformer_actions.extra_parent_plans.last() {
377 self.update_column_requirements(last_stage, self.level - 1);
380 }
381 self.stage
382 .extend(transformer_actions.extra_parent_plans.into_iter().rev());
383 self.expand_on_next_call = true;
384 self.new_child_plan = transformer_actions.new_child_plan;
385 }
386 }
387 Commutativity::NonCommutative
388 | Commutativity::Unimplemented
389 | Commutativity::Unsupported => {
390 return true;
391 }
392 }
393
394 false
395 }
396
397 fn update_column_requirements(&mut self, plan: &LogicalPlan, plan_level: usize) {
401 debug!(
402 "PlanRewriter: update column requirements for plan: {plan}\n with old column_requirements: {:?}",
403 self.column_requirements
404 );
405 let mut container = HashSet::new();
406 for expr in plan.expressions() {
407 let _ = expr_to_columns(&expr, &mut container);
409 }
410
411 self.column_requirements.push((container, plan_level));
412 debug!(
413 "PlanRewriter: updated column requirements: {:?}",
414 self.column_requirements
415 );
416 }
417
418 fn is_expanded(&self) -> bool {
419 self.status == RewriterStatus::Expanded
420 }
421
422 fn set_expanded(&mut self) {
423 self.status = RewriterStatus::Expanded;
424 }
425
426 fn set_unexpanded(&mut self) {
427 self.status = RewriterStatus::Unexpanded;
428 }
429
430 fn maybe_update_alias(&mut self, node: &LogicalPlan) {
432 if let Some(alias_tracker) = &mut self.alias_tracker {
433 alias_tracker.update_alias(node);
434 debug!(
435 "Current partition columns are: {:?}",
436 self.get_aliased_partition_columns()
437 );
438 } else if let LogicalPlan::TableScan(table_scan) = node {
439 self.alias_tracker = AliasTracker::new(table_scan);
440 debug!(
441 "Initialize partition columns: {:?} with table={}",
442 self.get_aliased_partition_columns(),
443 table_scan.table_name
444 );
445 }
446 }
447
448 fn get_aliased_partition_columns(&self) -> Option<AliasMapping> {
449 if let Some(part_cols) = self.partition_cols.as_ref() {
450 let Some(alias_tracker) = &self.alias_tracker else {
451 return None;
453 };
454 let mut aliased = HashMap::new();
455 for part_col in part_cols {
456 let all_alias = alias_tracker
457 .get_all_alias_for_col(part_col)
458 .cloned()
459 .unwrap_or_default();
460
461 aliased.insert(part_col.clone(), all_alias);
462 }
463 Some(aliased)
464 } else {
465 None
466 }
467 }
468
469 fn maybe_set_partitions(&mut self, plan: &LogicalPlan) {
470 if self.partition_cols.is_some() {
471 return;
473 }
474
475 if let LogicalPlan::TableScan(table_scan) = plan {
476 if let Some(source) = table_scan
477 .source
478 .as_any()
479 .downcast_ref::<DefaultTableSource>()
480 {
481 if let Some(provider) = source
482 .table_provider
483 .as_any()
484 .downcast_ref::<DfTableProviderAdapter>()
485 {
486 let table = provider.table();
487 if table.table_type() == TableType::Base {
488 let info = table.table_info();
489 let partition_key_indices = info.meta.partition_key_indices.clone();
490 let schema = info.meta.schema.clone();
491 let mut partition_cols = partition_key_indices
492 .into_iter()
493 .map(|index| schema.column_name_by_index(index).to_string())
494 .collect::<Vec<String>>();
495
496 let partition_rules = table.partition_rules();
497 let exist_phy_part_cols_not_in_logical_table = partition_rules
498 .map(|r| !r.extra_phy_cols_not_in_logical_table.is_empty())
499 .unwrap_or(false);
500
501 if exist_phy_part_cols_not_in_logical_table && partition_cols.is_empty() {
502 partition_cols
510 .push("__OTHER_PHYSICAL_PART_COLS_PLACEHOLDER__".to_string());
511 }
512 self.partition_cols = Some(partition_cols);
513 }
514 }
515 }
516 }
517 }
518
519 fn pop_stack(&mut self) {
521 self.level -= 1;
522 self.stack.pop();
523 }
524
525 fn expand(&mut self, mut on_node: LogicalPlan) -> DfResult<LogicalPlan> {
526 let schema = on_node.schema().clone();
528 if let Some(new_child_plan) = self.new_child_plan.take() {
529 on_node = new_child_plan;
531 }
532 let mut rewriter = EnforceDistRequirementRewriter::new(
533 std::mem::take(&mut self.column_requirements),
534 self.level,
535 );
536 debug!("PlanRewriter: enforce column requirements for node: {on_node} with rewriter: {rewriter:?}");
537 on_node = on_node.rewrite(&mut rewriter)?.data;
538 debug!(
539 "PlanRewriter: after enforced column requirements with rewriter: {rewriter:?} for node:\n{on_node}"
540 );
541
542 let mut node = MergeScanLogicalPlan::new(
544 on_node,
545 false,
546 self.partition_cols.clone().unwrap_or_default(),
549 )
550 .into_logical_plan();
551
552 for new_stage in self.stage.drain(..) {
554 node = new_stage
555 .with_new_exprs(new_stage.expressions_consider_join(), vec![node.clone()])?;
556 }
557 self.set_expanded();
558
559 let node = LogicalPlanBuilder::from(node)
562 .project(schema.iter().map(|(qualifier, field)| {
563 Expr::Column(Column::new(qualifier.cloned(), field.name()))
564 }))?
565 .build()?;
566
567 Ok(node)
568 }
569}
570
571#[derive(Debug)]
579struct EnforceDistRequirementRewriter {
580 column_requirements: Vec<(HashSet<Column>, usize)>,
584 cur_level: usize,
595}
596
597impl EnforceDistRequirementRewriter {
598 fn new(column_requirements: Vec<(HashSet<Column>, usize)>, cur_level: usize) -> Self {
599 Self {
600 column_requirements,
601 cur_level,
602 }
603 }
604}
605
606impl TreeNodeRewriter for EnforceDistRequirementRewriter {
607 type Node = LogicalPlan;
608
609 fn f_down(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
610 if node.inputs().len() > 1 {
612 return Err(datafusion_common::DataFusionError::Internal(
613 "EnforceDistRequirementRewriter: node with multiple inputs is not supported"
614 .to_string(),
615 ));
616 }
617 self.cur_level += 1;
618 Ok(Transformed::no(node))
619 }
620
621 fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
622 self.cur_level -= 1;
623 let mut applicable_column_requirements = self
625 .column_requirements
626 .iter()
627 .filter(|(_, level)| *level >= self.cur_level)
628 .map(|(cols, _)| cols.clone())
629 .reduce(|mut acc, cols| {
630 acc.extend(cols);
631 acc
632 })
633 .unwrap_or_default();
634
635 debug!(
636 "EnforceDistRequirementRewriter: applicable column requirements at level {} = {:?} for node {}",
637 self.cur_level,
638 applicable_column_requirements,
639 node.display()
640 );
641
642 if let LogicalPlan::Projection(ref projection) = node {
644 for expr in &projection.expr {
645 let (qualifier, name) = expr.qualified_name();
646 let column = Column::new(qualifier, name);
647 applicable_column_requirements.remove(&column);
648 }
649 if applicable_column_requirements.is_empty() {
650 return Ok(Transformed::no(node));
651 }
652
653 let mut new_exprs = projection.expr.clone();
654 for col in &applicable_column_requirements {
655 new_exprs.push(Expr::Column(col.clone()));
656 }
657 let new_node =
658 node.with_new_exprs(new_exprs, node.inputs().into_iter().cloned().collect())?;
659 debug!(
660 "EnforceDistRequirementRewriter: added missing columns {:?} to projection node from old node: \n{node}\n Making new node: \n{new_node}",
661 applicable_column_requirements
662 );
663
664 return Ok(Transformed::yes(new_node));
666 }
667 Ok(Transformed::no(node))
668 }
669}
670
671impl TreeNodeRewriter for PlanRewriter {
672 type Node = LogicalPlan;
673
674 fn f_down<'a>(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
676 self.level += 1;
677 self.stack.push((node.clone(), self.level));
678 self.stage.clear();
680 self.set_unexpanded();
681 self.partition_cols = None;
682 self.alias_tracker = None;
683 Ok(Transformed::no(node))
684 }
685
686 fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
690 if self.is_expanded() {
692 self.pop_stack();
693 return Ok(Transformed::no(node));
694 }
695
696 if node.inputs().is_empty() && !matches!(node, LogicalPlan::TableScan(_)) {
698 self.set_expanded();
699 self.pop_stack();
700 return Ok(Transformed::no(node));
701 }
702
703 self.maybe_set_partitions(&node);
704
705 self.maybe_update_alias(&node);
706
707 let Some(parent) = self.get_parent() else {
708 debug!("Plan Rewriter: expand now for no parent found for node: {node}");
709 let node = self.expand(node);
710 debug!(
711 "PlanRewriter: expanded plan: {}",
712 match &node {
713 Ok(n) => n.to_string(),
714 Err(e) => format!("Error expanding plan: {e}"),
715 }
716 );
717 let node = node?;
718 self.pop_stack();
719 return Ok(Transformed::yes(node));
720 };
721
722 let parent = parent.clone();
723
724 if self.should_expand(&parent) {
726 debug!(
728 "PlanRewriter: should expand child:\n {node}\n Of Parent: {}",
729 parent.display()
730 );
731 let node = self.expand(node);
732 debug!(
733 "PlanRewriter: expanded plan: {}",
734 match &node {
735 Ok(n) => n.to_string(),
736 Err(e) => format!("Error expanding plan: {e}"),
737 }
738 );
739 let node = node?;
740 self.pop_stack();
741 return Ok(Transformed::yes(node));
742 }
743
744 self.pop_stack();
745 Ok(Transformed::no(node))
746 }
747}