1use std::str::FromStr;
16
17use snafu::{OptionExt, ResultExt};
18use sqlparser::ast::{Ident, Query, Value};
19use sqlparser::dialect::Dialect;
20use sqlparser::keywords::Keyword;
21use sqlparser::parser::{Parser, ParserError, ParserOptions};
22use sqlparser::tokenizer::{Token, TokenWithSpan};
23
24use crate::ast::{Expr, ObjectName};
25use crate::error::{self, InvalidSqlSnafu, Result, SyntaxSnafu};
26use crate::parsers::tql_parser;
27use crate::statements::kill::Kill;
28use crate::statements::statement::Statement;
29use crate::statements::transform_statements;
30
31pub const FLOW: &str = "FLOW";
32
33#[derive(Clone, Debug, Default)]
35pub struct ParseOptions {}
36
37pub struct ParserContext<'a> {
39 pub(crate) parser: Parser<'a>,
40 pub(crate) sql: &'a str,
41}
42
43impl ParserContext<'_> {
44 pub fn new<'a>(dialect: &'a dyn Dialect, sql: &'a str) -> Result<ParserContext<'a>> {
46 let parser = Parser::new(dialect)
47 .with_options(ParserOptions::new().with_trailing_commas(true))
48 .try_with_sql(sql)
49 .context(SyntaxSnafu)?;
50
51 Ok(ParserContext { parser, sql })
52 }
53
54 pub fn parser_query(&mut self) -> Result<Box<Query>> {
56 self.parser.parse_query().context(SyntaxSnafu)
57 }
58
59 pub fn create_with_dialect(
61 sql: &str,
62 dialect: &dyn Dialect,
63 _opts: ParseOptions,
64 ) -> Result<Vec<Statement>> {
65 let mut stmts: Vec<Statement> = Vec::new();
66
67 let mut parser_ctx = ParserContext::new(dialect, sql)?;
68
69 let mut expecting_statement_delimiter = false;
70 loop {
71 while parser_ctx.parser.consume_token(&Token::SemiColon) {
73 expecting_statement_delimiter = false;
74 }
75
76 if parser_ctx.parser.peek_token() == Token::EOF {
77 break;
78 }
79 if expecting_statement_delimiter {
80 return parser_ctx.unsupported(parser_ctx.peek_token_as_string());
81 }
82
83 let statement = parser_ctx.parse_statement()?;
84 stmts.push(statement);
85 expecting_statement_delimiter = true;
86 }
87
88 transform_statements(&mut stmts)?;
89
90 Ok(stmts)
91 }
92
93 pub fn parse_table_name(sql: &str, dialect: &dyn Dialect) -> Result<ObjectName> {
94 let parser = Parser::new(dialect)
95 .with_options(ParserOptions::new().with_trailing_commas(true))
96 .try_with_sql(sql)
97 .context(SyntaxSnafu)?;
98 ParserContext { parser, sql }.intern_parse_table_name()
99 }
100
101 pub(crate) fn intern_parse_table_name(&mut self) -> Result<ObjectName> {
102 let raw_table_name =
103 self.parser
104 .parse_object_name(false)
105 .context(error::UnexpectedSnafu {
106 expected: "a table name",
107 actual: self.parser.peek_token().to_string(),
108 })?;
109 Self::canonicalize_object_name(raw_table_name)
110 }
111
112 pub fn parse_function(sql: &str, dialect: &dyn Dialect) -> Result<Expr> {
113 let mut parser = Parser::new(dialect)
114 .with_options(ParserOptions::new().with_trailing_commas(true))
115 .try_with_sql(sql)
116 .context(SyntaxSnafu)?;
117
118 let function_name = parser.parse_identifier().context(SyntaxSnafu)?;
119 parser
120 .parse_function(vec![function_name].into())
121 .context(SyntaxSnafu)
122 }
123
124 pub fn parse_statement(&mut self) -> Result<Statement> {
126 match self.parser.peek_token().token {
127 Token::Word(w) => match w.keyword {
128 Keyword::CREATE => {
129 let _ = self.parser.next_token();
130 self.parse_create()
131 }
132
133 Keyword::EXPLAIN => {
134 let _ = self.parser.next_token();
135 self.parse_explain()
136 }
137
138 Keyword::SHOW => {
139 let _ = self.parser.next_token();
140 self.parse_show()
141 }
142
143 Keyword::DELETE => self.parse_delete(),
144
145 Keyword::DESCRIBE | Keyword::DESC => {
146 let _ = self.parser.next_token();
147 self.parse_describe()
148 }
149
150 Keyword::INSERT => self.parse_insert(),
151
152 Keyword::REPLACE => self.parse_replace(),
153
154 Keyword::SELECT | Keyword::VALUES => self.parse_query(),
155
156 Keyword::WITH => self.parse_with_tql(),
157
158 Keyword::ALTER => self.parse_alter(),
159
160 Keyword::DROP => self.parse_drop(),
161
162 Keyword::COPY => self.parse_copy(),
163
164 Keyword::TRUNCATE => self.parse_truncate(),
165
166 Keyword::COMMENT => self.parse_comment(),
167
168 Keyword::SET => self.parse_set_variables(),
169
170 Keyword::ADMIN => self.parse_admin_command(),
171
172 Keyword::NoKeyword
173 if w.quote_style.is_none() && w.value.to_uppercase() == tql_parser::TQL =>
174 {
175 self.parse_tql(false)
176 }
177
178 Keyword::DECLARE => self.parse_declare_cursor(),
179
180 Keyword::FETCH => self.parse_fetch_cursor(),
181
182 Keyword::CLOSE => self.parse_close_cursor(),
183
184 Keyword::USE => {
185 let _ = self.parser.next_token();
186
187 let database_name = self.parser.parse_identifier().with_context(|_| {
188 error::UnexpectedSnafu {
189 expected: "a database name",
190 actual: self.peek_token_as_string(),
191 }
192 })?;
193 Ok(Statement::Use(
194 Self::canonicalize_identifier(database_name).value,
195 ))
196 }
197
198 Keyword::KILL => {
199 let _ = self.parser.next_token();
200 let kill = if self.parser.parse_keyword(Keyword::QUERY) {
201 let connection_id_exp =
203 self.parser.parse_number_value().with_context(|_| {
204 error::UnexpectedSnafu {
205 expected: "MySQL numeric connection id",
206 actual: self.peek_token_as_string(),
207 }
208 })?;
209 let Value::Number(s, _) = connection_id_exp.value else {
210 return error::UnexpectedTokenSnafu {
211 expected: "MySQL numeric connection id",
212 actual: connection_id_exp.to_string(),
213 }
214 .fail();
215 };
216
217 let connection_id = u32::from_str(&s).map_err(|_| {
218 error::UnexpectedTokenSnafu {
219 expected: "MySQL numeric connection id",
220 actual: s,
221 }
222 .build()
223 })?;
224 Kill::ConnectionId(connection_id)
225 } else {
226 let process_id_ident =
227 self.parser.parse_literal_string().with_context(|_| {
228 error::UnexpectedSnafu {
229 expected: "process id string literal",
230 actual: self.peek_token_as_string(),
231 }
232 })?;
233 Kill::ProcessId(process_id_ident)
234 };
235
236 Ok(Statement::Kill(kill))
237 }
238
239 _ => self.unsupported(self.peek_token_as_string()),
240 },
241 Token::LParen => self.parse_query(),
242 unexpected => self.unsupported(unexpected.to_string()),
243 }
244 }
245
246 pub fn parse_mysql_prepare_stmt(sql: &str, dialect: &dyn Dialect) -> Result<(String, String)> {
248 ParserContext::new(dialect, sql)?.parse_mysql_prepare()
249 }
250
251 pub fn parse_mysql_execute_stmt(
253 sql: &str,
254 dialect: &dyn Dialect,
255 ) -> Result<(String, Vec<Expr>)> {
256 ParserContext::new(dialect, sql)?.parse_mysql_execute()
257 }
258
259 pub fn parse_mysql_deallocate_stmt(sql: &str, dialect: &dyn Dialect) -> Result<String> {
261 ParserContext::new(dialect, sql)?.parse_deallocate()
262 }
263
264 pub fn unsupported<T>(&self, keyword: String) -> Result<T> {
266 error::UnsupportedSnafu { keyword }.fail()
267 }
268
269 pub(crate) fn expected<T>(&self, expected: &str, found: TokenWithSpan) -> Result<T> {
271 Err(ParserError::ParserError(format!(
272 "Expected {expected}, found: {found}",
273 )))
274 .context(SyntaxSnafu)
275 }
276
277 pub fn matches_keyword(&mut self, expected: Keyword) -> bool {
278 match self.parser.peek_token().token {
279 Token::Word(w) => w.keyword == expected,
280 _ => false,
281 }
282 }
283
284 pub fn consume_token(&mut self, expected: &str) -> bool {
285 if self.peek_token_as_string().to_uppercase() == *expected.to_uppercase() {
286 let _ = self.parser.next_token();
287 true
288 } else {
289 false
290 }
291 }
292
293 #[inline]
294 pub(crate) fn peek_token_as_string(&self) -> String {
295 self.parser.peek_token().to_string()
296 }
297
298 pub fn canonicalize_identifier(ident: Ident) -> Ident {
300 if ident.quote_style.is_some() {
301 ident
302 } else {
303 Ident::new(ident.value.to_lowercase())
304 }
305 }
306
307 pub(crate) fn canonicalize_object_name(object_name: ObjectName) -> Result<ObjectName> {
309 object_name
310 .0
311 .into_iter()
312 .map(|x| {
313 x.as_ident()
314 .cloned()
315 .map(Self::canonicalize_identifier)
316 .with_context(|| InvalidSqlSnafu {
317 msg: format!("not an ident: '{x}'"),
318 })
319 })
320 .collect::<Result<Vec<_>>>()
321 .map(Into::into)
322 }
323
324 pub(crate) fn parse_object_name(&mut self) -> std::result::Result<ObjectName, ParserError> {
329 self.parser.parse_object_name(false)
330 }
331}
332
333#[cfg(test)]
334mod tests {
335
336 use datatypes::prelude::ConcreteDataType;
337 use sqlparser::dialect::MySqlDialect;
338
339 use super::*;
340 use crate::dialect::GreptimeDbDialect;
341 use crate::statements::create::CreateTable;
342 use crate::statements::sql_data_type_to_concrete_data_type;
343
344 fn test_timestamp_precision(sql: &str, expected_type: ConcreteDataType) {
345 match ParserContext::create_with_dialect(
346 sql,
347 &GreptimeDbDialect {},
348 ParseOptions::default(),
349 )
350 .unwrap()
351 .pop()
352 .unwrap()
353 {
354 Statement::CreateTable(CreateTable { columns, .. }) => {
355 let ts_col = columns.first().unwrap();
356 assert_eq!(
357 expected_type,
358 sql_data_type_to_concrete_data_type(ts_col.data_type(), &Default::default())
359 .unwrap()
360 );
361 }
362 _ => unreachable!(),
363 }
364 }
365
366 #[test]
367 pub fn test_create_table_with_precision() {
368 test_timestamp_precision(
369 "create table demo (ts timestamp time index, cnt int);",
370 ConcreteDataType::timestamp_millisecond_datatype(),
371 );
372 test_timestamp_precision(
373 "create table demo (ts timestamp(0) time index, cnt int);",
374 ConcreteDataType::timestamp_second_datatype(),
375 );
376 test_timestamp_precision(
377 "create table demo (ts timestamp(3) time index, cnt int);",
378 ConcreteDataType::timestamp_millisecond_datatype(),
379 );
380 test_timestamp_precision(
381 "create table demo (ts timestamp(6) time index, cnt int);",
382 ConcreteDataType::timestamp_microsecond_datatype(),
383 );
384 test_timestamp_precision(
385 "create table demo (ts timestamp(9) time index, cnt int);",
386 ConcreteDataType::timestamp_nanosecond_datatype(),
387 );
388 }
389
390 #[test]
391 #[should_panic]
392 pub fn test_create_table_with_invalid_precision() {
393 test_timestamp_precision(
394 "create table demo (ts timestamp(1) time index, cnt int);",
395 ConcreteDataType::timestamp_millisecond_datatype(),
396 );
397 }
398
399 #[test]
400 pub fn test_parse_table_name() {
401 let table_name = "a.b.c";
402
403 let object_name =
404 ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
405
406 assert_eq!(object_name.0.len(), 3);
407 assert_eq!(object_name.to_string(), table_name);
408
409 let table_name = "a.b";
410
411 let object_name =
412 ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
413
414 assert_eq!(object_name.0.len(), 2);
415 assert_eq!(object_name.to_string(), table_name);
416
417 let table_name = "Test.\"public-test\"";
418
419 let object_name =
420 ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
421
422 assert_eq!(object_name.0.len(), 2);
423 assert_eq!(object_name.to_string(), table_name.to_ascii_lowercase());
424
425 let table_name = "HelloWorld";
426
427 let object_name =
428 ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
429
430 assert_eq!(object_name.0.len(), 1);
431 assert_eq!(object_name.to_string(), table_name.to_ascii_lowercase());
432 }
433
434 #[test]
435 pub fn test_parse_mysql_prepare_stmt() {
436 let sql = "PREPARE stmt1 FROM 'SELECT * FROM t1 WHERE id = ?';";
437 let (stmt_name, stmt) =
438 ParserContext::parse_mysql_prepare_stmt(sql, &MySqlDialect {}).unwrap();
439 assert_eq!(stmt_name, "stmt1");
440 assert_eq!(stmt, "SELECT * FROM t1 WHERE id = ?");
441
442 let sql = "PREPARE stmt2 FROM \"SELECT * FROM t1 WHERE id = ?\"";
443 let (stmt_name, stmt) =
444 ParserContext::parse_mysql_prepare_stmt(sql, &MySqlDialect {}).unwrap();
445 assert_eq!(stmt_name, "stmt2");
446 assert_eq!(stmt, "SELECT * FROM t1 WHERE id = ?");
447 }
448
449 #[test]
450 pub fn test_parse_mysql_execute_stmt() {
451 let sql = "EXECUTE stmt1 USING 1, 'hello';";
452 let (stmt_name, params) =
453 ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
454 assert_eq!(stmt_name, "stmt1");
455 assert_eq!(params.len(), 2);
456 assert_eq!(params[0].to_string(), "1");
457 assert_eq!(params[1].to_string(), "'hello'");
458
459 let sql = "EXECUTE stmt2;";
460 let (stmt_name, params) =
461 ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
462 assert_eq!(stmt_name, "stmt2");
463 assert_eq!(params.len(), 0);
464
465 let sql = "EXECUTE stmt3 USING 231, 'hello', \"2003-03-1\", NULL, ;";
466 let (stmt_name, params) =
467 ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
468 assert_eq!(stmt_name, "stmt3");
469 assert_eq!(params.len(), 4);
470 assert_eq!(params[0].to_string(), "231");
471 assert_eq!(params[1].to_string(), "'hello'");
472 assert_eq!(params[2].to_string(), "\"2003-03-1\"");
473 assert_eq!(params[3].to_string(), "NULL");
474 }
475
476 #[test]
477 pub fn test_parse_mysql_deallocate_stmt() {
478 let sql = "DEALLOCATE stmt1;";
479 let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap();
480 assert_eq!(stmt_name, "stmt1");
481
482 let sql = "DEALLOCATE stmt2";
483 let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap();
484 assert_eq!(stmt_name, "stmt2");
485 }
486
487 #[test]
488 pub fn test_parse_kill_query_statement() {
489 use crate::statements::kill::Kill;
490
491 let sql = "KILL QUERY 123";
493 let statements =
494 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
495 .unwrap();
496
497 assert_eq!(statements.len(), 1);
498 match &statements[0] {
499 Statement::Kill(Kill::ConnectionId(connection_id)) => {
500 assert_eq!(*connection_id, 123);
501 }
502 _ => panic!("Expected Kill::ConnectionId statement"),
503 }
504
505 let sql = "KILL QUERY 999999";
507 let statements =
508 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
509 .unwrap();
510
511 assert_eq!(statements.len(), 1);
512 match &statements[0] {
513 Statement::Kill(Kill::ConnectionId(connection_id)) => {
514 assert_eq!(*connection_id, 999999);
515 }
516 _ => panic!("Expected Kill::ConnectionId statement"),
517 }
518 }
519
520 #[test]
521 pub fn test_parse_kill_process_statement() {
522 use crate::statements::kill::Kill;
523
524 let sql = "KILL 'process-123'";
526 let statements =
527 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
528 .unwrap();
529
530 assert_eq!(statements.len(), 1);
531 match &statements[0] {
532 Statement::Kill(Kill::ProcessId(process_id)) => {
533 assert_eq!(process_id, "process-123");
534 }
535 _ => panic!("Expected Kill::ProcessId statement"),
536 }
537
538 let sql = "KILL \"process-456\"";
540 let statements =
541 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
542 .unwrap();
543
544 assert_eq!(statements.len(), 1);
545 match &statements[0] {
546 Statement::Kill(Kill::ProcessId(process_id)) => {
547 assert_eq!(process_id, "process-456");
548 }
549 _ => panic!("Expected Kill::ProcessId statement"),
550 }
551
552 let sql = "KILL 'f47ac10b-58cc-4372-a567-0e02b2c3d479'";
554 let statements =
555 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
556 .unwrap();
557
558 assert_eq!(statements.len(), 1);
559 match &statements[0] {
560 Statement::Kill(Kill::ProcessId(process_id)) => {
561 assert_eq!(process_id, "f47ac10b-58cc-4372-a567-0e02b2c3d479");
562 }
563 _ => panic!("Expected Kill::ProcessId statement"),
564 }
565 }
566
567 #[test]
568 pub fn test_parse_kill_statement_errors() {
569 let sql = "KILL QUERY";
571 let result =
572 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
573 assert!(result.is_err());
574
575 let sql = "KILL QUERY 'not-a-number'";
577 let result =
578 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
579 assert!(result.is_err());
580
581 let sql = "KILL";
583 let result =
584 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
585 assert!(result.is_err());
586
587 let sql = "KILL QUERY 4294967296"; let result =
590 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
591 assert!(result.is_err());
592 }
593
594 #[test]
595 pub fn test_parse_kill_statement_edge_cases() {
596 use crate::statements::kill::Kill;
597
598 let sql = "KILL QUERY 0";
600 let statements =
601 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
602 .unwrap();
603
604 assert_eq!(statements.len(), 1);
605 match &statements[0] {
606 Statement::Kill(Kill::ConnectionId(connection_id)) => {
607 assert_eq!(*connection_id, 0);
608 }
609 _ => panic!("Expected Kill::ConnectionId statement"),
610 }
611
612 let sql = "KILL QUERY 4294967295"; let statements =
615 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
616 .unwrap();
617
618 assert_eq!(statements.len(), 1);
619 match &statements[0] {
620 Statement::Kill(Kill::ConnectionId(connection_id)) => {
621 assert_eq!(*connection_id, 4294967295);
622 }
623 _ => panic!("Expected Kill::ConnectionId statement"),
624 }
625
626 let sql = "KILL ''";
628 let statements =
629 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
630 .unwrap();
631
632 assert_eq!(statements.len(), 1);
633 match &statements[0] {
634 Statement::Kill(Kill::ProcessId(process_id)) => {
635 assert_eq!(process_id, "");
636 }
637 _ => panic!("Expected Kill::ProcessId statement"),
638 }
639 }
640
641 #[test]
642 pub fn test_parse_kill_statement_case_insensitive() {
643 use crate::statements::kill::Kill;
644
645 let sql = "kill query 123";
647 let statements =
648 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
649 .unwrap();
650
651 assert_eq!(statements.len(), 1);
652 match &statements[0] {
653 Statement::Kill(Kill::ConnectionId(connection_id)) => {
654 assert_eq!(*connection_id, 123);
655 }
656 _ => panic!("Expected Kill::ConnectionId statement"),
657 }
658
659 let sql = "Kill Query 456";
661 let statements =
662 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
663 .unwrap();
664
665 assert_eq!(statements.len(), 1);
666 match &statements[0] {
667 Statement::Kill(Kill::ConnectionId(connection_id)) => {
668 assert_eq!(*connection_id, 456);
669 }
670 _ => panic!("Expected Kill::ConnectionId statement"),
671 }
672 }
673}