1use std::collections::{BTreeSet, HashSet};
18use std::sync::Arc;
19
20use catalog::CatalogManagerRef;
21use common_error::ext::BoxedError;
22use common_telemetry::debug;
23use datafusion::error::Result as DfResult;
24use datafusion::logical_expr::Expr;
25use datafusion::sql::unparser::Unparser;
26use datafusion_common::tree_node::{
27 Transformed, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
28};
29use datafusion_common::{DFSchema, DataFusionError, ScalarValue};
30use datafusion_expr::{Distinct, LogicalPlan, Projection};
31use datatypes::schema::SchemaRef;
32use query::parser::{PromQuery, QueryLanguageParser, QueryStatement, DEFAULT_LOOKBACK_STRING};
33use query::QueryEngineRef;
34use session::context::QueryContextRef;
35use snafu::{ensure, OptionExt, ResultExt};
36use sql::parser::{ParseOptions, ParserContext};
37use sql::statements::statement::Statement;
38use sql::statements::tql::Tql;
39use table::metadata::TableInfo;
40
41use crate::adapter::AUTO_CREATED_PLACEHOLDER_TS_COL;
42use crate::df_optimizer::apply_df_optimizer;
43use crate::error::{DatafusionSnafu, ExternalSnafu, InvalidQuerySnafu, TableNotFoundSnafu};
44use crate::{Error, TableName};
45
46pub async fn get_table_info_df_schema(
47 catalog_mr: CatalogManagerRef,
48 table_name: TableName,
49) -> Result<(Arc<TableInfo>, Arc<DFSchema>), Error> {
50 let full_table_name = table_name.clone().join(".");
51 let table = catalog_mr
52 .table(&table_name[0], &table_name[1], &table_name[2], None)
53 .await
54 .map_err(BoxedError::new)
55 .context(ExternalSnafu)?
56 .context(TableNotFoundSnafu {
57 name: &full_table_name,
58 })?;
59 let table_info = table.table_info().clone();
60
61 let schema = table_info.meta.schema.clone();
62
63 let df_schema: Arc<DFSchema> = Arc::new(
64 schema
65 .arrow_schema()
66 .clone()
67 .try_into()
68 .with_context(|_| DatafusionSnafu {
69 context: format!(
70 "Failed to convert arrow schema to datafusion schema, arrow_schema={:?}",
71 schema.arrow_schema()
72 ),
73 })?,
74 );
75 Ok((table_info, df_schema))
76}
77
78pub async fn sql_to_df_plan(
81 query_ctx: QueryContextRef,
82 engine: QueryEngineRef,
83 sql: &str,
84 optimize: bool,
85) -> Result<LogicalPlan, Error> {
86 let stmts =
87 ParserContext::create_with_dialect(sql, query_ctx.sql_dialect(), ParseOptions::default())
88 .map_err(BoxedError::new)
89 .context(ExternalSnafu)?;
90
91 ensure!(
92 stmts.len() == 1,
93 InvalidQuerySnafu {
94 reason: format!("Expect only one statement, found {}", stmts.len())
95 }
96 );
97 let stmt = &stmts[0];
98 let query_stmt = match stmt {
99 Statement::Tql(tql) => match tql {
100 Tql::Eval(eval) => {
101 let eval = eval.clone();
102 let promql = PromQuery {
103 start: eval.start,
104 end: eval.end,
105 step: eval.step,
106 query: eval.query,
107 lookback: eval
108 .lookback
109 .unwrap_or_else(|| DEFAULT_LOOKBACK_STRING.to_string()),
110 };
111
112 QueryLanguageParser::parse_promql(&promql, &query_ctx)
113 .map_err(BoxedError::new)
114 .context(ExternalSnafu)?
115 }
116 _ => InvalidQuerySnafu {
117 reason: format!("TQL statement {tql:?} is not supported, expect only TQL EVAL"),
118 }
119 .fail()?,
120 },
121 _ => QueryStatement::Sql(stmt.clone()),
122 };
123 let plan = engine
124 .planner()
125 .plan(&query_stmt, query_ctx)
126 .await
127 .map_err(BoxedError::new)
128 .context(ExternalSnafu)?;
129
130 let plan = if optimize {
131 apply_df_optimizer(plan).await?
132 } else {
133 plan
134 };
135 Ok(plan)
136}
137
138pub fn df_plan_to_sql(plan: &LogicalPlan) -> Result<String, Error> {
139 struct ForceQuoteIdentifiers;
141 impl datafusion::sql::unparser::dialect::Dialect for ForceQuoteIdentifiers {
142 fn identifier_quote_style(&self, identifier: &str) -> Option<char> {
143 if identifier.to_lowercase() != identifier {
144 Some('`')
145 } else {
146 None
147 }
148 }
149 }
150 let unparser = Unparser::new(&ForceQuoteIdentifiers);
151 let sql = unparser
153 .plan_to_sql(plan)
154 .with_context(|_e| DatafusionSnafu {
155 context: format!("Failed to unparse logical plan {plan:?}"),
156 })?;
157 Ok(sql.to_string())
158}
159
160#[derive(Debug, Clone, Default)]
162pub struct FindGroupByFinalName {
163 group_exprs: Option<HashSet<datafusion_expr::Expr>>,
164}
165
166impl FindGroupByFinalName {
167 pub fn get_group_expr_names(&self) -> Option<HashSet<String>> {
168 self.group_exprs
169 .as_ref()
170 .map(|exprs| exprs.iter().map(|expr| expr.qualified_name().1).collect())
171 }
172}
173
174impl TreeNodeVisitor<'_> for FindGroupByFinalName {
175 type Node = LogicalPlan;
176
177 fn f_down(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
178 if let LogicalPlan::Aggregate(aggregate) = node {
179 self.group_exprs = Some(aggregate.group_expr.iter().cloned().collect());
180 debug!(
181 "FindGroupByFinalName: Get Group by exprs from Aggregate: {:?}",
182 self.group_exprs
183 );
184 } else if let LogicalPlan::Distinct(distinct) = node {
185 debug!("FindGroupByFinalName: Distinct: {}", node);
186 match distinct {
187 Distinct::All(input) => {
188 if let LogicalPlan::TableScan(table_scan) = &**input {
189 let len = table_scan.projected_schema.fields().len();
191 let columns = (0..len)
192 .map(|f| {
193 let (qualifier, field) =
194 table_scan.projected_schema.qualified_field(f);
195 datafusion_common::Column::new(qualifier.cloned(), field.name())
196 })
197 .map(datafusion_expr::Expr::Column);
198 self.group_exprs = Some(columns.collect());
199 } else {
200 self.group_exprs = Some(input.expressions().iter().cloned().collect())
201 }
202 }
203 Distinct::On(distinct_on) => {
204 self.group_exprs = Some(distinct_on.on_expr.iter().cloned().collect())
205 }
206 }
207 debug!(
208 "FindGroupByFinalName: Get Group by exprs from Distinct: {:?}",
209 self.group_exprs
210 );
211 }
212
213 Ok(TreeNodeRecursion::Continue)
214 }
215
216 fn f_up(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
218 if let LogicalPlan::Projection(projection) = node {
219 for expr in &projection.expr {
220 let Some(group_exprs) = &mut self.group_exprs else {
221 return Ok(TreeNodeRecursion::Continue);
222 };
223 if let datafusion_expr::Expr::Alias(alias) = expr {
224 let mut new_group_exprs = group_exprs.clone();
226 for group_expr in group_exprs.iter() {
227 if group_expr.name_for_alias()? == alias.expr.name_for_alias()? {
228 new_group_exprs.remove(group_expr);
229 new_group_exprs.insert(expr.clone());
230 break;
231 }
232 }
233 *group_exprs = new_group_exprs;
234 }
235 }
236 }
237 debug!("Aliased group by exprs: {:?}", self.group_exprs);
238 Ok(TreeNodeRecursion::Continue)
239 }
240}
241
242#[derive(Debug)]
249pub struct AddAutoColumnRewriter {
250 pub schema: SchemaRef,
251 pub is_rewritten: bool,
252}
253
254impl AddAutoColumnRewriter {
255 pub fn new(schema: SchemaRef) -> Self {
256 Self {
257 schema,
258 is_rewritten: false,
259 }
260 }
261}
262
263impl TreeNodeRewriter for AddAutoColumnRewriter {
264 type Node = LogicalPlan;
265 fn f_down(&mut self, mut node: Self::Node) -> DfResult<Transformed<Self::Node>> {
266 if self.is_rewritten {
267 return Ok(Transformed::no(node));
268 }
269
270 if let LogicalPlan::Distinct(Distinct::All(_)) = &node {
272 let mut exprs = vec![];
273
274 for field in node.schema().fields().iter() {
275 exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
276 field.name(),
277 )));
278 }
279
280 let projection =
281 LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
282
283 node = projection;
284 }
285 else if let LogicalPlan::TableScan(table_scan) = node {
287 let mut exprs = vec![];
288
289 for field in table_scan.projected_schema.fields().iter() {
290 exprs.push(Expr::Column(datafusion::common::Column::new(
291 Some(table_scan.table_name.clone()),
292 field.name(),
293 )));
294 }
295
296 let projection = LogicalPlan::Projection(Projection::try_new(
297 exprs,
298 Arc::new(LogicalPlan::TableScan(table_scan)),
299 )?);
300
301 node = projection;
302 }
303
304 let mut exprs = if let LogicalPlan::Projection(project) = &node {
306 project.expr.clone()
307 } else {
308 return Ok(Transformed::no(node));
309 };
310
311 let all_names = self
312 .schema
313 .column_schemas()
314 .iter()
315 .map(|c| c.name.clone())
316 .collect::<BTreeSet<_>>();
317 for (idx, expr) in exprs.iter_mut().enumerate() {
319 if !all_names.contains(&expr.qualified_name().1) {
320 if let Some(col_name) = self
321 .schema
322 .column_schemas()
323 .get(idx)
324 .map(|c| c.name.clone())
325 {
326 *expr = expr.clone().alias(col_name);
330 }
331 }
332 }
333
334 let query_col_cnt = exprs.len();
336 let table_col_cnt = self.schema.column_schemas().len();
337 debug!("query_col_cnt={query_col_cnt}, table_col_cnt={table_col_cnt}");
338
339 let placeholder_ts_expr =
340 datafusion::logical_expr::lit(ScalarValue::TimestampMillisecond(Some(0), None))
341 .alias(AUTO_CREATED_PLACEHOLDER_TS_COL);
342
343 if query_col_cnt == table_col_cnt {
344 } else if query_col_cnt + 1 == table_col_cnt {
346 let last_col_schema = self.schema.column_schemas().last().unwrap();
347
348 if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
350 && self.schema.timestamp_index() == Some(table_col_cnt - 1)
351 {
352 exprs.push(placeholder_ts_expr);
353 } else if last_col_schema.data_type.is_timestamp() {
354 exprs.push(datafusion::prelude::now().alias(&last_col_schema.name));
356 } else {
357 return Err(DataFusionError::Plan(format!(
359 "Expect the last column in table to be timestamp column, found column {} with type {:?}",
360 last_col_schema.name,
361 last_col_schema.data_type
362 )));
363 }
364 } else if query_col_cnt + 2 == table_col_cnt {
365 let mut col_iter = self.schema.column_schemas().iter().rev();
366 let last_col_schema = col_iter.next().unwrap();
367 let second_last_col_schema = col_iter.next().unwrap();
368 if second_last_col_schema.data_type.is_timestamp() {
369 exprs.push(datafusion::prelude::now().alias(&second_last_col_schema.name));
370 } else {
371 return Err(DataFusionError::Plan(format!(
372 "Expect the second last column in the table to be timestamp column, found column {} with type {:?}",
373 second_last_col_schema.name,
374 second_last_col_schema.data_type
375 )));
376 }
377
378 if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
379 && self.schema.timestamp_index() == Some(table_col_cnt - 1)
380 {
381 exprs.push(placeholder_ts_expr);
382 } else {
383 return Err(DataFusionError::Plan(format!(
384 "Expect timestamp column {}, found {:?}",
385 AUTO_CREATED_PLACEHOLDER_TS_COL, last_col_schema
386 )));
387 }
388 } else {
389 return Err(DataFusionError::Plan(format!(
390 "Expect table have 0,1 or 2 columns more than query columns, found {} query columns {:?}, {} table columns {:?}",
391 query_col_cnt, exprs, table_col_cnt, self.schema.column_schemas()
392 )));
393 }
394
395 self.is_rewritten = true;
396 let new_plan = node.with_new_exprs(exprs, node.inputs().into_iter().cloned().collect())?;
397 Ok(Transformed::yes(new_plan))
398 }
399
400 fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
402 node.recompute_schema().map(Transformed::yes)
403 }
404}
405
406#[derive(Debug)]
408pub struct AddFilterRewriter {
409 extra_filter: Expr,
410 is_rewritten: bool,
411}
412
413impl AddFilterRewriter {
414 pub fn new(filter: Expr) -> Self {
415 Self {
416 extra_filter: filter,
417 is_rewritten: false,
418 }
419 }
420}
421
422impl TreeNodeRewriter for AddFilterRewriter {
423 type Node = LogicalPlan;
424 fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
425 if self.is_rewritten {
426 return Ok(Transformed::no(node));
427 }
428 match node {
429 LogicalPlan::Filter(mut filter) if !filter.having => {
430 filter.predicate = filter.predicate.and(self.extra_filter.clone());
431 self.is_rewritten = true;
432 Ok(Transformed::yes(LogicalPlan::Filter(filter)))
433 }
434 LogicalPlan::TableScan(_) => {
435 let filter =
437 datafusion_expr::Filter::try_new(self.extra_filter.clone(), Arc::new(node))?;
438 self.is_rewritten = true;
439 Ok(Transformed::yes(LogicalPlan::Filter(filter)))
440 }
441 _ => Ok(Transformed::no(node)),
442 }
443 }
444}
445
446#[cfg(test)]
447mod test {
448 use std::sync::Arc;
449
450 use datafusion_common::tree_node::TreeNode as _;
451 use datatypes::prelude::ConcreteDataType;
452 use datatypes::schema::{ColumnSchema, Schema};
453 use pretty_assertions::assert_eq;
454 use query::query_engine::DefaultSerializer;
455 use session::context::QueryContext;
456 use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
457
458 use super::*;
459 use crate::test_utils::create_test_query_engine;
460
461 #[tokio::test]
463 async fn test_sql_plan_convert() {
464 let query_engine = create_test_query_engine();
465 let ctx = QueryContext::arc();
466 let old = r#"SELECT "NUMBER" FROM "UPPERCASE_NUMBERS_WITH_TS""#;
467 let new = sql_to_df_plan(ctx.clone(), query_engine.clone(), old, false)
468 .await
469 .unwrap();
470 let new_sql = df_plan_to_sql(&new).unwrap();
471
472 assert_eq!(
473 r#"SELECT `UPPERCASE_NUMBERS_WITH_TS`.`NUMBER` FROM `UPPERCASE_NUMBERS_WITH_TS`"#,
474 new_sql
475 );
476 }
477
478 #[tokio::test]
479 async fn test_add_filter() {
480 let testcases = vec![
481 (
482 "SELECT number FROM numbers_with_ts GROUP BY number",
483 "SELECT numbers_with_ts.number FROM numbers_with_ts WHERE (number > 4) GROUP BY numbers_with_ts.number"
484 ),
485
486 (
487 "SELECT number FROM numbers_with_ts WHERE number < 2 OR number >10",
488 "SELECT numbers_with_ts.number FROM numbers_with_ts WHERE ((numbers_with_ts.number < 2) OR (numbers_with_ts.number > 10)) AND (number > 4)"
489 ),
490
491 (
492 "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window",
493 "SELECT date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE (number > 4) GROUP BY date_bin('5 minutes', numbers_with_ts.ts)"
494 ),
495
496 (
498 "SELECT number, time_window FROM (SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number);",
499 "SELECT numbers_with_ts.number, time_window FROM (SELECT numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE (number > 4) GROUP BY date_bin('5 minutes', numbers_with_ts.ts), numbers_with_ts.number)"
500 ),
501
502 (
504 "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) GROUP BY number, time_window, bucket_name;",
505 "SELECT sum(numbers_with_ts.number), numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window, bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE (number > 4)) GROUP BY numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts), bucket_name"
506 ),
507
508 (
510 "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) as cte WHERE number > 1 GROUP BY number, time_window, bucket_name;",
511 "SELECT sum(cte.number), cte.number, date_bin('5 minutes', cte.ts) AS time_window, cte.bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE (number > 4)) AS cte WHERE (cte.number > 1) GROUP BY cte.number, date_bin('5 minutes', cte.ts), cte.bucket_name"
512 )
513 ];
514 use datafusion_expr::{col, lit};
515 let query_engine = create_test_query_engine();
516 let ctx = QueryContext::arc();
517
518 for (before, after) in testcases {
519 let sql = before;
520 let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
521 .await
522 .unwrap();
523
524 let mut add_filter = AddFilterRewriter::new(col("number").gt(lit(4u32)));
525 let plan = plan.rewrite(&mut add_filter).unwrap().data;
526 let new_sql = df_plan_to_sql(&plan).unwrap();
527 assert_eq!(after, new_sql);
528 }
529 }
530
531 #[tokio::test]
532 async fn test_add_auto_column_rewriter() {
533 let testcases = vec![
534 (
536 "SELECT number FROM numbers_with_ts",
537 Ok("SELECT numbers_with_ts.number, now() AS ts FROM numbers_with_ts"),
538 vec![
539 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
540 ColumnSchema::new(
541 "ts",
542 ConcreteDataType::timestamp_millisecond_datatype(),
543 false,
544 )
545 .with_time_index(true),
546 ],
547 ),
548 (
550 "SELECT number FROM numbers_with_ts",
551 Ok("SELECT numbers_with_ts.number, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts"),
552 vec![
553 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
554 ColumnSchema::new(
555 AUTO_CREATED_PLACEHOLDER_TS_COL,
556 ConcreteDataType::timestamp_millisecond_datatype(),
557 false,
558 )
559 .with_time_index(true),
560 ],
561 ),
562 (
564 "SELECT number, ts FROM numbers_with_ts",
565 Ok("SELECT numbers_with_ts.number, numbers_with_ts.ts FROM numbers_with_ts"),
566 vec![
567 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
568 ColumnSchema::new(
569 "ts",
570 ConcreteDataType::timestamp_millisecond_datatype(),
571 false,
572 )
573 .with_time_index(true),
574 ],
575 ),
576 (
578 "SELECT number FROM numbers_with_ts",
579 Ok("SELECT numbers_with_ts.number, now() AS update_at, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts"),
580 vec![
581 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
582 ColumnSchema::new(
583 "update_at",
584 ConcreteDataType::timestamp_millisecond_datatype(),
585 false,
586 ),
587 ColumnSchema::new(
588 AUTO_CREATED_PLACEHOLDER_TS_COL,
589 ConcreteDataType::timestamp_millisecond_datatype(),
590 false,
591 )
592 .with_time_index(true),
593 ],
594 ),
595 (
597 "SELECT number, ts FROM numbers_with_ts",
598 Ok("SELECT numbers_with_ts.number, numbers_with_ts.ts AS update_at, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts"),
599 vec![
600 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
601 ColumnSchema::new(
602 "update_at",
603 ConcreteDataType::timestamp_millisecond_datatype(),
604 false,
605 ),
606 ColumnSchema::new(
607 AUTO_CREATED_PLACEHOLDER_TS_COL,
608 ConcreteDataType::timestamp_millisecond_datatype(),
609 false,
610 )
611 .with_time_index(true),
612 ],
613 ),
614 (
616 "SELECT number, ts FROM numbers_with_ts",
617 Ok("SELECT numbers_with_ts.number, numbers_with_ts.ts, now() AS update_atat FROM numbers_with_ts"),
618 vec![
619 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
620 ColumnSchema::new(
621 "ts",
622 ConcreteDataType::timestamp_millisecond_datatype(),
623 false,
624 )
625 .with_time_index(true),
626 ColumnSchema::new(
627 "update_atat",
629 ConcreteDataType::timestamp_millisecond_datatype(),
630 false,
631 ),
632 ],
633 ),
634 (
636 "SELECT number, ts FROM numbers_with_ts",
637 Err("Expect the last column in table to be timestamp column, found column atat with type Int8"),
638 vec![
639 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
640 ColumnSchema::new(
641 "ts",
642 ConcreteDataType::timestamp_millisecond_datatype(),
643 false,
644 )
645 .with_time_index(true),
646 ColumnSchema::new(
647 "atat",
649 ConcreteDataType::int8_datatype(),
650 false,
651 ),
652 ],
653 ),
654 (
656 "SELECT number FROM numbers_with_ts",
657 Err("Expect the second last column in the table to be timestamp column, found column ts with type Int8"),
658 vec![
659 ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
660 ColumnSchema::new(
661 "ts",
662 ConcreteDataType::int8_datatype(),
663 false,
664 ),
665 ColumnSchema::new(
666 "atat",
668 ConcreteDataType::timestamp_millisecond_datatype(),
669 false,
670 )
671 .with_time_index(true),
672 ],
673 ),
674 ];
675
676 let query_engine = create_test_query_engine();
677 let ctx = QueryContext::arc();
678 for (before, after, column_schemas) in testcases {
679 let schema = Arc::new(Schema::new(column_schemas));
680 let mut add_auto_column_rewriter = AddAutoColumnRewriter::new(schema);
681
682 let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), before, false)
683 .await
684 .unwrap();
685 let new_sql = (|| {
686 let plan = plan
687 .rewrite(&mut add_auto_column_rewriter)
688 .map_err(|e| e.to_string())?
689 .data;
690 df_plan_to_sql(&plan).map_err(|e| e.to_string())
691 })();
692 match (after, new_sql.clone()) {
693 (Ok(after), Ok(new_sql)) => assert_eq!(after, new_sql),
694 (Err(expected), Err(real_err_msg)) => assert!(
695 real_err_msg.contains(expected),
696 "expected: {expected}, real: {real_err_msg}"
697 ),
698 _ => panic!("expected: {:?}, real: {:?}", after, new_sql),
699 }
700 }
701 }
702
703 #[tokio::test]
704 async fn test_find_group_by_exprs() {
705 let testcases = vec![
706 (
707 "SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS ts FROM numbers_with_ts GROUP BY ts;",
708 vec!["ts"]
709 ),
710 (
711 "SELECT number FROM numbers_with_ts GROUP BY number",
712 vec!["number"]
713 ),
714 (
715 "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window",
716 vec!["time_window"]
717 ),
718 (
720 "SELECT number, time_window FROM (SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number);",
721 vec!["time_window", "number"]
722 ),
723 (
725 "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) GROUP BY number, time_window, bucket_name;",
726 vec!["number", "time_window", "bucket_name"]
727 ),
728 (
730 "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) as cte GROUP BY number, time_window, bucket_name;",
731 vec!["number", "time_window", "bucket_name"]
732 )
733 ];
734
735 let query_engine = create_test_query_engine();
736 let ctx = QueryContext::arc();
737 for (sql, expected) in testcases {
738 let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
740 .await
741 .unwrap();
742 let mut group_by_exprs = FindGroupByFinalName::default();
743 plan.visit(&mut group_by_exprs).unwrap();
744 let expected: HashSet<String> = expected.into_iter().map(|s| s.to_string()).collect();
745 assert_eq!(
746 expected,
747 group_by_exprs.get_group_expr_names().unwrap_or_default()
748 );
749 }
750 }
751
752 #[tokio::test]
753 async fn test_null_cast() {
754 let query_engine = create_test_query_engine();
755 let ctx = QueryContext::arc();
756 let sql = "SELECT NULL::DOUBLE FROM numbers_with_ts";
757 let plan = sql_to_df_plan(ctx, query_engine.clone(), sql, false)
758 .await
759 .unwrap();
760
761 let _sub_plan = DFLogicalSubstraitConvertor {}
762 .encode(&plan, DefaultSerializer)
763 .unwrap();
764 }
765}