sql/parsers/
set_var_parser.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use snafu::ResultExt;
16use sqlparser::ast::{Set, Statement as SpStatement};
17
18use crate::ast::{Ident, ObjectName};
19use crate::error::{self, Result};
20use crate::parser::ParserContext;
21use crate::statements::set_variables::SetVariables;
22use crate::statements::statement::Statement;
23
24/// SET variables statement parser implementation
25impl ParserContext<'_> {
26    pub(crate) fn parse_set_variables(&mut self) -> Result<Statement> {
27        let _ = self.parser.next_token();
28        let spstatement = self.parser.parse_set().context(error::SyntaxSnafu)?;
29        match spstatement {
30            SpStatement::Set(set) => match set {
31                Set::SingleAssignment {
32                    scope: _,
33                    hivevar,
34                    variable,
35                    values,
36                } if !hivevar => Ok(Statement::SetVariables(SetVariables {
37                    variable,
38                    value: values,
39                })),
40
41                Set::SetTimeZone { local: _, value } => Ok(Statement::SetVariables(SetVariables {
42                    variable: ObjectName::from(vec![Ident::new("TIMEZONE")]),
43                    value: vec![value],
44                })),
45
46                set => error::UnsupportedSnafu {
47                    keyword: set.to_string(),
48                }
49                .fail(),
50            },
51            unexp => error::UnsupportedSnafu {
52                keyword: unexp.to_string(),
53            }
54            .fail(),
55        }
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use sqlparser::ast::{Expr, Ident, ObjectName, Value};
62
63    use super::*;
64    use crate::dialect::GreptimeDbDialect;
65    use crate::parser::ParseOptions;
66
67    fn assert_mysql_parse_result(sql: &str, indent_str: &str, expr: Expr) {
68        let result =
69            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
70        let mut stmts = result.unwrap();
71        assert_eq!(
72            stmts.pop().unwrap(),
73            Statement::SetVariables(SetVariables {
74                variable: ObjectName::from(vec![Ident::new(indent_str)]),
75                value: vec![expr]
76            })
77        );
78    }
79
80    fn assert_pg_parse_result(sql: &str, indent: &str, expr: Expr) {
81        let result =
82            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
83        let mut stmts = result.unwrap();
84        assert_eq!(
85            stmts.pop().unwrap(),
86            Statement::SetVariables(SetVariables {
87                variable: ObjectName::from(vec![Ident::new(indent)]),
88                value: vec![expr],
89            })
90        );
91    }
92
93    #[test]
94    pub fn test_set_timezone() {
95        let expected_utc_expr = Expr::Value(Value::SingleQuotedString("UTC".to_string()).into());
96        // mysql style
97        let sql = "SET time_zone = 'UTC'";
98        assert_mysql_parse_result(sql, "time_zone", expected_utc_expr.clone());
99        // session or local style
100        let sql = "SET LOCAL time_zone = 'UTC'";
101        assert_mysql_parse_result(sql, "time_zone", expected_utc_expr.clone());
102        let sql = "SET SESSION time_zone = 'UTC'";
103        assert_mysql_parse_result(sql, "time_zone", expected_utc_expr.clone());
104
105        // postgresql style
106        let sql = "SET TIMEZONE TO 'UTC'";
107        assert_pg_parse_result(sql, "TIMEZONE", expected_utc_expr.clone());
108        let sql = "SET TIMEZONE 'UTC'";
109        assert_pg_parse_result(sql, "TIMEZONE", expected_utc_expr);
110    }
111
112    #[test]
113    pub fn test_set_query_timeout() {
114        let expected_query_timeout_expr =
115            Expr::Value(Value::Number("5000".to_string(), false).into());
116        // mysql style
117        let sql = "SET MAX_EXECUTION_TIME = 5000";
118        assert_mysql_parse_result(
119            sql,
120            "MAX_EXECUTION_TIME",
121            expected_query_timeout_expr.clone(),
122        );
123        // session or local style
124        let sql = "SET LOCAL MAX_EXECUTION_TIME = 5000";
125        assert_mysql_parse_result(
126            sql,
127            "MAX_EXECUTION_TIME",
128            expected_query_timeout_expr.clone(),
129        );
130        let sql = "SET SESSION MAX_EXECUTION_TIME = 5000";
131        assert_mysql_parse_result(
132            sql,
133            "MAX_EXECUTION_TIME",
134            expected_query_timeout_expr.clone(),
135        );
136
137        // postgresql style
138        let sql = "SET STATEMENT_TIMEOUT = 5000";
139        assert_pg_parse_result(
140            sql,
141            "STATEMENT_TIMEOUT",
142            expected_query_timeout_expr.clone(),
143        );
144        let sql = "SET STATEMENT_TIMEOUT TO 5000";
145        assert_pg_parse_result(sql, "STATEMENT_TIMEOUT", expected_query_timeout_expr);
146    }
147}