query/
datafusion.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Planner, QueryEngine implementations based on DataFusion.
16
17mod error;
18mod planner;
19
20use std::any::Any;
21use std::collections::HashMap;
22use std::sync::Arc;
23
24use async_trait::async_trait;
25use common_base::Plugins;
26use common_catalog::consts::is_readonly_schema;
27use common_error::ext::BoxedError;
28use common_function::function::FunctionContext;
29use common_function::function_factory::ScalarFunctionFactory;
30use common_query::{Output, OutputData, OutputMeta};
31use common_recordbatch::adapter::RecordBatchStreamAdapter;
32use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream};
33use common_telemetry::tracing;
34use datafusion::catalog::TableFunction;
35use datafusion::physical_plan::ExecutionPlan;
36use datafusion::physical_plan::analyze::AnalyzeExec;
37use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
38use datafusion_common::ResolvedTableReference;
39use datafusion_expr::{
40    AggregateUDF, DmlStatement, LogicalPlan as DfLogicalPlan, LogicalPlan, WriteOp,
41};
42use datatypes::prelude::VectorRef;
43use datatypes::schema::Schema;
44use futures_util::StreamExt;
45use session::context::QueryContextRef;
46use snafu::{OptionExt, ResultExt, ensure};
47use sqlparser::ast::AnalyzeFormat;
48use table::TableRef;
49use table::requests::{DeleteRequest, InsertRequest};
50
51use crate::analyze::DistAnalyzeExec;
52use crate::dataframe::DataFrame;
53pub use crate::datafusion::planner::DfContextProviderAdapter;
54use crate::dist_plan::{DistPlannerOptions, MergeScanLogicalPlan};
55use crate::error::{
56    CatalogSnafu, ConvertSchemaSnafu, CreateRecordBatchSnafu, MissingTableMutationHandlerSnafu,
57    MissingTimestampColumnSnafu, QueryExecutionSnafu, Result, TableMutationSnafu,
58    TableNotFoundSnafu, TableReadOnlySnafu, UnsupportedExprSnafu,
59};
60use crate::executor::QueryExecutor;
61use crate::metrics::{OnDone, QUERY_STAGE_ELAPSED};
62use crate::physical_wrapper::PhysicalPlanWrapperRef;
63use crate::planner::{DfLogicalPlanner, LogicalPlanner};
64use crate::query_engine::{DescribeResult, QueryEngineContext, QueryEngineState};
65use crate::{QueryEngine, metrics};
66
67/// Query parallelism hint key.
68/// This hint can be set in the query context to control the parallelism of the query execution.
69pub const QUERY_PARALLELISM_HINT: &str = "query_parallelism";
70
71/// Whether to fallback to the original plan when failed to push down.
72pub const QUERY_FALLBACK_HINT: &str = "query_fallback";
73
74pub struct DatafusionQueryEngine {
75    state: Arc<QueryEngineState>,
76    plugins: Plugins,
77}
78
79impl DatafusionQueryEngine {
80    pub fn new(state: Arc<QueryEngineState>, plugins: Plugins) -> Self {
81        Self { state, plugins }
82    }
83
84    #[tracing::instrument(skip_all)]
85    async fn exec_query_plan(
86        &self,
87        plan: LogicalPlan,
88        query_ctx: QueryContextRef,
89    ) -> Result<Output> {
90        let mut ctx = self.engine_context(query_ctx.clone());
91
92        // `create_physical_plan` will optimize logical plan internally
93        let physical_plan = self.create_physical_plan(&mut ctx, &plan).await?;
94        let optimized_physical_plan = self.optimize_physical_plan(&mut ctx, physical_plan)?;
95
96        let physical_plan = if let Some(wrapper) = self.plugins.get::<PhysicalPlanWrapperRef>() {
97            wrapper.wrap(optimized_physical_plan, query_ctx)
98        } else {
99            optimized_physical_plan
100        };
101
102        Ok(Output::new(
103            OutputData::Stream(self.execute_stream(&ctx, &physical_plan)?),
104            OutputMeta::new_with_plan(physical_plan),
105        ))
106    }
107
108    #[tracing::instrument(skip_all)]
109    async fn exec_dml_statement(
110        &self,
111        dml: DmlStatement,
112        query_ctx: QueryContextRef,
113    ) -> Result<Output> {
114        ensure!(
115            matches!(dml.op, WriteOp::Insert(_) | WriteOp::Delete),
116            UnsupportedExprSnafu {
117                name: format!("DML op {}", dml.op),
118            }
119        );
120
121        let _timer = QUERY_STAGE_ELAPSED
122            .with_label_values(&[dml.op.name()])
123            .start_timer();
124
125        let default_catalog = &query_ctx.current_catalog().to_owned();
126        let default_schema = &query_ctx.current_schema();
127        let table_name = dml.table_name.resolve(default_catalog, default_schema);
128        let table = self.find_table(&table_name, &query_ctx).await?;
129
130        let output = self
131            .exec_query_plan((*dml.input).clone(), query_ctx.clone())
132            .await?;
133        let mut stream = match output.data {
134            OutputData::RecordBatches(batches) => batches.as_stream(),
135            OutputData::Stream(stream) => stream,
136            _ => unreachable!(),
137        };
138
139        let mut affected_rows = 0;
140        let mut insert_cost = 0;
141
142        while let Some(batch) = stream.next().await {
143            let batch = batch.context(CreateRecordBatchSnafu)?;
144            let column_vectors = batch
145                .column_vectors(&table_name.to_string(), table.schema())
146                .map_err(BoxedError::new)
147                .context(QueryExecutionSnafu)?;
148
149            match dml.op {
150                WriteOp::Insert(_) => {
151                    // We ignore the insert op.
152                    let output = self
153                        .insert(&table_name, column_vectors, query_ctx.clone())
154                        .await?;
155                    let (rows, cost) = output.extract_rows_and_cost();
156                    affected_rows += rows;
157                    insert_cost += cost;
158                }
159                WriteOp::Delete => {
160                    affected_rows += self
161                        .delete(&table_name, &table, column_vectors, query_ctx.clone())
162                        .await?;
163                }
164                _ => unreachable!("guarded by the 'ensure!' at the beginning"),
165            }
166        }
167        Ok(Output::new(
168            OutputData::AffectedRows(affected_rows),
169            OutputMeta::new_with_cost(insert_cost),
170        ))
171    }
172
173    #[tracing::instrument(skip_all)]
174    async fn delete(
175        &self,
176        table_name: &ResolvedTableReference,
177        table: &TableRef,
178        column_vectors: HashMap<String, VectorRef>,
179        query_ctx: QueryContextRef,
180    ) -> Result<usize> {
181        let catalog_name = table_name.catalog.to_string();
182        let schema_name = table_name.schema.to_string();
183        let table_name = table_name.table.to_string();
184        let table_schema = table.schema();
185
186        ensure!(
187            !is_readonly_schema(&schema_name),
188            TableReadOnlySnafu { table: table_name }
189        );
190
191        let ts_column = table_schema
192            .timestamp_column()
193            .map(|x| &x.name)
194            .with_context(|| MissingTimestampColumnSnafu {
195                table_name: table_name.clone(),
196            })?;
197
198        let table_info = table.table_info();
199        let rowkey_columns = table_info
200            .meta
201            .row_key_column_names()
202            .collect::<Vec<&String>>();
203        let column_vectors = column_vectors
204            .into_iter()
205            .filter(|x| &x.0 == ts_column || rowkey_columns.contains(&&x.0))
206            .collect::<HashMap<_, _>>();
207
208        let request = DeleteRequest {
209            catalog_name,
210            schema_name,
211            table_name,
212            key_column_values: column_vectors,
213        };
214
215        self.state
216            .table_mutation_handler()
217            .context(MissingTableMutationHandlerSnafu)?
218            .delete(request, query_ctx)
219            .await
220            .context(TableMutationSnafu)
221    }
222
223    #[tracing::instrument(skip_all)]
224    async fn insert(
225        &self,
226        table_name: &ResolvedTableReference,
227        column_vectors: HashMap<String, VectorRef>,
228        query_ctx: QueryContextRef,
229    ) -> Result<Output> {
230        let catalog_name = table_name.catalog.to_string();
231        let schema_name = table_name.schema.to_string();
232        let table_name = table_name.table.to_string();
233
234        ensure!(
235            !is_readonly_schema(&schema_name),
236            TableReadOnlySnafu { table: table_name }
237        );
238
239        let request = InsertRequest {
240            catalog_name,
241            schema_name,
242            table_name,
243            columns_values: column_vectors,
244        };
245
246        self.state
247            .table_mutation_handler()
248            .context(MissingTableMutationHandlerSnafu)?
249            .insert(request, query_ctx)
250            .await
251            .context(TableMutationSnafu)
252    }
253
254    async fn find_table(
255        &self,
256        table_name: &ResolvedTableReference,
257        query_context: &QueryContextRef,
258    ) -> Result<TableRef> {
259        let catalog_name = table_name.catalog.as_ref();
260        let schema_name = table_name.schema.as_ref();
261        let table_name = table_name.table.as_ref();
262
263        self.state
264            .catalog_manager()
265            .table(catalog_name, schema_name, table_name, Some(query_context))
266            .await
267            .context(CatalogSnafu)?
268            .with_context(|| TableNotFoundSnafu { table: table_name })
269    }
270
271    #[tracing::instrument(skip_all)]
272    async fn create_physical_plan(
273        &self,
274        ctx: &mut QueryEngineContext,
275        logical_plan: &LogicalPlan,
276    ) -> Result<Arc<dyn ExecutionPlan>> {
277        /// Only print context on panic, to avoid cluttering logs.
278        ///
279        /// TODO(discord9): remove this once we catch the bug
280        #[derive(Debug)]
281        struct PanicLogger<'a> {
282            input_logical_plan: &'a LogicalPlan,
283            after_analyze: Option<LogicalPlan>,
284            after_optimize: Option<LogicalPlan>,
285            phy_plan: Option<Arc<dyn ExecutionPlan>>,
286        }
287        impl Drop for PanicLogger<'_> {
288            fn drop(&mut self) {
289                if std::thread::panicking() {
290                    common_telemetry::error!(
291                        "Panic while creating physical plan, input logical plan: {:?}, after analyze: {:?}, after optimize: {:?}, final physical plan: {:?}",
292                        self.input_logical_plan,
293                        self.after_analyze,
294                        self.after_optimize,
295                        self.phy_plan
296                    );
297                }
298            }
299        }
300
301        let mut logger = PanicLogger {
302            input_logical_plan: logical_plan,
303            after_analyze: None,
304            after_optimize: None,
305            phy_plan: None,
306        };
307
308        let _timer = metrics::CREATE_PHYSICAL_ELAPSED.start_timer();
309        let state = ctx.state();
310
311        common_telemetry::debug!("Create physical plan, input plan: {logical_plan}");
312
313        // special handle EXPLAIN plan
314        if matches!(logical_plan, DfLogicalPlan::Explain(_)) {
315            return state
316                .create_physical_plan(logical_plan)
317                .await
318                .context(error::DatafusionSnafu)
319                .map_err(BoxedError::new)
320                .context(QueryExecutionSnafu);
321        }
322
323        // analyze first
324        let analyzed_plan = state
325            .analyzer()
326            .execute_and_check(logical_plan.clone(), state.config_options(), |_, _| {})
327            .context(error::DatafusionSnafu)
328            .map_err(BoxedError::new)
329            .context(QueryExecutionSnafu)?;
330
331        logger.after_analyze = Some(analyzed_plan.clone());
332
333        common_telemetry::debug!("Create physical plan, analyzed plan: {analyzed_plan}");
334
335        // skip optimize for MergeScan
336        let optimized_plan = if let DfLogicalPlan::Extension(ext) = &analyzed_plan
337            && ext.node.name() == MergeScanLogicalPlan::name()
338        {
339            analyzed_plan.clone()
340        } else {
341            state
342                .optimizer()
343                .optimize(analyzed_plan, state, |_, _| {})
344                .context(error::DatafusionSnafu)
345                .map_err(BoxedError::new)
346                .context(QueryExecutionSnafu)?
347        };
348
349        common_telemetry::debug!("Create physical plan, optimized plan: {optimized_plan}");
350        logger.after_optimize = Some(optimized_plan.clone());
351
352        let physical_plan = state
353            .query_planner()
354            .create_physical_plan(&optimized_plan, state)
355            .await?;
356
357        logger.phy_plan = Some(physical_plan.clone());
358        drop(logger);
359        Ok(physical_plan)
360    }
361
362    #[tracing::instrument(skip_all)]
363    pub fn optimize(
364        &self,
365        context: &QueryEngineContext,
366        plan: &LogicalPlan,
367    ) -> Result<LogicalPlan> {
368        let _timer = metrics::OPTIMIZE_LOGICAL_ELAPSED.start_timer();
369
370        // Optimized by extension rules
371        let optimized_plan = self
372            .state
373            .optimize_by_extension_rules(plan.clone(), context)
374            .context(error::DatafusionSnafu)
375            .map_err(BoxedError::new)
376            .context(QueryExecutionSnafu)?;
377
378        // Optimized by datafusion optimizer
379        let optimized_plan = self
380            .state
381            .session_state()
382            .optimize(&optimized_plan)
383            .context(error::DatafusionSnafu)
384            .map_err(BoxedError::new)
385            .context(QueryExecutionSnafu)?;
386
387        Ok(optimized_plan)
388    }
389
390    #[tracing::instrument(skip_all)]
391    fn optimize_physical_plan(
392        &self,
393        ctx: &mut QueryEngineContext,
394        plan: Arc<dyn ExecutionPlan>,
395    ) -> Result<Arc<dyn ExecutionPlan>> {
396        let _timer = metrics::OPTIMIZE_PHYSICAL_ELAPSED.start_timer();
397
398        // TODO(ruihang): `self.create_physical_plan()` already optimize the plan, check
399        // if we need to optimize it again here.
400        // let state = ctx.state();
401        // let config = state.config_options();
402
403        // skip optimize AnalyzeExec plan
404        let optimized_plan = if let Some(analyze_plan) = plan.as_any().downcast_ref::<AnalyzeExec>()
405        {
406            let format = if let Some(format) = ctx.query_ctx().explain_format()
407                && format.to_lowercase() == "json"
408            {
409                AnalyzeFormat::JSON
410            } else {
411                AnalyzeFormat::TEXT
412            };
413            // Sets the verbose flag of the query context.
414            // The MergeScanExec plan uses the verbose flag to determine whether to print the plan in verbose mode.
415            ctx.query_ctx().set_explain_verbose(analyze_plan.verbose());
416
417            Arc::new(DistAnalyzeExec::new(
418                analyze_plan.input().clone(),
419                analyze_plan.verbose(),
420                format,
421            ))
422            // let mut new_plan = analyze_plan.input().clone();
423            // for optimizer in state.physical_optimizers() {
424            //     new_plan = optimizer
425            //         .optimize(new_plan, config)
426            //         .context(DataFusionSnafu)?;
427            // }
428            // Arc::new(DistAnalyzeExec::new(new_plan))
429        } else {
430            plan
431            // let mut new_plan = plan;
432            // for optimizer in state.physical_optimizers() {
433            //     new_plan = optimizer
434            //         .optimize(new_plan, config)
435            //         .context(DataFusionSnafu)?;
436            // }
437            // new_plan
438        };
439
440        Ok(optimized_plan)
441    }
442}
443
444#[async_trait]
445impl QueryEngine for DatafusionQueryEngine {
446    fn as_any(&self) -> &dyn Any {
447        self
448    }
449
450    fn planner(&self) -> Arc<dyn LogicalPlanner> {
451        Arc::new(DfLogicalPlanner::new(self.state.clone()))
452    }
453
454    fn name(&self) -> &str {
455        "datafusion"
456    }
457
458    async fn describe(
459        &self,
460        plan: LogicalPlan,
461        query_ctx: QueryContextRef,
462    ) -> Result<DescribeResult> {
463        let ctx = self.engine_context(query_ctx);
464        if let Ok(optimised_plan) = self.optimize(&ctx, &plan) {
465            let schema = optimised_plan
466                .schema()
467                .clone()
468                .try_into()
469                .context(ConvertSchemaSnafu)?;
470            Ok(DescribeResult {
471                schema,
472                logical_plan: optimised_plan,
473            })
474        } else {
475            // Table's like those in information_schema cannot be optimized when
476            // it contains parameters. So we fallback to original plans.
477            let schema = plan
478                .schema()
479                .clone()
480                .try_into()
481                .context(ConvertSchemaSnafu)?;
482            Ok(DescribeResult {
483                schema,
484                logical_plan: plan,
485            })
486        }
487    }
488
489    async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result<Output> {
490        match plan {
491            LogicalPlan::Dml(dml) => self.exec_dml_statement(dml, query_ctx).await,
492            _ => self.exec_query_plan(plan, query_ctx).await,
493        }
494    }
495
496    /// Note in SQL queries, aggregate names are looked up using
497    /// lowercase unless the query uses quotes. For example,
498    ///
499    /// `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"`
500    /// `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"`
501    ///
502    /// So it's better to make UDAF name lowercase when creating one.
503    fn register_aggregate_function(&self, func: AggregateUDF) {
504        self.state.register_aggr_function(func);
505    }
506
507    /// Register an scalar function.
508    /// Will override if the function with same name is already registered.
509    fn register_scalar_function(&self, func: ScalarFunctionFactory) {
510        self.state.register_scalar_function(func);
511    }
512
513    fn register_table_function(&self, func: Arc<TableFunction>) {
514        self.state.register_table_function(func);
515    }
516
517    fn read_table(&self, table: TableRef) -> Result<DataFrame> {
518        Ok(DataFrame::DataFusion(
519            self.state
520                .read_table(table)
521                .context(error::DatafusionSnafu)
522                .map_err(BoxedError::new)
523                .context(QueryExecutionSnafu)?,
524        ))
525    }
526
527    fn engine_context(&self, query_ctx: QueryContextRef) -> QueryEngineContext {
528        let mut state = self.state.session_state();
529        state.config_mut().set_extension(query_ctx.clone());
530        // note that hints in "x-greptime-hints" is automatically parsed
531        // and set to query context's extension, so we can get it from query context.
532        if let Some(parallelism) = query_ctx.extension(QUERY_PARALLELISM_HINT) {
533            if let Ok(n) = parallelism.parse::<u64>() {
534                if n > 0 {
535                    let new_cfg = state.config().clone().with_target_partitions(n as usize);
536                    *state.config_mut() = new_cfg;
537                }
538            } else {
539                common_telemetry::warn!(
540                    "Failed to parse query_parallelism: {}, using default value",
541                    parallelism
542                );
543            }
544        }
545
546        // configure execution options
547        state.config_mut().options_mut().execution.time_zone = query_ctx.timezone().to_string();
548
549        // usually it's impossible to have both `set variable` set by sql client and
550        // hint in header by grpc client, so only need to deal with them separately
551        if query_ctx.configuration_parameter().allow_query_fallback() {
552            state
553                .config_mut()
554                .options_mut()
555                .extensions
556                .insert(DistPlannerOptions {
557                    allow_query_fallback: true,
558                });
559        } else if let Some(fallback) = query_ctx.extension(QUERY_FALLBACK_HINT) {
560            // also check the query context for fallback hint
561            // if it is set, we will enable the fallback
562            if fallback.to_lowercase().parse::<bool>().unwrap_or(false) {
563                state
564                    .config_mut()
565                    .options_mut()
566                    .extensions
567                    .insert(DistPlannerOptions {
568                        allow_query_fallback: true,
569                    });
570            }
571        }
572
573        state
574            .config_mut()
575            .options_mut()
576            .extensions
577            .insert(FunctionContext {
578                query_ctx: query_ctx.clone(),
579                state: self.engine_state().function_state(),
580            });
581
582        let config_options = state.config_options().clone();
583        let _ = state
584            .execution_props_mut()
585            .config_options
586            .insert(config_options);
587
588        QueryEngineContext::new(state, query_ctx)
589    }
590
591    fn engine_state(&self) -> &QueryEngineState {
592        &self.state
593    }
594}
595
596impl QueryExecutor for DatafusionQueryEngine {
597    #[tracing::instrument(skip_all)]
598    fn execute_stream(
599        &self,
600        ctx: &QueryEngineContext,
601        plan: &Arc<dyn ExecutionPlan>,
602    ) -> Result<SendableRecordBatchStream> {
603        let explain_verbose = ctx.query_ctx().explain_verbose();
604        let output_partitions = plan.properties().output_partitioning().partition_count();
605        if explain_verbose {
606            common_telemetry::info!("Executing query plan, output_partitions: {output_partitions}");
607        }
608
609        let exec_timer = metrics::EXEC_PLAN_ELAPSED.start_timer();
610        let task_ctx = ctx.build_task_ctx();
611
612        match plan.properties().output_partitioning().partition_count() {
613            0 => {
614                let schema = Arc::new(
615                    Schema::try_from(plan.schema())
616                        .map_err(BoxedError::new)
617                        .context(QueryExecutionSnafu)?,
618                );
619                Ok(Box::pin(EmptyRecordBatchStream::new(schema)))
620            }
621            1 => {
622                let df_stream = plan
623                    .execute(0, task_ctx)
624                    .context(error::DatafusionSnafu)
625                    .map_err(BoxedError::new)
626                    .context(QueryExecutionSnafu)?;
627                let mut stream = RecordBatchStreamAdapter::try_new(df_stream)
628                    .context(error::ConvertDfRecordBatchStreamSnafu)
629                    .map_err(BoxedError::new)
630                    .context(QueryExecutionSnafu)?;
631                stream.set_metrics2(plan.clone());
632                stream.set_explain_verbose(explain_verbose);
633                let stream = OnDone::new(Box::pin(stream), move || {
634                    let exec_cost = exec_timer.stop_and_record();
635                    if explain_verbose {
636                        common_telemetry::info!(
637                            "DatafusionQueryEngine execute 1 stream, cost: {:?}s",
638                            exec_cost,
639                        );
640                    }
641                });
642                Ok(Box::pin(stream))
643            }
644            _ => {
645                // merge into a single partition
646                let merged_plan = CoalescePartitionsExec::new(plan.clone());
647                // CoalescePartitionsExec must produce a single partition
648                assert_eq!(
649                    1,
650                    merged_plan
651                        .properties()
652                        .output_partitioning()
653                        .partition_count()
654                );
655                let df_stream = merged_plan
656                    .execute(0, task_ctx)
657                    .context(error::DatafusionSnafu)
658                    .map_err(BoxedError::new)
659                    .context(QueryExecutionSnafu)?;
660                let mut stream = RecordBatchStreamAdapter::try_new(df_stream)
661                    .context(error::ConvertDfRecordBatchStreamSnafu)
662                    .map_err(BoxedError::new)
663                    .context(QueryExecutionSnafu)?;
664                stream.set_metrics2(plan.clone());
665                stream.set_explain_verbose(ctx.query_ctx().explain_verbose());
666                let stream = OnDone::new(Box::pin(stream), move || {
667                    let exec_cost = exec_timer.stop_and_record();
668                    if explain_verbose {
669                        common_telemetry::info!(
670                            "DatafusionQueryEngine execute {output_partitions} stream, cost: {:?}s",
671                            exec_cost
672                        );
673                    }
674                });
675                Ok(Box::pin(stream))
676            }
677        }
678    }
679}
680
681#[cfg(test)]
682mod tests {
683    use std::sync::Arc;
684
685    use catalog::RegisterTableRequest;
686    use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, NUMBERS_TABLE_ID};
687    use common_recordbatch::util;
688    use datafusion::prelude::{col, lit};
689    use datatypes::prelude::ConcreteDataType;
690    use datatypes::schema::ColumnSchema;
691    use datatypes::vectors::{Helper, UInt32Vector, UInt64Vector, VectorRef};
692    use session::context::{QueryContext, QueryContextBuilder};
693    use table::table::numbers::{NUMBERS_TABLE_NAME, NumbersTable};
694
695    use super::*;
696    use crate::options::QueryOptions;
697    use crate::parser::QueryLanguageParser;
698    use crate::query_engine::{QueryEngineFactory, QueryEngineRef};
699
700    async fn create_test_engine() -> QueryEngineRef {
701        let catalog_manager = catalog::memory::new_memory_catalog_manager().unwrap();
702        let req = RegisterTableRequest {
703            catalog: DEFAULT_CATALOG_NAME.to_string(),
704            schema: DEFAULT_SCHEMA_NAME.to_string(),
705            table_name: NUMBERS_TABLE_NAME.to_string(),
706            table_id: NUMBERS_TABLE_ID,
707            table: NumbersTable::table(NUMBERS_TABLE_ID),
708        };
709        catalog_manager.register_table_sync(req).unwrap();
710
711        QueryEngineFactory::new(
712            catalog_manager,
713            None,
714            None,
715            None,
716            None,
717            false,
718            QueryOptions::default(),
719        )
720        .query_engine()
721    }
722
723    #[tokio::test]
724    async fn test_sql_to_plan() {
725        let engine = create_test_engine().await;
726        let sql = "select sum(number) from numbers limit 20";
727
728        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
729        let plan = engine
730            .planner()
731            .plan(&stmt, QueryContext::arc())
732            .await
733            .unwrap();
734
735        assert_eq!(
736            plan.to_string(),
737            r#"Limit: skip=0, fetch=20
738  Projection: sum(numbers.number)
739    Aggregate: groupBy=[[]], aggr=[[sum(numbers.number)]]
740      TableScan: numbers"#
741        );
742    }
743
744    #[tokio::test]
745    async fn test_execute() {
746        let engine = create_test_engine().await;
747        let sql = "select sum(number) from numbers limit 20";
748
749        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
750        let plan = engine
751            .planner()
752            .plan(&stmt, QueryContext::arc())
753            .await
754            .unwrap();
755
756        let output = engine.execute(plan, QueryContext::arc()).await.unwrap();
757
758        match output.data {
759            OutputData::Stream(recordbatch) => {
760                let numbers = util::collect(recordbatch).await.unwrap();
761                assert_eq!(1, numbers.len());
762                assert_eq!(numbers[0].num_columns(), 1);
763                assert_eq!(1, numbers[0].schema.num_columns());
764                assert_eq!(
765                    "sum(numbers.number)",
766                    numbers[0].schema.column_schemas()[0].name
767                );
768
769                let batch = &numbers[0];
770                assert_eq!(1, batch.num_columns());
771                assert_eq!(batch.column(0).len(), 1);
772
773                assert_eq!(
774                    *batch.column(0),
775                    Arc::new(UInt64Vector::from_slice([4950])) as VectorRef
776                );
777            }
778            _ => unreachable!(),
779        }
780    }
781
782    #[tokio::test]
783    async fn test_read_table() {
784        let engine = create_test_engine().await;
785
786        let engine = engine
787            .as_any()
788            .downcast_ref::<DatafusionQueryEngine>()
789            .unwrap();
790        let query_ctx = Arc::new(QueryContextBuilder::default().build());
791        let table = engine
792            .find_table(
793                &ResolvedTableReference {
794                    catalog: "greptime".into(),
795                    schema: "public".into(),
796                    table: "numbers".into(),
797                },
798                &query_ctx,
799            )
800            .await
801            .unwrap();
802
803        let DataFrame::DataFusion(df) = engine.read_table(table).unwrap();
804        let df = df
805            .select_columns(&["number"])
806            .unwrap()
807            .filter(col("number").lt(lit(10)))
808            .unwrap();
809        let batches = df.collect().await.unwrap();
810        assert_eq!(1, batches.len());
811        let batch = &batches[0];
812
813        assert_eq!(1, batch.num_columns());
814        assert_eq!(batch.column(0).len(), 10);
815
816        assert_eq!(
817            Helper::try_into_vector(batch.column(0)).unwrap(),
818            Arc::new(UInt32Vector::from_slice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) as VectorRef
819        );
820    }
821
822    #[tokio::test]
823    async fn test_describe() {
824        let engine = create_test_engine().await;
825        let sql = "select sum(number) from numbers limit 20";
826
827        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
828
829        let plan = engine
830            .planner()
831            .plan(&stmt, QueryContext::arc())
832            .await
833            .unwrap();
834
835        let DescribeResult {
836            schema,
837            logical_plan,
838        } = engine.describe(plan, QueryContext::arc()).await.unwrap();
839
840        assert_eq!(
841            schema.column_schemas()[0],
842            ColumnSchema::new(
843                "sum(numbers.number)",
844                ConcreteDataType::uint64_datatype(),
845                true
846            )
847        );
848        assert_eq!(
849            "Limit: skip=0, fetch=20\n  Aggregate: groupBy=[[]], aggr=[[sum(CAST(numbers.number AS UInt64))]]\n    TableScan: numbers projection=[number]",
850            format!("{}", logical_plan.display_indent())
851        );
852    }
853}