Skip to main content

sql/parsers/
utils.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::sync::Arc;
17
18use datafusion::config::ConfigOptions;
19use datafusion::error::Result as DfResult;
20use datafusion::execution::SessionStateBuilder;
21use datafusion::execution::context::SessionState;
22use datafusion::optimizer::simplify_expressions::ExprSimplifier;
23use datafusion_common::tree_node::{TreeNode, TreeNodeVisitor};
24use datafusion_common::{DFSchema, ScalarValue};
25use datafusion_expr::simplify::SimplifyContext;
26use datafusion_expr::{AggregateUDF, Expr, ScalarUDF, TableSource, WindowUDF};
27use datafusion_sql::TableReference;
28use datafusion_sql::planner::{ContextProvider, SqlToRel};
29use datatypes::arrow::datatypes::DataType;
30use datatypes::schema::{
31    COLUMN_FULLTEXT_OPT_KEY_ANALYZER, COLUMN_FULLTEXT_OPT_KEY_BACKEND,
32    COLUMN_FULLTEXT_OPT_KEY_CASE_SENSITIVE, COLUMN_FULLTEXT_OPT_KEY_FALSE_POSITIVE_RATE,
33    COLUMN_FULLTEXT_OPT_KEY_GRANULARITY, COLUMN_SKIPPING_INDEX_OPT_KEY_FALSE_POSITIVE_RATE,
34    COLUMN_SKIPPING_INDEX_OPT_KEY_GRANULARITY, COLUMN_SKIPPING_INDEX_OPT_KEY_TYPE,
35    COLUMN_VECTOR_INDEX_OPT_KEY_CONNECTIVITY, COLUMN_VECTOR_INDEX_OPT_KEY_ENGINE,
36    COLUMN_VECTOR_INDEX_OPT_KEY_EXPANSION_ADD, COLUMN_VECTOR_INDEX_OPT_KEY_EXPANSION_SEARCH,
37    COLUMN_VECTOR_INDEX_OPT_KEY_METRIC,
38};
39use snafu::{ResultExt, ensure};
40use sqlparser::dialect::Dialect;
41use sqlparser::keywords::Keyword;
42use sqlparser::parser::Parser;
43use table::requests::{SEMANTIC_PREFIX, validate_semantic_option, validate_table_option};
44
45use crate::error::{
46    ConvertToLogicalExpressionSnafu, InvalidSqlSnafu, InvalidTableOptionSnafu, ParseSqlValueSnafu,
47    Result, SimplificationSnafu, SyntaxSnafu,
48};
49use crate::parser::{ParseOptions, ParserContext};
50use crate::parsers::with_tql_parser::CteContent;
51use crate::statements::OptionMap;
52use crate::statements::query::Query;
53use crate::statements::statement::Statement;
54use crate::util::{OptionValue, parse_option_string};
55
56/// Check if the given SQL query is a TQL statement. Simple tql cte query is also considered as TQL statement.
57pub fn is_tql(dialect: &dyn Dialect, sql: &str) -> Result<bool> {
58    let stmts = ParserContext::create_with_dialect(sql, dialect, ParseOptions::default())?;
59
60    ensure!(
61        stmts.len() == 1,
62        InvalidSqlSnafu {
63            msg: format!("Expect only one statement, found {}", stmts.len())
64        }
65    );
66    let stmt = &stmts[0];
67    match stmt {
68        Statement::Tql(_) => Ok(true),
69        Statement::Query(query) => Ok(is_simple_tql_cte_query(query)),
70        _ => Ok(false),
71    }
72}
73
74pub(crate) fn is_simple_tql_cte_query(query: &Query) -> bool {
75    use crate::parser::ParserContext;
76
77    let Some(hybrid_cte) = &query.hybrid_cte else {
78        return false;
79    };
80
81    if !has_only_hybrid_tql_cte(query) {
82        return false;
83    }
84
85    let Some(cte) = hybrid_cte.cte_tables.first() else {
86        return false;
87    };
88    if hybrid_cte.cte_tables.len() != 1 || !matches!(cte.content, CteContent::Tql(_)) {
89        return false;
90    }
91
92    let Some(reference) = extract_simple_select_star_reference(query) else {
93        return false;
94    };
95
96    let reference = ParserContext::canonicalize_identifier(reference).value;
97    let cte_name = ParserContext::canonicalize_identifier(cte.name.clone()).value;
98    reference == cte_name
99}
100
101pub(crate) fn has_tql_cte(query: &Query) -> bool {
102    query.hybrid_cte.as_ref().is_some_and(|with| {
103        with.cte_tables
104            .iter()
105            .any(|cte| matches!(cte.content, CteContent::Tql(_)))
106    })
107}
108
109fn has_only_hybrid_tql_cte(query: &Query) -> bool {
110    query
111        .inner
112        .with
113        .as_ref()
114        .is_none_or(|with| with.cte_tables.is_empty())
115}
116
117fn extract_simple_select_star_reference(query: &Query) -> Option<sqlparser::ast::Ident> {
118    use sqlparser::ast::{SetExpr, TableFactor};
119
120    if !is_plain_query_root(&query.inner) {
121        return None;
122    }
123
124    let SetExpr::Select(select) = &*query.inner.body else {
125        return None;
126    };
127    if !is_plain_select(select) || !is_plain_wildcard_projection(select.projection.as_slice()) {
128        return None;
129    }
130
131    let [table_with_joins] = select.from.as_slice() else {
132        return None;
133    };
134    if !table_with_joins.joins.is_empty() {
135        return None;
136    }
137
138    let TableFactor::Table { name, .. } = &table_with_joins.relation else {
139        return None;
140    };
141    if name.0.len() != 1 {
142        return None;
143    }
144
145    name.0[0].as_ident().cloned()
146}
147
148fn is_plain_query_root(query: &sqlparser::ast::Query) -> bool {
149    query.order_by.is_none()
150        && query.limit_clause.is_none()
151        && query.fetch.is_none()
152        && query.locks.is_empty()
153        && query.for_clause.is_none()
154        && query.settings.is_none()
155        && query.format_clause.is_none()
156        && query.pipe_operators.is_empty()
157}
158
159fn is_plain_select(select: &sqlparser::ast::Select) -> bool {
160    use sqlparser::ast::GroupByExpr;
161
162    select.distinct.is_none()
163        && select.top.is_none()
164        && select.exclude.is_none()
165        && select.into.is_none()
166        && select.lateral_views.is_empty()
167        && select.prewhere.is_none()
168        && select.selection.is_none()
169        && matches!(select.group_by, GroupByExpr::Expressions(ref exprs, _) if exprs.is_empty())
170        && select.cluster_by.is_empty()
171        && select.distribute_by.is_empty()
172        && select.sort_by.is_empty()
173        && select.having.is_none()
174        && select.named_window.is_empty()
175        && select.qualify.is_none()
176        && select.value_table_mode.is_none()
177        && select.connect_by.is_empty()
178}
179
180fn is_plain_wildcard_projection(projection: &[sqlparser::ast::SelectItem]) -> bool {
181    use sqlparser::ast::SelectItem;
182
183    matches!(
184        projection,
185        [SelectItem::Wildcard(options)]
186            if options.opt_ilike.is_none()
187                && options.opt_exclude.is_none()
188                && options.opt_except.is_none()
189                && options.opt_replace.is_none()
190                && options.opt_rename.is_none()
191    )
192}
193
194/// Convert a parser expression to a scalar value. This function will try the
195/// best to resolve and reduce constants. Exprs like `1 + 1` or `now()` can be
196/// handled properly.
197///
198/// if `require_now_expr` is true, it will ensure that the expression contains a `now()` function.
199/// If the expression does not contain `now()`, it will return an error.
200///
201pub fn parser_expr_to_scalar_value_literal(
202    expr: sqlparser::ast::Expr,
203    require_now_expr: bool,
204) -> Result<ScalarValue> {
205    // 1. convert parser expr to logical expr
206    let empty_df_schema = DFSchema::empty();
207    let logical_expr = SqlToRel::new(&StubContextProvider::default())
208        .sql_to_expr(expr, &empty_df_schema, &mut Default::default())
209        .context(ConvertToLogicalExpressionSnafu)?;
210
211    struct FindNow {
212        found: bool,
213    }
214
215    impl TreeNodeVisitor<'_> for FindNow {
216        type Node = Expr;
217        fn f_down(
218            &mut self,
219            node: &Self::Node,
220        ) -> DfResult<datafusion_common::tree_node::TreeNodeRecursion> {
221            if let Expr::ScalarFunction(func) = node
222                && func.name().to_lowercase() == "now"
223            {
224                if !func.args.is_empty() {
225                    return Err(datafusion_common::DataFusionError::Plan(
226                        "now() function should not have arguments".to_string(),
227                    ));
228                }
229                self.found = true;
230                return Ok(datafusion_common::tree_node::TreeNodeRecursion::Stop);
231            }
232            Ok(datafusion_common::tree_node::TreeNodeRecursion::Continue)
233        }
234    }
235
236    if require_now_expr {
237        let have_now = {
238            let mut visitor = FindNow { found: false };
239            logical_expr.visit(&mut visitor).unwrap();
240            visitor.found
241        };
242        if !have_now {
243            return ParseSqlValueSnafu {
244                msg: format!(
245                    "expected now() expression, but not found in {}",
246                    logical_expr
247                ),
248            }
249            .fail();
250        }
251    }
252
253    // 2. simplify logical expr
254    let info = SimplifyContext::default().with_current_time();
255    let simplifier = ExprSimplifier::new(info);
256
257    // Coerce the logical expression so simplifier can handle it correctly. This is necessary for const eval with possible type mismatch. i.e.: `now() - now() + '15s'::interval` which is `TimestampNanosecond - TimestampNanosecond + IntervalMonthDayNano`.
258    let logical_expr = simplifier
259        .coerce(logical_expr, &empty_df_schema)
260        .context(SimplificationSnafu)?;
261
262    let simplified_expr = simplifier
263        .simplify(logical_expr)
264        .context(SimplificationSnafu)?;
265
266    if let datafusion::logical_expr::Expr::Literal(lit, _) = simplified_expr {
267        Ok(lit)
268    } else {
269        // Err(ParseSqlValue)
270        ParseSqlValueSnafu {
271            msg: format!("expected literal value, but found {:?}", simplified_expr),
272        }
273        .fail()
274    }
275}
276
277/// Helper struct for [`parser_expr_to_scalar_value`].
278struct StubContextProvider {
279    state: SessionState,
280}
281
282impl Default for StubContextProvider {
283    fn default() -> Self {
284        Self {
285            state: SessionStateBuilder::new()
286                .with_config(Default::default())
287                .with_runtime_env(Default::default())
288                .with_default_features()
289                .build(),
290        }
291    }
292}
293
294impl ContextProvider for StubContextProvider {
295    fn get_table_source(&self, _name: TableReference) -> DfResult<Arc<dyn TableSource>> {
296        unimplemented!()
297    }
298
299    fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
300        self.state.scalar_functions().get(name).cloned()
301    }
302
303    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
304        self.state.aggregate_functions().get(name).cloned()
305    }
306
307    fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
308        unimplemented!()
309    }
310
311    fn get_variable_type(&self, _variable_names: &[String]) -> Option<DataType> {
312        unimplemented!()
313    }
314
315    fn options(&self) -> &ConfigOptions {
316        self.state.config_options()
317    }
318
319    fn udf_names(&self) -> Vec<String> {
320        self.state.scalar_functions().keys().cloned().collect()
321    }
322
323    fn udaf_names(&self) -> Vec<String> {
324        self.state.aggregate_functions().keys().cloned().collect()
325    }
326
327    fn udwf_names(&self) -> Vec<String> {
328        self.state.window_functions().keys().cloned().collect()
329    }
330}
331
332pub fn validate_column_fulltext_create_option(key: &str) -> bool {
333    [
334        COLUMN_FULLTEXT_OPT_KEY_ANALYZER,
335        COLUMN_FULLTEXT_OPT_KEY_CASE_SENSITIVE,
336        COLUMN_FULLTEXT_OPT_KEY_BACKEND,
337        COLUMN_FULLTEXT_OPT_KEY_GRANULARITY,
338        COLUMN_FULLTEXT_OPT_KEY_FALSE_POSITIVE_RATE,
339    ]
340    .contains(&key)
341}
342
343pub fn validate_column_skipping_index_create_option(key: &str) -> bool {
344    [
345        COLUMN_SKIPPING_INDEX_OPT_KEY_GRANULARITY,
346        COLUMN_SKIPPING_INDEX_OPT_KEY_TYPE,
347        COLUMN_SKIPPING_INDEX_OPT_KEY_FALSE_POSITIVE_RATE,
348    ]
349    .contains(&key)
350}
351
352pub fn validate_column_vector_index_create_option(key: &str) -> bool {
353    [
354        COLUMN_VECTOR_INDEX_OPT_KEY_ENGINE,
355        COLUMN_VECTOR_INDEX_OPT_KEY_METRIC,
356        COLUMN_VECTOR_INDEX_OPT_KEY_CONNECTIVITY,
357        COLUMN_VECTOR_INDEX_OPT_KEY_EXPANSION_ADD,
358        COLUMN_VECTOR_INDEX_OPT_KEY_EXPANSION_SEARCH,
359    ]
360    .contains(&key)
361}
362
363/// Convert an [`IntervalMonthDayNano`] to a [`Duration`].
364#[cfg(feature = "enterprise")]
365pub fn convert_month_day_nano_to_duration(
366    interval: arrow_buffer::IntervalMonthDayNano,
367) -> Result<std::time::Duration> {
368    let months: i64 = interval.months.into();
369    let days: i64 = interval.days.into();
370    let months_in_seconds: i64 = months * 60 * 60 * 24 * 3044 / 1000;
371    let days_in_seconds: i64 = days * 60 * 60 * 24;
372    let seconds_from_nanos = interval.nanoseconds / 1_000_000_000;
373    let total_seconds = months_in_seconds + days_in_seconds + seconds_from_nanos;
374
375    let mut nanos_remainder = interval.nanoseconds % 1_000_000_000;
376    let mut adjusted_seconds = total_seconds;
377
378    if nanos_remainder < 0 {
379        nanos_remainder += 1_000_000_000;
380        adjusted_seconds -= 1;
381    }
382
383    snafu::ensure!(
384        adjusted_seconds >= 0,
385        crate::error::InvalidIntervalSnafu {
386            reason: "must be a positive interval",
387        }
388    );
389
390    // Cast safety: `adjusted_seconds` is guaranteed to be non-negative before.
391    let adjusted_seconds = adjusted_seconds as u64;
392    // Cast safety: `nanos_remainder` is smaller than 1_000_000_000 which
393    // is checked above.
394    let nanos_remainder = nanos_remainder as u32;
395
396    Ok(std::time::Duration::new(adjusted_seconds, nanos_remainder))
397}
398
399pub fn parse_with_options(parser: &mut Parser) -> Result<OptionMap> {
400    let options = parser
401        .parse_options(Keyword::WITH)
402        .context(SyntaxSnafu)?
403        .into_iter()
404        .map(parse_option_string)
405        .collect::<Result<HashMap<String, OptionValue>>>()?;
406    for (key, value) in &options {
407        if key.starts_with(SEMANTIC_PREFIX) {
408            // Semantic keys are whitelisted and value-checked against their domain,
409            // so a user cannot set an unknown key or an out-of-range value.
410            let value = value.as_string().unwrap_or_default();
411            ensure!(
412                validate_semantic_option(key, value),
413                InvalidTableOptionSnafu { key }
414            );
415        } else {
416            ensure!(validate_table_option(key), InvalidTableOptionSnafu { key });
417        }
418    }
419    Ok(OptionMap::new(options))
420}
421
422#[cfg(test)]
423mod tests {
424    use chrono::DateTime;
425    use datafusion::functions::datetime::expr_fn::now;
426    use datafusion_expr::lit;
427    use datatypes::arrow::datatypes::TimestampNanosecondType;
428
429    use super::*;
430    use crate::dialect::GreptimeDbDialect;
431    use crate::parser::{ParseOptions, ParserContext};
432    use crate::statements::statement::Statement;
433
434    #[test]
435    fn test_is_tql() {
436        let dialect = GreptimeDbDialect {};
437
438        assert!(is_tql(&dialect, "TQL EVAL (0, 10, '1s') cpu_usage_total").unwrap());
439        assert!(!is_tql(&dialect, "SELECT 1").unwrap());
440
441        let tql_cte = r#"
442WITH tql_cte(ts, val) AS (
443    TQL EVAL (0, 15, '5s') metric
444)
445SELECT * FROM tql_cte
446"#;
447        assert!(is_tql(&dialect, tql_cte).unwrap());
448
449        let rename_cols = r#"
450WITH tql (the_timestamp, the_value) AS (
451    TQL EVAL (0, 40, '10s') metric
452)
453SELECT * FROM tql
454"#;
455        assert!(is_tql(&dialect, rename_cols).unwrap());
456        let stmts =
457            ParserContext::create_with_dialect(rename_cols, &dialect, ParseOptions::default())
458                .unwrap();
459        let Statement::Query(q) = &stmts[0] else {
460            panic!("Expected Query statement");
461        };
462        let hybrid = q.hybrid_cte.as_ref().expect("Expected hybrid cte");
463        assert_eq!(hybrid.cte_tables.len(), 1);
464        assert_eq!(hybrid.cte_tables[0].columns.len(), 2);
465        assert_eq!(hybrid.cte_tables[0].columns[0].to_string(), "the_timestamp");
466        assert_eq!(hybrid.cte_tables[0].columns[1].to_string(), "the_value");
467
468        let sql_cte = r#"
469WITH cte AS (SELECT 1)
470SELECT * FROM cte
471"#;
472        assert!(!is_tql(&dialect, sql_cte).unwrap());
473
474        let extra_sql_cte = r#"
475WITH sql_cte AS (SELECT 1), tql_cte(ts, val) AS (
476    TQL EVAL (0, 15, '5s') metric
477)
478SELECT * FROM tql_cte
479"#;
480        assert!(!is_tql(&dialect, extra_sql_cte).unwrap());
481
482        let not_select_star = r#"
483WITH tql_cte(ts, val) AS (
484    TQL EVAL (0, 15, '5s') metric
485)
486SELECT ts FROM tql_cte
487"#;
488        assert!(!is_tql(&dialect, not_select_star).unwrap());
489
490        let with_filter = r#"
491WITH tql_cte(ts, val) AS (
492    TQL EVAL (0, 15, '5s') metric
493)
494SELECT * FROM tql_cte WHERE ts > 0
495"#;
496        assert!(!is_tql(&dialect, with_filter).unwrap());
497    }
498
499    /// Keep this test to make sure we are using datafusion's `ExprSimplifier` correctly.
500    #[test]
501    fn test_simplifier() {
502        let now_time = DateTime::from_timestamp(61, 0).unwrap();
503        let lit_now = lit(ScalarValue::new_timestamp::<TimestampNanosecondType>(
504            now_time.timestamp_nanos_opt(),
505            None,
506        ));
507        let testcases = vec![
508            (now(), lit_now),
509            (now() - now(), lit(ScalarValue::DurationNanosecond(Some(0)))),
510            (
511                now() + lit(ScalarValue::new_interval_dt(0, 1500)),
512                lit(ScalarValue::new_timestamp::<TimestampNanosecondType>(
513                    Some(62500000000),
514                    None,
515                )),
516            ),
517            (
518                now() - (now() + lit(ScalarValue::new_interval_dt(0, 1500))),
519                lit(ScalarValue::DurationNanosecond(Some(-1500000000))),
520            ),
521            // this one failed if type is not coerced
522            (
523                now() - now() + lit(ScalarValue::new_interval_dt(0, 1500)),
524                lit(ScalarValue::new_interval_mdn(0, 0, 1500000000)),
525            ),
526            (
527                lit(ScalarValue::new_interval_mdn(
528                    0,
529                    0,
530                    61 * 86400 * 1_000_000_000,
531                )),
532                lit(ScalarValue::new_interval_mdn(
533                    0,
534                    0,
535                    61 * 86400 * 1_000_000_000,
536                )),
537            ),
538        ];
539
540        let info = SimplifyContext::default().with_query_execution_start_time(Some(now_time));
541        let simplifier = ExprSimplifier::new(info);
542        for (expr, expected) in testcases {
543            let expr_name = expr.schema_name().to_string();
544            let expr = simplifier.coerce(expr, &DFSchema::empty()).unwrap();
545
546            let simplified_expr = simplifier.simplify(expr).unwrap();
547            assert_eq!(
548                simplified_expr, expected,
549                "Failed to simplify expression: {expr_name}"
550            );
551        }
552    }
553}