query/
datafusion.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Planner, QueryEngine implementations based on DataFusion.
16
17mod error;
18mod planner;
19
20use std::any::Any;
21use std::collections::HashMap;
22use std::sync::Arc;
23
24use async_trait::async_trait;
25use common_base::Plugins;
26use common_catalog::consts::is_readonly_schema;
27use common_error::ext::BoxedError;
28use common_function::function::FunctionContext;
29use common_function::function_factory::ScalarFunctionFactory;
30use common_query::{Output, OutputData, OutputMeta};
31use common_recordbatch::adapter::RecordBatchStreamAdapter;
32use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream};
33use common_telemetry::tracing;
34use datafusion::catalog::TableFunction;
35use datafusion::dataframe::DataFrame;
36use datafusion::physical_plan::ExecutionPlan;
37use datafusion::physical_plan::analyze::AnalyzeExec;
38use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
39use datafusion_common::ResolvedTableReference;
40use datafusion_expr::{
41    AggregateUDF, DmlStatement, LogicalPlan as DfLogicalPlan, LogicalPlan, WindowUDF, WriteOp,
42};
43use datatypes::prelude::VectorRef;
44use datatypes::schema::Schema;
45use futures_util::StreamExt;
46use session::context::QueryContextRef;
47use snafu::{OptionExt, ResultExt, ensure};
48use sqlparser::ast::AnalyzeFormat;
49use table::TableRef;
50use table::requests::{DeleteRequest, InsertRequest};
51use tracing::Span;
52
53use crate::analyze::DistAnalyzeExec;
54pub use crate::datafusion::planner::DfContextProviderAdapter;
55use crate::dist_plan::{DistPlannerOptions, MergeScanLogicalPlan};
56use crate::error::{
57    CatalogSnafu, ConvertSchemaSnafu, CreateRecordBatchSnafu, MissingTableMutationHandlerSnafu,
58    MissingTimestampColumnSnafu, QueryExecutionSnafu, Result, TableMutationSnafu,
59    TableNotFoundSnafu, TableReadOnlySnafu, UnsupportedExprSnafu,
60};
61use crate::executor::QueryExecutor;
62use crate::metrics::{OnDone, QUERY_STAGE_ELAPSED};
63use crate::physical_wrapper::PhysicalPlanWrapperRef;
64use crate::planner::{DfLogicalPlanner, LogicalPlanner};
65use crate::query_engine::{DescribeResult, QueryEngineContext, QueryEngineState};
66use crate::{QueryEngine, metrics};
67
68/// Query parallelism hint key.
69/// This hint can be set in the query context to control the parallelism of the query execution.
70pub const QUERY_PARALLELISM_HINT: &str = "query_parallelism";
71
72/// Whether to fallback to the original plan when failed to push down.
73pub const QUERY_FALLBACK_HINT: &str = "query_fallback";
74
75pub struct DatafusionQueryEngine {
76    state: Arc<QueryEngineState>,
77    plugins: Plugins,
78}
79
80impl DatafusionQueryEngine {
81    pub fn new(state: Arc<QueryEngineState>, plugins: Plugins) -> Self {
82        Self { state, plugins }
83    }
84
85    #[tracing::instrument(skip_all)]
86    async fn exec_query_plan(
87        &self,
88        plan: LogicalPlan,
89        query_ctx: QueryContextRef,
90    ) -> Result<Output> {
91        let mut ctx = self.engine_context(query_ctx.clone());
92
93        // `create_physical_plan` will optimize logical plan internally
94        let physical_plan = self.create_physical_plan(&mut ctx, &plan).await?;
95        let optimized_physical_plan = self.optimize_physical_plan(&mut ctx, physical_plan)?;
96
97        let physical_plan = if let Some(wrapper) = self.plugins.get::<PhysicalPlanWrapperRef>() {
98            wrapper.wrap(optimized_physical_plan, query_ctx)
99        } else {
100            optimized_physical_plan
101        };
102
103        Ok(Output::new(
104            OutputData::Stream(self.execute_stream(&ctx, &physical_plan)?),
105            OutputMeta::new_with_plan(physical_plan),
106        ))
107    }
108
109    #[tracing::instrument(skip_all)]
110    async fn exec_dml_statement(
111        &self,
112        dml: DmlStatement,
113        query_ctx: QueryContextRef,
114    ) -> Result<Output> {
115        ensure!(
116            matches!(dml.op, WriteOp::Insert(_) | WriteOp::Delete),
117            UnsupportedExprSnafu {
118                name: format!("DML op {}", dml.op),
119            }
120        );
121
122        let _timer = QUERY_STAGE_ELAPSED
123            .with_label_values(&[dml.op.name()])
124            .start_timer();
125
126        let default_catalog = &query_ctx.current_catalog().to_owned();
127        let default_schema = &query_ctx.current_schema();
128        let table_name = dml.table_name.resolve(default_catalog, default_schema);
129        let table = self.find_table(&table_name, &query_ctx).await?;
130
131        let output = self
132            .exec_query_plan((*dml.input).clone(), query_ctx.clone())
133            .await?;
134        let mut stream = match output.data {
135            OutputData::RecordBatches(batches) => batches.as_stream(),
136            OutputData::Stream(stream) => stream,
137            _ => unreachable!(),
138        };
139
140        let mut affected_rows = 0;
141        let mut insert_cost = 0;
142
143        while let Some(batch) = stream.next().await {
144            let batch = batch.context(CreateRecordBatchSnafu)?;
145            let column_vectors = batch
146                .column_vectors(&table_name.to_string(), table.schema())
147                .map_err(BoxedError::new)
148                .context(QueryExecutionSnafu)?;
149
150            match dml.op {
151                WriteOp::Insert(_) => {
152                    // We ignore the insert op.
153                    let output = self
154                        .insert(&table_name, column_vectors, query_ctx.clone())
155                        .await?;
156                    let (rows, cost) = output.extract_rows_and_cost();
157                    affected_rows += rows;
158                    insert_cost += cost;
159                }
160                WriteOp::Delete => {
161                    affected_rows += self
162                        .delete(&table_name, &table, column_vectors, query_ctx.clone())
163                        .await?;
164                }
165                _ => unreachable!("guarded by the 'ensure!' at the beginning"),
166            }
167        }
168        Ok(Output::new(
169            OutputData::AffectedRows(affected_rows),
170            OutputMeta::new_with_cost(insert_cost),
171        ))
172    }
173
174    #[tracing::instrument(skip_all)]
175    async fn delete(
176        &self,
177        table_name: &ResolvedTableReference,
178        table: &TableRef,
179        column_vectors: HashMap<String, VectorRef>,
180        query_ctx: QueryContextRef,
181    ) -> Result<usize> {
182        let catalog_name = table_name.catalog.to_string();
183        let schema_name = table_name.schema.to_string();
184        let table_name = table_name.table.to_string();
185        let table_schema = table.schema();
186
187        ensure!(
188            !is_readonly_schema(&schema_name),
189            TableReadOnlySnafu { table: table_name }
190        );
191
192        let ts_column = table_schema
193            .timestamp_column()
194            .map(|x| &x.name)
195            .with_context(|| MissingTimestampColumnSnafu {
196                table_name: table_name.clone(),
197            })?;
198
199        let table_info = table.table_info();
200        let rowkey_columns = table_info
201            .meta
202            .row_key_column_names()
203            .collect::<Vec<&String>>();
204        let column_vectors = column_vectors
205            .into_iter()
206            .filter(|x| &x.0 == ts_column || rowkey_columns.contains(&&x.0))
207            .collect::<HashMap<_, _>>();
208
209        let request = DeleteRequest {
210            catalog_name,
211            schema_name,
212            table_name,
213            key_column_values: column_vectors,
214        };
215
216        self.state
217            .table_mutation_handler()
218            .context(MissingTableMutationHandlerSnafu)?
219            .delete(request, query_ctx)
220            .await
221            .context(TableMutationSnafu)
222    }
223
224    #[tracing::instrument(skip_all)]
225    async fn insert(
226        &self,
227        table_name: &ResolvedTableReference,
228        column_vectors: HashMap<String, VectorRef>,
229        query_ctx: QueryContextRef,
230    ) -> Result<Output> {
231        let catalog_name = table_name.catalog.to_string();
232        let schema_name = table_name.schema.to_string();
233        let table_name = table_name.table.to_string();
234
235        ensure!(
236            !is_readonly_schema(&schema_name),
237            TableReadOnlySnafu { table: table_name }
238        );
239
240        let request = InsertRequest {
241            catalog_name,
242            schema_name,
243            table_name,
244            columns_values: column_vectors,
245        };
246
247        self.state
248            .table_mutation_handler()
249            .context(MissingTableMutationHandlerSnafu)?
250            .insert(request, query_ctx)
251            .await
252            .context(TableMutationSnafu)
253    }
254
255    async fn find_table(
256        &self,
257        table_name: &ResolvedTableReference,
258        query_context: &QueryContextRef,
259    ) -> Result<TableRef> {
260        let catalog_name = table_name.catalog.as_ref();
261        let schema_name = table_name.schema.as_ref();
262        let table_name = table_name.table.as_ref();
263
264        self.state
265            .catalog_manager()
266            .table(catalog_name, schema_name, table_name, Some(query_context))
267            .await
268            .context(CatalogSnafu)?
269            .with_context(|| TableNotFoundSnafu { table: table_name })
270    }
271
272    #[tracing::instrument(skip_all)]
273    async fn create_physical_plan(
274        &self,
275        ctx: &mut QueryEngineContext,
276        logical_plan: &LogicalPlan,
277    ) -> Result<Arc<dyn ExecutionPlan>> {
278        /// Only print context on panic, to avoid cluttering logs.
279        ///
280        /// TODO(discord9): remove this once we catch the bug
281        #[derive(Debug)]
282        struct PanicLogger<'a> {
283            input_logical_plan: &'a LogicalPlan,
284            after_analyze: Option<LogicalPlan>,
285            after_optimize: Option<LogicalPlan>,
286            phy_plan: Option<Arc<dyn ExecutionPlan>>,
287        }
288        impl Drop for PanicLogger<'_> {
289            fn drop(&mut self) {
290                if std::thread::panicking() {
291                    common_telemetry::error!(
292                        "Panic while creating physical plan, input logical plan: {:?}, after analyze: {:?}, after optimize: {:?}, final physical plan: {:?}",
293                        self.input_logical_plan,
294                        self.after_analyze,
295                        self.after_optimize,
296                        self.phy_plan
297                    );
298                }
299            }
300        }
301
302        let mut logger = PanicLogger {
303            input_logical_plan: logical_plan,
304            after_analyze: None,
305            after_optimize: None,
306            phy_plan: None,
307        };
308
309        let _timer = metrics::CREATE_PHYSICAL_ELAPSED.start_timer();
310        let state = ctx.state();
311
312        common_telemetry::debug!("Create physical plan, input plan: {logical_plan}");
313
314        // special handle EXPLAIN plan
315        if matches!(logical_plan, DfLogicalPlan::Explain(_)) {
316            return state
317                .create_physical_plan(logical_plan)
318                .await
319                .map_err(Into::into);
320        }
321
322        // analyze first
323        let analyzed_plan = state.analyzer().execute_and_check(
324            logical_plan.clone(),
325            state.config_options(),
326            |_, _| {},
327        )?;
328
329        logger.after_analyze = Some(analyzed_plan.clone());
330
331        common_telemetry::debug!("Create physical plan, analyzed plan: {analyzed_plan}");
332
333        // skip optimize for MergeScan
334        let optimized_plan = if let DfLogicalPlan::Extension(ext) = &analyzed_plan
335            && ext.node.name() == MergeScanLogicalPlan::name()
336        {
337            analyzed_plan.clone()
338        } else {
339            state
340                .optimizer()
341                .optimize(analyzed_plan, state, |_, _| {})?
342        };
343
344        common_telemetry::debug!("Create physical plan, optimized plan: {optimized_plan}");
345        logger.after_optimize = Some(optimized_plan.clone());
346
347        let physical_plan = state
348            .query_planner()
349            .create_physical_plan(&optimized_plan, state)
350            .await?;
351
352        logger.phy_plan = Some(physical_plan.clone());
353        drop(logger);
354        Ok(physical_plan)
355    }
356
357    #[tracing::instrument(skip_all)]
358    fn optimize_physical_plan(
359        &self,
360        ctx: &mut QueryEngineContext,
361        plan: Arc<dyn ExecutionPlan>,
362    ) -> Result<Arc<dyn ExecutionPlan>> {
363        let _timer = metrics::OPTIMIZE_PHYSICAL_ELAPSED.start_timer();
364
365        // TODO(ruihang): `self.create_physical_plan()` already optimize the plan, check
366        // if we need to optimize it again here.
367        // let state = ctx.state();
368        // let config = state.config_options();
369
370        // skip optimize AnalyzeExec plan
371        let optimized_plan = if let Some(analyze_plan) = plan.as_any().downcast_ref::<AnalyzeExec>()
372        {
373            let format = if let Some(format) = ctx.query_ctx().explain_format()
374                && format.to_lowercase() == "json"
375            {
376                AnalyzeFormat::JSON
377            } else {
378                AnalyzeFormat::TEXT
379            };
380            // Sets the verbose flag of the query context.
381            // The MergeScanExec plan uses the verbose flag to determine whether to print the plan in verbose mode.
382            ctx.query_ctx().set_explain_verbose(analyze_plan.verbose());
383
384            Arc::new(DistAnalyzeExec::new(
385                analyze_plan.input().clone(),
386                analyze_plan.verbose(),
387                format,
388            ))
389            // let mut new_plan = analyze_plan.input().clone();
390            // for optimizer in state.physical_optimizers() {
391            //     new_plan = optimizer
392            //         .optimize(new_plan, config)
393            //         .context(DataFusionSnafu)?;
394            // }
395            // Arc::new(DistAnalyzeExec::new(new_plan))
396        } else {
397            plan
398            // let mut new_plan = plan;
399            // for optimizer in state.physical_optimizers() {
400            //     new_plan = optimizer
401            //         .optimize(new_plan, config)
402            //         .context(DataFusionSnafu)?;
403            // }
404            // new_plan
405        };
406
407        Ok(optimized_plan)
408    }
409}
410
411#[async_trait]
412impl QueryEngine for DatafusionQueryEngine {
413    fn as_any(&self) -> &dyn Any {
414        self
415    }
416
417    fn planner(&self) -> Arc<dyn LogicalPlanner> {
418        Arc::new(DfLogicalPlanner::new(self.state.clone()))
419    }
420
421    fn name(&self) -> &str {
422        "datafusion"
423    }
424
425    async fn describe(
426        &self,
427        plan: LogicalPlan,
428        _query_ctx: QueryContextRef,
429    ) -> Result<DescribeResult> {
430        let schema = plan
431            .schema()
432            .clone()
433            .try_into()
434            .context(ConvertSchemaSnafu)?;
435        Ok(DescribeResult {
436            schema,
437            logical_plan: plan,
438        })
439    }
440
441    async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result<Output> {
442        match plan {
443            LogicalPlan::Dml(dml) => self.exec_dml_statement(dml, query_ctx).await,
444            _ => self.exec_query_plan(plan, query_ctx).await,
445        }
446    }
447
448    /// Note in SQL queries, aggregate names are looked up using
449    /// lowercase unless the query uses quotes. For example,
450    ///
451    /// `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"`
452    /// `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"`
453    ///
454    /// So it's better to make UDAF name lowercase when creating one.
455    fn register_aggregate_function(&self, func: AggregateUDF) {
456        self.state.register_aggr_function(func);
457    }
458
459    /// Register an scalar function.
460    /// Will override if the function with same name is already registered.
461    fn register_scalar_function(&self, func: ScalarFunctionFactory) {
462        self.state.register_scalar_function(func);
463    }
464
465    fn register_table_function(&self, func: Arc<TableFunction>) {
466        self.state.register_table_function(func);
467    }
468
469    fn register_window_function(&self, func: WindowUDF) {
470        self.state.register_window_function(func);
471    }
472
473    fn read_table(&self, table: TableRef) -> Result<DataFrame> {
474        self.state.read_table(table).map_err(Into::into)
475    }
476
477    fn engine_context(&self, query_ctx: QueryContextRef) -> QueryEngineContext {
478        let mut state = self.state.session_state();
479        state.config_mut().set_extension(query_ctx.clone());
480        // note that hints in "x-greptime-hints" is automatically parsed
481        // and set to query context's extension, so we can get it from query context.
482        if let Some(parallelism) = query_ctx.extension(QUERY_PARALLELISM_HINT) {
483            if let Ok(n) = parallelism.parse::<u64>() {
484                if n > 0 {
485                    let new_cfg = state.config().clone().with_target_partitions(n as usize);
486                    *state.config_mut() = new_cfg;
487                }
488            } else {
489                common_telemetry::warn!(
490                    "Failed to parse query_parallelism: {}, using default value",
491                    parallelism
492                );
493            }
494        }
495
496        // configure execution options
497        state.config_mut().options_mut().execution.time_zone =
498            Some(query_ctx.timezone().to_string());
499
500        // usually it's impossible to have both `set variable` set by sql client and
501        // hint in header by grpc client, so only need to deal with them separately
502        if query_ctx.configuration_parameter().allow_query_fallback() {
503            state
504                .config_mut()
505                .options_mut()
506                .extensions
507                .insert(DistPlannerOptions {
508                    allow_query_fallback: true,
509                });
510        } else if let Some(fallback) = query_ctx.extension(QUERY_FALLBACK_HINT) {
511            // also check the query context for fallback hint
512            // if it is set, we will enable the fallback
513            if fallback.to_lowercase().parse::<bool>().unwrap_or(false) {
514                state
515                    .config_mut()
516                    .options_mut()
517                    .extensions
518                    .insert(DistPlannerOptions {
519                        allow_query_fallback: true,
520                    });
521            }
522        }
523
524        state
525            .config_mut()
526            .options_mut()
527            .extensions
528            .insert(FunctionContext {
529                query_ctx: query_ctx.clone(),
530                state: self.engine_state().function_state(),
531            });
532
533        let config_options = state.config_options().clone();
534        let _ = state
535            .execution_props_mut()
536            .config_options
537            .insert(config_options);
538
539        QueryEngineContext::new(state, query_ctx)
540    }
541
542    fn engine_state(&self) -> &QueryEngineState {
543        &self.state
544    }
545}
546
547impl QueryExecutor for DatafusionQueryEngine {
548    #[tracing::instrument(skip_all)]
549    fn execute_stream(
550        &self,
551        ctx: &QueryEngineContext,
552        plan: &Arc<dyn ExecutionPlan>,
553    ) -> Result<SendableRecordBatchStream> {
554        let explain_verbose = ctx.query_ctx().explain_verbose();
555        let output_partitions = plan.properties().output_partitioning().partition_count();
556        if explain_verbose {
557            common_telemetry::info!("Executing query plan, output_partitions: {output_partitions}");
558        }
559
560        let exec_timer = metrics::EXEC_PLAN_ELAPSED.start_timer();
561        let task_ctx = ctx.build_task_ctx();
562        let span = Span::current();
563
564        match plan.properties().output_partitioning().partition_count() {
565            0 => {
566                let schema = Arc::new(
567                    Schema::try_from(plan.schema())
568                        .map_err(BoxedError::new)
569                        .context(QueryExecutionSnafu)?,
570                );
571                Ok(Box::pin(EmptyRecordBatchStream::new(schema)))
572            }
573            1 => {
574                let df_stream = plan.execute(0, task_ctx)?;
575                let mut stream = RecordBatchStreamAdapter::try_new_with_span(df_stream, span)
576                    .context(error::ConvertDfRecordBatchStreamSnafu)
577                    .map_err(BoxedError::new)
578                    .context(QueryExecutionSnafu)?;
579                stream.set_metrics2(plan.clone());
580                stream.set_explain_verbose(explain_verbose);
581                let stream = OnDone::new(Box::pin(stream), move || {
582                    let exec_cost = exec_timer.stop_and_record();
583                    if explain_verbose {
584                        common_telemetry::info!(
585                            "DatafusionQueryEngine execute 1 stream, cost: {:?}s",
586                            exec_cost,
587                        );
588                    }
589                });
590                Ok(Box::pin(stream))
591            }
592            _ => {
593                // merge into a single partition
594                let merged_plan = CoalescePartitionsExec::new(plan.clone());
595                // CoalescePartitionsExec must produce a single partition
596                assert_eq!(
597                    1,
598                    merged_plan
599                        .properties()
600                        .output_partitioning()
601                        .partition_count()
602                );
603                let df_stream = merged_plan.execute(0, task_ctx)?;
604                let mut stream = RecordBatchStreamAdapter::try_new_with_span(df_stream, span)
605                    .context(error::ConvertDfRecordBatchStreamSnafu)
606                    .map_err(BoxedError::new)
607                    .context(QueryExecutionSnafu)?;
608                stream.set_metrics2(plan.clone());
609                stream.set_explain_verbose(ctx.query_ctx().explain_verbose());
610                let stream = OnDone::new(Box::pin(stream), move || {
611                    let exec_cost = exec_timer.stop_and_record();
612                    if explain_verbose {
613                        common_telemetry::info!(
614                            "DatafusionQueryEngine execute {output_partitions} stream, cost: {:?}s",
615                            exec_cost
616                        );
617                    }
618                });
619                Ok(Box::pin(stream))
620            }
621        }
622    }
623}
624
625#[cfg(test)]
626mod tests {
627    use std::fmt;
628    use std::sync::Arc;
629    use std::sync::atomic::{AtomicUsize, Ordering};
630
631    use api::v1::SemanticType;
632    use arrow::array::{ArrayRef, UInt64Array};
633    use arrow_schema::SortOptions;
634    use catalog::RegisterTableRequest;
635    use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, NUMBERS_TABLE_ID};
636    use common_error::ext::BoxedError;
637    use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream, util};
638    use datafusion::physical_plan::display::{DisplayAs, DisplayFormatType};
639    use datafusion::physical_plan::expressions::PhysicalSortExpr;
640    use datafusion::physical_plan::joins::{HashJoinExec, JoinOn, PartitionMode};
641    use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet;
642    use datafusion::physical_plan::{ExecutionPlan, PhysicalExpr};
643    use datafusion::prelude::{col, lit};
644    use datafusion_common::{JoinType, NullEquality};
645    use datafusion_physical_expr::expressions::Column;
646    use datatypes::prelude::ConcreteDataType;
647    use datatypes::schema::{ColumnSchema, SchemaRef};
648    use datatypes::vectors::{Helper, UInt32Vector, VectorRef};
649    use session::context::{QueryContext, QueryContextBuilder};
650    use store_api::metadata::{ColumnMetadata, RegionMetadataBuilder, RegionMetadataRef};
651    use store_api::region_engine::{
652        PartitionRange, PrepareRequest, QueryScanContext, RegionScanner, ScannerProperties,
653    };
654    use store_api::storage::{RegionId, ScanRequest};
655    use table::table::numbers::{NUMBERS_TABLE_NAME, NumbersTable};
656    use table::table::scan::RegionScanExec;
657
658    use super::*;
659    use crate::options::QueryOptions;
660    use crate::parser::QueryLanguageParser;
661    use crate::part_sort::PartSortExec;
662    use crate::query_engine::{QueryEngineFactory, QueryEngineRef};
663
664    #[derive(Debug)]
665    struct RecordingScanner {
666        schema: SchemaRef,
667        metadata: RegionMetadataRef,
668        properties: ScannerProperties,
669        update_calls: Arc<AtomicUsize>,
670        last_filter_len: Arc<AtomicUsize>,
671    }
672
673    impl RecordingScanner {
674        fn new(
675            schema: SchemaRef,
676            metadata: RegionMetadataRef,
677            update_calls: Arc<AtomicUsize>,
678            last_filter_len: Arc<AtomicUsize>,
679        ) -> Self {
680            Self {
681                schema,
682                metadata,
683                properties: ScannerProperties::default(),
684                update_calls,
685                last_filter_len,
686            }
687        }
688    }
689
690    impl RegionScanner for RecordingScanner {
691        fn name(&self) -> &str {
692            "RecordingScanner"
693        }
694
695        fn properties(&self) -> &ScannerProperties {
696            &self.properties
697        }
698
699        fn schema(&self) -> SchemaRef {
700            self.schema.clone()
701        }
702
703        fn metadata(&self) -> RegionMetadataRef {
704            self.metadata.clone()
705        }
706
707        fn prepare(&mut self, request: PrepareRequest) -> std::result::Result<(), BoxedError> {
708            self.properties.prepare(request);
709            Ok(())
710        }
711
712        fn scan_partition(
713            &self,
714            _ctx: &QueryScanContext,
715            _metrics_set: &ExecutionPlanMetricsSet,
716            _partition: usize,
717        ) -> std::result::Result<SendableRecordBatchStream, BoxedError> {
718            Ok(Box::pin(EmptyRecordBatchStream::new(self.schema.clone())))
719        }
720
721        fn has_predicate_without_region(&self) -> bool {
722            true
723        }
724
725        fn add_dyn_filter_to_predicate(
726            &mut self,
727            filter_exprs: Vec<Arc<dyn PhysicalExpr>>,
728        ) -> Vec<bool> {
729            self.update_calls.fetch_add(1, Ordering::Relaxed);
730            self.last_filter_len
731                .store(filter_exprs.len(), Ordering::Relaxed);
732            vec![true; filter_exprs.len()]
733        }
734
735        fn set_logical_region(&mut self, logical_region: bool) {
736            self.properties.set_logical_region(logical_region);
737        }
738    }
739
740    impl DisplayAs for RecordingScanner {
741        fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
742            write!(f, "RecordingScanner")
743        }
744    }
745
746    async fn create_test_engine() -> QueryEngineRef {
747        let catalog_manager = catalog::memory::new_memory_catalog_manager().unwrap();
748        let req = RegisterTableRequest {
749            catalog: DEFAULT_CATALOG_NAME.to_string(),
750            schema: DEFAULT_SCHEMA_NAME.to_string(),
751            table_name: NUMBERS_TABLE_NAME.to_string(),
752            table_id: NUMBERS_TABLE_ID,
753            table: NumbersTable::table(NUMBERS_TABLE_ID),
754        };
755        catalog_manager.register_table_sync(req).unwrap();
756
757        QueryEngineFactory::new(
758            catalog_manager,
759            None,
760            None,
761            None,
762            None,
763            false,
764            QueryOptions::default(),
765        )
766        .query_engine()
767    }
768
769    #[tokio::test]
770    async fn test_sql_to_plan() {
771        let engine = create_test_engine().await;
772        let sql = "select sum(number) from numbers limit 20";
773
774        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
775        let plan = engine
776            .planner()
777            .plan(&stmt, QueryContext::arc())
778            .await
779            .unwrap();
780
781        assert_eq!(
782            plan.to_string(),
783            r#"Limit: skip=0, fetch=20
784  Projection: sum(numbers.number)
785    Aggregate: groupBy=[[]], aggr=[[sum(numbers.number)]]
786      TableScan: numbers"#
787        );
788    }
789
790    #[tokio::test]
791    async fn test_execute() {
792        let engine = create_test_engine().await;
793        let sql = "select sum(number) from numbers limit 20";
794
795        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
796        let plan = engine
797            .planner()
798            .plan(&stmt, QueryContext::arc())
799            .await
800            .unwrap();
801
802        let output = engine.execute(plan, QueryContext::arc()).await.unwrap();
803
804        match output.data {
805            OutputData::Stream(recordbatch) => {
806                let numbers = util::collect(recordbatch).await.unwrap();
807                assert_eq!(1, numbers.len());
808                assert_eq!(numbers[0].num_columns(), 1);
809                assert_eq!(1, numbers[0].schema.num_columns());
810                assert_eq!(
811                    "sum(numbers.number)",
812                    numbers[0].schema.column_schemas()[0].name
813                );
814
815                let batch = &numbers[0];
816                assert_eq!(1, batch.num_columns());
817                assert_eq!(batch.column(0).len(), 1);
818
819                let expected = Arc::new(UInt64Array::from_iter_values([4950])) as ArrayRef;
820                assert_eq!(batch.column(0), &expected);
821            }
822            _ => unreachable!(),
823        }
824    }
825
826    #[tokio::test]
827    async fn test_read_table() {
828        let engine = create_test_engine().await;
829
830        let engine = engine
831            .as_any()
832            .downcast_ref::<DatafusionQueryEngine>()
833            .unwrap();
834        let query_ctx = Arc::new(QueryContextBuilder::default().build());
835        let table = engine
836            .find_table(
837                &ResolvedTableReference {
838                    catalog: "greptime".into(),
839                    schema: "public".into(),
840                    table: "numbers".into(),
841                },
842                &query_ctx,
843            )
844            .await
845            .unwrap();
846
847        let df = engine.read_table(table).unwrap();
848        let df = df
849            .select_columns(&["number"])
850            .unwrap()
851            .filter(col("number").lt(lit(10)))
852            .unwrap();
853        let batches = df.collect().await.unwrap();
854        assert_eq!(1, batches.len());
855        let batch = &batches[0];
856
857        assert_eq!(1, batch.num_columns());
858        assert_eq!(batch.column(0).len(), 10);
859
860        assert_eq!(
861            Helper::try_into_vector(batch.column(0)).unwrap(),
862            Arc::new(UInt32Vector::from_slice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) as VectorRef
863        );
864    }
865
866    #[tokio::test]
867    async fn test_describe() {
868        let engine = create_test_engine().await;
869        let sql = "select sum(number) from numbers limit 20";
870
871        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
872
873        let plan = engine
874            .planner()
875            .plan(&stmt, QueryContext::arc())
876            .await
877            .unwrap();
878
879        let DescribeResult {
880            schema,
881            logical_plan,
882        } = engine.describe(plan, QueryContext::arc()).await.unwrap();
883
884        assert_eq!(
885            schema.column_schemas()[0],
886            ColumnSchema::new(
887                "sum(numbers.number)",
888                ConcreteDataType::uint64_datatype(),
889                true
890            )
891        );
892        assert_eq!(
893            "Limit: skip=0, fetch=20\n  Projection: sum(numbers.number)\n    Aggregate: groupBy=[[]], aggr=[[sum(numbers.number)]]\n      TableScan: numbers",
894            format!("{}", logical_plan.display_indent())
895        );
896    }
897
898    #[tokio::test]
899    async fn test_topk_dynamic_filter_pushdown_reaches_region_scan() {
900        let engine = create_test_engine().await;
901        let engine = engine
902            .as_any()
903            .downcast_ref::<DatafusionQueryEngine>()
904            .unwrap();
905        let engine_ctx = engine.engine_context(QueryContext::arc());
906        let state = engine_ctx.state();
907
908        let schema = Arc::new(datatypes::schema::Schema::new(vec![ColumnSchema::new(
909            "ts",
910            ConcreteDataType::timestamp_millisecond_datatype(),
911            false,
912        )]));
913
914        let mut metadata_builder = RegionMetadataBuilder::new(RegionId::new(1024, 1));
915        metadata_builder
916            .push_column_metadata(ColumnMetadata {
917                column_schema: ColumnSchema::new(
918                    "ts",
919                    ConcreteDataType::timestamp_millisecond_datatype(),
920                    false,
921                )
922                .with_time_index(true),
923                semantic_type: SemanticType::Timestamp,
924                column_id: 1,
925            })
926            .primary_key(vec![]);
927        let metadata = Arc::new(metadata_builder.build().unwrap());
928
929        let update_calls = Arc::new(AtomicUsize::new(0));
930        let last_filter_len = Arc::new(AtomicUsize::new(0));
931        let scanner = Box::new(RecordingScanner::new(
932            schema,
933            metadata,
934            update_calls.clone(),
935            last_filter_len.clone(),
936        ));
937        let scan = Arc::new(RegionScanExec::new(scanner, ScanRequest::default(), None).unwrap());
938
939        let sort_expr = PhysicalSortExpr {
940            expr: Arc::new(Column::new("ts", 0)),
941            options: SortOptions {
942                descending: true,
943                ..Default::default()
944            },
945        };
946        let partition_ranges: Vec<Vec<PartitionRange>> = vec![vec![]];
947        let mut plan: Arc<dyn ExecutionPlan> =
948            Arc::new(PartSortExec::try_new(sort_expr, Some(3), partition_ranges, scan).unwrap());
949
950        for optimizer in state.physical_optimizers() {
951            plan = optimizer.optimize(plan, state.config_options()).unwrap();
952        }
953
954        assert!(update_calls.load(Ordering::Relaxed) > 0);
955        assert!(last_filter_len.load(Ordering::Relaxed) > 0);
956    }
957
958    #[tokio::test]
959    async fn test_join_dynamic_filter_pushdown_reaches_region_scan() {
960        let engine = create_test_engine().await;
961        let engine = engine
962            .as_any()
963            .downcast_ref::<DatafusionQueryEngine>()
964            .unwrap();
965        let engine_ctx = engine.engine_context(QueryContext::arc());
966        let state = engine_ctx.state();
967
968        assert!(
969            state
970                .config_options()
971                .optimizer
972                .enable_join_dynamic_filter_pushdown
973        );
974
975        let schema = Arc::new(datatypes::schema::Schema::new(vec![ColumnSchema::new(
976            "ts",
977            ConcreteDataType::timestamp_millisecond_datatype(),
978            false,
979        )]));
980
981        let mut left_metadata_builder = RegionMetadataBuilder::new(RegionId::new(2048, 1));
982        left_metadata_builder
983            .push_column_metadata(ColumnMetadata {
984                column_schema: ColumnSchema::new(
985                    "ts",
986                    ConcreteDataType::timestamp_millisecond_datatype(),
987                    false,
988                )
989                .with_time_index(true),
990                semantic_type: SemanticType::Timestamp,
991                column_id: 1,
992            })
993            .primary_key(vec![]);
994        let left_metadata = Arc::new(left_metadata_builder.build().unwrap());
995
996        let mut right_metadata_builder = RegionMetadataBuilder::new(RegionId::new(2048, 2));
997        right_metadata_builder
998            .push_column_metadata(ColumnMetadata {
999                column_schema: ColumnSchema::new(
1000                    "ts",
1001                    ConcreteDataType::timestamp_millisecond_datatype(),
1002                    false,
1003                )
1004                .with_time_index(true),
1005                semantic_type: SemanticType::Timestamp,
1006                column_id: 1,
1007            })
1008            .primary_key(vec![]);
1009        let right_metadata = Arc::new(right_metadata_builder.build().unwrap());
1010
1011        let left_update_calls = Arc::new(AtomicUsize::new(0));
1012        let left_last_filter_len = Arc::new(AtomicUsize::new(0));
1013        let right_update_calls = Arc::new(AtomicUsize::new(0));
1014        let right_last_filter_len = Arc::new(AtomicUsize::new(0));
1015
1016        let left_scan = Arc::new(
1017            RegionScanExec::new(
1018                Box::new(RecordingScanner::new(
1019                    schema.clone(),
1020                    left_metadata,
1021                    left_update_calls.clone(),
1022                    left_last_filter_len.clone(),
1023                )),
1024                ScanRequest::default(),
1025                None,
1026            )
1027            .unwrap(),
1028        );
1029        let right_scan = Arc::new(
1030            RegionScanExec::new(
1031                Box::new(RecordingScanner::new(
1032                    schema,
1033                    right_metadata,
1034                    right_update_calls.clone(),
1035                    right_last_filter_len.clone(),
1036                )),
1037                ScanRequest::default(),
1038                None,
1039            )
1040            .unwrap(),
1041        );
1042
1043        let on: JoinOn = vec![(
1044            Arc::new(Column::new("ts", 0)) as Arc<dyn PhysicalExpr>,
1045            Arc::new(Column::new("ts", 0)) as Arc<dyn PhysicalExpr>,
1046        )];
1047
1048        let mut plan: Arc<dyn ExecutionPlan> = Arc::new(
1049            HashJoinExec::try_new(
1050                left_scan,
1051                right_scan,
1052                on,
1053                None,
1054                &JoinType::Inner,
1055                None,
1056                PartitionMode::CollectLeft,
1057                NullEquality::NullEqualsNull,
1058                false,
1059            )
1060            .unwrap(),
1061        );
1062
1063        for optimizer in state.physical_optimizers() {
1064            plan = optimizer.optimize(plan, state.config_options()).unwrap();
1065        }
1066
1067        assert!(left_update_calls.load(Ordering::Relaxed) > 0);
1068        assert_eq!(0, left_last_filter_len.load(Ordering::Relaxed));
1069        assert!(right_update_calls.load(Ordering::Relaxed) > 0);
1070        assert!(right_last_filter_len.load(Ordering::Relaxed) > 0);
1071    }
1072}