1mod error;
18mod planner;
19
20use std::any::Any;
21use std::collections::HashMap;
22use std::sync::Arc;
23
24use async_trait::async_trait;
25use common_base::Plugins;
26use common_catalog::consts::is_readonly_schema;
27use common_error::ext::BoxedError;
28use common_function::function::FunctionRef;
29use common_function::scalars::aggregate::AggregateFunctionMetaRef;
30use common_query::{Output, OutputData, OutputMeta};
31use common_recordbatch::adapter::RecordBatchStreamAdapter;
32use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream};
33use common_telemetry::tracing;
34use datafusion::physical_plan::analyze::AnalyzeExec;
35use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
36use datafusion::physical_plan::ExecutionPlan;
37use datafusion_common::ResolvedTableReference;
38use datafusion_expr::{DmlStatement, LogicalPlan as DfLogicalPlan, LogicalPlan, WriteOp};
39use datatypes::prelude::VectorRef;
40use datatypes::schema::Schema;
41use futures_util::StreamExt;
42use session::context::QueryContextRef;
43use snafu::{ensure, OptionExt, ResultExt};
44use sqlparser::ast::AnalyzeFormat;
45use table::requests::{DeleteRequest, InsertRequest};
46use table::TableRef;
47
48use crate::analyze::DistAnalyzeExec;
49use crate::dataframe::DataFrame;
50pub use crate::datafusion::planner::DfContextProviderAdapter;
51use crate::dist_plan::MergeScanLogicalPlan;
52use crate::error::{
53 CatalogSnafu, ConvertSchemaSnafu, CreateRecordBatchSnafu, MissingTableMutationHandlerSnafu,
54 MissingTimestampColumnSnafu, QueryExecutionSnafu, Result, TableMutationSnafu,
55 TableNotFoundSnafu, TableReadOnlySnafu, UnsupportedExprSnafu,
56};
57use crate::executor::QueryExecutor;
58use crate::metrics::{OnDone, QUERY_STAGE_ELAPSED};
59use crate::physical_wrapper::PhysicalPlanWrapperRef;
60use crate::planner::{DfLogicalPlanner, LogicalPlanner};
61use crate::query_engine::{DescribeResult, QueryEngineContext, QueryEngineState};
62use crate::{metrics, QueryEngine};
63
64pub struct DatafusionQueryEngine {
65 state: Arc<QueryEngineState>,
66 plugins: Plugins,
67}
68
69impl DatafusionQueryEngine {
70 pub fn new(state: Arc<QueryEngineState>, plugins: Plugins) -> Self {
71 Self { state, plugins }
72 }
73
74 #[tracing::instrument(skip_all)]
75 async fn exec_query_plan(
76 &self,
77 plan: LogicalPlan,
78 query_ctx: QueryContextRef,
79 ) -> Result<Output> {
80 let mut ctx = self.engine_context(query_ctx.clone());
81
82 let physical_plan = self.create_physical_plan(&mut ctx, &plan).await?;
84 let optimized_physical_plan = self.optimize_physical_plan(&mut ctx, physical_plan)?;
85
86 let physical_plan = if let Some(wrapper) = self.plugins.get::<PhysicalPlanWrapperRef>() {
87 wrapper.wrap(optimized_physical_plan, query_ctx)
88 } else {
89 optimized_physical_plan
90 };
91
92 Ok(Output::new(
93 OutputData::Stream(self.execute_stream(&ctx, &physical_plan)?),
94 OutputMeta::new_with_plan(physical_plan),
95 ))
96 }
97
98 #[tracing::instrument(skip_all)]
99 async fn exec_dml_statement(
100 &self,
101 dml: DmlStatement,
102 query_ctx: QueryContextRef,
103 ) -> Result<Output> {
104 ensure!(
105 matches!(dml.op, WriteOp::Insert(_) | WriteOp::Delete),
106 UnsupportedExprSnafu {
107 name: format!("DML op {}", dml.op),
108 }
109 );
110
111 let _timer = QUERY_STAGE_ELAPSED
112 .with_label_values(&[dml.op.name()])
113 .start_timer();
114
115 let default_catalog = &query_ctx.current_catalog().to_owned();
116 let default_schema = &query_ctx.current_schema();
117 let table_name = dml.table_name.resolve(default_catalog, default_schema);
118 let table = self.find_table(&table_name, &query_ctx).await?;
119
120 let output = self
121 .exec_query_plan((*dml.input).clone(), query_ctx.clone())
122 .await?;
123 let mut stream = match output.data {
124 OutputData::RecordBatches(batches) => batches.as_stream(),
125 OutputData::Stream(stream) => stream,
126 _ => unreachable!(),
127 };
128
129 let mut affected_rows = 0;
130 let mut insert_cost = 0;
131
132 while let Some(batch) = stream.next().await {
133 let batch = batch.context(CreateRecordBatchSnafu)?;
134 let column_vectors = batch
135 .column_vectors(&table_name.to_string(), table.schema())
136 .map_err(BoxedError::new)
137 .context(QueryExecutionSnafu)?;
138
139 match dml.op {
140 WriteOp::Insert(_) => {
141 let output = self
143 .insert(&table_name, column_vectors, query_ctx.clone())
144 .await?;
145 let (rows, cost) = output.extract_rows_and_cost();
146 affected_rows += rows;
147 insert_cost += cost;
148 }
149 WriteOp::Delete => {
150 affected_rows += self
151 .delete(&table_name, &table, column_vectors, query_ctx.clone())
152 .await?;
153 }
154 _ => unreachable!("guarded by the 'ensure!' at the beginning"),
155 }
156 }
157 Ok(Output::new(
158 OutputData::AffectedRows(affected_rows),
159 OutputMeta::new_with_cost(insert_cost),
160 ))
161 }
162
163 #[tracing::instrument(skip_all)]
164 async fn delete(
165 &self,
166 table_name: &ResolvedTableReference,
167 table: &TableRef,
168 column_vectors: HashMap<String, VectorRef>,
169 query_ctx: QueryContextRef,
170 ) -> Result<usize> {
171 let catalog_name = table_name.catalog.to_string();
172 let schema_name = table_name.schema.to_string();
173 let table_name = table_name.table.to_string();
174 let table_schema = table.schema();
175
176 ensure!(
177 !is_readonly_schema(&schema_name),
178 TableReadOnlySnafu { table: table_name }
179 );
180
181 let ts_column = table_schema
182 .timestamp_column()
183 .map(|x| &x.name)
184 .with_context(|| MissingTimestampColumnSnafu {
185 table_name: table_name.to_string(),
186 })?;
187
188 let table_info = table.table_info();
189 let rowkey_columns = table_info
190 .meta
191 .row_key_column_names()
192 .collect::<Vec<&String>>();
193 let column_vectors = column_vectors
194 .into_iter()
195 .filter(|x| &x.0 == ts_column || rowkey_columns.contains(&&x.0))
196 .collect::<HashMap<_, _>>();
197
198 let request = DeleteRequest {
199 catalog_name,
200 schema_name,
201 table_name,
202 key_column_values: column_vectors,
203 };
204
205 self.state
206 .table_mutation_handler()
207 .context(MissingTableMutationHandlerSnafu)?
208 .delete(request, query_ctx)
209 .await
210 .context(TableMutationSnafu)
211 }
212
213 #[tracing::instrument(skip_all)]
214 async fn insert(
215 &self,
216 table_name: &ResolvedTableReference,
217 column_vectors: HashMap<String, VectorRef>,
218 query_ctx: QueryContextRef,
219 ) -> Result<Output> {
220 let catalog_name = table_name.catalog.to_string();
221 let schema_name = table_name.schema.to_string();
222 let table_name = table_name.table.to_string();
223
224 ensure!(
225 !is_readonly_schema(&schema_name),
226 TableReadOnlySnafu { table: table_name }
227 );
228
229 let request = InsertRequest {
230 catalog_name,
231 schema_name,
232 table_name,
233 columns_values: column_vectors,
234 };
235
236 self.state
237 .table_mutation_handler()
238 .context(MissingTableMutationHandlerSnafu)?
239 .insert(request, query_ctx)
240 .await
241 .context(TableMutationSnafu)
242 }
243
244 async fn find_table(
245 &self,
246 table_name: &ResolvedTableReference,
247 query_context: &QueryContextRef,
248 ) -> Result<TableRef> {
249 let catalog_name = table_name.catalog.as_ref();
250 let schema_name = table_name.schema.as_ref();
251 let table_name = table_name.table.as_ref();
252
253 self.state
254 .catalog_manager()
255 .table(catalog_name, schema_name, table_name, Some(query_context))
256 .await
257 .context(CatalogSnafu)?
258 .with_context(|| TableNotFoundSnafu { table: table_name })
259 }
260
261 #[tracing::instrument(skip_all)]
262 async fn create_physical_plan(
263 &self,
264 ctx: &mut QueryEngineContext,
265 logical_plan: &LogicalPlan,
266 ) -> Result<Arc<dyn ExecutionPlan>> {
267 let _timer = metrics::CREATE_PHYSICAL_ELAPSED.start_timer();
268 let state = ctx.state();
269
270 common_telemetry::debug!("Create physical plan, input plan: {logical_plan}");
271
272 if matches!(logical_plan, DfLogicalPlan::Explain(_)) {
274 return state
275 .create_physical_plan(logical_plan)
276 .await
277 .context(error::DatafusionSnafu)
278 .map_err(BoxedError::new)
279 .context(QueryExecutionSnafu);
280 }
281
282 let analyzed_plan = state
284 .analyzer()
285 .execute_and_check(logical_plan.clone(), state.config_options(), |_, _| {})
286 .context(error::DatafusionSnafu)
287 .map_err(BoxedError::new)
288 .context(QueryExecutionSnafu)?;
289
290 common_telemetry::debug!("Create physical plan, analyzed plan: {analyzed_plan}");
291
292 let optimized_plan = if let DfLogicalPlan::Extension(ext) = &analyzed_plan
294 && ext.node.name() == MergeScanLogicalPlan::name()
295 {
296 analyzed_plan.clone()
297 } else {
298 state
299 .optimizer()
300 .optimize(analyzed_plan, state, |_, _| {})
301 .context(error::DatafusionSnafu)
302 .map_err(BoxedError::new)
303 .context(QueryExecutionSnafu)?
304 };
305
306 common_telemetry::debug!("Create physical plan, optimized plan: {optimized_plan}");
307
308 let physical_plan = state
309 .query_planner()
310 .create_physical_plan(&optimized_plan, state)
311 .await?;
312
313 Ok(physical_plan)
314 }
315
316 #[tracing::instrument(skip_all)]
317 pub fn optimize(
318 &self,
319 context: &QueryEngineContext,
320 plan: &LogicalPlan,
321 ) -> Result<LogicalPlan> {
322 let _timer = metrics::OPTIMIZE_LOGICAL_ELAPSED.start_timer();
323
324 let optimized_plan = self
326 .state
327 .optimize_by_extension_rules(plan.clone(), context)
328 .context(error::DatafusionSnafu)
329 .map_err(BoxedError::new)
330 .context(QueryExecutionSnafu)?;
331
332 let optimized_plan = self
334 .state
335 .session_state()
336 .optimize(&optimized_plan)
337 .context(error::DatafusionSnafu)
338 .map_err(BoxedError::new)
339 .context(QueryExecutionSnafu)?;
340
341 Ok(optimized_plan)
342 }
343
344 #[tracing::instrument(skip_all)]
345 fn optimize_physical_plan(
346 &self,
347 ctx: &mut QueryEngineContext,
348 plan: Arc<dyn ExecutionPlan>,
349 ) -> Result<Arc<dyn ExecutionPlan>> {
350 let _timer = metrics::OPTIMIZE_PHYSICAL_ELAPSED.start_timer();
351
352 let optimized_plan = if let Some(analyze_plan) = plan.as_any().downcast_ref::<AnalyzeExec>()
359 {
360 let format = if let Some(format) = ctx.query_ctx().explain_format()
361 && format.to_lowercase() == "json"
362 {
363 AnalyzeFormat::JSON
364 } else {
365 AnalyzeFormat::TEXT
366 };
367 ctx.query_ctx().set_explain_verbose(analyze_plan.verbose());
370
371 Arc::new(DistAnalyzeExec::new(
372 analyze_plan.input().clone(),
373 analyze_plan.verbose(),
374 format,
375 ))
376 } else {
384 plan
385 };
393
394 Ok(optimized_plan)
395 }
396}
397
398#[async_trait]
399impl QueryEngine for DatafusionQueryEngine {
400 fn as_any(&self) -> &dyn Any {
401 self
402 }
403
404 fn planner(&self) -> Arc<dyn LogicalPlanner> {
405 Arc::new(DfLogicalPlanner::new(self.state.clone()))
406 }
407
408 fn name(&self) -> &str {
409 "datafusion"
410 }
411
412 async fn describe(
413 &self,
414 plan: LogicalPlan,
415 query_ctx: QueryContextRef,
416 ) -> Result<DescribeResult> {
417 let ctx = self.engine_context(query_ctx);
418 if let Ok(optimised_plan) = self.optimize(&ctx, &plan) {
419 let schema = optimised_plan
420 .schema()
421 .clone()
422 .try_into()
423 .context(ConvertSchemaSnafu)?;
424 Ok(DescribeResult {
425 schema,
426 logical_plan: optimised_plan,
427 })
428 } else {
429 let schema = plan
432 .schema()
433 .clone()
434 .try_into()
435 .context(ConvertSchemaSnafu)?;
436 Ok(DescribeResult {
437 schema,
438 logical_plan: plan,
439 })
440 }
441 }
442
443 async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result<Output> {
444 match plan {
445 LogicalPlan::Dml(dml) => self.exec_dml_statement(dml, query_ctx).await,
446 _ => self.exec_query_plan(plan, query_ctx).await,
447 }
448 }
449
450 fn register_aggregate_function(&self, func: AggregateFunctionMetaRef) {
458 self.state.register_aggregate_function(func);
459 }
460
461 fn register_function(&self, func: FunctionRef) {
464 self.state.register_function(func);
465 }
466
467 fn read_table(&self, table: TableRef) -> Result<DataFrame> {
468 Ok(DataFrame::DataFusion(
469 self.state
470 .read_table(table)
471 .context(error::DatafusionSnafu)
472 .map_err(BoxedError::new)
473 .context(QueryExecutionSnafu)?,
474 ))
475 }
476
477 fn engine_context(&self, query_ctx: QueryContextRef) -> QueryEngineContext {
478 let mut state = self.state.session_state();
479 state.config_mut().set_extension(query_ctx.clone());
480 QueryEngineContext::new(state, query_ctx)
481 }
482
483 fn engine_state(&self) -> &QueryEngineState {
484 &self.state
485 }
486}
487
488impl QueryExecutor for DatafusionQueryEngine {
489 #[tracing::instrument(skip_all)]
490 fn execute_stream(
491 &self,
492 ctx: &QueryEngineContext,
493 plan: &Arc<dyn ExecutionPlan>,
494 ) -> Result<SendableRecordBatchStream> {
495 let exec_timer = metrics::EXEC_PLAN_ELAPSED.start_timer();
496 let task_ctx = ctx.build_task_ctx();
497
498 match plan.properties().output_partitioning().partition_count() {
499 0 => {
500 let schema = Arc::new(
501 Schema::try_from(plan.schema())
502 .map_err(BoxedError::new)
503 .context(QueryExecutionSnafu)?,
504 );
505 Ok(Box::pin(EmptyRecordBatchStream::new(schema)))
506 }
507 1 => {
508 let df_stream = plan
509 .execute(0, task_ctx)
510 .context(error::DatafusionSnafu)
511 .map_err(BoxedError::new)
512 .context(QueryExecutionSnafu)?;
513 let mut stream = RecordBatchStreamAdapter::try_new(df_stream)
514 .context(error::ConvertDfRecordBatchStreamSnafu)
515 .map_err(BoxedError::new)
516 .context(QueryExecutionSnafu)?;
517 stream.set_metrics2(plan.clone());
518 stream.set_explain_verbose(ctx.query_ctx().explain_verbose());
519 let stream = OnDone::new(Box::pin(stream), move || {
520 exec_timer.observe_duration();
521 });
522 Ok(Box::pin(stream))
523 }
524 _ => {
525 let merged_plan = CoalescePartitionsExec::new(plan.clone());
527 assert_eq!(
529 1,
530 merged_plan
531 .properties()
532 .output_partitioning()
533 .partition_count()
534 );
535 let df_stream = merged_plan
536 .execute(0, task_ctx)
537 .context(error::DatafusionSnafu)
538 .map_err(BoxedError::new)
539 .context(QueryExecutionSnafu)?;
540 let mut stream = RecordBatchStreamAdapter::try_new(df_stream)
541 .context(error::ConvertDfRecordBatchStreamSnafu)
542 .map_err(BoxedError::new)
543 .context(QueryExecutionSnafu)?;
544 stream.set_metrics2(plan.clone());
545 stream.set_explain_verbose(ctx.query_ctx().explain_verbose());
546 let stream = OnDone::new(Box::pin(stream), move || {
547 exec_timer.observe_duration();
548 });
549 Ok(Box::pin(stream))
550 }
551 }
552 }
553}
554
555#[cfg(test)]
556mod tests {
557 use std::sync::Arc;
558
559 use catalog::RegisterTableRequest;
560 use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, NUMBERS_TABLE_ID};
561 use common_recordbatch::util;
562 use datafusion::prelude::{col, lit};
563 use datatypes::prelude::ConcreteDataType;
564 use datatypes::schema::ColumnSchema;
565 use datatypes::vectors::{Helper, UInt32Vector, UInt64Vector, VectorRef};
566 use session::context::{QueryContext, QueryContextBuilder};
567 use table::table::numbers::{NumbersTable, NUMBERS_TABLE_NAME};
568
569 use super::*;
570 use crate::options::QueryOptions;
571 use crate::parser::QueryLanguageParser;
572 use crate::query_engine::{QueryEngineFactory, QueryEngineRef};
573
574 async fn create_test_engine() -> QueryEngineRef {
575 let catalog_manager = catalog::memory::new_memory_catalog_manager().unwrap();
576 let req = RegisterTableRequest {
577 catalog: DEFAULT_CATALOG_NAME.to_string(),
578 schema: DEFAULT_SCHEMA_NAME.to_string(),
579 table_name: NUMBERS_TABLE_NAME.to_string(),
580 table_id: NUMBERS_TABLE_ID,
581 table: NumbersTable::table(NUMBERS_TABLE_ID),
582 };
583 catalog_manager.register_table_sync(req).unwrap();
584
585 QueryEngineFactory::new(
586 catalog_manager,
587 None,
588 None,
589 None,
590 None,
591 false,
592 QueryOptions::default(),
593 )
594 .query_engine()
595 }
596
597 #[tokio::test]
598 async fn test_sql_to_plan() {
599 let engine = create_test_engine().await;
600 let sql = "select sum(number) from numbers limit 20";
601
602 let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
603 let plan = engine
604 .planner()
605 .plan(&stmt, QueryContext::arc())
606 .await
607 .unwrap();
608
609 assert_eq!(
610 plan.to_string(),
611 r#"Limit: skip=0, fetch=20
612 Projection: sum(numbers.number)
613 Aggregate: groupBy=[[]], aggr=[[sum(numbers.number)]]
614 TableScan: numbers"#
615 );
616 }
617
618 #[tokio::test]
619 async fn test_execute() {
620 let engine = create_test_engine().await;
621 let sql = "select sum(number) from numbers limit 20";
622
623 let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
624 let plan = engine
625 .planner()
626 .plan(&stmt, QueryContext::arc())
627 .await
628 .unwrap();
629
630 let output = engine.execute(plan, QueryContext::arc()).await.unwrap();
631
632 match output.data {
633 OutputData::Stream(recordbatch) => {
634 let numbers = util::collect(recordbatch).await.unwrap();
635 assert_eq!(1, numbers.len());
636 assert_eq!(numbers[0].num_columns(), 1);
637 assert_eq!(1, numbers[0].schema.num_columns());
638 assert_eq!(
639 "sum(numbers.number)",
640 numbers[0].schema.column_schemas()[0].name
641 );
642
643 let batch = &numbers[0];
644 assert_eq!(1, batch.num_columns());
645 assert_eq!(batch.column(0).len(), 1);
646
647 assert_eq!(
648 *batch.column(0),
649 Arc::new(UInt64Vector::from_slice([4950])) as VectorRef
650 );
651 }
652 _ => unreachable!(),
653 }
654 }
655
656 #[tokio::test]
657 async fn test_read_table() {
658 let engine = create_test_engine().await;
659
660 let engine = engine
661 .as_any()
662 .downcast_ref::<DatafusionQueryEngine>()
663 .unwrap();
664 let query_ctx = Arc::new(QueryContextBuilder::default().build());
665 let table = engine
666 .find_table(
667 &ResolvedTableReference {
668 catalog: "greptime".into(),
669 schema: "public".into(),
670 table: "numbers".into(),
671 },
672 &query_ctx,
673 )
674 .await
675 .unwrap();
676
677 let DataFrame::DataFusion(df) = engine.read_table(table).unwrap();
678 let df = df
679 .select_columns(&["number"])
680 .unwrap()
681 .filter(col("number").lt(lit(10)))
682 .unwrap();
683 let batches = df.collect().await.unwrap();
684 assert_eq!(1, batches.len());
685 let batch = &batches[0];
686
687 assert_eq!(1, batch.num_columns());
688 assert_eq!(batch.column(0).len(), 10);
689
690 assert_eq!(
691 Helper::try_into_vector(batch.column(0)).unwrap(),
692 Arc::new(UInt32Vector::from_slice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) as VectorRef
693 );
694 }
695
696 #[tokio::test]
697 async fn test_describe() {
698 let engine = create_test_engine().await;
699 let sql = "select sum(number) from numbers limit 20";
700
701 let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
702
703 let plan = engine
704 .planner()
705 .plan(&stmt, QueryContext::arc())
706 .await
707 .unwrap();
708
709 let DescribeResult {
710 schema,
711 logical_plan,
712 } = engine.describe(plan, QueryContext::arc()).await.unwrap();
713
714 assert_eq!(
715 schema.column_schemas()[0],
716 ColumnSchema::new(
717 "sum(numbers.number)",
718 ConcreteDataType::uint64_datatype(),
719 true
720 )
721 );
722 assert_eq!("Limit: skip=0, fetch=20\n Aggregate: groupBy=[[]], aggr=[[sum(CAST(numbers.number AS UInt64))]]\n TableScan: numbers projection=[number]", format!("{}", logical_plan.display_indent()));
723 }
724}