sql/parsers/
utils.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::sync::Arc;
17
18use datafusion::config::ConfigOptions;
19use datafusion::error::Result as DfResult;
20use datafusion::execution::SessionStateBuilder;
21use datafusion::execution::context::SessionState;
22use datafusion::optimizer::simplify_expressions::ExprSimplifier;
23use datafusion_common::tree_node::{TreeNode, TreeNodeVisitor};
24use datafusion_common::{DFSchema, ScalarValue};
25use datafusion_expr::simplify::SimplifyContext;
26use datafusion_expr::{AggregateUDF, Expr, ScalarUDF, TableSource, WindowUDF};
27use datafusion_sql::TableReference;
28use datafusion_sql::planner::{ContextProvider, SqlToRel};
29use datatypes::arrow::datatypes::DataType;
30use datatypes::schema::{
31    COLUMN_FULLTEXT_OPT_KEY_ANALYZER, COLUMN_FULLTEXT_OPT_KEY_BACKEND,
32    COLUMN_FULLTEXT_OPT_KEY_CASE_SENSITIVE, COLUMN_FULLTEXT_OPT_KEY_FALSE_POSITIVE_RATE,
33    COLUMN_FULLTEXT_OPT_KEY_GRANULARITY, COLUMN_SKIPPING_INDEX_OPT_KEY_FALSE_POSITIVE_RATE,
34    COLUMN_SKIPPING_INDEX_OPT_KEY_GRANULARITY, COLUMN_SKIPPING_INDEX_OPT_KEY_TYPE,
35    COLUMN_VECTOR_INDEX_OPT_KEY_CONNECTIVITY, COLUMN_VECTOR_INDEX_OPT_KEY_ENGINE,
36    COLUMN_VECTOR_INDEX_OPT_KEY_EXPANSION_ADD, COLUMN_VECTOR_INDEX_OPT_KEY_EXPANSION_SEARCH,
37    COLUMN_VECTOR_INDEX_OPT_KEY_METRIC,
38};
39use snafu::{ResultExt, ensure};
40use sqlparser::dialect::Dialect;
41use sqlparser::keywords::Keyword;
42use sqlparser::parser::Parser;
43use table::requests::validate_table_option;
44
45use crate::error::{
46    ConvertToLogicalExpressionSnafu, InvalidSqlSnafu, InvalidTableOptionSnafu, ParseSqlValueSnafu,
47    Result, SimplificationSnafu, SyntaxSnafu,
48};
49use crate::parser::{ParseOptions, ParserContext};
50use crate::statements::OptionMap;
51use crate::statements::statement::Statement;
52use crate::util::{OptionValue, parse_option_string};
53
54/// Check if the given SQL query is a TQL statement.
55pub fn is_tql(dialect: &dyn Dialect, sql: &str) -> Result<bool> {
56    let stmts = ParserContext::create_with_dialect(sql, dialect, ParseOptions::default())?;
57
58    ensure!(
59        stmts.len() == 1,
60        InvalidSqlSnafu {
61            msg: format!("Expect only one statement, found {}", stmts.len())
62        }
63    );
64    let stmt = &stmts[0];
65    match stmt {
66        Statement::Tql(_) => Ok(true),
67        _ => Ok(false),
68    }
69}
70
71/// Convert a parser expression to a scalar value. This function will try the
72/// best to resolve and reduce constants. Exprs like `1 + 1` or `now()` can be
73/// handled properly.
74///
75/// if `require_now_expr` is true, it will ensure that the expression contains a `now()` function.
76/// If the expression does not contain `now()`, it will return an error.
77///
78pub fn parser_expr_to_scalar_value_literal(
79    expr: sqlparser::ast::Expr,
80    require_now_expr: bool,
81) -> Result<ScalarValue> {
82    // 1. convert parser expr to logical expr
83    let empty_df_schema = DFSchema::empty();
84    let logical_expr = SqlToRel::new(&StubContextProvider::default())
85        .sql_to_expr(expr, &empty_df_schema, &mut Default::default())
86        .context(ConvertToLogicalExpressionSnafu)?;
87
88    struct FindNow {
89        found: bool,
90    }
91
92    impl TreeNodeVisitor<'_> for FindNow {
93        type Node = Expr;
94        fn f_down(
95            &mut self,
96            node: &Self::Node,
97        ) -> DfResult<datafusion_common::tree_node::TreeNodeRecursion> {
98            if let Expr::ScalarFunction(func) = node
99                && func.name().to_lowercase() == "now"
100            {
101                if !func.args.is_empty() {
102                    return Err(datafusion_common::DataFusionError::Plan(
103                        "now() function should not have arguments".to_string(),
104                    ));
105                }
106                self.found = true;
107                return Ok(datafusion_common::tree_node::TreeNodeRecursion::Stop);
108            }
109            Ok(datafusion_common::tree_node::TreeNodeRecursion::Continue)
110        }
111    }
112
113    if require_now_expr {
114        let have_now = {
115            let mut visitor = FindNow { found: false };
116            logical_expr.visit(&mut visitor).unwrap();
117            visitor.found
118        };
119        if !have_now {
120            return ParseSqlValueSnafu {
121                msg: format!(
122                    "expected now() expression, but not found in {}",
123                    logical_expr
124                ),
125            }
126            .fail();
127        }
128    }
129
130    // 2. simplify logical expr
131    let info = SimplifyContext::default().with_current_time();
132    let simplifier = ExprSimplifier::new(info);
133
134    // Coerce the logical expression so simplifier can handle it correctly. This is necessary for const eval with possible type mismatch. i.e.: `now() - now() + '15s'::interval` which is `TimestampNanosecond - TimestampNanosecond + IntervalMonthDayNano`.
135    let logical_expr = simplifier
136        .coerce(logical_expr, &empty_df_schema)
137        .context(SimplificationSnafu)?;
138
139    let simplified_expr = simplifier
140        .simplify(logical_expr)
141        .context(SimplificationSnafu)?;
142
143    if let datafusion::logical_expr::Expr::Literal(lit, _) = simplified_expr {
144        Ok(lit)
145    } else {
146        // Err(ParseSqlValue)
147        ParseSqlValueSnafu {
148            msg: format!("expected literal value, but found {:?}", simplified_expr),
149        }
150        .fail()
151    }
152}
153
154/// Helper struct for [`parser_expr_to_scalar_value`].
155struct StubContextProvider {
156    state: SessionState,
157}
158
159impl Default for StubContextProvider {
160    fn default() -> Self {
161        Self {
162            state: SessionStateBuilder::new()
163                .with_config(Default::default())
164                .with_runtime_env(Default::default())
165                .with_default_features()
166                .build(),
167        }
168    }
169}
170
171impl ContextProvider for StubContextProvider {
172    fn get_table_source(&self, _name: TableReference) -> DfResult<Arc<dyn TableSource>> {
173        unimplemented!()
174    }
175
176    fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
177        self.state.scalar_functions().get(name).cloned()
178    }
179
180    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
181        self.state.aggregate_functions().get(name).cloned()
182    }
183
184    fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
185        unimplemented!()
186    }
187
188    fn get_variable_type(&self, _variable_names: &[String]) -> Option<DataType> {
189        unimplemented!()
190    }
191
192    fn options(&self) -> &ConfigOptions {
193        self.state.config_options()
194    }
195
196    fn udf_names(&self) -> Vec<String> {
197        self.state.scalar_functions().keys().cloned().collect()
198    }
199
200    fn udaf_names(&self) -> Vec<String> {
201        self.state.aggregate_functions().keys().cloned().collect()
202    }
203
204    fn udwf_names(&self) -> Vec<String> {
205        self.state.window_functions().keys().cloned().collect()
206    }
207}
208
209pub fn validate_column_fulltext_create_option(key: &str) -> bool {
210    [
211        COLUMN_FULLTEXT_OPT_KEY_ANALYZER,
212        COLUMN_FULLTEXT_OPT_KEY_CASE_SENSITIVE,
213        COLUMN_FULLTEXT_OPT_KEY_BACKEND,
214        COLUMN_FULLTEXT_OPT_KEY_GRANULARITY,
215        COLUMN_FULLTEXT_OPT_KEY_FALSE_POSITIVE_RATE,
216    ]
217    .contains(&key)
218}
219
220pub fn validate_column_skipping_index_create_option(key: &str) -> bool {
221    [
222        COLUMN_SKIPPING_INDEX_OPT_KEY_GRANULARITY,
223        COLUMN_SKIPPING_INDEX_OPT_KEY_TYPE,
224        COLUMN_SKIPPING_INDEX_OPT_KEY_FALSE_POSITIVE_RATE,
225    ]
226    .contains(&key)
227}
228
229pub fn validate_column_vector_index_create_option(key: &str) -> bool {
230    [
231        COLUMN_VECTOR_INDEX_OPT_KEY_ENGINE,
232        COLUMN_VECTOR_INDEX_OPT_KEY_METRIC,
233        COLUMN_VECTOR_INDEX_OPT_KEY_CONNECTIVITY,
234        COLUMN_VECTOR_INDEX_OPT_KEY_EXPANSION_ADD,
235        COLUMN_VECTOR_INDEX_OPT_KEY_EXPANSION_SEARCH,
236    ]
237    .contains(&key)
238}
239
240/// Convert an [`IntervalMonthDayNano`] to a [`Duration`].
241#[cfg(feature = "enterprise")]
242pub fn convert_month_day_nano_to_duration(
243    interval: arrow_buffer::IntervalMonthDayNano,
244) -> Result<std::time::Duration> {
245    let months: i64 = interval.months.into();
246    let days: i64 = interval.days.into();
247    let months_in_seconds: i64 = months * 60 * 60 * 24 * 3044 / 1000;
248    let days_in_seconds: i64 = days * 60 * 60 * 24;
249    let seconds_from_nanos = interval.nanoseconds / 1_000_000_000;
250    let total_seconds = months_in_seconds + days_in_seconds + seconds_from_nanos;
251
252    let mut nanos_remainder = interval.nanoseconds % 1_000_000_000;
253    let mut adjusted_seconds = total_seconds;
254
255    if nanos_remainder < 0 {
256        nanos_remainder += 1_000_000_000;
257        adjusted_seconds -= 1;
258    }
259
260    snafu::ensure!(
261        adjusted_seconds >= 0,
262        crate::error::InvalidIntervalSnafu {
263            reason: "must be a positive interval",
264        }
265    );
266
267    // Cast safety: `adjusted_seconds` is guaranteed to be non-negative before.
268    let adjusted_seconds = adjusted_seconds as u64;
269    // Cast safety: `nanos_remainder` is smaller than 1_000_000_000 which
270    // is checked above.
271    let nanos_remainder = nanos_remainder as u32;
272
273    Ok(std::time::Duration::new(adjusted_seconds, nanos_remainder))
274}
275
276pub fn parse_with_options(parser: &mut Parser) -> Result<OptionMap> {
277    let options = parser
278        .parse_options(Keyword::WITH)
279        .context(SyntaxSnafu)?
280        .into_iter()
281        .map(parse_option_string)
282        .collect::<Result<HashMap<String, OptionValue>>>()?;
283    for key in options.keys() {
284        ensure!(validate_table_option(key), InvalidTableOptionSnafu { key });
285    }
286    Ok(OptionMap::new(options))
287}
288
289#[cfg(test)]
290mod tests {
291    use chrono::DateTime;
292    use datafusion::functions::datetime::expr_fn::now;
293    use datafusion_expr::lit;
294    use datatypes::arrow::datatypes::TimestampNanosecondType;
295
296    use super::*;
297
298    /// Keep this test to make sure we are using datafusion's `ExprSimplifier` correctly.
299    #[test]
300    fn test_simplifier() {
301        let now_time = DateTime::from_timestamp(61, 0).unwrap();
302        let lit_now = lit(ScalarValue::new_timestamp::<TimestampNanosecondType>(
303            now_time.timestamp_nanos_opt(),
304            None,
305        ));
306        let testcases = vec![
307            (now(), lit_now),
308            (now() - now(), lit(ScalarValue::DurationNanosecond(Some(0)))),
309            (
310                now() + lit(ScalarValue::new_interval_dt(0, 1500)),
311                lit(ScalarValue::new_timestamp::<TimestampNanosecondType>(
312                    Some(62500000000),
313                    None,
314                )),
315            ),
316            (
317                now() - (now() + lit(ScalarValue::new_interval_dt(0, 1500))),
318                lit(ScalarValue::DurationNanosecond(Some(-1500000000))),
319            ),
320            // this one failed if type is not coerced
321            (
322                now() - now() + lit(ScalarValue::new_interval_dt(0, 1500)),
323                lit(ScalarValue::new_interval_mdn(0, 0, 1500000000)),
324            ),
325            (
326                lit(ScalarValue::new_interval_mdn(
327                    0,
328                    0,
329                    61 * 86400 * 1_000_000_000,
330                )),
331                lit(ScalarValue::new_interval_mdn(
332                    0,
333                    0,
334                    61 * 86400 * 1_000_000_000,
335                )),
336            ),
337        ];
338
339        let info = SimplifyContext::default().with_query_execution_start_time(Some(now_time));
340        let simplifier = ExprSimplifier::new(info);
341        for (expr, expected) in testcases {
342            let expr_name = expr.schema_name().to_string();
343            let expr = simplifier.coerce(expr, &DFSchema::empty()).unwrap();
344
345            let simplified_expr = simplifier.simplify(expr).unwrap();
346            assert_eq!(
347                simplified_expr, expected,
348                "Failed to simplify expression: {expr_name}"
349            );
350        }
351    }
352}