1mod error;
18mod planner;
19
20use std::any::Any;
21use std::collections::HashMap;
22use std::sync::Arc;
23
24use async_trait::async_trait;
25use common_base::Plugins;
26use common_catalog::consts::is_readonly_schema;
27use common_error::ext::BoxedError;
28use common_function::function_factory::ScalarFunctionFactory;
29use common_query::{Output, OutputData, OutputMeta};
30use common_recordbatch::adapter::RecordBatchStreamAdapter;
31use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream};
32use common_telemetry::tracing;
33use datafusion::physical_plan::analyze::AnalyzeExec;
34use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
35use datafusion::physical_plan::ExecutionPlan;
36use datafusion_common::ResolvedTableReference;
37use datafusion_expr::{
38 AggregateUDF, DmlStatement, LogicalPlan as DfLogicalPlan, LogicalPlan, WriteOp,
39};
40use datatypes::prelude::VectorRef;
41use datatypes::schema::Schema;
42use futures_util::StreamExt;
43use session::context::QueryContextRef;
44use snafu::{ensure, OptionExt, ResultExt};
45use sqlparser::ast::AnalyzeFormat;
46use table::requests::{DeleteRequest, InsertRequest};
47use table::TableRef;
48
49use crate::analyze::DistAnalyzeExec;
50use crate::dataframe::DataFrame;
51pub use crate::datafusion::planner::DfContextProviderAdapter;
52use crate::dist_plan::MergeScanLogicalPlan;
53use crate::error::{
54 CatalogSnafu, ConvertSchemaSnafu, CreateRecordBatchSnafu, MissingTableMutationHandlerSnafu,
55 MissingTimestampColumnSnafu, QueryExecutionSnafu, Result, TableMutationSnafu,
56 TableNotFoundSnafu, TableReadOnlySnafu, UnsupportedExprSnafu,
57};
58use crate::executor::QueryExecutor;
59use crate::metrics::{OnDone, QUERY_STAGE_ELAPSED};
60use crate::physical_wrapper::PhysicalPlanWrapperRef;
61use crate::planner::{DfLogicalPlanner, LogicalPlanner};
62use crate::query_engine::{DescribeResult, QueryEngineContext, QueryEngineState};
63use crate::{metrics, QueryEngine};
64
65pub const QUERY_PARALLELISM_HINT: &str = "query_parallelism";
68
69pub struct DatafusionQueryEngine {
70 state: Arc<QueryEngineState>,
71 plugins: Plugins,
72}
73
74impl DatafusionQueryEngine {
75 pub fn new(state: Arc<QueryEngineState>, plugins: Plugins) -> Self {
76 Self { state, plugins }
77 }
78
79 #[tracing::instrument(skip_all)]
80 async fn exec_query_plan(
81 &self,
82 plan: LogicalPlan,
83 query_ctx: QueryContextRef,
84 ) -> Result<Output> {
85 let mut ctx = self.engine_context(query_ctx.clone());
86
87 let physical_plan = self.create_physical_plan(&mut ctx, &plan).await?;
89 let optimized_physical_plan = self.optimize_physical_plan(&mut ctx, physical_plan)?;
90
91 let physical_plan = if let Some(wrapper) = self.plugins.get::<PhysicalPlanWrapperRef>() {
92 wrapper.wrap(optimized_physical_plan, query_ctx)
93 } else {
94 optimized_physical_plan
95 };
96
97 Ok(Output::new(
98 OutputData::Stream(self.execute_stream(&ctx, &physical_plan)?),
99 OutputMeta::new_with_plan(physical_plan),
100 ))
101 }
102
103 #[tracing::instrument(skip_all)]
104 async fn exec_dml_statement(
105 &self,
106 dml: DmlStatement,
107 query_ctx: QueryContextRef,
108 ) -> Result<Output> {
109 ensure!(
110 matches!(dml.op, WriteOp::Insert(_) | WriteOp::Delete),
111 UnsupportedExprSnafu {
112 name: format!("DML op {}", dml.op),
113 }
114 );
115
116 let _timer = QUERY_STAGE_ELAPSED
117 .with_label_values(&[dml.op.name()])
118 .start_timer();
119
120 let default_catalog = &query_ctx.current_catalog().to_owned();
121 let default_schema = &query_ctx.current_schema();
122 let table_name = dml.table_name.resolve(default_catalog, default_schema);
123 let table = self.find_table(&table_name, &query_ctx).await?;
124
125 let output = self
126 .exec_query_plan((*dml.input).clone(), query_ctx.clone())
127 .await?;
128 let mut stream = match output.data {
129 OutputData::RecordBatches(batches) => batches.as_stream(),
130 OutputData::Stream(stream) => stream,
131 _ => unreachable!(),
132 };
133
134 let mut affected_rows = 0;
135 let mut insert_cost = 0;
136
137 while let Some(batch) = stream.next().await {
138 let batch = batch.context(CreateRecordBatchSnafu)?;
139 let column_vectors = batch
140 .column_vectors(&table_name.to_string(), table.schema())
141 .map_err(BoxedError::new)
142 .context(QueryExecutionSnafu)?;
143
144 match dml.op {
145 WriteOp::Insert(_) => {
146 let output = self
148 .insert(&table_name, column_vectors, query_ctx.clone())
149 .await?;
150 let (rows, cost) = output.extract_rows_and_cost();
151 affected_rows += rows;
152 insert_cost += cost;
153 }
154 WriteOp::Delete => {
155 affected_rows += self
156 .delete(&table_name, &table, column_vectors, query_ctx.clone())
157 .await?;
158 }
159 _ => unreachable!("guarded by the 'ensure!' at the beginning"),
160 }
161 }
162 Ok(Output::new(
163 OutputData::AffectedRows(affected_rows),
164 OutputMeta::new_with_cost(insert_cost),
165 ))
166 }
167
168 #[tracing::instrument(skip_all)]
169 async fn delete(
170 &self,
171 table_name: &ResolvedTableReference,
172 table: &TableRef,
173 column_vectors: HashMap<String, VectorRef>,
174 query_ctx: QueryContextRef,
175 ) -> Result<usize> {
176 let catalog_name = table_name.catalog.to_string();
177 let schema_name = table_name.schema.to_string();
178 let table_name = table_name.table.to_string();
179 let table_schema = table.schema();
180
181 ensure!(
182 !is_readonly_schema(&schema_name),
183 TableReadOnlySnafu { table: table_name }
184 );
185
186 let ts_column = table_schema
187 .timestamp_column()
188 .map(|x| &x.name)
189 .with_context(|| MissingTimestampColumnSnafu {
190 table_name: table_name.to_string(),
191 })?;
192
193 let table_info = table.table_info();
194 let rowkey_columns = table_info
195 .meta
196 .row_key_column_names()
197 .collect::<Vec<&String>>();
198 let column_vectors = column_vectors
199 .into_iter()
200 .filter(|x| &x.0 == ts_column || rowkey_columns.contains(&&x.0))
201 .collect::<HashMap<_, _>>();
202
203 let request = DeleteRequest {
204 catalog_name,
205 schema_name,
206 table_name,
207 key_column_values: column_vectors,
208 };
209
210 self.state
211 .table_mutation_handler()
212 .context(MissingTableMutationHandlerSnafu)?
213 .delete(request, query_ctx)
214 .await
215 .context(TableMutationSnafu)
216 }
217
218 #[tracing::instrument(skip_all)]
219 async fn insert(
220 &self,
221 table_name: &ResolvedTableReference,
222 column_vectors: HashMap<String, VectorRef>,
223 query_ctx: QueryContextRef,
224 ) -> Result<Output> {
225 let catalog_name = table_name.catalog.to_string();
226 let schema_name = table_name.schema.to_string();
227 let table_name = table_name.table.to_string();
228
229 ensure!(
230 !is_readonly_schema(&schema_name),
231 TableReadOnlySnafu { table: table_name }
232 );
233
234 let request = InsertRequest {
235 catalog_name,
236 schema_name,
237 table_name,
238 columns_values: column_vectors,
239 };
240
241 self.state
242 .table_mutation_handler()
243 .context(MissingTableMutationHandlerSnafu)?
244 .insert(request, query_ctx)
245 .await
246 .context(TableMutationSnafu)
247 }
248
249 async fn find_table(
250 &self,
251 table_name: &ResolvedTableReference,
252 query_context: &QueryContextRef,
253 ) -> Result<TableRef> {
254 let catalog_name = table_name.catalog.as_ref();
255 let schema_name = table_name.schema.as_ref();
256 let table_name = table_name.table.as_ref();
257
258 self.state
259 .catalog_manager()
260 .table(catalog_name, schema_name, table_name, Some(query_context))
261 .await
262 .context(CatalogSnafu)?
263 .with_context(|| TableNotFoundSnafu { table: table_name })
264 }
265
266 #[tracing::instrument(skip_all)]
267 async fn create_physical_plan(
268 &self,
269 ctx: &mut QueryEngineContext,
270 logical_plan: &LogicalPlan,
271 ) -> Result<Arc<dyn ExecutionPlan>> {
272 let _timer = metrics::CREATE_PHYSICAL_ELAPSED.start_timer();
273 let state = ctx.state();
274
275 common_telemetry::debug!("Create physical plan, input plan: {logical_plan}");
276
277 if matches!(logical_plan, DfLogicalPlan::Explain(_)) {
279 return state
280 .create_physical_plan(logical_plan)
281 .await
282 .context(error::DatafusionSnafu)
283 .map_err(BoxedError::new)
284 .context(QueryExecutionSnafu);
285 }
286
287 let analyzed_plan = state
289 .analyzer()
290 .execute_and_check(logical_plan.clone(), state.config_options(), |_, _| {})
291 .context(error::DatafusionSnafu)
292 .map_err(BoxedError::new)
293 .context(QueryExecutionSnafu)?;
294
295 common_telemetry::debug!("Create physical plan, analyzed plan: {analyzed_plan}");
296
297 let optimized_plan = if let DfLogicalPlan::Extension(ext) = &analyzed_plan
299 && ext.node.name() == MergeScanLogicalPlan::name()
300 {
301 analyzed_plan.clone()
302 } else {
303 state
304 .optimizer()
305 .optimize(analyzed_plan, state, |_, _| {})
306 .context(error::DatafusionSnafu)
307 .map_err(BoxedError::new)
308 .context(QueryExecutionSnafu)?
309 };
310
311 common_telemetry::debug!("Create physical plan, optimized plan: {optimized_plan}");
312
313 let physical_plan = state
314 .query_planner()
315 .create_physical_plan(&optimized_plan, state)
316 .await?;
317
318 Ok(physical_plan)
319 }
320
321 #[tracing::instrument(skip_all)]
322 pub fn optimize(
323 &self,
324 context: &QueryEngineContext,
325 plan: &LogicalPlan,
326 ) -> Result<LogicalPlan> {
327 let _timer = metrics::OPTIMIZE_LOGICAL_ELAPSED.start_timer();
328
329 let optimized_plan = self
331 .state
332 .optimize_by_extension_rules(plan.clone(), context)
333 .context(error::DatafusionSnafu)
334 .map_err(BoxedError::new)
335 .context(QueryExecutionSnafu)?;
336
337 let optimized_plan = self
339 .state
340 .session_state()
341 .optimize(&optimized_plan)
342 .context(error::DatafusionSnafu)
343 .map_err(BoxedError::new)
344 .context(QueryExecutionSnafu)?;
345
346 Ok(optimized_plan)
347 }
348
349 #[tracing::instrument(skip_all)]
350 fn optimize_physical_plan(
351 &self,
352 ctx: &mut QueryEngineContext,
353 plan: Arc<dyn ExecutionPlan>,
354 ) -> Result<Arc<dyn ExecutionPlan>> {
355 let _timer = metrics::OPTIMIZE_PHYSICAL_ELAPSED.start_timer();
356
357 let optimized_plan = if let Some(analyze_plan) = plan.as_any().downcast_ref::<AnalyzeExec>()
364 {
365 let format = if let Some(format) = ctx.query_ctx().explain_format()
366 && format.to_lowercase() == "json"
367 {
368 AnalyzeFormat::JSON
369 } else {
370 AnalyzeFormat::TEXT
371 };
372 ctx.query_ctx().set_explain_verbose(analyze_plan.verbose());
375
376 Arc::new(DistAnalyzeExec::new(
377 analyze_plan.input().clone(),
378 analyze_plan.verbose(),
379 format,
380 ))
381 } else {
389 plan
390 };
398
399 Ok(optimized_plan)
400 }
401}
402
403#[async_trait]
404impl QueryEngine for DatafusionQueryEngine {
405 fn as_any(&self) -> &dyn Any {
406 self
407 }
408
409 fn planner(&self) -> Arc<dyn LogicalPlanner> {
410 Arc::new(DfLogicalPlanner::new(self.state.clone()))
411 }
412
413 fn name(&self) -> &str {
414 "datafusion"
415 }
416
417 async fn describe(
418 &self,
419 plan: LogicalPlan,
420 query_ctx: QueryContextRef,
421 ) -> Result<DescribeResult> {
422 let ctx = self.engine_context(query_ctx);
423 if let Ok(optimised_plan) = self.optimize(&ctx, &plan) {
424 let schema = optimised_plan
425 .schema()
426 .clone()
427 .try_into()
428 .context(ConvertSchemaSnafu)?;
429 Ok(DescribeResult {
430 schema,
431 logical_plan: optimised_plan,
432 })
433 } else {
434 let schema = plan
437 .schema()
438 .clone()
439 .try_into()
440 .context(ConvertSchemaSnafu)?;
441 Ok(DescribeResult {
442 schema,
443 logical_plan: plan,
444 })
445 }
446 }
447
448 async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result<Output> {
449 match plan {
450 LogicalPlan::Dml(dml) => self.exec_dml_statement(dml, query_ctx).await,
451 _ => self.exec_query_plan(plan, query_ctx).await,
452 }
453 }
454
455 fn register_aggregate_function(&self, func: AggregateUDF) {
463 self.state.register_aggr_function(func);
464 }
465
466 fn register_scalar_function(&self, func: ScalarFunctionFactory) {
469 self.state.register_scalar_function(func);
470 }
471
472 fn read_table(&self, table: TableRef) -> Result<DataFrame> {
473 Ok(DataFrame::DataFusion(
474 self.state
475 .read_table(table)
476 .context(error::DatafusionSnafu)
477 .map_err(BoxedError::new)
478 .context(QueryExecutionSnafu)?,
479 ))
480 }
481
482 fn engine_context(&self, query_ctx: QueryContextRef) -> QueryEngineContext {
483 let mut state = self.state.session_state();
484 state.config_mut().set_extension(query_ctx.clone());
485 if let Some(parallelism) = query_ctx.extension(QUERY_PARALLELISM_HINT) {
488 if let Ok(n) = parallelism.parse::<u64>() {
489 if n > 0 {
490 let new_cfg = state.config().clone().with_target_partitions(n as usize);
491 *state.config_mut() = new_cfg;
492 }
493 } else {
494 common_telemetry::warn!(
495 "Failed to parse query_parallelism: {}, using default value",
496 parallelism
497 );
498 }
499 }
500 QueryEngineContext::new(state, query_ctx)
501 }
502
503 fn engine_state(&self) -> &QueryEngineState {
504 &self.state
505 }
506}
507
508impl QueryExecutor for DatafusionQueryEngine {
509 #[tracing::instrument(skip_all)]
510 fn execute_stream(
511 &self,
512 ctx: &QueryEngineContext,
513 plan: &Arc<dyn ExecutionPlan>,
514 ) -> Result<SendableRecordBatchStream> {
515 let exec_timer = metrics::EXEC_PLAN_ELAPSED.start_timer();
516 let task_ctx = ctx.build_task_ctx();
517
518 match plan.properties().output_partitioning().partition_count() {
519 0 => {
520 let schema = Arc::new(
521 Schema::try_from(plan.schema())
522 .map_err(BoxedError::new)
523 .context(QueryExecutionSnafu)?,
524 );
525 Ok(Box::pin(EmptyRecordBatchStream::new(schema)))
526 }
527 1 => {
528 let df_stream = plan
529 .execute(0, task_ctx)
530 .context(error::DatafusionSnafu)
531 .map_err(BoxedError::new)
532 .context(QueryExecutionSnafu)?;
533 let mut stream = RecordBatchStreamAdapter::try_new(df_stream)
534 .context(error::ConvertDfRecordBatchStreamSnafu)
535 .map_err(BoxedError::new)
536 .context(QueryExecutionSnafu)?;
537 stream.set_metrics2(plan.clone());
538 stream.set_explain_verbose(ctx.query_ctx().explain_verbose());
539 let stream = OnDone::new(Box::pin(stream), move || {
540 exec_timer.observe_duration();
541 });
542 Ok(Box::pin(stream))
543 }
544 _ => {
545 let merged_plan = CoalescePartitionsExec::new(plan.clone());
547 assert_eq!(
549 1,
550 merged_plan
551 .properties()
552 .output_partitioning()
553 .partition_count()
554 );
555 let df_stream = merged_plan
556 .execute(0, task_ctx)
557 .context(error::DatafusionSnafu)
558 .map_err(BoxedError::new)
559 .context(QueryExecutionSnafu)?;
560 let mut stream = RecordBatchStreamAdapter::try_new(df_stream)
561 .context(error::ConvertDfRecordBatchStreamSnafu)
562 .map_err(BoxedError::new)
563 .context(QueryExecutionSnafu)?;
564 stream.set_metrics2(plan.clone());
565 stream.set_explain_verbose(ctx.query_ctx().explain_verbose());
566 let stream = OnDone::new(Box::pin(stream), move || {
567 exec_timer.observe_duration();
568 });
569 Ok(Box::pin(stream))
570 }
571 }
572 }
573}
574
575#[cfg(test)]
576mod tests {
577 use std::sync::Arc;
578
579 use catalog::RegisterTableRequest;
580 use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, NUMBERS_TABLE_ID};
581 use common_recordbatch::util;
582 use datafusion::prelude::{col, lit};
583 use datatypes::prelude::ConcreteDataType;
584 use datatypes::schema::ColumnSchema;
585 use datatypes::vectors::{Helper, UInt32Vector, UInt64Vector, VectorRef};
586 use session::context::{QueryContext, QueryContextBuilder};
587 use table::table::numbers::{NumbersTable, NUMBERS_TABLE_NAME};
588
589 use super::*;
590 use crate::options::QueryOptions;
591 use crate::parser::QueryLanguageParser;
592 use crate::query_engine::{QueryEngineFactory, QueryEngineRef};
593
594 async fn create_test_engine() -> QueryEngineRef {
595 let catalog_manager = catalog::memory::new_memory_catalog_manager().unwrap();
596 let req = RegisterTableRequest {
597 catalog: DEFAULT_CATALOG_NAME.to_string(),
598 schema: DEFAULT_SCHEMA_NAME.to_string(),
599 table_name: NUMBERS_TABLE_NAME.to_string(),
600 table_id: NUMBERS_TABLE_ID,
601 table: NumbersTable::table(NUMBERS_TABLE_ID),
602 };
603 catalog_manager.register_table_sync(req).unwrap();
604
605 QueryEngineFactory::new(
606 catalog_manager,
607 None,
608 None,
609 None,
610 None,
611 false,
612 QueryOptions::default(),
613 )
614 .query_engine()
615 }
616
617 #[tokio::test]
618 async fn test_sql_to_plan() {
619 let engine = create_test_engine().await;
620 let sql = "select sum(number) from numbers limit 20";
621
622 let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
623 let plan = engine
624 .planner()
625 .plan(&stmt, QueryContext::arc())
626 .await
627 .unwrap();
628
629 assert_eq!(
630 plan.to_string(),
631 r#"Limit: skip=0, fetch=20
632 Projection: sum(numbers.number)
633 Aggregate: groupBy=[[]], aggr=[[sum(numbers.number)]]
634 TableScan: numbers"#
635 );
636 }
637
638 #[tokio::test]
639 async fn test_execute() {
640 let engine = create_test_engine().await;
641 let sql = "select sum(number) from numbers limit 20";
642
643 let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
644 let plan = engine
645 .planner()
646 .plan(&stmt, QueryContext::arc())
647 .await
648 .unwrap();
649
650 let output = engine.execute(plan, QueryContext::arc()).await.unwrap();
651
652 match output.data {
653 OutputData::Stream(recordbatch) => {
654 let numbers = util::collect(recordbatch).await.unwrap();
655 assert_eq!(1, numbers.len());
656 assert_eq!(numbers[0].num_columns(), 1);
657 assert_eq!(1, numbers[0].schema.num_columns());
658 assert_eq!(
659 "sum(numbers.number)",
660 numbers[0].schema.column_schemas()[0].name
661 );
662
663 let batch = &numbers[0];
664 assert_eq!(1, batch.num_columns());
665 assert_eq!(batch.column(0).len(), 1);
666
667 assert_eq!(
668 *batch.column(0),
669 Arc::new(UInt64Vector::from_slice([4950])) as VectorRef
670 );
671 }
672 _ => unreachable!(),
673 }
674 }
675
676 #[tokio::test]
677 async fn test_read_table() {
678 let engine = create_test_engine().await;
679
680 let engine = engine
681 .as_any()
682 .downcast_ref::<DatafusionQueryEngine>()
683 .unwrap();
684 let query_ctx = Arc::new(QueryContextBuilder::default().build());
685 let table = engine
686 .find_table(
687 &ResolvedTableReference {
688 catalog: "greptime".into(),
689 schema: "public".into(),
690 table: "numbers".into(),
691 },
692 &query_ctx,
693 )
694 .await
695 .unwrap();
696
697 let DataFrame::DataFusion(df) = engine.read_table(table).unwrap();
698 let df = df
699 .select_columns(&["number"])
700 .unwrap()
701 .filter(col("number").lt(lit(10)))
702 .unwrap();
703 let batches = df.collect().await.unwrap();
704 assert_eq!(1, batches.len());
705 let batch = &batches[0];
706
707 assert_eq!(1, batch.num_columns());
708 assert_eq!(batch.column(0).len(), 10);
709
710 assert_eq!(
711 Helper::try_into_vector(batch.column(0)).unwrap(),
712 Arc::new(UInt32Vector::from_slice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) as VectorRef
713 );
714 }
715
716 #[tokio::test]
717 async fn test_describe() {
718 let engine = create_test_engine().await;
719 let sql = "select sum(number) from numbers limit 20";
720
721 let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
722
723 let plan = engine
724 .planner()
725 .plan(&stmt, QueryContext::arc())
726 .await
727 .unwrap();
728
729 let DescribeResult {
730 schema,
731 logical_plan,
732 } = engine.describe(plan, QueryContext::arc()).await.unwrap();
733
734 assert_eq!(
735 schema.column_schemas()[0],
736 ColumnSchema::new(
737 "sum(numbers.number)",
738 ConcreteDataType::uint64_datatype(),
739 true
740 )
741 );
742 assert_eq!("Limit: skip=0, fetch=20\n Aggregate: groupBy=[[]], aggr=[[sum(CAST(numbers.number AS UInt64))]]\n TableScan: numbers projection=[number]", format!("{}", logical_plan.display_indent()));
743 }
744}