sql/
parser.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use snafu::ResultExt;
16use sqlparser::ast::{Ident, Query};
17use sqlparser::dialect::Dialect;
18use sqlparser::keywords::Keyword;
19use sqlparser::parser::{Parser, ParserError, ParserOptions};
20use sqlparser::tokenizer::{Token, TokenWithSpan};
21
22use crate::ast::{Expr, ObjectName};
23use crate::error::{self, Result, SyntaxSnafu};
24use crate::parsers::tql_parser;
25use crate::statements::statement::Statement;
26use crate::statements::transform_statements;
27
28pub const FLOW: &str = "FLOW";
29
30/// SQL Parser options.
31#[derive(Clone, Debug, Default)]
32pub struct ParseOptions {}
33
34/// GrepTime SQL parser context, a simple wrapper for Datafusion SQL parser.
35pub struct ParserContext<'a> {
36    pub(crate) parser: Parser<'a>,
37    pub(crate) sql: &'a str,
38}
39
40impl ParserContext<'_> {
41    /// Construct a new ParserContext.
42    pub fn new<'a>(dialect: &'a dyn Dialect, sql: &'a str) -> Result<ParserContext<'a>> {
43        let parser = Parser::new(dialect)
44            .with_options(ParserOptions::new().with_trailing_commas(true))
45            .try_with_sql(sql)
46            .context(SyntaxSnafu)?;
47
48        Ok(ParserContext { parser, sql })
49    }
50
51    /// Parses parser context to Query.
52    pub fn parser_query(&mut self) -> Result<Box<Query>> {
53        self.parser.parse_query().context(SyntaxSnafu)
54    }
55
56    /// Parses SQL with given dialect
57    pub fn create_with_dialect(
58        sql: &str,
59        dialect: &dyn Dialect,
60        _opts: ParseOptions,
61    ) -> Result<Vec<Statement>> {
62        let mut stmts: Vec<Statement> = Vec::new();
63
64        let mut parser_ctx = ParserContext::new(dialect, sql)?;
65
66        let mut expecting_statement_delimiter = false;
67        loop {
68            // ignore empty statements (between successive statement delimiters)
69            while parser_ctx.parser.consume_token(&Token::SemiColon) {
70                expecting_statement_delimiter = false;
71            }
72
73            if parser_ctx.parser.peek_token() == Token::EOF {
74                break;
75            }
76            if expecting_statement_delimiter {
77                return parser_ctx.unsupported(parser_ctx.peek_token_as_string());
78            }
79
80            let statement = parser_ctx.parse_statement()?;
81            stmts.push(statement);
82            expecting_statement_delimiter = true;
83        }
84
85        transform_statements(&mut stmts)?;
86
87        Ok(stmts)
88    }
89
90    pub fn parse_table_name(sql: &str, dialect: &dyn Dialect) -> Result<ObjectName> {
91        let parser = Parser::new(dialect)
92            .with_options(ParserOptions::new().with_trailing_commas(true))
93            .try_with_sql(sql)
94            .context(SyntaxSnafu)?;
95        ParserContext { parser, sql }.intern_parse_table_name()
96    }
97
98    pub(crate) fn intern_parse_table_name(&mut self) -> Result<ObjectName> {
99        let raw_table_name =
100            self.parser
101                .parse_object_name(false)
102                .context(error::UnexpectedSnafu {
103                    expected: "a table name",
104                    actual: self.parser.peek_token().to_string(),
105                })?;
106        Ok(Self::canonicalize_object_name(raw_table_name))
107    }
108
109    pub fn parse_function(sql: &str, dialect: &dyn Dialect) -> Result<Expr> {
110        let mut parser = Parser::new(dialect)
111            .with_options(ParserOptions::new().with_trailing_commas(true))
112            .try_with_sql(sql)
113            .context(SyntaxSnafu)?;
114
115        let function_name = parser.parse_identifier().context(SyntaxSnafu)?;
116        parser
117            .parse_function(ObjectName(vec![function_name]))
118            .context(SyntaxSnafu)
119    }
120
121    /// Parses parser context to a set of statements.
122    pub fn parse_statement(&mut self) -> Result<Statement> {
123        match self.parser.peek_token().token {
124            Token::Word(w) => {
125                match w.keyword {
126                    Keyword::CREATE => {
127                        let _ = self.parser.next_token();
128                        self.parse_create()
129                    }
130
131                    Keyword::EXPLAIN => {
132                        let _ = self.parser.next_token();
133                        self.parse_explain()
134                    }
135
136                    Keyword::SHOW => {
137                        let _ = self.parser.next_token();
138                        self.parse_show()
139                    }
140
141                    Keyword::DELETE => self.parse_delete(),
142
143                    Keyword::DESCRIBE | Keyword::DESC => {
144                        let _ = self.parser.next_token();
145                        self.parse_describe()
146                    }
147
148                    Keyword::INSERT => self.parse_insert(),
149
150                    Keyword::REPLACE => self.parse_replace(),
151
152                    Keyword::SELECT | Keyword::WITH | Keyword::VALUES => self.parse_query(),
153
154                    Keyword::ALTER => self.parse_alter(),
155
156                    Keyword::DROP => self.parse_drop(),
157
158                    Keyword::COPY => self.parse_copy(),
159
160                    Keyword::TRUNCATE => self.parse_truncate(),
161
162                    Keyword::SET => self.parse_set_variables(),
163
164                    Keyword::ADMIN => self.parse_admin_command(),
165
166                    Keyword::NoKeyword
167                        if w.quote_style.is_none() && w.value.to_uppercase() == tql_parser::TQL =>
168                    {
169                        self.parse_tql()
170                    }
171
172                    Keyword::DECLARE => self.parse_declare_cursor(),
173
174                    Keyword::FETCH => self.parse_fetch_cursor(),
175
176                    Keyword::CLOSE => self.parse_close_cursor(),
177
178                    Keyword::USE => {
179                        let _ = self.parser.next_token();
180
181                        let database_name = self.parser.parse_identifier().with_context(|_| {
182                            error::UnexpectedSnafu {
183                                expected: "a database name",
184                                actual: self.peek_token_as_string(),
185                            }
186                        })?;
187                        Ok(Statement::Use(
188                            Self::canonicalize_identifier(database_name).value,
189                        ))
190                    }
191
192                    // todo(hl) support more statements.
193                    _ => self.unsupported(self.peek_token_as_string()),
194                }
195            }
196            Token::LParen => self.parse_query(),
197            unexpected => self.unsupported(unexpected.to_string()),
198        }
199    }
200
201    /// Parses MySQL style 'PREPARE stmt_name FROM stmt' into a (stmt_name, stmt) tuple.
202    pub fn parse_mysql_prepare_stmt(sql: &str, dialect: &dyn Dialect) -> Result<(String, String)> {
203        ParserContext::new(dialect, sql)?.parse_mysql_prepare()
204    }
205
206    /// Parses MySQL style 'EXECUTE stmt_name USING param_list' into a stmt_name string and a list of parameters.
207    pub fn parse_mysql_execute_stmt(
208        sql: &str,
209        dialect: &dyn Dialect,
210    ) -> Result<(String, Vec<Expr>)> {
211        ParserContext::new(dialect, sql)?.parse_mysql_execute()
212    }
213
214    /// Parses MySQL style 'DEALLOCATE stmt_name' into a stmt_name string.
215    pub fn parse_mysql_deallocate_stmt(sql: &str, dialect: &dyn Dialect) -> Result<String> {
216        ParserContext::new(dialect, sql)?.parse_deallocate()
217    }
218
219    /// Raises an "unsupported statement" error.
220    pub fn unsupported<T>(&self, keyword: String) -> Result<T> {
221        error::UnsupportedSnafu { keyword }.fail()
222    }
223
224    // Report unexpected token
225    pub(crate) fn expected<T>(&self, expected: &str, found: TokenWithSpan) -> Result<T> {
226        Err(ParserError::ParserError(format!(
227            "Expected {expected}, found: {found}",
228        )))
229        .context(SyntaxSnafu)
230    }
231
232    pub fn matches_keyword(&mut self, expected: Keyword) -> bool {
233        match self.parser.peek_token().token {
234            Token::Word(w) => w.keyword == expected,
235            _ => false,
236        }
237    }
238
239    pub fn consume_token(&mut self, expected: &str) -> bool {
240        if self.peek_token_as_string().to_uppercase() == *expected.to_uppercase() {
241            let _ = self.parser.next_token();
242            true
243        } else {
244            false
245        }
246    }
247
248    #[inline]
249    pub(crate) fn peek_token_as_string(&self) -> String {
250        self.parser.peek_token().to_string()
251    }
252
253    /// Canonicalize the identifier to lowercase if it's not quoted.
254    pub fn canonicalize_identifier(ident: Ident) -> Ident {
255        if ident.quote_style.is_some() {
256            ident
257        } else {
258            Ident::new(ident.value.to_lowercase())
259        }
260    }
261
262    /// Like [canonicalize_identifier] but for [ObjectName].
263    pub fn canonicalize_object_name(object_name: ObjectName) -> ObjectName {
264        ObjectName(
265            object_name
266                .0
267                .into_iter()
268                .map(Self::canonicalize_identifier)
269                .collect(),
270        )
271    }
272
273    /// Simply a shortcut for sqlparser's same name method `parse_object_name`,
274    /// but with constant argument "false".
275    /// Because the argument is always "false" for us (it's introduced by BigQuery),
276    /// we don't want to write it again and again.
277    pub(crate) fn parse_object_name(&mut self) -> std::result::Result<ObjectName, ParserError> {
278        self.parser.parse_object_name(false)
279    }
280}
281
282#[cfg(test)]
283mod tests {
284
285    use datatypes::prelude::ConcreteDataType;
286    use sqlparser::dialect::MySqlDialect;
287
288    use super::*;
289    use crate::dialect::GreptimeDbDialect;
290    use crate::statements::create::CreateTable;
291    use crate::statements::sql_data_type_to_concrete_data_type;
292
293    fn test_timestamp_precision(sql: &str, expected_type: ConcreteDataType) {
294        match ParserContext::create_with_dialect(
295            sql,
296            &GreptimeDbDialect {},
297            ParseOptions::default(),
298        )
299        .unwrap()
300        .pop()
301        .unwrap()
302        {
303            Statement::CreateTable(CreateTable { columns, .. }) => {
304                let ts_col = columns.first().unwrap();
305                assert_eq!(
306                    expected_type,
307                    sql_data_type_to_concrete_data_type(ts_col.data_type()).unwrap()
308                );
309            }
310            _ => unreachable!(),
311        }
312    }
313
314    #[test]
315    pub fn test_create_table_with_precision() {
316        test_timestamp_precision(
317            "create table demo (ts timestamp time index, cnt int);",
318            ConcreteDataType::timestamp_millisecond_datatype(),
319        );
320        test_timestamp_precision(
321            "create table demo (ts timestamp(0) time index, cnt int);",
322            ConcreteDataType::timestamp_second_datatype(),
323        );
324        test_timestamp_precision(
325            "create table demo (ts timestamp(3) time index, cnt int);",
326            ConcreteDataType::timestamp_millisecond_datatype(),
327        );
328        test_timestamp_precision(
329            "create table demo (ts timestamp(6) time index, cnt int);",
330            ConcreteDataType::timestamp_microsecond_datatype(),
331        );
332        test_timestamp_precision(
333            "create table demo (ts timestamp(9) time index, cnt int);",
334            ConcreteDataType::timestamp_nanosecond_datatype(),
335        );
336    }
337
338    #[test]
339    #[should_panic]
340    pub fn test_create_table_with_invalid_precision() {
341        test_timestamp_precision(
342            "create table demo (ts timestamp(1) time index, cnt int);",
343            ConcreteDataType::timestamp_millisecond_datatype(),
344        );
345    }
346
347    #[test]
348    pub fn test_parse_table_name() {
349        let table_name = "a.b.c";
350
351        let object_name =
352            ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
353
354        assert_eq!(object_name.0.len(), 3);
355        assert_eq!(object_name.to_string(), table_name);
356
357        let table_name = "a.b";
358
359        let object_name =
360            ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
361
362        assert_eq!(object_name.0.len(), 2);
363        assert_eq!(object_name.to_string(), table_name);
364
365        let table_name = "Test.\"public-test\"";
366
367        let object_name =
368            ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
369
370        assert_eq!(object_name.0.len(), 2);
371        assert_eq!(object_name.to_string(), table_name.to_ascii_lowercase());
372
373        let table_name = "HelloWorld";
374
375        let object_name =
376            ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
377
378        assert_eq!(object_name.0.len(), 1);
379        assert_eq!(object_name.to_string(), table_name.to_ascii_lowercase());
380    }
381
382    #[test]
383    pub fn test_parse_mysql_prepare_stmt() {
384        let sql = "PREPARE stmt1 FROM 'SELECT * FROM t1 WHERE id = ?';";
385        let (stmt_name, stmt) =
386            ParserContext::parse_mysql_prepare_stmt(sql, &MySqlDialect {}).unwrap();
387        assert_eq!(stmt_name, "stmt1");
388        assert_eq!(stmt, "SELECT * FROM t1 WHERE id = ?");
389
390        let sql = "PREPARE stmt2 FROM \"SELECT * FROM t1 WHERE id = ?\"";
391        let (stmt_name, stmt) =
392            ParserContext::parse_mysql_prepare_stmt(sql, &MySqlDialect {}).unwrap();
393        assert_eq!(stmt_name, "stmt2");
394        assert_eq!(stmt, "SELECT * FROM t1 WHERE id = ?");
395    }
396
397    #[test]
398    pub fn test_parse_mysql_execute_stmt() {
399        let sql = "EXECUTE stmt1 USING 1, 'hello';";
400        let (stmt_name, params) =
401            ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
402        assert_eq!(stmt_name, "stmt1");
403        assert_eq!(params.len(), 2);
404        assert_eq!(params[0].to_string(), "1");
405        assert_eq!(params[1].to_string(), "'hello'");
406
407        let sql = "EXECUTE stmt2;";
408        let (stmt_name, params) =
409            ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
410        assert_eq!(stmt_name, "stmt2");
411        assert_eq!(params.len(), 0);
412
413        let sql = "EXECUTE stmt3 USING 231, 'hello', \"2003-03-1\", NULL, ;";
414        let (stmt_name, params) =
415            ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
416        assert_eq!(stmt_name, "stmt3");
417        assert_eq!(params.len(), 4);
418        assert_eq!(params[0].to_string(), "231");
419        assert_eq!(params[1].to_string(), "'hello'");
420        assert_eq!(params[2].to_string(), "\"2003-03-1\"");
421        assert_eq!(params[3].to_string(), "NULL");
422    }
423
424    #[test]
425    pub fn test_parse_mysql_deallocate_stmt() {
426        let sql = "DEALLOCATE stmt1;";
427        let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap();
428        assert_eq!(stmt_name, "stmt1");
429
430        let sql = "DEALLOCATE stmt2";
431        let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap();
432        assert_eq!(stmt_name, "stmt2");
433    }
434}