flow/batching_mode/
utils.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! some utils for helping with batching mode
16
17use std::collections::{BTreeSet, HashSet};
18use std::sync::Arc;
19
20use catalog::CatalogManagerRef;
21use common_error::ext::BoxedError;
22use common_telemetry::debug;
23use datafusion::error::Result as DfResult;
24use datafusion::logical_expr::Expr;
25use datafusion::sql::unparser::Unparser;
26use datafusion_common::tree_node::{
27    Transformed, TreeNode as _, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
28};
29use datafusion_common::{DFSchema, DataFusionError, ScalarValue};
30use datafusion_expr::{Distinct, LogicalPlan, Projection};
31use datatypes::schema::SchemaRef;
32use query::QueryEngineRef;
33use query::parser::{DEFAULT_LOOKBACK_STRING, PromQuery, QueryLanguageParser, QueryStatement};
34use session::context::QueryContextRef;
35use snafu::{OptionExt, ResultExt, ensure};
36use sql::parser::{ParseOptions, ParserContext};
37use sql::statements::statement::Statement;
38use sql::statements::tql::Tql;
39use table::TableRef;
40
41use crate::adapter::AUTO_CREATED_PLACEHOLDER_TS_COL;
42use crate::df_optimizer::apply_df_optimizer;
43use crate::error::{DatafusionSnafu, ExternalSnafu, InvalidQuerySnafu, TableNotFoundSnafu};
44use crate::{Error, TableName};
45
46pub async fn get_table_info_df_schema(
47    catalog_mr: CatalogManagerRef,
48    table_name: TableName,
49) -> Result<(TableRef, Arc<DFSchema>), Error> {
50    let full_table_name = table_name.clone().join(".");
51    let table = catalog_mr
52        .table(&table_name[0], &table_name[1], &table_name[2], None)
53        .await
54        .map_err(BoxedError::new)
55        .context(ExternalSnafu)?
56        .context(TableNotFoundSnafu {
57            name: &full_table_name,
58        })?;
59    let table_info = table.table_info();
60
61    let schema = table_info.meta.schema.clone();
62
63    let df_schema: Arc<DFSchema> = Arc::new(
64        schema
65            .arrow_schema()
66            .clone()
67            .try_into()
68            .with_context(|_| DatafusionSnafu {
69                context: format!(
70                    "Failed to convert arrow schema to datafusion schema, arrow_schema={:?}",
71                    schema.arrow_schema()
72                ),
73            })?,
74    );
75    Ok((table, df_schema))
76}
77
78/// Convert sql to datafusion logical plan
79/// Also support TQL (but only Eval not Explain or Analyze)
80pub async fn sql_to_df_plan(
81    query_ctx: QueryContextRef,
82    engine: QueryEngineRef,
83    sql: &str,
84    optimize: bool,
85) -> Result<LogicalPlan, Error> {
86    let stmts =
87        ParserContext::create_with_dialect(sql, query_ctx.sql_dialect(), ParseOptions::default())
88            .map_err(BoxedError::new)
89            .context(ExternalSnafu)?;
90
91    ensure!(
92        stmts.len() == 1,
93        InvalidQuerySnafu {
94            reason: format!("Expect only one statement, found {}", stmts.len())
95        }
96    );
97    let stmt = &stmts[0];
98    let query_stmt = match stmt {
99        Statement::Tql(tql) => match tql {
100            Tql::Eval(eval) => {
101                let eval = eval.clone();
102                let promql = PromQuery {
103                    start: eval.start,
104                    end: eval.end,
105                    step: eval.step,
106                    query: eval.query,
107                    lookback: eval
108                        .lookback
109                        .unwrap_or_else(|| DEFAULT_LOOKBACK_STRING.to_string()),
110                };
111
112                QueryLanguageParser::parse_promql(&promql, &query_ctx)
113                    .map_err(BoxedError::new)
114                    .context(ExternalSnafu)?
115            }
116            _ => InvalidQuerySnafu {
117                reason: format!("TQL statement {tql:?} is not supported, expect only TQL EVAL"),
118            }
119            .fail()?,
120        },
121        _ => QueryStatement::Sql(stmt.clone()),
122    };
123    let plan = engine
124        .planner()
125        .plan(&query_stmt, query_ctx.clone())
126        .await
127        .map_err(BoxedError::new)
128        .context(ExternalSnafu)?;
129
130    let plan = if optimize {
131        apply_df_optimizer(plan, &query_ctx).await?
132    } else {
133        plan
134    };
135    Ok(plan)
136}
137
138/// Generate a plan that matches the schema of the sink table
139/// from given sql by alias and adding auto columns
140pub(crate) async fn gen_plan_with_matching_schema(
141    sql: &str,
142    query_ctx: QueryContextRef,
143    engine: QueryEngineRef,
144    sink_table_schema: SchemaRef,
145) -> Result<LogicalPlan, Error> {
146    let plan = sql_to_df_plan(query_ctx.clone(), engine.clone(), sql, false).await?;
147
148    let mut add_auto_column = ColumnMatcherRewriter::new(sink_table_schema);
149    let plan = plan
150        .clone()
151        .rewrite(&mut add_auto_column)
152        .with_context(|_| DatafusionSnafu {
153            context: format!("Failed to rewrite plan:\n {}\n", plan),
154        })?
155        .data;
156    Ok(plan)
157}
158
159pub fn df_plan_to_sql(plan: &LogicalPlan) -> Result<String, Error> {
160    /// A dialect that forces identifiers to be quoted when have uppercase
161    struct ForceQuoteIdentifiers;
162    impl datafusion::sql::unparser::dialect::Dialect for ForceQuoteIdentifiers {
163        fn identifier_quote_style(&self, identifier: &str) -> Option<char> {
164            if identifier.to_lowercase() != identifier {
165                Some('`')
166            } else {
167                None
168            }
169        }
170    }
171    let unparser = Unparser::new(&ForceQuoteIdentifiers);
172    // first make all column qualified
173    let sql = unparser
174        .plan_to_sql(plan)
175        .with_context(|_e| DatafusionSnafu {
176            context: format!("Failed to unparse logical plan {plan:?}"),
177        })?;
178    Ok(sql.to_string())
179}
180
181/// Helper to find the innermost group by expr in schema, return None if no group by expr
182#[derive(Debug, Clone, Default)]
183pub struct FindGroupByFinalName {
184    group_exprs: Option<HashSet<datafusion_expr::Expr>>,
185}
186
187impl FindGroupByFinalName {
188    pub fn get_group_expr_names(&self) -> Option<HashSet<String>> {
189        self.group_exprs
190            .as_ref()
191            .map(|exprs| exprs.iter().map(|expr| expr.qualified_name().1).collect())
192    }
193}
194
195impl TreeNodeVisitor<'_> for FindGroupByFinalName {
196    type Node = LogicalPlan;
197
198    fn f_down(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
199        if let LogicalPlan::Aggregate(aggregate) = node {
200            self.group_exprs = Some(aggregate.group_expr.iter().cloned().collect());
201            debug!(
202                "FindGroupByFinalName: Get Group by exprs from Aggregate: {:?}",
203                self.group_exprs
204            );
205        } else if let LogicalPlan::Distinct(distinct) = node {
206            debug!("FindGroupByFinalName: Distinct: {}", node);
207            match distinct {
208                Distinct::All(input) => {
209                    if let LogicalPlan::TableScan(table_scan) = &**input {
210                        // get column from field_qualifier, projection and projected_schema:
211                        let len = table_scan.projected_schema.fields().len();
212                        let columns = (0..len)
213                            .map(|f| {
214                                let (qualifier, field) =
215                                    table_scan.projected_schema.qualified_field(f);
216                                datafusion_common::Column::new(qualifier.cloned(), field.name())
217                            })
218                            .map(datafusion_expr::Expr::Column);
219                        self.group_exprs = Some(columns.collect());
220                    } else {
221                        self.group_exprs = Some(input.expressions().iter().cloned().collect())
222                    }
223                }
224                Distinct::On(distinct_on) => {
225                    self.group_exprs = Some(distinct_on.on_expr.iter().cloned().collect())
226                }
227            }
228            debug!(
229                "FindGroupByFinalName: Get Group by exprs from Distinct: {:?}",
230                self.group_exprs
231            );
232        }
233
234        Ok(TreeNodeRecursion::Continue)
235    }
236
237    /// deal with projection when going up with group exprs
238    fn f_up(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
239        if let LogicalPlan::Projection(projection) = node {
240            for expr in &projection.expr {
241                let Some(group_exprs) = &mut self.group_exprs else {
242                    return Ok(TreeNodeRecursion::Continue);
243                };
244                if let datafusion_expr::Expr::Alias(alias) = expr {
245                    // if a alias exist, replace with the new alias
246                    let mut new_group_exprs = group_exprs.clone();
247                    for group_expr in group_exprs.iter() {
248                        if group_expr.name_for_alias()? == alias.expr.name_for_alias()? {
249                            new_group_exprs.remove(group_expr);
250                            new_group_exprs.insert(expr.clone());
251                            break;
252                        }
253                    }
254                    *group_exprs = new_group_exprs;
255                }
256            }
257        }
258        debug!("Aliased group by exprs: {:?}", self.group_exprs);
259        Ok(TreeNodeRecursion::Continue)
260    }
261}
262
263/// Optionally add to the final select columns like `update_at` if the sink table has such column
264/// (which doesn't necessary need to have exact name just need to be a extra timestamp column)
265/// and `__ts_placeholder`(this column need to have exact this name and be a timestamp)
266/// with values like `now()` and `0`
267///
268/// it also give existing columns alias to column in sink table if needed
269#[derive(Debug)]
270pub struct ColumnMatcherRewriter {
271    pub schema: SchemaRef,
272    pub is_rewritten: bool,
273}
274
275impl ColumnMatcherRewriter {
276    pub fn new(schema: SchemaRef) -> Self {
277        Self {
278            schema,
279            is_rewritten: false,
280        }
281    }
282
283    /// modify the exprs in place so that it matches the schema and some auto columns are added
284    fn modify_project_exprs(&mut self, mut exprs: Vec<Expr>) -> DfResult<Vec<Expr>> {
285        let all_names = self
286            .schema
287            .column_schemas()
288            .iter()
289            .map(|c| c.name.clone())
290            .collect::<BTreeSet<_>>();
291        // first match by position
292        for (idx, expr) in exprs.iter_mut().enumerate() {
293            if !all_names.contains(&expr.qualified_name().1)
294                && let Some(col_name) = self
295                    .schema
296                    .column_schemas()
297                    .get(idx)
298                    .map(|c| c.name.clone())
299            {
300                // if the data type mismatched, later check_execute will error out
301                // hence no need to check it here, beside, optimize pass might be able to cast it
302                // so checking here is not necessary
303                *expr = expr.clone().alias(col_name);
304            }
305        }
306
307        // add columns if have different column count
308        let query_col_cnt = exprs.len();
309        let table_col_cnt = self.schema.column_schemas().len();
310        debug!("query_col_cnt={query_col_cnt}, table_col_cnt={table_col_cnt}");
311
312        let placeholder_ts_expr =
313            datafusion::logical_expr::lit(ScalarValue::TimestampMillisecond(Some(0), None))
314                .alias(AUTO_CREATED_PLACEHOLDER_TS_COL);
315
316        if query_col_cnt == table_col_cnt {
317            // still need to add alias, see below
318        } else if query_col_cnt + 1 == table_col_cnt {
319            let last_col_schema = self.schema.column_schemas().last().unwrap();
320
321            // if time index column is auto created add it
322            if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
323                && self.schema.timestamp_index() == Some(table_col_cnt - 1)
324            {
325                exprs.push(placeholder_ts_expr);
326            } else if last_col_schema.data_type.is_timestamp() {
327                // is the update at column
328                exprs.push(datafusion::prelude::now().alias(&last_col_schema.name));
329            } else {
330                // helpful error message
331                return Err(DataFusionError::Plan(format!(
332                    "Expect the last column in table to be timestamp column, found column {} with type {:?}",
333                    last_col_schema.name, last_col_schema.data_type
334                )));
335            }
336        } else if query_col_cnt + 2 == table_col_cnt {
337            let mut col_iter = self.schema.column_schemas().iter().rev();
338            let last_col_schema = col_iter.next().unwrap();
339            let second_last_col_schema = col_iter.next().unwrap();
340            if second_last_col_schema.data_type.is_timestamp() {
341                exprs.push(datafusion::prelude::now().alias(&second_last_col_schema.name));
342            } else {
343                return Err(DataFusionError::Plan(format!(
344                    "Expect the second last column in the table to be timestamp column, found column {} with type {:?}",
345                    second_last_col_schema.name, second_last_col_schema.data_type
346                )));
347            }
348
349            if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
350                && self.schema.timestamp_index() == Some(table_col_cnt - 1)
351            {
352                exprs.push(placeholder_ts_expr);
353            } else {
354                return Err(DataFusionError::Plan(format!(
355                    "Expect timestamp column {}, found {:?}",
356                    AUTO_CREATED_PLACEHOLDER_TS_COL, last_col_schema
357                )));
358            }
359        } else {
360            return Err(DataFusionError::Plan(format!(
361                "Expect table have 0,1 or 2 columns more than query columns, found {} query columns {:?}, {} table columns {:?}",
362                query_col_cnt,
363                exprs,
364                table_col_cnt,
365                self.schema.column_schemas()
366            )));
367        }
368        Ok(exprs)
369    }
370}
371
372impl TreeNodeRewriter for ColumnMatcherRewriter {
373    type Node = LogicalPlan;
374    fn f_down(&mut self, mut node: Self::Node) -> DfResult<Transformed<Self::Node>> {
375        if self.is_rewritten {
376            return Ok(Transformed::no(node));
377        }
378
379        // if is distinct all, wrap it in a projection
380        if let LogicalPlan::Distinct(Distinct::All(_)) = &node {
381            let mut exprs = vec![];
382
383            for field in node.schema().fields().iter() {
384                exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
385                    field.name(),
386                )));
387            }
388
389            let projection =
390                LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
391
392            node = projection;
393        }
394        // handle table_scan by wrap it in a projection
395        else if let LogicalPlan::TableScan(table_scan) = node {
396            let mut exprs = vec![];
397
398            for field in table_scan.projected_schema.fields().iter() {
399                exprs.push(Expr::Column(datafusion::common::Column::new(
400                    Some(table_scan.table_name.clone()),
401                    field.name(),
402                )));
403            }
404
405            let projection = LogicalPlan::Projection(Projection::try_new(
406                exprs,
407                Arc::new(LogicalPlan::TableScan(table_scan)),
408            )?);
409
410            node = projection;
411        }
412
413        // only do rewrite if found the outermost projection
414        // if the outermost node is projection, can rewrite the exprs
415        // if not, wrap it in a projection
416        if let LogicalPlan::Projection(project) = &node {
417            let exprs = project.expr.clone();
418            let exprs = self.modify_project_exprs(exprs)?;
419
420            self.is_rewritten = true;
421            let new_plan =
422                node.with_new_exprs(exprs, node.inputs().into_iter().cloned().collect())?;
423            Ok(Transformed::yes(new_plan))
424        } else {
425            // wrap the logical plan in a projection
426            let mut exprs = vec![];
427            for field in node.schema().fields().iter() {
428                exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
429                    field.name(),
430                )));
431            }
432            let exprs = self.modify_project_exprs(exprs)?;
433            self.is_rewritten = true;
434            let new_plan =
435                LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
436            Ok(Transformed::yes(new_plan))
437        }
438    }
439
440    /// We might add new columns, so we need to recompute the schema
441    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
442        node.recompute_schema().map(Transformed::yes)
443    }
444}
445
446/// Find out the `Filter` Node corresponding to innermost(deepest) `WHERE` and add a new filter expr to it
447#[derive(Debug)]
448pub struct AddFilterRewriter {
449    extra_filter: Expr,
450    is_rewritten: bool,
451}
452
453impl AddFilterRewriter {
454    pub fn new(filter: Expr) -> Self {
455        Self {
456            extra_filter: filter,
457            is_rewritten: false,
458        }
459    }
460}
461
462impl TreeNodeRewriter for AddFilterRewriter {
463    type Node = LogicalPlan;
464    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
465        if self.is_rewritten {
466            return Ok(Transformed::no(node));
467        }
468        match node {
469            LogicalPlan::Filter(mut filter) => {
470                filter.predicate = filter.predicate.and(self.extra_filter.clone());
471                self.is_rewritten = true;
472                Ok(Transformed::yes(LogicalPlan::Filter(filter)))
473            }
474            LogicalPlan::TableScan(_) => {
475                // add a new filter
476                let filter =
477                    datafusion_expr::Filter::try_new(self.extra_filter.clone(), Arc::new(node))?;
478                self.is_rewritten = true;
479                Ok(Transformed::yes(LogicalPlan::Filter(filter)))
480            }
481            _ => Ok(Transformed::no(node)),
482        }
483    }
484}
485
486#[cfg(test)]
487mod test {
488    use std::sync::Arc;
489
490    use datafusion_common::tree_node::TreeNode as _;
491    use datatypes::prelude::ConcreteDataType;
492    use datatypes::schema::{ColumnSchema, Schema};
493    use pretty_assertions::assert_eq;
494    use query::query_engine::DefaultSerializer;
495    use session::context::QueryContext;
496    use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
497
498    use super::*;
499    use crate::test_utils::create_test_query_engine;
500
501    /// test if uppercase are handled correctly(with quote)
502    #[tokio::test]
503    async fn test_sql_plan_convert() {
504        let query_engine = create_test_query_engine();
505        let ctx = QueryContext::arc();
506        let old = r#"SELECT "NUMBER" FROM "UPPERCASE_NUMBERS_WITH_TS""#;
507        let new = sql_to_df_plan(ctx.clone(), query_engine.clone(), old, false)
508            .await
509            .unwrap();
510        let new_sql = df_plan_to_sql(&new).unwrap();
511
512        assert_eq!(
513            r#"SELECT `UPPERCASE_NUMBERS_WITH_TS`.`NUMBER` FROM `UPPERCASE_NUMBERS_WITH_TS`"#,
514            new_sql
515        );
516    }
517
518    #[tokio::test]
519    async fn test_add_filter() {
520        let testcases = vec![
521            (
522                "SELECT number FROM numbers_with_ts GROUP BY number",
523                "SELECT numbers_with_ts.number FROM numbers_with_ts WHERE (number > 4) GROUP BY numbers_with_ts.number",
524            ),
525            (
526                "SELECT number FROM numbers_with_ts WHERE number < 2 OR number >10",
527                "SELECT numbers_with_ts.number FROM numbers_with_ts WHERE ((numbers_with_ts.number < 2) OR (numbers_with_ts.number > 10)) AND (number > 4)",
528            ),
529            (
530                "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window",
531                "SELECT date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE (number > 4) GROUP BY date_bin('5 minutes', numbers_with_ts.ts)",
532            ),
533            // subquery
534            (
535                "SELECT number, time_window FROM (SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number);",
536                "SELECT numbers_with_ts.number, time_window FROM (SELECT numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE (number > 4) GROUP BY date_bin('5 minutes', numbers_with_ts.ts), numbers_with_ts.number)",
537            ),
538            // complex subquery without alias
539            (
540                "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) GROUP BY number, time_window, bucket_name;",
541                "SELECT sum(numbers_with_ts.number), numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window, bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE (number > 4)) GROUP BY numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts), bucket_name",
542            ),
543            // complex subquery alias
544            (
545                "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) as cte WHERE number > 1 GROUP BY number, time_window, bucket_name;",
546                "SELECT sum(cte.number), cte.number, date_bin('5 minutes', cte.ts) AS time_window, cte.bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE (number > 4)) AS cte WHERE (cte.number > 1) GROUP BY cte.number, date_bin('5 minutes', cte.ts), cte.bucket_name",
547            ),
548        ];
549        use datafusion_expr::{col, lit};
550        let query_engine = create_test_query_engine();
551        let ctx = QueryContext::arc();
552
553        for (before, after) in testcases {
554            let sql = before;
555            let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
556                .await
557                .unwrap();
558
559            let mut add_filter = AddFilterRewriter::new(col("number").gt(lit(4u32)));
560            let plan = plan.rewrite(&mut add_filter).unwrap().data;
561            let new_sql = df_plan_to_sql(&plan).unwrap();
562            assert_eq!(after, new_sql);
563        }
564    }
565
566    #[tokio::test]
567    async fn test_add_auto_column_rewriter() {
568        let testcases = vec![
569            // add update_at
570            (
571                "SELECT number FROM numbers_with_ts",
572                Ok("SELECT numbers_with_ts.number, now() AS ts FROM numbers_with_ts"),
573                vec![
574                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
575                    ColumnSchema::new(
576                        "ts",
577                        ConcreteDataType::timestamp_millisecond_datatype(),
578                        false,
579                    )
580                    .with_time_index(true),
581                ],
582            ),
583            // add ts placeholder
584            (
585                "SELECT number FROM numbers_with_ts",
586                Ok(
587                    "SELECT numbers_with_ts.number, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts",
588                ),
589                vec![
590                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
591                    ColumnSchema::new(
592                        AUTO_CREATED_PLACEHOLDER_TS_COL,
593                        ConcreteDataType::timestamp_millisecond_datatype(),
594                        false,
595                    )
596                    .with_time_index(true),
597                ],
598            ),
599            // no modify
600            (
601                "SELECT number, ts FROM numbers_with_ts",
602                Ok("SELECT numbers_with_ts.number, numbers_with_ts.ts FROM numbers_with_ts"),
603                vec![
604                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
605                    ColumnSchema::new(
606                        "ts",
607                        ConcreteDataType::timestamp_millisecond_datatype(),
608                        false,
609                    )
610                    .with_time_index(true),
611                ],
612            ),
613            // add update_at and ts placeholder
614            (
615                "SELECT number FROM numbers_with_ts",
616                Ok(
617                    "SELECT numbers_with_ts.number, now() AS update_at, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts",
618                ),
619                vec![
620                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
621                    ColumnSchema::new(
622                        "update_at",
623                        ConcreteDataType::timestamp_millisecond_datatype(),
624                        false,
625                    ),
626                    ColumnSchema::new(
627                        AUTO_CREATED_PLACEHOLDER_TS_COL,
628                        ConcreteDataType::timestamp_millisecond_datatype(),
629                        false,
630                    )
631                    .with_time_index(true),
632                ],
633            ),
634            // add ts placeholder
635            (
636                "SELECT number, ts FROM numbers_with_ts",
637                Ok(
638                    "SELECT numbers_with_ts.number, numbers_with_ts.ts AS update_at, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts",
639                ),
640                vec![
641                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
642                    ColumnSchema::new(
643                        "update_at",
644                        ConcreteDataType::timestamp_millisecond_datatype(),
645                        false,
646                    ),
647                    ColumnSchema::new(
648                        AUTO_CREATED_PLACEHOLDER_TS_COL,
649                        ConcreteDataType::timestamp_millisecond_datatype(),
650                        false,
651                    )
652                    .with_time_index(true),
653                ],
654            ),
655            // add update_at after time index column
656            (
657                "SELECT number, ts FROM numbers_with_ts",
658                Ok(
659                    "SELECT numbers_with_ts.number, numbers_with_ts.ts, now() AS update_atat FROM numbers_with_ts",
660                ),
661                vec![
662                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
663                    ColumnSchema::new(
664                        "ts",
665                        ConcreteDataType::timestamp_millisecond_datatype(),
666                        false,
667                    )
668                    .with_time_index(true),
669                    ColumnSchema::new(
670                        // name is irrelevant for update_at column
671                        "update_atat",
672                        ConcreteDataType::timestamp_millisecond_datatype(),
673                        false,
674                    ),
675                ],
676            ),
677            // error datatype mismatch
678            (
679                "SELECT number, ts FROM numbers_with_ts",
680                Err(
681                    "Expect the last column in table to be timestamp column, found column atat with type Int8",
682                ),
683                vec![
684                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
685                    ColumnSchema::new(
686                        "ts",
687                        ConcreteDataType::timestamp_millisecond_datatype(),
688                        false,
689                    )
690                    .with_time_index(true),
691                    ColumnSchema::new(
692                        // name is irrelevant for update_at column
693                        "atat",
694                        ConcreteDataType::int8_datatype(),
695                        false,
696                    ),
697                ],
698            ),
699            // error datatype mismatch on second last column
700            (
701                "SELECT number FROM numbers_with_ts",
702                Err(
703                    "Expect the second last column in the table to be timestamp column, found column ts with type Int8",
704                ),
705                vec![
706                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
707                    ColumnSchema::new("ts", ConcreteDataType::int8_datatype(), false),
708                    ColumnSchema::new(
709                        // name is irrelevant for update_at column
710                        "atat",
711                        ConcreteDataType::timestamp_millisecond_datatype(),
712                        false,
713                    )
714                    .with_time_index(true),
715                ],
716            ),
717        ];
718
719        let query_engine = create_test_query_engine();
720        let ctx = QueryContext::arc();
721        for (before, after, column_schemas) in testcases {
722            let schema = Arc::new(Schema::new(column_schemas));
723            let mut add_auto_column_rewriter = ColumnMatcherRewriter::new(schema);
724
725            let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), before, false)
726                .await
727                .unwrap();
728            let new_sql = (|| {
729                let plan = plan
730                    .rewrite(&mut add_auto_column_rewriter)
731                    .map_err(|e| e.to_string())?
732                    .data;
733                df_plan_to_sql(&plan).map_err(|e| e.to_string())
734            })();
735            match (after, new_sql.clone()) {
736                (Ok(after), Ok(new_sql)) => assert_eq!(after, new_sql),
737                (Err(expected), Err(real_err_msg)) => assert!(
738                    real_err_msg.contains(expected),
739                    "expected: {expected}, real: {real_err_msg}"
740                ),
741                _ => panic!("expected: {:?}, real: {:?}", after, new_sql),
742            }
743        }
744    }
745
746    #[tokio::test]
747    async fn test_find_group_by_exprs() {
748        let testcases = vec![
749            (
750                "SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS ts FROM numbers_with_ts GROUP BY ts;",
751                vec!["ts"],
752            ),
753            (
754                "SELECT number FROM numbers_with_ts GROUP BY number",
755                vec!["number"],
756            ),
757            (
758                "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window",
759                vec!["time_window"],
760            ),
761            // subquery
762            (
763                "SELECT number, time_window FROM (SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number);",
764                vec!["time_window", "number"],
765            ),
766            // complex subquery without alias
767            (
768                "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) GROUP BY number, time_window, bucket_name;",
769                vec!["number", "time_window", "bucket_name"],
770            ),
771            // complex subquery alias
772            (
773                "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) as cte GROUP BY number, time_window, bucket_name;",
774                vec!["number", "time_window", "bucket_name"],
775            ),
776        ];
777
778        let query_engine = create_test_query_engine();
779        let ctx = QueryContext::arc();
780        for (sql, expected) in testcases {
781            // need to be unoptimize for better readiability
782            let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
783                .await
784                .unwrap();
785            let mut group_by_exprs = FindGroupByFinalName::default();
786            plan.visit(&mut group_by_exprs).unwrap();
787            let expected: HashSet<String> = expected.into_iter().map(|s| s.to_string()).collect();
788            assert_eq!(
789                expected,
790                group_by_exprs.get_group_expr_names().unwrap_or_default()
791            );
792        }
793    }
794
795    #[tokio::test]
796    async fn test_null_cast() {
797        let query_engine = create_test_query_engine();
798        let ctx = QueryContext::arc();
799        let sql = "SELECT NULL::DOUBLE FROM numbers_with_ts";
800        let plan = sql_to_df_plan(ctx, query_engine.clone(), sql, false)
801            .await
802            .unwrap();
803
804        let _sub_plan = DFLogicalSubstraitConvertor {}
805            .encode(&plan, DefaultSerializer)
806            .unwrap();
807    }
808}