flow/batching_mode/
utils.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! some utils for helping with batching mode
16
17use std::collections::{BTreeSet, HashSet};
18use std::sync::Arc;
19
20use catalog::CatalogManagerRef;
21use common_error::ext::BoxedError;
22use common_telemetry::debug;
23use datafusion::error::Result as DfResult;
24use datafusion::logical_expr::Expr;
25use datafusion::sql::unparser::Unparser;
26use datafusion_common::tree_node::{
27    Transformed, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
28};
29use datafusion_common::{DFSchema, DataFusionError, ScalarValue};
30use datafusion_expr::{Distinct, LogicalPlan, Projection};
31use datatypes::schema::SchemaRef;
32use query::parser::{PromQuery, QueryLanguageParser, QueryStatement, DEFAULT_LOOKBACK_STRING};
33use query::QueryEngineRef;
34use session::context::QueryContextRef;
35use snafu::{ensure, OptionExt, ResultExt};
36use sql::parser::{ParseOptions, ParserContext};
37use sql::statements::statement::Statement;
38use sql::statements::tql::Tql;
39use table::metadata::TableInfo;
40
41use crate::adapter::AUTO_CREATED_PLACEHOLDER_TS_COL;
42use crate::df_optimizer::apply_df_optimizer;
43use crate::error::{DatafusionSnafu, ExternalSnafu, InvalidQuerySnafu, TableNotFoundSnafu};
44use crate::{Error, TableName};
45
46pub async fn get_table_info_df_schema(
47    catalog_mr: CatalogManagerRef,
48    table_name: TableName,
49) -> Result<(Arc<TableInfo>, Arc<DFSchema>), Error> {
50    let full_table_name = table_name.clone().join(".");
51    let table = catalog_mr
52        .table(&table_name[0], &table_name[1], &table_name[2], None)
53        .await
54        .map_err(BoxedError::new)
55        .context(ExternalSnafu)?
56        .context(TableNotFoundSnafu {
57            name: &full_table_name,
58        })?;
59    let table_info = table.table_info().clone();
60
61    let schema = table_info.meta.schema.clone();
62
63    let df_schema: Arc<DFSchema> = Arc::new(
64        schema
65            .arrow_schema()
66            .clone()
67            .try_into()
68            .with_context(|_| DatafusionSnafu {
69                context: format!(
70                    "Failed to convert arrow schema to datafusion schema, arrow_schema={:?}",
71                    schema.arrow_schema()
72                ),
73            })?,
74    );
75    Ok((table_info, df_schema))
76}
77
78/// Convert sql to datafusion logical plan
79/// Also support TQL (but only Eval not Explain or Analyze)
80pub async fn sql_to_df_plan(
81    query_ctx: QueryContextRef,
82    engine: QueryEngineRef,
83    sql: &str,
84    optimize: bool,
85) -> Result<LogicalPlan, Error> {
86    let stmts =
87        ParserContext::create_with_dialect(sql, query_ctx.sql_dialect(), ParseOptions::default())
88            .map_err(BoxedError::new)
89            .context(ExternalSnafu)?;
90
91    ensure!(
92        stmts.len() == 1,
93        InvalidQuerySnafu {
94            reason: format!("Expect only one statement, found {}", stmts.len())
95        }
96    );
97    let stmt = &stmts[0];
98    let query_stmt = match stmt {
99        Statement::Tql(tql) => match tql {
100            Tql::Eval(eval) => {
101                let eval = eval.clone();
102                let promql = PromQuery {
103                    start: eval.start,
104                    end: eval.end,
105                    step: eval.step,
106                    query: eval.query,
107                    lookback: eval
108                        .lookback
109                        .unwrap_or_else(|| DEFAULT_LOOKBACK_STRING.to_string()),
110                };
111
112                QueryLanguageParser::parse_promql(&promql, &query_ctx)
113                    .map_err(BoxedError::new)
114                    .context(ExternalSnafu)?
115            }
116            _ => InvalidQuerySnafu {
117                reason: format!("TQL statement {tql:?} is not supported, expect only TQL EVAL"),
118            }
119            .fail()?,
120        },
121        _ => QueryStatement::Sql(stmt.clone()),
122    };
123    let plan = engine
124        .planner()
125        .plan(&query_stmt, query_ctx)
126        .await
127        .map_err(BoxedError::new)
128        .context(ExternalSnafu)?;
129
130    let plan = if optimize {
131        apply_df_optimizer(plan).await?
132    } else {
133        plan
134    };
135    Ok(plan)
136}
137
138pub fn df_plan_to_sql(plan: &LogicalPlan) -> Result<String, Error> {
139    /// A dialect that forces identifiers to be quoted when have uppercase
140    struct ForceQuoteIdentifiers;
141    impl datafusion::sql::unparser::dialect::Dialect for ForceQuoteIdentifiers {
142        fn identifier_quote_style(&self, identifier: &str) -> Option<char> {
143            if identifier.to_lowercase() != identifier {
144                Some('`')
145            } else {
146                None
147            }
148        }
149    }
150    let unparser = Unparser::new(&ForceQuoteIdentifiers);
151    // first make all column qualified
152    let sql = unparser
153        .plan_to_sql(plan)
154        .with_context(|_e| DatafusionSnafu {
155            context: format!("Failed to unparse logical plan {plan:?}"),
156        })?;
157    Ok(sql.to_string())
158}
159
160/// Helper to find the innermost group by expr in schema, return None if no group by expr
161#[derive(Debug, Clone, Default)]
162pub struct FindGroupByFinalName {
163    group_exprs: Option<HashSet<datafusion_expr::Expr>>,
164}
165
166impl FindGroupByFinalName {
167    pub fn get_group_expr_names(&self) -> Option<HashSet<String>> {
168        self.group_exprs
169            .as_ref()
170            .map(|exprs| exprs.iter().map(|expr| expr.qualified_name().1).collect())
171    }
172}
173
174impl TreeNodeVisitor<'_> for FindGroupByFinalName {
175    type Node = LogicalPlan;
176
177    fn f_down(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
178        if let LogicalPlan::Aggregate(aggregate) = node {
179            self.group_exprs = Some(aggregate.group_expr.iter().cloned().collect());
180            debug!(
181                "FindGroupByFinalName: Get Group by exprs from Aggregate: {:?}",
182                self.group_exprs
183            );
184        } else if let LogicalPlan::Distinct(distinct) = node {
185            debug!("FindGroupByFinalName: Distinct: {}", node);
186            match distinct {
187                Distinct::All(input) => {
188                    if let LogicalPlan::TableScan(table_scan) = &**input {
189                        // get column from field_qualifier, projection and projected_schema:
190                        let len = table_scan.projected_schema.fields().len();
191                        let columns = (0..len)
192                            .map(|f| {
193                                let (qualifier, field) =
194                                    table_scan.projected_schema.qualified_field(f);
195                                datafusion_common::Column::new(qualifier.cloned(), field.name())
196                            })
197                            .map(datafusion_expr::Expr::Column);
198                        self.group_exprs = Some(columns.collect());
199                    } else {
200                        self.group_exprs = Some(input.expressions().iter().cloned().collect())
201                    }
202                }
203                Distinct::On(distinct_on) => {
204                    self.group_exprs = Some(distinct_on.on_expr.iter().cloned().collect())
205                }
206            }
207            debug!(
208                "FindGroupByFinalName: Get Group by exprs from Distinct: {:?}",
209                self.group_exprs
210            );
211        }
212
213        Ok(TreeNodeRecursion::Continue)
214    }
215
216    /// deal with projection when going up with group exprs
217    fn f_up(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
218        if let LogicalPlan::Projection(projection) = node {
219            for expr in &projection.expr {
220                let Some(group_exprs) = &mut self.group_exprs else {
221                    return Ok(TreeNodeRecursion::Continue);
222                };
223                if let datafusion_expr::Expr::Alias(alias) = expr {
224                    // if a alias exist, replace with the new alias
225                    let mut new_group_exprs = group_exprs.clone();
226                    for group_expr in group_exprs.iter() {
227                        if group_expr.name_for_alias()? == alias.expr.name_for_alias()? {
228                            new_group_exprs.remove(group_expr);
229                            new_group_exprs.insert(expr.clone());
230                            break;
231                        }
232                    }
233                    *group_exprs = new_group_exprs;
234                }
235            }
236        }
237        debug!("Aliased group by exprs: {:?}", self.group_exprs);
238        Ok(TreeNodeRecursion::Continue)
239    }
240}
241
242/// Add to the final select columns like `update_at`
243/// (which doesn't necessary need to have exact name just need to be a extra timestamp column)
244/// and `__ts_placeholder`(this column need to have exact this name and be a timestamp)
245/// with values like `now()` and `0`
246///
247/// it also give existing columns alias to column in sink table if needed
248#[derive(Debug)]
249pub struct AddAutoColumnRewriter {
250    pub schema: SchemaRef,
251    pub is_rewritten: bool,
252}
253
254impl AddAutoColumnRewriter {
255    pub fn new(schema: SchemaRef) -> Self {
256        Self {
257            schema,
258            is_rewritten: false,
259        }
260    }
261}
262
263impl TreeNodeRewriter for AddAutoColumnRewriter {
264    type Node = LogicalPlan;
265    fn f_down(&mut self, mut node: Self::Node) -> DfResult<Transformed<Self::Node>> {
266        if self.is_rewritten {
267            return Ok(Transformed::no(node));
268        }
269
270        // if is distinct all, wrap it in a projection
271        if let LogicalPlan::Distinct(Distinct::All(_)) = &node {
272            let mut exprs = vec![];
273
274            for field in node.schema().fields().iter() {
275                exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
276                    field.name(),
277                )));
278            }
279
280            let projection =
281                LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
282
283            node = projection;
284        }
285        // handle table_scan by wrap it in a projection
286        else if let LogicalPlan::TableScan(table_scan) = node {
287            let mut exprs = vec![];
288
289            for field in table_scan.projected_schema.fields().iter() {
290                exprs.push(Expr::Column(datafusion::common::Column::new(
291                    Some(table_scan.table_name.clone()),
292                    field.name(),
293                )));
294            }
295
296            let projection = LogicalPlan::Projection(Projection::try_new(
297                exprs,
298                Arc::new(LogicalPlan::TableScan(table_scan)),
299            )?);
300
301            node = projection;
302        }
303
304        // only do rewrite if found the outermost projection
305        let mut exprs = if let LogicalPlan::Projection(project) = &node {
306            project.expr.clone()
307        } else {
308            return Ok(Transformed::no(node));
309        };
310
311        let all_names = self
312            .schema
313            .column_schemas()
314            .iter()
315            .map(|c| c.name.clone())
316            .collect::<BTreeSet<_>>();
317        // first match by position
318        for (idx, expr) in exprs.iter_mut().enumerate() {
319            if !all_names.contains(&expr.qualified_name().1) {
320                if let Some(col_name) = self
321                    .schema
322                    .column_schemas()
323                    .get(idx)
324                    .map(|c| c.name.clone())
325                {
326                    // if the data type mismatched, later check_execute will error out
327                    // hence no need to check it here, beside, optimize pass might be able to cast it
328                    // so checking here is not necessary
329                    *expr = expr.clone().alias(col_name);
330                }
331            }
332        }
333
334        // add columns if have different column count
335        let query_col_cnt = exprs.len();
336        let table_col_cnt = self.schema.column_schemas().len();
337        debug!("query_col_cnt={query_col_cnt}, table_col_cnt={table_col_cnt}");
338
339        let placeholder_ts_expr =
340            datafusion::logical_expr::lit(ScalarValue::TimestampMillisecond(Some(0), None))
341                .alias(AUTO_CREATED_PLACEHOLDER_TS_COL);
342
343        if query_col_cnt == table_col_cnt {
344            // still need to add alias, see below
345        } else if query_col_cnt + 1 == table_col_cnt {
346            let last_col_schema = self.schema.column_schemas().last().unwrap();
347
348            // if time index column is auto created add it
349            if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
350                && self.schema.timestamp_index() == Some(table_col_cnt - 1)
351            {
352                exprs.push(placeholder_ts_expr);
353            } else if last_col_schema.data_type.is_timestamp() {
354                // is the update at column
355                exprs.push(datafusion::prelude::now().alias(&last_col_schema.name));
356            } else {
357                // helpful error message
358                return Err(DataFusionError::Plan(format!(
359                    "Expect the last column in table to be timestamp column, found column {} with type {:?}",
360                    last_col_schema.name,
361                    last_col_schema.data_type
362                )));
363            }
364        } else if query_col_cnt + 2 == table_col_cnt {
365            let mut col_iter = self.schema.column_schemas().iter().rev();
366            let last_col_schema = col_iter.next().unwrap();
367            let second_last_col_schema = col_iter.next().unwrap();
368            if second_last_col_schema.data_type.is_timestamp() {
369                exprs.push(datafusion::prelude::now().alias(&second_last_col_schema.name));
370            } else {
371                return Err(DataFusionError::Plan(format!(
372                    "Expect the second last column in the table to be timestamp column, found column {} with type {:?}",
373                    second_last_col_schema.name,
374                    second_last_col_schema.data_type
375                )));
376            }
377
378            if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
379                && self.schema.timestamp_index() == Some(table_col_cnt - 1)
380            {
381                exprs.push(placeholder_ts_expr);
382            } else {
383                return Err(DataFusionError::Plan(format!(
384                    "Expect timestamp column {}, found {:?}",
385                    AUTO_CREATED_PLACEHOLDER_TS_COL, last_col_schema
386                )));
387            }
388        } else {
389            return Err(DataFusionError::Plan(format!(
390                    "Expect table have 0,1 or 2 columns more than query columns, found {} query columns {:?}, {} table columns {:?}",
391                    query_col_cnt, exprs, table_col_cnt, self.schema.column_schemas()
392                )));
393        }
394
395        self.is_rewritten = true;
396        let new_plan = node.with_new_exprs(exprs, node.inputs().into_iter().cloned().collect())?;
397        Ok(Transformed::yes(new_plan))
398    }
399
400    /// We might add new columns, so we need to recompute the schema
401    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
402        node.recompute_schema().map(Transformed::yes)
403    }
404}
405
406/// Find out the `Filter` Node corresponding to innermost(deepest) `WHERE` and add a new filter expr to it
407#[derive(Debug)]
408pub struct AddFilterRewriter {
409    extra_filter: Expr,
410    is_rewritten: bool,
411}
412
413impl AddFilterRewriter {
414    pub fn new(filter: Expr) -> Self {
415        Self {
416            extra_filter: filter,
417            is_rewritten: false,
418        }
419    }
420}
421
422impl TreeNodeRewriter for AddFilterRewriter {
423    type Node = LogicalPlan;
424    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
425        if self.is_rewritten {
426            return Ok(Transformed::no(node));
427        }
428        match node {
429            LogicalPlan::Filter(mut filter) if !filter.having => {
430                filter.predicate = filter.predicate.and(self.extra_filter.clone());
431                self.is_rewritten = true;
432                Ok(Transformed::yes(LogicalPlan::Filter(filter)))
433            }
434            LogicalPlan::TableScan(_) => {
435                // add a new filter
436                let filter =
437                    datafusion_expr::Filter::try_new(self.extra_filter.clone(), Arc::new(node))?;
438                self.is_rewritten = true;
439                Ok(Transformed::yes(LogicalPlan::Filter(filter)))
440            }
441            _ => Ok(Transformed::no(node)),
442        }
443    }
444}
445
446#[cfg(test)]
447mod test {
448    use std::sync::Arc;
449
450    use datafusion_common::tree_node::TreeNode as _;
451    use datatypes::prelude::ConcreteDataType;
452    use datatypes::schema::{ColumnSchema, Schema};
453    use pretty_assertions::assert_eq;
454    use query::query_engine::DefaultSerializer;
455    use session::context::QueryContext;
456    use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
457
458    use super::*;
459    use crate::test_utils::create_test_query_engine;
460
461    /// test if uppercase are handled correctly(with quote)
462    #[tokio::test]
463    async fn test_sql_plan_convert() {
464        let query_engine = create_test_query_engine();
465        let ctx = QueryContext::arc();
466        let old = r#"SELECT "NUMBER" FROM "UPPERCASE_NUMBERS_WITH_TS""#;
467        let new = sql_to_df_plan(ctx.clone(), query_engine.clone(), old, false)
468            .await
469            .unwrap();
470        let new_sql = df_plan_to_sql(&new).unwrap();
471
472        assert_eq!(
473            r#"SELECT `UPPERCASE_NUMBERS_WITH_TS`.`NUMBER` FROM `UPPERCASE_NUMBERS_WITH_TS`"#,
474            new_sql
475        );
476    }
477
478    #[tokio::test]
479    async fn test_add_filter() {
480        let testcases = vec![
481            (
482                "SELECT number FROM numbers_with_ts GROUP BY number",
483                "SELECT numbers_with_ts.number FROM numbers_with_ts WHERE (number > 4) GROUP BY numbers_with_ts.number"
484            ),
485
486            (
487                "SELECT number FROM numbers_with_ts WHERE number < 2 OR number >10",
488                "SELECT numbers_with_ts.number FROM numbers_with_ts WHERE ((numbers_with_ts.number < 2) OR (numbers_with_ts.number > 10)) AND (number > 4)"
489            ),
490
491            (
492                "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window",
493                "SELECT date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE (number > 4) GROUP BY date_bin('5 minutes', numbers_with_ts.ts)"
494            ),
495
496            // subquery
497            (
498                "SELECT number, time_window FROM (SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number);", 
499                "SELECT numbers_with_ts.number, time_window FROM (SELECT numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE (number > 4) GROUP BY date_bin('5 minutes', numbers_with_ts.ts), numbers_with_ts.number)"
500            ),
501
502            // complex subquery without alias
503            (
504                "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) GROUP BY number, time_window, bucket_name;",
505                "SELECT sum(numbers_with_ts.number), numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window, bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE (number > 4)) GROUP BY numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts), bucket_name"
506            ),
507
508            // complex subquery alias
509            (
510                "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) as cte WHERE number > 1 GROUP BY number, time_window, bucket_name;",
511                "SELECT sum(cte.number), cte.number, date_bin('5 minutes', cte.ts) AS time_window, cte.bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE (number > 4)) AS cte WHERE (cte.number > 1) GROUP BY cte.number, date_bin('5 minutes', cte.ts), cte.bucket_name"
512            )
513        ];
514        use datafusion_expr::{col, lit};
515        let query_engine = create_test_query_engine();
516        let ctx = QueryContext::arc();
517
518        for (before, after) in testcases {
519            let sql = before;
520            let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
521                .await
522                .unwrap();
523
524            let mut add_filter = AddFilterRewriter::new(col("number").gt(lit(4u32)));
525            let plan = plan.rewrite(&mut add_filter).unwrap().data;
526            let new_sql = df_plan_to_sql(&plan).unwrap();
527            assert_eq!(after, new_sql);
528        }
529    }
530
531    #[tokio::test]
532    async fn test_add_auto_column_rewriter() {
533        let testcases = vec![
534            // add update_at
535            (
536                "SELECT number FROM numbers_with_ts",
537                Ok("SELECT numbers_with_ts.number, now() AS ts FROM numbers_with_ts"),
538                vec![
539                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
540                    ColumnSchema::new(
541                        "ts",
542                        ConcreteDataType::timestamp_millisecond_datatype(),
543                        false,
544                    )
545                    .with_time_index(true),
546                ],
547            ),
548            // add ts placeholder
549            (
550                "SELECT number FROM numbers_with_ts",
551                Ok("SELECT numbers_with_ts.number, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts"),
552                vec![
553                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
554                    ColumnSchema::new(
555                        AUTO_CREATED_PLACEHOLDER_TS_COL,
556                        ConcreteDataType::timestamp_millisecond_datatype(),
557                        false,
558                    )
559                    .with_time_index(true),
560                ],
561            ),
562            // no modify
563            (
564                "SELECT number, ts FROM numbers_with_ts",
565                Ok("SELECT numbers_with_ts.number, numbers_with_ts.ts FROM numbers_with_ts"),
566                vec![
567                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
568                    ColumnSchema::new(
569                        "ts",
570                        ConcreteDataType::timestamp_millisecond_datatype(),
571                        false,
572                    )
573                    .with_time_index(true),
574                ],
575            ),
576            // add update_at and ts placeholder
577            (
578                "SELECT number FROM numbers_with_ts",
579                Ok("SELECT numbers_with_ts.number, now() AS update_at, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts"),
580                vec![
581                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
582                    ColumnSchema::new(
583                        "update_at",
584                        ConcreteDataType::timestamp_millisecond_datatype(),
585                        false,
586                    ),
587                    ColumnSchema::new(
588                        AUTO_CREATED_PLACEHOLDER_TS_COL,
589                        ConcreteDataType::timestamp_millisecond_datatype(),
590                        false,
591                    )
592                    .with_time_index(true),
593                ],
594            ),
595            // add ts placeholder
596            (
597                "SELECT number, ts FROM numbers_with_ts",
598                Ok("SELECT numbers_with_ts.number, numbers_with_ts.ts AS update_at, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts"),
599                vec![
600                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
601                    ColumnSchema::new(
602                        "update_at",
603                        ConcreteDataType::timestamp_millisecond_datatype(),
604                        false,
605                    ),
606                    ColumnSchema::new(
607                        AUTO_CREATED_PLACEHOLDER_TS_COL,
608                        ConcreteDataType::timestamp_millisecond_datatype(),
609                        false,
610                    )
611                    .with_time_index(true),
612                ],
613            ),
614            // add update_at after time index column
615            (
616                "SELECT number, ts FROM numbers_with_ts",
617                Ok("SELECT numbers_with_ts.number, numbers_with_ts.ts, now() AS update_atat FROM numbers_with_ts"),
618                vec![
619                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
620                    ColumnSchema::new(
621                        "ts",
622                        ConcreteDataType::timestamp_millisecond_datatype(),
623                        false,
624                    )
625                    .with_time_index(true),
626                    ColumnSchema::new(
627                        // name is irrelevant for update_at column
628                        "update_atat",
629                        ConcreteDataType::timestamp_millisecond_datatype(),
630                        false,
631                    ),
632                ],
633            ),
634            // error datatype mismatch
635            (
636                "SELECT number, ts FROM numbers_with_ts",
637                Err("Expect the last column in table to be timestamp column, found column atat with type Int8"),
638                vec![
639                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
640                    ColumnSchema::new(
641                        "ts",
642                        ConcreteDataType::timestamp_millisecond_datatype(),
643                        false,
644                    )
645                    .with_time_index(true),
646                    ColumnSchema::new(
647                        // name is irrelevant for update_at column
648                        "atat",
649                        ConcreteDataType::int8_datatype(),
650                        false,
651                    ),
652                ],
653            ),
654            // error datatype mismatch on second last column
655            (
656                "SELECT number FROM numbers_with_ts",
657                Err("Expect the second last column in the table to be timestamp column, found column ts with type Int8"),
658                vec![
659                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
660                    ColumnSchema::new(
661                        "ts",
662                        ConcreteDataType::int8_datatype(),
663                        false,
664                    ),
665                    ColumnSchema::new(
666                        // name is irrelevant for update_at column
667                        "atat",
668                        ConcreteDataType::timestamp_millisecond_datatype(),
669                        false,
670                    )
671                    .with_time_index(true),
672                ],
673            ),
674        ];
675
676        let query_engine = create_test_query_engine();
677        let ctx = QueryContext::arc();
678        for (before, after, column_schemas) in testcases {
679            let schema = Arc::new(Schema::new(column_schemas));
680            let mut add_auto_column_rewriter = AddAutoColumnRewriter::new(schema);
681
682            let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), before, false)
683                .await
684                .unwrap();
685            let new_sql = (|| {
686                let plan = plan
687                    .rewrite(&mut add_auto_column_rewriter)
688                    .map_err(|e| e.to_string())?
689                    .data;
690                df_plan_to_sql(&plan).map_err(|e| e.to_string())
691            })();
692            match (after, new_sql.clone()) {
693                (Ok(after), Ok(new_sql)) => assert_eq!(after, new_sql),
694                (Err(expected), Err(real_err_msg)) => assert!(
695                    real_err_msg.contains(expected),
696                    "expected: {expected}, real: {real_err_msg}"
697                ),
698                _ => panic!("expected: {:?}, real: {:?}", after, new_sql),
699            }
700        }
701    }
702
703    #[tokio::test]
704    async fn test_find_group_by_exprs() {
705        let testcases = vec![
706            (
707                "SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS ts FROM numbers_with_ts GROUP BY ts;", 
708                vec!["ts"]
709            ),
710            (
711                "SELECT number FROM numbers_with_ts GROUP BY number",
712                vec!["number"]
713            ),
714            (
715                "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window",
716                vec!["time_window"]
717            ),
718             // subquery
719            (
720                "SELECT number, time_window FROM (SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number);", 
721                vec!["time_window", "number"]
722            ),
723            // complex subquery without alias
724            (
725                "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) GROUP BY number, time_window, bucket_name;",
726                vec!["number", "time_window", "bucket_name"]
727            ),
728            // complex subquery alias
729            (
730                "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) as cte GROUP BY number, time_window, bucket_name;",
731                vec!["number", "time_window", "bucket_name"]
732            )
733        ];
734
735        let query_engine = create_test_query_engine();
736        let ctx = QueryContext::arc();
737        for (sql, expected) in testcases {
738            // need to be unoptimize for better readiability
739            let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
740                .await
741                .unwrap();
742            let mut group_by_exprs = FindGroupByFinalName::default();
743            plan.visit(&mut group_by_exprs).unwrap();
744            let expected: HashSet<String> = expected.into_iter().map(|s| s.to_string()).collect();
745            assert_eq!(
746                expected,
747                group_by_exprs.get_group_expr_names().unwrap_or_default()
748            );
749        }
750    }
751
752    #[tokio::test]
753    async fn test_null_cast() {
754        let query_engine = create_test_query_engine();
755        let ctx = QueryContext::arc();
756        let sql = "SELECT NULL::DOUBLE FROM numbers_with_ts";
757        let plan = sql_to_df_plan(ctx, query_engine.clone(), sql, false)
758            .await
759            .unwrap();
760
761        let _sub_plan = DFLogicalSubstraitConvertor {}
762            .encode(&plan, DefaultSerializer)
763            .unwrap();
764    }
765}