1use std::collections::{BTreeSet, HashSet};
18use std::sync::Arc;
19
20use catalog::CatalogManagerRef;
21use common_error::ext::BoxedError;
22use common_telemetry::debug;
23use datafusion::error::Result as DfResult;
24use datafusion::logical_expr::Expr;
25use datafusion::sql::unparser::Unparser;
26use datafusion_common::tree_node::{
27 Transformed, TreeNode as _, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
28};
29use datafusion_common::{DFSchema, DataFusionError, ScalarValue};
30use datafusion_expr::{Distinct, LogicalPlan, Projection};
31use datatypes::schema::SchemaRef;
32use query::QueryEngineRef;
33use query::parser::{DEFAULT_LOOKBACK_STRING, PromQuery, QueryLanguageParser, QueryStatement};
34use session::context::QueryContextRef;
35use snafu::{OptionExt, ResultExt, ensure};
36use sql::parser::{ParseOptions, ParserContext};
37use sql::statements::statement::Statement;
38use sql::statements::tql::Tql;
39use table::TableRef;
40
41use crate::adapter::AUTO_CREATED_PLACEHOLDER_TS_COL;
42use crate::df_optimizer::apply_df_optimizer;
43use crate::error::{DatafusionSnafu, ExternalSnafu, InvalidQuerySnafu, TableNotFoundSnafu};
44use crate::{Error, TableName};
45
46pub async fn get_table_info_df_schema(
47 catalog_mr: CatalogManagerRef,
48 table_name: TableName,
49) -> Result<(TableRef, Arc<DFSchema>), Error> {
50 let full_table_name = table_name.clone().join(".");
51 let table = catalog_mr
52 .table(&table_name[0], &table_name[1], &table_name[2], None)
53 .await
54 .map_err(BoxedError::new)
55 .context(ExternalSnafu)?
56 .context(TableNotFoundSnafu {
57 name: &full_table_name,
58 })?;
59 let table_info = table.table_info();
60
61 let schema = table_info.meta.schema.clone();
62
63 let df_schema: Arc<DFSchema> = Arc::new(
64 schema
65 .arrow_schema()
66 .clone()
67 .try_into()
68 .with_context(|_| DatafusionSnafu {
69 context: format!(
70 "Failed to convert arrow schema to datafusion schema, arrow_schema={:?}",
71 schema.arrow_schema()
72 ),
73 })?,
74 );
75 Ok((table, df_schema))
76}
77
78pub async fn sql_to_df_plan(
81 query_ctx: QueryContextRef,
82 engine: QueryEngineRef,
83 sql: &str,
84 optimize: bool,
85) -> Result<LogicalPlan, Error> {
86 let stmts =
87 ParserContext::create_with_dialect(sql, query_ctx.sql_dialect(), ParseOptions::default())
88 .map_err(BoxedError::new)
89 .context(ExternalSnafu)?;
90
91 ensure!(
92 stmts.len() == 1,
93 InvalidQuerySnafu {
94 reason: format!("Expect only one statement, found {}", stmts.len())
95 }
96 );
97 let stmt = &stmts[0];
98 let query_stmt = match stmt {
99 Statement::Tql(tql) => match tql {
100 Tql::Eval(eval) => {
101 let eval = eval.clone();
102 let promql = PromQuery {
103 start: eval.start,
104 end: eval.end,
105 step: eval.step,
106 query: eval.query,
107 lookback: eval
108 .lookback
109 .unwrap_or_else(|| DEFAULT_LOOKBACK_STRING.to_string()),
110 };
111
112 QueryLanguageParser::parse_promql(&promql, &query_ctx)
113 .map_err(BoxedError::new)
114 .context(ExternalSnafu)?
115 }
116 _ => InvalidQuerySnafu {
117 reason: format!("TQL statement {tql:?} is not supported, expect only TQL EVAL"),
118 }
119 .fail()?,
120 },
121 _ => QueryStatement::Sql(stmt.clone()),
122 };
123 let plan = engine
124 .planner()
125 .plan(&query_stmt, query_ctx.clone())
126 .await
127 .map_err(BoxedError::new)
128 .context(ExternalSnafu)?;
129
130 let plan = if optimize {
131 apply_df_optimizer(plan, &query_ctx).await?
132 } else {
133 plan
134 };
135 Ok(plan)
136}
137
138pub(crate) async fn gen_plan_with_matching_schema(
141 sql: &str,
142 query_ctx: QueryContextRef,
143 engine: QueryEngineRef,
144 sink_table_schema: SchemaRef,
145) -> Result<LogicalPlan, Error> {
146 let plan = sql_to_df_plan(query_ctx.clone(), engine.clone(), sql, false).await?;
147
148 let mut add_auto_column = ColumnMatcherRewriter::new(sink_table_schema);
149 let plan = plan
150 .clone()
151 .rewrite(&mut add_auto_column)
152 .with_context(|_| DatafusionSnafu {
153 context: format!("Failed to rewrite plan:\n {}\n", plan),
154 })?
155 .data;
156 Ok(plan)
157}
158
159pub fn df_plan_to_sql(plan: &LogicalPlan) -> Result<String, Error> {
160 struct ForceQuoteIdentifiers;
162 impl datafusion::sql::unparser::dialect::Dialect for ForceQuoteIdentifiers {
163 fn identifier_quote_style(&self, identifier: &str) -> Option<char> {
164 if identifier.to_lowercase() != identifier {
165 Some('`')
166 } else {
167 None
168 }
169 }
170 }
171 let unparser = Unparser::new(&ForceQuoteIdentifiers);
172 let sql = unparser
174 .plan_to_sql(plan)
175 .with_context(|_e| DatafusionSnafu {
176 context: format!("Failed to unparse logical plan {plan:?}"),
177 })?;
178 Ok(sql.to_string())
179}
180
181#[derive(Debug, Clone, Default)]
183pub struct FindGroupByFinalName {
184 group_exprs: Option<HashSet<datafusion_expr::Expr>>,
185}
186
187impl FindGroupByFinalName {
188 pub fn get_group_expr_names(&self) -> Option<HashSet<String>> {
189 self.group_exprs
190 .as_ref()
191 .map(|exprs| exprs.iter().map(|expr| expr.qualified_name().1).collect())
192 }
193}
194
195impl TreeNodeVisitor<'_> for FindGroupByFinalName {
196 type Node = LogicalPlan;
197
198 fn f_down(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
199 if let LogicalPlan::Aggregate(aggregate) = node {
200 self.group_exprs = Some(aggregate.group_expr.iter().cloned().collect());
201 debug!(
202 "FindGroupByFinalName: Get Group by exprs from Aggregate: {:?}",
203 self.group_exprs
204 );
205 } else if let LogicalPlan::Distinct(distinct) = node {
206 debug!("FindGroupByFinalName: Distinct: {}", node);
207 match distinct {
208 Distinct::All(input) => {
209 if let LogicalPlan::TableScan(table_scan) = &**input {
210 let len = table_scan.projected_schema.fields().len();
212 let columns = (0..len)
213 .map(|f| {
214 let (qualifier, field) =
215 table_scan.projected_schema.qualified_field(f);
216 datafusion_common::Column::new(qualifier.cloned(), field.name())
217 })
218 .map(datafusion_expr::Expr::Column);
219 self.group_exprs = Some(columns.collect());
220 } else {
221 self.group_exprs = Some(input.expressions().iter().cloned().collect())
222 }
223 }
224 Distinct::On(distinct_on) => {
225 self.group_exprs = Some(distinct_on.on_expr.iter().cloned().collect())
226 }
227 }
228 debug!(
229 "FindGroupByFinalName: Get Group by exprs from Distinct: {:?}",
230 self.group_exprs
231 );
232 }
233
234 Ok(TreeNodeRecursion::Continue)
235 }
236
237 fn f_up(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
239 if let LogicalPlan::Projection(projection) = node {
240 for expr in &projection.expr {
241 let Some(group_exprs) = &mut self.group_exprs else {
242 return Ok(TreeNodeRecursion::Continue);
243 };
244 if let datafusion_expr::Expr::Alias(alias) = expr {
245 let mut new_group_exprs = group_exprs.clone();
247 for group_expr in group_exprs.iter() {
248 if group_expr.name_for_alias()? == alias.expr.name_for_alias()? {
249 new_group_exprs.remove(group_expr);
250 new_group_exprs.insert(expr.clone());
251 break;
252 }
253 }
254 *group_exprs = new_group_exprs;
255 }
256 }
257 }
258 debug!("Aliased group by exprs: {:?}", self.group_exprs);
259 Ok(TreeNodeRecursion::Continue)
260 }
261}
262
263#[derive(Debug)]
270pub struct ColumnMatcherRewriter {
271 pub schema: SchemaRef,
272 pub is_rewritten: bool,
273}
274
275impl ColumnMatcherRewriter {
276 pub fn new(schema: SchemaRef) -> Self {
277 Self {
278 schema,
279 is_rewritten: false,
280 }
281 }
282
283 fn modify_project_exprs(&mut self, mut exprs: Vec<Expr>) -> DfResult<Vec<Expr>> {
285 let all_names = self
286 .schema
287 .column_schemas()
288 .iter()
289 .map(|c| c.name.clone())
290 .collect::<BTreeSet<_>>();
291 for (idx, expr) in exprs.iter_mut().enumerate() {
293 if !all_names.contains(&expr.qualified_name().1)
294 && let Some(col_name) = self
295 .schema
296 .column_schemas()
297 .get(idx)
298 .map(|c| c.name.clone())
299 {
300 *expr = expr.clone().alias(col_name);
304 }
305 }
306
307 let query_col_cnt = exprs.len();
309 let table_col_cnt = self.schema.column_schemas().len();
310 debug!("query_col_cnt={query_col_cnt}, table_col_cnt={table_col_cnt}");
311
312 let placeholder_ts_expr =
313 datafusion::logical_expr::lit(ScalarValue::TimestampMillisecond(Some(0), None))
314 .alias(AUTO_CREATED_PLACEHOLDER_TS_COL);
315
316 if query_col_cnt == table_col_cnt {
317 } else if query_col_cnt + 1 == table_col_cnt {
319 let last_col_schema = self.schema.column_schemas().last().unwrap();
320
321 if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
323 && self.schema.timestamp_index() == Some(table_col_cnt - 1)
324 {
325 exprs.push(placeholder_ts_expr);
326 } else if last_col_schema.data_type.is_timestamp() {
327 exprs.push(datafusion::prelude::now().alias(&last_col_schema.name));
329 } else {
330 return Err(DataFusionError::Plan(format!(
332 "Expect the last column in table to be timestamp column, found column {} with type {:?}",
333 last_col_schema.name, last_col_schema.data_type
334 )));
335 }
336 } else if query_col_cnt + 2 == table_col_cnt {
337 let mut col_iter = self.schema.column_schemas().iter().rev();
338 let last_col_schema = col_iter.next().unwrap();
339 let second_last_col_schema = col_iter.next().unwrap();
340 if second_last_col_schema.data_type.is_timestamp() {
341 exprs.push(datafusion::prelude::now().alias(&second_last_col_schema.name));
342 } else {
343 return Err(DataFusionError::Plan(format!(
344 "Expect the second last column in the table to be timestamp column, found column {} with type {:?}",
345 second_last_col_schema.name, second_last_col_schema.data_type
346 )));
347 }
348
349 if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
350 && self.schema.timestamp_index() == Some(table_col_cnt - 1)
351 {
352 exprs.push(placeholder_ts_expr);
353 } else {
354 return Err(DataFusionError::Plan(format!(
355 "Expect timestamp column {}, found {:?}",
356 AUTO_CREATED_PLACEHOLDER_TS_COL, last_col_schema
357 )));
358 }
359 } else {
360 return Err(DataFusionError::Plan(format!(
361 "Expect table have 0,1 or 2 columns more than query columns, found {} query columns {:?}, {} table columns {:?}",
362 query_col_cnt,
363 exprs,
364 table_col_cnt,
365 self.schema.column_schemas()
366 )));
367 }
368 Ok(exprs)
369 }
370}
371
372impl TreeNodeRewriter for ColumnMatcherRewriter {
373 type Node = LogicalPlan;
374 fn f_down(&mut self, mut node: Self::Node) -> DfResult<Transformed<Self::Node>> {
375 if self.is_rewritten {
376 return Ok(Transformed::no(node));
377 }
378
379 if let LogicalPlan::Distinct(Distinct::All(_)) = &node {
381 let mut exprs = vec![];
382
383 for field in node.schema().fields().iter() {
384 exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
385 field.name(),
386 )));
387 }
388
389 let projection =
390 LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
391
392 node = projection;
393 }
394 else if let LogicalPlan::TableScan(table_scan) = node {
396 let mut exprs = vec![];
397
398 for field in table_scan.projected_schema.fields().iter() {
399 exprs.push(Expr::Column(datafusion::common::Column::new(
400 Some(table_scan.table_name.clone()),
401 field.name(),
402 )));
403 }
404
405 let projection = LogicalPlan::Projection(Projection::try_new(
406 exprs,
407 Arc::new(LogicalPlan::TableScan(table_scan)),
408 )?);
409
410 node = projection;
411 }
412
413 if let LogicalPlan::Projection(project) = &node {
417 let exprs = project.expr.clone();
418 let exprs = self.modify_project_exprs(exprs)?;
419
420 self.is_rewritten = true;
421 let new_plan =
422 node.with_new_exprs(exprs, node.inputs().into_iter().cloned().collect())?;
423 Ok(Transformed::yes(new_plan))
424 } else {
425 let mut exprs = vec![];
427 for field in node.schema().fields().iter() {
428 exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
429 field.name(),
430 )));
431 }
432 let exprs = self.modify_project_exprs(exprs)?;
433 self.is_rewritten = true;
434 let new_plan =
435 LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
436 Ok(Transformed::yes(new_plan))
437 }
438 }
439
440 fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
442 node.recompute_schema().map(Transformed::yes)
443 }
444}
445
446#[derive(Debug)]
448pub struct AddFilterRewriter {
449 extra_filter: Expr,
450 is_rewritten: bool,
451}
452
453impl AddFilterRewriter {
454 pub fn new(filter: Expr) -> Self {
455 Self {
456 extra_filter: filter,
457 is_rewritten: false,
458 }
459 }
460}
461
462impl TreeNodeRewriter for AddFilterRewriter {
463 type Node = LogicalPlan;
464 fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
465 if self.is_rewritten {
466 return Ok(Transformed::no(node));
467 }
468 match node {
469 LogicalPlan::Filter(mut filter) => {
470 filter.predicate = filter.predicate.and(self.extra_filter.clone());
471 self.is_rewritten = true;
472 Ok(Transformed::yes(LogicalPlan::Filter(filter)))
473 }
474 LogicalPlan::TableScan(_) => {
475 let filter =
477 datafusion_expr::Filter::try_new(self.extra_filter.clone(), Arc::new(node))?;
478 self.is_rewritten = true;
479 Ok(Transformed::yes(LogicalPlan::Filter(filter)))
480 }
481 _ => Ok(Transformed::no(node)),
482 }
483 }
484}
485
486#[cfg(test)]
487mod test {
488 use std::sync::Arc;
489
490 use datafusion_common::tree_node::TreeNode as _;
491 use datatypes::prelude::ConcreteDataType;
492 use datatypes::schema::{ColumnSchema, Schema};
493 use pretty_assertions::assert_eq;
494 use query::query_engine::DefaultSerializer;
495 use session::context::QueryContext;
496 use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
497
498 use super::*;
499 use crate::test_utils::create_test_query_engine;
500
501 #[tokio::test]
503 async fn test_sql_plan_convert() {
504 let query_engine = create_test_query_engine();
505 let ctx = QueryContext::arc();
506 let old = r#"SELECT "NUMBER" FROM "UPPERCASE_NUMBERS_WITH_TS""#;
507 let new = sql_to_df_plan(ctx.clone(), query_engine.clone(), old, false)
508 .await
509 .unwrap();
510 let new_sql = df_plan_to_sql(&new).unwrap();
511
512 assert_eq!(
513 r#"SELECT `UPPERCASE_NUMBERS_WITH_TS`.`NUMBER` FROM `UPPERCASE_NUMBERS_WITH_TS`"#,
514 new_sql
515 );
516 }
517
518 #[tokio::test]
519 async fn test_add_filter() {
520 let testcases = vec![
521 (
522 "SELECT number FROM numbers_with_ts GROUP BY number",
523 "SELECT numbers_with_ts.number FROM numbers_with_ts WHERE (number > 4) GROUP BY numbers_with_ts.number",
524 ),
525 (
526 "SELECT number FROM numbers_with_ts WHERE number < 2 OR number >10",
527 "SELECT numbers_with_ts.number FROM numbers_with_ts WHERE ((numbers_with_ts.number < 2) OR (numbers_with_ts.number > 10)) AND (number > 4)",
528 ),
529 (
530 "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window",
531 "SELECT date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE (number > 4) GROUP BY date_bin('5 minutes', numbers_with_ts.ts)",
532 ),
533 (
535 "SELECT number, time_window FROM (SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number);",
536 "SELECT numbers_with_ts.number, time_window FROM (SELECT numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE (number > 4) GROUP BY date_bin('5 minutes', numbers_with_ts.ts), numbers_with_ts.number)",
537 ),
538 (
540 "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) GROUP BY number, time_window, bucket_name;",
541 "SELECT sum(numbers_with_ts.number), numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window, bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE (number > 4)) GROUP BY numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts), bucket_name",
542 ),
543 (
545 "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) as cte WHERE number > 1 GROUP BY number, time_window, bucket_name;",
546 "SELECT sum(cte.number), cte.number, date_bin('5 minutes', cte.ts) AS time_window, cte.bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE (number > 4)) AS cte WHERE (cte.number > 1) GROUP BY cte.number, date_bin('5 minutes', cte.ts), cte.bucket_name",
547 ),
548 ];
549 use datafusion_expr::{col, lit};
550 let query_engine = create_test_query_engine();
551 let ctx = QueryContext::arc();
552
553 for (before, after) in testcases {
554 let sql = before;
555 let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
556 .await
557 .unwrap();
558
559 let mut add_filter = AddFilterRewriter::new(col("number").gt(lit(4u32)));
560 let plan = plan.rewrite(&mut add_filter).unwrap().data;
561 let new_sql = df_plan_to_sql(&plan).unwrap();
562 assert_eq!(after, new_sql);
563 }
564 }
565
566 #[tokio::test]
567 async fn test_add_auto_column_rewriter() {
568 let testcases = vec![
569 (
571 "SELECT number FROM numbers_with_ts",
572 Ok("SELECT numbers_with_ts.number, now() AS ts FROM numbers_with_ts"),
573 vec![
574 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
575 ColumnSchema::new(
576 "ts",
577 ConcreteDataType::timestamp_millisecond_datatype(),
578 false,
579 )
580 .with_time_index(true),
581 ],
582 ),
583 (
585 "SELECT number FROM numbers_with_ts",
586 Ok(
587 "SELECT numbers_with_ts.number, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts",
588 ),
589 vec![
590 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
591 ColumnSchema::new(
592 AUTO_CREATED_PLACEHOLDER_TS_COL,
593 ConcreteDataType::timestamp_millisecond_datatype(),
594 false,
595 )
596 .with_time_index(true),
597 ],
598 ),
599 (
601 "SELECT number, ts FROM numbers_with_ts",
602 Ok("SELECT numbers_with_ts.number, numbers_with_ts.ts FROM numbers_with_ts"),
603 vec![
604 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
605 ColumnSchema::new(
606 "ts",
607 ConcreteDataType::timestamp_millisecond_datatype(),
608 false,
609 )
610 .with_time_index(true),
611 ],
612 ),
613 (
615 "SELECT number FROM numbers_with_ts",
616 Ok(
617 "SELECT numbers_with_ts.number, now() AS update_at, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts",
618 ),
619 vec![
620 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
621 ColumnSchema::new(
622 "update_at",
623 ConcreteDataType::timestamp_millisecond_datatype(),
624 false,
625 ),
626 ColumnSchema::new(
627 AUTO_CREATED_PLACEHOLDER_TS_COL,
628 ConcreteDataType::timestamp_millisecond_datatype(),
629 false,
630 )
631 .with_time_index(true),
632 ],
633 ),
634 (
636 "SELECT number, ts FROM numbers_with_ts",
637 Ok(
638 "SELECT numbers_with_ts.number, numbers_with_ts.ts AS update_at, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts",
639 ),
640 vec![
641 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
642 ColumnSchema::new(
643 "update_at",
644 ConcreteDataType::timestamp_millisecond_datatype(),
645 false,
646 ),
647 ColumnSchema::new(
648 AUTO_CREATED_PLACEHOLDER_TS_COL,
649 ConcreteDataType::timestamp_millisecond_datatype(),
650 false,
651 )
652 .with_time_index(true),
653 ],
654 ),
655 (
657 "SELECT number, ts FROM numbers_with_ts",
658 Ok(
659 "SELECT numbers_with_ts.number, numbers_with_ts.ts, now() AS update_atat FROM numbers_with_ts",
660 ),
661 vec![
662 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
663 ColumnSchema::new(
664 "ts",
665 ConcreteDataType::timestamp_millisecond_datatype(),
666 false,
667 )
668 .with_time_index(true),
669 ColumnSchema::new(
670 "update_atat",
672 ConcreteDataType::timestamp_millisecond_datatype(),
673 false,
674 ),
675 ],
676 ),
677 (
679 "SELECT number, ts FROM numbers_with_ts",
680 Err(
681 "Expect the last column in table to be timestamp column, found column atat with type Int8",
682 ),
683 vec![
684 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
685 ColumnSchema::new(
686 "ts",
687 ConcreteDataType::timestamp_millisecond_datatype(),
688 false,
689 )
690 .with_time_index(true),
691 ColumnSchema::new(
692 "atat",
694 ConcreteDataType::int8_datatype(),
695 false,
696 ),
697 ],
698 ),
699 (
701 "SELECT number FROM numbers_with_ts",
702 Err(
703 "Expect the second last column in the table to be timestamp column, found column ts with type Int8",
704 ),
705 vec![
706 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
707 ColumnSchema::new("ts", ConcreteDataType::int8_datatype(), false),
708 ColumnSchema::new(
709 "atat",
711 ConcreteDataType::timestamp_millisecond_datatype(),
712 false,
713 )
714 .with_time_index(true),
715 ],
716 ),
717 ];
718
719 let query_engine = create_test_query_engine();
720 let ctx = QueryContext::arc();
721 for (before, after, column_schemas) in testcases {
722 let schema = Arc::new(Schema::new(column_schemas));
723 let mut add_auto_column_rewriter = ColumnMatcherRewriter::new(schema);
724
725 let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), before, false)
726 .await
727 .unwrap();
728 let new_sql = (|| {
729 let plan = plan
730 .rewrite(&mut add_auto_column_rewriter)
731 .map_err(|e| e.to_string())?
732 .data;
733 df_plan_to_sql(&plan).map_err(|e| e.to_string())
734 })();
735 match (after, new_sql.clone()) {
736 (Ok(after), Ok(new_sql)) => assert_eq!(after, new_sql),
737 (Err(expected), Err(real_err_msg)) => assert!(
738 real_err_msg.contains(expected),
739 "expected: {expected}, real: {real_err_msg}"
740 ),
741 _ => panic!("expected: {:?}, real: {:?}", after, new_sql),
742 }
743 }
744 }
745
746 #[tokio::test]
747 async fn test_find_group_by_exprs() {
748 let testcases = vec![
749 (
750 "SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS ts FROM numbers_with_ts GROUP BY ts;",
751 vec!["ts"],
752 ),
753 (
754 "SELECT number FROM numbers_with_ts GROUP BY number",
755 vec!["number"],
756 ),
757 (
758 "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window",
759 vec!["time_window"],
760 ),
761 (
763 "SELECT number, time_window FROM (SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number);",
764 vec!["time_window", "number"],
765 ),
766 (
768 "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) GROUP BY number, time_window, bucket_name;",
769 vec!["number", "time_window", "bucket_name"],
770 ),
771 (
773 "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) as cte GROUP BY number, time_window, bucket_name;",
774 vec!["number", "time_window", "bucket_name"],
775 ),
776 ];
777
778 let query_engine = create_test_query_engine();
779 let ctx = QueryContext::arc();
780 for (sql, expected) in testcases {
781 let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
783 .await
784 .unwrap();
785 let mut group_by_exprs = FindGroupByFinalName::default();
786 plan.visit(&mut group_by_exprs).unwrap();
787 let expected: HashSet<String> = expected.into_iter().map(|s| s.to_string()).collect();
788 assert_eq!(
789 expected,
790 group_by_exprs.get_group_expr_names().unwrap_or_default()
791 );
792 }
793 }
794
795 #[tokio::test]
796 async fn test_null_cast() {
797 let query_engine = create_test_query_engine();
798 let ctx = QueryContext::arc();
799 let sql = "SELECT NULL::DOUBLE FROM numbers_with_ts";
800 let plan = sql_to_df_plan(ctx, query_engine.clone(), sql, false)
801 .await
802 .unwrap();
803
804 let _sub_plan = DFLogicalSubstraitConvertor {}
805 .encode(&plan, DefaultSerializer)
806 .unwrap();
807 }
808}