1use std::any::Any;
16use std::borrow::Cow;
17use std::collections::{HashMap, HashSet};
18use std::str::FromStr;
19use std::sync::Arc;
20
21use arrow_schema::DataType;
22use async_trait::async_trait;
23use catalog::table_source::DfTableSourceProvider;
24use common_error::ext::BoxedError;
25use common_telemetry::tracing;
26use datafusion::common::{DFSchema, plan_err};
27use datafusion::execution::context::SessionState;
28use datafusion::sql::planner::PlannerContext;
29use datafusion_common::ToDFSchema;
30use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
31use datafusion_expr::expr::{Exists, InSubquery};
32use datafusion_expr::{
33 Analyze, Explain, ExplainFormat, Expr as DfExpr, LogicalPlan, LogicalPlanBuilder, PlanType,
34 ToStringifiedPlan, col,
35};
36use datafusion_sql::planner::{ParserOptions, SqlToRel};
37use log_query::LogQuery;
38use promql_parser::parser::EvalStmt;
39use session::context::QueryContextRef;
40use snafu::{ResultExt, ensure};
41use sql::CteContent;
42use sql::ast::Expr as SqlExpr;
43use sql::statements::explain::ExplainStatement;
44use sql::statements::query::Query;
45use sql::statements::statement::Statement;
46use sql::statements::tql::Tql;
47
48use crate::error::{
49 CteColumnSchemaMismatchSnafu, PlanSqlSnafu, QueryPlanSnafu, Result, SqlSnafu,
50 UnimplementedSnafu,
51};
52use crate::log_query::planner::LogQueryPlanner;
53use crate::parser::{DEFAULT_LOOKBACK_STRING, PromQuery, QueryLanguageParser, QueryStatement};
54use crate::promql::planner::PromPlanner;
55use crate::query_engine::{DefaultPlanDecoder, QueryEngineState};
56use crate::range_select::plan_rewrite::RangePlanRewriter;
57use crate::{DfContextProviderAdapter, QueryEngineContext};
58
59#[async_trait]
60pub trait LogicalPlanner: Send + Sync {
61 async fn plan(&self, stmt: &QueryStatement, query_ctx: QueryContextRef) -> Result<LogicalPlan>;
62
63 async fn plan_logs_query(
64 &self,
65 query: LogQuery,
66 query_ctx: QueryContextRef,
67 ) -> Result<LogicalPlan>;
68
69 fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan>;
70
71 fn as_any(&self) -> &dyn Any;
72}
73
74pub struct DfLogicalPlanner {
75 engine_state: Arc<QueryEngineState>,
76 session_state: SessionState,
77}
78
79impl DfLogicalPlanner {
80 pub fn new(engine_state: Arc<QueryEngineState>) -> Self {
81 let session_state = engine_state.session_state();
82 Self {
83 engine_state,
84 session_state,
85 }
86 }
87
88 async fn explain_to_plan(
91 &self,
92 explain: &ExplainStatement,
93 query_ctx: QueryContextRef,
94 ) -> Result<LogicalPlan> {
95 let plan = self.plan_sql(&explain.statement, query_ctx).await?;
96 if matches!(plan, LogicalPlan::Explain(_)) {
97 return plan_err!("Nested EXPLAINs are not supported").context(PlanSqlSnafu);
98 }
99
100 let verbose = explain.verbose;
101 let analyze = explain.analyze;
102 let format = explain.format.map(|f| f.to_string());
103
104 let plan = Arc::new(plan);
105 let schema = LogicalPlan::explain_schema();
106 let schema = ToDFSchema::to_dfschema_ref(schema)?;
107
108 if verbose && format.is_some() {
109 return plan_err!("EXPLAIN VERBOSE with FORMAT is not supported").context(PlanSqlSnafu);
110 }
111
112 if analyze {
113 Ok(LogicalPlan::Analyze(Analyze {
115 verbose,
116 input: plan,
117 schema,
118 }))
119 } else {
120 let stringified_plans = vec![plan.to_stringified(PlanType::InitialLogicalPlan)];
121
122 let options = self.session_state.config().options();
124 let format = format
125 .map(|x| ExplainFormat::from_str(&x))
126 .transpose()?
127 .unwrap_or_else(|| options.explain.format.clone());
128
129 Ok(LogicalPlan::Explain(Explain {
130 verbose,
131 explain_format: format,
132 plan,
133 stringified_plans,
134 schema,
135 logical_optimization_succeeded: false,
136 }))
137 }
138 }
139
140 #[tracing::instrument(skip_all)]
141 #[async_recursion::async_recursion]
142 async fn plan_sql(&self, stmt: &Statement, query_ctx: QueryContextRef) -> Result<LogicalPlan> {
143 let mut planner_context = PlannerContext::new();
144 let mut stmt = Cow::Borrowed(stmt);
145 let mut is_tql_cte = false;
146
147 if let Statement::Explain(explain) = stmt.as_ref() {
149 return self.explain_to_plan(explain, query_ctx).await;
150 }
151
152 if self.has_hybrid_ctes(stmt.as_ref()) {
154 let stmt_owned = stmt.into_owned();
155 let mut query = match stmt_owned {
156 Statement::Query(query) => query.as_ref().clone(),
157 _ => unreachable!("has_hybrid_ctes should only return true for Query statements"),
158 };
159 self.plan_query_with_hybrid_ctes(&query, query_ctx.clone(), &mut planner_context)
160 .await?;
161
162 query.hybrid_cte = None;
164 stmt = Cow::Owned(Statement::Query(Box::new(query)));
165 is_tql_cte = true;
166 }
167
168 let mut df_stmt = stmt.as_ref().try_into().context(SqlSnafu)?;
169
170 if let datafusion::sql::parser::Statement::Statement(
172 box datafusion::sql::sqlparser::ast::Statement::Explain { .. },
173 ) = &mut df_stmt
174 {
175 UnimplementedSnafu {
176 operation: "EXPLAIN with FORMAT using raw datafusion planner",
177 }
178 .fail()?;
179 }
180
181 let table_provider = DfTableSourceProvider::new(
182 self.engine_state.catalog_manager().clone(),
183 self.engine_state.disallow_cross_catalog_query(),
184 query_ctx.clone(),
185 Arc::new(DefaultPlanDecoder::new(
186 self.session_state.clone(),
187 &query_ctx,
188 )?),
189 self.session_state
190 .config_options()
191 .sql_parser
192 .enable_ident_normalization,
193 );
194
195 let context_provider = DfContextProviderAdapter::try_new(
196 self.engine_state.clone(),
197 self.session_state.clone(),
198 Some(&df_stmt),
199 query_ctx.clone(),
200 )
201 .await?;
202
203 let config_options = self.session_state.config().options();
204 let parser_options = &config_options.sql_parser;
205 let parser_options = ParserOptions {
206 map_string_types_to_utf8view: false,
207 ..parser_options.into()
208 };
209
210 let sql_to_rel = SqlToRel::new_with_options(&context_provider, parser_options);
211
212 let result = if is_tql_cte {
214 let Statement::Query(query) = stmt.into_owned() else {
215 unreachable!("is_tql_cte should only be true for Query statements");
216 };
217 let sqlparser_stmt = sqlparser::ast::Statement::Query(Box::new(query.inner));
218 sql_to_rel
219 .sql_statement_to_plan_with_context(sqlparser_stmt, &mut planner_context)
220 .context(PlanSqlSnafu)?
221 } else {
222 sql_to_rel
223 .statement_to_plan(df_stmt)
224 .context(PlanSqlSnafu)?
225 };
226
227 common_telemetry::debug!("Logical planner, statement to plan result: {result}");
228 let plan = RangePlanRewriter::new(table_provider, query_ctx.clone())
229 .rewrite(result)
230 .await?;
231
232 let context = QueryEngineContext::new(self.session_state.clone(), query_ctx);
234 let plan = self
235 .engine_state
236 .optimize_by_extension_rules(plan, &context)?;
237 common_telemetry::debug!("Logical planner, optimize result: {plan}");
238
239 Ok(plan)
240 }
241
242 #[tracing::instrument(skip_all)]
244 pub(crate) async fn sql_to_expr(
245 &self,
246 sql: SqlExpr,
247 schema: &DFSchema,
248 normalize_ident: bool,
249 query_ctx: QueryContextRef,
250 ) -> Result<DfExpr> {
251 let context_provider = DfContextProviderAdapter::try_new(
252 self.engine_state.clone(),
253 self.session_state.clone(),
254 None,
255 query_ctx,
256 )
257 .await?;
258
259 let config_options = self.session_state.config().options();
260 let parser_options = &config_options.sql_parser;
261 let parser_options: ParserOptions = ParserOptions {
262 map_string_types_to_utf8view: false,
263 enable_ident_normalization: normalize_ident,
264 ..parser_options.into()
265 };
266
267 let sql_to_rel = SqlToRel::new_with_options(&context_provider, parser_options);
268
269 Ok(sql_to_rel.sql_to_expr(sql, schema, &mut PlannerContext::new())?)
270 }
271
272 #[tracing::instrument(skip_all)]
273 async fn plan_pql(&self, stmt: &EvalStmt, query_ctx: QueryContextRef) -> Result<LogicalPlan> {
274 let plan_decoder = Arc::new(DefaultPlanDecoder::new(
275 self.session_state.clone(),
276 &query_ctx,
277 )?);
278 let table_provider = DfTableSourceProvider::new(
279 self.engine_state.catalog_manager().clone(),
280 self.engine_state.disallow_cross_catalog_query(),
281 query_ctx.clone(),
282 plan_decoder,
283 self.session_state
284 .config_options()
285 .sql_parser
286 .enable_ident_normalization,
287 );
288 let plan = PromPlanner::stmt_to_plan(table_provider, stmt, &self.engine_state)
289 .await
290 .map_err(BoxedError::new)
291 .context(QueryPlanSnafu)?;
292
293 let context = QueryEngineContext::new(self.session_state.clone(), query_ctx);
294 Ok(self
295 .engine_state
296 .optimize_by_extension_rules(plan, &context)?)
297 }
298
299 #[tracing::instrument(skip_all)]
300 fn optimize_logical_plan(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
301 Ok(self.engine_state.optimize_logical_plan(plan)?)
302 }
303
304 fn has_hybrid_ctes(&self, stmt: &Statement) -> bool {
306 if let Statement::Query(query) = stmt {
307 query
308 .hybrid_cte
309 .as_ref()
310 .map(|hybrid_cte| !hybrid_cte.cte_tables.is_empty())
311 .unwrap_or(false)
312 } else {
313 false
314 }
315 }
316
317 async fn plan_query_with_hybrid_ctes(
319 &self,
320 query: &Query,
321 query_ctx: QueryContextRef,
322 planner_context: &mut PlannerContext,
323 ) -> Result<()> {
324 let hybrid_cte = query.hybrid_cte.as_ref().unwrap();
325
326 for cte in &hybrid_cte.cte_tables {
327 match &cte.content {
328 CteContent::Tql(tql) => {
329 let mut logical_plan = self.tql_to_logical_plan(tql, query_ctx.clone()).await?;
331 if !cte.columns.is_empty() {
332 let schema = logical_plan.schema();
333 let schema_fields = schema.fields().to_vec();
334 ensure!(
335 schema_fields.len() == cte.columns.len(),
336 CteColumnSchemaMismatchSnafu {
337 cte_name: cte.name.value.clone(),
338 original: schema_fields
339 .iter()
340 .map(|field| field.name().clone())
341 .collect::<Vec<_>>(),
342 expected: cte
343 .columns
344 .iter()
345 .map(|column| column.to_string())
346 .collect::<Vec<_>>(),
347 }
348 );
349 let aliases = cte
350 .columns
351 .iter()
352 .zip(schema_fields.iter())
353 .map(|(column, field)| col(field.name()).alias(column.to_string()));
354 logical_plan = LogicalPlanBuilder::from(logical_plan)
355 .project(aliases)
356 .context(PlanSqlSnafu)?
357 .build()
358 .context(PlanSqlSnafu)?;
359 }
360
361 logical_plan = LogicalPlan::SubqueryAlias(
363 datafusion_expr::SubqueryAlias::try_new(
364 Arc::new(logical_plan),
365 cte.name.value.clone(),
366 )
367 .context(PlanSqlSnafu)?,
368 );
369
370 planner_context.insert_cte(&cte.name.value, logical_plan);
371 }
372 CteContent::Sql(_) => {
373 unreachable!("SQL CTEs should not be in hybrid_cte.cte_tables");
376 }
377 }
378 }
379
380 Ok(())
381 }
382
383 async fn tql_to_logical_plan(
385 &self,
386 tql: &Tql,
387 query_ctx: QueryContextRef,
388 ) -> Result<LogicalPlan> {
389 match tql {
390 Tql::Eval(eval) => {
391 let prom_query = PromQuery {
393 query: eval.query.clone(),
394 start: eval.start.clone(),
395 end: eval.end.clone(),
396 step: eval.step.clone(),
397 lookback: eval
398 .lookback
399 .clone()
400 .unwrap_or_else(|| DEFAULT_LOOKBACK_STRING.to_string()),
401 alias: eval.alias.clone(),
402 };
403 let stmt = QueryLanguageParser::parse_promql(&prom_query, &query_ctx)?;
404
405 self.plan(&stmt, query_ctx).await
406 }
407 Tql::Explain(_) => UnimplementedSnafu {
408 operation: "TQL EXPLAIN in CTEs",
409 }
410 .fail(),
411 Tql::Analyze(_) => UnimplementedSnafu {
412 operation: "TQL ANALYZE in CTEs",
413 }
414 .fail(),
415 }
416 }
417
418 fn extract_placeholder_cast_types(
428 plan: &LogicalPlan,
429 ) -> Result<HashMap<String, Option<DataType>>> {
430 let mut placeholder_types = HashMap::new();
431 let mut casted_placeholders = HashSet::new();
432
433 Self::extract_from_plan(plan, &mut placeholder_types, &mut casted_placeholders)?;
434
435 Ok(placeholder_types)
436 }
437
438 fn extract_from_plan(
439 plan: &LogicalPlan,
440 placeholder_types: &mut HashMap<String, Option<DataType>>,
441 casted_placeholders: &mut HashSet<String>,
442 ) -> Result<()> {
443 plan.apply(|node| {
444 for expr in node.expressions() {
445 let _ = expr.apply(|e| {
446 if let DfExpr::Cast(cast) = e
448 && let DfExpr::Placeholder(ph) = &*cast.expr
449 {
450 placeholder_types.insert(ph.id.clone(), Some(cast.data_type.clone()));
451 casted_placeholders.insert(ph.id.clone());
452 }
453
454 if let DfExpr::Placeholder(ph) = e
456 && !casted_placeholders.contains(&ph.id)
457 && !placeholder_types.contains_key(&ph.id)
458 {
459 placeholder_types.insert(ph.id.clone(), None);
460 }
461
462 match e {
464 DfExpr::Exists(Exists { subquery, .. })
465 | DfExpr::InSubquery(InSubquery { subquery, .. })
466 | DfExpr::ScalarSubquery(subquery) => {
467 Self::extract_from_plan(
468 &subquery.subquery,
469 placeholder_types,
470 casted_placeholders,
471 )?;
472 }
473 _ => {}
474 }
475
476 Ok(TreeNodeRecursion::Continue)
477 });
478 }
479 Ok(TreeNodeRecursion::Continue)
480 })?;
481 Ok(())
482 }
483
484 pub fn get_inferred_parameter_types(
498 plan: &LogicalPlan,
499 ) -> Result<HashMap<String, Option<DataType>>> {
500 let param_types = plan.get_parameter_types().context(PlanSqlSnafu)?;
501
502 let has_none = param_types.values().any(|v| v.is_none());
503
504 if !has_none {
505 Ok(param_types)
506 } else {
507 let cast_types = Self::extract_placeholder_cast_types(plan)?;
508
509 let mut merged = param_types;
510
511 for (id, opt_type) in cast_types {
512 merged
513 .entry(id)
514 .and_modify(|existing| {
515 if existing.is_none() {
516 *existing = opt_type.clone();
517 }
518 })
519 .or_insert(opt_type);
520 }
521
522 Ok(merged)
523 }
524 }
525}
526
527#[async_trait]
528impl LogicalPlanner for DfLogicalPlanner {
529 #[tracing::instrument(skip_all)]
530 async fn plan(&self, stmt: &QueryStatement, query_ctx: QueryContextRef) -> Result<LogicalPlan> {
531 match stmt {
532 QueryStatement::Sql(stmt) => self.plan_sql(stmt, query_ctx).await,
533 QueryStatement::Promql(stmt, _alias) => self.plan_pql(stmt, query_ctx).await,
534 }
535 }
536
537 async fn plan_logs_query(
538 &self,
539 query: LogQuery,
540 query_ctx: QueryContextRef,
541 ) -> Result<LogicalPlan> {
542 let plan_decoder = Arc::new(DefaultPlanDecoder::new(
543 self.session_state.clone(),
544 &query_ctx,
545 )?);
546 let table_provider = DfTableSourceProvider::new(
547 self.engine_state.catalog_manager().clone(),
548 self.engine_state.disallow_cross_catalog_query(),
549 query_ctx,
550 plan_decoder,
551 self.session_state
552 .config_options()
553 .sql_parser
554 .enable_ident_normalization,
555 );
556
557 let mut planner = LogQueryPlanner::new(table_provider, self.session_state.clone());
558 planner
559 .query_to_plan(query)
560 .await
561 .map_err(BoxedError::new)
562 .context(QueryPlanSnafu)
563 }
564
565 fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
566 self.optimize_logical_plan(plan)
567 }
568
569 fn as_any(&self) -> &dyn Any {
570 self
571 }
572}
573
574#[cfg(test)]
575mod tests {
576 use std::sync::Arc;
577
578 use arrow_schema::DataType;
579 use catalog::RegisterTableRequest;
580 use catalog::memory::MemoryCatalogManager;
581 use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
582 use datatypes::prelude::ConcreteDataType;
583 use datatypes::schema::{ColumnSchema, Schema};
584 use session::context::QueryContext;
585 use store_api::metric_engine_consts::{
586 DATA_SCHEMA_TABLE_ID_COLUMN_NAME, DATA_SCHEMA_TSID_COLUMN_NAME, LOGICAL_TABLE_METADATA_KEY,
587 METRIC_ENGINE_NAME,
588 };
589 use table::metadata::{TableInfoBuilder, TableMetaBuilder};
590 use table::test_util::EmptyTable;
591
592 use super::*;
593 use crate::parser::{PromQuery, QueryLanguageParser};
594 use crate::{QueryEngineFactory, QueryEngineRef};
595
596 async fn create_test_engine() -> QueryEngineRef {
597 let columns = vec![
598 ColumnSchema::new("id", ConcreteDataType::int32_datatype(), false),
599 ColumnSchema::new("name", ConcreteDataType::string_datatype(), true),
600 ];
601 let schema = Arc::new(Schema::new(columns));
602 let table_meta = TableMetaBuilder::empty()
603 .schema(schema)
604 .primary_key_indices(vec![0])
605 .value_indices(vec![1])
606 .next_column_id(1024)
607 .build()
608 .unwrap();
609 let table_info = TableInfoBuilder::new("test", table_meta).build().unwrap();
610 let table = EmptyTable::from_table_info(&table_info);
611
612 crate::tests::new_query_engine_with_table(table)
613 }
614
615 fn create_promql_test_engine() -> QueryEngineRef {
616 let catalog_manager = MemoryCatalogManager::with_default_setup();
617 let physical_table_name = "phy";
618 let physical_table_id = 999u32;
619
620 let physical_schema = Arc::new(Schema::new(vec![
621 ColumnSchema::new(
622 DATA_SCHEMA_TABLE_ID_COLUMN_NAME.to_string(),
623 ConcreteDataType::uint32_datatype(),
624 false,
625 ),
626 ColumnSchema::new(
627 DATA_SCHEMA_TSID_COLUMN_NAME.to_string(),
628 ConcreteDataType::uint64_datatype(),
629 false,
630 ),
631 ColumnSchema::new("tag_0", ConcreteDataType::string_datatype(), false),
632 ColumnSchema::new("tag_1", ConcreteDataType::string_datatype(), false),
633 ColumnSchema::new(
634 "timestamp",
635 ConcreteDataType::timestamp_millisecond_datatype(),
636 false,
637 )
638 .with_time_index(true),
639 ColumnSchema::new("field_0", ConcreteDataType::float64_datatype(), true),
640 ]));
641 let physical_meta = TableMetaBuilder::empty()
642 .schema(physical_schema)
643 .primary_key_indices(vec![0, 1, 2, 3])
644 .value_indices(vec![4, 5])
645 .engine(METRIC_ENGINE_NAME.to_string())
646 .next_column_id(1024)
647 .build()
648 .unwrap();
649 let physical_info = TableInfoBuilder::default()
650 .table_id(physical_table_id)
651 .name(physical_table_name)
652 .meta(physical_meta)
653 .build()
654 .unwrap();
655 catalog_manager
656 .register_table_sync(RegisterTableRequest {
657 catalog: DEFAULT_CATALOG_NAME.to_string(),
658 schema: DEFAULT_SCHEMA_NAME.to_string(),
659 table_name: physical_table_name.to_string(),
660 table_id: physical_table_id,
661 table: EmptyTable::from_table_info(&physical_info),
662 })
663 .unwrap();
664
665 let mut options = table::requests::TableOptions::default();
666 options.extra_options.insert(
667 LOGICAL_TABLE_METADATA_KEY.to_string(),
668 physical_table_name.to_string(),
669 );
670 let logical_schema = Arc::new(Schema::new(vec![
671 ColumnSchema::new("tag_0", ConcreteDataType::string_datatype(), false),
672 ColumnSchema::new("tag_1", ConcreteDataType::string_datatype(), false),
673 ColumnSchema::new(
674 "timestamp",
675 ConcreteDataType::timestamp_millisecond_datatype(),
676 false,
677 )
678 .with_time_index(true),
679 ColumnSchema::new("field_0", ConcreteDataType::float64_datatype(), true),
680 ]));
681 let logical_meta = TableMetaBuilder::empty()
682 .schema(logical_schema)
683 .primary_key_indices(vec![0, 1])
684 .value_indices(vec![3])
685 .engine(METRIC_ENGINE_NAME.to_string())
686 .options(options)
687 .next_column_id(1024)
688 .build()
689 .unwrap();
690 let logical_info = TableInfoBuilder::default()
691 .table_id(1024)
692 .name("some_metric")
693 .meta(logical_meta)
694 .build()
695 .unwrap();
696 catalog_manager
697 .register_table_sync(RegisterTableRequest {
698 catalog: DEFAULT_CATALOG_NAME.to_string(),
699 schema: DEFAULT_SCHEMA_NAME.to_string(),
700 table_name: "some_metric".to_string(),
701 table_id: 1024,
702 table: EmptyTable::from_table_info(&logical_info),
703 })
704 .unwrap();
705
706 QueryEngineFactory::new(
707 catalog_manager,
708 None,
709 None,
710 None,
711 None,
712 false,
713 crate::options::QueryOptions::default(),
714 )
715 .query_engine()
716 }
717
718 async fn parse_sql_to_plan(sql: &str) -> LogicalPlan {
719 let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
720 let engine = create_test_engine().await;
721 engine
722 .planner()
723 .plan(&stmt, QueryContext::arc())
724 .await
725 .unwrap()
726 }
727
728 async fn parse_promql_to_plan(query: &str) -> LogicalPlan {
729 let engine = create_promql_test_engine();
730 let query_ctx = QueryContext::arc();
731 let stmt = QueryLanguageParser::parse_promql(
732 &PromQuery {
733 query: query.to_string(),
734 start: "0".to_string(),
735 end: "10".to_string(),
736 step: "5s".to_string(),
737 lookback: "300s".to_string(),
738 alias: None,
739 },
740 &query_ctx,
741 )
742 .unwrap();
743
744 engine.planner().plan(&stmt, query_ctx).await.unwrap()
745 }
746
747 #[tokio::test]
748 async fn test_extract_placeholder_cast_types_multiple() {
749 let plan = parse_sql_to_plan(
750 "SELECT $1::INT, $2::TEXT, $3, $4::INTEGER FROM test WHERE $5::FLOAT > 0",
751 )
752 .await;
753 let types = DfLogicalPlanner::extract_placeholder_cast_types(&plan).unwrap();
754
755 assert_eq!(types.len(), 5);
756 assert_eq!(types.get("$1"), Some(&Some(DataType::Int32)));
757 assert_eq!(types.get("$2"), Some(&Some(DataType::Utf8)));
758 assert_eq!(types.get("$3"), Some(&None));
759 assert_eq!(types.get("$4"), Some(&Some(DataType::Int32)));
760 assert_eq!(types.get("$5"), Some(&Some(DataType::Float32)));
761 }
762
763 #[tokio::test]
764 async fn test_get_inferred_parameter_types_fallback_for_udf_args() {
765 let plan = parse_sql_to_plan(
767 "SELECT parse_ident($1), parse_ident($2::TEXT) FROM test WHERE id > $3",
768 )
769 .await;
770 let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap();
771
772 assert_eq!(types.len(), 3);
773
774 let type_1 = types.get("$1").unwrap();
775 let type_2 = types.get("$2").unwrap();
776 let type_3 = types.get("$3").unwrap();
777
778 assert!(type_1.is_none(), "Expected $1 to be None");
779 assert_eq!(type_2, &Some(DataType::Utf8));
780 assert_eq!(type_3, &Some(DataType::Int32));
781 }
782
783 #[tokio::test]
784 async fn test_plan_pql_applies_extension_rules() {
785 for inner_agg in ["count", "sum", "avg", "min", "max", "stddev", "stdvar"] {
786 let plan = parse_promql_to_plan(&format!(
787 "sum(irate(some_metric[1h])) / scalar(count({inner_agg}(some_metric) by (tag_0)))"
788 ))
789 .await;
790 let plan_str = plan.display_indent_schema().to_string();
791 assert!(plan_str.contains("Distinct:"), "{inner_agg}: {plan_str}");
792 }
793 }
794
795 #[tokio::test]
796 async fn test_plan_pql_filters_null_only_groups_for_non_count_inner_aggs() {
797 let count_plan = parse_promql_to_plan("scalar(count(count(some_metric) by (tag_0)))").await;
798 let count_plan_str = count_plan.display_indent_schema().to_string();
799 assert!(
800 !count_plan_str.contains("field_0 IS NOT NULL"),
801 "{count_plan_str}"
802 );
803
804 for inner_agg in ["sum", "avg", "min", "max", "stddev", "stdvar"] {
805 let plan = parse_promql_to_plan(&format!(
806 "scalar(count({inner_agg}(some_metric) by (tag_0)))"
807 ))
808 .await;
809 let plan_str = plan.display_indent_schema().to_string();
810 assert!(
811 plan_str.contains("field_0 IS NOT NULL"),
812 "{inner_agg}: {plan_str}"
813 );
814 }
815 }
816
817 #[tokio::test]
818 async fn test_plan_pql_skips_extension_rules_for_non_direct_or_unsupported_inner_agg() {
819 for query in [
820 "sum(irate(some_metric[1h])) / scalar(count(sum(irate(some_metric[1h])) by (tag_0)))",
821 "sum(irate(some_metric[1h])) / scalar(count(group(some_metric) by (tag_0)))",
822 ] {
823 let plan = parse_promql_to_plan(query).await;
824 let plan_str = plan.display_indent_schema().to_string();
825 assert!(!plan_str.contains("Distinct:"), "{query}: {plan_str}");
826 }
827 }
828
829 #[tokio::test]
830 async fn test_plan_sql_does_not_apply_nested_count_rule() {
831 let plan = parse_sql_to_plan(
832 "SELECT id, count(inner_count) \
833 FROM ( \
834 SELECT id, count(name) AS inner_count \
835 FROM test \
836 GROUP BY id \
837 ORDER BY id \
838 LIMIT 1000000 \
839 ) t \
840 GROUP BY id \
841 ORDER BY id",
842 )
843 .await;
844
845 let plan_str = plan.display_indent_schema().to_string();
846 assert!(!plan_str.contains("Distinct:"), "{plan_str}");
847 }
848
849 #[tokio::test]
850 async fn test_get_inferred_parameter_types_subquery() {
851 let plan = parse_sql_to_plan(
852 r#"SELECT * FROM test WHERE id = (SELECT id FROM test CROSS JOIN (SELECT parse_ident($1::TEXT) AS parts) p LIMIT 1)"#,
853 ).await;
854 let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap();
855
856 assert_eq!(types.len(), 1);
857 let type_1 = types.get("$1").unwrap();
858 assert_eq!(type_1, &Some(DataType::Utf8));
859 }
860
861 #[tokio::test]
862 async fn test_get_inferred_parameter_types_insert() {
863 let plan = parse_sql_to_plan("INSERT INTO test (id, name) VALUES ($1, $2), ($3, $4)").await;
864 let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap();
865
866 assert_eq!(types.len(), 4);
867 assert_eq!(types.get("$1"), Some(&Some(DataType::Int32)));
868 assert_eq!(types.get("$2"), Some(&Some(DataType::Utf8)));
869 assert_eq!(types.get("$3"), Some(&Some(DataType::Int32)));
870 assert_eq!(types.get("$4"), Some(&Some(DataType::Utf8)));
871 }
872}