1use std::collections::HashSet;
16use std::sync::Arc;
17
18use common_telemetry::debug;
19use datafusion::datasource::DefaultTableSource;
20use datafusion::error::Result as DfResult;
21use datafusion_common::config::ConfigOptions;
22use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
23use datafusion_common::Column;
24use datafusion_expr::expr::{Exists, InSubquery};
25use datafusion_expr::utils::expr_to_columns;
26use datafusion_expr::{col as col_fn, Expr, LogicalPlan, LogicalPlanBuilder, Subquery};
27use datafusion_optimizer::analyzer::AnalyzerRule;
28use datafusion_optimizer::simplify_expressions::SimplifyExpressions;
29use datafusion_optimizer::{OptimizerContext, OptimizerRule};
30use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
31use table::metadata::TableType;
32use table::table::adapter::DfTableProviderAdapter;
33
34use crate::dist_plan::commutativity::{
35 partial_commutative_transformer, Categorizer, Commutativity,
36};
37use crate::dist_plan::merge_scan::MergeScanLogicalPlan;
38use crate::plan::ExtractExpr;
39use crate::query_engine::DefaultSerializer;
40
41#[derive(Debug)]
42pub struct DistPlannerAnalyzer;
43
44impl AnalyzerRule for DistPlannerAnalyzer {
45 fn name(&self) -> &str {
46 "DistPlannerAnalyzer"
47 }
48
49 fn analyze(
50 &self,
51 plan: LogicalPlan,
52 _config: &ConfigOptions,
53 ) -> datafusion_common::Result<LogicalPlan> {
54 let optimizer_context = OptimizerContext::new();
56 let plan = SimplifyExpressions::new()
57 .rewrite(plan, &optimizer_context)?
58 .data;
59
60 let plan = plan.transform(&Self::inspect_plan_with_subquery)?;
61 let mut rewriter = PlanRewriter::default();
62 let result = plan.data.rewrite(&mut rewriter)?.data;
63
64 Ok(result)
65 }
66}
67
68impl DistPlannerAnalyzer {
69 fn inspect_plan_with_subquery(plan: LogicalPlan) -> DfResult<Transformed<LogicalPlan>> {
70 if let LogicalPlan::Limit(_) | LogicalPlan::Distinct(_) = &plan {
73 return Ok(Transformed::no(plan));
74 }
75
76 let exprs = plan
77 .expressions_consider_join()
78 .into_iter()
79 .map(|e| e.transform(&Self::transform_subquery).map(|x| x.data))
80 .collect::<DfResult<Vec<_>>>()?;
81
82 if !matches!(plan, LogicalPlan::Unnest(_)) {
84 let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
85 Ok(Transformed::yes(plan.with_new_exprs(exprs, inputs)?))
86 } else {
87 Ok(Transformed::no(plan))
88 }
89 }
90
91 fn transform_subquery(expr: Expr) -> DfResult<Transformed<Expr>> {
92 match expr {
93 Expr::Exists(exists) => Ok(Transformed::yes(Expr::Exists(Exists {
94 subquery: Self::handle_subquery(exists.subquery)?,
95 negated: exists.negated,
96 }))),
97 Expr::InSubquery(in_subquery) => Ok(Transformed::yes(Expr::InSubquery(InSubquery {
98 expr: in_subquery.expr,
99 subquery: Self::handle_subquery(in_subquery.subquery)?,
100 negated: in_subquery.negated,
101 }))),
102 Expr::ScalarSubquery(scalar_subquery) => Ok(Transformed::yes(Expr::ScalarSubquery(
103 Self::handle_subquery(scalar_subquery)?,
104 ))),
105
106 _ => Ok(Transformed::no(expr)),
107 }
108 }
109
110 fn handle_subquery(subquery: Subquery) -> DfResult<Subquery> {
111 let mut rewriter = PlanRewriter::default();
112 let mut rewrote_subquery = subquery
113 .subquery
114 .as_ref()
115 .clone()
116 .rewrite(&mut rewriter)?
117 .data;
118 if matches!(rewrote_subquery, LogicalPlan::Extension(_)) {
120 let output_schema = rewrote_subquery.schema().clone();
121 let project_exprs = output_schema
122 .fields()
123 .iter()
124 .map(|f| col_fn(f.name()))
125 .collect::<Vec<_>>();
126 rewrote_subquery = LogicalPlanBuilder::from(rewrote_subquery)
127 .project(project_exprs)?
128 .build()?;
129 }
130
131 Ok(Subquery {
132 subquery: Arc::new(rewrote_subquery),
133 outer_ref_columns: subquery.outer_ref_columns,
134 })
135 }
136}
137
138#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
140enum RewriterStatus {
141 #[default]
142 Unexpanded,
143 Expanded,
144}
145
146#[derive(Debug, Default)]
147struct PlanRewriter {
148 level: usize,
150 stack: Vec<(LogicalPlan, usize)>,
152 stage: Vec<LogicalPlan>,
154 status: RewriterStatus,
155 partition_cols: Option<Vec<String>>,
157 column_requirements: HashSet<Column>,
158 expand_on_next_call: bool,
159 new_child_plan: Option<LogicalPlan>,
160}
161
162impl PlanRewriter {
163 fn get_parent(&self) -> Option<&LogicalPlan> {
164 self.stack
166 .iter()
167 .rev()
168 .find(|(_, level)| *level == self.level - 1)
169 .map(|(node, _)| node)
170 }
171
172 fn should_expand(&mut self, plan: &LogicalPlan) -> bool {
174 if DFLogicalSubstraitConvertor
175 .encode(plan, DefaultSerializer)
176 .is_err()
177 {
178 return true;
179 }
180 if self.expand_on_next_call {
181 self.expand_on_next_call = false;
182 return true;
183 }
184 match Categorizer::check_plan(plan, self.partition_cols.clone()) {
185 Commutativity::Commutative => {}
186 Commutativity::PartialCommutative => {
187 if let Some(plan) = partial_commutative_transformer(plan) {
188 self.update_column_requirements(&plan);
189 self.stage.push(plan)
190 }
191 }
192 Commutativity::ConditionalCommutative(transformer) => {
193 if let Some(transformer) = transformer
194 && let Some(plan) = transformer(plan)
195 {
196 self.update_column_requirements(&plan);
197 self.stage.push(plan)
198 }
199 }
200 Commutativity::TransformedCommutative { transformer } => {
201 if let Some(transformer) = transformer
202 && let Some(transformer_actions) = transformer(plan)
203 {
204 debug!(
205 "PlanRewriter: transformed plan: {:#?}\n from {plan}",
206 transformer_actions.extra_parent_plans
207 );
208 if let Some(last_stage) = transformer_actions.extra_parent_plans.last() {
209 self.update_column_requirements(last_stage);
211 }
212 self.stage
213 .extend(transformer_actions.extra_parent_plans.into_iter().rev());
214 self.expand_on_next_call = true;
215 self.new_child_plan = transformer_actions.new_child_plan;
216 }
217 }
218 Commutativity::NonCommutative
219 | Commutativity::Unimplemented
220 | Commutativity::Unsupported => {
221 return true;
222 }
223 }
224
225 false
226 }
227
228 fn update_column_requirements(&mut self, plan: &LogicalPlan) {
229 let mut container = HashSet::new();
230 for expr in plan.expressions() {
231 let _ = expr_to_columns(&expr, &mut container);
233 }
234
235 for col in container {
236 self.column_requirements.insert(col);
237 }
238 }
239
240 fn is_expanded(&self) -> bool {
241 self.status == RewriterStatus::Expanded
242 }
243
244 fn set_expanded(&mut self) {
245 self.status = RewriterStatus::Expanded;
246 }
247
248 fn set_unexpanded(&mut self) {
249 self.status = RewriterStatus::Unexpanded;
250 }
251
252 fn maybe_set_partitions(&mut self, plan: &LogicalPlan) {
253 if self.partition_cols.is_some() {
254 return;
256 }
257
258 if let LogicalPlan::TableScan(table_scan) = plan {
259 if let Some(source) = table_scan
260 .source
261 .as_any()
262 .downcast_ref::<DefaultTableSource>()
263 {
264 if let Some(provider) = source
265 .table_provider
266 .as_any()
267 .downcast_ref::<DfTableProviderAdapter>()
268 {
269 if provider.table().table_type() == TableType::Base {
270 let info = provider.table().table_info();
271 let partition_key_indices = info.meta.partition_key_indices.clone();
272 let schema = info.meta.schema.clone();
273 let partition_cols = partition_key_indices
274 .into_iter()
275 .map(|index| schema.column_name_by_index(index).to_string())
276 .collect::<Vec<String>>();
277 self.partition_cols = Some(partition_cols);
278 }
279 }
280 }
281 }
282 }
283
284 fn pop_stack(&mut self) {
286 self.level -= 1;
287 self.stack.pop();
288 }
289
290 fn expand(&mut self, mut on_node: LogicalPlan) -> DfResult<LogicalPlan> {
291 if let Some(new_child_plan) = self.new_child_plan.take() {
292 on_node = new_child_plan;
294 }
295 let schema = on_node.schema().clone();
297 let mut rewriter = EnforceDistRequirementRewriter {
298 column_requirements: std::mem::take(&mut self.column_requirements),
299 };
300 on_node = on_node.rewrite(&mut rewriter)?.data;
301
302 let mut node = MergeScanLogicalPlan::new(
304 on_node,
305 false,
306 self.partition_cols.clone().unwrap_or_default(),
309 )
310 .into_logical_plan();
311
312 for new_stage in self.stage.drain(..) {
314 node = new_stage
315 .with_new_exprs(new_stage.expressions_consider_join(), vec![node.clone()])?;
316 }
317 self.set_expanded();
318
319 let node = LogicalPlanBuilder::from(node)
321 .project(schema.iter().map(|(qualifier, field)| {
322 Expr::Column(Column::new(qualifier.cloned(), field.name()))
323 }))?
324 .build()?;
325
326 Ok(node)
327 }
328}
329
330struct EnforceDistRequirementRewriter {
337 column_requirements: HashSet<Column>,
338}
339
340impl TreeNodeRewriter for EnforceDistRequirementRewriter {
341 type Node = LogicalPlan;
342
343 fn f_down(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
344 if let LogicalPlan::Projection(ref projection) = node {
345 let mut column_requirements = std::mem::take(&mut self.column_requirements);
346 if column_requirements.is_empty() {
347 return Ok(Transformed::no(node));
348 }
349
350 for expr in &projection.expr {
351 let (qualifier, name) = expr.qualified_name();
352 let column = Column::new(qualifier, name);
353 column_requirements.remove(&column);
354 }
355 if column_requirements.is_empty() {
356 return Ok(Transformed::no(node));
357 }
358
359 let mut new_exprs = projection.expr.clone();
360 for col in &column_requirements {
361 new_exprs.push(Expr::Column(col.clone()));
362 }
363 let new_node =
364 node.with_new_exprs(new_exprs, node.inputs().into_iter().cloned().collect())?;
365 return Ok(Transformed::yes(new_node));
366 }
367
368 Ok(Transformed::no(node))
369 }
370
371 fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
372 Ok(Transformed::no(node))
373 }
374}
375
376impl TreeNodeRewriter for PlanRewriter {
377 type Node = LogicalPlan;
378
379 fn f_down<'a>(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
381 self.level += 1;
382 self.stack.push((node.clone(), self.level));
383 self.stage.clear();
385 self.set_unexpanded();
386 self.partition_cols = None;
387 Ok(Transformed::no(node))
388 }
389
390 fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
394 if self.is_expanded() {
396 self.pop_stack();
397 return Ok(Transformed::no(node));
398 }
399
400 if node.inputs().is_empty() && !matches!(node, LogicalPlan::TableScan(_)) {
402 self.set_expanded();
403 self.pop_stack();
404 return Ok(Transformed::no(node));
405 }
406
407 self.maybe_set_partitions(&node);
408
409 let Some(parent) = self.get_parent() else {
410 let node = self.expand(node)?;
411 self.pop_stack();
412 return Ok(Transformed::yes(node));
413 };
414
415 let parent = parent.clone();
416
417 if self.should_expand(&parent) {
419 debug!("PlanRewriter: should expand child:\n {node}\n Of Parent: {parent}");
421 let node = self.expand(node);
422 debug!(
423 "PlanRewriter: expanded plan: {}",
424 match &node {
425 Ok(n) => n.to_string(),
426 Err(e) => format!("Error expanding plan: {e}"),
427 }
428 );
429 let node = node?;
430 self.pop_stack();
431 return Ok(Transformed::yes(node));
432 }
433
434 self.pop_stack();
435 Ok(Transformed::no(node))
436 }
437}
438
439#[cfg(test)]
440mod test {
441 use std::sync::Arc;
442
443 use datafusion::datasource::DefaultTableSource;
444 use datafusion::functions_aggregate::expr_fn::avg;
445 use datafusion_common::JoinType;
446 use datafusion_expr::{col, lit, Expr, LogicalPlanBuilder};
447 use table::table::adapter::DfTableProviderAdapter;
448 use table::table::numbers::NumbersTable;
449
450 use super::*;
451
452 #[ignore = "Projection is disabled for https://github.com/apache/arrow-datafusion/issues/6489"]
453 #[test]
454 fn transform_simple_projection_filter() {
455 let numbers_table = NumbersTable::table(0);
456 let table_source = Arc::new(DefaultTableSource::new(Arc::new(
457 DfTableProviderAdapter::new(numbers_table),
458 )));
459
460 let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
461 .unwrap()
462 .filter(col("number").lt(lit(10)))
463 .unwrap()
464 .project(vec![col("number")])
465 .unwrap()
466 .distinct()
467 .unwrap()
468 .build()
469 .unwrap();
470
471 let config = ConfigOptions::default();
472 let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
473 let expected = [
474 "Distinct:",
475 " MergeScan [is_placeholder=false]",
476 " Distinct:",
477 " Projection: t.number",
478 " Filter: t.number < Int32(10)",
479 " TableScan: t",
480 ]
481 .join("\n");
482 assert_eq!(expected, result.to_string());
483 }
484
485 #[test]
486 fn transform_aggregator() {
487 let numbers_table = NumbersTable::table(0);
488 let table_source = Arc::new(DefaultTableSource::new(Arc::new(
489 DfTableProviderAdapter::new(numbers_table),
490 )));
491
492 let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
493 .unwrap()
494 .aggregate(Vec::<Expr>::new(), vec![avg(col("number"))])
495 .unwrap()
496 .build()
497 .unwrap();
498
499 let config = ConfigOptions::default();
500 let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
501 let expected = "Projection: avg(t.number)\
502 \n MergeScan [is_placeholder=false]";
503 assert_eq!(expected, result.to_string());
504 }
505
506 #[test]
507 fn transform_distinct_order() {
508 let numbers_table = NumbersTable::table(0);
509 let table_source = Arc::new(DefaultTableSource::new(Arc::new(
510 DfTableProviderAdapter::new(numbers_table),
511 )));
512
513 let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
514 .unwrap()
515 .distinct()
516 .unwrap()
517 .sort(vec![col("number").sort(true, false)])
518 .unwrap()
519 .build()
520 .unwrap();
521
522 let config = ConfigOptions::default();
523 let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
524 let expected = ["Projection: t.number", " MergeScan [is_placeholder=false]"].join("\n");
525 assert_eq!(expected, result.to_string());
526 }
527
528 #[test]
529 fn transform_single_limit() {
530 let numbers_table = NumbersTable::table(0);
531 let table_source = Arc::new(DefaultTableSource::new(Arc::new(
532 DfTableProviderAdapter::new(numbers_table),
533 )));
534
535 let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
536 .unwrap()
537 .limit(0, Some(1))
538 .unwrap()
539 .build()
540 .unwrap();
541
542 let config = ConfigOptions::default();
543 let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
544 let expected = "Projection: t.number\
545 \n MergeScan [is_placeholder=false]";
546 assert_eq!(expected, result.to_string());
547 }
548
549 #[test]
550 fn transform_unalighed_join_with_alias() {
551 let left = NumbersTable::table(0);
552 let right = NumbersTable::table(1);
553 let left_source = Arc::new(DefaultTableSource::new(Arc::new(
554 DfTableProviderAdapter::new(left),
555 )));
556 let right_source = Arc::new(DefaultTableSource::new(Arc::new(
557 DfTableProviderAdapter::new(right),
558 )));
559
560 let right_plan = LogicalPlanBuilder::scan_with_filters("t", right_source, None, vec![])
561 .unwrap()
562 .alias("right")
563 .unwrap()
564 .build()
565 .unwrap();
566
567 let plan = LogicalPlanBuilder::scan_with_filters("t", left_source, None, vec![])
568 .unwrap()
569 .join_on(
570 right_plan,
571 JoinType::LeftSemi,
572 vec![col("t.number").eq(col("right.number"))],
573 )
574 .unwrap()
575 .limit(0, Some(1))
576 .unwrap()
577 .build()
578 .unwrap();
579
580 let config = ConfigOptions::default();
581 let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
582 let expected = [
583 "Limit: skip=0, fetch=1",
584 " LeftSemi Join: Filter: t.number = right.number",
585 " Projection: t.number",
586 " MergeScan [is_placeholder=false]",
587 " SubqueryAlias: right",
588 " Projection: t.number",
589 " MergeScan [is_placeholder=false]",
590 ]
591 .join("\n");
592 assert_eq!(expected, result.to_string());
593 }
594}