sql/parsers/
utils.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use chrono::Utc;
18use datafusion::config::ConfigOptions;
19use datafusion::error::Result as DfResult;
20use datafusion::execution::SessionStateBuilder;
21use datafusion::execution::context::SessionState;
22use datafusion::optimizer::simplify_expressions::ExprSimplifier;
23use datafusion_common::tree_node::{TreeNode, TreeNodeVisitor};
24use datafusion_common::{DFSchema, ScalarValue};
25use datafusion_expr::execution_props::ExecutionProps;
26use datafusion_expr::simplify::SimplifyContext;
27use datafusion_expr::{AggregateUDF, Expr, ScalarUDF, TableSource, WindowUDF};
28use datafusion_sql::TableReference;
29use datafusion_sql::planner::{ContextProvider, SqlToRel};
30use datatypes::arrow::datatypes::DataType;
31use datatypes::schema::{
32    COLUMN_FULLTEXT_OPT_KEY_ANALYZER, COLUMN_FULLTEXT_OPT_KEY_BACKEND,
33    COLUMN_FULLTEXT_OPT_KEY_CASE_SENSITIVE, COLUMN_FULLTEXT_OPT_KEY_FALSE_POSITIVE_RATE,
34    COLUMN_FULLTEXT_OPT_KEY_GRANULARITY, COLUMN_SKIPPING_INDEX_OPT_KEY_FALSE_POSITIVE_RATE,
35    COLUMN_SKIPPING_INDEX_OPT_KEY_GRANULARITY, COLUMN_SKIPPING_INDEX_OPT_KEY_TYPE,
36};
37use snafu::{ResultExt, ensure};
38use sqlparser::dialect::Dialect;
39
40use crate::error::{
41    ConvertToLogicalExpressionSnafu, InvalidSqlSnafu, ParseSqlValueSnafu, Result,
42    SimplificationSnafu,
43};
44use crate::parser::{ParseOptions, ParserContext};
45use crate::statements::statement::Statement;
46
47/// Check if the given SQL query is a TQL statement.
48pub fn is_tql(dialect: &dyn Dialect, sql: &str) -> Result<bool> {
49    let stmts = ParserContext::create_with_dialect(sql, dialect, ParseOptions::default())?;
50
51    ensure!(
52        stmts.len() == 1,
53        InvalidSqlSnafu {
54            msg: format!("Expect only one statement, found {}", stmts.len())
55        }
56    );
57    let stmt = &stmts[0];
58    match stmt {
59        Statement::Tql(_) => Ok(true),
60        _ => Ok(false),
61    }
62}
63
64/// Convert a parser expression to a scalar value. This function will try the
65/// best to resolve and reduce constants. Exprs like `1 + 1` or `now()` can be
66/// handled properly.
67///
68/// if `require_now_expr` is true, it will ensure that the expression contains a `now()` function.
69/// If the expression does not contain `now()`, it will return an error.
70///
71pub fn parser_expr_to_scalar_value_literal(
72    expr: sqlparser::ast::Expr,
73    require_now_expr: bool,
74) -> Result<ScalarValue> {
75    // 1. convert parser expr to logical expr
76    let empty_df_schema = DFSchema::empty();
77    let logical_expr = SqlToRel::new(&StubContextProvider::default())
78        .sql_to_expr(expr, &empty_df_schema, &mut Default::default())
79        .context(ConvertToLogicalExpressionSnafu)?;
80
81    struct FindNow {
82        found: bool,
83    }
84
85    impl TreeNodeVisitor<'_> for FindNow {
86        type Node = Expr;
87        fn f_down(
88            &mut self,
89            node: &Self::Node,
90        ) -> DfResult<datafusion_common::tree_node::TreeNodeRecursion> {
91            if let Expr::ScalarFunction(func) = node
92                && func.name().to_lowercase() == "now"
93            {
94                if !func.args.is_empty() {
95                    return Err(datafusion_common::DataFusionError::Plan(
96                        "now() function should not have arguments".to_string(),
97                    ));
98                }
99                self.found = true;
100                return Ok(datafusion_common::tree_node::TreeNodeRecursion::Stop);
101            }
102            Ok(datafusion_common::tree_node::TreeNodeRecursion::Continue)
103        }
104    }
105
106    if require_now_expr {
107        let have_now = {
108            let mut visitor = FindNow { found: false };
109            logical_expr.visit(&mut visitor).unwrap();
110            visitor.found
111        };
112        if !have_now {
113            return ParseSqlValueSnafu {
114                msg: format!(
115                    "expected now() expression, but not found in {}",
116                    logical_expr
117                ),
118            }
119            .fail();
120        }
121    }
122
123    // 2. simplify logical expr
124    let execution_props = ExecutionProps::new().with_query_execution_start_time(Utc::now());
125    let info =
126        SimplifyContext::new(&execution_props).with_schema(Arc::new(empty_df_schema.clone()));
127
128    let simplifier = ExprSimplifier::new(info);
129
130    // Coerce the logical expression so simplifier can handle it correctly. This is necessary for const eval with possible type mismatch. i.e.: `now() - now() + '15s'::interval` which is `TimestampNanosecond - TimestampNanosecond + IntervalMonthDayNano`.
131    let logical_expr = simplifier
132        .coerce(logical_expr, &empty_df_schema)
133        .context(SimplificationSnafu)?;
134
135    let simplified_expr = simplifier
136        .simplify(logical_expr)
137        .context(SimplificationSnafu)?;
138
139    if let datafusion::logical_expr::Expr::Literal(lit, _) = simplified_expr {
140        Ok(lit)
141    } else {
142        // Err(ParseSqlValue)
143        ParseSqlValueSnafu {
144            msg: format!("expected literal value, but found {:?}", simplified_expr),
145        }
146        .fail()
147    }
148}
149
150/// Helper struct for [`parser_expr_to_scalar_value`].
151struct StubContextProvider {
152    state: SessionState,
153}
154
155impl Default for StubContextProvider {
156    fn default() -> Self {
157        Self {
158            state: SessionStateBuilder::new()
159                .with_config(Default::default())
160                .with_runtime_env(Default::default())
161                .with_default_features()
162                .build(),
163        }
164    }
165}
166
167impl ContextProvider for StubContextProvider {
168    fn get_table_source(&self, _name: TableReference) -> DfResult<Arc<dyn TableSource>> {
169        unimplemented!()
170    }
171
172    fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
173        self.state.scalar_functions().get(name).cloned()
174    }
175
176    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
177        self.state.aggregate_functions().get(name).cloned()
178    }
179
180    fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
181        unimplemented!()
182    }
183
184    fn get_variable_type(&self, _variable_names: &[String]) -> Option<DataType> {
185        unimplemented!()
186    }
187
188    fn options(&self) -> &ConfigOptions {
189        self.state.config_options()
190    }
191
192    fn udf_names(&self) -> Vec<String> {
193        self.state.scalar_functions().keys().cloned().collect()
194    }
195
196    fn udaf_names(&self) -> Vec<String> {
197        self.state.aggregate_functions().keys().cloned().collect()
198    }
199
200    fn udwf_names(&self) -> Vec<String> {
201        self.state.window_functions().keys().cloned().collect()
202    }
203}
204
205pub fn validate_column_fulltext_create_option(key: &str) -> bool {
206    [
207        COLUMN_FULLTEXT_OPT_KEY_ANALYZER,
208        COLUMN_FULLTEXT_OPT_KEY_CASE_SENSITIVE,
209        COLUMN_FULLTEXT_OPT_KEY_BACKEND,
210        COLUMN_FULLTEXT_OPT_KEY_GRANULARITY,
211        COLUMN_FULLTEXT_OPT_KEY_FALSE_POSITIVE_RATE,
212    ]
213    .contains(&key)
214}
215
216pub fn validate_column_skipping_index_create_option(key: &str) -> bool {
217    [
218        COLUMN_SKIPPING_INDEX_OPT_KEY_GRANULARITY,
219        COLUMN_SKIPPING_INDEX_OPT_KEY_TYPE,
220        COLUMN_SKIPPING_INDEX_OPT_KEY_FALSE_POSITIVE_RATE,
221    ]
222    .contains(&key)
223}
224
225/// Valid options for VECTOR INDEX:
226/// - engine: Vector index engine (usearch)
227/// - metric: Distance metric (l2sq, cosine, inner_product)
228/// - connectivity: HNSW M parameter
229/// - expansion_add: ef_construction parameter
230/// - expansion_search: ef_search parameter
231pub const COLUMN_VECTOR_INDEX_OPT_KEY_ENGINE: &str = "engine";
232pub const COLUMN_VECTOR_INDEX_OPT_KEY_METRIC: &str = "metric";
233pub const COLUMN_VECTOR_INDEX_OPT_KEY_CONNECTIVITY: &str = "connectivity";
234pub const COLUMN_VECTOR_INDEX_OPT_KEY_EXPANSION_ADD: &str = "expansion_add";
235pub const COLUMN_VECTOR_INDEX_OPT_KEY_EXPANSION_SEARCH: &str = "expansion_search";
236
237pub fn validate_column_vector_index_create_option(key: &str) -> bool {
238    [
239        COLUMN_VECTOR_INDEX_OPT_KEY_ENGINE,
240        COLUMN_VECTOR_INDEX_OPT_KEY_METRIC,
241        COLUMN_VECTOR_INDEX_OPT_KEY_CONNECTIVITY,
242        COLUMN_VECTOR_INDEX_OPT_KEY_EXPANSION_ADD,
243        COLUMN_VECTOR_INDEX_OPT_KEY_EXPANSION_SEARCH,
244    ]
245    .contains(&key)
246}
247
248/// Convert an [`IntervalMonthDayNano`] to a [`Duration`].
249#[cfg(feature = "enterprise")]
250pub fn convert_month_day_nano_to_duration(
251    interval: arrow_buffer::IntervalMonthDayNano,
252) -> Result<std::time::Duration> {
253    let months: i64 = interval.months.into();
254    let days: i64 = interval.days.into();
255    let months_in_seconds: i64 = months * 60 * 60 * 24 * 3044 / 1000;
256    let days_in_seconds: i64 = days * 60 * 60 * 24;
257    let seconds_from_nanos = interval.nanoseconds / 1_000_000_000;
258    let total_seconds = months_in_seconds + days_in_seconds + seconds_from_nanos;
259
260    let mut nanos_remainder = interval.nanoseconds % 1_000_000_000;
261    let mut adjusted_seconds = total_seconds;
262
263    if nanos_remainder < 0 {
264        nanos_remainder += 1_000_000_000;
265        adjusted_seconds -= 1;
266    }
267
268    snafu::ensure!(
269        adjusted_seconds >= 0,
270        crate::error::InvalidIntervalSnafu {
271            reason: "must be a positive interval",
272        }
273    );
274
275    // Cast safety: `adjusted_seconds` is guaranteed to be non-negative before.
276    let adjusted_seconds = adjusted_seconds as u64;
277    // Cast safety: `nanos_remainder` is smaller than 1_000_000_000 which
278    // is checked above.
279    let nanos_remainder = nanos_remainder as u32;
280
281    Ok(std::time::Duration::new(adjusted_seconds, nanos_remainder))
282}
283
284#[cfg(test)]
285mod tests {
286    use std::sync::Arc;
287
288    use chrono::DateTime;
289    use datafusion::functions::datetime::expr_fn::now;
290    use datafusion_expr::lit;
291    use datatypes::arrow::datatypes::TimestampNanosecondType;
292
293    use super::*;
294
295    /// Keep this test to make sure we are using datafusion's `ExprSimplifier` correctly.
296    #[test]
297    fn test_simplifier() {
298        let now_time = DateTime::from_timestamp(61, 0).unwrap();
299        let lit_now = lit(ScalarValue::new_timestamp::<TimestampNanosecondType>(
300            now_time.timestamp_nanos_opt(),
301            None,
302        ));
303        let testcases = vec![
304            (now(), lit_now),
305            (now() - now(), lit(ScalarValue::DurationNanosecond(Some(0)))),
306            (
307                now() + lit(ScalarValue::new_interval_dt(0, 1500)),
308                lit(ScalarValue::new_timestamp::<TimestampNanosecondType>(
309                    Some(62500000000),
310                    None,
311                )),
312            ),
313            (
314                now() - (now() + lit(ScalarValue::new_interval_dt(0, 1500))),
315                lit(ScalarValue::DurationNanosecond(Some(-1500000000))),
316            ),
317            // this one failed if type is not coerced
318            (
319                now() - now() + lit(ScalarValue::new_interval_dt(0, 1500)),
320                lit(ScalarValue::new_interval_mdn(0, 0, 1500000000)),
321            ),
322            (
323                lit(ScalarValue::new_interval_mdn(
324                    0,
325                    0,
326                    61 * 86400 * 1_000_000_000,
327                )),
328                lit(ScalarValue::new_interval_mdn(
329                    0,
330                    0,
331                    61 * 86400 * 1_000_000_000,
332                )),
333            ),
334        ];
335
336        let execution_props = ExecutionProps::new().with_query_execution_start_time(now_time);
337        let info = SimplifyContext::new(&execution_props).with_schema(Arc::new(DFSchema::empty()));
338
339        let simplifier = ExprSimplifier::new(info);
340        for (expr, expected) in testcases {
341            let expr_name = expr.schema_name().to_string();
342            let expr = simplifier.coerce(expr, &DFSchema::empty()).unwrap();
343
344            let simplified_expr = simplifier.simplify(expr).unwrap();
345            assert_eq!(
346                simplified_expr, expected,
347                "Failed to simplify expression: {expr_name}"
348            );
349        }
350    }
351}