sql/parsers/
utils.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::sync::Arc;
17
18use chrono::Utc;
19use datafusion::config::ConfigOptions;
20use datafusion::error::Result as DfResult;
21use datafusion::execution::SessionStateBuilder;
22use datafusion::execution::context::SessionState;
23use datafusion::optimizer::simplify_expressions::ExprSimplifier;
24use datafusion_common::tree_node::{TreeNode, TreeNodeVisitor};
25use datafusion_common::{DFSchema, ScalarValue};
26use datafusion_expr::execution_props::ExecutionProps;
27use datafusion_expr::simplify::SimplifyContext;
28use datafusion_expr::{AggregateUDF, Expr, ScalarUDF, TableSource, WindowUDF};
29use datafusion_sql::TableReference;
30use datafusion_sql::planner::{ContextProvider, SqlToRel};
31use datatypes::arrow::datatypes::DataType;
32use datatypes::schema::{
33    COLUMN_FULLTEXT_OPT_KEY_ANALYZER, COLUMN_FULLTEXT_OPT_KEY_BACKEND,
34    COLUMN_FULLTEXT_OPT_KEY_CASE_SENSITIVE, COLUMN_FULLTEXT_OPT_KEY_FALSE_POSITIVE_RATE,
35    COLUMN_FULLTEXT_OPT_KEY_GRANULARITY, COLUMN_SKIPPING_INDEX_OPT_KEY_FALSE_POSITIVE_RATE,
36    COLUMN_SKIPPING_INDEX_OPT_KEY_GRANULARITY, COLUMN_SKIPPING_INDEX_OPT_KEY_TYPE,
37    COLUMN_VECTOR_INDEX_OPT_KEY_CONNECTIVITY, COLUMN_VECTOR_INDEX_OPT_KEY_ENGINE,
38    COLUMN_VECTOR_INDEX_OPT_KEY_EXPANSION_ADD, COLUMN_VECTOR_INDEX_OPT_KEY_EXPANSION_SEARCH,
39    COLUMN_VECTOR_INDEX_OPT_KEY_METRIC,
40};
41use snafu::{ResultExt, ensure};
42use sqlparser::dialect::Dialect;
43use sqlparser::keywords::Keyword;
44use sqlparser::parser::Parser;
45use table::requests::validate_table_option;
46
47use crate::error::{
48    ConvertToLogicalExpressionSnafu, InvalidSqlSnafu, InvalidTableOptionSnafu, ParseSqlValueSnafu,
49    Result, SimplificationSnafu, SyntaxSnafu,
50};
51use crate::parser::{ParseOptions, ParserContext};
52use crate::statements::OptionMap;
53use crate::statements::statement::Statement;
54use crate::util::{OptionValue, parse_option_string};
55
56/// Check if the given SQL query is a TQL statement.
57pub fn is_tql(dialect: &dyn Dialect, sql: &str) -> Result<bool> {
58    let stmts = ParserContext::create_with_dialect(sql, dialect, ParseOptions::default())?;
59
60    ensure!(
61        stmts.len() == 1,
62        InvalidSqlSnafu {
63            msg: format!("Expect only one statement, found {}", stmts.len())
64        }
65    );
66    let stmt = &stmts[0];
67    match stmt {
68        Statement::Tql(_) => Ok(true),
69        _ => Ok(false),
70    }
71}
72
73/// Convert a parser expression to a scalar value. This function will try the
74/// best to resolve and reduce constants. Exprs like `1 + 1` or `now()` can be
75/// handled properly.
76///
77/// if `require_now_expr` is true, it will ensure that the expression contains a `now()` function.
78/// If the expression does not contain `now()`, it will return an error.
79///
80pub fn parser_expr_to_scalar_value_literal(
81    expr: sqlparser::ast::Expr,
82    require_now_expr: bool,
83) -> Result<ScalarValue> {
84    // 1. convert parser expr to logical expr
85    let empty_df_schema = DFSchema::empty();
86    let logical_expr = SqlToRel::new(&StubContextProvider::default())
87        .sql_to_expr(expr, &empty_df_schema, &mut Default::default())
88        .context(ConvertToLogicalExpressionSnafu)?;
89
90    struct FindNow {
91        found: bool,
92    }
93
94    impl TreeNodeVisitor<'_> for FindNow {
95        type Node = Expr;
96        fn f_down(
97            &mut self,
98            node: &Self::Node,
99        ) -> DfResult<datafusion_common::tree_node::TreeNodeRecursion> {
100            if let Expr::ScalarFunction(func) = node
101                && func.name().to_lowercase() == "now"
102            {
103                if !func.args.is_empty() {
104                    return Err(datafusion_common::DataFusionError::Plan(
105                        "now() function should not have arguments".to_string(),
106                    ));
107                }
108                self.found = true;
109                return Ok(datafusion_common::tree_node::TreeNodeRecursion::Stop);
110            }
111            Ok(datafusion_common::tree_node::TreeNodeRecursion::Continue)
112        }
113    }
114
115    if require_now_expr {
116        let have_now = {
117            let mut visitor = FindNow { found: false };
118            logical_expr.visit(&mut visitor).unwrap();
119            visitor.found
120        };
121        if !have_now {
122            return ParseSqlValueSnafu {
123                msg: format!(
124                    "expected now() expression, but not found in {}",
125                    logical_expr
126                ),
127            }
128            .fail();
129        }
130    }
131
132    // 2. simplify logical expr
133    let execution_props = ExecutionProps::new().with_query_execution_start_time(Utc::now());
134    let info =
135        SimplifyContext::new(&execution_props).with_schema(Arc::new(empty_df_schema.clone()));
136
137    let simplifier = ExprSimplifier::new(info);
138
139    // Coerce the logical expression so simplifier can handle it correctly. This is necessary for const eval with possible type mismatch. i.e.: `now() - now() + '15s'::interval` which is `TimestampNanosecond - TimestampNanosecond + IntervalMonthDayNano`.
140    let logical_expr = simplifier
141        .coerce(logical_expr, &empty_df_schema)
142        .context(SimplificationSnafu)?;
143
144    let simplified_expr = simplifier
145        .simplify(logical_expr)
146        .context(SimplificationSnafu)?;
147
148    if let datafusion::logical_expr::Expr::Literal(lit, _) = simplified_expr {
149        Ok(lit)
150    } else {
151        // Err(ParseSqlValue)
152        ParseSqlValueSnafu {
153            msg: format!("expected literal value, but found {:?}", simplified_expr),
154        }
155        .fail()
156    }
157}
158
159/// Helper struct for [`parser_expr_to_scalar_value`].
160struct StubContextProvider {
161    state: SessionState,
162}
163
164impl Default for StubContextProvider {
165    fn default() -> Self {
166        Self {
167            state: SessionStateBuilder::new()
168                .with_config(Default::default())
169                .with_runtime_env(Default::default())
170                .with_default_features()
171                .build(),
172        }
173    }
174}
175
176impl ContextProvider for StubContextProvider {
177    fn get_table_source(&self, _name: TableReference) -> DfResult<Arc<dyn TableSource>> {
178        unimplemented!()
179    }
180
181    fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
182        self.state.scalar_functions().get(name).cloned()
183    }
184
185    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
186        self.state.aggregate_functions().get(name).cloned()
187    }
188
189    fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
190        unimplemented!()
191    }
192
193    fn get_variable_type(&self, _variable_names: &[String]) -> Option<DataType> {
194        unimplemented!()
195    }
196
197    fn options(&self) -> &ConfigOptions {
198        self.state.config_options()
199    }
200
201    fn udf_names(&self) -> Vec<String> {
202        self.state.scalar_functions().keys().cloned().collect()
203    }
204
205    fn udaf_names(&self) -> Vec<String> {
206        self.state.aggregate_functions().keys().cloned().collect()
207    }
208
209    fn udwf_names(&self) -> Vec<String> {
210        self.state.window_functions().keys().cloned().collect()
211    }
212}
213
214pub fn validate_column_fulltext_create_option(key: &str) -> bool {
215    [
216        COLUMN_FULLTEXT_OPT_KEY_ANALYZER,
217        COLUMN_FULLTEXT_OPT_KEY_CASE_SENSITIVE,
218        COLUMN_FULLTEXT_OPT_KEY_BACKEND,
219        COLUMN_FULLTEXT_OPT_KEY_GRANULARITY,
220        COLUMN_FULLTEXT_OPT_KEY_FALSE_POSITIVE_RATE,
221    ]
222    .contains(&key)
223}
224
225pub fn validate_column_skipping_index_create_option(key: &str) -> bool {
226    [
227        COLUMN_SKIPPING_INDEX_OPT_KEY_GRANULARITY,
228        COLUMN_SKIPPING_INDEX_OPT_KEY_TYPE,
229        COLUMN_SKIPPING_INDEX_OPT_KEY_FALSE_POSITIVE_RATE,
230    ]
231    .contains(&key)
232}
233
234pub fn validate_column_vector_index_create_option(key: &str) -> bool {
235    [
236        COLUMN_VECTOR_INDEX_OPT_KEY_ENGINE,
237        COLUMN_VECTOR_INDEX_OPT_KEY_METRIC,
238        COLUMN_VECTOR_INDEX_OPT_KEY_CONNECTIVITY,
239        COLUMN_VECTOR_INDEX_OPT_KEY_EXPANSION_ADD,
240        COLUMN_VECTOR_INDEX_OPT_KEY_EXPANSION_SEARCH,
241    ]
242    .contains(&key)
243}
244
245/// Convert an [`IntervalMonthDayNano`] to a [`Duration`].
246#[cfg(feature = "enterprise")]
247pub fn convert_month_day_nano_to_duration(
248    interval: arrow_buffer::IntervalMonthDayNano,
249) -> Result<std::time::Duration> {
250    let months: i64 = interval.months.into();
251    let days: i64 = interval.days.into();
252    let months_in_seconds: i64 = months * 60 * 60 * 24 * 3044 / 1000;
253    let days_in_seconds: i64 = days * 60 * 60 * 24;
254    let seconds_from_nanos = interval.nanoseconds / 1_000_000_000;
255    let total_seconds = months_in_seconds + days_in_seconds + seconds_from_nanos;
256
257    let mut nanos_remainder = interval.nanoseconds % 1_000_000_000;
258    let mut adjusted_seconds = total_seconds;
259
260    if nanos_remainder < 0 {
261        nanos_remainder += 1_000_000_000;
262        adjusted_seconds -= 1;
263    }
264
265    snafu::ensure!(
266        adjusted_seconds >= 0,
267        crate::error::InvalidIntervalSnafu {
268            reason: "must be a positive interval",
269        }
270    );
271
272    // Cast safety: `adjusted_seconds` is guaranteed to be non-negative before.
273    let adjusted_seconds = adjusted_seconds as u64;
274    // Cast safety: `nanos_remainder` is smaller than 1_000_000_000 which
275    // is checked above.
276    let nanos_remainder = nanos_remainder as u32;
277
278    Ok(std::time::Duration::new(adjusted_seconds, nanos_remainder))
279}
280
281pub fn parse_with_options(parser: &mut Parser) -> Result<OptionMap> {
282    let options = parser
283        .parse_options(Keyword::WITH)
284        .context(SyntaxSnafu)?
285        .into_iter()
286        .map(parse_option_string)
287        .collect::<Result<HashMap<String, OptionValue>>>()?;
288    for key in options.keys() {
289        ensure!(validate_table_option(key), InvalidTableOptionSnafu { key });
290    }
291    Ok(OptionMap::new(options))
292}
293
294#[cfg(test)]
295mod tests {
296    use std::sync::Arc;
297
298    use chrono::DateTime;
299    use datafusion::functions::datetime::expr_fn::now;
300    use datafusion_expr::lit;
301    use datatypes::arrow::datatypes::TimestampNanosecondType;
302
303    use super::*;
304
305    /// Keep this test to make sure we are using datafusion's `ExprSimplifier` correctly.
306    #[test]
307    fn test_simplifier() {
308        let now_time = DateTime::from_timestamp(61, 0).unwrap();
309        let lit_now = lit(ScalarValue::new_timestamp::<TimestampNanosecondType>(
310            now_time.timestamp_nanos_opt(),
311            None,
312        ));
313        let testcases = vec![
314            (now(), lit_now),
315            (now() - now(), lit(ScalarValue::DurationNanosecond(Some(0)))),
316            (
317                now() + lit(ScalarValue::new_interval_dt(0, 1500)),
318                lit(ScalarValue::new_timestamp::<TimestampNanosecondType>(
319                    Some(62500000000),
320                    None,
321                )),
322            ),
323            (
324                now() - (now() + lit(ScalarValue::new_interval_dt(0, 1500))),
325                lit(ScalarValue::DurationNanosecond(Some(-1500000000))),
326            ),
327            // this one failed if type is not coerced
328            (
329                now() - now() + lit(ScalarValue::new_interval_dt(0, 1500)),
330                lit(ScalarValue::new_interval_mdn(0, 0, 1500000000)),
331            ),
332            (
333                lit(ScalarValue::new_interval_mdn(
334                    0,
335                    0,
336                    61 * 86400 * 1_000_000_000,
337                )),
338                lit(ScalarValue::new_interval_mdn(
339                    0,
340                    0,
341                    61 * 86400 * 1_000_000_000,
342                )),
343            ),
344        ];
345
346        let execution_props = ExecutionProps::new().with_query_execution_start_time(now_time);
347        let info = SimplifyContext::new(&execution_props).with_schema(Arc::new(DFSchema::empty()));
348
349        let simplifier = ExprSimplifier::new(info);
350        for (expr, expected) in testcases {
351            let expr_name = expr.schema_name().to_string();
352            let expr = simplifier.coerce(expr, &DFSchema::empty()).unwrap();
353
354            let simplified_expr = simplifier.simplify(expr).unwrap();
355            assert_eq!(
356                simplified_expr, expected,
357                "Failed to simplify expression: {expr_name}"
358            );
359        }
360    }
361}