1use std::collections::{BTreeSet, HashMap, HashSet};
18use std::sync::Arc;
19
20use catalog::CatalogManagerRef;
21use common_error::ext::BoxedError;
22use common_telemetry::debug;
23use datafusion::error::Result as DfResult;
24use datafusion::logical_expr::Expr;
25use datafusion::sql::unparser::Unparser;
26use datafusion_common::tree_node::{
27 Transformed, TreeNode as _, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
28};
29use datafusion_common::{DFSchema, DataFusionError, ScalarValue};
30use datafusion_expr::{Distinct, LogicalPlan, Projection};
31use datatypes::schema::{ColumnSchema, SchemaRef};
32use query::QueryEngineRef;
33use query::parser::{DEFAULT_LOOKBACK_STRING, PromQuery, QueryLanguageParser, QueryStatement};
34use session::context::QueryContextRef;
35use snafu::{OptionExt, ResultExt, ensure};
36use sql::parser::{ParseOptions, ParserContext};
37use sql::statements::statement::Statement;
38use sql::statements::tql::Tql;
39use table::TableRef;
40
41use crate::adapter::{AUTO_CREATED_PLACEHOLDER_TS_COL, AUTO_CREATED_UPDATE_AT_TS_COL};
42use crate::df_optimizer::apply_df_optimizer;
43use crate::error::{DatafusionSnafu, ExternalSnafu, InvalidQuerySnafu, TableNotFoundSnafu};
44use crate::{Error, TableName};
45
46pub async fn get_table_info_df_schema(
47 catalog_mr: CatalogManagerRef,
48 table_name: TableName,
49) -> Result<(TableRef, Arc<DFSchema>), Error> {
50 let full_table_name = table_name.clone().join(".");
51 let table = catalog_mr
52 .table(&table_name[0], &table_name[1], &table_name[2], None)
53 .await
54 .map_err(BoxedError::new)
55 .context(ExternalSnafu)?
56 .context(TableNotFoundSnafu {
57 name: &full_table_name,
58 })?;
59 let table_info = table.table_info();
60
61 let schema = table_info.meta.schema.clone();
62
63 let df_schema: Arc<DFSchema> = Arc::new(
64 schema
65 .arrow_schema()
66 .clone()
67 .try_into()
68 .with_context(|_| DatafusionSnafu {
69 context: format!(
70 "Failed to convert arrow schema to datafusion schema, arrow_schema={:?}",
71 schema.arrow_schema()
72 ),
73 })?,
74 );
75 Ok((table, df_schema))
76}
77
78pub async fn sql_to_df_plan(
81 query_ctx: QueryContextRef,
82 engine: QueryEngineRef,
83 sql: &str,
84 optimize: bool,
85) -> Result<LogicalPlan, Error> {
86 let stmts =
87 ParserContext::create_with_dialect(sql, query_ctx.sql_dialect(), ParseOptions::default())
88 .map_err(BoxedError::new)
89 .context(ExternalSnafu)?;
90
91 ensure!(
92 stmts.len() == 1,
93 InvalidQuerySnafu {
94 reason: format!("Expect only one statement, found {}", stmts.len())
95 }
96 );
97 let stmt = &stmts[0];
98 let query_stmt = match stmt {
99 Statement::Tql(tql) => match tql {
100 Tql::Eval(eval) => {
101 let eval = eval.clone();
102 let promql = PromQuery {
103 start: eval.start,
104 end: eval.end,
105 step: eval.step,
106 query: eval.query,
107 lookback: eval
108 .lookback
109 .unwrap_or_else(|| DEFAULT_LOOKBACK_STRING.to_string()),
110 alias: eval.alias.clone(),
111 };
112
113 QueryLanguageParser::parse_promql(&promql, &query_ctx)
114 .map_err(BoxedError::new)
115 .context(ExternalSnafu)?
116 }
117 _ => InvalidQuerySnafu {
118 reason: format!("TQL statement {tql:?} is not supported, expect only TQL EVAL"),
119 }
120 .fail()?,
121 },
122 _ => QueryStatement::Sql(stmt.clone()),
123 };
124 let plan = engine
125 .planner()
126 .plan(&query_stmt, query_ctx.clone())
127 .await
128 .map_err(BoxedError::new)
129 .context(ExternalSnafu)?;
130
131 let plan = if optimize {
132 apply_df_optimizer(plan, &query_ctx).await?
133 } else {
134 plan
135 };
136 Ok(plan)
137}
138
139pub(crate) async fn gen_plan_with_matching_schema(
142 sql: &str,
143 query_ctx: QueryContextRef,
144 engine: QueryEngineRef,
145 sink_table_schema: SchemaRef,
146 primary_key_indices: &[usize],
147 allow_partial: bool,
148) -> Result<LogicalPlan, Error> {
149 let plan = sql_to_df_plan(query_ctx.clone(), engine.clone(), sql, false).await?;
150
151 let mut add_auto_column = ColumnMatcherRewriter::new(
152 sink_table_schema,
153 primary_key_indices.to_vec(),
154 allow_partial,
155 );
156 let plan = plan
157 .clone()
158 .rewrite(&mut add_auto_column)
159 .with_context(|_| DatafusionSnafu {
160 context: format!("Failed to rewrite plan:\n {}\n", plan),
161 })?
162 .data;
163 Ok(plan)
164}
165
166pub fn df_plan_to_sql(plan: &LogicalPlan) -> Result<String, Error> {
167 struct ForceQuoteIdentifiers;
169 impl datafusion::sql::unparser::dialect::Dialect for ForceQuoteIdentifiers {
170 fn identifier_quote_style(&self, identifier: &str) -> Option<char> {
171 if identifier.to_lowercase() != identifier {
172 Some('`')
173 } else {
174 None
175 }
176 }
177 }
178 let unparser = Unparser::new(&ForceQuoteIdentifiers);
179 let sql = unparser
181 .plan_to_sql(plan)
182 .with_context(|_e| DatafusionSnafu {
183 context: format!("Failed to unparse logical plan {plan:?}"),
184 })?;
185 Ok(sql.to_string())
186}
187
188#[derive(Debug, Clone, Default)]
190pub struct FindGroupByFinalName {
191 group_exprs: Option<HashSet<datafusion_expr::Expr>>,
192}
193
194impl FindGroupByFinalName {
195 pub fn get_group_expr_names(&self) -> Option<HashSet<String>> {
196 self.group_exprs
197 .as_ref()
198 .map(|exprs| exprs.iter().map(|expr| expr.qualified_name().1).collect())
199 }
200}
201
202impl TreeNodeVisitor<'_> for FindGroupByFinalName {
203 type Node = LogicalPlan;
204
205 fn f_down(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
206 if let LogicalPlan::Aggregate(aggregate) = node {
207 self.group_exprs = Some(aggregate.group_expr.iter().cloned().collect());
208 debug!(
209 "FindGroupByFinalName: Get Group by exprs from Aggregate: {:?}",
210 self.group_exprs
211 );
212 } else if let LogicalPlan::Distinct(distinct) = node {
213 debug!("FindGroupByFinalName: Distinct: {}", node);
214 match distinct {
215 Distinct::All(input) => {
216 if let LogicalPlan::TableScan(table_scan) = &**input {
217 let len = table_scan.projected_schema.fields().len();
219 let columns = (0..len)
220 .map(|f| {
221 let (qualifier, field) =
222 table_scan.projected_schema.qualified_field(f);
223 datafusion_common::Column::new(qualifier.cloned(), field.name())
224 })
225 .map(datafusion_expr::Expr::Column);
226 self.group_exprs = Some(columns.collect());
227 } else {
228 self.group_exprs = Some(input.expressions().iter().cloned().collect())
229 }
230 }
231 Distinct::On(distinct_on) => {
232 self.group_exprs = Some(distinct_on.on_expr.iter().cloned().collect())
233 }
234 }
235 debug!(
236 "FindGroupByFinalName: Get Group by exprs from Distinct: {:?}",
237 self.group_exprs
238 );
239 }
240
241 Ok(TreeNodeRecursion::Continue)
242 }
243
244 fn f_up(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
246 if let LogicalPlan::Projection(projection) = node {
247 for expr in &projection.expr {
248 let Some(group_exprs) = &mut self.group_exprs else {
249 return Ok(TreeNodeRecursion::Continue);
250 };
251 if let datafusion_expr::Expr::Alias(alias) = expr {
252 let mut new_group_exprs = group_exprs.clone();
254 for group_expr in group_exprs.iter() {
255 if group_expr.name_for_alias()? == alias.expr.name_for_alias()? {
256 new_group_exprs.remove(group_expr);
257 new_group_exprs.insert(expr.clone());
258 break;
259 }
260 }
261 *group_exprs = new_group_exprs;
262 }
263 }
264 }
265 debug!("Aliased group by exprs: {:?}", self.group_exprs);
266 Ok(TreeNodeRecursion::Continue)
267 }
268}
269
270#[derive(Debug)]
277pub struct ColumnMatcherRewriter {
278 pub schema: SchemaRef,
279 pub is_rewritten: bool,
280 pub primary_key_indices: Vec<usize>,
281 pub allow_partial: bool,
282}
283
284impl ColumnMatcherRewriter {
285 pub fn new(schema: SchemaRef, primary_key_indices: Vec<usize>, allow_partial: bool) -> Self {
286 Self {
287 schema,
288 is_rewritten: false,
289 primary_key_indices,
290 allow_partial,
291 }
292 }
293
294 fn modify_project_exprs(&mut self, mut exprs: Vec<Expr>) -> DfResult<Vec<Expr>> {
296 if self.allow_partial {
297 return self.modify_project_exprs_with_partial(exprs);
298 }
299
300 let all_names = self
301 .schema
302 .column_schemas()
303 .iter()
304 .map(|c| c.name.clone())
305 .collect::<BTreeSet<_>>();
306 for (idx, expr) in exprs.iter_mut().enumerate() {
308 if !all_names.contains(&expr.qualified_name().1)
309 && let Some(col_name) = self
310 .schema
311 .column_schemas()
312 .get(idx)
313 .map(|c| c.name.clone())
314 {
315 *expr = expr.clone().alias(col_name);
319 }
320 }
321
322 let query_col_cnt = exprs.len();
324 let table_col_cnt = self.schema.column_schemas().len();
325 debug!("query_col_cnt={query_col_cnt}, table_col_cnt={table_col_cnt}");
326
327 let placeholder_ts_expr =
328 datafusion::logical_expr::lit(ScalarValue::TimestampMillisecond(Some(0), None))
329 .alias(AUTO_CREATED_PLACEHOLDER_TS_COL);
330
331 if query_col_cnt == table_col_cnt {
332 } else if query_col_cnt + 1 == table_col_cnt {
334 let last_col_schema = self.schema.column_schemas().last().unwrap();
335
336 if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
338 && self.schema.timestamp_index() == Some(table_col_cnt - 1)
339 {
340 exprs.push(placeholder_ts_expr);
341 } else if last_col_schema.data_type.is_timestamp() {
342 exprs.push(datafusion::prelude::now().alias(&last_col_schema.name));
344 } else {
345 return Err(DataFusionError::Plan(format!(
347 "Expect the last column in table to be timestamp column, found column {} with type {:?}",
348 last_col_schema.name, last_col_schema.data_type
349 )));
350 }
351 } else if query_col_cnt + 2 == table_col_cnt {
352 let mut col_iter = self.schema.column_schemas().iter().rev();
353 let last_col_schema = col_iter.next().unwrap();
354 let second_last_col_schema = col_iter.next().unwrap();
355 if second_last_col_schema.data_type.is_timestamp() {
356 exprs.push(datafusion::prelude::now().alias(&second_last_col_schema.name));
357 } else {
358 return Err(DataFusionError::Plan(format!(
359 "Expect the second last column in the table to be timestamp column, found column {} with type {:?}",
360 second_last_col_schema.name, second_last_col_schema.data_type
361 )));
362 }
363
364 if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
365 && self.schema.timestamp_index() == Some(table_col_cnt - 1)
366 {
367 exprs.push(placeholder_ts_expr);
368 } else {
369 return Err(DataFusionError::Plan(format!(
370 "Expect timestamp column {}, found {:?}",
371 AUTO_CREATED_PLACEHOLDER_TS_COL, last_col_schema
372 )));
373 }
374 } else {
375 return Err(DataFusionError::Plan(format!(
376 "Expect table have 0,1 or 2 columns more than query columns, found {} query columns {:?}, {} table columns {:?}",
377 query_col_cnt,
378 exprs,
379 table_col_cnt,
380 self.schema.column_schemas()
381 )));
382 }
383 Ok(exprs)
384 }
385
386 fn modify_project_exprs_with_partial(&mut self, exprs: Vec<Expr>) -> DfResult<Vec<Expr>> {
387 let table_col_cnt = self.schema.column_schemas().len();
388 let query_col_cnt = exprs.len();
389
390 if query_col_cnt > table_col_cnt {
391 return Err(DataFusionError::Plan(format!(
392 "Expect query column count <= table column count, found {} query columns {:?}, {} table columns {:?}",
393 query_col_cnt,
394 exprs,
395 table_col_cnt,
396 self.schema.column_schemas()
397 )));
398 }
399
400 let name_to_expr: HashMap<String, Expr> = exprs
401 .clone()
402 .into_iter()
403 .map(|e| (e.qualified_name().1, e))
404 .collect();
405
406 let required_columns = self.required_columns_for_partial();
407 let missing: Vec<_> = required_columns
408 .iter()
409 .filter(|name| !name_to_expr.contains_key(*name))
410 .cloned()
411 .collect();
412 if !missing.is_empty() {
413 return Err(DataFusionError::Plan(format!(
414 "Column(s) {:?} required by sink table are missing from flow output when merge_mode=last_non_null",
415 missing
416 )));
417 }
418
419 let placeholder_ts_expr =
420 datafusion::logical_expr::lit(ScalarValue::TimestampMillisecond(Some(0), None))
421 .alias(AUTO_CREATED_PLACEHOLDER_TS_COL);
422
423 let timestamp_index = self.schema.timestamp_index();
424 let mut remap = name_to_expr;
425 let mut new_exprs = Vec::with_capacity(table_col_cnt);
426
427 for (idx, col_schema) in self.schema.column_schemas().iter().enumerate() {
428 let col_name = col_schema.name.clone();
429 if let Some(expr) = remap.remove(&col_name) {
430 let expr = if expr.qualified_name().1 == col_name {
431 expr
432 } else {
433 expr.alias(col_name.clone())
434 };
435 new_exprs.push(expr);
436 continue;
437 }
438
439 if col_name == AUTO_CREATED_PLACEHOLDER_TS_COL && timestamp_index == Some(idx) {
440 new_exprs.push(placeholder_ts_expr.clone());
441 continue;
442 }
443
444 if col_name == AUTO_CREATED_UPDATE_AT_TS_COL && col_schema.data_type.is_timestamp() {
445 new_exprs.push(datafusion::prelude::now().alias(&col_name));
446 continue;
447 }
448
449 new_exprs.push(Self::null_expr(col_schema));
450 }
451
452 if !remap.is_empty() {
453 let extra: Vec<_> = remap.keys().cloned().collect();
454 return Err(DataFusionError::Plan(format!(
455 "Flow output has extra column(s) {:?} not found in sink schema when merge_mode=last_non_null",
456 extra
457 )));
458 }
459
460 Ok(new_exprs)
461 }
462
463 fn null_expr(col_schema: &ColumnSchema) -> Expr {
464 Expr::Literal(ScalarValue::Null, None).alias(col_schema.name.clone())
465 }
466
467 fn required_columns_for_partial(&self) -> HashSet<String> {
468 let mut required = HashSet::new();
469 for idx in &self.primary_key_indices {
470 if let Some(col) = self.schema.column_schemas().get(*idx) {
471 required.insert(col.name.clone());
472 }
473 }
474
475 if let Some(ts_idx) = self.schema.timestamp_index()
476 && let Some(col) = self.schema.column_schemas().get(ts_idx)
477 && col.name != AUTO_CREATED_PLACEHOLDER_TS_COL
478 {
479 required.insert(col.name.clone());
480 }
481
482 required
483 }
484}
485
486impl TreeNodeRewriter for ColumnMatcherRewriter {
487 type Node = LogicalPlan;
488 fn f_down(&mut self, mut node: Self::Node) -> DfResult<Transformed<Self::Node>> {
489 if self.is_rewritten {
490 return Ok(Transformed::no(node));
491 }
492
493 if let LogicalPlan::Distinct(Distinct::All(_)) = &node {
495 let mut exprs = vec![];
496
497 for field in node.schema().fields().iter() {
498 exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
499 field.name(),
500 )));
501 }
502
503 let projection =
504 LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
505
506 node = projection;
507 }
508 else if let LogicalPlan::TableScan(table_scan) = node {
510 let mut exprs = vec![];
511
512 for field in table_scan.projected_schema.fields().iter() {
513 exprs.push(Expr::Column(datafusion::common::Column::new(
514 Some(table_scan.table_name.clone()),
515 field.name(),
516 )));
517 }
518
519 let projection = LogicalPlan::Projection(Projection::try_new(
520 exprs,
521 Arc::new(LogicalPlan::TableScan(table_scan)),
522 )?);
523
524 node = projection;
525 }
526
527 if let LogicalPlan::Projection(project) = &node {
531 let exprs = project.expr.clone();
532 let exprs = self.modify_project_exprs(exprs)?;
533
534 self.is_rewritten = true;
535 let new_plan =
536 node.with_new_exprs(exprs, node.inputs().into_iter().cloned().collect())?;
537 Ok(Transformed::yes(new_plan))
538 } else {
539 let mut exprs = vec![];
541 for field in node.schema().fields().iter() {
542 exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
543 field.name(),
544 )));
545 }
546 let exprs = self.modify_project_exprs(exprs)?;
547 self.is_rewritten = true;
548 let new_plan =
549 LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
550 Ok(Transformed::yes(new_plan))
551 }
552 }
553
554 fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
556 node.recompute_schema().map(Transformed::yes)
557 }
558}
559
560#[derive(Debug)]
562pub struct AddFilterRewriter {
563 extra_filter: Expr,
564 is_rewritten: bool,
565}
566
567impl AddFilterRewriter {
568 pub fn new(filter: Expr) -> Self {
569 Self {
570 extra_filter: filter,
571 is_rewritten: false,
572 }
573 }
574}
575
576impl TreeNodeRewriter for AddFilterRewriter {
577 type Node = LogicalPlan;
578 fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
579 if self.is_rewritten {
580 return Ok(Transformed::no(node));
581 }
582 match node {
583 LogicalPlan::Filter(mut filter) => {
584 filter.predicate = filter.predicate.and(self.extra_filter.clone());
585 self.is_rewritten = true;
586 Ok(Transformed::yes(LogicalPlan::Filter(filter)))
587 }
588 LogicalPlan::TableScan(_) => {
589 let filter =
591 datafusion_expr::Filter::try_new(self.extra_filter.clone(), Arc::new(node))?;
592 self.is_rewritten = true;
593 Ok(Transformed::yes(LogicalPlan::Filter(filter)))
594 }
595 _ => Ok(Transformed::no(node)),
596 }
597 }
598}
599
600#[cfg(test)]
601mod test {
602 use std::sync::Arc;
603
604 use datafusion_common::tree_node::TreeNode as _;
605 use datatypes::prelude::ConcreteDataType;
606 use datatypes::schema::{ColumnSchema, Schema};
607 use pretty_assertions::assert_eq;
608 use query::query_engine::DefaultSerializer;
609 use session::context::QueryContext;
610 use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
611
612 use super::*;
613 use crate::test_utils::create_test_query_engine;
614
615 #[tokio::test]
617 async fn test_sql_plan_convert() {
618 let query_engine = create_test_query_engine();
619 let ctx = QueryContext::arc();
620 let old = r#"SELECT "NUMBER" FROM "UPPERCASE_NUMBERS_WITH_TS""#;
621 let new = sql_to_df_plan(ctx.clone(), query_engine.clone(), old, false)
622 .await
623 .unwrap();
624 let new_sql = df_plan_to_sql(&new).unwrap();
625
626 assert_eq!(
627 r#"SELECT `UPPERCASE_NUMBERS_WITH_TS`.`NUMBER` FROM `UPPERCASE_NUMBERS_WITH_TS`"#,
628 new_sql
629 );
630 }
631
632 #[tokio::test]
633 async fn test_add_filter() {
634 let testcases = vec![
635 (
636 "SELECT number FROM numbers_with_ts GROUP BY number",
637 "SELECT numbers_with_ts.number FROM numbers_with_ts WHERE (number > 4) GROUP BY numbers_with_ts.number",
638 ),
639 (
640 "SELECT number FROM numbers_with_ts WHERE number < 2 OR number >10",
641 "SELECT numbers_with_ts.number FROM numbers_with_ts WHERE ((numbers_with_ts.number < 2) OR (numbers_with_ts.number > 10)) AND (number > 4)",
642 ),
643 (
644 "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window",
645 "SELECT date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE (number > 4) GROUP BY date_bin('5 minutes', numbers_with_ts.ts)",
646 ),
647 (
649 "SELECT number, time_window FROM (SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number);",
650 "SELECT numbers_with_ts.number, time_window FROM (SELECT numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE (number > 4) GROUP BY date_bin('5 minutes', numbers_with_ts.ts), numbers_with_ts.number)",
651 ),
652 (
654 "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) GROUP BY number, time_window, bucket_name;",
655 "SELECT sum(numbers_with_ts.number), numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window, bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE (number > 4)) GROUP BY numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts), bucket_name",
656 ),
657 (
659 "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) as cte WHERE number > 1 GROUP BY number, time_window, bucket_name;",
660 "SELECT sum(cte.number), cte.number, date_bin('5 minutes', cte.ts) AS time_window, cte.bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE (number > 4)) AS cte WHERE (cte.number > 1) GROUP BY cte.number, date_bin('5 minutes', cte.ts), cte.bucket_name",
661 ),
662 ];
663 use datafusion_expr::{col, lit};
664 let query_engine = create_test_query_engine();
665 let ctx = QueryContext::arc();
666
667 for (before, after) in testcases {
668 let sql = before;
669 let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
670 .await
671 .unwrap();
672
673 let mut add_filter = AddFilterRewriter::new(col("number").gt(lit(4u32)));
674 let plan = plan.rewrite(&mut add_filter).unwrap().data;
675 let new_sql = df_plan_to_sql(&plan).unwrap();
676 assert_eq!(after, new_sql);
677 }
678 }
679
680 #[tokio::test]
681 async fn test_add_auto_column_rewriter() {
682 let testcases = vec![
683 (
685 "SELECT number FROM numbers_with_ts",
686 Ok("SELECT numbers_with_ts.number, now() AS ts FROM numbers_with_ts"),
687 vec![
688 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
689 ColumnSchema::new(
690 "ts",
691 ConcreteDataType::timestamp_millisecond_datatype(),
692 false,
693 )
694 .with_time_index(true),
695 ],
696 ),
697 (
699 "SELECT number FROM numbers_with_ts",
700 Ok(
701 "SELECT numbers_with_ts.number, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts",
702 ),
703 vec![
704 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
705 ColumnSchema::new(
706 AUTO_CREATED_PLACEHOLDER_TS_COL,
707 ConcreteDataType::timestamp_millisecond_datatype(),
708 false,
709 )
710 .with_time_index(true),
711 ],
712 ),
713 (
715 "SELECT number, ts FROM numbers_with_ts",
716 Ok("SELECT numbers_with_ts.number, numbers_with_ts.ts FROM numbers_with_ts"),
717 vec![
718 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
719 ColumnSchema::new(
720 "ts",
721 ConcreteDataType::timestamp_millisecond_datatype(),
722 false,
723 )
724 .with_time_index(true),
725 ],
726 ),
727 (
729 "SELECT number FROM numbers_with_ts",
730 Ok(
731 "SELECT numbers_with_ts.number, now() AS update_at, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts",
732 ),
733 vec![
734 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
735 ColumnSchema::new(
736 "update_at",
737 ConcreteDataType::timestamp_millisecond_datatype(),
738 false,
739 ),
740 ColumnSchema::new(
741 AUTO_CREATED_PLACEHOLDER_TS_COL,
742 ConcreteDataType::timestamp_millisecond_datatype(),
743 false,
744 )
745 .with_time_index(true),
746 ],
747 ),
748 (
750 "SELECT number, ts FROM numbers_with_ts",
751 Ok(
752 "SELECT numbers_with_ts.number, numbers_with_ts.ts AS update_at, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts",
753 ),
754 vec![
755 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
756 ColumnSchema::new(
757 "update_at",
758 ConcreteDataType::timestamp_millisecond_datatype(),
759 false,
760 ),
761 ColumnSchema::new(
762 AUTO_CREATED_PLACEHOLDER_TS_COL,
763 ConcreteDataType::timestamp_millisecond_datatype(),
764 false,
765 )
766 .with_time_index(true),
767 ],
768 ),
769 (
771 "SELECT number, ts FROM numbers_with_ts",
772 Ok(
773 "SELECT numbers_with_ts.number, numbers_with_ts.ts, now() AS update_atat FROM numbers_with_ts",
774 ),
775 vec![
776 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
777 ColumnSchema::new(
778 "ts",
779 ConcreteDataType::timestamp_millisecond_datatype(),
780 false,
781 )
782 .with_time_index(true),
783 ColumnSchema::new(
784 "update_atat",
786 ConcreteDataType::timestamp_millisecond_datatype(),
787 false,
788 ),
789 ],
790 ),
791 (
793 "SELECT number, ts FROM numbers_with_ts",
794 Err(
795 "Expect the last column in table to be timestamp column, found column atat with type Int8",
796 ),
797 vec![
798 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
799 ColumnSchema::new(
800 "ts",
801 ConcreteDataType::timestamp_millisecond_datatype(),
802 false,
803 )
804 .with_time_index(true),
805 ColumnSchema::new(
806 "atat",
808 ConcreteDataType::int8_datatype(),
809 false,
810 ),
811 ],
812 ),
813 (
815 "SELECT number FROM numbers_with_ts",
816 Err(
817 "Expect the second last column in the table to be timestamp column, found column ts with type Int8",
818 ),
819 vec![
820 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
821 ColumnSchema::new("ts", ConcreteDataType::int8_datatype(), false),
822 ColumnSchema::new(
823 "atat",
825 ConcreteDataType::timestamp_millisecond_datatype(),
826 false,
827 )
828 .with_time_index(true),
829 ],
830 ),
831 ];
832
833 let query_engine = create_test_query_engine();
834 let ctx = QueryContext::arc();
835 for (before, after, column_schemas) in testcases {
836 let schema = Arc::new(Schema::new(column_schemas));
837 let mut add_auto_column_rewriter =
838 ColumnMatcherRewriter::new(schema, Vec::new(), false);
839
840 let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), before, false)
841 .await
842 .unwrap();
843 let new_sql = (|| {
844 let plan = plan
845 .rewrite(&mut add_auto_column_rewriter)
846 .map_err(|e| e.to_string())?
847 .data;
848 df_plan_to_sql(&plan).map_err(|e| e.to_string())
849 })();
850 match (after, new_sql.clone()) {
851 (Ok(after), Ok(new_sql)) => assert_eq!(after, new_sql),
852 (Err(expected), Err(real_err_msg)) => assert!(
853 real_err_msg.contains(expected),
854 "expected: {expected}, real: {real_err_msg}"
855 ),
856 _ => panic!("expected: {:?}, real: {:?}", after, new_sql),
857 }
858 }
859 }
860
861 #[tokio::test]
862 async fn test_find_group_by_exprs() {
863 let testcases = vec![
864 (
865 "SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS ts FROM numbers_with_ts GROUP BY ts;",
866 vec!["ts"],
867 ),
868 (
869 "SELECT number FROM numbers_with_ts GROUP BY number",
870 vec!["number"],
871 ),
872 (
873 "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window",
874 vec!["time_window"],
875 ),
876 (
878 "SELECT number, time_window FROM (SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number);",
879 vec!["time_window", "number"],
880 ),
881 (
883 "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) GROUP BY number, time_window, bucket_name;",
884 vec!["number", "time_window", "bucket_name"],
885 ),
886 (
888 "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) as cte GROUP BY number, time_window, bucket_name;",
889 vec!["number", "time_window", "bucket_name"],
890 ),
891 ];
892
893 let query_engine = create_test_query_engine();
894 let ctx = QueryContext::arc();
895 for (sql, expected) in testcases {
896 let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
898 .await
899 .unwrap();
900 let mut group_by_exprs = FindGroupByFinalName::default();
901 plan.visit(&mut group_by_exprs).unwrap();
902 let expected: HashSet<String> = expected.into_iter().map(|s| s.to_string()).collect();
903 assert_eq!(
904 expected,
905 group_by_exprs.get_group_expr_names().unwrap_or_default()
906 );
907 }
908 }
909
910 #[tokio::test]
911 async fn test_null_cast() {
912 let query_engine = create_test_query_engine();
913 let ctx = QueryContext::arc();
914 let sql = "SELECT NULL::DOUBLE FROM numbers_with_ts";
915 let plan = sql_to_df_plan(ctx, query_engine.clone(), sql, false)
916 .await
917 .unwrap();
918
919 let _sub_plan = DFLogicalSubstraitConvertor {}
920 .encode(&plan, DefaultSerializer)
921 .unwrap();
922 }
923}