sql/
parser.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::str::FromStr;
16
17use snafu::ResultExt;
18use sqlparser::ast::{Ident, Query, Value};
19use sqlparser::dialect::Dialect;
20use sqlparser::keywords::Keyword;
21use sqlparser::parser::{Parser, ParserError, ParserOptions};
22use sqlparser::tokenizer::{Token, TokenWithSpan};
23
24use crate::ast::{Expr, ObjectName};
25use crate::error::{self, Result, SyntaxSnafu};
26use crate::parsers::tql_parser;
27use crate::statements::kill::Kill;
28use crate::statements::statement::Statement;
29use crate::statements::transform_statements;
30
31pub const FLOW: &str = "FLOW";
32
33/// SQL Parser options.
34#[derive(Clone, Debug, Default)]
35pub struct ParseOptions {}
36
37/// GrepTime SQL parser context, a simple wrapper for Datafusion SQL parser.
38pub struct ParserContext<'a> {
39    pub(crate) parser: Parser<'a>,
40    pub(crate) sql: &'a str,
41}
42
43impl ParserContext<'_> {
44    /// Construct a new ParserContext.
45    pub fn new<'a>(dialect: &'a dyn Dialect, sql: &'a str) -> Result<ParserContext<'a>> {
46        let parser = Parser::new(dialect)
47            .with_options(ParserOptions::new().with_trailing_commas(true))
48            .try_with_sql(sql)
49            .context(SyntaxSnafu)?;
50
51        Ok(ParserContext { parser, sql })
52    }
53
54    /// Parses parser context to Query.
55    pub fn parser_query(&mut self) -> Result<Box<Query>> {
56        self.parser.parse_query().context(SyntaxSnafu)
57    }
58
59    /// Parses SQL with given dialect
60    pub fn create_with_dialect(
61        sql: &str,
62        dialect: &dyn Dialect,
63        _opts: ParseOptions,
64    ) -> Result<Vec<Statement>> {
65        let mut stmts: Vec<Statement> = Vec::new();
66
67        let mut parser_ctx = ParserContext::new(dialect, sql)?;
68
69        let mut expecting_statement_delimiter = false;
70        loop {
71            // ignore empty statements (between successive statement delimiters)
72            while parser_ctx.parser.consume_token(&Token::SemiColon) {
73                expecting_statement_delimiter = false;
74            }
75
76            if parser_ctx.parser.peek_token() == Token::EOF {
77                break;
78            }
79            if expecting_statement_delimiter {
80                return parser_ctx.unsupported(parser_ctx.peek_token_as_string());
81            }
82
83            let statement = parser_ctx.parse_statement()?;
84            stmts.push(statement);
85            expecting_statement_delimiter = true;
86        }
87
88        transform_statements(&mut stmts)?;
89
90        Ok(stmts)
91    }
92
93    pub fn parse_table_name(sql: &str, dialect: &dyn Dialect) -> Result<ObjectName> {
94        let parser = Parser::new(dialect)
95            .with_options(ParserOptions::new().with_trailing_commas(true))
96            .try_with_sql(sql)
97            .context(SyntaxSnafu)?;
98        ParserContext { parser, sql }.intern_parse_table_name()
99    }
100
101    pub(crate) fn intern_parse_table_name(&mut self) -> Result<ObjectName> {
102        let raw_table_name =
103            self.parser
104                .parse_object_name(false)
105                .context(error::UnexpectedSnafu {
106                    expected: "a table name",
107                    actual: self.parser.peek_token().to_string(),
108                })?;
109        Ok(Self::canonicalize_object_name(raw_table_name))
110    }
111
112    pub fn parse_function(sql: &str, dialect: &dyn Dialect) -> Result<Expr> {
113        let mut parser = Parser::new(dialect)
114            .with_options(ParserOptions::new().with_trailing_commas(true))
115            .try_with_sql(sql)
116            .context(SyntaxSnafu)?;
117
118        let function_name = parser.parse_identifier().context(SyntaxSnafu)?;
119        parser
120            .parse_function(ObjectName(vec![function_name]))
121            .context(SyntaxSnafu)
122    }
123
124    /// Parses parser context to a set of statements.
125    pub fn parse_statement(&mut self) -> Result<Statement> {
126        match self.parser.peek_token().token {
127            Token::Word(w) => match w.keyword {
128                Keyword::CREATE => {
129                    let _ = self.parser.next_token();
130                    self.parse_create()
131                }
132
133                Keyword::EXPLAIN => {
134                    let _ = self.parser.next_token();
135                    self.parse_explain()
136                }
137
138                Keyword::SHOW => {
139                    let _ = self.parser.next_token();
140                    self.parse_show()
141                }
142
143                Keyword::DELETE => self.parse_delete(),
144
145                Keyword::DESCRIBE | Keyword::DESC => {
146                    let _ = self.parser.next_token();
147                    self.parse_describe()
148                }
149
150                Keyword::INSERT => self.parse_insert(),
151
152                Keyword::REPLACE => self.parse_replace(),
153
154                Keyword::SELECT | Keyword::WITH | Keyword::VALUES => self.parse_query(),
155
156                Keyword::ALTER => self.parse_alter(),
157
158                Keyword::DROP => self.parse_drop(),
159
160                Keyword::COPY => self.parse_copy(),
161
162                Keyword::TRUNCATE => self.parse_truncate(),
163
164                Keyword::SET => self.parse_set_variables(),
165
166                Keyword::ADMIN => self.parse_admin_command(),
167
168                Keyword::NoKeyword
169                    if w.quote_style.is_none() && w.value.to_uppercase() == tql_parser::TQL =>
170                {
171                    self.parse_tql()
172                }
173
174                Keyword::DECLARE => self.parse_declare_cursor(),
175
176                Keyword::FETCH => self.parse_fetch_cursor(),
177
178                Keyword::CLOSE => self.parse_close_cursor(),
179
180                Keyword::USE => {
181                    let _ = self.parser.next_token();
182
183                    let database_name = self.parser.parse_identifier().with_context(|_| {
184                        error::UnexpectedSnafu {
185                            expected: "a database name",
186                            actual: self.peek_token_as_string(),
187                        }
188                    })?;
189                    Ok(Statement::Use(
190                        Self::canonicalize_identifier(database_name).value,
191                    ))
192                }
193
194                Keyword::KILL => {
195                    let _ = self.parser.next_token();
196                    let kill = if self.parser.parse_keyword(Keyword::QUERY) {
197                        // MySQL KILL QUERY <connection id> statements
198                        let connection_id_exp =
199                            self.parser.parse_number_value().with_context(|_| {
200                                error::UnexpectedSnafu {
201                                    expected: "MySQL numeric connection id",
202                                    actual: self.peek_token_as_string(),
203                                }
204                            })?;
205                        let Value::Number(s, _) = connection_id_exp else {
206                            return error::UnexpectedTokenSnafu {
207                                expected: "MySQL numeric connection id",
208                                actual: connection_id_exp.to_string(),
209                            }
210                            .fail();
211                        };
212
213                        let connection_id = u32::from_str(&s).map_err(|_| {
214                            error::UnexpectedTokenSnafu {
215                                expected: "MySQL numeric connection id",
216                                actual: s,
217                            }
218                            .build()
219                        })?;
220                        Kill::ConnectionId(connection_id)
221                    } else {
222                        let process_id_ident =
223                            self.parser.parse_literal_string().with_context(|_| {
224                                error::UnexpectedSnafu {
225                                    expected: "process id string literal",
226                                    actual: self.peek_token_as_string(),
227                                }
228                            })?;
229                        Kill::ProcessId(process_id_ident)
230                    };
231
232                    Ok(Statement::Kill(kill))
233                }
234
235                _ => self.unsupported(self.peek_token_as_string()),
236            },
237            Token::LParen => self.parse_query(),
238            unexpected => self.unsupported(unexpected.to_string()),
239        }
240    }
241
242    /// Parses MySQL style 'PREPARE stmt_name FROM stmt' into a (stmt_name, stmt) tuple.
243    pub fn parse_mysql_prepare_stmt(sql: &str, dialect: &dyn Dialect) -> Result<(String, String)> {
244        ParserContext::new(dialect, sql)?.parse_mysql_prepare()
245    }
246
247    /// Parses MySQL style 'EXECUTE stmt_name USING param_list' into a stmt_name string and a list of parameters.
248    pub fn parse_mysql_execute_stmt(
249        sql: &str,
250        dialect: &dyn Dialect,
251    ) -> Result<(String, Vec<Expr>)> {
252        ParserContext::new(dialect, sql)?.parse_mysql_execute()
253    }
254
255    /// Parses MySQL style 'DEALLOCATE stmt_name' into a stmt_name string.
256    pub fn parse_mysql_deallocate_stmt(sql: &str, dialect: &dyn Dialect) -> Result<String> {
257        ParserContext::new(dialect, sql)?.parse_deallocate()
258    }
259
260    /// Raises an "unsupported statement" error.
261    pub fn unsupported<T>(&self, keyword: String) -> Result<T> {
262        error::UnsupportedSnafu { keyword }.fail()
263    }
264
265    // Report unexpected token
266    pub(crate) fn expected<T>(&self, expected: &str, found: TokenWithSpan) -> Result<T> {
267        Err(ParserError::ParserError(format!(
268            "Expected {expected}, found: {found}",
269        )))
270        .context(SyntaxSnafu)
271    }
272
273    pub fn matches_keyword(&mut self, expected: Keyword) -> bool {
274        match self.parser.peek_token().token {
275            Token::Word(w) => w.keyword == expected,
276            _ => false,
277        }
278    }
279
280    pub fn consume_token(&mut self, expected: &str) -> bool {
281        if self.peek_token_as_string().to_uppercase() == *expected.to_uppercase() {
282            let _ = self.parser.next_token();
283            true
284        } else {
285            false
286        }
287    }
288
289    #[inline]
290    pub(crate) fn peek_token_as_string(&self) -> String {
291        self.parser.peek_token().to_string()
292    }
293
294    /// Canonicalize the identifier to lowercase if it's not quoted.
295    pub fn canonicalize_identifier(ident: Ident) -> Ident {
296        if ident.quote_style.is_some() {
297            ident
298        } else {
299            Ident::new(ident.value.to_lowercase())
300        }
301    }
302
303    /// Like [canonicalize_identifier] but for [ObjectName].
304    pub fn canonicalize_object_name(object_name: ObjectName) -> ObjectName {
305        ObjectName(
306            object_name
307                .0
308                .into_iter()
309                .map(Self::canonicalize_identifier)
310                .collect(),
311        )
312    }
313
314    /// Simply a shortcut for sqlparser's same name method `parse_object_name`,
315    /// but with constant argument "false".
316    /// Because the argument is always "false" for us (it's introduced by BigQuery),
317    /// we don't want to write it again and again.
318    pub(crate) fn parse_object_name(&mut self) -> std::result::Result<ObjectName, ParserError> {
319        self.parser.parse_object_name(false)
320    }
321}
322
323#[cfg(test)]
324mod tests {
325
326    use datatypes::prelude::ConcreteDataType;
327    use sqlparser::dialect::MySqlDialect;
328
329    use super::*;
330    use crate::dialect::GreptimeDbDialect;
331    use crate::statements::create::CreateTable;
332    use crate::statements::sql_data_type_to_concrete_data_type;
333
334    fn test_timestamp_precision(sql: &str, expected_type: ConcreteDataType) {
335        match ParserContext::create_with_dialect(
336            sql,
337            &GreptimeDbDialect {},
338            ParseOptions::default(),
339        )
340        .unwrap()
341        .pop()
342        .unwrap()
343        {
344            Statement::CreateTable(CreateTable { columns, .. }) => {
345                let ts_col = columns.first().unwrap();
346                assert_eq!(
347                    expected_type,
348                    sql_data_type_to_concrete_data_type(ts_col.data_type()).unwrap()
349                );
350            }
351            _ => unreachable!(),
352        }
353    }
354
355    #[test]
356    pub fn test_create_table_with_precision() {
357        test_timestamp_precision(
358            "create table demo (ts timestamp time index, cnt int);",
359            ConcreteDataType::timestamp_millisecond_datatype(),
360        );
361        test_timestamp_precision(
362            "create table demo (ts timestamp(0) time index, cnt int);",
363            ConcreteDataType::timestamp_second_datatype(),
364        );
365        test_timestamp_precision(
366            "create table demo (ts timestamp(3) time index, cnt int);",
367            ConcreteDataType::timestamp_millisecond_datatype(),
368        );
369        test_timestamp_precision(
370            "create table demo (ts timestamp(6) time index, cnt int);",
371            ConcreteDataType::timestamp_microsecond_datatype(),
372        );
373        test_timestamp_precision(
374            "create table demo (ts timestamp(9) time index, cnt int);",
375            ConcreteDataType::timestamp_nanosecond_datatype(),
376        );
377    }
378
379    #[test]
380    #[should_panic]
381    pub fn test_create_table_with_invalid_precision() {
382        test_timestamp_precision(
383            "create table demo (ts timestamp(1) time index, cnt int);",
384            ConcreteDataType::timestamp_millisecond_datatype(),
385        );
386    }
387
388    #[test]
389    pub fn test_parse_table_name() {
390        let table_name = "a.b.c";
391
392        let object_name =
393            ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
394
395        assert_eq!(object_name.0.len(), 3);
396        assert_eq!(object_name.to_string(), table_name);
397
398        let table_name = "a.b";
399
400        let object_name =
401            ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
402
403        assert_eq!(object_name.0.len(), 2);
404        assert_eq!(object_name.to_string(), table_name);
405
406        let table_name = "Test.\"public-test\"";
407
408        let object_name =
409            ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
410
411        assert_eq!(object_name.0.len(), 2);
412        assert_eq!(object_name.to_string(), table_name.to_ascii_lowercase());
413
414        let table_name = "HelloWorld";
415
416        let object_name =
417            ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
418
419        assert_eq!(object_name.0.len(), 1);
420        assert_eq!(object_name.to_string(), table_name.to_ascii_lowercase());
421    }
422
423    #[test]
424    pub fn test_parse_mysql_prepare_stmt() {
425        let sql = "PREPARE stmt1 FROM 'SELECT * FROM t1 WHERE id = ?';";
426        let (stmt_name, stmt) =
427            ParserContext::parse_mysql_prepare_stmt(sql, &MySqlDialect {}).unwrap();
428        assert_eq!(stmt_name, "stmt1");
429        assert_eq!(stmt, "SELECT * FROM t1 WHERE id = ?");
430
431        let sql = "PREPARE stmt2 FROM \"SELECT * FROM t1 WHERE id = ?\"";
432        let (stmt_name, stmt) =
433            ParserContext::parse_mysql_prepare_stmt(sql, &MySqlDialect {}).unwrap();
434        assert_eq!(stmt_name, "stmt2");
435        assert_eq!(stmt, "SELECT * FROM t1 WHERE id = ?");
436    }
437
438    #[test]
439    pub fn test_parse_mysql_execute_stmt() {
440        let sql = "EXECUTE stmt1 USING 1, 'hello';";
441        let (stmt_name, params) =
442            ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
443        assert_eq!(stmt_name, "stmt1");
444        assert_eq!(params.len(), 2);
445        assert_eq!(params[0].to_string(), "1");
446        assert_eq!(params[1].to_string(), "'hello'");
447
448        let sql = "EXECUTE stmt2;";
449        let (stmt_name, params) =
450            ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
451        assert_eq!(stmt_name, "stmt2");
452        assert_eq!(params.len(), 0);
453
454        let sql = "EXECUTE stmt3 USING 231, 'hello', \"2003-03-1\", NULL, ;";
455        let (stmt_name, params) =
456            ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
457        assert_eq!(stmt_name, "stmt3");
458        assert_eq!(params.len(), 4);
459        assert_eq!(params[0].to_string(), "231");
460        assert_eq!(params[1].to_string(), "'hello'");
461        assert_eq!(params[2].to_string(), "\"2003-03-1\"");
462        assert_eq!(params[3].to_string(), "NULL");
463    }
464
465    #[test]
466    pub fn test_parse_mysql_deallocate_stmt() {
467        let sql = "DEALLOCATE stmt1;";
468        let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap();
469        assert_eq!(stmt_name, "stmt1");
470
471        let sql = "DEALLOCATE stmt2";
472        let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap();
473        assert_eq!(stmt_name, "stmt2");
474    }
475
476    #[test]
477    pub fn test_parse_kill_query_statement() {
478        use crate::statements::kill::Kill;
479
480        // Test MySQL-style KILL QUERY with connection ID
481        let sql = "KILL QUERY 123";
482        let statements =
483            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
484                .unwrap();
485
486        assert_eq!(statements.len(), 1);
487        match &statements[0] {
488            Statement::Kill(Kill::ConnectionId(connection_id)) => {
489                assert_eq!(*connection_id, 123);
490            }
491            _ => panic!("Expected Kill::ConnectionId statement"),
492        }
493
494        // Test with larger connection ID
495        let sql = "KILL QUERY 999999";
496        let statements =
497            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
498                .unwrap();
499
500        assert_eq!(statements.len(), 1);
501        match &statements[0] {
502            Statement::Kill(Kill::ConnectionId(connection_id)) => {
503                assert_eq!(*connection_id, 999999);
504            }
505            _ => panic!("Expected Kill::ConnectionId statement"),
506        }
507    }
508
509    #[test]
510    pub fn test_parse_kill_process_statement() {
511        use crate::statements::kill::Kill;
512
513        // Test KILL with process ID string
514        let sql = "KILL 'process-123'";
515        let statements =
516            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
517                .unwrap();
518
519        assert_eq!(statements.len(), 1);
520        match &statements[0] {
521            Statement::Kill(Kill::ProcessId(process_id)) => {
522                assert_eq!(process_id, "process-123");
523            }
524            _ => panic!("Expected Kill::ProcessId statement"),
525        }
526
527        // Test with double quotes
528        let sql = "KILL \"process-456\"";
529        let statements =
530            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
531                .unwrap();
532
533        assert_eq!(statements.len(), 1);
534        match &statements[0] {
535            Statement::Kill(Kill::ProcessId(process_id)) => {
536                assert_eq!(process_id, "process-456");
537            }
538            _ => panic!("Expected Kill::ProcessId statement"),
539        }
540
541        // Test with UUID-like process ID
542        let sql = "KILL 'f47ac10b-58cc-4372-a567-0e02b2c3d479'";
543        let statements =
544            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
545                .unwrap();
546
547        assert_eq!(statements.len(), 1);
548        match &statements[0] {
549            Statement::Kill(Kill::ProcessId(process_id)) => {
550                assert_eq!(process_id, "f47ac10b-58cc-4372-a567-0e02b2c3d479");
551            }
552            _ => panic!("Expected Kill::ProcessId statement"),
553        }
554    }
555
556    #[test]
557    pub fn test_parse_kill_statement_errors() {
558        // Test KILL QUERY without connection ID
559        let sql = "KILL QUERY";
560        let result =
561            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
562        assert!(result.is_err());
563
564        // Test KILL QUERY with non-numeric connection ID
565        let sql = "KILL QUERY 'not-a-number'";
566        let result =
567            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
568        assert!(result.is_err());
569
570        // Test KILL without any argument
571        let sql = "KILL";
572        let result =
573            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
574        assert!(result.is_err());
575
576        // Test KILL QUERY with connection ID that's too large for u32
577        let sql = "KILL QUERY 4294967296"; // u32::MAX + 1
578        let result =
579            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
580        assert!(result.is_err());
581    }
582
583    #[test]
584    pub fn test_parse_kill_statement_edge_cases() {
585        use crate::statements::kill::Kill;
586
587        // Test KILL QUERY with zero connection ID
588        let sql = "KILL QUERY 0";
589        let statements =
590            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
591                .unwrap();
592
593        assert_eq!(statements.len(), 1);
594        match &statements[0] {
595            Statement::Kill(Kill::ConnectionId(connection_id)) => {
596                assert_eq!(*connection_id, 0);
597            }
598            _ => panic!("Expected Kill::ConnectionId statement"),
599        }
600
601        // Test KILL QUERY with maximum u32 value
602        let sql = "KILL QUERY 4294967295"; // u32::MAX
603        let statements =
604            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
605                .unwrap();
606
607        assert_eq!(statements.len(), 1);
608        match &statements[0] {
609            Statement::Kill(Kill::ConnectionId(connection_id)) => {
610                assert_eq!(*connection_id, 4294967295);
611            }
612            _ => panic!("Expected Kill::ConnectionId statement"),
613        }
614
615        // Test KILL with empty string process ID
616        let sql = "KILL ''";
617        let statements =
618            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
619                .unwrap();
620
621        assert_eq!(statements.len(), 1);
622        match &statements[0] {
623            Statement::Kill(Kill::ProcessId(process_id)) => {
624                assert_eq!(process_id, "");
625            }
626            _ => panic!("Expected Kill::ProcessId statement"),
627        }
628    }
629
630    #[test]
631    pub fn test_parse_kill_statement_case_insensitive() {
632        use crate::statements::kill::Kill;
633
634        // Test lowercase
635        let sql = "kill query 123";
636        let statements =
637            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
638                .unwrap();
639
640        assert_eq!(statements.len(), 1);
641        match &statements[0] {
642            Statement::Kill(Kill::ConnectionId(connection_id)) => {
643                assert_eq!(*connection_id, 123);
644            }
645            _ => panic!("Expected Kill::ConnectionId statement"),
646        }
647
648        // Test mixed case
649        let sql = "Kill Query 456";
650        let statements =
651            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
652                .unwrap();
653
654        assert_eq!(statements.len(), 1);
655        match &statements[0] {
656            Statement::Kill(Kill::ConnectionId(connection_id)) => {
657                assert_eq!(*connection_id, 456);
658            }
659            _ => panic!("Expected Kill::ConnectionId statement"),
660        }
661    }
662}