flow/batching_mode/
utils.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! some utils for helping with batching mode
16
17use std::collections::{BTreeSet, HashSet};
18use std::sync::Arc;
19
20use catalog::CatalogManagerRef;
21use common_error::ext::BoxedError;
22use common_telemetry::debug;
23use datafusion::error::Result as DfResult;
24use datafusion::logical_expr::Expr;
25use datafusion::sql::unparser::Unparser;
26use datafusion_common::tree_node::{
27    Transformed, TreeNode as _, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
28};
29use datafusion_common::{DFSchema, DataFusionError, ScalarValue};
30use datafusion_expr::{Distinct, LogicalPlan, Projection};
31use datatypes::schema::SchemaRef;
32use query::QueryEngineRef;
33use query::parser::{DEFAULT_LOOKBACK_STRING, PromQuery, QueryLanguageParser, QueryStatement};
34use session::context::QueryContextRef;
35use snafu::{OptionExt, ResultExt, ensure};
36use sql::parser::{ParseOptions, ParserContext};
37use sql::statements::statement::Statement;
38use sql::statements::tql::Tql;
39use table::TableRef;
40
41use crate::adapter::AUTO_CREATED_PLACEHOLDER_TS_COL;
42use crate::df_optimizer::apply_df_optimizer;
43use crate::error::{DatafusionSnafu, ExternalSnafu, InvalidQuerySnafu, TableNotFoundSnafu};
44use crate::{Error, TableName};
45
46pub async fn get_table_info_df_schema(
47    catalog_mr: CatalogManagerRef,
48    table_name: TableName,
49) -> Result<(TableRef, Arc<DFSchema>), Error> {
50    let full_table_name = table_name.clone().join(".");
51    let table = catalog_mr
52        .table(&table_name[0], &table_name[1], &table_name[2], None)
53        .await
54        .map_err(BoxedError::new)
55        .context(ExternalSnafu)?
56        .context(TableNotFoundSnafu {
57            name: &full_table_name,
58        })?;
59    let table_info = table.table_info();
60
61    let schema = table_info.meta.schema.clone();
62
63    let df_schema: Arc<DFSchema> = Arc::new(
64        schema
65            .arrow_schema()
66            .clone()
67            .try_into()
68            .with_context(|_| DatafusionSnafu {
69                context: format!(
70                    "Failed to convert arrow schema to datafusion schema, arrow_schema={:?}",
71                    schema.arrow_schema()
72                ),
73            })?,
74    );
75    Ok((table, df_schema))
76}
77
78/// Convert sql to datafusion logical plan
79/// Also support TQL (but only Eval not Explain or Analyze)
80pub async fn sql_to_df_plan(
81    query_ctx: QueryContextRef,
82    engine: QueryEngineRef,
83    sql: &str,
84    optimize: bool,
85) -> Result<LogicalPlan, Error> {
86    let stmts =
87        ParserContext::create_with_dialect(sql, query_ctx.sql_dialect(), ParseOptions::default())
88            .map_err(BoxedError::new)
89            .context(ExternalSnafu)?;
90
91    ensure!(
92        stmts.len() == 1,
93        InvalidQuerySnafu {
94            reason: format!("Expect only one statement, found {}", stmts.len())
95        }
96    );
97    let stmt = &stmts[0];
98    let query_stmt = match stmt {
99        Statement::Tql(tql) => match tql {
100            Tql::Eval(eval) => {
101                let eval = eval.clone();
102                let promql = PromQuery {
103                    start: eval.start,
104                    end: eval.end,
105                    step: eval.step,
106                    query: eval.query,
107                    lookback: eval
108                        .lookback
109                        .unwrap_or_else(|| DEFAULT_LOOKBACK_STRING.to_string()),
110                    alias: eval.alias.clone(),
111                };
112
113                QueryLanguageParser::parse_promql(&promql, &query_ctx)
114                    .map_err(BoxedError::new)
115                    .context(ExternalSnafu)?
116            }
117            _ => InvalidQuerySnafu {
118                reason: format!("TQL statement {tql:?} is not supported, expect only TQL EVAL"),
119            }
120            .fail()?,
121        },
122        _ => QueryStatement::Sql(stmt.clone()),
123    };
124    let plan = engine
125        .planner()
126        .plan(&query_stmt, query_ctx.clone())
127        .await
128        .map_err(BoxedError::new)
129        .context(ExternalSnafu)?;
130
131    let plan = if optimize {
132        apply_df_optimizer(plan, &query_ctx).await?
133    } else {
134        plan
135    };
136    Ok(plan)
137}
138
139/// Generate a plan that matches the schema of the sink table
140/// from given sql by alias and adding auto columns
141pub(crate) async fn gen_plan_with_matching_schema(
142    sql: &str,
143    query_ctx: QueryContextRef,
144    engine: QueryEngineRef,
145    sink_table_schema: SchemaRef,
146) -> Result<LogicalPlan, Error> {
147    let plan = sql_to_df_plan(query_ctx.clone(), engine.clone(), sql, false).await?;
148
149    let mut add_auto_column = ColumnMatcherRewriter::new(sink_table_schema);
150    let plan = plan
151        .clone()
152        .rewrite(&mut add_auto_column)
153        .with_context(|_| DatafusionSnafu {
154            context: format!("Failed to rewrite plan:\n {}\n", plan),
155        })?
156        .data;
157    Ok(plan)
158}
159
160pub fn df_plan_to_sql(plan: &LogicalPlan) -> Result<String, Error> {
161    /// A dialect that forces identifiers to be quoted when have uppercase
162    struct ForceQuoteIdentifiers;
163    impl datafusion::sql::unparser::dialect::Dialect for ForceQuoteIdentifiers {
164        fn identifier_quote_style(&self, identifier: &str) -> Option<char> {
165            if identifier.to_lowercase() != identifier {
166                Some('`')
167            } else {
168                None
169            }
170        }
171    }
172    let unparser = Unparser::new(&ForceQuoteIdentifiers);
173    // first make all column qualified
174    let sql = unparser
175        .plan_to_sql(plan)
176        .with_context(|_e| DatafusionSnafu {
177            context: format!("Failed to unparse logical plan {plan:?}"),
178        })?;
179    Ok(sql.to_string())
180}
181
182/// Helper to find the innermost group by expr in schema, return None if no group by expr
183#[derive(Debug, Clone, Default)]
184pub struct FindGroupByFinalName {
185    group_exprs: Option<HashSet<datafusion_expr::Expr>>,
186}
187
188impl FindGroupByFinalName {
189    pub fn get_group_expr_names(&self) -> Option<HashSet<String>> {
190        self.group_exprs
191            .as_ref()
192            .map(|exprs| exprs.iter().map(|expr| expr.qualified_name().1).collect())
193    }
194}
195
196impl TreeNodeVisitor<'_> for FindGroupByFinalName {
197    type Node = LogicalPlan;
198
199    fn f_down(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
200        if let LogicalPlan::Aggregate(aggregate) = node {
201            self.group_exprs = Some(aggregate.group_expr.iter().cloned().collect());
202            debug!(
203                "FindGroupByFinalName: Get Group by exprs from Aggregate: {:?}",
204                self.group_exprs
205            );
206        } else if let LogicalPlan::Distinct(distinct) = node {
207            debug!("FindGroupByFinalName: Distinct: {}", node);
208            match distinct {
209                Distinct::All(input) => {
210                    if let LogicalPlan::TableScan(table_scan) = &**input {
211                        // get column from field_qualifier, projection and projected_schema:
212                        let len = table_scan.projected_schema.fields().len();
213                        let columns = (0..len)
214                            .map(|f| {
215                                let (qualifier, field) =
216                                    table_scan.projected_schema.qualified_field(f);
217                                datafusion_common::Column::new(qualifier.cloned(), field.name())
218                            })
219                            .map(datafusion_expr::Expr::Column);
220                        self.group_exprs = Some(columns.collect());
221                    } else {
222                        self.group_exprs = Some(input.expressions().iter().cloned().collect())
223                    }
224                }
225                Distinct::On(distinct_on) => {
226                    self.group_exprs = Some(distinct_on.on_expr.iter().cloned().collect())
227                }
228            }
229            debug!(
230                "FindGroupByFinalName: Get Group by exprs from Distinct: {:?}",
231                self.group_exprs
232            );
233        }
234
235        Ok(TreeNodeRecursion::Continue)
236    }
237
238    /// deal with projection when going up with group exprs
239    fn f_up(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
240        if let LogicalPlan::Projection(projection) = node {
241            for expr in &projection.expr {
242                let Some(group_exprs) = &mut self.group_exprs else {
243                    return Ok(TreeNodeRecursion::Continue);
244                };
245                if let datafusion_expr::Expr::Alias(alias) = expr {
246                    // if a alias exist, replace with the new alias
247                    let mut new_group_exprs = group_exprs.clone();
248                    for group_expr in group_exprs.iter() {
249                        if group_expr.name_for_alias()? == alias.expr.name_for_alias()? {
250                            new_group_exprs.remove(group_expr);
251                            new_group_exprs.insert(expr.clone());
252                            break;
253                        }
254                    }
255                    *group_exprs = new_group_exprs;
256                }
257            }
258        }
259        debug!("Aliased group by exprs: {:?}", self.group_exprs);
260        Ok(TreeNodeRecursion::Continue)
261    }
262}
263
264/// Optionally add to the final select columns like `update_at` if the sink table has such column
265/// (which doesn't necessary need to have exact name just need to be a extra timestamp column)
266/// and `__ts_placeholder`(this column need to have exact this name and be a timestamp)
267/// with values like `now()` and `0`
268///
269/// it also give existing columns alias to column in sink table if needed
270#[derive(Debug)]
271pub struct ColumnMatcherRewriter {
272    pub schema: SchemaRef,
273    pub is_rewritten: bool,
274}
275
276impl ColumnMatcherRewriter {
277    pub fn new(schema: SchemaRef) -> Self {
278        Self {
279            schema,
280            is_rewritten: false,
281        }
282    }
283
284    /// modify the exprs in place so that it matches the schema and some auto columns are added
285    fn modify_project_exprs(&mut self, mut exprs: Vec<Expr>) -> DfResult<Vec<Expr>> {
286        let all_names = self
287            .schema
288            .column_schemas()
289            .iter()
290            .map(|c| c.name.clone())
291            .collect::<BTreeSet<_>>();
292        // first match by position
293        for (idx, expr) in exprs.iter_mut().enumerate() {
294            if !all_names.contains(&expr.qualified_name().1)
295                && let Some(col_name) = self
296                    .schema
297                    .column_schemas()
298                    .get(idx)
299                    .map(|c| c.name.clone())
300            {
301                // if the data type mismatched, later check_execute will error out
302                // hence no need to check it here, beside, optimize pass might be able to cast it
303                // so checking here is not necessary
304                *expr = expr.clone().alias(col_name);
305            }
306        }
307
308        // add columns if have different column count
309        let query_col_cnt = exprs.len();
310        let table_col_cnt = self.schema.column_schemas().len();
311        debug!("query_col_cnt={query_col_cnt}, table_col_cnt={table_col_cnt}");
312
313        let placeholder_ts_expr =
314            datafusion::logical_expr::lit(ScalarValue::TimestampMillisecond(Some(0), None))
315                .alias(AUTO_CREATED_PLACEHOLDER_TS_COL);
316
317        if query_col_cnt == table_col_cnt {
318            // still need to add alias, see below
319        } else if query_col_cnt + 1 == table_col_cnt {
320            let last_col_schema = self.schema.column_schemas().last().unwrap();
321
322            // if time index column is auto created add it
323            if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
324                && self.schema.timestamp_index() == Some(table_col_cnt - 1)
325            {
326                exprs.push(placeholder_ts_expr);
327            } else if last_col_schema.data_type.is_timestamp() {
328                // is the update at column
329                exprs.push(datafusion::prelude::now().alias(&last_col_schema.name));
330            } else {
331                // helpful error message
332                return Err(DataFusionError::Plan(format!(
333                    "Expect the last column in table to be timestamp column, found column {} with type {:?}",
334                    last_col_schema.name, last_col_schema.data_type
335                )));
336            }
337        } else if query_col_cnt + 2 == table_col_cnt {
338            let mut col_iter = self.schema.column_schemas().iter().rev();
339            let last_col_schema = col_iter.next().unwrap();
340            let second_last_col_schema = col_iter.next().unwrap();
341            if second_last_col_schema.data_type.is_timestamp() {
342                exprs.push(datafusion::prelude::now().alias(&second_last_col_schema.name));
343            } else {
344                return Err(DataFusionError::Plan(format!(
345                    "Expect the second last column in the table to be timestamp column, found column {} with type {:?}",
346                    second_last_col_schema.name, second_last_col_schema.data_type
347                )));
348            }
349
350            if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
351                && self.schema.timestamp_index() == Some(table_col_cnt - 1)
352            {
353                exprs.push(placeholder_ts_expr);
354            } else {
355                return Err(DataFusionError::Plan(format!(
356                    "Expect timestamp column {}, found {:?}",
357                    AUTO_CREATED_PLACEHOLDER_TS_COL, last_col_schema
358                )));
359            }
360        } else {
361            return Err(DataFusionError::Plan(format!(
362                "Expect table have 0,1 or 2 columns more than query columns, found {} query columns {:?}, {} table columns {:?}",
363                query_col_cnt,
364                exprs,
365                table_col_cnt,
366                self.schema.column_schemas()
367            )));
368        }
369        Ok(exprs)
370    }
371}
372
373impl TreeNodeRewriter for ColumnMatcherRewriter {
374    type Node = LogicalPlan;
375    fn f_down(&mut self, mut node: Self::Node) -> DfResult<Transformed<Self::Node>> {
376        if self.is_rewritten {
377            return Ok(Transformed::no(node));
378        }
379
380        // if is distinct all, wrap it in a projection
381        if let LogicalPlan::Distinct(Distinct::All(_)) = &node {
382            let mut exprs = vec![];
383
384            for field in node.schema().fields().iter() {
385                exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
386                    field.name(),
387                )));
388            }
389
390            let projection =
391                LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
392
393            node = projection;
394        }
395        // handle table_scan by wrap it in a projection
396        else if let LogicalPlan::TableScan(table_scan) = node {
397            let mut exprs = vec![];
398
399            for field in table_scan.projected_schema.fields().iter() {
400                exprs.push(Expr::Column(datafusion::common::Column::new(
401                    Some(table_scan.table_name.clone()),
402                    field.name(),
403                )));
404            }
405
406            let projection = LogicalPlan::Projection(Projection::try_new(
407                exprs,
408                Arc::new(LogicalPlan::TableScan(table_scan)),
409            )?);
410
411            node = projection;
412        }
413
414        // only do rewrite if found the outermost projection
415        // if the outermost node is projection, can rewrite the exprs
416        // if not, wrap it in a projection
417        if let LogicalPlan::Projection(project) = &node {
418            let exprs = project.expr.clone();
419            let exprs = self.modify_project_exprs(exprs)?;
420
421            self.is_rewritten = true;
422            let new_plan =
423                node.with_new_exprs(exprs, node.inputs().into_iter().cloned().collect())?;
424            Ok(Transformed::yes(new_plan))
425        } else {
426            // wrap the logical plan in a projection
427            let mut exprs = vec![];
428            for field in node.schema().fields().iter() {
429                exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
430                    field.name(),
431                )));
432            }
433            let exprs = self.modify_project_exprs(exprs)?;
434            self.is_rewritten = true;
435            let new_plan =
436                LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
437            Ok(Transformed::yes(new_plan))
438        }
439    }
440
441    /// We might add new columns, so we need to recompute the schema
442    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
443        node.recompute_schema().map(Transformed::yes)
444    }
445}
446
447/// Find out the `Filter` Node corresponding to innermost(deepest) `WHERE` and add a new filter expr to it
448#[derive(Debug)]
449pub struct AddFilterRewriter {
450    extra_filter: Expr,
451    is_rewritten: bool,
452}
453
454impl AddFilterRewriter {
455    pub fn new(filter: Expr) -> Self {
456        Self {
457            extra_filter: filter,
458            is_rewritten: false,
459        }
460    }
461}
462
463impl TreeNodeRewriter for AddFilterRewriter {
464    type Node = LogicalPlan;
465    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
466        if self.is_rewritten {
467            return Ok(Transformed::no(node));
468        }
469        match node {
470            LogicalPlan::Filter(mut filter) => {
471                filter.predicate = filter.predicate.and(self.extra_filter.clone());
472                self.is_rewritten = true;
473                Ok(Transformed::yes(LogicalPlan::Filter(filter)))
474            }
475            LogicalPlan::TableScan(_) => {
476                // add a new filter
477                let filter =
478                    datafusion_expr::Filter::try_new(self.extra_filter.clone(), Arc::new(node))?;
479                self.is_rewritten = true;
480                Ok(Transformed::yes(LogicalPlan::Filter(filter)))
481            }
482            _ => Ok(Transformed::no(node)),
483        }
484    }
485}
486
487#[cfg(test)]
488mod test {
489    use std::sync::Arc;
490
491    use datafusion_common::tree_node::TreeNode as _;
492    use datatypes::prelude::ConcreteDataType;
493    use datatypes::schema::{ColumnSchema, Schema};
494    use pretty_assertions::assert_eq;
495    use query::query_engine::DefaultSerializer;
496    use session::context::QueryContext;
497    use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
498
499    use super::*;
500    use crate::test_utils::create_test_query_engine;
501
502    /// test if uppercase are handled correctly(with quote)
503    #[tokio::test]
504    async fn test_sql_plan_convert() {
505        let query_engine = create_test_query_engine();
506        let ctx = QueryContext::arc();
507        let old = r#"SELECT "NUMBER" FROM "UPPERCASE_NUMBERS_WITH_TS""#;
508        let new = sql_to_df_plan(ctx.clone(), query_engine.clone(), old, false)
509            .await
510            .unwrap();
511        let new_sql = df_plan_to_sql(&new).unwrap();
512
513        assert_eq!(
514            r#"SELECT `UPPERCASE_NUMBERS_WITH_TS`.`NUMBER` FROM `UPPERCASE_NUMBERS_WITH_TS`"#,
515            new_sql
516        );
517    }
518
519    #[tokio::test]
520    async fn test_add_filter() {
521        let testcases = vec![
522            (
523                "SELECT number FROM numbers_with_ts GROUP BY number",
524                "SELECT numbers_with_ts.number FROM numbers_with_ts WHERE (number > 4) GROUP BY numbers_with_ts.number",
525            ),
526            (
527                "SELECT number FROM numbers_with_ts WHERE number < 2 OR number >10",
528                "SELECT numbers_with_ts.number FROM numbers_with_ts WHERE ((numbers_with_ts.number < 2) OR (numbers_with_ts.number > 10)) AND (number > 4)",
529            ),
530            (
531                "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window",
532                "SELECT date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE (number > 4) GROUP BY date_bin('5 minutes', numbers_with_ts.ts)",
533            ),
534            // subquery
535            (
536                "SELECT number, time_window FROM (SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number);",
537                "SELECT numbers_with_ts.number, time_window FROM (SELECT numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE (number > 4) GROUP BY date_bin('5 minutes', numbers_with_ts.ts), numbers_with_ts.number)",
538            ),
539            // complex subquery without alias
540            (
541                "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) GROUP BY number, time_window, bucket_name;",
542                "SELECT sum(numbers_with_ts.number), numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window, bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE (number > 4)) GROUP BY numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts), bucket_name",
543            ),
544            // complex subquery alias
545            (
546                "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) as cte WHERE number > 1 GROUP BY number, time_window, bucket_name;",
547                "SELECT sum(cte.number), cte.number, date_bin('5 minutes', cte.ts) AS time_window, cte.bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE (number > 4)) AS cte WHERE (cte.number > 1) GROUP BY cte.number, date_bin('5 minutes', cte.ts), cte.bucket_name",
548            ),
549        ];
550        use datafusion_expr::{col, lit};
551        let query_engine = create_test_query_engine();
552        let ctx = QueryContext::arc();
553
554        for (before, after) in testcases {
555            let sql = before;
556            let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
557                .await
558                .unwrap();
559
560            let mut add_filter = AddFilterRewriter::new(col("number").gt(lit(4u32)));
561            let plan = plan.rewrite(&mut add_filter).unwrap().data;
562            let new_sql = df_plan_to_sql(&plan).unwrap();
563            assert_eq!(after, new_sql);
564        }
565    }
566
567    #[tokio::test]
568    async fn test_add_auto_column_rewriter() {
569        let testcases = vec![
570            // add update_at
571            (
572                "SELECT number FROM numbers_with_ts",
573                Ok("SELECT numbers_with_ts.number, now() AS ts FROM numbers_with_ts"),
574                vec![
575                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
576                    ColumnSchema::new(
577                        "ts",
578                        ConcreteDataType::timestamp_millisecond_datatype(),
579                        false,
580                    )
581                    .with_time_index(true),
582                ],
583            ),
584            // add ts placeholder
585            (
586                "SELECT number FROM numbers_with_ts",
587                Ok(
588                    "SELECT numbers_with_ts.number, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts",
589                ),
590                vec![
591                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
592                    ColumnSchema::new(
593                        AUTO_CREATED_PLACEHOLDER_TS_COL,
594                        ConcreteDataType::timestamp_millisecond_datatype(),
595                        false,
596                    )
597                    .with_time_index(true),
598                ],
599            ),
600            // no modify
601            (
602                "SELECT number, ts FROM numbers_with_ts",
603                Ok("SELECT numbers_with_ts.number, numbers_with_ts.ts FROM numbers_with_ts"),
604                vec![
605                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
606                    ColumnSchema::new(
607                        "ts",
608                        ConcreteDataType::timestamp_millisecond_datatype(),
609                        false,
610                    )
611                    .with_time_index(true),
612                ],
613            ),
614            // add update_at and ts placeholder
615            (
616                "SELECT number FROM numbers_with_ts",
617                Ok(
618                    "SELECT numbers_with_ts.number, now() AS update_at, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts",
619                ),
620                vec![
621                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
622                    ColumnSchema::new(
623                        "update_at",
624                        ConcreteDataType::timestamp_millisecond_datatype(),
625                        false,
626                    ),
627                    ColumnSchema::new(
628                        AUTO_CREATED_PLACEHOLDER_TS_COL,
629                        ConcreteDataType::timestamp_millisecond_datatype(),
630                        false,
631                    )
632                    .with_time_index(true),
633                ],
634            ),
635            // add ts placeholder
636            (
637                "SELECT number, ts FROM numbers_with_ts",
638                Ok(
639                    "SELECT numbers_with_ts.number, numbers_with_ts.ts AS update_at, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts",
640                ),
641                vec![
642                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
643                    ColumnSchema::new(
644                        "update_at",
645                        ConcreteDataType::timestamp_millisecond_datatype(),
646                        false,
647                    ),
648                    ColumnSchema::new(
649                        AUTO_CREATED_PLACEHOLDER_TS_COL,
650                        ConcreteDataType::timestamp_millisecond_datatype(),
651                        false,
652                    )
653                    .with_time_index(true),
654                ],
655            ),
656            // add update_at after time index column
657            (
658                "SELECT number, ts FROM numbers_with_ts",
659                Ok(
660                    "SELECT numbers_with_ts.number, numbers_with_ts.ts, now() AS update_atat FROM numbers_with_ts",
661                ),
662                vec![
663                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
664                    ColumnSchema::new(
665                        "ts",
666                        ConcreteDataType::timestamp_millisecond_datatype(),
667                        false,
668                    )
669                    .with_time_index(true),
670                    ColumnSchema::new(
671                        // name is irrelevant for update_at column
672                        "update_atat",
673                        ConcreteDataType::timestamp_millisecond_datatype(),
674                        false,
675                    ),
676                ],
677            ),
678            // error datatype mismatch
679            (
680                "SELECT number, ts FROM numbers_with_ts",
681                Err(
682                    "Expect the last column in table to be timestamp column, found column atat with type Int8",
683                ),
684                vec![
685                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
686                    ColumnSchema::new(
687                        "ts",
688                        ConcreteDataType::timestamp_millisecond_datatype(),
689                        false,
690                    )
691                    .with_time_index(true),
692                    ColumnSchema::new(
693                        // name is irrelevant for update_at column
694                        "atat",
695                        ConcreteDataType::int8_datatype(),
696                        false,
697                    ),
698                ],
699            ),
700            // error datatype mismatch on second last column
701            (
702                "SELECT number FROM numbers_with_ts",
703                Err(
704                    "Expect the second last column in the table to be timestamp column, found column ts with type Int8",
705                ),
706                vec![
707                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
708                    ColumnSchema::new("ts", ConcreteDataType::int8_datatype(), false),
709                    ColumnSchema::new(
710                        // name is irrelevant for update_at column
711                        "atat",
712                        ConcreteDataType::timestamp_millisecond_datatype(),
713                        false,
714                    )
715                    .with_time_index(true),
716                ],
717            ),
718        ];
719
720        let query_engine = create_test_query_engine();
721        let ctx = QueryContext::arc();
722        for (before, after, column_schemas) in testcases {
723            let schema = Arc::new(Schema::new(column_schemas));
724            let mut add_auto_column_rewriter = ColumnMatcherRewriter::new(schema);
725
726            let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), before, false)
727                .await
728                .unwrap();
729            let new_sql = (|| {
730                let plan = plan
731                    .rewrite(&mut add_auto_column_rewriter)
732                    .map_err(|e| e.to_string())?
733                    .data;
734                df_plan_to_sql(&plan).map_err(|e| e.to_string())
735            })();
736            match (after, new_sql.clone()) {
737                (Ok(after), Ok(new_sql)) => assert_eq!(after, new_sql),
738                (Err(expected), Err(real_err_msg)) => assert!(
739                    real_err_msg.contains(expected),
740                    "expected: {expected}, real: {real_err_msg}"
741                ),
742                _ => panic!("expected: {:?}, real: {:?}", after, new_sql),
743            }
744        }
745    }
746
747    #[tokio::test]
748    async fn test_find_group_by_exprs() {
749        let testcases = vec![
750            (
751                "SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS ts FROM numbers_with_ts GROUP BY ts;",
752                vec!["ts"],
753            ),
754            (
755                "SELECT number FROM numbers_with_ts GROUP BY number",
756                vec!["number"],
757            ),
758            (
759                "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window",
760                vec!["time_window"],
761            ),
762            // subquery
763            (
764                "SELECT number, time_window FROM (SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number);",
765                vec!["time_window", "number"],
766            ),
767            // complex subquery without alias
768            (
769                "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) GROUP BY number, time_window, bucket_name;",
770                vec!["number", "time_window", "bucket_name"],
771            ),
772            // complex subquery alias
773            (
774                "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) as cte GROUP BY number, time_window, bucket_name;",
775                vec!["number", "time_window", "bucket_name"],
776            ),
777        ];
778
779        let query_engine = create_test_query_engine();
780        let ctx = QueryContext::arc();
781        for (sql, expected) in testcases {
782            // need to be unoptimize for better readiability
783            let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
784                .await
785                .unwrap();
786            let mut group_by_exprs = FindGroupByFinalName::default();
787            plan.visit(&mut group_by_exprs).unwrap();
788            let expected: HashSet<String> = expected.into_iter().map(|s| s.to_string()).collect();
789            assert_eq!(
790                expected,
791                group_by_exprs.get_group_expr_names().unwrap_or_default()
792            );
793        }
794    }
795
796    #[tokio::test]
797    async fn test_null_cast() {
798        let query_engine = create_test_query_engine();
799        let ctx = QueryContext::arc();
800        let sql = "SELECT NULL::DOUBLE FROM numbers_with_ts";
801        let plan = sql_to_df_plan(ctx, query_engine.clone(), sql, false)
802            .await
803            .unwrap();
804
805        let _sub_plan = DFLogicalSubstraitConvertor {}
806            .encode(&plan, DefaultSerializer)
807            .unwrap();
808    }
809}