flow/batching_mode/
task.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeSet, HashSet};
16use std::sync::{Arc, RwLock};
17use std::time::{Duration, SystemTime, UNIX_EPOCH};
18
19use api::v1::CreateTableExpr;
20use arrow_schema::Fields;
21use catalog::CatalogManagerRef;
22use common_error::ext::BoxedError;
23use common_query::logical_plan::breakup_insert_plan;
24use common_telemetry::tracing::warn;
25use common_telemetry::{debug, info};
26use common_time::Timestamp;
27use datafusion::optimizer::analyzer::count_wildcard_rule::CountWildcardRule;
28use datafusion::optimizer::AnalyzerRule;
29use datafusion::sql::unparser::expr_to_sql;
30use datafusion_common::tree_node::{Transformed, TreeNode};
31use datafusion_common::DFSchemaRef;
32use datafusion_expr::{DmlStatement, LogicalPlan, WriteOp};
33use datatypes::prelude::ConcreteDataType;
34use datatypes::schema::{ColumnSchema, Schema};
35use operator::expr_helper::column_schemas_to_defs;
36use query::query_engine::DefaultSerializer;
37use query::QueryEngineRef;
38use session::context::QueryContextRef;
39use snafu::{ensure, OptionExt, ResultExt};
40use sql::parser::{ParseOptions, ParserContext};
41use sql::statements::statement::Statement;
42use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
43use tokio::sync::oneshot;
44use tokio::sync::oneshot::error::TryRecvError;
45use tokio::time::Instant;
46
47use crate::adapter::{AUTO_CREATED_PLACEHOLDER_TS_COL, AUTO_CREATED_UPDATE_AT_TS_COL};
48use crate::batching_mode::frontend_client::FrontendClient;
49use crate::batching_mode::state::TaskState;
50use crate::batching_mode::time_window::TimeWindowExpr;
51use crate::batching_mode::utils::{
52    get_table_info_df_schema, sql_to_df_plan, AddAutoColumnRewriter, AddFilterRewriter,
53    FindGroupByFinalName,
54};
55use crate::batching_mode::BatchingModeOptions;
56use crate::df_optimizer::apply_df_optimizer;
57use crate::error::{
58    ConvertColumnSchemaSnafu, DatafusionSnafu, ExternalSnafu, InvalidQuerySnafu,
59    SubstraitEncodeLogicalPlanSnafu, UnexpectedSnafu,
60};
61use crate::metrics::{
62    METRIC_FLOW_BATCHING_ENGINE_ERROR_CNT, METRIC_FLOW_BATCHING_ENGINE_QUERY_TIME,
63    METRIC_FLOW_BATCHING_ENGINE_SLOW_QUERY, METRIC_FLOW_BATCHING_ENGINE_START_QUERY_CNT,
64    METRIC_FLOW_ROWS,
65};
66use crate::{Error, FlowId};
67
68/// The task's config, immutable once created
69#[derive(Clone)]
70pub struct TaskConfig {
71    pub flow_id: FlowId,
72    pub query: String,
73    /// output schema of the query
74    pub output_schema: DFSchemaRef,
75    pub time_window_expr: Option<TimeWindowExpr>,
76    /// in seconds
77    pub expire_after: Option<i64>,
78    sink_table_name: [String; 3],
79    pub source_table_names: HashSet<[String; 3]>,
80    catalog_manager: CatalogManagerRef,
81    query_type: QueryType,
82    batch_opts: Arc<BatchingModeOptions>,
83}
84
85fn determine_query_type(query: &str, query_ctx: &QueryContextRef) -> Result<QueryType, Error> {
86    let stmts =
87        ParserContext::create_with_dialect(query, query_ctx.sql_dialect(), ParseOptions::default())
88            .map_err(BoxedError::new)
89            .context(ExternalSnafu)?;
90
91    ensure!(
92        stmts.len() == 1,
93        InvalidQuerySnafu {
94            reason: format!("Expect only one statement, found {}", stmts.len())
95        }
96    );
97    let stmt = &stmts[0];
98    match stmt {
99        Statement::Tql(_) => Ok(QueryType::Tql),
100        _ => Ok(QueryType::Sql),
101    }
102}
103
104#[derive(Debug, Clone)]
105enum QueryType {
106    /// query is a tql query
107    Tql,
108    /// query is a sql query
109    Sql,
110}
111
112#[derive(Clone)]
113pub struct BatchingTask {
114    pub config: Arc<TaskConfig>,
115    pub state: Arc<RwLock<TaskState>>,
116}
117
118/// Arguments for creating batching task
119pub struct TaskArgs<'a> {
120    pub flow_id: FlowId,
121    pub query: &'a str,
122    pub plan: LogicalPlan,
123    pub time_window_expr: Option<TimeWindowExpr>,
124    pub expire_after: Option<i64>,
125    pub sink_table_name: [String; 3],
126    pub source_table_names: Vec<[String; 3]>,
127    pub query_ctx: QueryContextRef,
128    pub catalog_manager: CatalogManagerRef,
129    pub shutdown_rx: oneshot::Receiver<()>,
130    pub batch_opts: Arc<BatchingModeOptions>,
131}
132
133impl BatchingTask {
134    #[allow(clippy::too_many_arguments)]
135    pub fn try_new(
136        TaskArgs {
137            flow_id,
138            query,
139            plan,
140            time_window_expr,
141            expire_after,
142            sink_table_name,
143            source_table_names,
144            query_ctx,
145            catalog_manager,
146            shutdown_rx,
147            batch_opts,
148        }: TaskArgs<'_>,
149    ) -> Result<Self, Error> {
150        Ok(Self {
151            config: Arc::new(TaskConfig {
152                flow_id,
153                query: query.to_string(),
154                time_window_expr,
155                expire_after,
156                sink_table_name,
157                source_table_names: source_table_names.into_iter().collect(),
158                catalog_manager,
159                output_schema: plan.schema().clone(),
160                query_type: determine_query_type(query, &query_ctx)?,
161                batch_opts,
162            }),
163            state: Arc::new(RwLock::new(TaskState::new(query_ctx, shutdown_rx))),
164        })
165    }
166
167    /// mark time window range (now - expire_after, now) as dirty (or (0, now) if expire_after not set)
168    ///
169    /// useful for flush_flow to flush dirty time windows range
170    pub fn mark_all_windows_as_dirty(&self) -> Result<(), Error> {
171        let now = SystemTime::now();
172        let now = Timestamp::new_second(
173            now.duration_since(UNIX_EPOCH)
174                .expect("Time went backwards")
175                .as_secs() as _,
176        );
177        let lower_bound = self
178            .config
179            .expire_after
180            .map(|e| now.sub_duration(Duration::from_secs(e as _)))
181            .transpose()
182            .map_err(BoxedError::new)
183            .context(ExternalSnafu)?
184            .unwrap_or(Timestamp::new_second(0));
185        debug!(
186            "Flow {} mark range ({:?}, {:?}) as dirty",
187            self.config.flow_id, lower_bound, now
188        );
189        self.state
190            .write()
191            .unwrap()
192            .dirty_time_windows
193            .add_window(lower_bound, Some(now));
194        Ok(())
195    }
196
197    /// Create sink table if not exists
198    pub async fn check_or_create_sink_table(
199        &self,
200        engine: &QueryEngineRef,
201        frontend_client: &Arc<FrontendClient>,
202    ) -> Result<Option<(u32, Duration)>, Error> {
203        if !self.is_table_exist(&self.config.sink_table_name).await? {
204            let create_table = self.gen_create_table_expr(engine.clone()).await?;
205            info!(
206                "Try creating sink table(if not exists) with expr: {:?}",
207                create_table
208            );
209            self.create_table(frontend_client, create_table).await?;
210            info!(
211                "Sink table {}(if not exists) created",
212                self.config.sink_table_name.join(".")
213            );
214        }
215
216        Ok(None)
217    }
218
219    async fn is_table_exist(&self, table_name: &[String; 3]) -> Result<bool, Error> {
220        self.config
221            .catalog_manager
222            .table_exists(&table_name[0], &table_name[1], &table_name[2], None)
223            .await
224            .map_err(BoxedError::new)
225            .context(ExternalSnafu)
226    }
227
228    pub async fn gen_exec_once(
229        &self,
230        engine: &QueryEngineRef,
231        frontend_client: &Arc<FrontendClient>,
232        max_window_cnt: Option<usize>,
233    ) -> Result<Option<(u32, Duration)>, Error> {
234        if let Some(new_query) = self.gen_insert_plan(engine, max_window_cnt).await? {
235            debug!("Generate new query: {}", new_query);
236            self.execute_logical_plan(frontend_client, &new_query).await
237        } else {
238            debug!("Generate no query");
239            Ok(None)
240        }
241    }
242
243    pub async fn gen_insert_plan(
244        &self,
245        engine: &QueryEngineRef,
246        max_window_cnt: Option<usize>,
247    ) -> Result<Option<LogicalPlan>, Error> {
248        let (table, df_schema) = get_table_info_df_schema(
249            self.config.catalog_manager.clone(),
250            self.config.sink_table_name.clone(),
251        )
252        .await?;
253
254        let new_query = self
255            .gen_query_with_time_window(engine.clone(), &table.meta.schema, max_window_cnt)
256            .await?;
257
258        let insert_into = if let Some((new_query, _column_cnt)) = new_query {
259            // first check if all columns in input query exists in sink table
260            // since insert into ref to names in record batch generate by given query
261            let table_columns = df_schema
262                .columns()
263                .into_iter()
264                .map(|c| c.name)
265                .collect::<BTreeSet<_>>();
266            for column in new_query.schema().columns() {
267                ensure!(
268                    table_columns.contains(column.name()),
269                    InvalidQuerySnafu {
270                        reason: format!(
271                            "Column {} not found in sink table with columns {:?}",
272                            column, table_columns
273                        ),
274                    }
275                );
276            }
277            // update_at& time index placeholder (if exists) should have default value
278            LogicalPlan::Dml(DmlStatement::new(
279                datafusion_common::TableReference::Full {
280                    catalog: self.config.sink_table_name[0].clone().into(),
281                    schema: self.config.sink_table_name[1].clone().into(),
282                    table: self.config.sink_table_name[2].clone().into(),
283                },
284                df_schema,
285                WriteOp::Insert(datafusion_expr::dml::InsertOp::Append),
286                Arc::new(new_query),
287            ))
288        } else {
289            return Ok(None);
290        };
291        let insert_into = insert_into.recompute_schema().context(DatafusionSnafu {
292            context: "Failed to recompute schema",
293        })?;
294        Ok(Some(insert_into))
295    }
296
297    pub async fn create_table(
298        &self,
299        frontend_client: &Arc<FrontendClient>,
300        expr: CreateTableExpr,
301    ) -> Result<(), Error> {
302        let catalog = &self.config.sink_table_name[0];
303        let schema = &self.config.sink_table_name[1];
304        frontend_client
305            .create(expr.clone(), catalog, schema)
306            .await?;
307        Ok(())
308    }
309
310    pub async fn execute_logical_plan(
311        &self,
312        frontend_client: &Arc<FrontendClient>,
313        plan: &LogicalPlan,
314    ) -> Result<Option<(u32, Duration)>, Error> {
315        let instant = Instant::now();
316        let flow_id = self.config.flow_id;
317
318        debug!(
319            "Executing flow {flow_id}(expire_after={:?} secs) with query {}",
320            self.config.expire_after, &plan
321        );
322
323        let catalog = &self.config.sink_table_name[0];
324        let schema = &self.config.sink_table_name[1];
325
326        // fix all table ref by make it fully qualified, i.e. "table_name" => "catalog_name.schema_name.table_name"
327        let fixed_plan = plan
328            .clone()
329            .transform_down_with_subqueries(|p| {
330                if let LogicalPlan::TableScan(mut table_scan) = p {
331                    let resolved = table_scan.table_name.resolve(catalog, schema);
332                    table_scan.table_name = resolved.into();
333                    Ok(Transformed::yes(LogicalPlan::TableScan(table_scan)))
334                } else {
335                    Ok(Transformed::no(p))
336                }
337            })
338            .with_context(|_| DatafusionSnafu {
339                context: format!("Failed to fix table ref in logical plan, plan={:?}", plan),
340            })?
341            .data;
342
343        let expanded_plan = CountWildcardRule::new()
344            .analyze(fixed_plan.clone(), &Default::default())
345            .with_context(|_| DatafusionSnafu {
346                context: format!(
347                    "Failed to expand wildcard in logical plan, plan={:?}",
348                    fixed_plan
349                ),
350            })?;
351
352        let plan = expanded_plan;
353        let mut peer_desc = None;
354
355        let res = {
356            let _timer = METRIC_FLOW_BATCHING_ENGINE_QUERY_TIME
357                .with_label_values(&[flow_id.to_string().as_str()])
358                .start_timer();
359
360            // hack and special handling the insert logical plan
361            let req = if let Some((insert_to, insert_plan)) =
362                breakup_insert_plan(&plan, catalog, schema)
363            {
364                let message = DFLogicalSubstraitConvertor {}
365                    .encode(&insert_plan, DefaultSerializer)
366                    .context(SubstraitEncodeLogicalPlanSnafu)?;
367                api::v1::greptime_request::Request::Query(api::v1::QueryRequest {
368                    query: Some(api::v1::query_request::Query::InsertIntoPlan(
369                        api::v1::InsertIntoPlan {
370                            table_name: Some(insert_to),
371                            logical_plan: message.to_vec(),
372                        },
373                    )),
374                })
375            } else {
376                let message = DFLogicalSubstraitConvertor {}
377                    .encode(&plan, DefaultSerializer)
378                    .context(SubstraitEncodeLogicalPlanSnafu)?;
379
380                api::v1::greptime_request::Request::Query(api::v1::QueryRequest {
381                    query: Some(api::v1::query_request::Query::LogicalPlan(message.to_vec())),
382                })
383            };
384
385            frontend_client
386                .handle(req, catalog, schema, &mut peer_desc)
387                .await
388        };
389
390        let elapsed = instant.elapsed();
391        if let Ok(affected_rows) = &res {
392            debug!(
393                "Flow {flow_id} executed, affected_rows: {affected_rows:?}, elapsed: {:?}",
394                elapsed
395            );
396            METRIC_FLOW_ROWS
397                .with_label_values(&[format!("{}-out-batching", flow_id).as_str()])
398                .inc_by(*affected_rows as _);
399        } else if let Err(err) = &res {
400            warn!(
401                "Failed to execute Flow {flow_id} on frontend {:?}, result: {err:?}, elapsed: {:?} with query: {}",
402                peer_desc, elapsed, &plan
403            );
404        }
405
406        // record slow query
407        if elapsed >= self.config.batch_opts.slow_query_threshold {
408            warn!(
409                "Flow {flow_id} on frontend {:?} executed for {:?} before complete, query: {}",
410                peer_desc, elapsed, &plan
411            );
412            METRIC_FLOW_BATCHING_ENGINE_SLOW_QUERY
413                .with_label_values(&[
414                    flow_id.to_string().as_str(),
415                    &peer_desc.unwrap_or_default().to_string(),
416                ])
417                .observe(elapsed.as_secs_f64());
418        }
419
420        self.state
421            .write()
422            .unwrap()
423            .after_query_exec(elapsed, res.is_ok());
424
425        let res = res?;
426
427        Ok(Some((res, elapsed)))
428    }
429
430    /// start executing query in a loop, break when receive shutdown signal
431    ///
432    /// any error will be logged when executing query
433    pub async fn start_executing_loop(
434        &self,
435        engine: QueryEngineRef,
436        frontend_client: Arc<FrontendClient>,
437    ) {
438        let flow_id_str = self.config.flow_id.to_string();
439        loop {
440            // first check if shutdown signal is received
441            // if so, break the loop
442            {
443                let mut state = self.state.write().unwrap();
444                match state.shutdown_rx.try_recv() {
445                    Ok(()) => break,
446                    Err(TryRecvError::Closed) => {
447                        warn!(
448                            "Unexpected shutdown flow {}, shutdown anyway",
449                            self.config.flow_id
450                        );
451                        break;
452                    }
453                    Err(TryRecvError::Empty) => (),
454                }
455            }
456            METRIC_FLOW_BATCHING_ENGINE_START_QUERY_CNT
457                .with_label_values(&[&flow_id_str])
458                .inc();
459
460            let min_refresh = self.config.batch_opts.experimental_min_refresh_duration;
461
462            let new_query = match self.gen_insert_plan(&engine, None).await {
463                Ok(new_query) => new_query,
464                Err(err) => {
465                    common_telemetry::error!(err; "Failed to generate query for flow={}", self.config.flow_id);
466                    // also sleep for a little while before try again to prevent flooding logs
467                    tokio::time::sleep(min_refresh).await;
468                    continue;
469                }
470            };
471
472            let res = if let Some(new_query) = &new_query {
473                self.execute_logical_plan(&frontend_client, new_query).await
474            } else {
475                Ok(None)
476            };
477
478            match res {
479                // normal execute, sleep for some time before doing next query
480                Ok(Some(_)) => {
481                    let sleep_until = {
482                        let state = self.state.write().unwrap();
483
484                        let time_window_size = self
485                            .config
486                            .time_window_expr
487                            .as_ref()
488                            .and_then(|t| *t.time_window_size());
489
490                        state.get_next_start_query_time(
491                            self.config.flow_id,
492                            &time_window_size,
493                            min_refresh,
494                            Some(self.config.batch_opts.query_timeout),
495                            self.config.batch_opts.experimental_max_filter_num_per_query,
496                        )
497                    };
498                    tokio::time::sleep_until(sleep_until).await;
499                }
500                // no new data, sleep for some time before checking for new data
501                Ok(None) => {
502                    debug!(
503                        "Flow id = {:?} found no new data, sleep for {:?} then continue",
504                        self.config.flow_id, min_refresh
505                    );
506                    tokio::time::sleep(min_refresh).await;
507                    continue;
508                }
509                // TODO(discord9): this error should have better place to go, but for now just print error, also more context is needed
510                Err(err) => {
511                    METRIC_FLOW_BATCHING_ENGINE_ERROR_CNT
512                        .with_label_values(&[&flow_id_str])
513                        .inc();
514                    match new_query {
515                        Some(query) => {
516                            common_telemetry::error!(err; "Failed to execute query for flow={} with query: {query}", self.config.flow_id)
517                        }
518                        None => {
519                            common_telemetry::error!(err; "Failed to generate query for flow={}", self.config.flow_id)
520                        }
521                    }
522                    // also sleep for a little while before try again to prevent flooding logs
523                    tokio::time::sleep(min_refresh).await;
524                }
525            }
526        }
527    }
528
529    /// Generate the create table SQL
530    ///
531    /// the auto created table will automatically added a `update_at` Milliseconds DEFAULT now() column in the end
532    /// (for compatibility with flow streaming mode)
533    ///
534    /// and it will use first timestamp column as time index, all other columns will be added as normal columns and nullable
535    async fn gen_create_table_expr(
536        &self,
537        engine: QueryEngineRef,
538    ) -> Result<CreateTableExpr, Error> {
539        let query_ctx = self.state.read().unwrap().query_ctx.clone();
540        let plan =
541            sql_to_df_plan(query_ctx.clone(), engine.clone(), &self.config.query, true).await?;
542        create_table_with_expr(&plan, &self.config.sink_table_name)
543    }
544
545    /// will merge and use the first ten time window in query
546    async fn gen_query_with_time_window(
547        &self,
548        engine: QueryEngineRef,
549        sink_table_schema: &Arc<Schema>,
550        max_window_cnt: Option<usize>,
551    ) -> Result<Option<(LogicalPlan, usize)>, Error> {
552        let query_ctx = self.state.read().unwrap().query_ctx.clone();
553        let start = SystemTime::now();
554        let since_the_epoch = start
555            .duration_since(UNIX_EPOCH)
556            .expect("Time went backwards");
557        let low_bound = self
558            .config
559            .expire_after
560            .map(|e| since_the_epoch.as_secs() - e as u64)
561            .unwrap_or(u64::MIN);
562
563        let low_bound = Timestamp::new_second(low_bound as i64);
564        let schema_len = self.config.output_schema.fields().len();
565
566        let expire_time_window_bound = self
567            .config
568            .time_window_expr
569            .as_ref()
570            .map(|expr| expr.eval(low_bound))
571            .transpose()?;
572
573        let (Some((Some(l), Some(u))), QueryType::Sql) =
574            (expire_time_window_bound, &self.config.query_type)
575        else {
576            // either no time window or not a sql query, then just use the original query
577            // use sink_table_meta to add to query the `update_at` and `__ts_placeholder` column's value too for compatibility reason
578            debug!(
579                "Flow id = {:?}, can't get window size: precise_lower_bound={expire_time_window_bound:?}, using the same query", self.config.flow_id
580            );
581            // clean dirty time window too, this could be from create flow's check_execute
582            self.state.write().unwrap().dirty_time_windows.clean();
583
584            // TODO(discord9): not add auto column for tql query?
585            let mut add_auto_column = AddAutoColumnRewriter::new(sink_table_schema.clone());
586
587            let plan = sql_to_df_plan(query_ctx.clone(), engine.clone(), &self.config.query, false)
588                .await?;
589
590            let plan = plan
591                .clone()
592                .rewrite(&mut add_auto_column)
593                .with_context(|_| DatafusionSnafu {
594                    context: format!("Failed to rewrite plan:\n {}\n", plan),
595                })?
596                .data;
597            let schema_len = plan.schema().fields().len();
598
599            // since no time window lower/upper bound is found, just return the original query(with auto columns)
600            return Ok(Some((plan, schema_len)));
601        };
602
603        debug!(
604            "Flow id = {:?}, found time window: precise_lower_bound={:?}, precise_upper_bound={:?} with dirty time windows: {:?}",
605            self.config.flow_id, l, u, self.state.read().unwrap().dirty_time_windows
606        );
607        let window_size = u.sub(&l).with_context(|| UnexpectedSnafu {
608            reason: format!("Can't get window size from {u:?} - {l:?}"),
609        })?;
610        let col_name = self
611            .config
612            .time_window_expr
613            .as_ref()
614            .map(|expr| expr.column_name.clone())
615            .with_context(|| UnexpectedSnafu {
616                reason: format!(
617                    "Flow id={:?}, Failed to get column name from time window expr",
618                    self.config.flow_id
619                ),
620            })?;
621
622        let expr = self
623            .state
624            .write()
625            .unwrap()
626            .dirty_time_windows
627            .gen_filter_exprs(
628                &col_name,
629                Some(l),
630                window_size,
631                max_window_cnt
632                    .unwrap_or(self.config.batch_opts.experimental_max_filter_num_per_query),
633                self.config.flow_id,
634                Some(self),
635            )?;
636
637        debug!(
638            "Flow id={:?}, Generated filter expr: {:?}",
639            self.config.flow_id,
640            expr.as_ref()
641                .map(|expr| expr_to_sql(expr).with_context(|_| DatafusionSnafu {
642                    context: format!("Failed to generate filter expr from {expr:?}"),
643                }))
644                .transpose()?
645                .map(|s| s.to_string())
646        );
647
648        let Some(expr) = expr else {
649            // no new data, hence no need to update
650            debug!("Flow id={:?}, no new data, not update", self.config.flow_id);
651            return Ok(None);
652        };
653
654        let mut add_filter = AddFilterRewriter::new(expr);
655        let mut add_auto_column = AddAutoColumnRewriter::new(sink_table_schema.clone());
656
657        let plan =
658            sql_to_df_plan(query_ctx.clone(), engine.clone(), &self.config.query, false).await?;
659        let rewrite = plan
660            .clone()
661            .rewrite(&mut add_filter)
662            .and_then(|p| p.data.rewrite(&mut add_auto_column))
663            .with_context(|_| DatafusionSnafu {
664                context: format!("Failed to rewrite plan:\n {}\n", plan),
665            })?
666            .data;
667        // only apply optimize after complex rewrite is done
668        let new_plan = apply_df_optimizer(rewrite).await?;
669
670        Ok(Some((new_plan, schema_len)))
671    }
672}
673
674// auto created table have a auto added column `update_at`, and optional have a `AUTO_CREATED_PLACEHOLDER_TS_COL` column for time index placeholder if no timestamp column is specified
675// TODO(discord9): for now no default value is set for auto added column for compatibility reason with streaming mode, but this might change in favor of simpler code?
676fn create_table_with_expr(
677    plan: &LogicalPlan,
678    sink_table_name: &[String; 3],
679) -> Result<CreateTableExpr, Error> {
680    let fields = plan.schema().fields();
681    let (first_time_stamp, primary_keys) = build_primary_key_constraint(plan, fields)?;
682
683    let mut column_schemas = Vec::new();
684    for field in fields {
685        let name = field.name();
686        let ty = ConcreteDataType::from_arrow_type(field.data_type());
687        let col_schema = if first_time_stamp == Some(name.clone()) {
688            ColumnSchema::new(name, ty, false).with_time_index(true)
689        } else {
690            ColumnSchema::new(name, ty, true)
691        };
692        column_schemas.push(col_schema);
693    }
694
695    let update_at_schema = ColumnSchema::new(
696        AUTO_CREATED_UPDATE_AT_TS_COL,
697        ConcreteDataType::timestamp_millisecond_datatype(),
698        true,
699    );
700    column_schemas.push(update_at_schema);
701
702    let time_index = if let Some(time_index) = first_time_stamp {
703        time_index
704    } else {
705        column_schemas.push(
706            ColumnSchema::new(
707                AUTO_CREATED_PLACEHOLDER_TS_COL,
708                ConcreteDataType::timestamp_millisecond_datatype(),
709                false,
710            )
711            .with_time_index(true),
712        );
713        AUTO_CREATED_PLACEHOLDER_TS_COL.to_string()
714    };
715
716    let column_defs =
717        column_schemas_to_defs(column_schemas, &primary_keys).context(ConvertColumnSchemaSnafu)?;
718    Ok(CreateTableExpr {
719        catalog_name: sink_table_name[0].clone(),
720        schema_name: sink_table_name[1].clone(),
721        table_name: sink_table_name[2].clone(),
722        desc: "Auto created table by flow engine".to_string(),
723        column_defs,
724        time_index,
725        primary_keys,
726        create_if_not_exists: true,
727        table_options: Default::default(),
728        table_id: None,
729        engine: "mito".to_string(),
730    })
731}
732
733/// Return first timestamp column which is in group by clause and other columns which are also in group by clause
734///
735/// # Returns
736///
737/// * `Option<String>` - first timestamp column which is in group by clause
738/// * `Vec<String>` - other columns which are also in group by clause
739fn build_primary_key_constraint(
740    plan: &LogicalPlan,
741    schema: &Fields,
742) -> Result<(Option<String>, Vec<String>), Error> {
743    let mut pk_names = FindGroupByFinalName::default();
744
745    plan.visit(&mut pk_names)
746        .with_context(|_| DatafusionSnafu {
747            context: format!("Can't find aggr expr in plan {plan:?}"),
748        })?;
749
750    // if no group by clause, return empty
751    let pk_final_names = pk_names.get_group_expr_names().unwrap_or_default();
752    if pk_final_names.is_empty() {
753        return Ok((None, Vec::new()));
754    }
755
756    let all_pk_cols: Vec<_> = schema
757        .iter()
758        .filter(|f| pk_final_names.contains(f.name()))
759        .map(|f| f.name().clone())
760        .collect();
761    // auto create table use first timestamp column in group by clause as time index
762    let first_time_stamp = schema
763        .iter()
764        .find(|f| {
765            all_pk_cols.contains(&f.name().clone())
766                && ConcreteDataType::from_arrow_type(f.data_type()).is_timestamp()
767        })
768        .map(|f| f.name().clone());
769
770    let all_pk_cols: Vec<_> = all_pk_cols
771        .into_iter()
772        .filter(|col| first_time_stamp != Some(col.to_string()))
773        .collect();
774
775    Ok((first_time_stamp, all_pk_cols))
776}
777
778#[cfg(test)]
779mod test {
780    use api::v1::column_def::try_as_column_schema;
781    use pretty_assertions::assert_eq;
782    use session::context::QueryContext;
783
784    use super::*;
785    use crate::test_utils::create_test_query_engine;
786
787    #[tokio::test]
788    async fn test_gen_create_table_sql() {
789        let query_engine = create_test_query_engine();
790        let ctx = QueryContext::arc();
791        struct TestCase {
792            sql: String,
793            sink_table_name: String,
794            column_schemas: Vec<ColumnSchema>,
795            primary_keys: Vec<String>,
796            time_index: String,
797        }
798
799        let update_at_schema = ColumnSchema::new(
800            AUTO_CREATED_UPDATE_AT_TS_COL,
801            ConcreteDataType::timestamp_millisecond_datatype(),
802            true,
803        );
804
805        let ts_placeholder_schema = ColumnSchema::new(
806            AUTO_CREATED_PLACEHOLDER_TS_COL,
807            ConcreteDataType::timestamp_millisecond_datatype(),
808            false,
809        )
810        .with_time_index(true);
811
812        let testcases = vec![
813            TestCase {
814                sql: "SELECT number, ts FROM numbers_with_ts".to_string(),
815                sink_table_name: "new_table".to_string(),
816                column_schemas: vec![
817                    ColumnSchema::new("number", ConcreteDataType::uint32_datatype(), true),
818                    ColumnSchema::new(
819                        "ts",
820                        ConcreteDataType::timestamp_millisecond_datatype(),
821                        true,
822                    ),
823                    update_at_schema.clone(),
824                    ts_placeholder_schema.clone(),
825                ],
826                primary_keys: vec![],
827                time_index: AUTO_CREATED_PLACEHOLDER_TS_COL.to_string(),
828            },
829            TestCase {
830                sql: "SELECT number, max(ts) FROM numbers_with_ts GROUP BY number".to_string(),
831                sink_table_name: "new_table".to_string(),
832                column_schemas: vec![
833                    ColumnSchema::new("number", ConcreteDataType::uint32_datatype(), true),
834                    ColumnSchema::new(
835                        "max(numbers_with_ts.ts)",
836                        ConcreteDataType::timestamp_millisecond_datatype(),
837                        true,
838                    ),
839                    update_at_schema.clone(),
840                    ts_placeholder_schema.clone(),
841                ],
842                primary_keys: vec!["number".to_string()],
843                time_index: AUTO_CREATED_PLACEHOLDER_TS_COL.to_string(),
844            },
845            TestCase {
846                sql: "SELECT max(number), ts FROM numbers_with_ts GROUP BY ts".to_string(),
847                sink_table_name: "new_table".to_string(),
848                column_schemas: vec![
849                    ColumnSchema::new(
850                        "max(numbers_with_ts.number)",
851                        ConcreteDataType::uint32_datatype(),
852                        true,
853                    ),
854                    ColumnSchema::new(
855                        "ts",
856                        ConcreteDataType::timestamp_millisecond_datatype(),
857                        false,
858                    )
859                    .with_time_index(true),
860                    update_at_schema.clone(),
861                ],
862                primary_keys: vec![],
863                time_index: "ts".to_string(),
864            },
865            TestCase {
866                sql: "SELECT number, ts FROM numbers_with_ts GROUP BY ts, number".to_string(),
867                sink_table_name: "new_table".to_string(),
868                column_schemas: vec![
869                    ColumnSchema::new("number", ConcreteDataType::uint32_datatype(), true),
870                    ColumnSchema::new(
871                        "ts",
872                        ConcreteDataType::timestamp_millisecond_datatype(),
873                        false,
874                    )
875                    .with_time_index(true),
876                    update_at_schema.clone(),
877                ],
878                primary_keys: vec!["number".to_string()],
879                time_index: "ts".to_string(),
880            },
881        ];
882
883        for tc in testcases {
884            let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), &tc.sql, true)
885                .await
886                .unwrap();
887            let expr = create_table_with_expr(
888                &plan,
889                &[
890                    "greptime".to_string(),
891                    "public".to_string(),
892                    tc.sink_table_name.clone(),
893                ],
894            )
895            .unwrap();
896            // TODO(discord9): assert expr
897            let column_schemas = expr
898                .column_defs
899                .iter()
900                .map(|c| try_as_column_schema(c).unwrap())
901                .collect::<Vec<_>>();
902            assert_eq!(tc.column_schemas, column_schemas);
903            assert_eq!(tc.primary_keys, expr.primary_keys);
904            assert_eq!(tc.time_index, expr.time_index);
905        }
906    }
907}