flow/batching_mode/
utils.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! some utils for helping with batching mode
16
17use std::collections::{BTreeSet, HashMap, HashSet};
18use std::sync::Arc;
19
20use catalog::CatalogManagerRef;
21use common_error::ext::BoxedError;
22use common_telemetry::debug;
23use datafusion::error::Result as DfResult;
24use datafusion::logical_expr::Expr;
25use datafusion::sql::unparser::Unparser;
26use datafusion_common::tree_node::{
27    Transformed, TreeNode as _, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
28};
29use datafusion_common::{DFSchema, DataFusionError, ScalarValue};
30use datafusion_expr::{Distinct, LogicalPlan, Projection};
31use datatypes::schema::{ColumnSchema, SchemaRef};
32use query::QueryEngineRef;
33use query::parser::{DEFAULT_LOOKBACK_STRING, PromQuery, QueryLanguageParser, QueryStatement};
34use session::context::QueryContextRef;
35use snafu::{OptionExt, ResultExt, ensure};
36use sql::parser::{ParseOptions, ParserContext};
37use sql::statements::statement::Statement;
38use sql::statements::tql::Tql;
39use table::TableRef;
40
41use crate::adapter::{AUTO_CREATED_PLACEHOLDER_TS_COL, AUTO_CREATED_UPDATE_AT_TS_COL};
42use crate::df_optimizer::apply_df_optimizer;
43use crate::error::{DatafusionSnafu, ExternalSnafu, InvalidQuerySnafu, TableNotFoundSnafu};
44use crate::{Error, TableName};
45
46pub async fn get_table_info_df_schema(
47    catalog_mr: CatalogManagerRef,
48    table_name: TableName,
49) -> Result<(TableRef, Arc<DFSchema>), Error> {
50    let full_table_name = table_name.clone().join(".");
51    let table = catalog_mr
52        .table(&table_name[0], &table_name[1], &table_name[2], None)
53        .await
54        .map_err(BoxedError::new)
55        .context(ExternalSnafu)?
56        .context(TableNotFoundSnafu {
57            name: &full_table_name,
58        })?;
59    let table_info = table.table_info();
60
61    let schema = table_info.meta.schema.clone();
62
63    let df_schema: Arc<DFSchema> = Arc::new(
64        schema
65            .arrow_schema()
66            .clone()
67            .try_into()
68            .with_context(|_| DatafusionSnafu {
69                context: format!(
70                    "Failed to convert arrow schema to datafusion schema, arrow_schema={:?}",
71                    schema.arrow_schema()
72                ),
73            })?,
74    );
75    Ok((table, df_schema))
76}
77
78/// Convert sql to datafusion logical plan
79/// Also support TQL (but only Eval not Explain or Analyze)
80pub async fn sql_to_df_plan(
81    query_ctx: QueryContextRef,
82    engine: QueryEngineRef,
83    sql: &str,
84    optimize: bool,
85) -> Result<LogicalPlan, Error> {
86    let stmts =
87        ParserContext::create_with_dialect(sql, query_ctx.sql_dialect(), ParseOptions::default())
88            .map_err(BoxedError::new)
89            .context(ExternalSnafu)?;
90
91    ensure!(
92        stmts.len() == 1,
93        InvalidQuerySnafu {
94            reason: format!("Expect only one statement, found {}", stmts.len())
95        }
96    );
97    let stmt = &stmts[0];
98    let query_stmt = match stmt {
99        Statement::Tql(tql) => match tql {
100            Tql::Eval(eval) => {
101                let eval = eval.clone();
102                let promql = PromQuery {
103                    start: eval.start,
104                    end: eval.end,
105                    step: eval.step,
106                    query: eval.query,
107                    lookback: eval
108                        .lookback
109                        .unwrap_or_else(|| DEFAULT_LOOKBACK_STRING.to_string()),
110                    alias: eval.alias.clone(),
111                };
112
113                QueryLanguageParser::parse_promql(&promql, &query_ctx)
114                    .map_err(BoxedError::new)
115                    .context(ExternalSnafu)?
116            }
117            _ => InvalidQuerySnafu {
118                reason: format!("TQL statement {tql:?} is not supported, expect only TQL EVAL"),
119            }
120            .fail()?,
121        },
122        _ => QueryStatement::Sql(stmt.clone()),
123    };
124    let plan = engine
125        .planner()
126        .plan(&query_stmt, query_ctx.clone())
127        .await
128        .map_err(BoxedError::new)
129        .context(ExternalSnafu)?;
130
131    let plan = if optimize {
132        apply_df_optimizer(plan, &query_ctx).await?
133    } else {
134        plan
135    };
136    Ok(plan)
137}
138
139/// Generate a plan that matches the schema of the sink table
140/// from given sql by alias and adding auto columns
141pub(crate) async fn gen_plan_with_matching_schema(
142    sql: &str,
143    query_ctx: QueryContextRef,
144    engine: QueryEngineRef,
145    sink_table_schema: SchemaRef,
146    primary_key_indices: &[usize],
147    allow_partial: bool,
148) -> Result<LogicalPlan, Error> {
149    let plan = sql_to_df_plan(query_ctx.clone(), engine.clone(), sql, false).await?;
150
151    let mut add_auto_column = ColumnMatcherRewriter::new(
152        sink_table_schema,
153        primary_key_indices.to_vec(),
154        allow_partial,
155    );
156    let plan = plan
157        .clone()
158        .rewrite(&mut add_auto_column)
159        .with_context(|_| DatafusionSnafu {
160            context: format!("Failed to rewrite plan:\n {}\n", plan),
161        })?
162        .data;
163    Ok(plan)
164}
165
166pub fn df_plan_to_sql(plan: &LogicalPlan) -> Result<String, Error> {
167    /// A dialect that forces identifiers to be quoted when have uppercase
168    struct ForceQuoteIdentifiers;
169    impl datafusion::sql::unparser::dialect::Dialect for ForceQuoteIdentifiers {
170        fn identifier_quote_style(&self, identifier: &str) -> Option<char> {
171            if identifier.to_lowercase() != identifier {
172                Some('`')
173            } else {
174                None
175            }
176        }
177    }
178    let unparser = Unparser::new(&ForceQuoteIdentifiers);
179    // first make all column qualified
180    let sql = unparser
181        .plan_to_sql(plan)
182        .with_context(|_e| DatafusionSnafu {
183            context: format!("Failed to unparse logical plan {plan:?}"),
184        })?;
185    Ok(sql.to_string())
186}
187
188/// Helper to find the innermost group by expr in schema, return None if no group by expr
189#[derive(Debug, Clone, Default)]
190pub struct FindGroupByFinalName {
191    group_exprs: Option<HashSet<datafusion_expr::Expr>>,
192}
193
194impl FindGroupByFinalName {
195    pub fn get_group_expr_names(&self) -> Option<HashSet<String>> {
196        self.group_exprs
197            .as_ref()
198            .map(|exprs| exprs.iter().map(|expr| expr.qualified_name().1).collect())
199    }
200}
201
202impl TreeNodeVisitor<'_> for FindGroupByFinalName {
203    type Node = LogicalPlan;
204
205    fn f_down(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
206        if let LogicalPlan::Aggregate(aggregate) = node {
207            self.group_exprs = Some(aggregate.group_expr.iter().cloned().collect());
208            debug!(
209                "FindGroupByFinalName: Get Group by exprs from Aggregate: {:?}",
210                self.group_exprs
211            );
212        } else if let LogicalPlan::Distinct(distinct) = node {
213            debug!("FindGroupByFinalName: Distinct: {}", node);
214            match distinct {
215                Distinct::All(input) => {
216                    if let LogicalPlan::TableScan(table_scan) = &**input {
217                        // get column from field_qualifier, projection and projected_schema:
218                        let len = table_scan.projected_schema.fields().len();
219                        let columns = (0..len)
220                            .map(|f| {
221                                let (qualifier, field) =
222                                    table_scan.projected_schema.qualified_field(f);
223                                datafusion_common::Column::new(qualifier.cloned(), field.name())
224                            })
225                            .map(datafusion_expr::Expr::Column);
226                        self.group_exprs = Some(columns.collect());
227                    } else {
228                        self.group_exprs = Some(input.expressions().iter().cloned().collect())
229                    }
230                }
231                Distinct::On(distinct_on) => {
232                    self.group_exprs = Some(distinct_on.on_expr.iter().cloned().collect())
233                }
234            }
235            debug!(
236                "FindGroupByFinalName: Get Group by exprs from Distinct: {:?}",
237                self.group_exprs
238            );
239        }
240
241        Ok(TreeNodeRecursion::Continue)
242    }
243
244    /// deal with projection when going up with group exprs
245    fn f_up(&mut self, node: &Self::Node) -> datafusion_common::Result<TreeNodeRecursion> {
246        if let LogicalPlan::Projection(projection) = node {
247            for expr in &projection.expr {
248                let Some(group_exprs) = &mut self.group_exprs else {
249                    return Ok(TreeNodeRecursion::Continue);
250                };
251                if let datafusion_expr::Expr::Alias(alias) = expr {
252                    // if a alias exist, replace with the new alias
253                    let mut new_group_exprs = group_exprs.clone();
254                    for group_expr in group_exprs.iter() {
255                        if group_expr.name_for_alias()? == alias.expr.name_for_alias()? {
256                            new_group_exprs.remove(group_expr);
257                            new_group_exprs.insert(expr.clone());
258                            break;
259                        }
260                    }
261                    *group_exprs = new_group_exprs;
262                }
263            }
264        }
265        debug!("Aliased group by exprs: {:?}", self.group_exprs);
266        Ok(TreeNodeRecursion::Continue)
267    }
268}
269
270/// Optionally add to the final select columns like `update_at` if the sink table has such column
271/// (which doesn't necessary need to have exact name just need to be a extra timestamp column)
272/// and `__ts_placeholder`(this column need to have exact this name and be a timestamp)
273/// with values like `now()` and `0`
274///
275/// it also give existing columns alias to column in sink table if needed
276#[derive(Debug)]
277pub struct ColumnMatcherRewriter {
278    pub schema: SchemaRef,
279    pub is_rewritten: bool,
280    pub primary_key_indices: Vec<usize>,
281    pub allow_partial: bool,
282}
283
284impl ColumnMatcherRewriter {
285    pub fn new(schema: SchemaRef, primary_key_indices: Vec<usize>, allow_partial: bool) -> Self {
286        Self {
287            schema,
288            is_rewritten: false,
289            primary_key_indices,
290            allow_partial,
291        }
292    }
293
294    /// modify the exprs in place so that it matches the schema and some auto columns are added
295    fn modify_project_exprs(&mut self, mut exprs: Vec<Expr>) -> DfResult<Vec<Expr>> {
296        if self.allow_partial {
297            return self.modify_project_exprs_with_partial(exprs);
298        }
299
300        let all_names = self
301            .schema
302            .column_schemas()
303            .iter()
304            .map(|c| c.name.clone())
305            .collect::<BTreeSet<_>>();
306        // first match by position
307        for (idx, expr) in exprs.iter_mut().enumerate() {
308            if !all_names.contains(&expr.qualified_name().1)
309                && let Some(col_name) = self
310                    .schema
311                    .column_schemas()
312                    .get(idx)
313                    .map(|c| c.name.clone())
314            {
315                // if the data type mismatched, later check_execute will error out
316                // hence no need to check it here, beside, optimize pass might be able to cast it
317                // so checking here is not necessary
318                *expr = expr.clone().alias(col_name);
319            }
320        }
321
322        // add columns if have different column count
323        let query_col_cnt = exprs.len();
324        let table_col_cnt = self.schema.column_schemas().len();
325        debug!("query_col_cnt={query_col_cnt}, table_col_cnt={table_col_cnt}");
326
327        let placeholder_ts_expr =
328            datafusion::logical_expr::lit(ScalarValue::TimestampMillisecond(Some(0), None))
329                .alias(AUTO_CREATED_PLACEHOLDER_TS_COL);
330
331        if query_col_cnt == table_col_cnt {
332            // still need to add alias, see below
333        } else if query_col_cnt + 1 == table_col_cnt {
334            let last_col_schema = self.schema.column_schemas().last().unwrap();
335
336            // if time index column is auto created add it
337            if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
338                && self.schema.timestamp_index() == Some(table_col_cnt - 1)
339            {
340                exprs.push(placeholder_ts_expr);
341            } else if last_col_schema.data_type.is_timestamp() {
342                // is the update at column
343                exprs.push(datafusion::prelude::now().alias(&last_col_schema.name));
344            } else {
345                // helpful error message
346                return Err(DataFusionError::Plan(format!(
347                    "Expect the last column in table to be timestamp column, found column {} with type {:?}",
348                    last_col_schema.name, last_col_schema.data_type
349                )));
350            }
351        } else if query_col_cnt + 2 == table_col_cnt {
352            let mut col_iter = self.schema.column_schemas().iter().rev();
353            let last_col_schema = col_iter.next().unwrap();
354            let second_last_col_schema = col_iter.next().unwrap();
355            if second_last_col_schema.data_type.is_timestamp() {
356                exprs.push(datafusion::prelude::now().alias(&second_last_col_schema.name));
357            } else {
358                return Err(DataFusionError::Plan(format!(
359                    "Expect the second last column in the table to be timestamp column, found column {} with type {:?}",
360                    second_last_col_schema.name, second_last_col_schema.data_type
361                )));
362            }
363
364            if last_col_schema.name == AUTO_CREATED_PLACEHOLDER_TS_COL
365                && self.schema.timestamp_index() == Some(table_col_cnt - 1)
366            {
367                exprs.push(placeholder_ts_expr);
368            } else {
369                return Err(DataFusionError::Plan(format!(
370                    "Expect timestamp column {}, found {:?}",
371                    AUTO_CREATED_PLACEHOLDER_TS_COL, last_col_schema
372                )));
373            }
374        } else {
375            return Err(DataFusionError::Plan(format!(
376                "Expect table have 0,1 or 2 columns more than query columns, found {} query columns {:?}, {} table columns {:?}",
377                query_col_cnt,
378                exprs,
379                table_col_cnt,
380                self.schema.column_schemas()
381            )));
382        }
383        Ok(exprs)
384    }
385
386    fn modify_project_exprs_with_partial(&mut self, exprs: Vec<Expr>) -> DfResult<Vec<Expr>> {
387        let table_col_cnt = self.schema.column_schemas().len();
388        let query_col_cnt = exprs.len();
389
390        if query_col_cnt > table_col_cnt {
391            return Err(DataFusionError::Plan(format!(
392                "Expect query column count <= table column count, found {} query columns {:?}, {} table columns {:?}",
393                query_col_cnt,
394                exprs,
395                table_col_cnt,
396                self.schema.column_schemas()
397            )));
398        }
399
400        let name_to_expr: HashMap<String, Expr> = exprs
401            .clone()
402            .into_iter()
403            .map(|e| (e.qualified_name().1, e))
404            .collect();
405
406        let required_columns = self.required_columns_for_partial();
407        let missing: Vec<_> = required_columns
408            .iter()
409            .filter(|name| !name_to_expr.contains_key(*name))
410            .cloned()
411            .collect();
412        if !missing.is_empty() {
413            return Err(DataFusionError::Plan(format!(
414                "Column(s) {:?} required by sink table are missing from flow output when merge_mode=last_non_null",
415                missing
416            )));
417        }
418
419        let placeholder_ts_expr =
420            datafusion::logical_expr::lit(ScalarValue::TimestampMillisecond(Some(0), None))
421                .alias(AUTO_CREATED_PLACEHOLDER_TS_COL);
422
423        let timestamp_index = self.schema.timestamp_index();
424        let mut remap = name_to_expr;
425        let mut new_exprs = Vec::with_capacity(table_col_cnt);
426
427        for (idx, col_schema) in self.schema.column_schemas().iter().enumerate() {
428            let col_name = col_schema.name.clone();
429            if let Some(expr) = remap.remove(&col_name) {
430                let expr = if expr.qualified_name().1 == col_name {
431                    expr
432                } else {
433                    expr.alias(col_name.clone())
434                };
435                new_exprs.push(expr);
436                continue;
437            }
438
439            if col_name == AUTO_CREATED_PLACEHOLDER_TS_COL && timestamp_index == Some(idx) {
440                new_exprs.push(placeholder_ts_expr.clone());
441                continue;
442            }
443
444            if col_name == AUTO_CREATED_UPDATE_AT_TS_COL && col_schema.data_type.is_timestamp() {
445                new_exprs.push(datafusion::prelude::now().alias(&col_name));
446                continue;
447            }
448
449            new_exprs.push(Self::null_expr(col_schema));
450        }
451
452        if !remap.is_empty() {
453            let extra: Vec<_> = remap.keys().cloned().collect();
454            return Err(DataFusionError::Plan(format!(
455                "Flow output has extra column(s) {:?} not found in sink schema when merge_mode=last_non_null",
456                extra
457            )));
458        }
459
460        Ok(new_exprs)
461    }
462
463    fn null_expr(col_schema: &ColumnSchema) -> Expr {
464        Expr::Literal(ScalarValue::Null, None).alias(col_schema.name.clone())
465    }
466
467    fn required_columns_for_partial(&self) -> HashSet<String> {
468        let mut required = HashSet::new();
469        for idx in &self.primary_key_indices {
470            if let Some(col) = self.schema.column_schemas().get(*idx) {
471                required.insert(col.name.clone());
472            }
473        }
474
475        if let Some(ts_idx) = self.schema.timestamp_index()
476            && let Some(col) = self.schema.column_schemas().get(ts_idx)
477            && col.name != AUTO_CREATED_PLACEHOLDER_TS_COL
478        {
479            required.insert(col.name.clone());
480        }
481
482        required
483    }
484}
485
486impl TreeNodeRewriter for ColumnMatcherRewriter {
487    type Node = LogicalPlan;
488    fn f_down(&mut self, mut node: Self::Node) -> DfResult<Transformed<Self::Node>> {
489        if self.is_rewritten {
490            return Ok(Transformed::no(node));
491        }
492
493        // if is distinct all, wrap it in a projection
494        if let LogicalPlan::Distinct(Distinct::All(_)) = &node {
495            let mut exprs = vec![];
496
497            for field in node.schema().fields().iter() {
498                exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
499                    field.name(),
500                )));
501            }
502
503            let projection =
504                LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
505
506            node = projection;
507        }
508        // handle table_scan by wrap it in a projection
509        else if let LogicalPlan::TableScan(table_scan) = node {
510            let mut exprs = vec![];
511
512            for field in table_scan.projected_schema.fields().iter() {
513                exprs.push(Expr::Column(datafusion::common::Column::new(
514                    Some(table_scan.table_name.clone()),
515                    field.name(),
516                )));
517            }
518
519            let projection = LogicalPlan::Projection(Projection::try_new(
520                exprs,
521                Arc::new(LogicalPlan::TableScan(table_scan)),
522            )?);
523
524            node = projection;
525        }
526
527        // only do rewrite if found the outermost projection
528        // if the outermost node is projection, can rewrite the exprs
529        // if not, wrap it in a projection
530        if let LogicalPlan::Projection(project) = &node {
531            let exprs = project.expr.clone();
532            let exprs = self.modify_project_exprs(exprs)?;
533
534            self.is_rewritten = true;
535            let new_plan =
536                node.with_new_exprs(exprs, node.inputs().into_iter().cloned().collect())?;
537            Ok(Transformed::yes(new_plan))
538        } else {
539            // wrap the logical plan in a projection
540            let mut exprs = vec![];
541            for field in node.schema().fields().iter() {
542                exprs.push(Expr::Column(datafusion::common::Column::new_unqualified(
543                    field.name(),
544                )));
545            }
546            let exprs = self.modify_project_exprs(exprs)?;
547            self.is_rewritten = true;
548            let new_plan =
549                LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(node.clone()))?);
550            Ok(Transformed::yes(new_plan))
551        }
552    }
553
554    /// We might add new columns, so we need to recompute the schema
555    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
556        node.recompute_schema().map(Transformed::yes)
557    }
558}
559
560/// Find out the `Filter` Node corresponding to innermost(deepest) `WHERE` and add a new filter expr to it
561#[derive(Debug)]
562pub struct AddFilterRewriter {
563    extra_filter: Expr,
564    is_rewritten: bool,
565}
566
567impl AddFilterRewriter {
568    pub fn new(filter: Expr) -> Self {
569        Self {
570            extra_filter: filter,
571            is_rewritten: false,
572        }
573    }
574}
575
576impl TreeNodeRewriter for AddFilterRewriter {
577    type Node = LogicalPlan;
578    fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
579        if self.is_rewritten {
580            return Ok(Transformed::no(node));
581        }
582        match node {
583            LogicalPlan::Filter(mut filter) => {
584                filter.predicate = filter.predicate.and(self.extra_filter.clone());
585                self.is_rewritten = true;
586                Ok(Transformed::yes(LogicalPlan::Filter(filter)))
587            }
588            LogicalPlan::TableScan(_) => {
589                // add a new filter
590                let filter =
591                    datafusion_expr::Filter::try_new(self.extra_filter.clone(), Arc::new(node))?;
592                self.is_rewritten = true;
593                Ok(Transformed::yes(LogicalPlan::Filter(filter)))
594            }
595            _ => Ok(Transformed::no(node)),
596        }
597    }
598}
599
600#[cfg(test)]
601mod test {
602    use std::sync::Arc;
603
604    use datafusion_common::tree_node::TreeNode as _;
605    use datatypes::prelude::ConcreteDataType;
606    use datatypes::schema::{ColumnSchema, Schema};
607    use pretty_assertions::assert_eq;
608    use query::query_engine::DefaultSerializer;
609    use session::context::QueryContext;
610    use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
611
612    use super::*;
613    use crate::test_utils::create_test_query_engine;
614
615    /// test if uppercase are handled correctly(with quote)
616    #[tokio::test]
617    async fn test_sql_plan_convert() {
618        let query_engine = create_test_query_engine();
619        let ctx = QueryContext::arc();
620        let old = r#"SELECT "NUMBER" FROM "UPPERCASE_NUMBERS_WITH_TS""#;
621        let new = sql_to_df_plan(ctx.clone(), query_engine.clone(), old, false)
622            .await
623            .unwrap();
624        let new_sql = df_plan_to_sql(&new).unwrap();
625
626        assert_eq!(
627            r#"SELECT `UPPERCASE_NUMBERS_WITH_TS`.`NUMBER` FROM `UPPERCASE_NUMBERS_WITH_TS`"#,
628            new_sql
629        );
630    }
631
632    #[tokio::test]
633    async fn test_add_filter() {
634        let testcases = vec![
635            (
636                "SELECT number FROM numbers_with_ts GROUP BY number",
637                "SELECT numbers_with_ts.number FROM numbers_with_ts WHERE (number > 4) GROUP BY numbers_with_ts.number",
638            ),
639            (
640                "SELECT number FROM numbers_with_ts WHERE number < 2 OR number >10",
641                "SELECT numbers_with_ts.number FROM numbers_with_ts WHERE ((numbers_with_ts.number < 2) OR (numbers_with_ts.number > 10)) AND (number > 4)",
642            ),
643            (
644                "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window",
645                "SELECT date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE (number > 4) GROUP BY date_bin('5 minutes', numbers_with_ts.ts)",
646            ),
647            // subquery
648            (
649                "SELECT number, time_window FROM (SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number);",
650                "SELECT numbers_with_ts.number, time_window FROM (SELECT numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE (number > 4) GROUP BY date_bin('5 minutes', numbers_with_ts.ts), numbers_with_ts.number)",
651            ),
652            // complex subquery without alias
653            (
654                "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) GROUP BY number, time_window, bucket_name;",
655                "SELECT sum(numbers_with_ts.number), numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window, bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE (number > 4)) GROUP BY numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts), bucket_name",
656            ),
657            // complex subquery alias
658            (
659                "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) as cte WHERE number > 1 GROUP BY number, time_window, bucket_name;",
660                "SELECT sum(cte.number), cte.number, date_bin('5 minutes', cte.ts) AS time_window, cte.bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE (number > 4)) AS cte WHERE (cte.number > 1) GROUP BY cte.number, date_bin('5 minutes', cte.ts), cte.bucket_name",
661            ),
662        ];
663        use datafusion_expr::{col, lit};
664        let query_engine = create_test_query_engine();
665        let ctx = QueryContext::arc();
666
667        for (before, after) in testcases {
668            let sql = before;
669            let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
670                .await
671                .unwrap();
672
673            let mut add_filter = AddFilterRewriter::new(col("number").gt(lit(4u32)));
674            let plan = plan.rewrite(&mut add_filter).unwrap().data;
675            let new_sql = df_plan_to_sql(&plan).unwrap();
676            assert_eq!(after, new_sql);
677        }
678    }
679
680    #[tokio::test]
681    async fn test_add_auto_column_rewriter() {
682        let testcases = vec![
683            // add update_at
684            (
685                "SELECT number FROM numbers_with_ts",
686                Ok("SELECT numbers_with_ts.number, now() AS ts FROM numbers_with_ts"),
687                vec![
688                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
689                    ColumnSchema::new(
690                        "ts",
691                        ConcreteDataType::timestamp_millisecond_datatype(),
692                        false,
693                    )
694                    .with_time_index(true),
695                ],
696            ),
697            // add ts placeholder
698            (
699                "SELECT number FROM numbers_with_ts",
700                Ok(
701                    "SELECT numbers_with_ts.number, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts",
702                ),
703                vec![
704                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
705                    ColumnSchema::new(
706                        AUTO_CREATED_PLACEHOLDER_TS_COL,
707                        ConcreteDataType::timestamp_millisecond_datatype(),
708                        false,
709                    )
710                    .with_time_index(true),
711                ],
712            ),
713            // no modify
714            (
715                "SELECT number, ts FROM numbers_with_ts",
716                Ok("SELECT numbers_with_ts.number, numbers_with_ts.ts FROM numbers_with_ts"),
717                vec![
718                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
719                    ColumnSchema::new(
720                        "ts",
721                        ConcreteDataType::timestamp_millisecond_datatype(),
722                        false,
723                    )
724                    .with_time_index(true),
725                ],
726            ),
727            // add update_at and ts placeholder
728            (
729                "SELECT number FROM numbers_with_ts",
730                Ok(
731                    "SELECT numbers_with_ts.number, now() AS update_at, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts",
732                ),
733                vec![
734                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
735                    ColumnSchema::new(
736                        "update_at",
737                        ConcreteDataType::timestamp_millisecond_datatype(),
738                        false,
739                    ),
740                    ColumnSchema::new(
741                        AUTO_CREATED_PLACEHOLDER_TS_COL,
742                        ConcreteDataType::timestamp_millisecond_datatype(),
743                        false,
744                    )
745                    .with_time_index(true),
746                ],
747            ),
748            // add ts placeholder
749            (
750                "SELECT number, ts FROM numbers_with_ts",
751                Ok(
752                    "SELECT numbers_with_ts.number, numbers_with_ts.ts AS update_at, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts",
753                ),
754                vec![
755                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
756                    ColumnSchema::new(
757                        "update_at",
758                        ConcreteDataType::timestamp_millisecond_datatype(),
759                        false,
760                    ),
761                    ColumnSchema::new(
762                        AUTO_CREATED_PLACEHOLDER_TS_COL,
763                        ConcreteDataType::timestamp_millisecond_datatype(),
764                        false,
765                    )
766                    .with_time_index(true),
767                ],
768            ),
769            // add update_at after time index column
770            (
771                "SELECT number, ts FROM numbers_with_ts",
772                Ok(
773                    "SELECT numbers_with_ts.number, numbers_with_ts.ts, now() AS update_atat FROM numbers_with_ts",
774                ),
775                vec![
776                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
777                    ColumnSchema::new(
778                        "ts",
779                        ConcreteDataType::timestamp_millisecond_datatype(),
780                        false,
781                    )
782                    .with_time_index(true),
783                    ColumnSchema::new(
784                        // name is irrelevant for update_at column
785                        "update_atat",
786                        ConcreteDataType::timestamp_millisecond_datatype(),
787                        false,
788                    ),
789                ],
790            ),
791            // error datatype mismatch
792            (
793                "SELECT number, ts FROM numbers_with_ts",
794                Err(
795                    "Expect the last column in table to be timestamp column, found column atat with type Int8",
796                ),
797                vec![
798                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
799                    ColumnSchema::new(
800                        "ts",
801                        ConcreteDataType::timestamp_millisecond_datatype(),
802                        false,
803                    )
804                    .with_time_index(true),
805                    ColumnSchema::new(
806                        // name is irrelevant for update_at column
807                        "atat",
808                        ConcreteDataType::int8_datatype(),
809                        false,
810                    ),
811                ],
812            ),
813            // error datatype mismatch on second last column
814            (
815                "SELECT number FROM numbers_with_ts",
816                Err(
817                    "Expect the second last column in the table to be timestamp column, found column ts with type Int8",
818                ),
819                vec![
820                    ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true),
821                    ColumnSchema::new("ts", ConcreteDataType::int8_datatype(), false),
822                    ColumnSchema::new(
823                        // name is irrelevant for update_at column
824                        "atat",
825                        ConcreteDataType::timestamp_millisecond_datatype(),
826                        false,
827                    )
828                    .with_time_index(true),
829                ],
830            ),
831        ];
832
833        let query_engine = create_test_query_engine();
834        let ctx = QueryContext::arc();
835        for (before, after, column_schemas) in testcases {
836            let schema = Arc::new(Schema::new(column_schemas));
837            let mut add_auto_column_rewriter =
838                ColumnMatcherRewriter::new(schema, Vec::new(), false);
839
840            let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), before, false)
841                .await
842                .unwrap();
843            let new_sql = (|| {
844                let plan = plan
845                    .rewrite(&mut add_auto_column_rewriter)
846                    .map_err(|e| e.to_string())?
847                    .data;
848                df_plan_to_sql(&plan).map_err(|e| e.to_string())
849            })();
850            match (after, new_sql.clone()) {
851                (Ok(after), Ok(new_sql)) => assert_eq!(after, new_sql),
852                (Err(expected), Err(real_err_msg)) => assert!(
853                    real_err_msg.contains(expected),
854                    "expected: {expected}, real: {real_err_msg}"
855                ),
856                _ => panic!("expected: {:?}, real: {:?}", after, new_sql),
857            }
858        }
859    }
860
861    #[tokio::test]
862    async fn test_find_group_by_exprs() {
863        let testcases = vec![
864            (
865                "SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS ts FROM numbers_with_ts GROUP BY ts;",
866                vec!["ts"],
867            ),
868            (
869                "SELECT number FROM numbers_with_ts GROUP BY number",
870                vec!["number"],
871            ),
872            (
873                "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window",
874                vec!["time_window"],
875            ),
876            // subquery
877            (
878                "SELECT number, time_window FROM (SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number);",
879                vec!["time_window", "number"],
880            ),
881            // complex subquery without alias
882            (
883                "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) GROUP BY number, time_window, bucket_name;",
884                vec!["number", "time_window", "bucket_name"],
885            ),
886            // complex subquery alias
887            (
888                "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) as cte GROUP BY number, time_window, bucket_name;",
889                vec!["number", "time_window", "bucket_name"],
890            ),
891        ];
892
893        let query_engine = create_test_query_engine();
894        let ctx = QueryContext::arc();
895        for (sql, expected) in testcases {
896            // need to be unoptimize for better readiability
897            let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
898                .await
899                .unwrap();
900            let mut group_by_exprs = FindGroupByFinalName::default();
901            plan.visit(&mut group_by_exprs).unwrap();
902            let expected: HashSet<String> = expected.into_iter().map(|s| s.to_string()).collect();
903            assert_eq!(
904                expected,
905                group_by_exprs.get_group_expr_names().unwrap_or_default()
906            );
907        }
908    }
909
910    #[tokio::test]
911    async fn test_null_cast() {
912        let query_engine = create_test_query_engine();
913        let ctx = QueryContext::arc();
914        let sql = "SELECT NULL::DOUBLE FROM numbers_with_ts";
915        let plan = sql_to_df_plan(ctx, query_engine.clone(), sql, false)
916            .await
917            .unwrap();
918
919        let _sub_plan = DFLogicalSubstraitConvertor {}
920            .encode(&plan, DefaultSerializer)
921            .unwrap();
922    }
923}