sql/parsers/
utils.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::sync::Arc;
17
18use datafusion::config::ConfigOptions;
19use datafusion::error::Result as DfResult;
20use datafusion::execution::SessionStateBuilder;
21use datafusion::execution::context::SessionState;
22use datafusion::optimizer::simplify_expressions::ExprSimplifier;
23use datafusion_common::tree_node::{TreeNode, TreeNodeVisitor};
24use datafusion_common::{DFSchema, ScalarValue};
25use datafusion_expr::simplify::SimplifyContext;
26use datafusion_expr::{AggregateUDF, Expr, ScalarUDF, TableSource, WindowUDF};
27use datafusion_sql::TableReference;
28use datafusion_sql::planner::{ContextProvider, SqlToRel};
29use datatypes::arrow::datatypes::DataType;
30use datatypes::schema::{
31    COLUMN_FULLTEXT_OPT_KEY_ANALYZER, COLUMN_FULLTEXT_OPT_KEY_BACKEND,
32    COLUMN_FULLTEXT_OPT_KEY_CASE_SENSITIVE, COLUMN_FULLTEXT_OPT_KEY_FALSE_POSITIVE_RATE,
33    COLUMN_FULLTEXT_OPT_KEY_GRANULARITY, COLUMN_SKIPPING_INDEX_OPT_KEY_FALSE_POSITIVE_RATE,
34    COLUMN_SKIPPING_INDEX_OPT_KEY_GRANULARITY, COLUMN_SKIPPING_INDEX_OPT_KEY_TYPE,
35    COLUMN_VECTOR_INDEX_OPT_KEY_CONNECTIVITY, COLUMN_VECTOR_INDEX_OPT_KEY_ENGINE,
36    COLUMN_VECTOR_INDEX_OPT_KEY_EXPANSION_ADD, COLUMN_VECTOR_INDEX_OPT_KEY_EXPANSION_SEARCH,
37    COLUMN_VECTOR_INDEX_OPT_KEY_METRIC,
38};
39use snafu::{ResultExt, ensure};
40use sqlparser::dialect::Dialect;
41use sqlparser::keywords::Keyword;
42use sqlparser::parser::Parser;
43use table::requests::validate_table_option;
44
45use crate::error::{
46    ConvertToLogicalExpressionSnafu, InvalidSqlSnafu, InvalidTableOptionSnafu, ParseSqlValueSnafu,
47    Result, SimplificationSnafu, SyntaxSnafu,
48};
49use crate::parser::{ParseOptions, ParserContext};
50use crate::parsers::with_tql_parser::CteContent;
51use crate::statements::OptionMap;
52use crate::statements::query::Query;
53use crate::statements::statement::Statement;
54use crate::util::{OptionValue, parse_option_string};
55
56/// Check if the given SQL query is a TQL statement. Simple tql cte query is also considered as TQL statement.
57pub fn is_tql(dialect: &dyn Dialect, sql: &str) -> Result<bool> {
58    let stmts = ParserContext::create_with_dialect(sql, dialect, ParseOptions::default())?;
59
60    ensure!(
61        stmts.len() == 1,
62        InvalidSqlSnafu {
63            msg: format!("Expect only one statement, found {}", stmts.len())
64        }
65    );
66    let stmt = &stmts[0];
67    match stmt {
68        Statement::Tql(_) => Ok(true),
69        Statement::Query(query) => Ok(is_simple_tql_cte_query(query)),
70        _ => Ok(false),
71    }
72}
73
74pub(crate) fn is_simple_tql_cte_query(query: &Query) -> bool {
75    use crate::parser::ParserContext;
76
77    let Some(hybrid_cte) = &query.hybrid_cte else {
78        return false;
79    };
80
81    if !has_only_hybrid_tql_cte(query) {
82        return false;
83    }
84
85    let Some(cte) = hybrid_cte.cte_tables.first() else {
86        return false;
87    };
88    if hybrid_cte.cte_tables.len() != 1 || !matches!(cte.content, CteContent::Tql(_)) {
89        return false;
90    }
91
92    let Some(reference) = extract_simple_select_star_reference(query) else {
93        return false;
94    };
95
96    let reference = ParserContext::canonicalize_identifier(reference).value;
97    let cte_name = ParserContext::canonicalize_identifier(cte.name.clone()).value;
98    reference == cte_name
99}
100
101fn has_only_hybrid_tql_cte(query: &Query) -> bool {
102    query
103        .inner
104        .with
105        .as_ref()
106        .is_none_or(|with| with.cte_tables.is_empty())
107}
108
109fn extract_simple_select_star_reference(query: &Query) -> Option<sqlparser::ast::Ident> {
110    use sqlparser::ast::{SetExpr, TableFactor};
111
112    if !is_plain_query_root(&query.inner) {
113        return None;
114    }
115
116    let SetExpr::Select(select) = &*query.inner.body else {
117        return None;
118    };
119    if !is_plain_select(select) || !is_plain_wildcard_projection(select.projection.as_slice()) {
120        return None;
121    }
122
123    let [table_with_joins] = select.from.as_slice() else {
124        return None;
125    };
126    if !table_with_joins.joins.is_empty() {
127        return None;
128    }
129
130    let TableFactor::Table { name, .. } = &table_with_joins.relation else {
131        return None;
132    };
133    if name.0.len() != 1 {
134        return None;
135    }
136
137    name.0[0].as_ident().cloned()
138}
139
140fn is_plain_query_root(query: &sqlparser::ast::Query) -> bool {
141    query.order_by.is_none()
142        && query.limit_clause.is_none()
143        && query.fetch.is_none()
144        && query.locks.is_empty()
145        && query.for_clause.is_none()
146        && query.settings.is_none()
147        && query.format_clause.is_none()
148        && query.pipe_operators.is_empty()
149}
150
151fn is_plain_select(select: &sqlparser::ast::Select) -> bool {
152    use sqlparser::ast::GroupByExpr;
153
154    select.distinct.is_none()
155        && select.top.is_none()
156        && select.exclude.is_none()
157        && select.into.is_none()
158        && select.lateral_views.is_empty()
159        && select.prewhere.is_none()
160        && select.selection.is_none()
161        && matches!(select.group_by, GroupByExpr::Expressions(ref exprs, _) if exprs.is_empty())
162        && select.cluster_by.is_empty()
163        && select.distribute_by.is_empty()
164        && select.sort_by.is_empty()
165        && select.having.is_none()
166        && select.named_window.is_empty()
167        && select.qualify.is_none()
168        && select.value_table_mode.is_none()
169        && select.connect_by.is_empty()
170}
171
172fn is_plain_wildcard_projection(projection: &[sqlparser::ast::SelectItem]) -> bool {
173    use sqlparser::ast::SelectItem;
174
175    matches!(
176        projection,
177        [SelectItem::Wildcard(options)]
178            if options.opt_ilike.is_none()
179                && options.opt_exclude.is_none()
180                && options.opt_except.is_none()
181                && options.opt_replace.is_none()
182                && options.opt_rename.is_none()
183    )
184}
185
186/// Convert a parser expression to a scalar value. This function will try the
187/// best to resolve and reduce constants. Exprs like `1 + 1` or `now()` can be
188/// handled properly.
189///
190/// if `require_now_expr` is true, it will ensure that the expression contains a `now()` function.
191/// If the expression does not contain `now()`, it will return an error.
192///
193pub fn parser_expr_to_scalar_value_literal(
194    expr: sqlparser::ast::Expr,
195    require_now_expr: bool,
196) -> Result<ScalarValue> {
197    // 1. convert parser expr to logical expr
198    let empty_df_schema = DFSchema::empty();
199    let logical_expr = SqlToRel::new(&StubContextProvider::default())
200        .sql_to_expr(expr, &empty_df_schema, &mut Default::default())
201        .context(ConvertToLogicalExpressionSnafu)?;
202
203    struct FindNow {
204        found: bool,
205    }
206
207    impl TreeNodeVisitor<'_> for FindNow {
208        type Node = Expr;
209        fn f_down(
210            &mut self,
211            node: &Self::Node,
212        ) -> DfResult<datafusion_common::tree_node::TreeNodeRecursion> {
213            if let Expr::ScalarFunction(func) = node
214                && func.name().to_lowercase() == "now"
215            {
216                if !func.args.is_empty() {
217                    return Err(datafusion_common::DataFusionError::Plan(
218                        "now() function should not have arguments".to_string(),
219                    ));
220                }
221                self.found = true;
222                return Ok(datafusion_common::tree_node::TreeNodeRecursion::Stop);
223            }
224            Ok(datafusion_common::tree_node::TreeNodeRecursion::Continue)
225        }
226    }
227
228    if require_now_expr {
229        let have_now = {
230            let mut visitor = FindNow { found: false };
231            logical_expr.visit(&mut visitor).unwrap();
232            visitor.found
233        };
234        if !have_now {
235            return ParseSqlValueSnafu {
236                msg: format!(
237                    "expected now() expression, but not found in {}",
238                    logical_expr
239                ),
240            }
241            .fail();
242        }
243    }
244
245    // 2. simplify logical expr
246    let info = SimplifyContext::default().with_current_time();
247    let simplifier = ExprSimplifier::new(info);
248
249    // Coerce the logical expression so simplifier can handle it correctly. This is necessary for const eval with possible type mismatch. i.e.: `now() - now() + '15s'::interval` which is `TimestampNanosecond - TimestampNanosecond + IntervalMonthDayNano`.
250    let logical_expr = simplifier
251        .coerce(logical_expr, &empty_df_schema)
252        .context(SimplificationSnafu)?;
253
254    let simplified_expr = simplifier
255        .simplify(logical_expr)
256        .context(SimplificationSnafu)?;
257
258    if let datafusion::logical_expr::Expr::Literal(lit, _) = simplified_expr {
259        Ok(lit)
260    } else {
261        // Err(ParseSqlValue)
262        ParseSqlValueSnafu {
263            msg: format!("expected literal value, but found {:?}", simplified_expr),
264        }
265        .fail()
266    }
267}
268
269/// Helper struct for [`parser_expr_to_scalar_value`].
270struct StubContextProvider {
271    state: SessionState,
272}
273
274impl Default for StubContextProvider {
275    fn default() -> Self {
276        Self {
277            state: SessionStateBuilder::new()
278                .with_config(Default::default())
279                .with_runtime_env(Default::default())
280                .with_default_features()
281                .build(),
282        }
283    }
284}
285
286impl ContextProvider for StubContextProvider {
287    fn get_table_source(&self, _name: TableReference) -> DfResult<Arc<dyn TableSource>> {
288        unimplemented!()
289    }
290
291    fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
292        self.state.scalar_functions().get(name).cloned()
293    }
294
295    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
296        self.state.aggregate_functions().get(name).cloned()
297    }
298
299    fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
300        unimplemented!()
301    }
302
303    fn get_variable_type(&self, _variable_names: &[String]) -> Option<DataType> {
304        unimplemented!()
305    }
306
307    fn options(&self) -> &ConfigOptions {
308        self.state.config_options()
309    }
310
311    fn udf_names(&self) -> Vec<String> {
312        self.state.scalar_functions().keys().cloned().collect()
313    }
314
315    fn udaf_names(&self) -> Vec<String> {
316        self.state.aggregate_functions().keys().cloned().collect()
317    }
318
319    fn udwf_names(&self) -> Vec<String> {
320        self.state.window_functions().keys().cloned().collect()
321    }
322}
323
324pub fn validate_column_fulltext_create_option(key: &str) -> bool {
325    [
326        COLUMN_FULLTEXT_OPT_KEY_ANALYZER,
327        COLUMN_FULLTEXT_OPT_KEY_CASE_SENSITIVE,
328        COLUMN_FULLTEXT_OPT_KEY_BACKEND,
329        COLUMN_FULLTEXT_OPT_KEY_GRANULARITY,
330        COLUMN_FULLTEXT_OPT_KEY_FALSE_POSITIVE_RATE,
331    ]
332    .contains(&key)
333}
334
335pub fn validate_column_skipping_index_create_option(key: &str) -> bool {
336    [
337        COLUMN_SKIPPING_INDEX_OPT_KEY_GRANULARITY,
338        COLUMN_SKIPPING_INDEX_OPT_KEY_TYPE,
339        COLUMN_SKIPPING_INDEX_OPT_KEY_FALSE_POSITIVE_RATE,
340    ]
341    .contains(&key)
342}
343
344pub fn validate_column_vector_index_create_option(key: &str) -> bool {
345    [
346        COLUMN_VECTOR_INDEX_OPT_KEY_ENGINE,
347        COLUMN_VECTOR_INDEX_OPT_KEY_METRIC,
348        COLUMN_VECTOR_INDEX_OPT_KEY_CONNECTIVITY,
349        COLUMN_VECTOR_INDEX_OPT_KEY_EXPANSION_ADD,
350        COLUMN_VECTOR_INDEX_OPT_KEY_EXPANSION_SEARCH,
351    ]
352    .contains(&key)
353}
354
355/// Convert an [`IntervalMonthDayNano`] to a [`Duration`].
356#[cfg(feature = "enterprise")]
357pub fn convert_month_day_nano_to_duration(
358    interval: arrow_buffer::IntervalMonthDayNano,
359) -> Result<std::time::Duration> {
360    let months: i64 = interval.months.into();
361    let days: i64 = interval.days.into();
362    let months_in_seconds: i64 = months * 60 * 60 * 24 * 3044 / 1000;
363    let days_in_seconds: i64 = days * 60 * 60 * 24;
364    let seconds_from_nanos = interval.nanoseconds / 1_000_000_000;
365    let total_seconds = months_in_seconds + days_in_seconds + seconds_from_nanos;
366
367    let mut nanos_remainder = interval.nanoseconds % 1_000_000_000;
368    let mut adjusted_seconds = total_seconds;
369
370    if nanos_remainder < 0 {
371        nanos_remainder += 1_000_000_000;
372        adjusted_seconds -= 1;
373    }
374
375    snafu::ensure!(
376        adjusted_seconds >= 0,
377        crate::error::InvalidIntervalSnafu {
378            reason: "must be a positive interval",
379        }
380    );
381
382    // Cast safety: `adjusted_seconds` is guaranteed to be non-negative before.
383    let adjusted_seconds = adjusted_seconds as u64;
384    // Cast safety: `nanos_remainder` is smaller than 1_000_000_000 which
385    // is checked above.
386    let nanos_remainder = nanos_remainder as u32;
387
388    Ok(std::time::Duration::new(adjusted_seconds, nanos_remainder))
389}
390
391pub fn parse_with_options(parser: &mut Parser) -> Result<OptionMap> {
392    let options = parser
393        .parse_options(Keyword::WITH)
394        .context(SyntaxSnafu)?
395        .into_iter()
396        .map(parse_option_string)
397        .collect::<Result<HashMap<String, OptionValue>>>()?;
398    for key in options.keys() {
399        ensure!(validate_table_option(key), InvalidTableOptionSnafu { key });
400    }
401    Ok(OptionMap::new(options))
402}
403
404#[cfg(test)]
405mod tests {
406    use chrono::DateTime;
407    use datafusion::functions::datetime::expr_fn::now;
408    use datafusion_expr::lit;
409    use datatypes::arrow::datatypes::TimestampNanosecondType;
410
411    use super::*;
412    use crate::dialect::GreptimeDbDialect;
413    use crate::parser::{ParseOptions, ParserContext};
414    use crate::statements::statement::Statement;
415
416    #[test]
417    fn test_is_tql() {
418        let dialect = GreptimeDbDialect {};
419
420        assert!(is_tql(&dialect, "TQL EVAL (0, 10, '1s') cpu_usage_total").unwrap());
421        assert!(!is_tql(&dialect, "SELECT 1").unwrap());
422
423        let tql_cte = r#"
424WITH tql_cte(ts, val) AS (
425    TQL EVAL (0, 15, '5s') metric
426)
427SELECT * FROM tql_cte
428"#;
429        assert!(is_tql(&dialect, tql_cte).unwrap());
430
431        let rename_cols = r#"
432WITH tql (the_timestamp, the_value) AS (
433    TQL EVAL (0, 40, '10s') metric
434)
435SELECT * FROM tql
436"#;
437        assert!(is_tql(&dialect, rename_cols).unwrap());
438        let stmts =
439            ParserContext::create_with_dialect(rename_cols, &dialect, ParseOptions::default())
440                .unwrap();
441        let Statement::Query(q) = &stmts[0] else {
442            panic!("Expected Query statement");
443        };
444        let hybrid = q.hybrid_cte.as_ref().expect("Expected hybrid cte");
445        assert_eq!(hybrid.cte_tables.len(), 1);
446        assert_eq!(hybrid.cte_tables[0].columns.len(), 2);
447        assert_eq!(hybrid.cte_tables[0].columns[0].to_string(), "the_timestamp");
448        assert_eq!(hybrid.cte_tables[0].columns[1].to_string(), "the_value");
449
450        let sql_cte = r#"
451WITH cte AS (SELECT 1)
452SELECT * FROM cte
453"#;
454        assert!(!is_tql(&dialect, sql_cte).unwrap());
455
456        let extra_sql_cte = r#"
457WITH sql_cte AS (SELECT 1), tql_cte(ts, val) AS (
458    TQL EVAL (0, 15, '5s') metric
459)
460SELECT * FROM tql_cte
461"#;
462        assert!(!is_tql(&dialect, extra_sql_cte).unwrap());
463
464        let not_select_star = r#"
465WITH tql_cte(ts, val) AS (
466    TQL EVAL (0, 15, '5s') metric
467)
468SELECT ts FROM tql_cte
469"#;
470        assert!(!is_tql(&dialect, not_select_star).unwrap());
471
472        let with_filter = r#"
473WITH tql_cte(ts, val) AS (
474    TQL EVAL (0, 15, '5s') metric
475)
476SELECT * FROM tql_cte WHERE ts > 0
477"#;
478        assert!(!is_tql(&dialect, with_filter).unwrap());
479    }
480
481    /// Keep this test to make sure we are using datafusion's `ExprSimplifier` correctly.
482    #[test]
483    fn test_simplifier() {
484        let now_time = DateTime::from_timestamp(61, 0).unwrap();
485        let lit_now = lit(ScalarValue::new_timestamp::<TimestampNanosecondType>(
486            now_time.timestamp_nanos_opt(),
487            None,
488        ));
489        let testcases = vec![
490            (now(), lit_now),
491            (now() - now(), lit(ScalarValue::DurationNanosecond(Some(0)))),
492            (
493                now() + lit(ScalarValue::new_interval_dt(0, 1500)),
494                lit(ScalarValue::new_timestamp::<TimestampNanosecondType>(
495                    Some(62500000000),
496                    None,
497                )),
498            ),
499            (
500                now() - (now() + lit(ScalarValue::new_interval_dt(0, 1500))),
501                lit(ScalarValue::DurationNanosecond(Some(-1500000000))),
502            ),
503            // this one failed if type is not coerced
504            (
505                now() - now() + lit(ScalarValue::new_interval_dt(0, 1500)),
506                lit(ScalarValue::new_interval_mdn(0, 0, 1500000000)),
507            ),
508            (
509                lit(ScalarValue::new_interval_mdn(
510                    0,
511                    0,
512                    61 * 86400 * 1_000_000_000,
513                )),
514                lit(ScalarValue::new_interval_mdn(
515                    0,
516                    0,
517                    61 * 86400 * 1_000_000_000,
518                )),
519            ),
520        ];
521
522        let info = SimplifyContext::default().with_query_execution_start_time(Some(now_time));
523        let simplifier = ExprSimplifier::new(info);
524        for (expr, expected) in testcases {
525            let expr_name = expr.schema_name().to_string();
526            let expr = simplifier.coerce(expr, &DFSchema::empty()).unwrap();
527
528            let simplified_expr = simplifier.simplify(expr).unwrap();
529            assert_eq!(
530                simplified_expr, expected,
531                "Failed to simplify expression: {expr_name}"
532            );
533        }
534    }
535}