1use std::str::FromStr;
16
17use snafu::{OptionExt, ResultExt};
18use sqlparser::ast::{Ident, Query, Value};
19use sqlparser::dialect::Dialect;
20use sqlparser::keywords::Keyword;
21use sqlparser::parser::{Parser, ParserError, ParserOptions};
22use sqlparser::tokenizer::{Token, TokenWithSpan};
23
24use crate::ast::{Expr, ObjectName};
25use crate::error::{self, InvalidSqlSnafu, Result, SyntaxSnafu};
26use crate::parsers::tql_parser;
27use crate::statements::kill::Kill;
28use crate::statements::statement::Statement;
29use crate::statements::transform_statements;
30
31pub const FLOW: &str = "FLOW";
32
33#[derive(Clone, Debug, Default)]
35pub struct ParseOptions {}
36
37pub struct ParserContext<'a> {
39 pub(crate) parser: Parser<'a>,
40 pub(crate) sql: &'a str,
41}
42
43impl ParserContext<'_> {
44 pub fn new<'a>(dialect: &'a dyn Dialect, sql: &'a str) -> Result<ParserContext<'a>> {
46 let parser = Parser::new(dialect)
47 .with_options(ParserOptions::new().with_trailing_commas(true))
48 .try_with_sql(sql)
49 .context(SyntaxSnafu)?;
50
51 Ok(ParserContext { parser, sql })
52 }
53
54 pub fn parser_query(&mut self) -> Result<Box<Query>> {
56 self.parser.parse_query().context(SyntaxSnafu)
57 }
58
59 pub fn create_with_dialect(
61 sql: &str,
62 dialect: &dyn Dialect,
63 _opts: ParseOptions,
64 ) -> Result<Vec<Statement>> {
65 let mut stmts: Vec<Statement> = Vec::new();
66
67 let mut parser_ctx = ParserContext::new(dialect, sql)?;
68
69 let mut expecting_statement_delimiter = false;
70 loop {
71 while parser_ctx.parser.consume_token(&Token::SemiColon) {
73 expecting_statement_delimiter = false;
74 }
75
76 if parser_ctx.parser.peek_token() == Token::EOF {
77 break;
78 }
79 if expecting_statement_delimiter {
80 return parser_ctx.unsupported(parser_ctx.peek_token_as_string());
81 }
82
83 let statement = parser_ctx.parse_statement()?;
84 stmts.push(statement);
85 expecting_statement_delimiter = true;
86 }
87
88 transform_statements(&mut stmts)?;
89
90 Ok(stmts)
91 }
92
93 pub fn parse_table_name(sql: &str, dialect: &dyn Dialect) -> Result<ObjectName> {
94 let parser = Parser::new(dialect)
95 .with_options(ParserOptions::new().with_trailing_commas(true))
96 .try_with_sql(sql)
97 .context(SyntaxSnafu)?;
98 ParserContext { parser, sql }.intern_parse_table_name()
99 }
100
101 pub(crate) fn intern_parse_table_name(&mut self) -> Result<ObjectName> {
102 let raw_table_name =
103 self.parser
104 .parse_object_name(false)
105 .context(error::UnexpectedSnafu {
106 expected: "a table name",
107 actual: self.parser.peek_token().to_string(),
108 })?;
109 Self::canonicalize_object_name(raw_table_name)
110 }
111
112 pub fn parse_function(sql: &str, dialect: &dyn Dialect) -> Result<Expr> {
113 let mut parser = Parser::new(dialect)
114 .with_options(ParserOptions::new().with_trailing_commas(true))
115 .try_with_sql(sql)
116 .context(SyntaxSnafu)?;
117
118 let function_name = parser.parse_identifier().context(SyntaxSnafu)?;
119 parser
120 .parse_function(vec![function_name].into())
121 .context(SyntaxSnafu)
122 }
123
124 pub fn parse_statement(&mut self) -> Result<Statement> {
126 match self.parser.peek_token().token {
127 Token::Word(w) => match w.keyword {
128 Keyword::CREATE => {
129 let _ = self.parser.next_token();
130 self.parse_create()
131 }
132
133 Keyword::EXPLAIN => {
134 let _ = self.parser.next_token();
135 self.parse_explain()
136 }
137
138 Keyword::SHOW => {
139 let _ = self.parser.next_token();
140 self.parse_show()
141 }
142
143 Keyword::DELETE => self.parse_delete(),
144
145 Keyword::DESCRIBE | Keyword::DESC => {
146 let _ = self.parser.next_token();
147 self.parse_describe()
148 }
149
150 Keyword::INSERT => self.parse_insert(),
151
152 Keyword::REPLACE => self.parse_replace(),
153
154 Keyword::SELECT | Keyword::VALUES => self.parse_query(),
155
156 Keyword::WITH => self.parse_with_tql(),
157
158 Keyword::ALTER => self.parse_alter(),
159
160 Keyword::DROP => self.parse_drop(),
161
162 Keyword::COPY => self.parse_copy(),
163
164 Keyword::TRUNCATE => self.parse_truncate(),
165
166 Keyword::SET => self.parse_set_variables(),
167
168 Keyword::ADMIN => self.parse_admin_command(),
169
170 Keyword::NoKeyword
171 if w.quote_style.is_none() && w.value.to_uppercase() == tql_parser::TQL =>
172 {
173 self.parse_tql(false)
174 }
175
176 Keyword::DECLARE => self.parse_declare_cursor(),
177
178 Keyword::FETCH => self.parse_fetch_cursor(),
179
180 Keyword::CLOSE => self.parse_close_cursor(),
181
182 Keyword::USE => {
183 let _ = self.parser.next_token();
184
185 let database_name = self.parser.parse_identifier().with_context(|_| {
186 error::UnexpectedSnafu {
187 expected: "a database name",
188 actual: self.peek_token_as_string(),
189 }
190 })?;
191 Ok(Statement::Use(
192 Self::canonicalize_identifier(database_name).value,
193 ))
194 }
195
196 Keyword::KILL => {
197 let _ = self.parser.next_token();
198 let kill = if self.parser.parse_keyword(Keyword::QUERY) {
199 let connection_id_exp =
201 self.parser.parse_number_value().with_context(|_| {
202 error::UnexpectedSnafu {
203 expected: "MySQL numeric connection id",
204 actual: self.peek_token_as_string(),
205 }
206 })?;
207 let Value::Number(s, _) = connection_id_exp.value else {
208 return error::UnexpectedTokenSnafu {
209 expected: "MySQL numeric connection id",
210 actual: connection_id_exp.to_string(),
211 }
212 .fail();
213 };
214
215 let connection_id = u32::from_str(&s).map_err(|_| {
216 error::UnexpectedTokenSnafu {
217 expected: "MySQL numeric connection id",
218 actual: s,
219 }
220 .build()
221 })?;
222 Kill::ConnectionId(connection_id)
223 } else {
224 let process_id_ident =
225 self.parser.parse_literal_string().with_context(|_| {
226 error::UnexpectedSnafu {
227 expected: "process id string literal",
228 actual: self.peek_token_as_string(),
229 }
230 })?;
231 Kill::ProcessId(process_id_ident)
232 };
233
234 Ok(Statement::Kill(kill))
235 }
236
237 _ => self.unsupported(self.peek_token_as_string()),
238 },
239 Token::LParen => self.parse_query(),
240 unexpected => self.unsupported(unexpected.to_string()),
241 }
242 }
243
244 pub fn parse_mysql_prepare_stmt(sql: &str, dialect: &dyn Dialect) -> Result<(String, String)> {
246 ParserContext::new(dialect, sql)?.parse_mysql_prepare()
247 }
248
249 pub fn parse_mysql_execute_stmt(
251 sql: &str,
252 dialect: &dyn Dialect,
253 ) -> Result<(String, Vec<Expr>)> {
254 ParserContext::new(dialect, sql)?.parse_mysql_execute()
255 }
256
257 pub fn parse_mysql_deallocate_stmt(sql: &str, dialect: &dyn Dialect) -> Result<String> {
259 ParserContext::new(dialect, sql)?.parse_deallocate()
260 }
261
262 pub fn unsupported<T>(&self, keyword: String) -> Result<T> {
264 error::UnsupportedSnafu { keyword }.fail()
265 }
266
267 pub(crate) fn expected<T>(&self, expected: &str, found: TokenWithSpan) -> Result<T> {
269 Err(ParserError::ParserError(format!(
270 "Expected {expected}, found: {found}",
271 )))
272 .context(SyntaxSnafu)
273 }
274
275 pub fn matches_keyword(&mut self, expected: Keyword) -> bool {
276 match self.parser.peek_token().token {
277 Token::Word(w) => w.keyword == expected,
278 _ => false,
279 }
280 }
281
282 pub fn consume_token(&mut self, expected: &str) -> bool {
283 if self.peek_token_as_string().to_uppercase() == *expected.to_uppercase() {
284 let _ = self.parser.next_token();
285 true
286 } else {
287 false
288 }
289 }
290
291 #[inline]
292 pub(crate) fn peek_token_as_string(&self) -> String {
293 self.parser.peek_token().to_string()
294 }
295
296 pub fn canonicalize_identifier(ident: Ident) -> Ident {
298 if ident.quote_style.is_some() {
299 ident
300 } else {
301 Ident::new(ident.value.to_lowercase())
302 }
303 }
304
305 pub(crate) fn canonicalize_object_name(object_name: ObjectName) -> Result<ObjectName> {
307 object_name
308 .0
309 .into_iter()
310 .map(|x| {
311 x.as_ident()
312 .cloned()
313 .map(Self::canonicalize_identifier)
314 .with_context(|| InvalidSqlSnafu {
315 msg: format!("not an ident: '{x}'"),
316 })
317 })
318 .collect::<Result<Vec<_>>>()
319 .map(Into::into)
320 }
321
322 pub(crate) fn parse_object_name(&mut self) -> std::result::Result<ObjectName, ParserError> {
327 self.parser.parse_object_name(false)
328 }
329}
330
331#[cfg(test)]
332mod tests {
333
334 use datatypes::prelude::ConcreteDataType;
335 use sqlparser::dialect::MySqlDialect;
336
337 use super::*;
338 use crate::dialect::GreptimeDbDialect;
339 use crate::statements::create::CreateTable;
340 use crate::statements::sql_data_type_to_concrete_data_type;
341
342 fn test_timestamp_precision(sql: &str, expected_type: ConcreteDataType) {
343 match ParserContext::create_with_dialect(
344 sql,
345 &GreptimeDbDialect {},
346 ParseOptions::default(),
347 )
348 .unwrap()
349 .pop()
350 .unwrap()
351 {
352 Statement::CreateTable(CreateTable { columns, .. }) => {
353 let ts_col = columns.first().unwrap();
354 assert_eq!(
355 expected_type,
356 sql_data_type_to_concrete_data_type(ts_col.data_type()).unwrap()
357 );
358 }
359 _ => unreachable!(),
360 }
361 }
362
363 #[test]
364 pub fn test_create_table_with_precision() {
365 test_timestamp_precision(
366 "create table demo (ts timestamp time index, cnt int);",
367 ConcreteDataType::timestamp_millisecond_datatype(),
368 );
369 test_timestamp_precision(
370 "create table demo (ts timestamp(0) time index, cnt int);",
371 ConcreteDataType::timestamp_second_datatype(),
372 );
373 test_timestamp_precision(
374 "create table demo (ts timestamp(3) time index, cnt int);",
375 ConcreteDataType::timestamp_millisecond_datatype(),
376 );
377 test_timestamp_precision(
378 "create table demo (ts timestamp(6) time index, cnt int);",
379 ConcreteDataType::timestamp_microsecond_datatype(),
380 );
381 test_timestamp_precision(
382 "create table demo (ts timestamp(9) time index, cnt int);",
383 ConcreteDataType::timestamp_nanosecond_datatype(),
384 );
385 }
386
387 #[test]
388 #[should_panic]
389 pub fn test_create_table_with_invalid_precision() {
390 test_timestamp_precision(
391 "create table demo (ts timestamp(1) time index, cnt int);",
392 ConcreteDataType::timestamp_millisecond_datatype(),
393 );
394 }
395
396 #[test]
397 pub fn test_parse_table_name() {
398 let table_name = "a.b.c";
399
400 let object_name =
401 ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
402
403 assert_eq!(object_name.0.len(), 3);
404 assert_eq!(object_name.to_string(), table_name);
405
406 let table_name = "a.b";
407
408 let object_name =
409 ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
410
411 assert_eq!(object_name.0.len(), 2);
412 assert_eq!(object_name.to_string(), table_name);
413
414 let table_name = "Test.\"public-test\"";
415
416 let object_name =
417 ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
418
419 assert_eq!(object_name.0.len(), 2);
420 assert_eq!(object_name.to_string(), table_name.to_ascii_lowercase());
421
422 let table_name = "HelloWorld";
423
424 let object_name =
425 ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
426
427 assert_eq!(object_name.0.len(), 1);
428 assert_eq!(object_name.to_string(), table_name.to_ascii_lowercase());
429 }
430
431 #[test]
432 pub fn test_parse_mysql_prepare_stmt() {
433 let sql = "PREPARE stmt1 FROM 'SELECT * FROM t1 WHERE id = ?';";
434 let (stmt_name, stmt) =
435 ParserContext::parse_mysql_prepare_stmt(sql, &MySqlDialect {}).unwrap();
436 assert_eq!(stmt_name, "stmt1");
437 assert_eq!(stmt, "SELECT * FROM t1 WHERE id = ?");
438
439 let sql = "PREPARE stmt2 FROM \"SELECT * FROM t1 WHERE id = ?\"";
440 let (stmt_name, stmt) =
441 ParserContext::parse_mysql_prepare_stmt(sql, &MySqlDialect {}).unwrap();
442 assert_eq!(stmt_name, "stmt2");
443 assert_eq!(stmt, "SELECT * FROM t1 WHERE id = ?");
444 }
445
446 #[test]
447 pub fn test_parse_mysql_execute_stmt() {
448 let sql = "EXECUTE stmt1 USING 1, 'hello';";
449 let (stmt_name, params) =
450 ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
451 assert_eq!(stmt_name, "stmt1");
452 assert_eq!(params.len(), 2);
453 assert_eq!(params[0].to_string(), "1");
454 assert_eq!(params[1].to_string(), "'hello'");
455
456 let sql = "EXECUTE stmt2;";
457 let (stmt_name, params) =
458 ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
459 assert_eq!(stmt_name, "stmt2");
460 assert_eq!(params.len(), 0);
461
462 let sql = "EXECUTE stmt3 USING 231, 'hello', \"2003-03-1\", NULL, ;";
463 let (stmt_name, params) =
464 ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
465 assert_eq!(stmt_name, "stmt3");
466 assert_eq!(params.len(), 4);
467 assert_eq!(params[0].to_string(), "231");
468 assert_eq!(params[1].to_string(), "'hello'");
469 assert_eq!(params[2].to_string(), "\"2003-03-1\"");
470 assert_eq!(params[3].to_string(), "NULL");
471 }
472
473 #[test]
474 pub fn test_parse_mysql_deallocate_stmt() {
475 let sql = "DEALLOCATE stmt1;";
476 let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap();
477 assert_eq!(stmt_name, "stmt1");
478
479 let sql = "DEALLOCATE stmt2";
480 let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap();
481 assert_eq!(stmt_name, "stmt2");
482 }
483
484 #[test]
485 pub fn test_parse_kill_query_statement() {
486 use crate::statements::kill::Kill;
487
488 let sql = "KILL QUERY 123";
490 let statements =
491 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
492 .unwrap();
493
494 assert_eq!(statements.len(), 1);
495 match &statements[0] {
496 Statement::Kill(Kill::ConnectionId(connection_id)) => {
497 assert_eq!(*connection_id, 123);
498 }
499 _ => panic!("Expected Kill::ConnectionId statement"),
500 }
501
502 let sql = "KILL QUERY 999999";
504 let statements =
505 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
506 .unwrap();
507
508 assert_eq!(statements.len(), 1);
509 match &statements[0] {
510 Statement::Kill(Kill::ConnectionId(connection_id)) => {
511 assert_eq!(*connection_id, 999999);
512 }
513 _ => panic!("Expected Kill::ConnectionId statement"),
514 }
515 }
516
517 #[test]
518 pub fn test_parse_kill_process_statement() {
519 use crate::statements::kill::Kill;
520
521 let sql = "KILL 'process-123'";
523 let statements =
524 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
525 .unwrap();
526
527 assert_eq!(statements.len(), 1);
528 match &statements[0] {
529 Statement::Kill(Kill::ProcessId(process_id)) => {
530 assert_eq!(process_id, "process-123");
531 }
532 _ => panic!("Expected Kill::ProcessId statement"),
533 }
534
535 let sql = "KILL \"process-456\"";
537 let statements =
538 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
539 .unwrap();
540
541 assert_eq!(statements.len(), 1);
542 match &statements[0] {
543 Statement::Kill(Kill::ProcessId(process_id)) => {
544 assert_eq!(process_id, "process-456");
545 }
546 _ => panic!("Expected Kill::ProcessId statement"),
547 }
548
549 let sql = "KILL 'f47ac10b-58cc-4372-a567-0e02b2c3d479'";
551 let statements =
552 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
553 .unwrap();
554
555 assert_eq!(statements.len(), 1);
556 match &statements[0] {
557 Statement::Kill(Kill::ProcessId(process_id)) => {
558 assert_eq!(process_id, "f47ac10b-58cc-4372-a567-0e02b2c3d479");
559 }
560 _ => panic!("Expected Kill::ProcessId statement"),
561 }
562 }
563
564 #[test]
565 pub fn test_parse_kill_statement_errors() {
566 let sql = "KILL QUERY";
568 let result =
569 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
570 assert!(result.is_err());
571
572 let sql = "KILL QUERY 'not-a-number'";
574 let result =
575 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
576 assert!(result.is_err());
577
578 let sql = "KILL";
580 let result =
581 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
582 assert!(result.is_err());
583
584 let sql = "KILL QUERY 4294967296"; let result =
587 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
588 assert!(result.is_err());
589 }
590
591 #[test]
592 pub fn test_parse_kill_statement_edge_cases() {
593 use crate::statements::kill::Kill;
594
595 let sql = "KILL QUERY 0";
597 let statements =
598 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
599 .unwrap();
600
601 assert_eq!(statements.len(), 1);
602 match &statements[0] {
603 Statement::Kill(Kill::ConnectionId(connection_id)) => {
604 assert_eq!(*connection_id, 0);
605 }
606 _ => panic!("Expected Kill::ConnectionId statement"),
607 }
608
609 let sql = "KILL QUERY 4294967295"; let statements =
612 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
613 .unwrap();
614
615 assert_eq!(statements.len(), 1);
616 match &statements[0] {
617 Statement::Kill(Kill::ConnectionId(connection_id)) => {
618 assert_eq!(*connection_id, 4294967295);
619 }
620 _ => panic!("Expected Kill::ConnectionId statement"),
621 }
622
623 let sql = "KILL ''";
625 let statements =
626 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
627 .unwrap();
628
629 assert_eq!(statements.len(), 1);
630 match &statements[0] {
631 Statement::Kill(Kill::ProcessId(process_id)) => {
632 assert_eq!(process_id, "");
633 }
634 _ => panic!("Expected Kill::ProcessId statement"),
635 }
636 }
637
638 #[test]
639 pub fn test_parse_kill_statement_case_insensitive() {
640 use crate::statements::kill::Kill;
641
642 let sql = "kill query 123";
644 let statements =
645 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
646 .unwrap();
647
648 assert_eq!(statements.len(), 1);
649 match &statements[0] {
650 Statement::Kill(Kill::ConnectionId(connection_id)) => {
651 assert_eq!(*connection_id, 123);
652 }
653 _ => panic!("Expected Kill::ConnectionId statement"),
654 }
655
656 let sql = "Kill Query 456";
658 let statements =
659 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
660 .unwrap();
661
662 assert_eq!(statements.len(), 1);
663 match &statements[0] {
664 Statement::Kill(Kill::ConnectionId(connection_id)) => {
665 assert_eq!(*connection_id, 456);
666 }
667 _ => panic!("Expected Kill::ConnectionId statement"),
668 }
669 }
670}