sql/parsers/
utils.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use chrono::Utc;
18use datafusion::config::ConfigOptions;
19use datafusion::error::Result as DfResult;
20use datafusion::execution::SessionStateBuilder;
21use datafusion::execution::context::SessionState;
22use datafusion::optimizer::simplify_expressions::ExprSimplifier;
23use datafusion_common::tree_node::{TreeNode, TreeNodeVisitor};
24use datafusion_common::{DFSchema, ScalarValue};
25use datafusion_expr::execution_props::ExecutionProps;
26use datafusion_expr::simplify::SimplifyContext;
27use datafusion_expr::{AggregateUDF, Expr, ScalarUDF, TableSource, WindowUDF};
28use datafusion_sql::TableReference;
29use datafusion_sql::planner::{ContextProvider, SqlToRel};
30use datatypes::arrow::datatypes::DataType;
31use datatypes::schema::{
32    COLUMN_FULLTEXT_OPT_KEY_ANALYZER, COLUMN_FULLTEXT_OPT_KEY_BACKEND,
33    COLUMN_FULLTEXT_OPT_KEY_CASE_SENSITIVE, COLUMN_FULLTEXT_OPT_KEY_FALSE_POSITIVE_RATE,
34    COLUMN_FULLTEXT_OPT_KEY_GRANULARITY, COLUMN_SKIPPING_INDEX_OPT_KEY_FALSE_POSITIVE_RATE,
35    COLUMN_SKIPPING_INDEX_OPT_KEY_GRANULARITY, COLUMN_SKIPPING_INDEX_OPT_KEY_TYPE,
36};
37use snafu::{ResultExt, ensure};
38use sqlparser::dialect::Dialect;
39
40use crate::error::{
41    ConvertToLogicalExpressionSnafu, InvalidSqlSnafu, ParseSqlValueSnafu, Result,
42    SimplificationSnafu,
43};
44use crate::parser::{ParseOptions, ParserContext};
45use crate::statements::statement::Statement;
46
47/// Check if the given SQL query is a TQL statement.
48pub fn is_tql(dialect: &dyn Dialect, sql: &str) -> Result<bool> {
49    let stmts = ParserContext::create_with_dialect(sql, dialect, ParseOptions::default())?;
50
51    ensure!(
52        stmts.len() == 1,
53        InvalidSqlSnafu {
54            msg: format!("Expect only one statement, found {}", stmts.len())
55        }
56    );
57    let stmt = &stmts[0];
58    match stmt {
59        Statement::Tql(_) => Ok(true),
60        _ => Ok(false),
61    }
62}
63
64/// Convert a parser expression to a scalar value. This function will try the
65/// best to resolve and reduce constants. Exprs like `1 + 1` or `now()` can be
66/// handled properly.
67///
68/// if `require_now_expr` is true, it will ensure that the expression contains a `now()` function.
69/// If the expression does not contain `now()`, it will return an error.
70///
71pub fn parser_expr_to_scalar_value_literal(
72    expr: sqlparser::ast::Expr,
73    require_now_expr: bool,
74) -> Result<ScalarValue> {
75    // 1. convert parser expr to logical expr
76    let empty_df_schema = DFSchema::empty();
77    let logical_expr = SqlToRel::new(&StubContextProvider::default())
78        .sql_to_expr(expr.into(), &empty_df_schema, &mut Default::default())
79        .context(ConvertToLogicalExpressionSnafu)?;
80
81    struct FindNow {
82        found: bool,
83    }
84
85    impl TreeNodeVisitor<'_> for FindNow {
86        type Node = Expr;
87        fn f_down(
88            &mut self,
89            node: &Self::Node,
90        ) -> DfResult<datafusion_common::tree_node::TreeNodeRecursion> {
91            if let Expr::ScalarFunction(func) = node
92                && func.name().to_lowercase() == "now"
93            {
94                if !func.args.is_empty() {
95                    return Err(datafusion_common::DataFusionError::Plan(
96                        "now() function should not have arguments".to_string(),
97                    ));
98                }
99                self.found = true;
100                return Ok(datafusion_common::tree_node::TreeNodeRecursion::Stop);
101            }
102            Ok(datafusion_common::tree_node::TreeNodeRecursion::Continue)
103        }
104    }
105
106    if require_now_expr {
107        let have_now = {
108            let mut visitor = FindNow { found: false };
109            logical_expr.visit(&mut visitor).unwrap();
110            visitor.found
111        };
112        if !have_now {
113            return ParseSqlValueSnafu {
114                msg: format!(
115                    "expected now() expression, but not found in {}",
116                    logical_expr
117                ),
118            }
119            .fail();
120        }
121    }
122
123    // 2. simplify logical expr
124    let execution_props = ExecutionProps::new().with_query_execution_start_time(Utc::now());
125    let info =
126        SimplifyContext::new(&execution_props).with_schema(Arc::new(empty_df_schema.clone()));
127
128    let simplifier = ExprSimplifier::new(info);
129
130    // Coerce the logical expression so simplifier can handle it correctly. This is necessary for const eval with possible type mismatch. i.e.: `now() - now() + '15s'::interval` which is `TimestampNanosecond - TimestampNanosecond + IntervalMonthDayNano`.
131    let logical_expr = simplifier
132        .coerce(logical_expr, &empty_df_schema)
133        .context(SimplificationSnafu)?;
134
135    let simplified_expr = simplifier
136        .simplify(logical_expr)
137        .context(SimplificationSnafu)?;
138
139    if let datafusion::logical_expr::Expr::Literal(lit, _) = simplified_expr {
140        Ok(lit)
141    } else {
142        // Err(ParseSqlValue)
143        ParseSqlValueSnafu {
144            msg: format!("expected literal value, but found {:?}", simplified_expr),
145        }
146        .fail()
147    }
148}
149
150/// Helper struct for [`parser_expr_to_scalar_value`].
151struct StubContextProvider {
152    state: SessionState,
153}
154
155impl Default for StubContextProvider {
156    fn default() -> Self {
157        Self {
158            state: SessionStateBuilder::new()
159                .with_config(Default::default())
160                .with_runtime_env(Default::default())
161                .with_default_features()
162                .build(),
163        }
164    }
165}
166
167impl ContextProvider for StubContextProvider {
168    fn get_table_source(&self, _name: TableReference) -> DfResult<Arc<dyn TableSource>> {
169        unimplemented!()
170    }
171
172    fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
173        self.state.scalar_functions().get(name).cloned()
174    }
175
176    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
177        self.state.aggregate_functions().get(name).cloned()
178    }
179
180    fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
181        unimplemented!()
182    }
183
184    fn get_variable_type(&self, _variable_names: &[String]) -> Option<DataType> {
185        unimplemented!()
186    }
187
188    fn options(&self) -> &ConfigOptions {
189        self.state.config_options()
190    }
191
192    fn udf_names(&self) -> Vec<String> {
193        self.state.scalar_functions().keys().cloned().collect()
194    }
195
196    fn udaf_names(&self) -> Vec<String> {
197        self.state.aggregate_functions().keys().cloned().collect()
198    }
199
200    fn udwf_names(&self) -> Vec<String> {
201        self.state.window_functions().keys().cloned().collect()
202    }
203}
204
205pub fn validate_column_fulltext_create_option(key: &str) -> bool {
206    [
207        COLUMN_FULLTEXT_OPT_KEY_ANALYZER,
208        COLUMN_FULLTEXT_OPT_KEY_CASE_SENSITIVE,
209        COLUMN_FULLTEXT_OPT_KEY_BACKEND,
210        COLUMN_FULLTEXT_OPT_KEY_GRANULARITY,
211        COLUMN_FULLTEXT_OPT_KEY_FALSE_POSITIVE_RATE,
212    ]
213    .contains(&key)
214}
215
216pub fn validate_column_skipping_index_create_option(key: &str) -> bool {
217    [
218        COLUMN_SKIPPING_INDEX_OPT_KEY_GRANULARITY,
219        COLUMN_SKIPPING_INDEX_OPT_KEY_TYPE,
220        COLUMN_SKIPPING_INDEX_OPT_KEY_FALSE_POSITIVE_RATE,
221    ]
222    .contains(&key)
223}
224
225/// Convert an [`IntervalMonthDayNano`] to a [`Duration`].
226#[cfg(feature = "enterprise")]
227pub fn convert_month_day_nano_to_duration(
228    interval: arrow_buffer::IntervalMonthDayNano,
229) -> Result<std::time::Duration> {
230    let months: i64 = interval.months.into();
231    let days: i64 = interval.days.into();
232    let months_in_seconds: i64 = months * 60 * 60 * 24 * 3044 / 1000;
233    let days_in_seconds: i64 = days * 60 * 60 * 24;
234    let seconds_from_nanos = interval.nanoseconds / 1_000_000_000;
235    let total_seconds = months_in_seconds + days_in_seconds + seconds_from_nanos;
236
237    let mut nanos_remainder = interval.nanoseconds % 1_000_000_000;
238    let mut adjusted_seconds = total_seconds;
239
240    if nanos_remainder < 0 {
241        nanos_remainder += 1_000_000_000;
242        adjusted_seconds -= 1;
243    }
244
245    snafu::ensure!(
246        adjusted_seconds >= 0,
247        crate::error::InvalidIntervalSnafu {
248            reason: "must be a positive interval",
249        }
250    );
251
252    // Cast safety: `adjusted_seconds` is guaranteed to be non-negative before.
253    let adjusted_seconds = adjusted_seconds as u64;
254    // Cast safety: `nanos_remainder` is smaller than 1_000_000_000 which
255    // is checked above.
256    let nanos_remainder = nanos_remainder as u32;
257
258    Ok(std::time::Duration::new(adjusted_seconds, nanos_remainder))
259}
260
261#[cfg(test)]
262mod tests {
263    use std::sync::Arc;
264
265    use chrono::DateTime;
266    use datafusion::functions::datetime::expr_fn::now;
267    use datafusion_expr::lit;
268    use datatypes::arrow::datatypes::TimestampNanosecondType;
269
270    use super::*;
271
272    /// Keep this test to make sure we are using datafusion's `ExprSimplifier` correctly.
273    #[test]
274    fn test_simplifier() {
275        let now_time = DateTime::from_timestamp(61, 0).unwrap();
276        let lit_now = lit(ScalarValue::new_timestamp::<TimestampNanosecondType>(
277            now_time.timestamp_nanos_opt(),
278            None,
279        ));
280        let testcases = vec![
281            (now(), lit_now),
282            (now() - now(), lit(ScalarValue::DurationNanosecond(Some(0)))),
283            (
284                now() + lit(ScalarValue::new_interval_dt(0, 1500)),
285                lit(ScalarValue::new_timestamp::<TimestampNanosecondType>(
286                    Some(62500000000),
287                    None,
288                )),
289            ),
290            (
291                now() - (now() + lit(ScalarValue::new_interval_dt(0, 1500))),
292                lit(ScalarValue::DurationNanosecond(Some(-1500000000))),
293            ),
294            // this one failed if type is not coerced
295            (
296                now() - now() + lit(ScalarValue::new_interval_dt(0, 1500)),
297                lit(ScalarValue::new_interval_mdn(0, 0, 1500000000)),
298            ),
299            (
300                lit(ScalarValue::new_interval_mdn(
301                    0,
302                    0,
303                    61 * 86400 * 1_000_000_000,
304                )),
305                lit(ScalarValue::new_interval_mdn(
306                    0,
307                    0,
308                    61 * 86400 * 1_000_000_000,
309                )),
310            ),
311        ];
312
313        let execution_props = ExecutionProps::new().with_query_execution_start_time(now_time);
314        let info = SimplifyContext::new(&execution_props).with_schema(Arc::new(DFSchema::empty()));
315
316        let simplifier = ExprSimplifier::new(info);
317        for (expr, expected) in testcases {
318            let expr_name = expr.schema_name().to_string();
319            let expr = simplifier.coerce(expr, &DFSchema::empty()).unwrap();
320
321            let simplified_expr = simplifier.simplify(expr).unwrap();
322            assert_eq!(
323                simplified_expr, expected,
324                "Failed to simplify expression: {expr_name}"
325            );
326        }
327    }
328}