Skip to main content

query/
datafusion.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Planner, QueryEngine implementations based on DataFusion.
16
17mod error;
18mod json_expr_planner;
19mod planner;
20
21use std::any::Any;
22use std::collections::HashMap;
23use std::sync::Arc;
24
25use async_trait::async_trait;
26use common_base::Plugins;
27use common_catalog::consts::is_readonly_schema;
28use common_error::ext::BoxedError;
29use common_function::function::FunctionContext;
30use common_function::function_factory::ScalarFunctionFactory;
31use common_query::{Output, OutputData, OutputMeta};
32use common_recordbatch::adapter::RecordBatchStreamAdapter;
33use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream};
34use common_telemetry::tracing;
35use datafusion::catalog::TableFunction;
36use datafusion::dataframe::DataFrame;
37use datafusion::physical_plan::ExecutionPlan;
38use datafusion::physical_plan::analyze::AnalyzeExec;
39use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
40use datafusion_common::ResolvedTableReference;
41use datafusion_expr::{
42    AggregateUDF, DmlStatement, LogicalPlan as DfLogicalPlan, LogicalPlan, WindowUDF, WriteOp,
43};
44use datatypes::prelude::VectorRef;
45use datatypes::schema::Schema;
46use futures_util::StreamExt;
47use session::context::QueryContextRef;
48use snafu::{OptionExt, ResultExt, ensure};
49use sqlparser::ast::AnalyzeFormat;
50use table::TableRef;
51use table::requests::{DeleteRequest, InsertRequest};
52use tracing::Span;
53
54use crate::analyze::DistAnalyzeExec;
55pub use crate::datafusion::planner::DfContextProviderAdapter;
56use crate::dist_plan::{
57    DistPlannerOptions, MergeScanLogicalPlan, RemoteDynFilterReceiverInjectorRef,
58};
59use crate::error::{
60    CatalogSnafu, CreateRecordBatchSnafu, MissingTableMutationHandlerSnafu,
61    MissingTimestampColumnSnafu, QueryExecutionSnafu, Result, TableMutationSnafu,
62    TableNotFoundSnafu, TableReadOnlySnafu, UnsupportedExprSnafu,
63};
64use crate::executor::QueryExecutor;
65use crate::metrics::{
66    OnDone, QUERY_STAGE_ELAPSED, maybe_attach_region_watermark_metrics,
67    should_collect_region_watermark_from_query_ctx,
68};
69use crate::physical_wrapper::PhysicalPlanWrapperRef;
70use crate::planner::{DfLogicalPlanner, LogicalPlanner};
71use crate::query_engine::{DescribeResult, QueryEngineContext, QueryEngineState};
72use crate::{QueryEngine, metrics};
73
74/// Query parallelism hint key.
75/// This hint can be set in the query context to control the parallelism of the query execution.
76pub const QUERY_PARALLELISM_HINT: &str = "query_parallelism";
77
78/// Whether to fallback to the original plan when failed to push down.
79pub const QUERY_FALLBACK_HINT: &str = "query_fallback";
80
81pub struct DatafusionQueryEngine {
82    state: Arc<QueryEngineState>,
83    plugins: Plugins,
84}
85
86impl DatafusionQueryEngine {
87    pub fn new(state: Arc<QueryEngineState>, plugins: Plugins) -> Self {
88        Self { state, plugins }
89    }
90
91    #[tracing::instrument(skip_all)]
92    async fn exec_query_plan(
93        &self,
94        plan: LogicalPlan,
95        query_ctx: QueryContextRef,
96    ) -> Result<Output> {
97        let mut ctx = self.engine_context(query_ctx.clone());
98        let plan = if let Some(receiver_injector) =
99            self.plugins.get::<RemoteDynFilterReceiverInjectorRef>()
100        {
101            receiver_injector.maybe_inject(plan, query_ctx.clone())
102        } else {
103            plan
104        };
105
106        // `create_physical_plan` will optimize logical plan internally
107        let physical_plan = self.create_physical_plan(&mut ctx, &plan).await?;
108        let physical_plan = self.optimize_physical_plan(&mut ctx, physical_plan)?;
109        let physical_plan = if let Some(wrapper) = self.plugins.get::<PhysicalPlanWrapperRef>() {
110            wrapper.wrap(physical_plan, query_ctx)
111        } else {
112            physical_plan
113        };
114
115        let stream = self.execute_stream(&ctx, &physical_plan)?;
116
117        Ok(Output::new(
118            OutputData::Stream(stream),
119            OutputMeta::new_with_plan(physical_plan),
120        ))
121    }
122
123    #[tracing::instrument(skip_all)]
124    async fn exec_dml_statement(
125        &self,
126        dml: DmlStatement,
127        query_ctx: QueryContextRef,
128    ) -> Result<Output> {
129        ensure!(
130            matches!(dml.op, WriteOp::Insert(_) | WriteOp::Delete),
131            UnsupportedExprSnafu {
132                name: format!("DML op {}", dml.op),
133            }
134        );
135
136        let _timer = QUERY_STAGE_ELAPSED
137            .with_label_values(&[dml.op.name()])
138            .start_timer();
139
140        let default_catalog = &query_ctx.current_catalog().to_owned();
141        let default_schema = &query_ctx.current_schema();
142        let table_name = dml.table_name.resolve(default_catalog, default_schema);
143        let table = self.find_table(&table_name, &query_ctx).await?;
144
145        let Output { data, meta } = self
146            .exec_query_plan((*dml.input).clone(), query_ctx.clone())
147            .await?;
148        let mut stream = match data {
149            OutputData::RecordBatches(batches) => batches.as_stream(),
150            OutputData::Stream(stream) => stream,
151            _ => unreachable!(),
152        };
153
154        let mut affected_rows = 0;
155        let mut insert_cost = 0;
156
157        while let Some(batch) = stream.next().await {
158            let batch = batch.context(CreateRecordBatchSnafu)?;
159            let column_vectors = batch
160                .column_vectors(&table_name.to_string(), table.schema())
161                .map_err(BoxedError::new)
162                .context(QueryExecutionSnafu)?;
163
164            match dml.op {
165                WriteOp::Insert(_) => {
166                    // We ignore the insert op.
167                    let output = self
168                        .insert(&table_name, column_vectors, query_ctx.clone())
169                        .await?;
170                    let (rows, cost) = output.extract_rows_and_cost();
171                    affected_rows += rows;
172                    insert_cost += cost;
173                }
174                WriteOp::Delete => {
175                    affected_rows += self
176                        .delete(&table_name, &table, column_vectors, query_ctx.clone())
177                        .await?;
178                }
179                _ => unreachable!("guarded by the 'ensure!' at the beginning"),
180            }
181        }
182        Ok(Output::new(
183            OutputData::AffectedRows(affected_rows),
184            OutputMeta::new(meta.plan, insert_cost),
185        ))
186    }
187
188    #[tracing::instrument(skip_all)]
189    async fn delete(
190        &self,
191        table_name: &ResolvedTableReference,
192        table: &TableRef,
193        column_vectors: HashMap<String, VectorRef>,
194        query_ctx: QueryContextRef,
195    ) -> Result<usize> {
196        let catalog_name = table_name.catalog.to_string();
197        let schema_name = table_name.schema.to_string();
198        let table_name = table_name.table.to_string();
199        let table_schema = table.schema();
200
201        ensure!(
202            !is_readonly_schema(&schema_name),
203            TableReadOnlySnafu { table: table_name }
204        );
205
206        let ts_column = table_schema
207            .timestamp_column()
208            .map(|x| &x.name)
209            .with_context(|| MissingTimestampColumnSnafu {
210                table_name: table_name.clone(),
211            })?;
212
213        let table_info = table.table_info();
214        let rowkey_columns = table_info
215            .meta
216            .row_key_column_names()
217            .collect::<Vec<&String>>();
218        let column_vectors = column_vectors
219            .into_iter()
220            .filter(|x| &x.0 == ts_column || rowkey_columns.contains(&&x.0))
221            .collect::<HashMap<_, _>>();
222
223        let request = DeleteRequest {
224            catalog_name,
225            schema_name,
226            table_name,
227            key_column_values: column_vectors,
228        };
229
230        self.state
231            .table_mutation_handler()
232            .context(MissingTableMutationHandlerSnafu)?
233            .delete(request, query_ctx)
234            .await
235            .context(TableMutationSnafu)
236    }
237
238    #[tracing::instrument(skip_all)]
239    async fn insert(
240        &self,
241        table_name: &ResolvedTableReference,
242        column_vectors: HashMap<String, VectorRef>,
243        query_ctx: QueryContextRef,
244    ) -> Result<Output> {
245        let catalog_name = table_name.catalog.to_string();
246        let schema_name = table_name.schema.to_string();
247        let table_name = table_name.table.to_string();
248
249        ensure!(
250            !is_readonly_schema(&schema_name),
251            TableReadOnlySnafu { table: table_name }
252        );
253
254        let request = InsertRequest {
255            catalog_name,
256            schema_name,
257            table_name,
258            columns_values: column_vectors,
259        };
260
261        self.state
262            .table_mutation_handler()
263            .context(MissingTableMutationHandlerSnafu)?
264            .insert(request, query_ctx)
265            .await
266            .context(TableMutationSnafu)
267    }
268
269    async fn find_table(
270        &self,
271        table_name: &ResolvedTableReference,
272        query_context: &QueryContextRef,
273    ) -> Result<TableRef> {
274        let catalog_name = table_name.catalog.as_ref();
275        let schema_name = table_name.schema.as_ref();
276        let table_name = table_name.table.as_ref();
277
278        self.state
279            .catalog_manager()
280            .table(catalog_name, schema_name, table_name, Some(query_context))
281            .await
282            .context(CatalogSnafu)?
283            .with_context(|| TableNotFoundSnafu { table: table_name })
284    }
285
286    #[tracing::instrument(skip_all)]
287    async fn create_physical_plan(
288        &self,
289        ctx: &mut QueryEngineContext,
290        logical_plan: &LogicalPlan,
291    ) -> Result<Arc<dyn ExecutionPlan>> {
292        /// Only print context on panic, to avoid cluttering logs.
293        ///
294        /// TODO(discord9): remove this once we catch the bug
295        #[derive(Debug)]
296        struct PanicLogger<'a> {
297            input_logical_plan: &'a LogicalPlan,
298            after_analyze: Option<LogicalPlan>,
299            after_optimize: Option<LogicalPlan>,
300            phy_plan: Option<Arc<dyn ExecutionPlan>>,
301        }
302        impl Drop for PanicLogger<'_> {
303            fn drop(&mut self) {
304                if std::thread::panicking() {
305                    common_telemetry::error!(
306                        "Panic while creating physical plan, input logical plan: {:?}, after analyze: {:?}, after optimize: {:?}, final physical plan: {:?}",
307                        self.input_logical_plan,
308                        self.after_analyze,
309                        self.after_optimize,
310                        self.phy_plan
311                    );
312                }
313            }
314        }
315
316        let mut logger = PanicLogger {
317            input_logical_plan: logical_plan,
318            after_analyze: None,
319            after_optimize: None,
320            phy_plan: None,
321        };
322
323        let _timer = metrics::CREATE_PHYSICAL_ELAPSED.start_timer();
324        let state = ctx.state();
325
326        common_telemetry::debug!("Create physical plan, input plan: {logical_plan}");
327
328        // special handle EXPLAIN plan
329        if matches!(logical_plan, DfLogicalPlan::Explain(_)) {
330            return state
331                .create_physical_plan(logical_plan)
332                .await
333                .map_err(Into::into);
334        }
335
336        // analyze first
337        let analyzed_plan = state.analyzer().execute_and_check(
338            logical_plan.clone(),
339            state.config_options(),
340            |_, _| {},
341        )?;
342
343        logger.after_analyze = Some(analyzed_plan.clone());
344
345        common_telemetry::debug!("Create physical plan, analyzed plan: {analyzed_plan}");
346
347        // skip optimize for MergeScan
348        let optimized_plan = if let DfLogicalPlan::Extension(ext) = &analyzed_plan
349            && ext.node.name() == MergeScanLogicalPlan::name()
350        {
351            analyzed_plan.clone()
352        } else {
353            state
354                .optimizer()
355                .optimize(analyzed_plan, state, |_, _| {})?
356        };
357
358        common_telemetry::debug!("Create physical plan, optimized plan: {optimized_plan}");
359        logger.after_optimize = Some(optimized_plan.clone());
360
361        let physical_plan = state
362            .query_planner()
363            .create_physical_plan(&optimized_plan, state)
364            .await?;
365
366        logger.phy_plan = Some(physical_plan.clone());
367        drop(logger);
368        Ok(physical_plan)
369    }
370
371    #[tracing::instrument(skip_all)]
372    fn optimize_physical_plan(
373        &self,
374        ctx: &mut QueryEngineContext,
375        plan: Arc<dyn ExecutionPlan>,
376    ) -> Result<Arc<dyn ExecutionPlan>> {
377        let _timer = metrics::OPTIMIZE_PHYSICAL_ELAPSED.start_timer();
378
379        // TODO(ruihang): `self.create_physical_plan()` already optimize the plan, check
380        // if we need to optimize it again here.
381        // let state = ctx.state();
382        // let config = state.config_options();
383
384        // skip optimize AnalyzeExec plan
385        let optimized_plan = if let Some(analyze_plan) = plan.as_any().downcast_ref::<AnalyzeExec>()
386        {
387            let format = if let Some(format) = ctx.query_ctx().explain_format()
388                && format.to_lowercase() == "json"
389            {
390                AnalyzeFormat::JSON
391            } else {
392                AnalyzeFormat::TEXT
393            };
394            // Sets the verbose flag of the query context.
395            // The MergeScanExec plan uses the verbose flag to determine whether to print the plan in verbose mode.
396            ctx.query_ctx().set_explain_verbose(analyze_plan.verbose());
397
398            Arc::new(DistAnalyzeExec::new(
399                analyze_plan.input().clone(),
400                analyze_plan.verbose(),
401                format,
402            ))
403            // let mut new_plan = analyze_plan.input().clone();
404            // for optimizer in state.physical_optimizers() {
405            //     new_plan = optimizer
406            //         .optimize(new_plan, config)
407            //         .context(DataFusionSnafu)?;
408            // }
409            // Arc::new(DistAnalyzeExec::new(new_plan))
410        } else {
411            plan
412            // let mut new_plan = plan;
413            // for optimizer in state.physical_optimizers() {
414            //     new_plan = optimizer
415            //         .optimize(new_plan, config)
416            //         .context(DataFusionSnafu)?;
417            // }
418            // new_plan
419        };
420
421        Ok(optimized_plan)
422    }
423}
424
425#[async_trait]
426impl QueryEngine for DatafusionQueryEngine {
427    fn as_any(&self) -> &dyn Any {
428        self
429    }
430
431    fn planner(&self) -> Arc<dyn LogicalPlanner> {
432        Arc::new(DfLogicalPlanner::new(self.state.clone()))
433    }
434
435    fn name(&self) -> &str {
436        "datafusion"
437    }
438
439    async fn describe(
440        &self,
441        plan: LogicalPlan,
442        _query_ctx: QueryContextRef,
443    ) -> Result<DescribeResult> {
444        Ok(DescribeResult { logical_plan: plan })
445    }
446
447    async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result<Output> {
448        match plan {
449            LogicalPlan::Dml(dml) => self.exec_dml_statement(dml, query_ctx).await,
450            _ => self.exec_query_plan(plan, query_ctx).await,
451        }
452    }
453
454    /// Note in SQL queries, aggregate names are looked up using
455    /// lowercase unless the query uses quotes. For example,
456    ///
457    /// `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"`
458    /// `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"`
459    ///
460    /// So it's better to make UDAF name lowercase when creating one.
461    fn register_aggregate_function(&self, func: AggregateUDF) {
462        self.state.register_aggr_function(func);
463    }
464
465    /// Register an scalar function.
466    /// Will override if the function with same name is already registered.
467    fn register_scalar_function(&self, func: ScalarFunctionFactory) {
468        self.state.register_scalar_function(func);
469    }
470
471    fn register_table_function(&self, func: Arc<TableFunction>) {
472        self.state.register_table_function(func);
473    }
474
475    fn register_window_function(&self, func: WindowUDF) {
476        self.state.register_window_function(func);
477    }
478
479    fn read_table(&self, table: TableRef) -> Result<DataFrame> {
480        self.state.read_table(table).map_err(Into::into)
481    }
482
483    fn engine_context(&self, query_ctx: QueryContextRef) -> QueryEngineContext {
484        let mut state = self.state.session_state();
485        state.config_mut().set_extension(query_ctx.clone());
486        state.config_mut().set_extension(self.state.clone());
487        // note that hints in "x-greptime-hints" is automatically parsed
488        // and set to query context's extension, so we can get it from query context.
489        if let Some(parallelism) = query_ctx.extension(QUERY_PARALLELISM_HINT) {
490            if let Ok(n) = parallelism.parse::<u64>() {
491                if n > 0 {
492                    let new_cfg = state.config().clone().with_target_partitions(n as usize);
493                    *state.config_mut() = new_cfg;
494                }
495            } else {
496                common_telemetry::warn!(
497                    "Failed to parse query_parallelism: {}, using default value",
498                    parallelism
499                );
500            }
501        }
502
503        // configure execution options
504        state.config_mut().options_mut().execution.time_zone =
505            Some(query_ctx.timezone().to_string());
506
507        // usually it's impossible to have both `set variable` set by sql client and
508        // hint in header by grpc client, so only need to deal with them separately
509        if query_ctx.configuration_parameter().allow_query_fallback() {
510            state
511                .config_mut()
512                .options_mut()
513                .extensions
514                .insert(DistPlannerOptions {
515                    allow_query_fallback: true,
516                });
517        } else if let Some(fallback) = query_ctx.extension(QUERY_FALLBACK_HINT) {
518            // also check the query context for fallback hint
519            // if it is set, we will enable the fallback
520            if fallback.to_lowercase().parse::<bool>().unwrap_or(false) {
521                state
522                    .config_mut()
523                    .options_mut()
524                    .extensions
525                    .insert(DistPlannerOptions {
526                        allow_query_fallback: true,
527                    });
528            }
529        }
530
531        state
532            .config_mut()
533            .options_mut()
534            .extensions
535            .insert(FunctionContext {
536                query_ctx: query_ctx.clone(),
537                state: self.engine_state().function_state(),
538            });
539
540        let config_options = state.config_options().clone();
541        let _ = state
542            .execution_props_mut()
543            .config_options
544            .insert(config_options);
545
546        QueryEngineContext::new(state, query_ctx)
547    }
548
549    fn engine_state(&self) -> &QueryEngineState {
550        &self.state
551    }
552}
553
554impl QueryExecutor for DatafusionQueryEngine {
555    #[tracing::instrument(skip_all)]
556    fn execute_stream(
557        &self,
558        ctx: &QueryEngineContext,
559        plan: &Arc<dyn ExecutionPlan>,
560    ) -> Result<SendableRecordBatchStream> {
561        let query_ctx = ctx.query_ctx();
562        let explain_verbose = query_ctx.explain_verbose();
563        let should_collect_region_watermark =
564            should_collect_region_watermark_from_query_ctx(&query_ctx)?;
565        let output_partitions = plan.properties().output_partitioning().partition_count();
566        if explain_verbose {
567            common_telemetry::info!("Executing query plan, output_partitions: {output_partitions}");
568        }
569
570        let exec_timer = metrics::EXEC_PLAN_ELAPSED.start_timer();
571        let task_ctx = ctx.build_task_ctx();
572        let span = Span::current();
573
574        match plan.properties().output_partitioning().partition_count() {
575            0 => {
576                let schema = Arc::new(
577                    Schema::try_from(plan.schema())
578                        .map_err(BoxedError::new)
579                        .context(QueryExecutionSnafu)?,
580                );
581                Ok(Box::pin(EmptyRecordBatchStream::new(schema)))
582            }
583            1 => {
584                let df_stream = plan.execute(0, task_ctx)?;
585                let mut stream = RecordBatchStreamAdapter::try_new_with_span(df_stream, span)
586                    .context(error::ConvertDfRecordBatchStreamSnafu)
587                    .map_err(BoxedError::new)
588                    .context(QueryExecutionSnafu)?;
589                stream.set_metrics2(plan.clone());
590                stream.set_explain_verbose(explain_verbose);
591                let stream = OnDone::new(Box::pin(stream), move || {
592                    let exec_cost = exec_timer.stop_and_record();
593                    if explain_verbose {
594                        common_telemetry::info!(
595                            "DatafusionQueryEngine execute 1 stream, cost: {:?}s",
596                            exec_cost,
597                        );
598                    }
599                });
600                Ok(maybe_attach_region_watermark_metrics(
601                    Box::pin(stream),
602                    plan.clone(),
603                    should_collect_region_watermark,
604                ))
605            }
606            _ => {
607                // merge into a single partition
608                let merged_plan = CoalescePartitionsExec::new(plan.clone());
609                // CoalescePartitionsExec must produce a single partition
610                assert_eq!(
611                    1,
612                    merged_plan
613                        .properties()
614                        .output_partitioning()
615                        .partition_count()
616                );
617                let df_stream = merged_plan.execute(0, task_ctx)?;
618                let mut stream = RecordBatchStreamAdapter::try_new_with_span(df_stream, span)
619                    .context(error::ConvertDfRecordBatchStreamSnafu)
620                    .map_err(BoxedError::new)
621                    .context(QueryExecutionSnafu)?;
622                stream.set_metrics2(plan.clone());
623                stream.set_explain_verbose(explain_verbose);
624                let stream = OnDone::new(Box::pin(stream), move || {
625                    let exec_cost = exec_timer.stop_and_record();
626                    if explain_verbose {
627                        common_telemetry::info!(
628                            "DatafusionQueryEngine execute {output_partitions} stream, cost: {:?}s",
629                            exec_cost
630                        );
631                    }
632                });
633                Ok(maybe_attach_region_watermark_metrics(
634                    Box::pin(stream),
635                    plan.clone(),
636                    should_collect_region_watermark,
637                ))
638            }
639        }
640    }
641}
642
643#[cfg(test)]
644mod tests {
645    use std::fmt;
646    use std::sync::Arc;
647    use std::sync::atomic::{AtomicUsize, Ordering};
648
649    use api::v1::SemanticType;
650    use arrow::array::{ArrayRef, UInt64Array};
651    use arrow_schema::SortOptions;
652    use catalog::RegisterTableRequest;
653    use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, NUMBERS_TABLE_ID};
654    use common_error::ext::BoxedError;
655    use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream, util};
656    use datafusion::physical_plan::display::{DisplayAs, DisplayFormatType};
657    use datafusion::physical_plan::expressions::PhysicalSortExpr;
658    use datafusion::physical_plan::joins::{HashJoinExec, JoinOn, PartitionMode};
659    use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet;
660    use datafusion::physical_plan::{ExecutionPlan, PhysicalExpr};
661    use datafusion::prelude::{col, lit};
662    use datafusion_common::{JoinType, NullEquality};
663    use datafusion_physical_expr::expressions::Column;
664    use datatypes::prelude::ConcreteDataType;
665    use datatypes::schema::{ColumnSchema, SchemaRef};
666    use datatypes::vectors::{Helper, UInt32Vector, VectorRef};
667    use session::context::{QueryContext, QueryContextBuilder};
668    use store_api::metadata::{ColumnMetadata, RegionMetadataBuilder, RegionMetadataRef};
669    use store_api::region_engine::{
670        PartitionRange, PrepareRequest, QueryScanContext, RegionScanner, ScannerProperties,
671    };
672    use store_api::storage::{RegionId, ScanRequest};
673    use table::table::numbers::{NUMBERS_TABLE_NAME, NumbersTable};
674    use table::table::scan::RegionScanExec;
675
676    use super::*;
677    use crate::options::QueryOptions;
678    use crate::parser::QueryLanguageParser;
679    use crate::part_sort::PartSortExec;
680    use crate::query_engine::{QueryEngineFactory, QueryEngineRef};
681
682    #[derive(Debug)]
683    struct RecordingScanner {
684        schema: SchemaRef,
685        metadata: RegionMetadataRef,
686        properties: ScannerProperties,
687        update_calls: Arc<AtomicUsize>,
688        last_filter_len: Arc<AtomicUsize>,
689    }
690
691    impl RecordingScanner {
692        fn new(
693            schema: SchemaRef,
694            metadata: RegionMetadataRef,
695            update_calls: Arc<AtomicUsize>,
696            last_filter_len: Arc<AtomicUsize>,
697        ) -> Self {
698            Self {
699                schema,
700                metadata,
701                properties: ScannerProperties::default(),
702                update_calls,
703                last_filter_len,
704            }
705        }
706    }
707
708    impl RegionScanner for RecordingScanner {
709        fn name(&self) -> &str {
710            "RecordingScanner"
711        }
712
713        fn properties(&self) -> &ScannerProperties {
714            &self.properties
715        }
716
717        fn schema(&self) -> SchemaRef {
718            self.schema.clone()
719        }
720
721        fn metadata(&self) -> RegionMetadataRef {
722            self.metadata.clone()
723        }
724
725        fn prepare(&mut self, request: PrepareRequest) -> std::result::Result<(), BoxedError> {
726            self.properties.prepare(request);
727            Ok(())
728        }
729
730        fn scan_partition(
731            &self,
732            _ctx: &QueryScanContext,
733            _metrics_set: &ExecutionPlanMetricsSet,
734            _partition: usize,
735        ) -> std::result::Result<SendableRecordBatchStream, BoxedError> {
736            Ok(Box::pin(EmptyRecordBatchStream::new(self.schema.clone())))
737        }
738
739        fn has_predicate_without_region(&self) -> bool {
740            true
741        }
742
743        fn add_dyn_filter_to_predicate(
744            &mut self,
745            filter_exprs: Vec<Arc<dyn PhysicalExpr>>,
746        ) -> Vec<bool> {
747            self.update_calls.fetch_add(1, Ordering::Relaxed);
748            self.last_filter_len
749                .store(filter_exprs.len(), Ordering::Relaxed);
750            vec![true; filter_exprs.len()]
751        }
752
753        fn set_logical_region(&mut self, logical_region: bool) {
754            self.properties.set_logical_region(logical_region);
755        }
756    }
757
758    impl DisplayAs for RecordingScanner {
759        fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
760            write!(f, "RecordingScanner")
761        }
762    }
763
764    async fn create_test_engine() -> QueryEngineRef {
765        let catalog_manager = catalog::memory::new_memory_catalog_manager().unwrap();
766        let req = RegisterTableRequest {
767            catalog: DEFAULT_CATALOG_NAME.to_string(),
768            schema: DEFAULT_SCHEMA_NAME.to_string(),
769            table_name: NUMBERS_TABLE_NAME.to_string(),
770            table_id: NUMBERS_TABLE_ID,
771            table: NumbersTable::table(NUMBERS_TABLE_ID),
772        };
773        catalog_manager.register_table_sync(req).unwrap();
774
775        QueryEngineFactory::new(
776            catalog_manager,
777            None,
778            None,
779            None,
780            None,
781            false,
782            QueryOptions::default(),
783        )
784        .query_engine()
785    }
786
787    #[tokio::test]
788    async fn test_sql_to_plan() {
789        let engine = create_test_engine().await;
790        let sql = "select sum(number) from numbers limit 20";
791
792        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
793        let plan = engine
794            .planner()
795            .plan(&stmt, QueryContext::arc())
796            .await
797            .unwrap();
798
799        assert_eq!(
800            plan.to_string(),
801            r#"Limit: skip=0, fetch=20
802  Projection: sum(numbers.number)
803    Aggregate: groupBy=[[]], aggr=[[sum(numbers.number)]]
804      TableScan: numbers"#
805        );
806    }
807
808    #[tokio::test]
809    async fn test_execute() {
810        let engine = create_test_engine().await;
811        let sql = "select sum(number) from numbers limit 20";
812
813        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
814        let plan = engine
815            .planner()
816            .plan(&stmt, QueryContext::arc())
817            .await
818            .unwrap();
819
820        let output = engine.execute(plan, QueryContext::arc()).await.unwrap();
821
822        match output.data {
823            OutputData::Stream(recordbatch) => {
824                let numbers = util::collect(recordbatch).await.unwrap();
825                assert_eq!(1, numbers.len());
826                assert_eq!(numbers[0].num_columns(), 1);
827                assert_eq!(1, numbers[0].schema.num_columns());
828                assert_eq!(
829                    "sum(numbers.number)",
830                    numbers[0].schema.column_schemas()[0].name
831                );
832
833                let batch = &numbers[0];
834                assert_eq!(1, batch.num_columns());
835                assert_eq!(batch.column(0).len(), 1);
836
837                let expected = Arc::new(UInt64Array::from_iter_values([4950])) as ArrayRef;
838                assert_eq!(batch.column(0), &expected);
839            }
840            _ => unreachable!(),
841        }
842    }
843
844    #[tokio::test]
845    async fn test_read_table() {
846        let engine = create_test_engine().await;
847
848        let engine = engine
849            .as_any()
850            .downcast_ref::<DatafusionQueryEngine>()
851            .unwrap();
852        let query_ctx = Arc::new(QueryContextBuilder::default().build());
853        let table = engine
854            .find_table(
855                &ResolvedTableReference {
856                    catalog: "greptime".into(),
857                    schema: "public".into(),
858                    table: "numbers".into(),
859                },
860                &query_ctx,
861            )
862            .await
863            .unwrap();
864
865        let df = engine.read_table(table).unwrap();
866        let df = df
867            .select_columns(&["number"])
868            .unwrap()
869            .filter(col("number").lt(lit(10)))
870            .unwrap();
871        let batches = df.collect().await.unwrap();
872        assert_eq!(1, batches.len());
873        let batch = &batches[0];
874
875        assert_eq!(1, batch.num_columns());
876        assert_eq!(batch.column(0).len(), 10);
877
878        assert_eq!(
879            Helper::try_into_vector(batch.column(0)).unwrap(),
880            Arc::new(UInt32Vector::from_slice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) as VectorRef
881        );
882    }
883
884    #[tokio::test]
885    async fn test_describe() {
886        let engine = create_test_engine().await;
887        let sql = "select sum(number) from numbers limit 20";
888
889        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
890
891        let plan = engine
892            .planner()
893            .plan(&stmt, QueryContext::arc())
894            .await
895            .unwrap();
896
897        let DescribeResult { logical_plan } =
898            engine.describe(plan, QueryContext::arc()).await.unwrap();
899
900        let schema: Schema = logical_plan.schema().clone().try_into().unwrap();
901
902        assert_eq!(
903            schema.column_schemas()[0],
904            ColumnSchema::new(
905                "sum(numbers.number)",
906                ConcreteDataType::uint64_datatype(),
907                true
908            )
909        );
910        assert_eq!(
911            "Limit: skip=0, fetch=20\n  Projection: sum(numbers.number)\n    Aggregate: groupBy=[[]], aggr=[[sum(numbers.number)]]\n      TableScan: numbers",
912            format!("{}", logical_plan.display_indent())
913        );
914    }
915
916    #[tokio::test]
917    async fn test_topk_dynamic_filter_pushdown_reaches_region_scan() {
918        let engine = create_test_engine().await;
919        let engine = engine
920            .as_any()
921            .downcast_ref::<DatafusionQueryEngine>()
922            .unwrap();
923        let engine_ctx = engine.engine_context(QueryContext::arc());
924        let state = engine_ctx.state();
925
926        let schema = Arc::new(datatypes::schema::Schema::new(vec![ColumnSchema::new(
927            "ts",
928            ConcreteDataType::timestamp_millisecond_datatype(),
929            false,
930        )]));
931
932        let mut metadata_builder = RegionMetadataBuilder::new(RegionId::new(1024, 1));
933        metadata_builder
934            .push_column_metadata(ColumnMetadata {
935                column_schema: ColumnSchema::new(
936                    "ts",
937                    ConcreteDataType::timestamp_millisecond_datatype(),
938                    false,
939                )
940                .with_time_index(true),
941                semantic_type: SemanticType::Timestamp,
942                column_id: 1,
943            })
944            .primary_key(vec![]);
945        let metadata = Arc::new(metadata_builder.build().unwrap());
946
947        let update_calls = Arc::new(AtomicUsize::new(0));
948        let last_filter_len = Arc::new(AtomicUsize::new(0));
949        let scanner = Box::new(RecordingScanner::new(
950            schema,
951            metadata,
952            update_calls.clone(),
953            last_filter_len.clone(),
954        ));
955        let scan = Arc::new(RegionScanExec::new(scanner, ScanRequest::default(), None).unwrap());
956
957        let sort_expr = PhysicalSortExpr {
958            expr: Arc::new(Column::new("ts", 0)),
959            options: SortOptions {
960                descending: true,
961                ..Default::default()
962            },
963        };
964        let partition_ranges: Vec<Vec<PartitionRange>> = vec![vec![]];
965        let mut plan: Arc<dyn ExecutionPlan> =
966            Arc::new(PartSortExec::try_new(sort_expr, Some(3), partition_ranges, scan).unwrap());
967
968        for optimizer in state.physical_optimizers() {
969            plan = optimizer.optimize(plan, state.config_options()).unwrap();
970        }
971
972        assert!(update_calls.load(Ordering::Relaxed) > 0);
973        assert!(last_filter_len.load(Ordering::Relaxed) > 0);
974    }
975
976    #[tokio::test]
977    async fn test_join_dynamic_filter_pushdown_reaches_region_scan() {
978        let engine = create_test_engine().await;
979        let engine = engine
980            .as_any()
981            .downcast_ref::<DatafusionQueryEngine>()
982            .unwrap();
983        let engine_ctx = engine.engine_context(QueryContext::arc());
984        let state = engine_ctx.state();
985
986        assert!(
987            state
988                .config_options()
989                .optimizer
990                .enable_join_dynamic_filter_pushdown
991        );
992
993        let schema = Arc::new(datatypes::schema::Schema::new(vec![ColumnSchema::new(
994            "ts",
995            ConcreteDataType::timestamp_millisecond_datatype(),
996            false,
997        )]));
998
999        let mut left_metadata_builder = RegionMetadataBuilder::new(RegionId::new(2048, 1));
1000        left_metadata_builder
1001            .push_column_metadata(ColumnMetadata {
1002                column_schema: ColumnSchema::new(
1003                    "ts",
1004                    ConcreteDataType::timestamp_millisecond_datatype(),
1005                    false,
1006                )
1007                .with_time_index(true),
1008                semantic_type: SemanticType::Timestamp,
1009                column_id: 1,
1010            })
1011            .primary_key(vec![]);
1012        let left_metadata = Arc::new(left_metadata_builder.build().unwrap());
1013
1014        let mut right_metadata_builder = RegionMetadataBuilder::new(RegionId::new(2048, 2));
1015        right_metadata_builder
1016            .push_column_metadata(ColumnMetadata {
1017                column_schema: ColumnSchema::new(
1018                    "ts",
1019                    ConcreteDataType::timestamp_millisecond_datatype(),
1020                    false,
1021                )
1022                .with_time_index(true),
1023                semantic_type: SemanticType::Timestamp,
1024                column_id: 1,
1025            })
1026            .primary_key(vec![]);
1027        let right_metadata = Arc::new(right_metadata_builder.build().unwrap());
1028
1029        let left_update_calls = Arc::new(AtomicUsize::new(0));
1030        let left_last_filter_len = Arc::new(AtomicUsize::new(0));
1031        let right_update_calls = Arc::new(AtomicUsize::new(0));
1032        let right_last_filter_len = Arc::new(AtomicUsize::new(0));
1033
1034        let left_scan = Arc::new(
1035            RegionScanExec::new(
1036                Box::new(RecordingScanner::new(
1037                    schema.clone(),
1038                    left_metadata,
1039                    left_update_calls.clone(),
1040                    left_last_filter_len.clone(),
1041                )),
1042                ScanRequest::default(),
1043                None,
1044            )
1045            .unwrap(),
1046        );
1047        let right_scan = Arc::new(
1048            RegionScanExec::new(
1049                Box::new(RecordingScanner::new(
1050                    schema,
1051                    right_metadata,
1052                    right_update_calls.clone(),
1053                    right_last_filter_len.clone(),
1054                )),
1055                ScanRequest::default(),
1056                None,
1057            )
1058            .unwrap(),
1059        );
1060
1061        let on: JoinOn = vec![(
1062            Arc::new(Column::new("ts", 0)) as Arc<dyn PhysicalExpr>,
1063            Arc::new(Column::new("ts", 0)) as Arc<dyn PhysicalExpr>,
1064        )];
1065
1066        let mut plan: Arc<dyn ExecutionPlan> = Arc::new(
1067            HashJoinExec::try_new(
1068                left_scan,
1069                right_scan,
1070                on,
1071                None,
1072                &JoinType::Inner,
1073                None,
1074                PartitionMode::CollectLeft,
1075                NullEquality::NullEqualsNull,
1076                false,
1077            )
1078            .unwrap(),
1079        );
1080
1081        for optimizer in state.physical_optimizers() {
1082            plan = optimizer.optimize(plan, state.config_options()).unwrap();
1083        }
1084
1085        assert!(left_update_calls.load(Ordering::Relaxed) > 0);
1086        assert_eq!(0, left_last_filter_len.load(Ordering::Relaxed));
1087        assert!(right_update_calls.load(Ordering::Relaxed) > 0);
1088        assert!(right_last_filter_len.load(Ordering::Relaxed) > 0);
1089    }
1090}