1use std::collections::{BTreeSet, HashSet};
18use std::sync::Arc;
19
20use catalog::CatalogManagerRef;
21use common_error::ext::BoxedError;
22use common_telemetry::debug;
23use datafusion::error::Result as DfResult;
24use datafusion::logical_expr::Expr;
25use datafusion::sql::unparser::Unparser;
26use datafusion_common::tree_node::{
27 Transformed, TreeNode as _, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
28};
29use datafusion_common::{DFSchema, DataFusionError, ScalarValue};
30use datafusion_expr::{Distinct, LogicalPlan, Projection};
31use datatypes::schema::SchemaRef;
32use query::QueryEngineRef;
33use query::parser::{DEFAULT_LOOKBACK_STRING, PromQuery, QueryLanguageParser, QueryStatement};
34use session::context::QueryContextRef;
35use snafu::{OptionExt, ResultExt, ensure};
36use sql::parser::{ParseOptions, ParserContext};
37use sql::statements::statement::Statement;
38use sql::statements::tql::Tql;
39use table::TableRef;
40
41use crate::adapter::AUTO_CREATED_PLACEHOLDER_TS_COL;
42use crate::df_optimizer::apply_df_optimizer;
43use crate::error::{DatafusionSnafu, ExternalSnafu, InvalidQuerySnafu, TableNotFoundSnafu};
44use crate::{Error, TableName};
45
46pub async fn get_table_info_df_schema(
47 catalog_mr: CatalogManagerRef,
48 table_name: TableName,
49) -> Result<(TableRef, Arc<DFSchema>), Error> {
50 let full_table_name = table_name.clone().join(".");
51 let table = catalog_mr
52 .table(&table_name[0], &table_name[1], &table_name[2], None)
53 .await
54 .map_err(BoxedError::new)
55 .context(ExternalSnafu)?
56 .context(TableNotFoundSnafu {
57 name: &full_table_name,
58 })?;
59 let table_info = table.table_info();
60
61 let schema = table_info.meta.schema.clone();
62
63 let df_schema: Arc<DFSchema> = Arc::new(
64 schema
65 .arrow_schema()
66 .clone()
67 .try_into()
68 .with_context(|_| DatafusionSnafu {
69 context: format!(
70 "Failed to convert arrow schema to datafusion schema, arrow_schema={:?}",
71 schema.arrow_schema()
72 ),
73 })?,
74 );
75 Ok((table, df_schema))
76}
77
78pub async fn sql_to_df_plan(
81 query_ctx: QueryContextRef,
82 engine: QueryEngineRef,
83 sql: &str,
84 optimize: bool,
85) -> Result<LogicalPlan, Error> {
86 let stmts =
87 ParserContext::create_with_dialect(sql, query_ctx.sql_dialect(), ParseOptions::default())
88 .map_err(BoxedError::new)
89 .context(ExternalSnafu)?;
90
91 ensure!(
92 stmts.len() == 1,
93 InvalidQuerySnafu {
94 reason: format!("Expect only one statement, found {}", stmts.len())
95 }
96 );
97 let stmt = &stmts[0];
98 let query_stmt = match stmt {
99 Statement::Tql(tql) => match tql {
100 Tql::Eval(eval) => {
101 let eval = eval.clone();
102 let promql = PromQuery {
103 start: eval.start,
104 end: eval.end,
105 step: eval.step,
106 query: eval.query,
107 lookback: eval
108 .lookback
109 .unwrap_or_else(|| DEFAULT_LOOKBACK_STRING.to_string()),
110 alias: eval.alias.clone(),
111 };
112
113 QueryLanguageParser::parse_promql(&promql, &query_ctx)
114 .map_err(BoxedError::new)
115 .context(ExternalSnafu)?
116 }
117 _ => InvalidQuerySnafu {
118 reason: format!("TQL statement {tql:?} is not supported, expect only TQL EVAL"),
119 }
120 .fail()?,
121 },
122 _ => QueryStatement::Sql(stmt.clone()),
123 };
124 let plan = engine
125 .planner()
126 .plan(&query_stmt, query_ctx.clone())
127 .await
128 .map_err(BoxedError::new)
129 .context(ExternalSnafu)?;
130
131 let plan = if optimize {
132 apply_df_optimizer(plan, &query_ctx).await?
133 } else {
134 plan
135 };
136 Ok(plan)
137}
138
139pub(crate) async fn gen_plan_with_matching_schema(
142 sql: &str,
143 query_ctx: QueryContextRef,
144 engine: QueryEngineRef,
145 sink_table_schema: SchemaRef,
146) -> Result<LogicalPlan, Error> {
147 let plan = sql_to_df_plan(query_ctx.clone(), engine.clone(), sql, false).await?;
148
149 let mut add_auto_column = ColumnMatcherRewriter::new(sink_table_schema);
150 let plan = plan
151 .clone()
152 .rewrite(&mut add_auto_column)
153 .with_context(|_| DatafusionSnafu {
154 context: format!("Failed to rewrite plan:\n {}\n", plan),
155 })?
156 .data;
157 Ok(plan)
158}
159
160pub fn df_plan_to_sql(plan: &LogicalPlan) -> Result<String, Error> {
161 struct ForceQuoteIdentifiers;
163 impl datafusion::sql::unparser::dialect::Dialect for ForceQuoteIdentifiers {
164 fn identifier_quote_style(&self, identifier: &str) -> Option<char> {
165 if identifier.to_lowercase() != identifier {
166 Some('`')
167 } else {
168 None
169 }
170 }
171 }
172 let unparser = Unparser::new(&ForceQuoteIdentifiers);
173 let sql = unparser
175 .plan_to_sql(plan)
176 .with_context(|_e| DatafusionSnafu {
177 context: format!("Failed to unparse logical plan {plan:?}"),
178 })?;
179 Ok(sql.to_string())
180}
181
182#[derive(Debug, Clone, Default)]
184pub struct FindGroupByFinalName {
185 group_exprs: Option<HashSet<datafusion_expr::Expr>>,
186}
187
188impl FindGroupByFinalName {
189 pub fn get_group_expr_names(&self) -> Option<HashSet<String>> {
190 self.group_exprs
191 .as_ref()
192 .map(|exprs| exprs.iter().map(|expr| expr.qualified_name().1).collect())
193 }
194}
195
196impl TreeNodeVisitor<'_> for FindGroupByFinalName {
197 type Node = LogicalPlan;
198
199 fn f_down(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
200 if let LogicalPlan::Aggregate(aggregate) = node {
201 self.group_exprs = Some(aggregate.group_expr.iter().cloned().collect());
202 debug!(
203 "FindGroupByFinalName: Get Group by exprs from Aggregate: {:?}",
204 self.group_exprs
205 );
206 } else if let LogicalPlan::Distinct(distinct) = node {
207 debug!("FindGroupByFinalName: Distinct: {}", node);
208 match distinct {
209 Distinct::All(input) => {
210 if let LogicalPlan::TableScan(table_scan) = &**input {
211 let len = table_scan.projected_schema.fields().len();
213 let columns = (0..len)
214 .map(|f| {
215 let (qualifier, field) =
216 table_scan.projected_schema.qualified_field(f);
217 datafusion_common::Column::new(qualifier.cloned(), field.name())
218 })
219 .map(datafusion_expr::Expr::Column);
220 self.group_exprs = Some(columns.collect());
221 } else {
222 self.group_exprs = Some(input.expressions().iter().cloned().collect())
223 }
224 }
225 Distinct::On(distinct_on) => {
226 self.group_exprs = Some(distinct_on.on_expr.iter().cloned().collect())
227 }
228 }
229 debug!(
230 "FindGroupByFinalName: Get Group by exprs from Distinct: {:?}",
231 self.group_exprs
232 );
233 }
234
235 Ok(TreeNodeRecursion::Continue)
236 }
237
238 fn f_up(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
240 if let LogicalPlan::Projection(projection) = node {
241 for expr in &projection.expr {
242 let Some(group_exprs) = &mut self.group_exprs else {
243 return Ok(TreeNodeRecursion::Continue);
244 };
245 if let datafusion_expr::Expr::Alias(alias) = expr {
246 let mut new_group_exprs = group_exprs.clone();
248 for group_expr in group_exprs.iter() {
249 if group_expr.name_for_alias()? == alias.expr.name_for_alias()? {
250 new_group_exprs.remove(group_expr);
251 new_group_exprs.insert(expr.clone());
252 break;
253 }
254 }
255 *group_exprs = new_group_exprs;
256 }
257 }
258 }
259 debug!("Aliased group by exprs: {:?}", self.group_exprs);
260 Ok(TreeNodeRecursion::Continue)
261 }
262}
263
264#[derive(Debug)]
271pub struct ColumnMatcherRewriter {
272 pub schema: SchemaRef,
273 pub is_rewritten: bool,
274}
275
276impl ColumnMatcherRewriter {
277 pub fn new(schema: SchemaRef) -> Self {
278 Self {
279 schema,
280 is_rewritten: false,
281 }
282 }
283
284 fn modify_project_exprs(&mut self, mut exprs: Vec<Expr>) -> DfResult<Vec<Expr>> {
286 let all_names = self
287 .schema
288 .column_schemas()
289 .iter()
290 .map(|c| c.name.clone())
291 .collect::<BTreeSet<_>>();
292 for (idx, expr) in exprs.iter_mut().enumerate() {
294 if !all_names.contains(&expr.qualified_name().1)
295 && let Some(col_name) = self
296 .schema
297 .column_schemas()
298 .get(idx)
299 .map(|c| c.name.clone())
300 {
301 *expr = expr.clone().alias(col_name);
305 }
306 }
307
308 let query_col_cnt = exprs.len();
310 let table_col_cnt = self.schema.column_schemas().len();
311 debug!("query_col_cnt={query_col_cnt}, table_col_cnt={table_col_cnt}");
312
313 let placeholder_ts_expr =
314 datafusion::logical_expr::lit(ScalarValue::TimestampMillisecond(Some(0), None))
315 .alias(AUTO_CREATED_PLACEHOLDER_TS_COL);
316
317 if query_col_cnt == table_col_cnt {
318 } else if query_col_cnt + 1 == table_col_cnt {
320 let last_col_schema = self.schema.column_schemas().last().unwrap();
321
322 if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
324 && self.schema.timestamp_index() == Some(table_col_cnt - 1)
325 {
326 exprs.push(placeholder_ts_expr);
327 } else if last_col_schema.data_type.is_timestamp() {
328 exprs.push(datafusion::prelude::now().alias(&last_col_schema.name));
330 } else {
331 return Err(DataFusionError::Plan(format!(
333 "Expect the last column in table to be timestamp column, found column {} with type {:?}",
334 last_col_schema.name, last_col_schema.data_type
335 )));
336 }
337 } else if query_col_cnt + 2 == table_col_cnt {
338 let mut col_iter = self.schema.column_schemas().iter().rev();
339 let last_col_schema = col_iter.next().unwrap();
340 let second_last_col_schema = col_iter.next().unwrap();
341 if second_last_col_schema.data_type.is_timestamp() {
342 exprs.push(datafusion::prelude::now().alias(&second_last_col_schema.name));
343 } else {
344 return Err(DataFusionError::Plan(format!(
345 "Expect the second last column in the table to be timestamp column, found column {} with type {:?}",
346 second_last_col_schema.name, second_last_col_schema.data_type
347 )));
348 }
349
350 if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
351 && self.schema.timestamp_index() == Some(table_col_cnt - 1)
352 {
353 exprs.push(placeholder_ts_expr);
354 } else {
355 return Err(DataFusionError::Plan(format!(
356 "Expect timestamp column {}, found {:?}",
357 AUTO_CREATED_PLACEHOLDER_TS_COL, last_col_schema
358 )));
359 }
360 } else {
361 return Err(DataFusionError::Plan(format!(
362 "Expect table have 0,1 or 2 columns more than query columns, found {} query columns {:?}, {} table columns {:?}",
363 query_col_cnt,
364 exprs,
365 table_col_cnt,
366 self.schema.column_schemas()
367 )));
368 }
369 Ok(exprs)
370 }
371}
372
373impl TreeNodeRewriter for ColumnMatcherRewriter {
374 type Node = LogicalPlan;
375 fn f_down(&mut self, mut node: Self::Node) -> DfResult<Transformed<Self::Node>> {
376 if self.is_rewritten {
377 return Ok(Transformed::no(node));
378 }
379
380 if let LogicalPlan::Distinct(Distinct::All(_)) = &node {
382 let mut exprs = vec![];
383
384 for field in node.schema().fields().iter() {
385 exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
386 field.name(),
387 )));
388 }
389
390 let projection =
391 LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
392
393 node = projection;
394 }
395 else if let LogicalPlan::TableScan(table_scan) = node {
397 let mut exprs = vec![];
398
399 for field in table_scan.projected_schema.fields().iter() {
400 exprs.push(Expr::Column(datafusion::common::Column::new(
401 Some(table_scan.table_name.clone()),
402 field.name(),
403 )));
404 }
405
406 let projection = LogicalPlan::Projection(Projection::try_new(
407 exprs,
408 Arc::new(LogicalPlan::TableScan(table_scan)),
409 )?);
410
411 node = projection;
412 }
413
414 if let LogicalPlan::Projection(project) = &node {
418 let exprs = project.expr.clone();
419 let exprs = self.modify_project_exprs(exprs)?;
420
421 self.is_rewritten = true;
422 let new_plan =
423 node.with_new_exprs(exprs, node.inputs().into_iter().cloned().collect())?;
424 Ok(Transformed::yes(new_plan))
425 } else {
426 let mut exprs = vec![];
428 for field in node.schema().fields().iter() {
429 exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
430 field.name(),
431 )));
432 }
433 let exprs = self.modify_project_exprs(exprs)?;
434 self.is_rewritten = true;
435 let new_plan =
436 LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
437 Ok(Transformed::yes(new_plan))
438 }
439 }
440
441 fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
443 node.recompute_schema().map(Transformed::yes)
444 }
445}
446
447#[derive(Debug)]
449pub struct AddFilterRewriter {
450 extra_filter: Expr,
451 is_rewritten: bool,
452}
453
454impl AddFilterRewriter {
455 pub fn new(filter: Expr) -> Self {
456 Self {
457 extra_filter: filter,
458 is_rewritten: false,
459 }
460 }
461}
462
463impl TreeNodeRewriter for AddFilterRewriter {
464 type Node = LogicalPlan;
465 fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
466 if self.is_rewritten {
467 return Ok(Transformed::no(node));
468 }
469 match node {
470 LogicalPlan::Filter(mut filter) => {
471 filter.predicate = filter.predicate.and(self.extra_filter.clone());
472 self.is_rewritten = true;
473 Ok(Transformed::yes(LogicalPlan::Filter(filter)))
474 }
475 LogicalPlan::TableScan(_) => {
476 let filter =
478 datafusion_expr::Filter::try_new(self.extra_filter.clone(), Arc::new(node))?;
479 self.is_rewritten = true;
480 Ok(Transformed::yes(LogicalPlan::Filter(filter)))
481 }
482 _ => Ok(Transformed::no(node)),
483 }
484 }
485}
486
487#[cfg(test)]
488mod test {
489 use std::sync::Arc;
490
491 use datafusion_common::tree_node::TreeNode as _;
492 use datatypes::prelude::ConcreteDataType;
493 use datatypes::schema::{ColumnSchema, Schema};
494 use pretty_assertions::assert_eq;
495 use query::query_engine::DefaultSerializer;
496 use session::context::QueryContext;
497 use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
498
499 use super::*;
500 use crate::test_utils::create_test_query_engine;
501
502 #[tokio::test]
504 async fn test_sql_plan_convert() {
505 let query_engine = create_test_query_engine();
506 let ctx = QueryContext::arc();
507 let old = r#"SELECT "NUMBER" FROM "UPPERCASE_NUMBERS_WITH_TS""#;
508 let new = sql_to_df_plan(ctx.clone(), query_engine.clone(), old, false)
509 .await
510 .unwrap();
511 let new_sql = df_plan_to_sql(&new).unwrap();
512
513 assert_eq!(
514 r#"SELECT `UPPERCASE_NUMBERS_WITH_TS`.`NUMBER` FROM `UPPERCASE_NUMBERS_WITH_TS`"#,
515 new_sql
516 );
517 }
518
519 #[tokio::test]
520 async fn test_add_filter() {
521 let testcases = vec![
522 (
523 "SELECT number FROM numbers_with_ts GROUP BY number",
524 "SELECT numbers_with_ts.number FROM numbers_with_ts WHERE (number > 4) GROUP BY numbers_with_ts.number",
525 ),
526 (
527 "SELECT number FROM numbers_with_ts WHERE number < 2 OR number >10",
528 "SELECT numbers_with_ts.number FROM numbers_with_ts WHERE ((numbers_with_ts.number < 2) OR (numbers_with_ts.number > 10)) AND (number > 4)",
529 ),
530 (
531 "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window",
532 "SELECT date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE (number > 4) GROUP BY date_bin('5 minutes', numbers_with_ts.ts)",
533 ),
534 (
536 "SELECT number, time_window FROM (SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number);",
537 "SELECT numbers_with_ts.number, time_window FROM (SELECT numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE (number > 4) GROUP BY date_bin('5 minutes', numbers_with_ts.ts), numbers_with_ts.number)",
538 ),
539 (
541 "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) GROUP BY number, time_window, bucket_name;",
542 "SELECT sum(numbers_with_ts.number), numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window, bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE (number > 4)) GROUP BY numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts), bucket_name",
543 ),
544 (
546 "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) as cte WHERE number > 1 GROUP BY number, time_window, bucket_name;",
547 "SELECT sum(cte.number), cte.number, date_bin('5 minutes', cte.ts) AS time_window, cte.bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE (number > 4)) AS cte WHERE (cte.number > 1) GROUP BY cte.number, date_bin('5 minutes', cte.ts), cte.bucket_name",
548 ),
549 ];
550 use datafusion_expr::{col, lit};
551 let query_engine = create_test_query_engine();
552 let ctx = QueryContext::arc();
553
554 for (before, after) in testcases {
555 let sql = before;
556 let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
557 .await
558 .unwrap();
559
560 let mut add_filter = AddFilterRewriter::new(col("number").gt(lit(4u32)));
561 let plan = plan.rewrite(&mut add_filter).unwrap().data;
562 let new_sql = df_plan_to_sql(&plan).unwrap();
563 assert_eq!(after, new_sql);
564 }
565 }
566
567 #[tokio::test]
568 async fn test_add_auto_column_rewriter() {
569 let testcases = vec![
570 (
572 "SELECT number FROM numbers_with_ts",
573 Ok("SELECT numbers_with_ts.number, now() AS ts FROM numbers_with_ts"),
574 vec![
575 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
576 ColumnSchema::new(
577 "ts",
578 ConcreteDataType::timestamp_millisecond_datatype(),
579 false,
580 )
581 .with_time_index(true),
582 ],
583 ),
584 (
586 "SELECT number FROM numbers_with_ts",
587 Ok(
588 "SELECT numbers_with_ts.number, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts",
589 ),
590 vec![
591 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
592 ColumnSchema::new(
593 AUTO_CREATED_PLACEHOLDER_TS_COL,
594 ConcreteDataType::timestamp_millisecond_datatype(),
595 false,
596 )
597 .with_time_index(true),
598 ],
599 ),
600 (
602 "SELECT number, ts FROM numbers_with_ts",
603 Ok("SELECT numbers_with_ts.number, numbers_with_ts.ts FROM numbers_with_ts"),
604 vec![
605 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
606 ColumnSchema::new(
607 "ts",
608 ConcreteDataType::timestamp_millisecond_datatype(),
609 false,
610 )
611 .with_time_index(true),
612 ],
613 ),
614 (
616 "SELECT number FROM numbers_with_ts",
617 Ok(
618 "SELECT numbers_with_ts.number, now() AS update_at, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts",
619 ),
620 vec![
621 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
622 ColumnSchema::new(
623 "update_at",
624 ConcreteDataType::timestamp_millisecond_datatype(),
625 false,
626 ),
627 ColumnSchema::new(
628 AUTO_CREATED_PLACEHOLDER_TS_COL,
629 ConcreteDataType::timestamp_millisecond_datatype(),
630 false,
631 )
632 .with_time_index(true),
633 ],
634 ),
635 (
637 "SELECT number, ts FROM numbers_with_ts",
638 Ok(
639 "SELECT numbers_with_ts.number, numbers_with_ts.ts AS update_at, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts",
640 ),
641 vec![
642 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
643 ColumnSchema::new(
644 "update_at",
645 ConcreteDataType::timestamp_millisecond_datatype(),
646 false,
647 ),
648 ColumnSchema::new(
649 AUTO_CREATED_PLACEHOLDER_TS_COL,
650 ConcreteDataType::timestamp_millisecond_datatype(),
651 false,
652 )
653 .with_time_index(true),
654 ],
655 ),
656 (
658 "SELECT number, ts FROM numbers_with_ts",
659 Ok(
660 "SELECT numbers_with_ts.number, numbers_with_ts.ts, now() AS update_atat FROM numbers_with_ts",
661 ),
662 vec![
663 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
664 ColumnSchema::new(
665 "ts",
666 ConcreteDataType::timestamp_millisecond_datatype(),
667 false,
668 )
669 .with_time_index(true),
670 ColumnSchema::new(
671 "update_atat",
673 ConcreteDataType::timestamp_millisecond_datatype(),
674 false,
675 ),
676 ],
677 ),
678 (
680 "SELECT number, ts FROM numbers_with_ts",
681 Err(
682 "Expect the last column in table to be timestamp column, found column atat with type Int8",
683 ),
684 vec![
685 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
686 ColumnSchema::new(
687 "ts",
688 ConcreteDataType::timestamp_millisecond_datatype(),
689 false,
690 )
691 .with_time_index(true),
692 ColumnSchema::new(
693 "atat",
695 ConcreteDataType::int8_datatype(),
696 false,
697 ),
698 ],
699 ),
700 (
702 "SELECT number FROM numbers_with_ts",
703 Err(
704 "Expect the second last column in the table to be timestamp column, found column ts with type Int8",
705 ),
706 vec![
707 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
708 ColumnSchema::new("ts", ConcreteDataType::int8_datatype(), false),
709 ColumnSchema::new(
710 "atat",
712 ConcreteDataType::timestamp_millisecond_datatype(),
713 false,
714 )
715 .with_time_index(true),
716 ],
717 ),
718 ];
719
720 let query_engine = create_test_query_engine();
721 let ctx = QueryContext::arc();
722 for (before, after, column_schemas) in testcases {
723 let schema = Arc::new(Schema::new(column_schemas));
724 let mut add_auto_column_rewriter = ColumnMatcherRewriter::new(schema);
725
726 let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), before, false)
727 .await
728 .unwrap();
729 let new_sql = (|| {
730 let plan = plan
731 .rewrite(&mut add_auto_column_rewriter)
732 .map_err(|e| e.to_string())?
733 .data;
734 df_plan_to_sql(&plan).map_err(|e| e.to_string())
735 })();
736 match (after, new_sql.clone()) {
737 (Ok(after), Ok(new_sql)) => assert_eq!(after, new_sql),
738 (Err(expected), Err(real_err_msg)) => assert!(
739 real_err_msg.contains(expected),
740 "expected: {expected}, real: {real_err_msg}"
741 ),
742 _ => panic!("expected: {:?}, real: {:?}", after, new_sql),
743 }
744 }
745 }
746
747 #[tokio::test]
748 async fn test_find_group_by_exprs() {
749 let testcases = vec![
750 (
751 "SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS ts FROM numbers_with_ts GROUP BY ts;",
752 vec!["ts"],
753 ),
754 (
755 "SELECT number FROM numbers_with_ts GROUP BY number",
756 vec!["number"],
757 ),
758 (
759 "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window",
760 vec!["time_window"],
761 ),
762 (
764 "SELECT number, time_window FROM (SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number);",
765 vec!["time_window", "number"],
766 ),
767 (
769 "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) GROUP BY number, time_window, bucket_name;",
770 vec!["number", "time_window", "bucket_name"],
771 ),
772 (
774 "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) as cte GROUP BY number, time_window, bucket_name;",
775 vec!["number", "time_window", "bucket_name"],
776 ),
777 ];
778
779 let query_engine = create_test_query_engine();
780 let ctx = QueryContext::arc();
781 for (sql, expected) in testcases {
782 let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
784 .await
785 .unwrap();
786 let mut group_by_exprs = FindGroupByFinalName::default();
787 plan.visit(&mut group_by_exprs).unwrap();
788 let expected: HashSet<String> = expected.into_iter().map(|s| s.to_string()).collect();
789 assert_eq!(
790 expected,
791 group_by_exprs.get_group_expr_names().unwrap_or_default()
792 );
793 }
794 }
795
796 #[tokio::test]
797 async fn test_null_cast() {
798 let query_engine = create_test_query_engine();
799 let ctx = QueryContext::arc();
800 let sql = "SELECT NULL::DOUBLE FROM numbers_with_ts";
801 let plan = sql_to_df_plan(ctx, query_engine.clone(), sql, false)
802 .await
803 .unwrap();
804
805 let _sub_plan = DFLogicalSubstraitConvertor {}
806 .encode(&plan, DefaultSerializer)
807 .unwrap();
808 }
809}