Skip to main content

query/
planner.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::borrow::Cow;
17use std::collections::{HashMap, HashSet};
18use std::str::FromStr;
19use std::sync::Arc;
20
21use arrow_schema::DataType;
22use async_trait::async_trait;
23use catalog::table_source::DfTableSourceProvider;
24use common_error::ext::BoxedError;
25use common_telemetry::tracing;
26use datafusion::common::{DFSchema, plan_err};
27use datafusion::execution::context::SessionState;
28use datafusion::sql::planner::PlannerContext;
29use datafusion_common::ToDFSchema;
30use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
31use datafusion_expr::expr::{Exists, InSubquery};
32use datafusion_expr::{
33    Analyze, Explain, ExplainFormat, Expr as DfExpr, LogicalPlan, LogicalPlanBuilder, PlanType,
34    ToStringifiedPlan, col,
35};
36use datafusion_sql::planner::{ParserOptions, SqlToRel};
37use log_query::LogQuery;
38use promql_parser::parser::EvalStmt;
39use session::context::QueryContextRef;
40use snafu::{ResultExt, ensure};
41use sql::CteContent;
42use sql::ast::Expr as SqlExpr;
43use sql::statements::explain::ExplainStatement;
44use sql::statements::query::Query;
45use sql::statements::statement::Statement;
46use sql::statements::tql::Tql;
47
48use crate::error::{
49    CteColumnSchemaMismatchSnafu, PlanSqlSnafu, QueryPlanSnafu, Result, SqlSnafu,
50    UnimplementedSnafu,
51};
52use crate::log_query::planner::LogQueryPlanner;
53use crate::parser::{DEFAULT_LOOKBACK_STRING, PromQuery, QueryLanguageParser, QueryStatement};
54use crate::promql::planner::PromPlanner;
55use crate::query_engine::{DefaultPlanDecoder, QueryEngineState};
56use crate::range_select::plan_rewrite::RangePlanRewriter;
57use crate::{DfContextProviderAdapter, QueryEngineContext};
58
59#[async_trait]
60pub trait LogicalPlanner: Send + Sync {
61    async fn plan(&self, stmt: &QueryStatement, query_ctx: QueryContextRef) -> Result<LogicalPlan>;
62
63    async fn plan_logs_query(
64        &self,
65        query: LogQuery,
66        query_ctx: QueryContextRef,
67    ) -> Result<LogicalPlan>;
68
69    fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan>;
70
71    fn as_any(&self) -> &dyn Any;
72}
73
74pub struct DfLogicalPlanner {
75    engine_state: Arc<QueryEngineState>,
76    session_state: SessionState,
77}
78
79impl DfLogicalPlanner {
80    pub fn new(engine_state: Arc<QueryEngineState>) -> Self {
81        let session_state = engine_state.session_state();
82        Self {
83            engine_state,
84            session_state,
85        }
86    }
87
88    /// Basically the same with `explain_to_plan` in DataFusion, but adapted to Greptime's
89    /// `plan_sql` to support Greptime Statements.
90    async fn explain_to_plan(
91        &self,
92        explain: &ExplainStatement,
93        query_ctx: QueryContextRef,
94    ) -> Result<LogicalPlan> {
95        let plan = self.plan_sql(&explain.statement, query_ctx).await?;
96        if matches!(plan, LogicalPlan::Explain(_)) {
97            return plan_err!("Nested EXPLAINs are not supported").context(PlanSqlSnafu);
98        }
99
100        let verbose = explain.verbose;
101        let analyze = explain.analyze;
102        let format = explain.format.map(|f| f.to_string());
103
104        let plan = Arc::new(plan);
105        let schema = LogicalPlan::explain_schema();
106        let schema = ToDFSchema::to_dfschema_ref(schema)?;
107
108        if verbose && format.is_some() {
109            return plan_err!("EXPLAIN VERBOSE with FORMAT is not supported").context(PlanSqlSnafu);
110        }
111
112        if analyze {
113            // notice format is already set in query context, so can be ignore here
114            Ok(LogicalPlan::Analyze(Analyze {
115                verbose,
116                input: plan,
117                schema,
118            }))
119        } else {
120            let stringified_plans = vec![plan.to_stringified(PlanType::InitialLogicalPlan)];
121
122            // default to configuration value
123            let options = self.session_state.config().options();
124            let format = format
125                .map(|x| ExplainFormat::from_str(&x))
126                .transpose()?
127                .unwrap_or_else(|| options.explain.format.clone());
128
129            Ok(LogicalPlan::Explain(Explain {
130                verbose,
131                explain_format: format,
132                plan,
133                stringified_plans,
134                schema,
135                logical_optimization_succeeded: false,
136            }))
137        }
138    }
139
140    #[tracing::instrument(skip_all)]
141    #[async_recursion::async_recursion]
142    async fn plan_sql(&self, stmt: &Statement, query_ctx: QueryContextRef) -> Result<LogicalPlan> {
143        let mut planner_context = PlannerContext::new();
144        let mut stmt = Cow::Borrowed(stmt);
145        let mut is_tql_cte = false;
146
147        // handle explain before normal processing so we can explain Greptime Statements
148        if let Statement::Explain(explain) = stmt.as_ref() {
149            return self.explain_to_plan(explain, query_ctx).await;
150        }
151
152        // Check for hybrid CTEs before normal processing
153        if self.has_hybrid_ctes(stmt.as_ref()) {
154            let stmt_owned = stmt.into_owned();
155            let mut query = match stmt_owned {
156                Statement::Query(query) => query.as_ref().clone(),
157                _ => unreachable!("has_hybrid_ctes should only return true for Query statements"),
158            };
159            self.plan_query_with_hybrid_ctes(&query, query_ctx.clone(), &mut planner_context)
160                .await?;
161
162            // remove the processed TQL CTEs from the query
163            query.hybrid_cte = None;
164            stmt = Cow::Owned(Statement::Query(Box::new(query)));
165            is_tql_cte = true;
166        }
167
168        let mut df_stmt = stmt.as_ref().try_into().context(SqlSnafu)?;
169
170        // TODO(LFC): Remove this when Datafusion supports **both** the syntax and implementation of "explain with format".
171        if let datafusion::sql::parser::Statement::Statement(
172            box datafusion::sql::sqlparser::ast::Statement::Explain { .. },
173        ) = &mut df_stmt
174        {
175            UnimplementedSnafu {
176                operation: "EXPLAIN with FORMAT using raw datafusion planner",
177            }
178            .fail()?;
179        }
180
181        let table_provider = DfTableSourceProvider::new(
182            self.engine_state.catalog_manager().clone(),
183            self.engine_state.disallow_cross_catalog_query(),
184            query_ctx.clone(),
185            Arc::new(DefaultPlanDecoder::new(
186                self.session_state.clone(),
187                &query_ctx,
188            )?),
189            self.session_state
190                .config_options()
191                .sql_parser
192                .enable_ident_normalization,
193        );
194
195        let context_provider = DfContextProviderAdapter::try_new(
196            self.engine_state.clone(),
197            self.session_state.clone(),
198            Some(&df_stmt),
199            query_ctx.clone(),
200        )
201        .await?;
202
203        let config_options = self.session_state.config().options();
204        let parser_options = &config_options.sql_parser;
205        let parser_options = ParserOptions {
206            map_string_types_to_utf8view: false,
207            ..parser_options.into()
208        };
209
210        let sql_to_rel = SqlToRel::new_with_options(&context_provider, parser_options);
211
212        // this IF is to handle different version of ASTs
213        let result = if is_tql_cte {
214            let Statement::Query(query) = stmt.into_owned() else {
215                unreachable!("is_tql_cte should only be true for Query statements");
216            };
217            let sqlparser_stmt = sqlparser::ast::Statement::Query(Box::new(query.inner));
218            sql_to_rel
219                .sql_statement_to_plan_with_context(sqlparser_stmt, &mut planner_context)
220                .context(PlanSqlSnafu)?
221        } else {
222            sql_to_rel
223                .statement_to_plan(df_stmt)
224                .context(PlanSqlSnafu)?
225        };
226
227        common_telemetry::debug!("Logical planner, statement to plan result: {result}");
228        let plan = RangePlanRewriter::new(table_provider, query_ctx.clone())
229            .rewrite(result)
230            .await?;
231
232        // Optimize logical plan by extension rules
233        let context = QueryEngineContext::new(self.session_state.clone(), query_ctx);
234        let plan = self
235            .engine_state
236            .optimize_by_extension_rules(plan, &context)?;
237        common_telemetry::debug!("Logical planner, optimize result: {plan}");
238
239        Ok(plan)
240    }
241
242    /// Generate a relational expression from a SQL expression
243    #[tracing::instrument(skip_all)]
244    pub(crate) async fn sql_to_expr(
245        &self,
246        sql: SqlExpr,
247        schema: &DFSchema,
248        normalize_ident: bool,
249        query_ctx: QueryContextRef,
250    ) -> Result<DfExpr> {
251        let context_provider = DfContextProviderAdapter::try_new(
252            self.engine_state.clone(),
253            self.session_state.clone(),
254            None,
255            query_ctx,
256        )
257        .await?;
258
259        let config_options = self.session_state.config().options();
260        let parser_options = &config_options.sql_parser;
261        let parser_options: ParserOptions = ParserOptions {
262            map_string_types_to_utf8view: false,
263            enable_ident_normalization: normalize_ident,
264            ..parser_options.into()
265        };
266
267        let sql_to_rel = SqlToRel::new_with_options(&context_provider, parser_options);
268
269        Ok(sql_to_rel.sql_to_expr(sql, schema, &mut PlannerContext::new())?)
270    }
271
272    #[tracing::instrument(skip_all)]
273    async fn plan_pql(&self, stmt: &EvalStmt, query_ctx: QueryContextRef) -> Result<LogicalPlan> {
274        let plan_decoder = Arc::new(DefaultPlanDecoder::new(
275            self.session_state.clone(),
276            &query_ctx,
277        )?);
278        let table_provider = DfTableSourceProvider::new(
279            self.engine_state.catalog_manager().clone(),
280            self.engine_state.disallow_cross_catalog_query(),
281            query_ctx.clone(),
282            plan_decoder,
283            self.session_state
284                .config_options()
285                .sql_parser
286                .enable_ident_normalization,
287        );
288        let plan = PromPlanner::stmt_to_plan(table_provider, stmt, &self.engine_state)
289            .await
290            .map_err(BoxedError::new)
291            .context(QueryPlanSnafu)?;
292
293        let context = QueryEngineContext::new(self.session_state.clone(), query_ctx);
294        Ok(self
295            .engine_state
296            .optimize_by_extension_rules(plan, &context)?)
297    }
298
299    #[tracing::instrument(skip_all)]
300    fn optimize_logical_plan(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
301        Ok(self.engine_state.optimize_logical_plan(plan)?)
302    }
303
304    /// Check if a statement contains hybrid CTEs (mix of SQL and TQL)
305    fn has_hybrid_ctes(&self, stmt: &Statement) -> bool {
306        if let Statement::Query(query) = stmt {
307            query
308                .hybrid_cte
309                .as_ref()
310                .map(|hybrid_cte| !hybrid_cte.cte_tables.is_empty())
311                .unwrap_or(false)
312        } else {
313            false
314        }
315    }
316
317    /// Plan a query with hybrid CTEs using DataFusion's native PlannerContext
318    async fn plan_query_with_hybrid_ctes(
319        &self,
320        query: &Query,
321        query_ctx: QueryContextRef,
322        planner_context: &mut PlannerContext,
323    ) -> Result<()> {
324        let hybrid_cte = query.hybrid_cte.as_ref().unwrap();
325
326        for cte in &hybrid_cte.cte_tables {
327            match &cte.content {
328                CteContent::Tql(tql) => {
329                    // Plan TQL and register in PlannerContext
330                    let mut logical_plan = self.tql_to_logical_plan(tql, query_ctx.clone()).await?;
331                    if !cte.columns.is_empty() {
332                        let schema = logical_plan.schema();
333                        let schema_fields = schema.fields().to_vec();
334                        ensure!(
335                            schema_fields.len() == cte.columns.len(),
336                            CteColumnSchemaMismatchSnafu {
337                                cte_name: cte.name.value.clone(),
338                                original: schema_fields
339                                    .iter()
340                                    .map(|field| field.name().clone())
341                                    .collect::<Vec<_>>(),
342                                expected: cte
343                                    .columns
344                                    .iter()
345                                    .map(|column| column.to_string())
346                                    .collect::<Vec<_>>(),
347                            }
348                        );
349                        let aliases = cte
350                            .columns
351                            .iter()
352                            .zip(schema_fields.iter())
353                            .map(|(column, field)| col(field.name()).alias(column.to_string()));
354                        logical_plan = LogicalPlanBuilder::from(logical_plan)
355                            .project(aliases)
356                            .context(PlanSqlSnafu)?
357                            .build()
358                            .context(PlanSqlSnafu)?;
359                    }
360
361                    // Wrap in SubqueryAlias to ensure proper table qualification for CTE
362                    logical_plan = LogicalPlan::SubqueryAlias(
363                        datafusion_expr::SubqueryAlias::try_new(
364                            Arc::new(logical_plan),
365                            cte.name.value.clone(),
366                        )
367                        .context(PlanSqlSnafu)?,
368                    );
369
370                    planner_context.insert_cte(&cte.name.value, logical_plan);
371                }
372                CteContent::Sql(_) => {
373                    // SQL CTEs should have been moved to the main query's WITH clause
374                    // during parsing, so we shouldn't encounter them here
375                    unreachable!("SQL CTEs should not be in hybrid_cte.cte_tables");
376                }
377            }
378        }
379
380        Ok(())
381    }
382
383    /// Convert TQL to LogicalPlan directly
384    async fn tql_to_logical_plan(
385        &self,
386        tql: &Tql,
387        query_ctx: QueryContextRef,
388    ) -> Result<LogicalPlan> {
389        match tql {
390            Tql::Eval(eval) => {
391                // Convert TqlEval to PromQuery then to QueryStatement::Promql
392                let prom_query = PromQuery {
393                    query: eval.query.clone(),
394                    start: eval.start.clone(),
395                    end: eval.end.clone(),
396                    step: eval.step.clone(),
397                    lookback: eval
398                        .lookback
399                        .clone()
400                        .unwrap_or_else(|| DEFAULT_LOOKBACK_STRING.to_string()),
401                    alias: eval.alias.clone(),
402                };
403                let stmt = QueryLanguageParser::parse_promql(&prom_query, &query_ctx)?;
404
405                self.plan(&stmt, query_ctx).await
406            }
407            Tql::Explain(_) => UnimplementedSnafu {
408                operation: "TQL EXPLAIN in CTEs",
409            }
410            .fail(),
411            Tql::Analyze(_) => UnimplementedSnafu {
412                operation: "TQL ANALYZE in CTEs",
413            }
414            .fail(),
415        }
416    }
417
418    /// Extracts cast types for all placeholders in a logical plan.
419    /// Returns a map where each placeholder ID is mapped to:
420    /// - Some(DataType) if the placeholder is cast to a specific type
421    /// - None if the placeholder exists but has no cast
422    ///
423    /// Example: `$1::TEXT` returns `{"$1": Some(DataType::Utf8)}`
424    ///
425    /// This function walks through all expressions in the logical plan,
426    /// including subqueries, to identify placeholders and their cast types.
427    fn extract_placeholder_cast_types(
428        plan: &LogicalPlan,
429    ) -> Result<HashMap<String, Option<DataType>>> {
430        let mut placeholder_types = HashMap::new();
431        let mut casted_placeholders = HashSet::new();
432
433        Self::extract_from_plan(plan, &mut placeholder_types, &mut casted_placeholders)?;
434
435        Ok(placeholder_types)
436    }
437
438    fn extract_from_plan(
439        plan: &LogicalPlan,
440        placeholder_types: &mut HashMap<String, Option<DataType>>,
441        casted_placeholders: &mut HashSet<String>,
442    ) -> Result<()> {
443        plan.apply(|node| {
444            for expr in node.expressions() {
445                let _ = expr.apply(|e| {
446                    // Handle casted placeholders
447                    if let DfExpr::Cast(cast) = e
448                        && let DfExpr::Placeholder(ph) = &*cast.expr
449                    {
450                        placeholder_types.insert(ph.id.clone(), Some(cast.data_type.clone()));
451                        casted_placeholders.insert(ph.id.clone());
452                    }
453
454                    // Handle bare (non-casted) placeholders
455                    if let DfExpr::Placeholder(ph) = e
456                        && !casted_placeholders.contains(&ph.id)
457                        && !placeholder_types.contains_key(&ph.id)
458                    {
459                        placeholder_types.insert(ph.id.clone(), None);
460                    }
461
462                    // Recurse into subquery plans embedded in expressions
463                    match e {
464                        DfExpr::Exists(Exists { subquery, .. })
465                        | DfExpr::InSubquery(InSubquery { subquery, .. })
466                        | DfExpr::ScalarSubquery(subquery) => {
467                            Self::extract_from_plan(
468                                &subquery.subquery,
469                                placeholder_types,
470                                casted_placeholders,
471                            )?;
472                        }
473                        _ => {}
474                    }
475
476                    Ok(TreeNodeRecursion::Continue)
477                });
478            }
479            Ok(TreeNodeRecursion::Continue)
480        })?;
481        Ok(())
482    }
483
484    /// Gets inferred parameter types from a logical plan.
485    /// Returns a map where each parameter ID is mapped to:
486    /// - Some(DataType) if the parameter type could be inferred
487    /// - None if the parameter type could not be inferred
488    ///
489    /// This function first uses DataFusion's `get_parameter_types()` to infer types.
490    /// If any parameters have `None` values (i.e., DataFusion couldn't infer their types),
491    /// it falls back to using `extract_placeholder_cast_types()` to detect explicit casts.
492    ///
493    /// This is because datafusion can only infer types for a limited cases.
494    ///
495    /// Example: For query `WHERE $1::TEXT AND $2`, DataFusion may not infer `$2`'s type,
496    /// but this function will return `{"$1": Some(DataType::Utf8), "$2": None}`.
497    pub fn get_inferred_parameter_types(
498        plan: &LogicalPlan,
499    ) -> Result<HashMap<String, Option<DataType>>> {
500        let param_types = plan.get_parameter_types().context(PlanSqlSnafu)?;
501
502        let has_none = param_types.values().any(|v| v.is_none());
503
504        if !has_none {
505            Ok(param_types)
506        } else {
507            let cast_types = Self::extract_placeholder_cast_types(plan)?;
508
509            let mut merged = param_types;
510
511            for (id, opt_type) in cast_types {
512                merged
513                    .entry(id)
514                    .and_modify(|existing| {
515                        if existing.is_none() {
516                            *existing = opt_type.clone();
517                        }
518                    })
519                    .or_insert(opt_type);
520            }
521
522            Ok(merged)
523        }
524    }
525}
526
527#[async_trait]
528impl LogicalPlanner for DfLogicalPlanner {
529    #[tracing::instrument(skip_all)]
530    async fn plan(&self, stmt: &QueryStatement, query_ctx: QueryContextRef) -> Result<LogicalPlan> {
531        match stmt {
532            QueryStatement::Sql(stmt) => self.plan_sql(stmt, query_ctx).await,
533            QueryStatement::Promql(stmt, _alias) => self.plan_pql(stmt, query_ctx).await,
534        }
535    }
536
537    async fn plan_logs_query(
538        &self,
539        query: LogQuery,
540        query_ctx: QueryContextRef,
541    ) -> Result<LogicalPlan> {
542        let plan_decoder = Arc::new(DefaultPlanDecoder::new(
543            self.session_state.clone(),
544            &query_ctx,
545        )?);
546        let table_provider = DfTableSourceProvider::new(
547            self.engine_state.catalog_manager().clone(),
548            self.engine_state.disallow_cross_catalog_query(),
549            query_ctx,
550            plan_decoder,
551            self.session_state
552                .config_options()
553                .sql_parser
554                .enable_ident_normalization,
555        );
556
557        let mut planner = LogQueryPlanner::new(table_provider, self.session_state.clone());
558        planner
559            .query_to_plan(query)
560            .await
561            .map_err(BoxedError::new)
562            .context(QueryPlanSnafu)
563    }
564
565    fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
566        self.optimize_logical_plan(plan)
567    }
568
569    fn as_any(&self) -> &dyn Any {
570        self
571    }
572}
573
574#[cfg(test)]
575mod tests {
576    use std::sync::Arc;
577
578    use arrow_schema::DataType;
579    use catalog::RegisterTableRequest;
580    use catalog::memory::MemoryCatalogManager;
581    use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
582    use datatypes::prelude::ConcreteDataType;
583    use datatypes::schema::{ColumnSchema, Schema};
584    use session::context::QueryContext;
585    use store_api::metric_engine_consts::{
586        DATA_SCHEMA_TABLE_ID_COLUMN_NAME, DATA_SCHEMA_TSID_COLUMN_NAME, LOGICAL_TABLE_METADATA_KEY,
587        METRIC_ENGINE_NAME,
588    };
589    use table::metadata::{TableInfoBuilder, TableMetaBuilder};
590    use table::test_util::EmptyTable;
591
592    use super::*;
593    use crate::parser::{PromQuery, QueryLanguageParser};
594    use crate::{QueryEngineFactory, QueryEngineRef};
595
596    async fn create_test_engine() -> QueryEngineRef {
597        let columns = vec![
598            ColumnSchema::new("id", ConcreteDataType::int32_datatype(), false),
599            ColumnSchema::new("name", ConcreteDataType::string_datatype(), true),
600        ];
601        let schema = Arc::new(Schema::new(columns));
602        let table_meta = TableMetaBuilder::empty()
603            .schema(schema)
604            .primary_key_indices(vec![0])
605            .value_indices(vec![1])
606            .next_column_id(1024)
607            .build()
608            .unwrap();
609        let table_info = TableInfoBuilder::new("test", table_meta).build().unwrap();
610        let table = EmptyTable::from_table_info(&table_info);
611
612        crate::tests::new_query_engine_with_table(table)
613    }
614
615    fn create_promql_test_engine() -> QueryEngineRef {
616        let catalog_manager = MemoryCatalogManager::with_default_setup();
617        let physical_table_name = "phy";
618        let physical_table_id = 999u32;
619
620        let physical_schema = Arc::new(Schema::new(vec![
621            ColumnSchema::new(
622                DATA_SCHEMA_TABLE_ID_COLUMN_NAME.to_string(),
623                ConcreteDataType::uint32_datatype(),
624                false,
625            ),
626            ColumnSchema::new(
627                DATA_SCHEMA_TSID_COLUMN_NAME.to_string(),
628                ConcreteDataType::uint64_datatype(),
629                false,
630            ),
631            ColumnSchema::new("tag_0", ConcreteDataType::string_datatype(), false),
632            ColumnSchema::new("tag_1", ConcreteDataType::string_datatype(), false),
633            ColumnSchema::new(
634                "timestamp",
635                ConcreteDataType::timestamp_millisecond_datatype(),
636                false,
637            )
638            .with_time_index(true),
639            ColumnSchema::new("field_0", ConcreteDataType::float64_datatype(), true),
640        ]));
641        let physical_meta = TableMetaBuilder::empty()
642            .schema(physical_schema)
643            .primary_key_indices(vec![0, 1, 2, 3])
644            .value_indices(vec![4, 5])
645            .engine(METRIC_ENGINE_NAME.to_string())
646            .next_column_id(1024)
647            .build()
648            .unwrap();
649        let physical_info = TableInfoBuilder::default()
650            .table_id(physical_table_id)
651            .name(physical_table_name)
652            .meta(physical_meta)
653            .build()
654            .unwrap();
655        catalog_manager
656            .register_table_sync(RegisterTableRequest {
657                catalog: DEFAULT_CATALOG_NAME.to_string(),
658                schema: DEFAULT_SCHEMA_NAME.to_string(),
659                table_name: physical_table_name.to_string(),
660                table_id: physical_table_id,
661                table: EmptyTable::from_table_info(&physical_info),
662            })
663            .unwrap();
664
665        let mut options = table::requests::TableOptions::default();
666        options.extra_options.insert(
667            LOGICAL_TABLE_METADATA_KEY.to_string(),
668            physical_table_name.to_string(),
669        );
670        let logical_schema = Arc::new(Schema::new(vec![
671            ColumnSchema::new("tag_0", ConcreteDataType::string_datatype(), false),
672            ColumnSchema::new("tag_1", ConcreteDataType::string_datatype(), false),
673            ColumnSchema::new(
674                "timestamp",
675                ConcreteDataType::timestamp_millisecond_datatype(),
676                false,
677            )
678            .with_time_index(true),
679            ColumnSchema::new("field_0", ConcreteDataType::float64_datatype(), true),
680        ]));
681        let logical_meta = TableMetaBuilder::empty()
682            .schema(logical_schema)
683            .primary_key_indices(vec![0, 1])
684            .value_indices(vec![3])
685            .engine(METRIC_ENGINE_NAME.to_string())
686            .options(options)
687            .next_column_id(1024)
688            .build()
689            .unwrap();
690        let logical_info = TableInfoBuilder::default()
691            .table_id(1024)
692            .name("some_metric")
693            .meta(logical_meta)
694            .build()
695            .unwrap();
696        catalog_manager
697            .register_table_sync(RegisterTableRequest {
698                catalog: DEFAULT_CATALOG_NAME.to_string(),
699                schema: DEFAULT_SCHEMA_NAME.to_string(),
700                table_name: "some_metric".to_string(),
701                table_id: 1024,
702                table: EmptyTable::from_table_info(&logical_info),
703            })
704            .unwrap();
705
706        QueryEngineFactory::new(
707            catalog_manager,
708            None,
709            None,
710            None,
711            None,
712            false,
713            crate::options::QueryOptions::default(),
714        )
715        .query_engine()
716    }
717
718    async fn parse_sql_to_plan(sql: &str) -> LogicalPlan {
719        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
720        let engine = create_test_engine().await;
721        engine
722            .planner()
723            .plan(&stmt, QueryContext::arc())
724            .await
725            .unwrap()
726    }
727
728    async fn parse_promql_to_plan(query: &str) -> LogicalPlan {
729        let engine = create_promql_test_engine();
730        let query_ctx = QueryContext::arc();
731        let stmt = QueryLanguageParser::parse_promql(
732            &PromQuery {
733                query: query.to_string(),
734                start: "0".to_string(),
735                end: "10".to_string(),
736                step: "5s".to_string(),
737                lookback: "300s".to_string(),
738                alias: None,
739            },
740            &query_ctx,
741        )
742        .unwrap();
743
744        engine.planner().plan(&stmt, query_ctx).await.unwrap()
745    }
746
747    #[tokio::test]
748    async fn test_extract_placeholder_cast_types_multiple() {
749        let plan = parse_sql_to_plan(
750            "SELECT $1::INT, $2::TEXT, $3, $4::INTEGER FROM test WHERE $5::FLOAT > 0",
751        )
752        .await;
753        let types = DfLogicalPlanner::extract_placeholder_cast_types(&plan).unwrap();
754
755        assert_eq!(types.len(), 5);
756        assert_eq!(types.get("$1"), Some(&Some(DataType::Int32)));
757        assert_eq!(types.get("$2"), Some(&Some(DataType::Utf8)));
758        assert_eq!(types.get("$3"), Some(&None));
759        assert_eq!(types.get("$4"), Some(&Some(DataType::Int32)));
760        assert_eq!(types.get("$5"), Some(&Some(DataType::Float32)));
761    }
762
763    #[tokio::test]
764    async fn test_get_inferred_parameter_types_fallback_for_udf_args() {
765        // datafusion is not able to infer type for scalar function arguments
766        let plan = parse_sql_to_plan(
767            "SELECT parse_ident($1), parse_ident($2::TEXT) FROM test WHERE id > $3",
768        )
769        .await;
770        let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap();
771
772        assert_eq!(types.len(), 3);
773
774        let type_1 = types.get("$1").unwrap();
775        let type_2 = types.get("$2").unwrap();
776        let type_3 = types.get("$3").unwrap();
777
778        assert!(type_1.is_none(), "Expected $1 to be None");
779        assert_eq!(type_2, &Some(DataType::Utf8));
780        assert_eq!(type_3, &Some(DataType::Int32));
781    }
782
783    #[tokio::test]
784    async fn test_plan_pql_applies_extension_rules() {
785        for inner_agg in ["count", "sum", "avg", "min", "max", "stddev", "stdvar"] {
786            let plan = parse_promql_to_plan(&format!(
787                "sum(irate(some_metric[1h])) / scalar(count({inner_agg}(some_metric) by (tag_0)))"
788            ))
789            .await;
790            let plan_str = plan.display_indent_schema().to_string();
791            assert!(plan_str.contains("Distinct:"), "{inner_agg}: {plan_str}");
792        }
793    }
794
795    #[tokio::test]
796    async fn test_plan_pql_filters_null_only_groups_for_non_count_inner_aggs() {
797        let count_plan = parse_promql_to_plan("scalar(count(count(some_metric) by (tag_0)))").await;
798        let count_plan_str = count_plan.display_indent_schema().to_string();
799        assert!(
800            !count_plan_str.contains("field_0 IS NOT NULL"),
801            "{count_plan_str}"
802        );
803
804        for inner_agg in ["sum", "avg", "min", "max", "stddev", "stdvar"] {
805            let plan = parse_promql_to_plan(&format!(
806                "scalar(count({inner_agg}(some_metric) by (tag_0)))"
807            ))
808            .await;
809            let plan_str = plan.display_indent_schema().to_string();
810            assert!(
811                plan_str.contains("field_0 IS NOT NULL"),
812                "{inner_agg}: {plan_str}"
813            );
814        }
815    }
816
817    #[tokio::test]
818    async fn test_plan_pql_skips_extension_rules_for_non_direct_or_unsupported_inner_agg() {
819        for query in [
820            "sum(irate(some_metric[1h])) / scalar(count(sum(irate(some_metric[1h])) by (tag_0)))",
821            "sum(irate(some_metric[1h])) / scalar(count(group(some_metric) by (tag_0)))",
822        ] {
823            let plan = parse_promql_to_plan(query).await;
824            let plan_str = plan.display_indent_schema().to_string();
825            assert!(!plan_str.contains("Distinct:"), "{query}: {plan_str}");
826        }
827    }
828
829    #[tokio::test]
830    async fn test_plan_sql_does_not_apply_nested_count_rule() {
831        let plan = parse_sql_to_plan(
832            "SELECT id, count(inner_count) \
833             FROM ( \
834                 SELECT id, count(name) AS inner_count \
835                 FROM test \
836                 GROUP BY id \
837                 ORDER BY id \
838                 LIMIT 1000000 \
839             ) t \
840             GROUP BY id \
841             ORDER BY id",
842        )
843        .await;
844
845        let plan_str = plan.display_indent_schema().to_string();
846        assert!(!plan_str.contains("Distinct:"), "{plan_str}");
847    }
848
849    #[tokio::test]
850    async fn test_get_inferred_parameter_types_subquery() {
851        let plan = parse_sql_to_plan(
852            r#"SELECT * FROM test WHERE id = (SELECT id FROM test CROSS JOIN (SELECT parse_ident($1::TEXT) AS parts) p LIMIT 1)"#,
853        ).await;
854        let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap();
855
856        assert_eq!(types.len(), 1);
857        let type_1 = types.get("$1").unwrap();
858        assert_eq!(type_1, &Some(DataType::Utf8));
859    }
860
861    #[tokio::test]
862    async fn test_get_inferred_parameter_types_insert() {
863        let plan = parse_sql_to_plan("INSERT INTO test (id, name) VALUES ($1, $2), ($3, $4)").await;
864        let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap();
865
866        assert_eq!(types.len(), 4);
867        assert_eq!(types.get("$1"), Some(&Some(DataType::Int32)));
868        assert_eq!(types.get("$2"), Some(&Some(DataType::Utf8)));
869        assert_eq!(types.get("$3"), Some(&Some(DataType::Int32)));
870        assert_eq!(types.get("$4"), Some(&Some(DataType::Utf8)));
871    }
872}