1mod error;
18mod planner;
19
20use std::any::Any;
21use std::collections::HashMap;
22use std::sync::Arc;
23
24use async_trait::async_trait;
25use common_base::Plugins;
26use common_catalog::consts::is_readonly_schema;
27use common_error::ext::BoxedError;
28use common_function::function::FunctionContext;
29use common_function::function_factory::ScalarFunctionFactory;
30use common_query::{Output, OutputData, OutputMeta};
31use common_recordbatch::adapter::RecordBatchStreamAdapter;
32use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream};
33use common_telemetry::tracing;
34use datafusion::catalog::TableFunction;
35use datafusion::dataframe::DataFrame;
36use datafusion::physical_plan::ExecutionPlan;
37use datafusion::physical_plan::analyze::AnalyzeExec;
38use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
39use datafusion_common::ResolvedTableReference;
40use datafusion_expr::{
41 AggregateUDF, DmlStatement, LogicalPlan as DfLogicalPlan, LogicalPlan, WindowUDF, WriteOp,
42};
43use datatypes::prelude::VectorRef;
44use datatypes::schema::Schema;
45use futures_util::StreamExt;
46use session::context::QueryContextRef;
47use snafu::{OptionExt, ResultExt, ensure};
48use sqlparser::ast::AnalyzeFormat;
49use table::TableRef;
50use table::requests::{DeleteRequest, InsertRequest};
51use tracing::Span;
52
53use crate::analyze::DistAnalyzeExec;
54pub use crate::datafusion::planner::DfContextProviderAdapter;
55use crate::dist_plan::{DistPlannerOptions, MergeScanLogicalPlan};
56use crate::error::{
57 CatalogSnafu, ConvertSchemaSnafu, CreateRecordBatchSnafu, MissingTableMutationHandlerSnafu,
58 MissingTimestampColumnSnafu, QueryExecutionSnafu, Result, TableMutationSnafu,
59 TableNotFoundSnafu, TableReadOnlySnafu, UnsupportedExprSnafu,
60};
61use crate::executor::QueryExecutor;
62use crate::metrics::{OnDone, QUERY_STAGE_ELAPSED};
63use crate::physical_wrapper::PhysicalPlanWrapperRef;
64use crate::planner::{DfLogicalPlanner, LogicalPlanner};
65use crate::query_engine::{DescribeResult, QueryEngineContext, QueryEngineState};
66use crate::{QueryEngine, metrics};
67
68pub const QUERY_PARALLELISM_HINT: &str = "query_parallelism";
71
72pub const QUERY_FALLBACK_HINT: &str = "query_fallback";
74
75pub struct DatafusionQueryEngine {
76 state: Arc<QueryEngineState>,
77 plugins: Plugins,
78}
79
80impl DatafusionQueryEngine {
81 pub fn new(state: Arc<QueryEngineState>, plugins: Plugins) -> Self {
82 Self { state, plugins }
83 }
84
85 #[tracing::instrument(skip_all)]
86 async fn exec_query_plan(
87 &self,
88 plan: LogicalPlan,
89 query_ctx: QueryContextRef,
90 ) -> Result<Output> {
91 let mut ctx = self.engine_context(query_ctx.clone());
92
93 let physical_plan = self.create_physical_plan(&mut ctx, &plan).await?;
95 let optimized_physical_plan = self.optimize_physical_plan(&mut ctx, physical_plan)?;
96
97 let physical_plan = if let Some(wrapper) = self.plugins.get::<PhysicalPlanWrapperRef>() {
98 wrapper.wrap(optimized_physical_plan, query_ctx)
99 } else {
100 optimized_physical_plan
101 };
102
103 Ok(Output::new(
104 OutputData::Stream(self.execute_stream(&ctx, &physical_plan)?),
105 OutputMeta::new_with_plan(physical_plan),
106 ))
107 }
108
109 #[tracing::instrument(skip_all)]
110 async fn exec_dml_statement(
111 &self,
112 dml: DmlStatement,
113 query_ctx: QueryContextRef,
114 ) -> Result<Output> {
115 ensure!(
116 matches!(dml.op, WriteOp::Insert(_) | WriteOp::Delete),
117 UnsupportedExprSnafu {
118 name: format!("DML op {}", dml.op),
119 }
120 );
121
122 let _timer = QUERY_STAGE_ELAPSED
123 .with_label_values(&[dml.op.name()])
124 .start_timer();
125
126 let default_catalog = &query_ctx.current_catalog().to_owned();
127 let default_schema = &query_ctx.current_schema();
128 let table_name = dml.table_name.resolve(default_catalog, default_schema);
129 let table = self.find_table(&table_name, &query_ctx).await?;
130
131 let output = self
132 .exec_query_plan((*dml.input).clone(), query_ctx.clone())
133 .await?;
134 let mut stream = match output.data {
135 OutputData::RecordBatches(batches) => batches.as_stream(),
136 OutputData::Stream(stream) => stream,
137 _ => unreachable!(),
138 };
139
140 let mut affected_rows = 0;
141 let mut insert_cost = 0;
142
143 while let Some(batch) = stream.next().await {
144 let batch = batch.context(CreateRecordBatchSnafu)?;
145 let column_vectors = batch
146 .column_vectors(&table_name.to_string(), table.schema())
147 .map_err(BoxedError::new)
148 .context(QueryExecutionSnafu)?;
149
150 match dml.op {
151 WriteOp::Insert(_) => {
152 let output = self
154 .insert(&table_name, column_vectors, query_ctx.clone())
155 .await?;
156 let (rows, cost) = output.extract_rows_and_cost();
157 affected_rows += rows;
158 insert_cost += cost;
159 }
160 WriteOp::Delete => {
161 affected_rows += self
162 .delete(&table_name, &table, column_vectors, query_ctx.clone())
163 .await?;
164 }
165 _ => unreachable!("guarded by the 'ensure!' at the beginning"),
166 }
167 }
168 Ok(Output::new(
169 OutputData::AffectedRows(affected_rows),
170 OutputMeta::new_with_cost(insert_cost),
171 ))
172 }
173
174 #[tracing::instrument(skip_all)]
175 async fn delete(
176 &self,
177 table_name: &ResolvedTableReference,
178 table: &TableRef,
179 column_vectors: HashMap<String, VectorRef>,
180 query_ctx: QueryContextRef,
181 ) -> Result<usize> {
182 let catalog_name = table_name.catalog.to_string();
183 let schema_name = table_name.schema.to_string();
184 let table_name = table_name.table.to_string();
185 let table_schema = table.schema();
186
187 ensure!(
188 !is_readonly_schema(&schema_name),
189 TableReadOnlySnafu { table: table_name }
190 );
191
192 let ts_column = table_schema
193 .timestamp_column()
194 .map(|x| &x.name)
195 .with_context(|| MissingTimestampColumnSnafu {
196 table_name: table_name.clone(),
197 })?;
198
199 let table_info = table.table_info();
200 let rowkey_columns = table_info
201 .meta
202 .row_key_column_names()
203 .collect::<Vec<&String>>();
204 let column_vectors = column_vectors
205 .into_iter()
206 .filter(|x| &x.0 == ts_column || rowkey_columns.contains(&&x.0))
207 .collect::<HashMap<_, _>>();
208
209 let request = DeleteRequest {
210 catalog_name,
211 schema_name,
212 table_name,
213 key_column_values: column_vectors,
214 };
215
216 self.state
217 .table_mutation_handler()
218 .context(MissingTableMutationHandlerSnafu)?
219 .delete(request, query_ctx)
220 .await
221 .context(TableMutationSnafu)
222 }
223
224 #[tracing::instrument(skip_all)]
225 async fn insert(
226 &self,
227 table_name: &ResolvedTableReference,
228 column_vectors: HashMap<String, VectorRef>,
229 query_ctx: QueryContextRef,
230 ) -> Result<Output> {
231 let catalog_name = table_name.catalog.to_string();
232 let schema_name = table_name.schema.to_string();
233 let table_name = table_name.table.to_string();
234
235 ensure!(
236 !is_readonly_schema(&schema_name),
237 TableReadOnlySnafu { table: table_name }
238 );
239
240 let request = InsertRequest {
241 catalog_name,
242 schema_name,
243 table_name,
244 columns_values: column_vectors,
245 };
246
247 self.state
248 .table_mutation_handler()
249 .context(MissingTableMutationHandlerSnafu)?
250 .insert(request, query_ctx)
251 .await
252 .context(TableMutationSnafu)
253 }
254
255 async fn find_table(
256 &self,
257 table_name: &ResolvedTableReference,
258 query_context: &QueryContextRef,
259 ) -> Result<TableRef> {
260 let catalog_name = table_name.catalog.as_ref();
261 let schema_name = table_name.schema.as_ref();
262 let table_name = table_name.table.as_ref();
263
264 self.state
265 .catalog_manager()
266 .table(catalog_name, schema_name, table_name, Some(query_context))
267 .await
268 .context(CatalogSnafu)?
269 .with_context(|| TableNotFoundSnafu { table: table_name })
270 }
271
272 #[tracing::instrument(skip_all)]
273 async fn create_physical_plan(
274 &self,
275 ctx: &mut QueryEngineContext,
276 logical_plan: &LogicalPlan,
277 ) -> Result<Arc<dyn ExecutionPlan>> {
278 #[derive(Debug)]
282 struct PanicLogger<'a> {
283 input_logical_plan: &'a LogicalPlan,
284 after_analyze: Option<LogicalPlan>,
285 after_optimize: Option<LogicalPlan>,
286 phy_plan: Option<Arc<dyn ExecutionPlan>>,
287 }
288 impl Drop for PanicLogger<'_> {
289 fn drop(&mut self) {
290 if std::thread::panicking() {
291 common_telemetry::error!(
292 "Panic while creating physical plan, input logical plan: {:?}, after analyze: {:?}, after optimize: {:?}, final physical plan: {:?}",
293 self.input_logical_plan,
294 self.after_analyze,
295 self.after_optimize,
296 self.phy_plan
297 );
298 }
299 }
300 }
301
302 let mut logger = PanicLogger {
303 input_logical_plan: logical_plan,
304 after_analyze: None,
305 after_optimize: None,
306 phy_plan: None,
307 };
308
309 let _timer = metrics::CREATE_PHYSICAL_ELAPSED.start_timer();
310 let state = ctx.state();
311
312 common_telemetry::debug!("Create physical plan, input plan: {logical_plan}");
313
314 if matches!(logical_plan, DfLogicalPlan::Explain(_)) {
316 return state
317 .create_physical_plan(logical_plan)
318 .await
319 .map_err(Into::into);
320 }
321
322 let analyzed_plan = state.analyzer().execute_and_check(
324 logical_plan.clone(),
325 state.config_options(),
326 |_, _| {},
327 )?;
328
329 logger.after_analyze = Some(analyzed_plan.clone());
330
331 common_telemetry::debug!("Create physical plan, analyzed plan: {analyzed_plan}");
332
333 let optimized_plan = if let DfLogicalPlan::Extension(ext) = &analyzed_plan
335 && ext.node.name() == MergeScanLogicalPlan::name()
336 {
337 analyzed_plan.clone()
338 } else {
339 state
340 .optimizer()
341 .optimize(analyzed_plan, state, |_, _| {})?
342 };
343
344 common_telemetry::debug!("Create physical plan, optimized plan: {optimized_plan}");
345 logger.after_optimize = Some(optimized_plan.clone());
346
347 let physical_plan = state
348 .query_planner()
349 .create_physical_plan(&optimized_plan, state)
350 .await?;
351
352 logger.phy_plan = Some(physical_plan.clone());
353 drop(logger);
354 Ok(physical_plan)
355 }
356
357 #[tracing::instrument(skip_all)]
358 pub fn optimize(
359 &self,
360 context: &QueryEngineContext,
361 plan: &LogicalPlan,
362 ) -> Result<LogicalPlan> {
363 let _timer = metrics::OPTIMIZE_LOGICAL_ELAPSED.start_timer();
364
365 let optimized_plan = self
367 .state
368 .optimize_by_extension_rules(plan.clone(), context)?;
369
370 let optimized_plan = self.state.session_state().optimize(&optimized_plan)?;
372
373 Ok(optimized_plan)
374 }
375
376 #[tracing::instrument(skip_all)]
377 fn optimize_physical_plan(
378 &self,
379 ctx: &mut QueryEngineContext,
380 plan: Arc<dyn ExecutionPlan>,
381 ) -> Result<Arc<dyn ExecutionPlan>> {
382 let _timer = metrics::OPTIMIZE_PHYSICAL_ELAPSED.start_timer();
383
384 let optimized_plan = if let Some(analyze_plan) = plan.as_any().downcast_ref::<AnalyzeExec>()
391 {
392 let format = if let Some(format) = ctx.query_ctx().explain_format()
393 && format.to_lowercase() == "json"
394 {
395 AnalyzeFormat::JSON
396 } else {
397 AnalyzeFormat::TEXT
398 };
399 ctx.query_ctx().set_explain_verbose(analyze_plan.verbose());
402
403 Arc::new(DistAnalyzeExec::new(
404 analyze_plan.input().clone(),
405 analyze_plan.verbose(),
406 format,
407 ))
408 } else {
416 plan
417 };
425
426 Ok(optimized_plan)
427 }
428}
429
430#[async_trait]
431impl QueryEngine for DatafusionQueryEngine {
432 fn as_any(&self) -> &dyn Any {
433 self
434 }
435
436 fn planner(&self) -> Arc<dyn LogicalPlanner> {
437 Arc::new(DfLogicalPlanner::new(self.state.clone()))
438 }
439
440 fn name(&self) -> &str {
441 "datafusion"
442 }
443
444 async fn describe(
445 &self,
446 plan: LogicalPlan,
447 query_ctx: QueryContextRef,
448 ) -> Result<DescribeResult> {
449 let ctx = self.engine_context(query_ctx);
450 if let Ok(optimised_plan) = self.optimize(&ctx, &plan) {
451 let schema = optimised_plan
452 .schema()
453 .clone()
454 .try_into()
455 .context(ConvertSchemaSnafu)?;
456 Ok(DescribeResult {
457 schema,
458 logical_plan: optimised_plan,
459 })
460 } else {
461 let schema = plan
464 .schema()
465 .clone()
466 .try_into()
467 .context(ConvertSchemaSnafu)?;
468 Ok(DescribeResult {
469 schema,
470 logical_plan: plan,
471 })
472 }
473 }
474
475 async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result<Output> {
476 match plan {
477 LogicalPlan::Dml(dml) => self.exec_dml_statement(dml, query_ctx).await,
478 _ => self.exec_query_plan(plan, query_ctx).await,
479 }
480 }
481
482 fn register_aggregate_function(&self, func: AggregateUDF) {
490 self.state.register_aggr_function(func);
491 }
492
493 fn register_scalar_function(&self, func: ScalarFunctionFactory) {
496 self.state.register_scalar_function(func);
497 }
498
499 fn register_table_function(&self, func: Arc<TableFunction>) {
500 self.state.register_table_function(func);
501 }
502
503 fn register_window_function(&self, func: WindowUDF) {
504 self.state.register_window_function(func);
505 }
506
507 fn read_table(&self, table: TableRef) -> Result<DataFrame> {
508 self.state.read_table(table).map_err(Into::into)
509 }
510
511 fn engine_context(&self, query_ctx: QueryContextRef) -> QueryEngineContext {
512 let mut state = self.state.session_state();
513 state.config_mut().set_extension(query_ctx.clone());
514 if let Some(parallelism) = query_ctx.extension(QUERY_PARALLELISM_HINT) {
517 if let Ok(n) = parallelism.parse::<u64>() {
518 if n > 0 {
519 let new_cfg = state.config().clone().with_target_partitions(n as usize);
520 *state.config_mut() = new_cfg;
521 }
522 } else {
523 common_telemetry::warn!(
524 "Failed to parse query_parallelism: {}, using default value",
525 parallelism
526 );
527 }
528 }
529
530 state.config_mut().options_mut().execution.time_zone =
532 Some(query_ctx.timezone().to_string());
533
534 if query_ctx.configuration_parameter().allow_query_fallback() {
537 state
538 .config_mut()
539 .options_mut()
540 .extensions
541 .insert(DistPlannerOptions {
542 allow_query_fallback: true,
543 });
544 } else if let Some(fallback) = query_ctx.extension(QUERY_FALLBACK_HINT) {
545 if fallback.to_lowercase().parse::<bool>().unwrap_or(false) {
548 state
549 .config_mut()
550 .options_mut()
551 .extensions
552 .insert(DistPlannerOptions {
553 allow_query_fallback: true,
554 });
555 }
556 }
557
558 state
559 .config_mut()
560 .options_mut()
561 .extensions
562 .insert(FunctionContext {
563 query_ctx: query_ctx.clone(),
564 state: self.engine_state().function_state(),
565 });
566
567 let config_options = state.config_options().clone();
568 let _ = state
569 .execution_props_mut()
570 .config_options
571 .insert(config_options);
572
573 QueryEngineContext::new(state, query_ctx)
574 }
575
576 fn engine_state(&self) -> &QueryEngineState {
577 &self.state
578 }
579}
580
581impl QueryExecutor for DatafusionQueryEngine {
582 #[tracing::instrument(skip_all)]
583 fn execute_stream(
584 &self,
585 ctx: &QueryEngineContext,
586 plan: &Arc<dyn ExecutionPlan>,
587 ) -> Result<SendableRecordBatchStream> {
588 let explain_verbose = ctx.query_ctx().explain_verbose();
589 let output_partitions = plan.properties().output_partitioning().partition_count();
590 if explain_verbose {
591 common_telemetry::info!("Executing query plan, output_partitions: {output_partitions}");
592 }
593
594 let exec_timer = metrics::EXEC_PLAN_ELAPSED.start_timer();
595 let task_ctx = ctx.build_task_ctx();
596 let span = Span::current();
597
598 match plan.properties().output_partitioning().partition_count() {
599 0 => {
600 let schema = Arc::new(
601 Schema::try_from(plan.schema())
602 .map_err(BoxedError::new)
603 .context(QueryExecutionSnafu)?,
604 );
605 Ok(Box::pin(EmptyRecordBatchStream::new(schema)))
606 }
607 1 => {
608 let df_stream = plan.execute(0, task_ctx)?;
609 let mut stream = RecordBatchStreamAdapter::try_new_with_span(df_stream, span)
610 .context(error::ConvertDfRecordBatchStreamSnafu)
611 .map_err(BoxedError::new)
612 .context(QueryExecutionSnafu)?;
613 stream.set_metrics2(plan.clone());
614 stream.set_explain_verbose(explain_verbose);
615 let stream = OnDone::new(Box::pin(stream), move || {
616 let exec_cost = exec_timer.stop_and_record();
617 if explain_verbose {
618 common_telemetry::info!(
619 "DatafusionQueryEngine execute 1 stream, cost: {:?}s",
620 exec_cost,
621 );
622 }
623 });
624 Ok(Box::pin(stream))
625 }
626 _ => {
627 let merged_plan = CoalescePartitionsExec::new(plan.clone());
629 assert_eq!(
631 1,
632 merged_plan
633 .properties()
634 .output_partitioning()
635 .partition_count()
636 );
637 let df_stream = merged_plan.execute(0, task_ctx)?;
638 let mut stream = RecordBatchStreamAdapter::try_new_with_span(df_stream, span)
639 .context(error::ConvertDfRecordBatchStreamSnafu)
640 .map_err(BoxedError::new)
641 .context(QueryExecutionSnafu)?;
642 stream.set_metrics2(plan.clone());
643 stream.set_explain_verbose(ctx.query_ctx().explain_verbose());
644 let stream = OnDone::new(Box::pin(stream), move || {
645 let exec_cost = exec_timer.stop_and_record();
646 if explain_verbose {
647 common_telemetry::info!(
648 "DatafusionQueryEngine execute {output_partitions} stream, cost: {:?}s",
649 exec_cost
650 );
651 }
652 });
653 Ok(Box::pin(stream))
654 }
655 }
656 }
657}
658
659#[cfg(test)]
660mod tests {
661 use std::fmt;
662 use std::sync::Arc;
663 use std::sync::atomic::{AtomicUsize, Ordering};
664
665 use api::v1::SemanticType;
666 use arrow::array::{ArrayRef, UInt64Array};
667 use arrow_schema::SortOptions;
668 use catalog::RegisterTableRequest;
669 use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, NUMBERS_TABLE_ID};
670 use common_error::ext::BoxedError;
671 use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream, util};
672 use datafusion::physical_plan::display::{DisplayAs, DisplayFormatType};
673 use datafusion::physical_plan::expressions::PhysicalSortExpr;
674 use datafusion::physical_plan::joins::{HashJoinExec, JoinOn, PartitionMode};
675 use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet;
676 use datafusion::physical_plan::{ExecutionPlan, PhysicalExpr};
677 use datafusion::prelude::{col, lit};
678 use datafusion_common::{JoinType, NullEquality};
679 use datafusion_physical_expr::expressions::Column;
680 use datatypes::prelude::ConcreteDataType;
681 use datatypes::schema::{ColumnSchema, SchemaRef};
682 use datatypes::vectors::{Helper, UInt32Vector, VectorRef};
683 use session::context::{QueryContext, QueryContextBuilder};
684 use store_api::metadata::{ColumnMetadata, RegionMetadataBuilder, RegionMetadataRef};
685 use store_api::region_engine::{
686 PartitionRange, PrepareRequest, QueryScanContext, RegionScanner, ScannerProperties,
687 };
688 use store_api::storage::{RegionId, ScanRequest};
689 use table::table::numbers::{NUMBERS_TABLE_NAME, NumbersTable};
690 use table::table::scan::RegionScanExec;
691
692 use super::*;
693 use crate::options::QueryOptions;
694 use crate::parser::QueryLanguageParser;
695 use crate::part_sort::PartSortExec;
696 use crate::query_engine::{QueryEngineFactory, QueryEngineRef};
697
698 #[derive(Debug)]
699 struct RecordingScanner {
700 schema: SchemaRef,
701 metadata: RegionMetadataRef,
702 properties: ScannerProperties,
703 update_calls: Arc<AtomicUsize>,
704 last_filter_len: Arc<AtomicUsize>,
705 }
706
707 impl RecordingScanner {
708 fn new(
709 schema: SchemaRef,
710 metadata: RegionMetadataRef,
711 update_calls: Arc<AtomicUsize>,
712 last_filter_len: Arc<AtomicUsize>,
713 ) -> Self {
714 Self {
715 schema,
716 metadata,
717 properties: ScannerProperties::default(),
718 update_calls,
719 last_filter_len,
720 }
721 }
722 }
723
724 impl RegionScanner for RecordingScanner {
725 fn name(&self) -> &str {
726 "RecordingScanner"
727 }
728
729 fn properties(&self) -> &ScannerProperties {
730 &self.properties
731 }
732
733 fn schema(&self) -> SchemaRef {
734 self.schema.clone()
735 }
736
737 fn metadata(&self) -> RegionMetadataRef {
738 self.metadata.clone()
739 }
740
741 fn prepare(&mut self, request: PrepareRequest) -> std::result::Result<(), BoxedError> {
742 self.properties.prepare(request);
743 Ok(())
744 }
745
746 fn scan_partition(
747 &self,
748 _ctx: &QueryScanContext,
749 _metrics_set: &ExecutionPlanMetricsSet,
750 _partition: usize,
751 ) -> std::result::Result<SendableRecordBatchStream, BoxedError> {
752 Ok(Box::pin(EmptyRecordBatchStream::new(self.schema.clone())))
753 }
754
755 fn has_predicate_without_region(&self) -> bool {
756 true
757 }
758
759 fn add_dyn_filter_to_predicate(
760 &mut self,
761 filter_exprs: Vec<Arc<dyn PhysicalExpr>>,
762 ) -> Vec<bool> {
763 self.update_calls.fetch_add(1, Ordering::Relaxed);
764 self.last_filter_len
765 .store(filter_exprs.len(), Ordering::Relaxed);
766 vec![true; filter_exprs.len()]
767 }
768
769 fn set_logical_region(&mut self, logical_region: bool) {
770 self.properties.set_logical_region(logical_region);
771 }
772 }
773
774 impl DisplayAs for RecordingScanner {
775 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
776 write!(f, "RecordingScanner")
777 }
778 }
779
780 async fn create_test_engine() -> QueryEngineRef {
781 let catalog_manager = catalog::memory::new_memory_catalog_manager().unwrap();
782 let req = RegisterTableRequest {
783 catalog: DEFAULT_CATALOG_NAME.to_string(),
784 schema: DEFAULT_SCHEMA_NAME.to_string(),
785 table_name: NUMBERS_TABLE_NAME.to_string(),
786 table_id: NUMBERS_TABLE_ID,
787 table: NumbersTable::table(NUMBERS_TABLE_ID),
788 };
789 catalog_manager.register_table_sync(req).unwrap();
790
791 QueryEngineFactory::new(
792 catalog_manager,
793 None,
794 None,
795 None,
796 None,
797 false,
798 QueryOptions::default(),
799 )
800 .query_engine()
801 }
802
803 #[tokio::test]
804 async fn test_sql_to_plan() {
805 let engine = create_test_engine().await;
806 let sql = "select sum(number) from numbers limit 20";
807
808 let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
809 let plan = engine
810 .planner()
811 .plan(&stmt, QueryContext::arc())
812 .await
813 .unwrap();
814
815 assert_eq!(
816 plan.to_string(),
817 r#"Limit: skip=0, fetch=20
818 Projection: sum(numbers.number)
819 Aggregate: groupBy=[[]], aggr=[[sum(numbers.number)]]
820 TableScan: numbers"#
821 );
822 }
823
824 #[tokio::test]
825 async fn test_execute() {
826 let engine = create_test_engine().await;
827 let sql = "select sum(number) from numbers limit 20";
828
829 let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
830 let plan = engine
831 .planner()
832 .plan(&stmt, QueryContext::arc())
833 .await
834 .unwrap();
835
836 let output = engine.execute(plan, QueryContext::arc()).await.unwrap();
837
838 match output.data {
839 OutputData::Stream(recordbatch) => {
840 let numbers = util::collect(recordbatch).await.unwrap();
841 assert_eq!(1, numbers.len());
842 assert_eq!(numbers[0].num_columns(), 1);
843 assert_eq!(1, numbers[0].schema.num_columns());
844 assert_eq!(
845 "sum(numbers.number)",
846 numbers[0].schema.column_schemas()[0].name
847 );
848
849 let batch = &numbers[0];
850 assert_eq!(1, batch.num_columns());
851 assert_eq!(batch.column(0).len(), 1);
852
853 let expected = Arc::new(UInt64Array::from_iter_values([4950])) as ArrayRef;
854 assert_eq!(batch.column(0), &expected);
855 }
856 _ => unreachable!(),
857 }
858 }
859
860 #[tokio::test]
861 async fn test_read_table() {
862 let engine = create_test_engine().await;
863
864 let engine = engine
865 .as_any()
866 .downcast_ref::<DatafusionQueryEngine>()
867 .unwrap();
868 let query_ctx = Arc::new(QueryContextBuilder::default().build());
869 let table = engine
870 .find_table(
871 &ResolvedTableReference {
872 catalog: "greptime".into(),
873 schema: "public".into(),
874 table: "numbers".into(),
875 },
876 &query_ctx,
877 )
878 .await
879 .unwrap();
880
881 let df = engine.read_table(table).unwrap();
882 let df = df
883 .select_columns(&["number"])
884 .unwrap()
885 .filter(col("number").lt(lit(10)))
886 .unwrap();
887 let batches = df.collect().await.unwrap();
888 assert_eq!(1, batches.len());
889 let batch = &batches[0];
890
891 assert_eq!(1, batch.num_columns());
892 assert_eq!(batch.column(0).len(), 10);
893
894 assert_eq!(
895 Helper::try_into_vector(batch.column(0)).unwrap(),
896 Arc::new(UInt32Vector::from_slice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) as VectorRef
897 );
898 }
899
900 #[tokio::test]
901 async fn test_describe() {
902 let engine = create_test_engine().await;
903 let sql = "select sum(number) from numbers limit 20";
904
905 let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
906
907 let plan = engine
908 .planner()
909 .plan(&stmt, QueryContext::arc())
910 .await
911 .unwrap();
912
913 let DescribeResult {
914 schema,
915 logical_plan,
916 } = engine.describe(plan, QueryContext::arc()).await.unwrap();
917
918 assert_eq!(
919 schema.column_schemas()[0],
920 ColumnSchema::new(
921 "sum(numbers.number)",
922 ConcreteDataType::uint64_datatype(),
923 true
924 )
925 );
926 assert_eq!(
927 "Limit: skip=0, fetch=20\n Aggregate: groupBy=[[]], aggr=[[sum(CAST(numbers.number AS UInt64))]]\n TableScan: numbers projection=[number]",
928 format!("{}", logical_plan.display_indent())
929 );
930 }
931
932 #[tokio::test]
933 async fn test_topk_dynamic_filter_pushdown_reaches_region_scan() {
934 let engine = create_test_engine().await;
935 let engine = engine
936 .as_any()
937 .downcast_ref::<DatafusionQueryEngine>()
938 .unwrap();
939 let engine_ctx = engine.engine_context(QueryContext::arc());
940 let state = engine_ctx.state();
941
942 let schema = Arc::new(datatypes::schema::Schema::new(vec![ColumnSchema::new(
943 "ts",
944 ConcreteDataType::timestamp_millisecond_datatype(),
945 false,
946 )]));
947
948 let mut metadata_builder = RegionMetadataBuilder::new(RegionId::new(1024, 1));
949 metadata_builder
950 .push_column_metadata(ColumnMetadata {
951 column_schema: ColumnSchema::new(
952 "ts",
953 ConcreteDataType::timestamp_millisecond_datatype(),
954 false,
955 )
956 .with_time_index(true),
957 semantic_type: SemanticType::Timestamp,
958 column_id: 1,
959 })
960 .primary_key(vec![]);
961 let metadata = Arc::new(metadata_builder.build().unwrap());
962
963 let update_calls = Arc::new(AtomicUsize::new(0));
964 let last_filter_len = Arc::new(AtomicUsize::new(0));
965 let scanner = Box::new(RecordingScanner::new(
966 schema,
967 metadata,
968 update_calls.clone(),
969 last_filter_len.clone(),
970 ));
971 let scan = Arc::new(RegionScanExec::new(scanner, ScanRequest::default(), None).unwrap());
972
973 let sort_expr = PhysicalSortExpr {
974 expr: Arc::new(Column::new("ts", 0)),
975 options: SortOptions {
976 descending: true,
977 ..Default::default()
978 },
979 };
980 let partition_ranges: Vec<Vec<PartitionRange>> = vec![vec![]];
981 let mut plan: Arc<dyn ExecutionPlan> =
982 Arc::new(PartSortExec::try_new(sort_expr, Some(3), partition_ranges, scan).unwrap());
983
984 for optimizer in state.physical_optimizers() {
985 plan = optimizer.optimize(plan, state.config_options()).unwrap();
986 }
987
988 assert!(update_calls.load(Ordering::Relaxed) > 0);
989 assert!(last_filter_len.load(Ordering::Relaxed) > 0);
990 }
991
992 #[tokio::test]
993 async fn test_join_dynamic_filter_pushdown_reaches_region_scan() {
994 let engine = create_test_engine().await;
995 let engine = engine
996 .as_any()
997 .downcast_ref::<DatafusionQueryEngine>()
998 .unwrap();
999 let engine_ctx = engine.engine_context(QueryContext::arc());
1000 let state = engine_ctx.state();
1001
1002 assert!(
1003 state
1004 .config_options()
1005 .optimizer
1006 .enable_join_dynamic_filter_pushdown
1007 );
1008
1009 let schema = Arc::new(datatypes::schema::Schema::new(vec![ColumnSchema::new(
1010 "ts",
1011 ConcreteDataType::timestamp_millisecond_datatype(),
1012 false,
1013 )]));
1014
1015 let mut left_metadata_builder = RegionMetadataBuilder::new(RegionId::new(2048, 1));
1016 left_metadata_builder
1017 .push_column_metadata(ColumnMetadata {
1018 column_schema: ColumnSchema::new(
1019 "ts",
1020 ConcreteDataType::timestamp_millisecond_datatype(),
1021 false,
1022 )
1023 .with_time_index(true),
1024 semantic_type: SemanticType::Timestamp,
1025 column_id: 1,
1026 })
1027 .primary_key(vec![]);
1028 let left_metadata = Arc::new(left_metadata_builder.build().unwrap());
1029
1030 let mut right_metadata_builder = RegionMetadataBuilder::new(RegionId::new(2048, 2));
1031 right_metadata_builder
1032 .push_column_metadata(ColumnMetadata {
1033 column_schema: ColumnSchema::new(
1034 "ts",
1035 ConcreteDataType::timestamp_millisecond_datatype(),
1036 false,
1037 )
1038 .with_time_index(true),
1039 semantic_type: SemanticType::Timestamp,
1040 column_id: 1,
1041 })
1042 .primary_key(vec![]);
1043 let right_metadata = Arc::new(right_metadata_builder.build().unwrap());
1044
1045 let left_update_calls = Arc::new(AtomicUsize::new(0));
1046 let left_last_filter_len = Arc::new(AtomicUsize::new(0));
1047 let right_update_calls = Arc::new(AtomicUsize::new(0));
1048 let right_last_filter_len = Arc::new(AtomicUsize::new(0));
1049
1050 let left_scan = Arc::new(
1051 RegionScanExec::new(
1052 Box::new(RecordingScanner::new(
1053 schema.clone(),
1054 left_metadata,
1055 left_update_calls.clone(),
1056 left_last_filter_len.clone(),
1057 )),
1058 ScanRequest::default(),
1059 None,
1060 )
1061 .unwrap(),
1062 );
1063 let right_scan = Arc::new(
1064 RegionScanExec::new(
1065 Box::new(RecordingScanner::new(
1066 schema,
1067 right_metadata,
1068 right_update_calls.clone(),
1069 right_last_filter_len.clone(),
1070 )),
1071 ScanRequest::default(),
1072 None,
1073 )
1074 .unwrap(),
1075 );
1076
1077 let on: JoinOn = vec![(
1078 Arc::new(Column::new("ts", 0)) as Arc<dyn PhysicalExpr>,
1079 Arc::new(Column::new("ts", 0)) as Arc<dyn PhysicalExpr>,
1080 )];
1081
1082 let mut plan: Arc<dyn ExecutionPlan> = Arc::new(
1083 HashJoinExec::try_new(
1084 left_scan,
1085 right_scan,
1086 on,
1087 None,
1088 &JoinType::Inner,
1089 None,
1090 PartitionMode::CollectLeft,
1091 NullEquality::NullEqualsNull,
1092 false,
1093 )
1094 .unwrap(),
1095 );
1096
1097 for optimizer in state.physical_optimizers() {
1098 plan = optimizer.optimize(plan, state.config_options()).unwrap();
1099 }
1100
1101 assert!(left_update_calls.load(Ordering::Relaxed) > 0);
1102 assert_eq!(0, left_last_filter_len.load(Ordering::Relaxed));
1103 assert!(right_update_calls.load(Ordering::Relaxed) > 0);
1104 assert!(right_last_filter_len.load(Ordering::Relaxed) > 0);
1105 }
1106}