1use std::str::FromStr;
16
17use snafu::ResultExt;
18use sqlparser::ast::{Ident, Query, Value};
19use sqlparser::dialect::Dialect;
20use sqlparser::keywords::Keyword;
21use sqlparser::parser::{Parser, ParserError, ParserOptions};
22use sqlparser::tokenizer::{Token, TokenWithSpan};
23
24use crate::ast::{Expr, ObjectName};
25use crate::error::{self, Result, SyntaxSnafu};
26use crate::parsers::tql_parser;
27use crate::statements::kill::Kill;
28use crate::statements::statement::Statement;
29use crate::statements::transform_statements;
30
31pub const FLOW: &str = "FLOW";
32
33#[derive(Clone, Debug, Default)]
35pub struct ParseOptions {}
36
37pub struct ParserContext<'a> {
39 pub(crate) parser: Parser<'a>,
40 pub(crate) sql: &'a str,
41}
42
43impl ParserContext<'_> {
44 pub fn new<'a>(dialect: &'a dyn Dialect, sql: &'a str) -> Result<ParserContext<'a>> {
46 let parser = Parser::new(dialect)
47 .with_options(ParserOptions::new().with_trailing_commas(true))
48 .try_with_sql(sql)
49 .context(SyntaxSnafu)?;
50
51 Ok(ParserContext { parser, sql })
52 }
53
54 pub fn parser_query(&mut self) -> Result<Box<Query>> {
56 self.parser.parse_query().context(SyntaxSnafu)
57 }
58
59 pub fn create_with_dialect(
61 sql: &str,
62 dialect: &dyn Dialect,
63 _opts: ParseOptions,
64 ) -> Result<Vec<Statement>> {
65 let mut stmts: Vec<Statement> = Vec::new();
66
67 let mut parser_ctx = ParserContext::new(dialect, sql)?;
68
69 let mut expecting_statement_delimiter = false;
70 loop {
71 while parser_ctx.parser.consume_token(&Token::SemiColon) {
73 expecting_statement_delimiter = false;
74 }
75
76 if parser_ctx.parser.peek_token() == Token::EOF {
77 break;
78 }
79 if expecting_statement_delimiter {
80 return parser_ctx.unsupported(parser_ctx.peek_token_as_string());
81 }
82
83 let statement = parser_ctx.parse_statement()?;
84 stmts.push(statement);
85 expecting_statement_delimiter = true;
86 }
87
88 transform_statements(&mut stmts)?;
89
90 Ok(stmts)
91 }
92
93 pub fn parse_table_name(sql: &str, dialect: &dyn Dialect) -> Result<ObjectName> {
94 let parser = Parser::new(dialect)
95 .with_options(ParserOptions::new().with_trailing_commas(true))
96 .try_with_sql(sql)
97 .context(SyntaxSnafu)?;
98 ParserContext { parser, sql }.intern_parse_table_name()
99 }
100
101 pub(crate) fn intern_parse_table_name(&mut self) -> Result<ObjectName> {
102 let raw_table_name =
103 self.parser
104 .parse_object_name(false)
105 .context(error::UnexpectedSnafu {
106 expected: "a table name",
107 actual: self.parser.peek_token().to_string(),
108 })?;
109 Ok(Self::canonicalize_object_name(raw_table_name))
110 }
111
112 pub fn parse_function(sql: &str, dialect: &dyn Dialect) -> Result<Expr> {
113 let mut parser = Parser::new(dialect)
114 .with_options(ParserOptions::new().with_trailing_commas(true))
115 .try_with_sql(sql)
116 .context(SyntaxSnafu)?;
117
118 let function_name = parser.parse_identifier().context(SyntaxSnafu)?;
119 parser
120 .parse_function(ObjectName(vec![function_name]))
121 .context(SyntaxSnafu)
122 }
123
124 pub fn parse_statement(&mut self) -> Result<Statement> {
126 match self.parser.peek_token().token {
127 Token::Word(w) => match w.keyword {
128 Keyword::CREATE => {
129 let _ = self.parser.next_token();
130 self.parse_create()
131 }
132
133 Keyword::EXPLAIN => {
134 let _ = self.parser.next_token();
135 self.parse_explain()
136 }
137
138 Keyword::SHOW => {
139 let _ = self.parser.next_token();
140 self.parse_show()
141 }
142
143 Keyword::DELETE => self.parse_delete(),
144
145 Keyword::DESCRIBE | Keyword::DESC => {
146 let _ = self.parser.next_token();
147 self.parse_describe()
148 }
149
150 Keyword::INSERT => self.parse_insert(),
151
152 Keyword::REPLACE => self.parse_replace(),
153
154 Keyword::SELECT | Keyword::WITH | Keyword::VALUES => self.parse_query(),
155
156 Keyword::ALTER => self.parse_alter(),
157
158 Keyword::DROP => self.parse_drop(),
159
160 Keyword::COPY => self.parse_copy(),
161
162 Keyword::TRUNCATE => self.parse_truncate(),
163
164 Keyword::SET => self.parse_set_variables(),
165
166 Keyword::ADMIN => self.parse_admin_command(),
167
168 Keyword::NoKeyword
169 if w.quote_style.is_none() && w.value.to_uppercase() == tql_parser::TQL =>
170 {
171 self.parse_tql()
172 }
173
174 Keyword::DECLARE => self.parse_declare_cursor(),
175
176 Keyword::FETCH => self.parse_fetch_cursor(),
177
178 Keyword::CLOSE => self.parse_close_cursor(),
179
180 Keyword::USE => {
181 let _ = self.parser.next_token();
182
183 let database_name = self.parser.parse_identifier().with_context(|_| {
184 error::UnexpectedSnafu {
185 expected: "a database name",
186 actual: self.peek_token_as_string(),
187 }
188 })?;
189 Ok(Statement::Use(
190 Self::canonicalize_identifier(database_name).value,
191 ))
192 }
193
194 Keyword::KILL => {
195 let _ = self.parser.next_token();
196 let kill = if self.parser.parse_keyword(Keyword::QUERY) {
197 let connection_id_exp =
199 self.parser.parse_number_value().with_context(|_| {
200 error::UnexpectedSnafu {
201 expected: "MySQL numeric connection id",
202 actual: self.peek_token_as_string(),
203 }
204 })?;
205 let Value::Number(s, _) = connection_id_exp else {
206 return error::UnexpectedTokenSnafu {
207 expected: "MySQL numeric connection id",
208 actual: connection_id_exp.to_string(),
209 }
210 .fail();
211 };
212
213 let connection_id = u32::from_str(&s).map_err(|_| {
214 error::UnexpectedTokenSnafu {
215 expected: "MySQL numeric connection id",
216 actual: s,
217 }
218 .build()
219 })?;
220 Kill::ConnectionId(connection_id)
221 } else {
222 let process_id_ident =
223 self.parser.parse_literal_string().with_context(|_| {
224 error::UnexpectedSnafu {
225 expected: "process id string literal",
226 actual: self.peek_token_as_string(),
227 }
228 })?;
229 Kill::ProcessId(process_id_ident)
230 };
231
232 Ok(Statement::Kill(kill))
233 }
234
235 _ => self.unsupported(self.peek_token_as_string()),
236 },
237 Token::LParen => self.parse_query(),
238 unexpected => self.unsupported(unexpected.to_string()),
239 }
240 }
241
242 pub fn parse_mysql_prepare_stmt(sql: &str, dialect: &dyn Dialect) -> Result<(String, String)> {
244 ParserContext::new(dialect, sql)?.parse_mysql_prepare()
245 }
246
247 pub fn parse_mysql_execute_stmt(
249 sql: &str,
250 dialect: &dyn Dialect,
251 ) -> Result<(String, Vec<Expr>)> {
252 ParserContext::new(dialect, sql)?.parse_mysql_execute()
253 }
254
255 pub fn parse_mysql_deallocate_stmt(sql: &str, dialect: &dyn Dialect) -> Result<String> {
257 ParserContext::new(dialect, sql)?.parse_deallocate()
258 }
259
260 pub fn unsupported<T>(&self, keyword: String) -> Result<T> {
262 error::UnsupportedSnafu { keyword }.fail()
263 }
264
265 pub(crate) fn expected<T>(&self, expected: &str, found: TokenWithSpan) -> Result<T> {
267 Err(ParserError::ParserError(format!(
268 "Expected {expected}, found: {found}",
269 )))
270 .context(SyntaxSnafu)
271 }
272
273 pub fn matches_keyword(&mut self, expected: Keyword) -> bool {
274 match self.parser.peek_token().token {
275 Token::Word(w) => w.keyword == expected,
276 _ => false,
277 }
278 }
279
280 pub fn consume_token(&mut self, expected: &str) -> bool {
281 if self.peek_token_as_string().to_uppercase() == *expected.to_uppercase() {
282 let _ = self.parser.next_token();
283 true
284 } else {
285 false
286 }
287 }
288
289 #[inline]
290 pub(crate) fn peek_token_as_string(&self) -> String {
291 self.parser.peek_token().to_string()
292 }
293
294 pub fn canonicalize_identifier(ident: Ident) -> Ident {
296 if ident.quote_style.is_some() {
297 ident
298 } else {
299 Ident::new(ident.value.to_lowercase())
300 }
301 }
302
303 pub fn canonicalize_object_name(object_name: ObjectName) -> ObjectName {
305 ObjectName(
306 object_name
307 .0
308 .into_iter()
309 .map(Self::canonicalize_identifier)
310 .collect(),
311 )
312 }
313
314 pub(crate) fn parse_object_name(&mut self) -> std::result::Result<ObjectName, ParserError> {
319 self.parser.parse_object_name(false)
320 }
321}
322
323#[cfg(test)]
324mod tests {
325
326 use datatypes::prelude::ConcreteDataType;
327 use sqlparser::dialect::MySqlDialect;
328
329 use super::*;
330 use crate::dialect::GreptimeDbDialect;
331 use crate::statements::create::CreateTable;
332 use crate::statements::sql_data_type_to_concrete_data_type;
333
334 fn test_timestamp_precision(sql: &str, expected_type: ConcreteDataType) {
335 match ParserContext::create_with_dialect(
336 sql,
337 &GreptimeDbDialect {},
338 ParseOptions::default(),
339 )
340 .unwrap()
341 .pop()
342 .unwrap()
343 {
344 Statement::CreateTable(CreateTable { columns, .. }) => {
345 let ts_col = columns.first().unwrap();
346 assert_eq!(
347 expected_type,
348 sql_data_type_to_concrete_data_type(ts_col.data_type()).unwrap()
349 );
350 }
351 _ => unreachable!(),
352 }
353 }
354
355 #[test]
356 pub fn test_create_table_with_precision() {
357 test_timestamp_precision(
358 "create table demo (ts timestamp time index, cnt int);",
359 ConcreteDataType::timestamp_millisecond_datatype(),
360 );
361 test_timestamp_precision(
362 "create table demo (ts timestamp(0) time index, cnt int);",
363 ConcreteDataType::timestamp_second_datatype(),
364 );
365 test_timestamp_precision(
366 "create table demo (ts timestamp(3) time index, cnt int);",
367 ConcreteDataType::timestamp_millisecond_datatype(),
368 );
369 test_timestamp_precision(
370 "create table demo (ts timestamp(6) time index, cnt int);",
371 ConcreteDataType::timestamp_microsecond_datatype(),
372 );
373 test_timestamp_precision(
374 "create table demo (ts timestamp(9) time index, cnt int);",
375 ConcreteDataType::timestamp_nanosecond_datatype(),
376 );
377 }
378
379 #[test]
380 #[should_panic]
381 pub fn test_create_table_with_invalid_precision() {
382 test_timestamp_precision(
383 "create table demo (ts timestamp(1) time index, cnt int);",
384 ConcreteDataType::timestamp_millisecond_datatype(),
385 );
386 }
387
388 #[test]
389 pub fn test_parse_table_name() {
390 let table_name = "a.b.c";
391
392 let object_name =
393 ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
394
395 assert_eq!(object_name.0.len(), 3);
396 assert_eq!(object_name.to_string(), table_name);
397
398 let table_name = "a.b";
399
400 let object_name =
401 ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
402
403 assert_eq!(object_name.0.len(), 2);
404 assert_eq!(object_name.to_string(), table_name);
405
406 let table_name = "Test.\"public-test\"";
407
408 let object_name =
409 ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
410
411 assert_eq!(object_name.0.len(), 2);
412 assert_eq!(object_name.to_string(), table_name.to_ascii_lowercase());
413
414 let table_name = "HelloWorld";
415
416 let object_name =
417 ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
418
419 assert_eq!(object_name.0.len(), 1);
420 assert_eq!(object_name.to_string(), table_name.to_ascii_lowercase());
421 }
422
423 #[test]
424 pub fn test_parse_mysql_prepare_stmt() {
425 let sql = "PREPARE stmt1 FROM 'SELECT * FROM t1 WHERE id = ?';";
426 let (stmt_name, stmt) =
427 ParserContext::parse_mysql_prepare_stmt(sql, &MySqlDialect {}).unwrap();
428 assert_eq!(stmt_name, "stmt1");
429 assert_eq!(stmt, "SELECT * FROM t1 WHERE id = ?");
430
431 let sql = "PREPARE stmt2 FROM \"SELECT * FROM t1 WHERE id = ?\"";
432 let (stmt_name, stmt) =
433 ParserContext::parse_mysql_prepare_stmt(sql, &MySqlDialect {}).unwrap();
434 assert_eq!(stmt_name, "stmt2");
435 assert_eq!(stmt, "SELECT * FROM t1 WHERE id = ?");
436 }
437
438 #[test]
439 pub fn test_parse_mysql_execute_stmt() {
440 let sql = "EXECUTE stmt1 USING 1, 'hello';";
441 let (stmt_name, params) =
442 ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
443 assert_eq!(stmt_name, "stmt1");
444 assert_eq!(params.len(), 2);
445 assert_eq!(params[0].to_string(), "1");
446 assert_eq!(params[1].to_string(), "'hello'");
447
448 let sql = "EXECUTE stmt2;";
449 let (stmt_name, params) =
450 ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
451 assert_eq!(stmt_name, "stmt2");
452 assert_eq!(params.len(), 0);
453
454 let sql = "EXECUTE stmt3 USING 231, 'hello', \"2003-03-1\", NULL, ;";
455 let (stmt_name, params) =
456 ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
457 assert_eq!(stmt_name, "stmt3");
458 assert_eq!(params.len(), 4);
459 assert_eq!(params[0].to_string(), "231");
460 assert_eq!(params[1].to_string(), "'hello'");
461 assert_eq!(params[2].to_string(), "\"2003-03-1\"");
462 assert_eq!(params[3].to_string(), "NULL");
463 }
464
465 #[test]
466 pub fn test_parse_mysql_deallocate_stmt() {
467 let sql = "DEALLOCATE stmt1;";
468 let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap();
469 assert_eq!(stmt_name, "stmt1");
470
471 let sql = "DEALLOCATE stmt2";
472 let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap();
473 assert_eq!(stmt_name, "stmt2");
474 }
475
476 #[test]
477 pub fn test_parse_kill_query_statement() {
478 use crate::statements::kill::Kill;
479
480 let sql = "KILL QUERY 123";
482 let statements =
483 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
484 .unwrap();
485
486 assert_eq!(statements.len(), 1);
487 match &statements[0] {
488 Statement::Kill(Kill::ConnectionId(connection_id)) => {
489 assert_eq!(*connection_id, 123);
490 }
491 _ => panic!("Expected Kill::ConnectionId statement"),
492 }
493
494 let sql = "KILL QUERY 999999";
496 let statements =
497 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
498 .unwrap();
499
500 assert_eq!(statements.len(), 1);
501 match &statements[0] {
502 Statement::Kill(Kill::ConnectionId(connection_id)) => {
503 assert_eq!(*connection_id, 999999);
504 }
505 _ => panic!("Expected Kill::ConnectionId statement"),
506 }
507 }
508
509 #[test]
510 pub fn test_parse_kill_process_statement() {
511 use crate::statements::kill::Kill;
512
513 let sql = "KILL 'process-123'";
515 let statements =
516 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
517 .unwrap();
518
519 assert_eq!(statements.len(), 1);
520 match &statements[0] {
521 Statement::Kill(Kill::ProcessId(process_id)) => {
522 assert_eq!(process_id, "process-123");
523 }
524 _ => panic!("Expected Kill::ProcessId statement"),
525 }
526
527 let sql = "KILL \"process-456\"";
529 let statements =
530 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
531 .unwrap();
532
533 assert_eq!(statements.len(), 1);
534 match &statements[0] {
535 Statement::Kill(Kill::ProcessId(process_id)) => {
536 assert_eq!(process_id, "process-456");
537 }
538 _ => panic!("Expected Kill::ProcessId statement"),
539 }
540
541 let sql = "KILL 'f47ac10b-58cc-4372-a567-0e02b2c3d479'";
543 let statements =
544 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
545 .unwrap();
546
547 assert_eq!(statements.len(), 1);
548 match &statements[0] {
549 Statement::Kill(Kill::ProcessId(process_id)) => {
550 assert_eq!(process_id, "f47ac10b-58cc-4372-a567-0e02b2c3d479");
551 }
552 _ => panic!("Expected Kill::ProcessId statement"),
553 }
554 }
555
556 #[test]
557 pub fn test_parse_kill_statement_errors() {
558 let sql = "KILL QUERY";
560 let result =
561 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
562 assert!(result.is_err());
563
564 let sql = "KILL QUERY 'not-a-number'";
566 let result =
567 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
568 assert!(result.is_err());
569
570 let sql = "KILL";
572 let result =
573 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
574 assert!(result.is_err());
575
576 let sql = "KILL QUERY 4294967296"; let result =
579 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
580 assert!(result.is_err());
581 }
582
583 #[test]
584 pub fn test_parse_kill_statement_edge_cases() {
585 use crate::statements::kill::Kill;
586
587 let sql = "KILL QUERY 0";
589 let statements =
590 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
591 .unwrap();
592
593 assert_eq!(statements.len(), 1);
594 match &statements[0] {
595 Statement::Kill(Kill::ConnectionId(connection_id)) => {
596 assert_eq!(*connection_id, 0);
597 }
598 _ => panic!("Expected Kill::ConnectionId statement"),
599 }
600
601 let sql = "KILL QUERY 4294967295"; let statements =
604 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
605 .unwrap();
606
607 assert_eq!(statements.len(), 1);
608 match &statements[0] {
609 Statement::Kill(Kill::ConnectionId(connection_id)) => {
610 assert_eq!(*connection_id, 4294967295);
611 }
612 _ => panic!("Expected Kill::ConnectionId statement"),
613 }
614
615 let sql = "KILL ''";
617 let statements =
618 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
619 .unwrap();
620
621 assert_eq!(statements.len(), 1);
622 match &statements[0] {
623 Statement::Kill(Kill::ProcessId(process_id)) => {
624 assert_eq!(process_id, "");
625 }
626 _ => panic!("Expected Kill::ProcessId statement"),
627 }
628 }
629
630 #[test]
631 pub fn test_parse_kill_statement_case_insensitive() {
632 use crate::statements::kill::Kill;
633
634 let sql = "kill query 123";
636 let statements =
637 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
638 .unwrap();
639
640 assert_eq!(statements.len(), 1);
641 match &statements[0] {
642 Statement::Kill(Kill::ConnectionId(connection_id)) => {
643 assert_eq!(*connection_id, 123);
644 }
645 _ => panic!("Expected Kill::ConnectionId statement"),
646 }
647
648 let sql = "Kill Query 456";
650 let statements =
651 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
652 .unwrap();
653
654 assert_eq!(statements.len(), 1);
655 match &statements[0] {
656 Statement::Kill(Kill::ConnectionId(connection_id)) => {
657 assert_eq!(*connection_id, 456);
658 }
659 _ => panic!("Expected Kill::ConnectionId statement"),
660 }
661 }
662}