sql/
parser.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::str::FromStr;
16
17use snafu::ResultExt;
18use sqlparser::ast::{Ident, Query, Value};
19use sqlparser::dialect::Dialect;
20use sqlparser::keywords::Keyword;
21use sqlparser::parser::{Parser, ParserError, ParserOptions};
22use sqlparser::tokenizer::{Token, TokenWithSpan};
23
24use crate::ast::{Expr, ObjectName};
25use crate::error::{self, Result, SyntaxSnafu};
26use crate::parsers::tql_parser;
27use crate::statements::kill::Kill;
28use crate::statements::statement::Statement;
29use crate::statements::transform_statements;
30
31pub const FLOW: &str = "FLOW";
32
33/// SQL Parser options.
34#[derive(Clone, Debug, Default)]
35pub struct ParseOptions {}
36
37/// GrepTime SQL parser context, a simple wrapper for Datafusion SQL parser.
38pub struct ParserContext<'a> {
39    pub(crate) parser: Parser<'a>,
40    pub(crate) sql: &'a str,
41}
42
43impl ParserContext<'_> {
44    /// Construct a new ParserContext.
45    pub fn new<'a>(dialect: &'a dyn Dialect, sql: &'a str) -> Result<ParserContext<'a>> {
46        let parser = Parser::new(dialect)
47            .with_options(ParserOptions::new().with_trailing_commas(true))
48            .try_with_sql(sql)
49            .context(SyntaxSnafu)?;
50
51        Ok(ParserContext { parser, sql })
52    }
53
54    /// Parses parser context to Query.
55    pub fn parser_query(&mut self) -> Result<Box<Query>> {
56        self.parser.parse_query().context(SyntaxSnafu)
57    }
58
59    /// Parses SQL with given dialect
60    pub fn create_with_dialect(
61        sql: &str,
62        dialect: &dyn Dialect,
63        _opts: ParseOptions,
64    ) -> Result<Vec<Statement>> {
65        let mut stmts: Vec<Statement> = Vec::new();
66
67        let mut parser_ctx = ParserContext::new(dialect, sql)?;
68
69        let mut expecting_statement_delimiter = false;
70        loop {
71            // ignore empty statements (between successive statement delimiters)
72            while parser_ctx.parser.consume_token(&Token::SemiColon) {
73                expecting_statement_delimiter = false;
74            }
75
76            if parser_ctx.parser.peek_token() == Token::EOF {
77                break;
78            }
79            if expecting_statement_delimiter {
80                return parser_ctx.unsupported(parser_ctx.peek_token_as_string());
81            }
82
83            let statement = parser_ctx.parse_statement()?;
84            stmts.push(statement);
85            expecting_statement_delimiter = true;
86        }
87
88        transform_statements(&mut stmts)?;
89
90        Ok(stmts)
91    }
92
93    pub fn parse_table_name(sql: &str, dialect: &dyn Dialect) -> Result<ObjectName> {
94        let parser = Parser::new(dialect)
95            .with_options(ParserOptions::new().with_trailing_commas(true))
96            .try_with_sql(sql)
97            .context(SyntaxSnafu)?;
98        ParserContext { parser, sql }.intern_parse_table_name()
99    }
100
101    pub(crate) fn intern_parse_table_name(&mut self) -> Result<ObjectName> {
102        let raw_table_name =
103            self.parser
104                .parse_object_name(false)
105                .context(error::UnexpectedSnafu {
106                    expected: "a table name",
107                    actual: self.parser.peek_token().to_string(),
108                })?;
109        Ok(Self::canonicalize_object_name(raw_table_name))
110    }
111
112    pub fn parse_function(sql: &str, dialect: &dyn Dialect) -> Result<Expr> {
113        let mut parser = Parser::new(dialect)
114            .with_options(ParserOptions::new().with_trailing_commas(true))
115            .try_with_sql(sql)
116            .context(SyntaxSnafu)?;
117
118        let function_name = parser.parse_identifier().context(SyntaxSnafu)?;
119        parser
120            .parse_function(ObjectName(vec![function_name]))
121            .context(SyntaxSnafu)
122    }
123
124    /// Parses parser context to a set of statements.
125    pub fn parse_statement(&mut self) -> Result<Statement> {
126        match self.parser.peek_token().token {
127            Token::Word(w) => match w.keyword {
128                Keyword::CREATE => {
129                    let _ = self.parser.next_token();
130                    self.parse_create()
131                }
132
133                Keyword::EXPLAIN => {
134                    let _ = self.parser.next_token();
135                    self.parse_explain()
136                }
137
138                Keyword::SHOW => {
139                    let _ = self.parser.next_token();
140                    self.parse_show()
141                }
142
143                Keyword::DELETE => self.parse_delete(),
144
145                Keyword::DESCRIBE | Keyword::DESC => {
146                    let _ = self.parser.next_token();
147                    self.parse_describe()
148                }
149
150                Keyword::INSERT => self.parse_insert(),
151
152                Keyword::REPLACE => self.parse_replace(),
153
154                Keyword::SELECT | Keyword::VALUES => self.parse_query(),
155
156                Keyword::WITH => self.parse_with_tql(),
157
158                Keyword::ALTER => self.parse_alter(),
159
160                Keyword::DROP => self.parse_drop(),
161
162                Keyword::COPY => self.parse_copy(),
163
164                Keyword::TRUNCATE => self.parse_truncate(),
165
166                Keyword::SET => self.parse_set_variables(),
167
168                Keyword::ADMIN => self.parse_admin_command(),
169
170                Keyword::NoKeyword
171                    if w.quote_style.is_none() && w.value.to_uppercase() == tql_parser::TQL =>
172                {
173                    self.parse_tql()
174                }
175
176                Keyword::DECLARE => self.parse_declare_cursor(),
177
178                Keyword::FETCH => self.parse_fetch_cursor(),
179
180                Keyword::CLOSE => self.parse_close_cursor(),
181
182                Keyword::USE => {
183                    let _ = self.parser.next_token();
184
185                    let database_name = self.parser.parse_identifier().with_context(|_| {
186                        error::UnexpectedSnafu {
187                            expected: "a database name",
188                            actual: self.peek_token_as_string(),
189                        }
190                    })?;
191                    Ok(Statement::Use(
192                        Self::canonicalize_identifier(database_name).value,
193                    ))
194                }
195
196                Keyword::KILL => {
197                    let _ = self.parser.next_token();
198                    let kill = if self.parser.parse_keyword(Keyword::QUERY) {
199                        // MySQL KILL QUERY <connection id> statements
200                        let connection_id_exp =
201                            self.parser.parse_number_value().with_context(|_| {
202                                error::UnexpectedSnafu {
203                                    expected: "MySQL numeric connection id",
204                                    actual: self.peek_token_as_string(),
205                                }
206                            })?;
207                        let Value::Number(s, _) = connection_id_exp else {
208                            return error::UnexpectedTokenSnafu {
209                                expected: "MySQL numeric connection id",
210                                actual: connection_id_exp.to_string(),
211                            }
212                            .fail();
213                        };
214
215                        let connection_id = u32::from_str(&s).map_err(|_| {
216                            error::UnexpectedTokenSnafu {
217                                expected: "MySQL numeric connection id",
218                                actual: s,
219                            }
220                            .build()
221                        })?;
222                        Kill::ConnectionId(connection_id)
223                    } else {
224                        let process_id_ident =
225                            self.parser.parse_literal_string().with_context(|_| {
226                                error::UnexpectedSnafu {
227                                    expected: "process id string literal",
228                                    actual: self.peek_token_as_string(),
229                                }
230                            })?;
231                        Kill::ProcessId(process_id_ident)
232                    };
233
234                    Ok(Statement::Kill(kill))
235                }
236
237                _ => self.unsupported(self.peek_token_as_string()),
238            },
239            Token::LParen => self.parse_query(),
240            unexpected => self.unsupported(unexpected.to_string()),
241        }
242    }
243
244    /// Parses MySQL style 'PREPARE stmt_name FROM stmt' into a (stmt_name, stmt) tuple.
245    pub fn parse_mysql_prepare_stmt(sql: &str, dialect: &dyn Dialect) -> Result<(String, String)> {
246        ParserContext::new(dialect, sql)?.parse_mysql_prepare()
247    }
248
249    /// Parses MySQL style 'EXECUTE stmt_name USING param_list' into a stmt_name string and a list of parameters.
250    pub fn parse_mysql_execute_stmt(
251        sql: &str,
252        dialect: &dyn Dialect,
253    ) -> Result<(String, Vec<Expr>)> {
254        ParserContext::new(dialect, sql)?.parse_mysql_execute()
255    }
256
257    /// Parses MySQL style 'DEALLOCATE stmt_name' into a stmt_name string.
258    pub fn parse_mysql_deallocate_stmt(sql: &str, dialect: &dyn Dialect) -> Result<String> {
259        ParserContext::new(dialect, sql)?.parse_deallocate()
260    }
261
262    /// Raises an "unsupported statement" error.
263    pub fn unsupported<T>(&self, keyword: String) -> Result<T> {
264        error::UnsupportedSnafu { keyword }.fail()
265    }
266
267    // Report unexpected token
268    pub(crate) fn expected<T>(&self, expected: &str, found: TokenWithSpan) -> Result<T> {
269        Err(ParserError::ParserError(format!(
270            "Expected {expected}, found: {found}",
271        )))
272        .context(SyntaxSnafu)
273    }
274
275    pub fn matches_keyword(&mut self, expected: Keyword) -> bool {
276        match self.parser.peek_token().token {
277            Token::Word(w) => w.keyword == expected,
278            _ => false,
279        }
280    }
281
282    pub fn consume_token(&mut self, expected: &str) -> bool {
283        if self.peek_token_as_string().to_uppercase() == *expected.to_uppercase() {
284            let _ = self.parser.next_token();
285            true
286        } else {
287            false
288        }
289    }
290
291    #[inline]
292    pub(crate) fn peek_token_as_string(&self) -> String {
293        self.parser.peek_token().to_string()
294    }
295
296    /// Canonicalize the identifier to lowercase if it's not quoted.
297    pub fn canonicalize_identifier(ident: Ident) -> Ident {
298        if ident.quote_style.is_some() {
299            ident
300        } else {
301            Ident::new(ident.value.to_lowercase())
302        }
303    }
304
305    /// Like [canonicalize_identifier] but for [ObjectName].
306    pub fn canonicalize_object_name(object_name: ObjectName) -> ObjectName {
307        ObjectName(
308            object_name
309                .0
310                .into_iter()
311                .map(Self::canonicalize_identifier)
312                .collect(),
313        )
314    }
315
316    /// Simply a shortcut for sqlparser's same name method `parse_object_name`,
317    /// but with constant argument "false".
318    /// Because the argument is always "false" for us (it's introduced by BigQuery),
319    /// we don't want to write it again and again.
320    pub(crate) fn parse_object_name(&mut self) -> std::result::Result<ObjectName, ParserError> {
321        self.parser.parse_object_name(false)
322    }
323}
324
325#[cfg(test)]
326mod tests {
327
328    use datatypes::prelude::ConcreteDataType;
329    use sqlparser::dialect::MySqlDialect;
330
331    use super::*;
332    use crate::dialect::GreptimeDbDialect;
333    use crate::statements::create::CreateTable;
334    use crate::statements::sql_data_type_to_concrete_data_type;
335
336    fn test_timestamp_precision(sql: &str, expected_type: ConcreteDataType) {
337        match ParserContext::create_with_dialect(
338            sql,
339            &GreptimeDbDialect {},
340            ParseOptions::default(),
341        )
342        .unwrap()
343        .pop()
344        .unwrap()
345        {
346            Statement::CreateTable(CreateTable { columns, .. }) => {
347                let ts_col = columns.first().unwrap();
348                assert_eq!(
349                    expected_type,
350                    sql_data_type_to_concrete_data_type(ts_col.data_type()).unwrap()
351                );
352            }
353            _ => unreachable!(),
354        }
355    }
356
357    #[test]
358    pub fn test_create_table_with_precision() {
359        test_timestamp_precision(
360            "create table demo (ts timestamp time index, cnt int);",
361            ConcreteDataType::timestamp_millisecond_datatype(),
362        );
363        test_timestamp_precision(
364            "create table demo (ts timestamp(0) time index, cnt int);",
365            ConcreteDataType::timestamp_second_datatype(),
366        );
367        test_timestamp_precision(
368            "create table demo (ts timestamp(3) time index, cnt int);",
369            ConcreteDataType::timestamp_millisecond_datatype(),
370        );
371        test_timestamp_precision(
372            "create table demo (ts timestamp(6) time index, cnt int);",
373            ConcreteDataType::timestamp_microsecond_datatype(),
374        );
375        test_timestamp_precision(
376            "create table demo (ts timestamp(9) time index, cnt int);",
377            ConcreteDataType::timestamp_nanosecond_datatype(),
378        );
379    }
380
381    #[test]
382    #[should_panic]
383    pub fn test_create_table_with_invalid_precision() {
384        test_timestamp_precision(
385            "create table demo (ts timestamp(1) time index, cnt int);",
386            ConcreteDataType::timestamp_millisecond_datatype(),
387        );
388    }
389
390    #[test]
391    pub fn test_parse_table_name() {
392        let table_name = "a.b.c";
393
394        let object_name =
395            ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
396
397        assert_eq!(object_name.0.len(), 3);
398        assert_eq!(object_name.to_string(), table_name);
399
400        let table_name = "a.b";
401
402        let object_name =
403            ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
404
405        assert_eq!(object_name.0.len(), 2);
406        assert_eq!(object_name.to_string(), table_name);
407
408        let table_name = "Test.\"public-test\"";
409
410        let object_name =
411            ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
412
413        assert_eq!(object_name.0.len(), 2);
414        assert_eq!(object_name.to_string(), table_name.to_ascii_lowercase());
415
416        let table_name = "HelloWorld";
417
418        let object_name =
419            ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
420
421        assert_eq!(object_name.0.len(), 1);
422        assert_eq!(object_name.to_string(), table_name.to_ascii_lowercase());
423    }
424
425    #[test]
426    pub fn test_parse_mysql_prepare_stmt() {
427        let sql = "PREPARE stmt1 FROM 'SELECT * FROM t1 WHERE id = ?';";
428        let (stmt_name, stmt) =
429            ParserContext::parse_mysql_prepare_stmt(sql, &MySqlDialect {}).unwrap();
430        assert_eq!(stmt_name, "stmt1");
431        assert_eq!(stmt, "SELECT * FROM t1 WHERE id = ?");
432
433        let sql = "PREPARE stmt2 FROM \"SELECT * FROM t1 WHERE id = ?\"";
434        let (stmt_name, stmt) =
435            ParserContext::parse_mysql_prepare_stmt(sql, &MySqlDialect {}).unwrap();
436        assert_eq!(stmt_name, "stmt2");
437        assert_eq!(stmt, "SELECT * FROM t1 WHERE id = ?");
438    }
439
440    #[test]
441    pub fn test_parse_mysql_execute_stmt() {
442        let sql = "EXECUTE stmt1 USING 1, 'hello';";
443        let (stmt_name, params) =
444            ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
445        assert_eq!(stmt_name, "stmt1");
446        assert_eq!(params.len(), 2);
447        assert_eq!(params[0].to_string(), "1");
448        assert_eq!(params[1].to_string(), "'hello'");
449
450        let sql = "EXECUTE stmt2;";
451        let (stmt_name, params) =
452            ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
453        assert_eq!(stmt_name, "stmt2");
454        assert_eq!(params.len(), 0);
455
456        let sql = "EXECUTE stmt3 USING 231, 'hello', \"2003-03-1\", NULL, ;";
457        let (stmt_name, params) =
458            ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
459        assert_eq!(stmt_name, "stmt3");
460        assert_eq!(params.len(), 4);
461        assert_eq!(params[0].to_string(), "231");
462        assert_eq!(params[1].to_string(), "'hello'");
463        assert_eq!(params[2].to_string(), "\"2003-03-1\"");
464        assert_eq!(params[3].to_string(), "NULL");
465    }
466
467    #[test]
468    pub fn test_parse_mysql_deallocate_stmt() {
469        let sql = "DEALLOCATE stmt1;";
470        let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap();
471        assert_eq!(stmt_name, "stmt1");
472
473        let sql = "DEALLOCATE stmt2";
474        let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap();
475        assert_eq!(stmt_name, "stmt2");
476    }
477
478    #[test]
479    pub fn test_parse_kill_query_statement() {
480        use crate::statements::kill::Kill;
481
482        // Test MySQL-style KILL QUERY with connection ID
483        let sql = "KILL QUERY 123";
484        let statements =
485            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
486                .unwrap();
487
488        assert_eq!(statements.len(), 1);
489        match &statements[0] {
490            Statement::Kill(Kill::ConnectionId(connection_id)) => {
491                assert_eq!(*connection_id, 123);
492            }
493            _ => panic!("Expected Kill::ConnectionId statement"),
494        }
495
496        // Test with larger connection ID
497        let sql = "KILL QUERY 999999";
498        let statements =
499            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
500                .unwrap();
501
502        assert_eq!(statements.len(), 1);
503        match &statements[0] {
504            Statement::Kill(Kill::ConnectionId(connection_id)) => {
505                assert_eq!(*connection_id, 999999);
506            }
507            _ => panic!("Expected Kill::ConnectionId statement"),
508        }
509    }
510
511    #[test]
512    pub fn test_parse_kill_process_statement() {
513        use crate::statements::kill::Kill;
514
515        // Test KILL with process ID string
516        let sql = "KILL 'process-123'";
517        let statements =
518            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
519                .unwrap();
520
521        assert_eq!(statements.len(), 1);
522        match &statements[0] {
523            Statement::Kill(Kill::ProcessId(process_id)) => {
524                assert_eq!(process_id, "process-123");
525            }
526            _ => panic!("Expected Kill::ProcessId statement"),
527        }
528
529        // Test with double quotes
530        let sql = "KILL \"process-456\"";
531        let statements =
532            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
533                .unwrap();
534
535        assert_eq!(statements.len(), 1);
536        match &statements[0] {
537            Statement::Kill(Kill::ProcessId(process_id)) => {
538                assert_eq!(process_id, "process-456");
539            }
540            _ => panic!("Expected Kill::ProcessId statement"),
541        }
542
543        // Test with UUID-like process ID
544        let sql = "KILL 'f47ac10b-58cc-4372-a567-0e02b2c3d479'";
545        let statements =
546            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
547                .unwrap();
548
549        assert_eq!(statements.len(), 1);
550        match &statements[0] {
551            Statement::Kill(Kill::ProcessId(process_id)) => {
552                assert_eq!(process_id, "f47ac10b-58cc-4372-a567-0e02b2c3d479");
553            }
554            _ => panic!("Expected Kill::ProcessId statement"),
555        }
556    }
557
558    #[test]
559    pub fn test_parse_kill_statement_errors() {
560        // Test KILL QUERY without connection ID
561        let sql = "KILL QUERY";
562        let result =
563            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
564        assert!(result.is_err());
565
566        // Test KILL QUERY with non-numeric connection ID
567        let sql = "KILL QUERY 'not-a-number'";
568        let result =
569            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
570        assert!(result.is_err());
571
572        // Test KILL without any argument
573        let sql = "KILL";
574        let result =
575            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
576        assert!(result.is_err());
577
578        // Test KILL QUERY with connection ID that's too large for u32
579        let sql = "KILL QUERY 4294967296"; // u32::MAX + 1
580        let result =
581            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
582        assert!(result.is_err());
583    }
584
585    #[test]
586    pub fn test_parse_kill_statement_edge_cases() {
587        use crate::statements::kill::Kill;
588
589        // Test KILL QUERY with zero connection ID
590        let sql = "KILL QUERY 0";
591        let statements =
592            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
593                .unwrap();
594
595        assert_eq!(statements.len(), 1);
596        match &statements[0] {
597            Statement::Kill(Kill::ConnectionId(connection_id)) => {
598                assert_eq!(*connection_id, 0);
599            }
600            _ => panic!("Expected Kill::ConnectionId statement"),
601        }
602
603        // Test KILL QUERY with maximum u32 value
604        let sql = "KILL QUERY 4294967295"; // u32::MAX
605        let statements =
606            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
607                .unwrap();
608
609        assert_eq!(statements.len(), 1);
610        match &statements[0] {
611            Statement::Kill(Kill::ConnectionId(connection_id)) => {
612                assert_eq!(*connection_id, 4294967295);
613            }
614            _ => panic!("Expected Kill::ConnectionId statement"),
615        }
616
617        // Test KILL with empty string process ID
618        let sql = "KILL ''";
619        let statements =
620            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
621                .unwrap();
622
623        assert_eq!(statements.len(), 1);
624        match &statements[0] {
625            Statement::Kill(Kill::ProcessId(process_id)) => {
626                assert_eq!(process_id, "");
627            }
628            _ => panic!("Expected Kill::ProcessId statement"),
629        }
630    }
631
632    #[test]
633    pub fn test_parse_kill_statement_case_insensitive() {
634        use crate::statements::kill::Kill;
635
636        // Test lowercase
637        let sql = "kill query 123";
638        let statements =
639            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
640                .unwrap();
641
642        assert_eq!(statements.len(), 1);
643        match &statements[0] {
644            Statement::Kill(Kill::ConnectionId(connection_id)) => {
645                assert_eq!(*connection_id, 123);
646            }
647            _ => panic!("Expected Kill::ConnectionId statement"),
648        }
649
650        // Test mixed case
651        let sql = "Kill Query 456";
652        let statements =
653            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
654                .unwrap();
655
656        assert_eq!(statements.len(), 1);
657        match &statements[0] {
658            Statement::Kill(Kill::ConnectionId(connection_id)) => {
659                assert_eq!(*connection_id, 456);
660            }
661            _ => panic!("Expected Kill::ConnectionId statement"),
662        }
663    }
664}