sql/
parser.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::str::FromStr;
16
17use snafu::{OptionExt, ResultExt};
18use sqlparser::ast::{Ident, Query, Value};
19use sqlparser::dialect::Dialect;
20use sqlparser::keywords::Keyword;
21use sqlparser::parser::{Parser, ParserError, ParserOptions};
22use sqlparser::tokenizer::{Token, TokenWithSpan};
23
24use crate::ast::{Expr, ObjectName};
25use crate::error::{self, InvalidSqlSnafu, Result, SyntaxSnafu};
26use crate::parsers::tql_parser;
27use crate::statements::kill::Kill;
28use crate::statements::statement::Statement;
29use crate::statements::transform_statements;
30
31pub const FLOW: &str = "FLOW";
32
33/// SQL Parser options.
34#[derive(Clone, Debug, Default)]
35pub struct ParseOptions {}
36
37/// GrepTime SQL parser context, a simple wrapper for Datafusion SQL parser.
38pub struct ParserContext<'a> {
39    pub(crate) parser: Parser<'a>,
40    pub(crate) sql: &'a str,
41}
42
43impl ParserContext<'_> {
44    /// Construct a new ParserContext.
45    pub fn new<'a>(dialect: &'a dyn Dialect, sql: &'a str) -> Result<ParserContext<'a>> {
46        let parser = Parser::new(dialect)
47            .with_options(ParserOptions::new().with_trailing_commas(true))
48            .try_with_sql(sql)
49            .context(SyntaxSnafu)?;
50
51        Ok(ParserContext { parser, sql })
52    }
53
54    /// Parses parser context to Query.
55    pub fn parser_query(&mut self) -> Result<Box<Query>> {
56        self.parser.parse_query().context(SyntaxSnafu)
57    }
58
59    /// Parses SQL with given dialect
60    pub fn create_with_dialect(
61        sql: &str,
62        dialect: &dyn Dialect,
63        _opts: ParseOptions,
64    ) -> Result<Vec<Statement>> {
65        let mut stmts: Vec<Statement> = Vec::new();
66
67        let mut parser_ctx = ParserContext::new(dialect, sql)?;
68
69        let mut expecting_statement_delimiter = false;
70        loop {
71            // ignore empty statements (between successive statement delimiters)
72            while parser_ctx.parser.consume_token(&Token::SemiColon) {
73                expecting_statement_delimiter = false;
74            }
75
76            if parser_ctx.parser.peek_token() == Token::EOF {
77                break;
78            }
79            if expecting_statement_delimiter {
80                return parser_ctx.unsupported(parser_ctx.peek_token_as_string());
81            }
82
83            let statement = parser_ctx.parse_statement()?;
84            stmts.push(statement);
85            expecting_statement_delimiter = true;
86        }
87
88        transform_statements(&mut stmts)?;
89
90        Ok(stmts)
91    }
92
93    pub fn parse_table_name(sql: &str, dialect: &dyn Dialect) -> Result<ObjectName> {
94        let parser = Parser::new(dialect)
95            .with_options(ParserOptions::new().with_trailing_commas(true))
96            .try_with_sql(sql)
97            .context(SyntaxSnafu)?;
98        ParserContext { parser, sql }.intern_parse_table_name()
99    }
100
101    pub(crate) fn intern_parse_table_name(&mut self) -> Result<ObjectName> {
102        let raw_table_name =
103            self.parser
104                .parse_object_name(false)
105                .context(error::UnexpectedSnafu {
106                    expected: "a table name",
107                    actual: self.parser.peek_token().to_string(),
108                })?;
109        Self::canonicalize_object_name(raw_table_name)
110    }
111
112    pub fn parse_function(sql: &str, dialect: &dyn Dialect) -> Result<Expr> {
113        let mut parser = Parser::new(dialect)
114            .with_options(ParserOptions::new().with_trailing_commas(true))
115            .try_with_sql(sql)
116            .context(SyntaxSnafu)?;
117
118        let function_name = parser.parse_identifier().context(SyntaxSnafu)?;
119        parser
120            .parse_function(vec![function_name].into())
121            .context(SyntaxSnafu)
122    }
123
124    /// Parses parser context to a set of statements.
125    pub fn parse_statement(&mut self) -> Result<Statement> {
126        match self.parser.peek_token().token {
127            Token::Word(w) => match w.keyword {
128                Keyword::CREATE => {
129                    let _ = self.parser.next_token();
130                    self.parse_create()
131                }
132
133                Keyword::EXPLAIN => {
134                    let _ = self.parser.next_token();
135                    self.parse_explain()
136                }
137
138                Keyword::SHOW => {
139                    let _ = self.parser.next_token();
140                    self.parse_show()
141                }
142
143                Keyword::DELETE => self.parse_delete(),
144
145                Keyword::DESCRIBE | Keyword::DESC => {
146                    let _ = self.parser.next_token();
147                    self.parse_describe()
148                }
149
150                Keyword::INSERT => self.parse_insert(),
151
152                Keyword::REPLACE => self.parse_replace(),
153
154                Keyword::SELECT | Keyword::VALUES => self.parse_query(),
155
156                Keyword::WITH => self.parse_with_tql(),
157
158                Keyword::ALTER => self.parse_alter(),
159
160                Keyword::DROP => self.parse_drop(),
161
162                Keyword::COPY => self.parse_copy(),
163
164                Keyword::TRUNCATE => self.parse_truncate(),
165
166                Keyword::COMMENT => self.parse_comment(),
167
168                Keyword::SET => self.parse_set_variables(),
169
170                Keyword::ADMIN => self.parse_admin_command(),
171
172                Keyword::NoKeyword
173                    if w.quote_style.is_none() && w.value.to_uppercase() == tql_parser::TQL =>
174                {
175                    self.parse_tql(false)
176                }
177
178                Keyword::DECLARE => self.parse_declare_cursor(),
179
180                Keyword::FETCH => self.parse_fetch_cursor(),
181
182                Keyword::CLOSE => self.parse_close_cursor(),
183
184                Keyword::USE => {
185                    let _ = self.parser.next_token();
186
187                    let database_name = self.parser.parse_identifier().with_context(|_| {
188                        error::UnexpectedSnafu {
189                            expected: "a database name",
190                            actual: self.peek_token_as_string(),
191                        }
192                    })?;
193                    Ok(Statement::Use(
194                        Self::canonicalize_identifier(database_name).value,
195                    ))
196                }
197
198                Keyword::KILL => {
199                    let _ = self.parser.next_token();
200                    let kill = if self.parser.parse_keyword(Keyword::QUERY) {
201                        // MySQL KILL QUERY <connection id> statements
202                        let connection_id_exp =
203                            self.parser.parse_number_value().with_context(|_| {
204                                error::UnexpectedSnafu {
205                                    expected: "MySQL numeric connection id",
206                                    actual: self.peek_token_as_string(),
207                                }
208                            })?;
209                        let Value::Number(s, _) = connection_id_exp.value else {
210                            return error::UnexpectedTokenSnafu {
211                                expected: "MySQL numeric connection id",
212                                actual: connection_id_exp.to_string(),
213                            }
214                            .fail();
215                        };
216
217                        let connection_id = u32::from_str(&s).map_err(|_| {
218                            error::UnexpectedTokenSnafu {
219                                expected: "MySQL numeric connection id",
220                                actual: s,
221                            }
222                            .build()
223                        })?;
224                        Kill::ConnectionId(connection_id)
225                    } else {
226                        let process_id_ident =
227                            self.parser.parse_literal_string().with_context(|_| {
228                                error::UnexpectedSnafu {
229                                    expected: "process id string literal",
230                                    actual: self.peek_token_as_string(),
231                                }
232                            })?;
233                        Kill::ProcessId(process_id_ident)
234                    };
235
236                    Ok(Statement::Kill(kill))
237                }
238
239                _ => self.unsupported(self.peek_token_as_string()),
240            },
241            Token::LParen => self.parse_query(),
242            unexpected => self.unsupported(unexpected.to_string()),
243        }
244    }
245
246    /// Parses MySQL style 'PREPARE stmt_name FROM stmt' into a (stmt_name, stmt) tuple.
247    pub fn parse_mysql_prepare_stmt(sql: &str, dialect: &dyn Dialect) -> Result<(String, String)> {
248        ParserContext::new(dialect, sql)?.parse_mysql_prepare()
249    }
250
251    /// Parses MySQL style 'EXECUTE stmt_name USING param_list' into a stmt_name string and a list of parameters.
252    pub fn parse_mysql_execute_stmt(
253        sql: &str,
254        dialect: &dyn Dialect,
255    ) -> Result<(String, Vec<Expr>)> {
256        ParserContext::new(dialect, sql)?.parse_mysql_execute()
257    }
258
259    /// Parses MySQL style 'DEALLOCATE stmt_name' into a stmt_name string.
260    pub fn parse_mysql_deallocate_stmt(sql: &str, dialect: &dyn Dialect) -> Result<String> {
261        ParserContext::new(dialect, sql)?.parse_deallocate()
262    }
263
264    /// Raises an "unsupported statement" error.
265    pub fn unsupported<T>(&self, keyword: String) -> Result<T> {
266        error::UnsupportedSnafu { keyword }.fail()
267    }
268
269    // Report unexpected token
270    pub(crate) fn expected<T>(&self, expected: &str, found: TokenWithSpan) -> Result<T> {
271        Err(ParserError::ParserError(format!(
272            "Expected {expected}, found: {found}",
273        )))
274        .context(SyntaxSnafu)
275    }
276
277    pub fn matches_keyword(&mut self, expected: Keyword) -> bool {
278        match self.parser.peek_token().token {
279            Token::Word(w) => w.keyword == expected,
280            _ => false,
281        }
282    }
283
284    pub fn consume_token(&mut self, expected: &str) -> bool {
285        if self.peek_token_as_string().to_uppercase() == *expected.to_uppercase() {
286            let _ = self.parser.next_token();
287            true
288        } else {
289            false
290        }
291    }
292
293    #[inline]
294    pub(crate) fn peek_token_as_string(&self) -> String {
295        self.parser.peek_token().to_string()
296    }
297
298    /// Canonicalize the identifier to lowercase if it's not quoted.
299    pub fn canonicalize_identifier(ident: Ident) -> Ident {
300        if ident.quote_style.is_some() {
301            ident
302        } else {
303            Ident::new(ident.value.to_lowercase())
304        }
305    }
306
307    /// Like [canonicalize_identifier] but for [ObjectName].
308    pub(crate) fn canonicalize_object_name(object_name: ObjectName) -> Result<ObjectName> {
309        object_name
310            .0
311            .into_iter()
312            .map(|x| {
313                x.as_ident()
314                    .cloned()
315                    .map(Self::canonicalize_identifier)
316                    .with_context(|| InvalidSqlSnafu {
317                        msg: format!("not an ident: '{x}'"),
318                    })
319            })
320            .collect::<Result<Vec<_>>>()
321            .map(Into::into)
322    }
323
324    /// Simply a shortcut for sqlparser's same name method `parse_object_name`,
325    /// but with constant argument "false".
326    /// Because the argument is always "false" for us (it's introduced by BigQuery),
327    /// we don't want to write it again and again.
328    pub(crate) fn parse_object_name(&mut self) -> std::result::Result<ObjectName, ParserError> {
329        self.parser.parse_object_name(false)
330    }
331}
332
333#[cfg(test)]
334mod tests {
335
336    use datatypes::prelude::ConcreteDataType;
337    use sqlparser::dialect::MySqlDialect;
338
339    use super::*;
340    use crate::dialect::GreptimeDbDialect;
341    use crate::statements::create::CreateTable;
342    use crate::statements::sql_data_type_to_concrete_data_type;
343
344    fn test_timestamp_precision(sql: &str, expected_type: ConcreteDataType) {
345        match ParserContext::create_with_dialect(
346            sql,
347            &GreptimeDbDialect {},
348            ParseOptions::default(),
349        )
350        .unwrap()
351        .pop()
352        .unwrap()
353        {
354            Statement::CreateTable(CreateTable { columns, .. }) => {
355                let ts_col = columns.first().unwrap();
356                assert_eq!(
357                    expected_type,
358                    sql_data_type_to_concrete_data_type(ts_col.data_type(), &Default::default())
359                        .unwrap()
360                );
361            }
362            _ => unreachable!(),
363        }
364    }
365
366    #[test]
367    pub fn test_create_table_with_precision() {
368        test_timestamp_precision(
369            "create table demo (ts timestamp time index, cnt int);",
370            ConcreteDataType::timestamp_millisecond_datatype(),
371        );
372        test_timestamp_precision(
373            "create table demo (ts timestamp(0) time index, cnt int);",
374            ConcreteDataType::timestamp_second_datatype(),
375        );
376        test_timestamp_precision(
377            "create table demo (ts timestamp(3) time index, cnt int);",
378            ConcreteDataType::timestamp_millisecond_datatype(),
379        );
380        test_timestamp_precision(
381            "create table demo (ts timestamp(6) time index, cnt int);",
382            ConcreteDataType::timestamp_microsecond_datatype(),
383        );
384        test_timestamp_precision(
385            "create table demo (ts timestamp(9) time index, cnt int);",
386            ConcreteDataType::timestamp_nanosecond_datatype(),
387        );
388    }
389
390    #[test]
391    #[should_panic]
392    pub fn test_create_table_with_invalid_precision() {
393        test_timestamp_precision(
394            "create table demo (ts timestamp(1) time index, cnt int);",
395            ConcreteDataType::timestamp_millisecond_datatype(),
396        );
397    }
398
399    #[test]
400    pub fn test_parse_table_name() {
401        let table_name = "a.b.c";
402
403        let object_name =
404            ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
405
406        assert_eq!(object_name.0.len(), 3);
407        assert_eq!(object_name.to_string(), table_name);
408
409        let table_name = "a.b";
410
411        let object_name =
412            ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
413
414        assert_eq!(object_name.0.len(), 2);
415        assert_eq!(object_name.to_string(), table_name);
416
417        let table_name = "Test.\"public-test\"";
418
419        let object_name =
420            ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
421
422        assert_eq!(object_name.0.len(), 2);
423        assert_eq!(object_name.to_string(), table_name.to_ascii_lowercase());
424
425        let table_name = "HelloWorld";
426
427        let object_name =
428            ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
429
430        assert_eq!(object_name.0.len(), 1);
431        assert_eq!(object_name.to_string(), table_name.to_ascii_lowercase());
432    }
433
434    #[test]
435    pub fn test_parse_mysql_prepare_stmt() {
436        let sql = "PREPARE stmt1 FROM 'SELECT * FROM t1 WHERE id = ?';";
437        let (stmt_name, stmt) =
438            ParserContext::parse_mysql_prepare_stmt(sql, &MySqlDialect {}).unwrap();
439        assert_eq!(stmt_name, "stmt1");
440        assert_eq!(stmt, "SELECT * FROM t1 WHERE id = ?");
441
442        let sql = "PREPARE stmt2 FROM \"SELECT * FROM t1 WHERE id = ?\"";
443        let (stmt_name, stmt) =
444            ParserContext::parse_mysql_prepare_stmt(sql, &MySqlDialect {}).unwrap();
445        assert_eq!(stmt_name, "stmt2");
446        assert_eq!(stmt, "SELECT * FROM t1 WHERE id = ?");
447    }
448
449    #[test]
450    pub fn test_parse_mysql_execute_stmt() {
451        let sql = "EXECUTE stmt1 USING 1, 'hello';";
452        let (stmt_name, params) =
453            ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
454        assert_eq!(stmt_name, "stmt1");
455        assert_eq!(params.len(), 2);
456        assert_eq!(params[0].to_string(), "1");
457        assert_eq!(params[1].to_string(), "'hello'");
458
459        let sql = "EXECUTE stmt2;";
460        let (stmt_name, params) =
461            ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
462        assert_eq!(stmt_name, "stmt2");
463        assert_eq!(params.len(), 0);
464
465        let sql = "EXECUTE stmt3 USING 231, 'hello', \"2003-03-1\", NULL, ;";
466        let (stmt_name, params) =
467            ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
468        assert_eq!(stmt_name, "stmt3");
469        assert_eq!(params.len(), 4);
470        assert_eq!(params[0].to_string(), "231");
471        assert_eq!(params[1].to_string(), "'hello'");
472        assert_eq!(params[2].to_string(), "\"2003-03-1\"");
473        assert_eq!(params[3].to_string(), "NULL");
474    }
475
476    #[test]
477    pub fn test_parse_mysql_deallocate_stmt() {
478        let sql = "DEALLOCATE stmt1;";
479        let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap();
480        assert_eq!(stmt_name, "stmt1");
481
482        let sql = "DEALLOCATE stmt2";
483        let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap();
484        assert_eq!(stmt_name, "stmt2");
485    }
486
487    #[test]
488    pub fn test_parse_kill_query_statement() {
489        use crate::statements::kill::Kill;
490
491        // Test MySQL-style KILL QUERY with connection ID
492        let sql = "KILL QUERY 123";
493        let statements =
494            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
495                .unwrap();
496
497        assert_eq!(statements.len(), 1);
498        match &statements[0] {
499            Statement::Kill(Kill::ConnectionId(connection_id)) => {
500                assert_eq!(*connection_id, 123);
501            }
502            _ => panic!("Expected Kill::ConnectionId statement"),
503        }
504
505        // Test with larger connection ID
506        let sql = "KILL QUERY 999999";
507        let statements =
508            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
509                .unwrap();
510
511        assert_eq!(statements.len(), 1);
512        match &statements[0] {
513            Statement::Kill(Kill::ConnectionId(connection_id)) => {
514                assert_eq!(*connection_id, 999999);
515            }
516            _ => panic!("Expected Kill::ConnectionId statement"),
517        }
518    }
519
520    #[test]
521    pub fn test_parse_kill_process_statement() {
522        use crate::statements::kill::Kill;
523
524        // Test KILL with process ID string
525        let sql = "KILL 'process-123'";
526        let statements =
527            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
528                .unwrap();
529
530        assert_eq!(statements.len(), 1);
531        match &statements[0] {
532            Statement::Kill(Kill::ProcessId(process_id)) => {
533                assert_eq!(process_id, "process-123");
534            }
535            _ => panic!("Expected Kill::ProcessId statement"),
536        }
537
538        // Test with double quotes
539        let sql = "KILL \"process-456\"";
540        let statements =
541            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
542                .unwrap();
543
544        assert_eq!(statements.len(), 1);
545        match &statements[0] {
546            Statement::Kill(Kill::ProcessId(process_id)) => {
547                assert_eq!(process_id, "process-456");
548            }
549            _ => panic!("Expected Kill::ProcessId statement"),
550        }
551
552        // Test with UUID-like process ID
553        let sql = "KILL 'f47ac10b-58cc-4372-a567-0e02b2c3d479'";
554        let statements =
555            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
556                .unwrap();
557
558        assert_eq!(statements.len(), 1);
559        match &statements[0] {
560            Statement::Kill(Kill::ProcessId(process_id)) => {
561                assert_eq!(process_id, "f47ac10b-58cc-4372-a567-0e02b2c3d479");
562            }
563            _ => panic!("Expected Kill::ProcessId statement"),
564        }
565    }
566
567    #[test]
568    pub fn test_parse_kill_statement_errors() {
569        // Test KILL QUERY without connection ID
570        let sql = "KILL QUERY";
571        let result =
572            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
573        assert!(result.is_err());
574
575        // Test KILL QUERY with non-numeric connection ID
576        let sql = "KILL QUERY 'not-a-number'";
577        let result =
578            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
579        assert!(result.is_err());
580
581        // Test KILL without any argument
582        let sql = "KILL";
583        let result =
584            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
585        assert!(result.is_err());
586
587        // Test KILL QUERY with connection ID that's too large for u32
588        let sql = "KILL QUERY 4294967296"; // u32::MAX + 1
589        let result =
590            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
591        assert!(result.is_err());
592    }
593
594    #[test]
595    pub fn test_parse_kill_statement_edge_cases() {
596        use crate::statements::kill::Kill;
597
598        // Test KILL QUERY with zero connection ID
599        let sql = "KILL QUERY 0";
600        let statements =
601            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
602                .unwrap();
603
604        assert_eq!(statements.len(), 1);
605        match &statements[0] {
606            Statement::Kill(Kill::ConnectionId(connection_id)) => {
607                assert_eq!(*connection_id, 0);
608            }
609            _ => panic!("Expected Kill::ConnectionId statement"),
610        }
611
612        // Test KILL QUERY with maximum u32 value
613        let sql = "KILL QUERY 4294967295"; // u32::MAX
614        let statements =
615            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
616                .unwrap();
617
618        assert_eq!(statements.len(), 1);
619        match &statements[0] {
620            Statement::Kill(Kill::ConnectionId(connection_id)) => {
621                assert_eq!(*connection_id, 4294967295);
622            }
623            _ => panic!("Expected Kill::ConnectionId statement"),
624        }
625
626        // Test KILL with empty string process ID
627        let sql = "KILL ''";
628        let statements =
629            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
630                .unwrap();
631
632        assert_eq!(statements.len(), 1);
633        match &statements[0] {
634            Statement::Kill(Kill::ProcessId(process_id)) => {
635                assert_eq!(process_id, "");
636            }
637            _ => panic!("Expected Kill::ProcessId statement"),
638        }
639    }
640
641    #[test]
642    pub fn test_parse_kill_statement_case_insensitive() {
643        use crate::statements::kill::Kill;
644
645        // Test lowercase
646        let sql = "kill query 123";
647        let statements =
648            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
649                .unwrap();
650
651        assert_eq!(statements.len(), 1);
652        match &statements[0] {
653            Statement::Kill(Kill::ConnectionId(connection_id)) => {
654                assert_eq!(*connection_id, 123);
655            }
656            _ => panic!("Expected Kill::ConnectionId statement"),
657        }
658
659        // Test mixed case
660        let sql = "Kill Query 456";
661        let statements =
662            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
663                .unwrap();
664
665        assert_eq!(statements.len(), 1);
666        match &statements[0] {
667            Statement::Kill(Kill::ConnectionId(connection_id)) => {
668                assert_eq!(*connection_id, 456);
669            }
670            _ => panic!("Expected Kill::ConnectionId statement"),
671        }
672    }
673}