1use std::collections::{BTreeSet, HashMap, HashSet};
18use std::sync::Arc;
19
20use catalog::CatalogManagerRef;
21use common_error::ext::BoxedError;
22use common_function::aggrs::aggr_wrapper::get_aggr_func;
23use common_telemetry::debug;
24use datafusion::datasource::DefaultTableSource;
25use datafusion::error::Result as DfResult;
26use datafusion::logical_expr::Expr;
27use datafusion::sql::unparser::Unparser;
28use datafusion_common::tree_node::{
29 Transformed, TreeNode as _, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
30};
31use datafusion_common::{
32 Column, DFSchema, DataFusionError, NullEquality, ScalarValue, TableReference,
33};
34use datafusion_expr::logical_plan::{Aggregate, TableScan};
35use datafusion_expr::{
36 Distinct, ExprSchemable, JoinType, LogicalPlan, LogicalPlanBuilder, Operator, Projection, and,
37 binary_expr, bitwise_and, bitwise_or, bitwise_xor, is_null, or, when,
38};
39use datatypes::prelude::ConcreteDataType;
40use datatypes::schema::{ColumnSchema, SchemaRef};
41use query::QueryEngineRef;
42use query::parser::{DEFAULT_LOOKBACK_STRING, PromQuery, QueryLanguageParser, QueryStatement};
43use session::context::QueryContextRef;
44use snafu::{OptionExt, ResultExt, ensure};
45use sql::parser::{ParseOptions, ParserContext};
46use sql::statements::statement::Statement;
47use sql::statements::tql::Tql;
48use table::TableRef;
49use table::table::adapter::DfTableProviderAdapter;
50
51use crate::adapter::{AUTO_CREATED_PLACEHOLDER_TS_COL, AUTO_CREATED_UPDATE_AT_TS_COL};
52use crate::df_optimizer::apply_df_optimizer;
53use crate::error::{DatafusionSnafu, ExternalSnafu, InvalidQuerySnafu, TableNotFoundSnafu};
54use crate::{Error, TableName};
55
56#[cfg(test)]
57mod test;
58
59#[derive(Debug, Clone, PartialEq, Eq)]
68pub struct IncrementalAggregateMergeColumn {
69 pub output_field_name: String,
72 pub merge_op: IncrementalAggregateMergeOp,
73}
74
75impl IncrementalAggregateMergeColumn {
76 pub fn new(output_field_name: String, merge_op: IncrementalAggregateMergeOp) -> Self {
78 Self {
79 output_field_name,
80 merge_op,
81 }
82 }
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub enum IncrementalAggregateMergeOp {
87 Sum,
88 Min,
89 Max,
90 BoolAnd,
91 BoolOr,
92 BitAnd,
93 BitOr,
94 BitXor,
95}
96
97#[derive(Debug, Clone, PartialEq, Eq)]
105pub struct IncrementalAggregateAnalysis {
106 pub group_key_names: Vec<String>,
108 pub merge_columns: Vec<IncrementalAggregateMergeColumn>,
109 pub literal_columns: Vec<String>,
111 pub output_field_names: Vec<String>,
113 pub unsupported_exprs: Vec<String>,
114}
115
116fn find_column_names(expr: &Expr, names: &mut Vec<String>) {
125 match expr {
126 Expr::Column(col) => {
127 names.push(col.name.clone());
128 }
129 Expr::Alias(alias) => find_column_names(&alias.expr, names),
130 _ => {}
131 }
132}
133
134fn unqualified_col(name: impl Into<String>) -> Expr {
135 Expr::Column(Column::from_name(name.into()))
136}
137
138fn qualified_col(qualifier: &str, name: impl Into<String>) -> Expr {
139 Expr::Column(Column::new(Some(qualifier), name.into()))
140}
141
142fn qualified_column(qualifier: &str, name: impl Into<String>) -> Column {
143 Column::new(Some(qualifier), name.into())
144}
145
146fn find_group_key_names(plan: &LogicalPlan) -> Result<Vec<String>, Error> {
147 let mut group_finder = FindGroupByFinalName::default();
148 plan.visit(&mut group_finder)
149 .with_context(|_| DatafusionSnafu {
150 context: format!("Failed to inspect group-by columns from logical plan: {plan:?}"),
151 })?;
152
153 let mut group_key_names = group_finder
154 .get_group_expr_names()
155 .unwrap_or_default()
156 .into_iter()
157 .collect::<Vec<_>>();
158 group_key_names.sort();
159 Ok(group_key_names)
160}
161
162fn has_grouping_set(plan: &LogicalPlan) -> bool {
163 match plan {
164 LogicalPlan::Aggregate(aggregate) => aggregate
165 .group_expr
166 .iter()
167 .any(|expr| matches!(expr, Expr::GroupingSet(_))),
168 _ => plan.inputs().into_iter().any(has_grouping_set),
169 }
170}
171
172fn has_aggregate(plan: &LogicalPlan) -> bool {
173 match plan {
174 LogicalPlan::Aggregate(_) => true,
175 _ => plan.inputs().into_iter().any(has_aggregate),
176 }
177}
178
179fn peel_subquery_aliases(mut plan: &LogicalPlan) -> &LogicalPlan {
180 while let LogicalPlan::SubqueryAlias(alias) = plan {
181 plan = alias.input.as_ref();
182 }
183 plan
184}
185
186fn extract_incremental_aggregate(plan: &LogicalPlan) -> Result<Option<&Aggregate>, String> {
187 let plan = match plan {
191 LogicalPlan::Projection(projection) => projection.input.as_ref(),
192 _ => plan,
193 };
194
195 match plan {
196 LogicalPlan::Aggregate(aggregate) => {
197 check_input_plan_shape(aggregate.input.as_ref())?;
198 Ok(Some(aggregate))
199 }
200 LogicalPlan::Filter(filter) if has_aggregate(filter.input.as_ref()) => Err(
201 "unsupported post-aggregate filter (HAVING) in incremental aggregate rewrite"
202 .to_string(),
203 ),
204 _ if has_aggregate(plan) => Err(
205 "unsupported post-aggregate plan shape in incremental aggregate rewrite".to_string(),
206 ),
207 _ => Ok(None),
208 }
209}
210
211fn check_input_plan_shape(plan: &LogicalPlan) -> Result<(), String> {
212 let plan = peel_subquery_aliases(plan);
213 match plan {
214 LogicalPlan::TableScan(_) => Ok(()),
217 LogicalPlan::Filter(filter) => match peel_subquery_aliases(filter.input.as_ref()) {
218 LogicalPlan::TableScan(_) => Ok(()),
219 _ => Err(
220 "unsupported aggregate input plan shape in incremental aggregate rewrite"
221 .to_string(),
222 ),
223 },
224 _ => Err(
225 "unsupported aggregate input plan shape in incremental aggregate rewrite".to_string(),
226 ),
227 }
228}
229
230#[derive(Debug, Default)]
231struct OutputProjectionInfo {
232 has_top_level_projection: bool,
233 output_aliases: HashMap<String, String>,
234 duplicate_aggregate_aliases: BTreeSet<String>,
235 literal_columns: HashSet<String>,
236 output_field_names: Vec<String>,
237}
238
239impl OutputProjectionInfo {
240 fn output_field_name_set(&self) -> HashSet<String> {
241 self.output_field_names.iter().cloned().collect()
242 }
243
244 fn duplicate_output_names(&self) -> Vec<String> {
245 let mut seen = HashSet::new();
246 let mut duplicates = BTreeSet::new();
247 for name in &self.output_field_names {
248 if !seen.insert(name.clone()) {
249 duplicates.insert(name.clone());
250 }
251 }
252 duplicates.into_iter().collect()
253 }
254}
255
256fn collect_output_projection_info(plan: &LogicalPlan) -> OutputProjectionInfo {
257 let mut projection_info = OutputProjectionInfo {
258 has_top_level_projection: matches!(plan, LogicalPlan::Projection(_)),
259 output_field_names: plan
260 .schema()
261 .fields()
262 .iter()
263 .map(|field| field.name().clone())
264 .collect(),
265 ..Default::default()
266 };
267
268 let mut output_aliases = HashMap::new();
269 if let LogicalPlan::Projection(projection) = plan {
270 for expr in &projection.expr {
271 match expr {
272 Expr::Alias(alias) => {
273 let alias_name = alias.name.clone();
279 let mut col_names = Vec::new();
280 find_column_names(&alias.expr, &mut col_names);
281 match col_names.len() {
282 0 if is_passthrough_output_column(&alias_name, alias.expr.as_ref()) => {
283 projection_info.literal_columns.insert(alias_name);
284 }
285 1 => {
286 if let Some(col_name) = col_names.into_iter().next() {
287 if let Some(existing_alias) = output_aliases.get(&col_name) {
288 if existing_alias != &alias_name {
289 projection_info.duplicate_aggregate_aliases.insert(format!(
290 "same aggregate output {col_name} is used by multiple aliases: {existing_alias}, {alias_name}"
291 ));
292 }
293 } else {
294 output_aliases.insert(col_name, alias_name);
295 }
296 }
297 }
298 _ => {}
299 }
300
301 }
304 Expr::Column(col) => {
305 output_aliases
306 .entry(col.name.clone())
307 .or_insert(col.name.clone());
308 }
309 Expr::Literal(_, _) => {
310 projection_info
311 .literal_columns
312 .insert(expr.qualified_name().1);
313 }
314 _ => {}
315 }
316 }
317 }
318
319 if projection_info
320 .output_field_names
321 .iter()
322 .any(|name| name == AUTO_CREATED_PLACEHOLDER_TS_COL)
323 {
324 projection_info
325 .literal_columns
326 .insert(AUTO_CREATED_PLACEHOLDER_TS_COL.to_string());
327 }
328
329 projection_info.output_aliases = output_aliases;
330 projection_info
331}
332
333fn is_passthrough_output_column(alias_name: &str, expr: &Expr) -> bool {
334 matches!(expr, Expr::Literal(_, _))
335 || match alias_name {
336 AUTO_CREATED_UPDATE_AT_TS_COL => expr == &datafusion::prelude::now(),
337 AUTO_CREATED_PLACEHOLDER_TS_COL => is_literal_or_cast_literal(expr),
338 _ => false,
339 }
340}
341
342fn is_literal_or_cast_literal(expr: &Expr) -> bool {
343 match expr {
344 Expr::Literal(_, _) => true,
345 Expr::Cast(cast) => is_literal_or_cast_literal(cast.expr.as_ref()),
346 Expr::TryCast(cast) => is_literal_or_cast_literal(cast.expr.as_ref()),
347 _ => false,
348 }
349}
350
351fn merge_op_for_aggregate_expr(aggr_expr: &Expr) -> Result<IncrementalAggregateMergeOp, String> {
352 let Some(aggr_func) = get_aggr_func(aggr_expr) else {
353 return Err(aggr_expr.to_string());
354 };
355 if aggr_func.params.distinct {
356 return Err(format!("unsupported DISTINCT aggregate: {aggr_expr}"));
357 }
358 if !aggr_func.params.order_by.is_empty() {
359 return Err(format!("unsupported aggregate ORDER BY: {aggr_expr}"));
360 }
361 if aggr_func.params.null_treatment.is_some() {
362 return Err(format!("unsupported aggregate NULL treatment: {aggr_expr}"));
363 }
364
365 match aggr_func.func.name().to_ascii_lowercase().as_str() {
366 "sum" | "count" => Ok(IncrementalAggregateMergeOp::Sum),
367 "min" => Ok(IncrementalAggregateMergeOp::Min),
368 "max" => Ok(IncrementalAggregateMergeOp::Max),
369 "bool_and" => Ok(IncrementalAggregateMergeOp::BoolAnd),
370 "bool_or" => Ok(IncrementalAggregateMergeOp::BoolOr),
371 "bit_and" => Ok(IncrementalAggregateMergeOp::BitAnd),
372 "bit_or" => Ok(IncrementalAggregateMergeOp::BitOr),
373 "bit_xor" => Ok(IncrementalAggregateMergeOp::BitXor),
374 _ => Err(aggr_expr.to_string()),
375 }
376}
377
378fn resolve_aggregate_output_field_name(
379 aggr_expr: &Expr,
380 projection_info: &OutputProjectionInfo,
381 output_field_name_set: &HashSet<String>,
382) -> Option<String> {
383 let raw_name = aggr_expr.qualified_name().1;
389 if let Some(alias) = projection_info.output_aliases.get(&raw_name) {
390 Some(alias.clone())
391 } else if !projection_info.has_top_level_projection && output_field_name_set.contains(&raw_name)
392 {
393 Some(raw_name)
394 } else {
395 None
396 }
397}
398
399fn find_uncovered_output_fields(
400 projection_info: &OutputProjectionInfo,
401 group_key_names: &[String],
402 merge_columns: &[IncrementalAggregateMergeColumn],
403) -> Vec<String> {
404 let group_key_names = group_key_names.iter().cloned().collect::<HashSet<_>>();
405 let merge_column_names = merge_columns
406 .iter()
407 .map(|c| c.output_field_name.clone())
408 .collect::<HashSet<_>>();
409
410 projection_info
411 .output_field_names
412 .iter()
413 .filter(|name| {
414 !group_key_names.contains(*name)
415 && !merge_column_names.contains(*name)
416 && !projection_info.literal_columns.contains(*name)
417 && name.as_str() != AUTO_CREATED_UPDATE_AT_TS_COL
421 && name.as_str() != AUTO_CREATED_PLACEHOLDER_TS_COL
422 })
423 .cloned()
424 .collect()
425}
426
427fn find_unsupported_group_key_projection_outputs(
428 plan: &LogicalPlan,
429 aggregate: &Aggregate,
430 group_key_names: &[String],
431) -> Vec<String> {
432 let LogicalPlan::Projection(projection) = plan else {
433 return vec![];
434 };
435
436 let group_key_names = group_key_names.iter().cloned().collect::<HashSet<_>>();
437 let group_expr_names = aggregate
438 .group_expr
439 .iter()
440 .filter_map(|expr| expr.name_for_alias().ok())
441 .collect::<HashSet<_>>();
442 projection
443 .expr
444 .iter()
445 .filter_map(|expr| {
446 let output_name = expr.qualified_name().1;
447 if !group_key_names.contains(&output_name) {
448 return None;
449 }
450
451 let source_name = match expr {
452 Expr::Alias(alias) => alias.expr.name_for_alias().ok(),
453 _ => expr.name_for_alias().ok(),
454 };
455 if source_name.is_some_and(|name| group_expr_names.contains(&name)) {
456 None
457 } else {
458 Some(format!(
459 "unsupported group key output field is not a transparent group expression: {output_name}"
460 ))
461 }
462 })
463 .collect()
464}
465
466pub fn analyze_incremental_aggregate_plan(
467 plan: &LogicalPlan,
468) -> Result<Option<IncrementalAggregateAnalysis>, Error> {
469 let group_key_names = find_group_key_names(plan)?;
470 let aggregate = match extract_incremental_aggregate(plan) {
471 Ok(Some(aggregate)) => aggregate,
472 Ok(None) => return Ok(None),
473 Err(reason) => {
474 let projection_info = collect_output_projection_info(plan);
475 let mut unsupported_exprs = projection_info
476 .duplicate_output_names()
477 .into_iter()
478 .map(|name| format!("duplicate output field name: {name}"))
479 .collect::<Vec<_>>();
480 unsupported_exprs.push(reason);
481 unsupported_exprs.extend(projection_info.duplicate_aggregate_aliases.iter().cloned());
482 return Ok(Some(IncrementalAggregateAnalysis {
483 group_key_names,
484 merge_columns: vec![],
485 literal_columns: vec![],
486 output_field_names: projection_info.output_field_names,
487 unsupported_exprs,
488 }));
489 }
490 };
491 let aggr_exprs = aggregate.aggr_expr.clone();
492 let projection_info = collect_output_projection_info(plan);
493 let output_field_name_set = projection_info.output_field_name_set();
494
495 let mut merge_columns = Vec::with_capacity(aggr_exprs.len());
496 let mut unsupported_exprs = projection_info
497 .duplicate_output_names()
498 .into_iter()
499 .map(|name| format!("duplicate output field name: {name}"))
500 .collect::<Vec<_>>();
501 if has_grouping_set(plan) {
502 unsupported_exprs.push(
503 "unsupported GROUPING SETS/CUBE/ROLLUP in incremental aggregate rewrite".to_string(),
504 );
505 }
506 if group_key_names.is_empty() {
507 unsupported_exprs
508 .push("unsupported global aggregate in incremental aggregate rewrite".to_string());
509 }
510 unsupported_exprs.extend(find_unsupported_group_key_projection_outputs(
511 plan,
512 aggregate,
513 &group_key_names,
514 ));
515 unsupported_exprs.extend(projection_info.duplicate_aggregate_aliases.iter().cloned());
516 for aggr_expr in aggr_exprs {
517 let merge_op = match merge_op_for_aggregate_expr(&aggr_expr) {
518 Ok(merge_op) => merge_op,
519 Err(reason) => {
520 unsupported_exprs.push(reason);
521 continue;
522 }
523 };
524 let Some(output_field_name) = resolve_aggregate_output_field_name(
525 &aggr_expr,
526 &projection_info,
527 &output_field_name_set,
528 ) else {
529 unsupported_exprs.push(aggr_expr.to_string());
530 continue;
531 };
532 merge_columns.push(IncrementalAggregateMergeColumn::new(
533 output_field_name,
534 merge_op,
535 ));
536 }
537 unsupported_exprs.extend(
538 find_uncovered_output_fields(&projection_info, &group_key_names, &merge_columns)
539 .into_iter()
540 .map(|name| format!("unsupported output field: {name}")),
541 );
542 if !unsupported_exprs.is_empty() {
543 merge_columns.clear();
544 }
545 let mut literal_columns = projection_info
546 .literal_columns
547 .into_iter()
548 .collect::<Vec<_>>();
549 literal_columns.sort();
550
551 Ok(Some(IncrementalAggregateAnalysis {
552 group_key_names,
553 merge_columns,
554 literal_columns,
555 output_field_names: projection_info.output_field_names,
556 unsupported_exprs,
557 }))
558}
559
560pub async fn rewrite_incremental_aggregate_with_sink_merge(
592 delta_plan: &LogicalPlan,
593 analysis: &IncrementalAggregateAnalysis,
594 sink_table: TableRef,
595 sink_table_name: &TableName,
596 sink_dirty_filter: Option<Expr>,
597) -> Result<LogicalPlan, Error> {
598 ensure!(
599 analysis.unsupported_exprs.is_empty(),
600 InvalidQuerySnafu {
601 reason: format!(
602 "UNSUPPORTED_INCREMENTAL_AGG: unsupported aggregate expressions {:?}",
603 analysis.unsupported_exprs
604 )
605 }
606 );
607
608 ensure!(
609 !analysis.merge_columns.is_empty(),
610 InvalidQuerySnafu {
611 reason:
612 "UNSUPPORTED_INCREMENTAL_AGG: aggregate query has no mergeable aggregate columns"
613 .to_string()
614 }
615 );
616
617 ensure!(
618 !analysis.group_key_names.is_empty(),
619 InvalidQuerySnafu {
620 reason: "UNSUPPORTED_INCREMENTAL_AGG: global aggregate query is not supported"
621 .to_string()
622 }
623 );
624
625 let delta_alias = "__flow_delta";
626 let sink_alias = "__flow_sink";
627
628 let mut selected_columns = analysis.group_key_names.clone();
629 selected_columns.extend(
630 analysis
631 .merge_columns
632 .iter()
633 .map(|c| c.output_field_name.clone()),
634 );
635 let mut delta_selected_columns = selected_columns.clone();
636 delta_selected_columns.extend(analysis.literal_columns.iter().cloned());
637
638 let delta_selected_exprs = delta_selected_columns
639 .iter()
640 .cloned()
641 .map(unqualified_col)
642 .collect::<Vec<_>>();
643 let delta_selected = LogicalPlanBuilder::from(delta_plan.clone())
644 .project(delta_selected_exprs)
645 .with_context(|_| DatafusionSnafu {
646 context: "Failed to project delta plan for incremental sink merge".to_string(),
647 })?
648 .alias(delta_alias)
649 .with_context(|_| DatafusionSnafu {
650 context: "Failed to alias delta plan for incremental sink merge".to_string(),
651 })?
652 .build()
653 .with_context(|_| DatafusionSnafu {
654 context: "Failed to build projected delta plan for incremental sink merge".to_string(),
655 })?;
656
657 let table_provider = Arc::new(DfTableProviderAdapter::new(sink_table));
658 let table_source = Arc::new(DefaultTableSource::new(table_provider));
659 let sink_scan = LogicalPlan::TableScan(
660 TableScan::try_new(
661 TableReference::Full {
662 catalog: sink_table_name[0].clone().into(),
663 schema: sink_table_name[1].clone().into(),
664 table: sink_table_name[2].clone().into(),
665 },
666 table_source,
667 None,
668 vec![],
669 None,
670 )
671 .with_context(|_| DatafusionSnafu {
672 context: "Failed to build sink table scan for incremental sink merge".to_string(),
673 })?,
674 );
675
676 let sink_selected_exprs = selected_columns
677 .iter()
678 .cloned()
679 .map(unqualified_col)
680 .collect::<Vec<_>>();
681 let sink_input = if let Some(predicate) = sink_dirty_filter {
682 LogicalPlanBuilder::from(sink_scan)
683 .filter(predicate)
684 .with_context(|_| DatafusionSnafu {
685 context: "Failed to filter sink table scan for incremental sink merge".to_string(),
686 })?
687 .build()
688 .with_context(|_| DatafusionSnafu {
689 context: "Failed to build filtered sink plan for incremental sink merge"
690 .to_string(),
691 })?
692 } else {
693 sink_scan
694 };
695
696 let sink_selected = LogicalPlanBuilder::from(sink_input)
697 .project(sink_selected_exprs)
698 .with_context(|_| DatafusionSnafu {
699 context: "Failed to project sink table scan for incremental sink merge".to_string(),
700 })?
701 .alias(sink_alias)
702 .with_context(|_| DatafusionSnafu {
703 context: "Failed to alias sink plan for incremental sink merge".to_string(),
704 })?
705 .build()
706 .with_context(|_| DatafusionSnafu {
707 context: "Failed to build projected sink plan for incremental sink merge".to_string(),
708 })?;
709
710 let join_keys = (
711 analysis
712 .group_key_names
713 .iter()
714 .cloned()
715 .map(|c| qualified_column(delta_alias, c))
716 .collect::<Vec<_>>(),
717 analysis
718 .group_key_names
719 .iter()
720 .cloned()
721 .map(|c| qualified_column(sink_alias, c))
722 .collect::<Vec<_>>(),
723 );
724
725 let joined = LogicalPlanBuilder::from(delta_selected)
726 .join_detailed(
727 sink_selected,
728 JoinType::Left,
729 join_keys,
730 None,
731 NullEquality::NullEqualsNull,
732 )
733 .with_context(|_| DatafusionSnafu {
734 context: "Failed to left join delta and sink plans for incremental sink merge"
735 .to_string(),
736 })?
737 .build()
738 .with_context(|_| DatafusionSnafu {
739 context: "Failed to build left join plan for incremental sink merge".to_string(),
740 })?;
741
742 let group_key_names = analysis.group_key_names.iter().collect::<HashSet<_>>();
743 let literal_columns = analysis.literal_columns.iter().collect::<HashSet<_>>();
744 let merge_columns = analysis
745 .merge_columns
746 .iter()
747 .map(|c| (&c.output_field_name, c))
748 .collect::<HashMap<_, _>>();
749
750 let mut projection_exprs = Vec::with_capacity(analysis.output_field_names.len());
751 for output_field_name in &analysis.output_field_names {
752 if group_key_names.contains(output_field_name)
753 || literal_columns.contains(output_field_name)
754 {
755 projection_exprs.push(
756 qualified_col(delta_alias, output_field_name.clone()).alias(output_field_name),
757 );
758 } else if let Some(merge_col) = merge_columns.get(output_field_name) {
759 projection_exprs.push(build_left_join_merge_expr(
760 delta_alias,
761 sink_alias,
762 merge_col,
763 )?);
764 } else {
765 return InvalidQuerySnafu {
766 reason: format!(
767 "UNSUPPORTED_INCREMENTAL_AGG: output field {output_field_name} is not covered by group keys, literals, or merge columns"
768 ),
769 }
770 .fail();
771 }
772 }
773
774 LogicalPlanBuilder::from(joined)
775 .project(projection_exprs)
776 .with_context(|_| DatafusionSnafu {
777 context: "Failed to build projection merge plan for incremental sink merge".to_string(),
778 })?
779 .build()
780 .with_context(|_| DatafusionSnafu {
781 context: "Failed to finalize incremental aggregate sink merge plan".to_string(),
782 })
783}
784
785fn build_left_join_merge_expr(
786 delta_alias: &str,
787 sink_alias: &str,
788 merge_col: &IncrementalAggregateMergeColumn,
789) -> Result<Expr, Error> {
790 let left = qualified_col(delta_alias, merge_col.output_field_name.clone());
791 let right = qualified_col(sink_alias, merge_col.output_field_name.clone());
792 let merged = match merge_col.merge_op {
793 IncrementalAggregateMergeOp::Sum => when(is_null(left.clone()), right.clone())
794 .when(is_null(right.clone()), left.clone())
795 .otherwise(binary_expr(left.clone(), Operator::Plus, right.clone()))
796 .with_context(|_| DatafusionSnafu {
797 context: "Failed to build SUM merge expression".to_string(),
798 })?,
799 IncrementalAggregateMergeOp::Min => when(is_null(right.clone()), left.clone())
800 .when(left.clone().lt_eq(right.clone()), left.clone())
801 .otherwise(right.clone())
802 .with_context(|_| DatafusionSnafu {
803 context: "Failed to build MIN merge expression".to_string(),
804 })?,
805 IncrementalAggregateMergeOp::Max => when(is_null(right.clone()), left.clone())
806 .when(left.clone().gt_eq(right.clone()), left.clone())
807 .otherwise(right.clone())
808 .with_context(|_| DatafusionSnafu {
809 context: "Failed to build MAX merge expression".to_string(),
810 })?,
811 IncrementalAggregateMergeOp::BoolAnd => when(is_null(left.clone()), right.clone())
812 .when(is_null(right.clone()), left.clone())
813 .otherwise(and(left.clone(), right.clone()))
814 .with_context(|_| DatafusionSnafu {
815 context: "Failed to build BOOL_AND merge expression".to_string(),
816 })?,
817 IncrementalAggregateMergeOp::BoolOr => when(is_null(left.clone()), right.clone())
818 .when(is_null(right.clone()), left.clone())
819 .otherwise(or(left.clone(), right.clone()))
820 .with_context(|_| DatafusionSnafu {
821 context: "Failed to build BOOL_OR merge expression".to_string(),
822 })?,
823 IncrementalAggregateMergeOp::BitAnd => when(is_null(left.clone()), right.clone())
824 .when(is_null(right.clone()), left.clone())
825 .otherwise(bitwise_and(left.clone(), right.clone()))
826 .with_context(|_| DatafusionSnafu {
827 context: "Failed to build BIT_AND merge expression".to_string(),
828 })?,
829 IncrementalAggregateMergeOp::BitOr => when(is_null(left.clone()), right.clone())
830 .when(is_null(right.clone()), left.clone())
831 .otherwise(bitwise_or(left.clone(), right.clone()))
832 .with_context(|_| DatafusionSnafu {
833 context: "Failed to build BIT_OR merge expression".to_string(),
834 })?,
835 IncrementalAggregateMergeOp::BitXor => when(is_null(left.clone()), right.clone())
836 .when(is_null(right.clone()), left.clone())
837 .otherwise(bitwise_xor(left.clone(), right.clone()))
838 .with_context(|_| DatafusionSnafu {
839 context: "Failed to build BIT_XOR merge expression".to_string(),
840 })?,
841 };
842 Ok(merged.alias(merge_col.output_field_name.clone()))
843}
844
845pub async fn get_table_info_df_schema(
846 catalog_mr: CatalogManagerRef,
847 table_name: TableName,
848) -> Result<(TableRef, Arc<DFSchema>), Error> {
849 let full_table_name = table_name.clone().join(".");
850 let table = catalog_mr
851 .table(&table_name[0], &table_name[1], &table_name[2], None)
852 .await
853 .map_err(BoxedError::new)
854 .context(ExternalSnafu)?
855 .context(TableNotFoundSnafu {
856 name: &full_table_name,
857 })?;
858 let table_info = table.table_info();
859
860 let schema = table_info.meta.schema.clone();
861
862 let df_schema: Arc<DFSchema> = Arc::new(
863 schema
864 .arrow_schema()
865 .clone()
866 .try_into()
867 .with_context(|_| DatafusionSnafu {
868 context: format!(
869 "Failed to convert arrow schema to datafusion schema, arrow_schema={:?}",
870 schema.arrow_schema()
871 ),
872 })?,
873 );
874 Ok((table, df_schema))
875}
876
877pub async fn sql_to_df_plan(
880 query_ctx: QueryContextRef,
881 engine: QueryEngineRef,
882 sql: &str,
883 optimize: bool,
884) -> Result<LogicalPlan, Error> {
885 let scheduled_time = query::options::parse_scheduled_time_datetime(&query_ctx.extensions())
886 .map_err(BoxedError::new)
887 .context(ExternalSnafu)?;
888 let stmts = ParserContext::create_with_dialect(
889 sql,
890 query_ctx.sql_dialect(),
891 ParseOptions { scheduled_time },
892 )
893 .map_err(BoxedError::new)
894 .context(ExternalSnafu)?;
895
896 ensure!(
897 stmts.len() == 1,
898 InvalidQuerySnafu {
899 reason: format!("Expect only one statement, found {}", stmts.len())
900 }
901 );
902 let stmt = &stmts[0];
903 let query_stmt = match stmt {
904 Statement::Tql(tql) => match tql {
905 Tql::Eval(eval) => {
906 let eval = eval.clone();
907 let promql = PromQuery {
908 start: eval.start,
909 end: eval.end,
910 step: eval.step,
911 query: eval.query,
912 lookback: eval
913 .lookback
914 .unwrap_or_else(|| DEFAULT_LOOKBACK_STRING.to_string()),
915 alias: eval.alias.clone(),
916 };
917
918 QueryLanguageParser::parse_promql(&promql, &query_ctx)
919 .map_err(BoxedError::new)
920 .context(ExternalSnafu)?
921 }
922 _ => InvalidQuerySnafu {
923 reason: format!("TQL statement {tql:?} is not supported, expect only TQL EVAL"),
924 }
925 .fail()?,
926 },
927 _ => QueryStatement::Sql(stmt.clone()),
928 };
929 let plan = engine
930 .planner()
931 .plan(&query_stmt, query_ctx.clone())
932 .await
933 .map_err(BoxedError::new)
934 .context(ExternalSnafu)?;
935
936 let plan = if optimize {
937 apply_df_optimizer(plan, &query_ctx).await?
938 } else {
939 plan
940 };
941 Ok(plan)
942}
943
944pub(crate) async fn gen_plan_with_matching_schema(
947 sql: &str,
948 query_ctx: QueryContextRef,
949 engine: QueryEngineRef,
950 sink_table_schema: SchemaRef,
951 primary_key_indices: &[usize],
952 allow_partial: bool,
953) -> Result<LogicalPlan, Error> {
954 let plan = sql_to_df_plan(query_ctx.clone(), engine.clone(), sql, false).await?;
955
956 let mut add_auto_column = ColumnMatcherRewriter::new(
957 sink_table_schema,
958 primary_key_indices.to_vec(),
959 allow_partial,
960 );
961 let plan = plan
962 .clone()
963 .rewrite(&mut add_auto_column)
964 .with_context(|_| DatafusionSnafu {
965 context: "Failed to rewrite plan".to_string(),
966 })?
967 .data;
968 Ok(plan)
969}
970
971pub fn df_plan_to_sql(plan: &LogicalPlan) -> Result<String, Error> {
972 struct ForceQuoteIdentifiers;
974 impl datafusion::sql::unparser::dialect::Dialect for ForceQuoteIdentifiers {
975 fn identifier_quote_style(&self, identifier: &str) -> Option<char> {
976 if identifier.to_lowercase() != identifier {
977 Some('`')
978 } else {
979 None
980 }
981 }
982 }
983 let unparser = Unparser::new(&ForceQuoteIdentifiers);
984 let sql = unparser
986 .plan_to_sql(plan)
987 .with_context(|_e| DatafusionSnafu {
988 context: format!("Failed to unparse logical plan {plan:?}"),
989 })?;
990 Ok(sql.to_string())
991}
992
993#[derive(Debug, Clone, Default)]
995pub struct FindGroupByFinalName {
996 group_exprs: Option<HashSet<datafusion_expr::Expr>>,
997}
998
999impl FindGroupByFinalName {
1000 pub fn get_group_expr_names(&self) -> Option<HashSet<String>> {
1001 self.group_exprs
1002 .as_ref()
1003 .map(|exprs| exprs.iter().map(|expr| expr.qualified_name().1).collect())
1004 }
1005}
1006
1007impl TreeNodeVisitor<'_> for FindGroupByFinalName {
1008 type Node = LogicalPlan;
1009
1010 fn f_down(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
1011 if let LogicalPlan::Aggregate(aggregate) = node {
1012 self.group_exprs = Some(aggregate.group_expr.iter().cloned().collect());
1013 debug!(
1014 "FindGroupByFinalName: Get Group by exprs from Aggregate: {:?}",
1015 self.group_exprs
1016 );
1017 } else if let LogicalPlan::Distinct(distinct) = node {
1018 debug!("FindGroupByFinalName: Distinct: {}", node);
1019 match distinct {
1020 Distinct::All(input) => {
1021 if let LogicalPlan::TableScan(table_scan) = &**input {
1022 let len = table_scan.projected_schema.fields().len();
1024 let columns = (0..len)
1025 .map(|f| {
1026 let (qualifier, field) =
1027 table_scan.projected_schema.qualified_field(f);
1028 datafusion_common::Column::new(qualifier.cloned(), field.name())
1029 })
1030 .map(datafusion_expr::Expr::Column);
1031 self.group_exprs = Some(columns.collect());
1032 } else {
1033 self.group_exprs = Some(input.expressions().iter().cloned().collect())
1034 }
1035 }
1036 Distinct::On(distinct_on) => {
1037 self.group_exprs = Some(distinct_on.on_expr.iter().cloned().collect())
1038 }
1039 }
1040 debug!(
1041 "FindGroupByFinalName: Get Group by exprs from Distinct: {:?}",
1042 self.group_exprs
1043 );
1044 }
1045
1046 Ok(TreeNodeRecursion::Continue)
1047 }
1048
1049 fn f_up(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
1051 if let LogicalPlan::Projection(projection) = node {
1052 for expr in &projection.expr {
1053 let Some(group_exprs) = &mut self.group_exprs else {
1054 return Ok(TreeNodeRecursion::Continue);
1055 };
1056 if let datafusion_expr::Expr::Alias(alias) = expr {
1057 let mut new_group_exprs = group_exprs.clone();
1059 for group_expr in group_exprs.iter() {
1060 if group_expr.name_for_alias()? == alias.expr.name_for_alias()? {
1061 new_group_exprs.remove(group_expr);
1062 new_group_exprs.insert(expr.clone());
1063 break;
1064 }
1065 }
1066 *group_exprs = new_group_exprs;
1067 }
1068 }
1069 }
1070 debug!("Aliased group by exprs: {:?}", self.group_exprs);
1071 Ok(TreeNodeRecursion::Continue)
1072 }
1073}
1074
1075#[derive(Debug)]
1082pub struct ColumnMatcherRewriter {
1083 pub schema: SchemaRef,
1084 pub is_rewritten: bool,
1085 pub primary_key_indices: Vec<usize>,
1086 pub allow_partial: bool,
1087}
1088
1089impl ColumnMatcherRewriter {
1090 pub fn new(schema: SchemaRef, primary_key_indices: Vec<usize>, allow_partial: bool) -> Self {
1091 Self {
1092 schema,
1093 is_rewritten: false,
1094 primary_key_indices,
1095 allow_partial,
1096 }
1097 }
1098
1099 fn modify_project_exprs(
1101 &mut self,
1102 mut exprs: Vec<Expr>,
1103 input_schema: &DFSchema,
1104 ) -> DfResult<Vec<Expr>> {
1105 if self.allow_partial {
1106 return self.modify_project_exprs_with_partial(exprs);
1107 }
1108
1109 let original_exprs = exprs.clone();
1110
1111 let all_names = self
1112 .schema
1113 .column_schemas()
1114 .iter()
1115 .map(|c| c.name.clone())
1116 .collect::<BTreeSet<_>>();
1117 let query_col_cnt = exprs.len();
1119 let table_col_cnt = self.schema.column_schemas().len();
1120 debug!("query_col_cnt={query_col_cnt}, table_col_cnt={table_col_cnt}");
1121
1122 let placeholder_ts_expr =
1123 datafusion::logical_expr::lit(ScalarValue::TimestampMillisecond(Some(0), None))
1124 .alias(AUTO_CREATED_PLACEHOLDER_TS_COL);
1125
1126 if query_col_cnt == table_col_cnt {
1127 } else if query_col_cnt + 1 == table_col_cnt {
1129 let last_col_schema = self.schema.column_schemas().last().unwrap();
1130
1131 if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
1133 && self.schema.timestamp_index() == Some(table_col_cnt - 1)
1134 {
1135 exprs.push(placeholder_ts_expr);
1136 } else if last_col_schema.data_type.is_timestamp() {
1137 exprs.push(datafusion::prelude::now().alias(&last_col_schema.name));
1139 } else {
1140 return Err(DataFusionError::Plan(format_flow_sink_schema_mismatch(
1141 &original_exprs,
1142 self.schema.as_ref(),
1143 )));
1144 }
1145 } else if query_col_cnt + 2 == table_col_cnt {
1146 let mut col_iter = self.schema.column_schemas().iter().rev();
1147 let last_col_schema = col_iter.next().unwrap();
1148 let second_last_col_schema = col_iter.next().unwrap();
1149 if second_last_col_schema.data_type.is_timestamp() {
1150 exprs.push(datafusion::prelude::now().alias(&second_last_col_schema.name));
1151 } else {
1152 return Err(DataFusionError::Plan(format!(
1153 "Expect the second last column in the table to be timestamp column, found column {} with type {:?}",
1154 second_last_col_schema.name, second_last_col_schema.data_type
1155 )));
1156 }
1157
1158 if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
1159 && self.schema.timestamp_index() == Some(table_col_cnt - 1)
1160 {
1161 exprs.push(placeholder_ts_expr);
1162 } else {
1163 return Err(DataFusionError::Plan(format!(
1164 "Expect timestamp column {}, found {:?}",
1165 AUTO_CREATED_PLACEHOLDER_TS_COL, last_col_schema
1166 )));
1167 }
1168 } else {
1169 return Err(DataFusionError::Plan(format_flow_sink_schema_mismatch(
1170 &original_exprs,
1171 self.schema.as_ref(),
1172 )));
1173 }
1174
1175 self.match_extra_output_columns(exprs, input_schema, &original_exprs, &all_names)
1176 }
1177
1178 fn match_extra_output_columns(
1189 &self,
1190 mut exprs: Vec<Expr>,
1191 input_schema: &DFSchema,
1192 original_exprs: &[Expr],
1193 all_names: &BTreeSet<String>,
1194 ) -> DfResult<Vec<Expr>> {
1195 let mut output_names = exprs
1196 .iter()
1197 .map(|expr| expr.qualified_name().1)
1198 .collect::<Vec<_>>();
1199 let output_name_set = output_names.iter().cloned().collect::<BTreeSet<_>>();
1200 let extra_expr_indices = output_names
1201 .iter()
1202 .enumerate()
1203 .filter_map(|(idx, name)| (!all_names.contains(name)).then_some(idx))
1204 .collect::<Vec<_>>();
1205 let missing_sink_indices = self
1206 .schema
1207 .column_schemas()
1208 .iter()
1209 .enumerate()
1210 .filter_map(|(idx, column)| (!output_name_set.contains(&column.name)).then_some(idx))
1211 .collect::<Vec<_>>();
1212
1213 if extra_expr_indices.is_empty() && missing_sink_indices.is_empty() {
1214 return Ok(exprs);
1215 }
1216
1217 if extra_expr_indices.len() != missing_sink_indices.len() {
1218 return Err(DataFusionError::Plan(format_flow_sink_schema_mismatch(
1219 original_exprs,
1220 self.schema.as_ref(),
1221 )));
1222 }
1223
1224 let mut positional_matches = Vec::new();
1225 for expr_idx in extra_expr_indices {
1226 if !missing_sink_indices.contains(&expr_idx) {
1227 return Err(DataFusionError::Plan(format_flow_sink_schema_mismatch(
1228 original_exprs,
1229 self.schema.as_ref(),
1230 )));
1231 }
1232
1233 let target_col_schema = &self.schema.column_schemas()[expr_idx];
1234 let expr_type =
1235 ConcreteDataType::from_arrow_type(&exprs[expr_idx].get_type(input_schema)?);
1236 if is_obviously_incompatible_positional_match(&expr_type, &target_col_schema.data_type)
1237 {
1238 return Err(DataFusionError::Plan(format!(
1239 "Cannot match flow output column '{}' to sink column '{}' by position: incompatible data types, flow output type is {:?}, sink column type is {:?}. {}",
1240 output_names[expr_idx],
1241 target_col_schema.name,
1242 expr_type,
1243 target_col_schema.data_type,
1244 format_flow_sink_schema_mismatch(original_exprs, self.schema.as_ref())
1245 )));
1246 }
1247
1248 let target_name = target_col_schema.name.clone();
1249 positional_matches.push(format!(
1250 "{} -> {} (flow output type: {:?}, sink column type: {:?})",
1251 output_names[expr_idx], target_name, expr_type, target_col_schema.data_type
1252 ));
1253 exprs[expr_idx] = exprs[expr_idx].clone().alias(target_name.clone());
1254 output_names[expr_idx] = target_name;
1255 }
1256
1257 if !positional_matches.is_empty() {
1258 debug!(
1259 "Matched flow output columns to sink columns by position: {:?}",
1260 positional_matches
1261 );
1262 }
1263
1264 let duplicated_output_names = duplicate_names(&output_names);
1265 if !duplicated_output_names.is_empty() {
1266 return Err(DataFusionError::Plan(format!(
1267 "Flow output schema contains duplicate column(s) after schema matching {:?}. {}",
1268 duplicated_output_names,
1269 format_flow_sink_schema_mismatch(&exprs, self.schema.as_ref())
1270 )));
1271 }
1272
1273 Ok(exprs)
1274 }
1275
1276 fn modify_project_exprs_with_partial(&mut self, exprs: Vec<Expr>) -> DfResult<Vec<Expr>> {
1277 let table_col_cnt = self.schema.column_schemas().len();
1278 let query_col_cnt = exprs.len();
1279
1280 if query_col_cnt > table_col_cnt {
1281 return Err(DataFusionError::Plan(format_flow_sink_schema_mismatch(
1282 &exprs,
1283 self.schema.as_ref(),
1284 )));
1285 }
1286
1287 let name_to_expr: HashMap<String, Expr> = exprs
1288 .clone()
1289 .into_iter()
1290 .map(|e| (e.qualified_name().1, e))
1291 .collect();
1292
1293 let required_columns = self.required_columns_for_partial();
1294 let missing: Vec<_> = required_columns
1295 .iter()
1296 .filter(|name| !name_to_expr.contains_key(*name))
1297 .cloned()
1298 .collect();
1299 if !missing.is_empty() {
1300 return Err(DataFusionError::Plan(format!(
1301 "Column(s) {:?} required by sink table are missing from flow output when merge_mode=last_non_null. {}",
1302 missing,
1303 format_flow_sink_schema_mismatch(&exprs, self.schema.as_ref())
1304 )));
1305 }
1306
1307 let placeholder_ts_expr =
1308 datafusion::logical_expr::lit(ScalarValue::TimestampMillisecond(Some(0), None))
1309 .alias(AUTO_CREATED_PLACEHOLDER_TS_COL);
1310
1311 let timestamp_index = self.schema.timestamp_index();
1312 let mut remap = name_to_expr;
1313 let mut new_exprs = Vec::with_capacity(table_col_cnt);
1314
1315 for (idx, col_schema) in self.schema.column_schemas().iter().enumerate() {
1316 let col_name = col_schema.name.clone();
1317 if let Some(expr) = remap.remove(&col_name) {
1318 let expr = if expr.qualified_name().1 == col_name {
1319 expr
1320 } else {
1321 expr.alias(col_name.clone())
1322 };
1323 new_exprs.push(expr);
1324 continue;
1325 }
1326
1327 if col_name == AUTO_CREATED_PLACEHOLDER_TS_COL && timestamp_index == Some(idx) {
1328 new_exprs.push(placeholder_ts_expr.clone());
1329 continue;
1330 }
1331
1332 if col_name == AUTO_CREATED_UPDATE_AT_TS_COL && col_schema.data_type.is_timestamp() {
1333 new_exprs.push(datafusion::prelude::now().alias(&col_name));
1334 continue;
1335 }
1336
1337 new_exprs.push(Self::null_expr(col_schema));
1338 }
1339
1340 if !remap.is_empty() {
1341 let extra: Vec<_> = remap.keys().cloned().collect();
1342 return Err(DataFusionError::Plan(format!(
1343 "Flow output has extra column(s) {:?} not found in sink schema when merge_mode=last_non_null. {}",
1344 extra,
1345 format_flow_sink_schema_mismatch(&exprs, self.schema.as_ref())
1346 )));
1347 }
1348
1349 Ok(new_exprs)
1350 }
1351
1352 fn null_expr(col_schema: &ColumnSchema) -> Expr {
1353 Expr::Literal(ScalarValue::Null, None).alias(col_schema.name.clone())
1354 }
1355
1356 fn required_columns_for_partial(&self) -> HashSet<String> {
1357 let mut required = HashSet::new();
1358 for idx in &self.primary_key_indices {
1359 if let Some(col) = self.schema.column_schemas().get(*idx) {
1360 required.insert(col.name.clone());
1361 }
1362 }
1363
1364 if let Some(ts_idx) = self.schema.timestamp_index()
1365 && let Some(col) = self.schema.column_schemas().get(ts_idx)
1366 && col.name != AUTO_CREATED_PLACEHOLDER_TS_COL
1367 {
1368 required.insert(col.name.clone());
1369 }
1370
1371 required
1372 }
1373}
1374
1375fn is_obviously_incompatible_positional_match(
1376 expr_type: &ConcreteDataType,
1377 sink_type: &ConcreteDataType,
1378) -> bool {
1379 if expr_type.is_null() || expr_type == sink_type {
1384 return false;
1385 }
1386
1387 expr_type.is_timestamp() != sink_type.is_timestamp()
1388 || expr_type.is_string() != sink_type.is_string()
1389 || expr_type.is_boolean() != sink_type.is_boolean()
1390 || expr_type.is_json() != sink_type.is_json()
1391 || expr_type.is_vector() != sink_type.is_vector()
1392}
1393
1394fn duplicate_names(names: &[String]) -> Vec<String> {
1395 let mut seen = HashSet::new();
1396 let mut duplicated = BTreeSet::new();
1397 for name in names {
1398 if !seen.insert(name.as_str()) {
1399 duplicated.insert(name.as_str());
1400 }
1401 }
1402 duplicated.into_iter().map(str::to_string).collect()
1403}
1404
1405fn format_flow_sink_schema_mismatch(
1406 query_exprs: &[Expr],
1407 sink_schema: &datatypes::schema::Schema,
1408) -> String {
1409 let flow_output_columns = query_exprs
1410 .iter()
1411 .map(|expr| expr.qualified_name().1)
1412 .collect::<Vec<_>>();
1413 let sink_table_columns = sink_schema
1414 .column_schemas()
1415 .iter()
1416 .map(|col| col.name.clone())
1417 .collect::<Vec<_>>();
1418
1419 let flow_output_set = flow_output_columns.iter().cloned().collect::<HashSet<_>>();
1420 let sink_table_set = sink_table_columns.iter().cloned().collect::<HashSet<_>>();
1421
1422 let mut extra_flow_columns = flow_output_columns
1423 .iter()
1424 .filter(|name| !sink_table_set.contains(*name))
1425 .cloned()
1426 .collect::<Vec<_>>();
1427 extra_flow_columns.sort();
1428 extra_flow_columns.dedup();
1429
1430 let mut missing_sink_columns = sink_table_columns
1431 .iter()
1432 .filter(|name| !flow_output_set.contains(*name))
1433 .cloned()
1434 .collect::<Vec<_>>();
1435 missing_sink_columns.sort();
1436 missing_sink_columns.dedup();
1437
1438 format!(
1439 "Flow output schema does not match sink table schema: found {} flow output columns and {} sink table columns. flow output columns: {:?}, sink table columns: {:?}, extra flow columns not in sink: {:?}, missing sink columns from flow output: {:?}",
1440 flow_output_columns.len(),
1441 sink_table_columns.len(),
1442 flow_output_columns,
1443 sink_table_columns,
1444 extra_flow_columns,
1445 missing_sink_columns
1446 )
1447}
1448
1449impl TreeNodeRewriter for ColumnMatcherRewriter {
1450 type Node = LogicalPlan;
1451 fn f_down(&mut self, mut node: Self::Node) -> DfResult<Transformed<Self::Node>> {
1452 if self.is_rewritten {
1453 return Ok(Transformed::no(node));
1454 }
1455
1456 if let LogicalPlan::Distinct(Distinct::All(_)) = &node {
1458 let mut exprs = vec![];
1459
1460 for field in node.schema().fields().iter() {
1461 exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
1462 field.name(),
1463 )));
1464 }
1465
1466 let projection =
1467 LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
1468
1469 node = projection;
1470 }
1471 else if let LogicalPlan::TableScan(table_scan) = node {
1473 let mut exprs = vec![];
1474
1475 for field in table_scan.projected_schema.fields().iter() {
1476 exprs.push(Expr::Column(datafusion::common::Column::new(
1477 Some(table_scan.table_name.clone()),
1478 field.name(),
1479 )));
1480 }
1481
1482 let projection = LogicalPlan::Projection(Projection::try_new(
1483 exprs,
1484 Arc::new(LogicalPlan::TableScan(table_scan)),
1485 )?);
1486
1487 node = projection;
1488 }
1489
1490 if let LogicalPlan::Projection(project) = &node {
1494 let exprs = project.expr.clone();
1495 let exprs = self.modify_project_exprs(exprs, project.input.schema())?;
1496
1497 self.is_rewritten = true;
1498 let new_plan =
1499 node.with_new_exprs(exprs, node.inputs().into_iter().cloned().collect())?;
1500 Ok(Transformed::yes(new_plan))
1501 } else {
1502 let mut exprs = vec![];
1504 for field in node.schema().fields().iter() {
1505 exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
1506 field.name(),
1507 )));
1508 }
1509 let exprs = self.modify_project_exprs(exprs, node.schema())?;
1510 self.is_rewritten = true;
1511 let new_plan =
1512 LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
1513 Ok(Transformed::yes(new_plan))
1514 }
1515 }
1516
1517 fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
1519 node.recompute_schema().map(Transformed::yes)
1520 }
1521}
1522
1523#[derive(Debug)]
1525pub struct AddFilterRewriter {
1526 extra_filter: Expr,
1527 is_rewritten: bool,
1528}
1529
1530impl AddFilterRewriter {
1531 pub fn new(filter: Expr) -> Self {
1532 Self {
1533 extra_filter: filter,
1534 is_rewritten: false,
1535 }
1536 }
1537}
1538
1539impl TreeNodeRewriter for AddFilterRewriter {
1540 type Node = LogicalPlan;
1541 fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
1542 if self.is_rewritten {
1543 return Ok(Transformed::no(node));
1544 }
1545 match node {
1546 LogicalPlan::Filter(mut filter) => {
1547 filter.predicate = filter.predicate.and(self.extra_filter.clone());
1548 self.is_rewritten = true;
1549 Ok(Transformed::yes(LogicalPlan::Filter(filter)))
1550 }
1551 LogicalPlan::TableScan(_) => {
1552 let filter =
1554 datafusion_expr::Filter::try_new(self.extra_filter.clone(), Arc::new(node))?;
1555 self.is_rewritten = true;
1556 Ok(Transformed::yes(LogicalPlan::Filter(filter)))
1557 }
1558 _ => Ok(Transformed::no(node)),
1559 }
1560 }
1561}