1use snafu::ResultExt;
16use sqlparser::ast::{Ident, Query};
17use sqlparser::dialect::Dialect;
18use sqlparser::keywords::Keyword;
19use sqlparser::parser::{Parser, ParserError, ParserOptions};
20use sqlparser::tokenizer::{Token, TokenWithSpan};
21
22use crate::ast::{Expr, ObjectName};
23use crate::error::{self, Result, SyntaxSnafu};
24use crate::parsers::tql_parser;
25use crate::statements::statement::Statement;
26use crate::statements::transform_statements;
27
28pub const FLOW: &str = "FLOW";
29
30#[derive(Clone, Debug, Default)]
32pub struct ParseOptions {}
33
34pub struct ParserContext<'a> {
36 pub(crate) parser: Parser<'a>,
37 pub(crate) sql: &'a str,
38}
39
40impl ParserContext<'_> {
41 pub fn new<'a>(dialect: &'a dyn Dialect, sql: &'a str) -> Result<ParserContext<'a>> {
43 let parser = Parser::new(dialect)
44 .with_options(ParserOptions::new().with_trailing_commas(true))
45 .try_with_sql(sql)
46 .context(SyntaxSnafu)?;
47
48 Ok(ParserContext { parser, sql })
49 }
50
51 pub fn parser_query(&mut self) -> Result<Box<Query>> {
53 self.parser.parse_query().context(SyntaxSnafu)
54 }
55
56 pub fn create_with_dialect(
58 sql: &str,
59 dialect: &dyn Dialect,
60 _opts: ParseOptions,
61 ) -> Result<Vec<Statement>> {
62 let mut stmts: Vec<Statement> = Vec::new();
63
64 let mut parser_ctx = ParserContext::new(dialect, sql)?;
65
66 let mut expecting_statement_delimiter = false;
67 loop {
68 while parser_ctx.parser.consume_token(&Token::SemiColon) {
70 expecting_statement_delimiter = false;
71 }
72
73 if parser_ctx.parser.peek_token() == Token::EOF {
74 break;
75 }
76 if expecting_statement_delimiter {
77 return parser_ctx.unsupported(parser_ctx.peek_token_as_string());
78 }
79
80 let statement = parser_ctx.parse_statement()?;
81 stmts.push(statement);
82 expecting_statement_delimiter = true;
83 }
84
85 transform_statements(&mut stmts)?;
86
87 Ok(stmts)
88 }
89
90 pub fn parse_table_name(sql: &str, dialect: &dyn Dialect) -> Result<ObjectName> {
91 let parser = Parser::new(dialect)
92 .with_options(ParserOptions::new().with_trailing_commas(true))
93 .try_with_sql(sql)
94 .context(SyntaxSnafu)?;
95 ParserContext { parser, sql }.intern_parse_table_name()
96 }
97
98 pub(crate) fn intern_parse_table_name(&mut self) -> Result<ObjectName> {
99 let raw_table_name =
100 self.parser
101 .parse_object_name(false)
102 .context(error::UnexpectedSnafu {
103 expected: "a table name",
104 actual: self.parser.peek_token().to_string(),
105 })?;
106 Ok(Self::canonicalize_object_name(raw_table_name))
107 }
108
109 pub fn parse_function(sql: &str, dialect: &dyn Dialect) -> Result<Expr> {
110 let mut parser = Parser::new(dialect)
111 .with_options(ParserOptions::new().with_trailing_commas(true))
112 .try_with_sql(sql)
113 .context(SyntaxSnafu)?;
114
115 let function_name = parser.parse_identifier().context(SyntaxSnafu)?;
116 parser
117 .parse_function(ObjectName(vec![function_name]))
118 .context(SyntaxSnafu)
119 }
120
121 pub fn parse_statement(&mut self) -> Result<Statement> {
123 match self.parser.peek_token().token {
124 Token::Word(w) => {
125 match w.keyword {
126 Keyword::CREATE => {
127 let _ = self.parser.next_token();
128 self.parse_create()
129 }
130
131 Keyword::EXPLAIN => {
132 let _ = self.parser.next_token();
133 self.parse_explain()
134 }
135
136 Keyword::SHOW => {
137 let _ = self.parser.next_token();
138 self.parse_show()
139 }
140
141 Keyword::DELETE => self.parse_delete(),
142
143 Keyword::DESCRIBE | Keyword::DESC => {
144 let _ = self.parser.next_token();
145 self.parse_describe()
146 }
147
148 Keyword::INSERT => self.parse_insert(),
149
150 Keyword::REPLACE => self.parse_replace(),
151
152 Keyword::SELECT | Keyword::WITH | Keyword::VALUES => self.parse_query(),
153
154 Keyword::ALTER => self.parse_alter(),
155
156 Keyword::DROP => self.parse_drop(),
157
158 Keyword::COPY => self.parse_copy(),
159
160 Keyword::TRUNCATE => self.parse_truncate(),
161
162 Keyword::SET => self.parse_set_variables(),
163
164 Keyword::ADMIN => self.parse_admin_command(),
165
166 Keyword::NoKeyword
167 if w.quote_style.is_none() && w.value.to_uppercase() == tql_parser::TQL =>
168 {
169 self.parse_tql()
170 }
171
172 Keyword::DECLARE => self.parse_declare_cursor(),
173
174 Keyword::FETCH => self.parse_fetch_cursor(),
175
176 Keyword::CLOSE => self.parse_close_cursor(),
177
178 Keyword::USE => {
179 let _ = self.parser.next_token();
180
181 let database_name = self.parser.parse_identifier().with_context(|_| {
182 error::UnexpectedSnafu {
183 expected: "a database name",
184 actual: self.peek_token_as_string(),
185 }
186 })?;
187 Ok(Statement::Use(
188 Self::canonicalize_identifier(database_name).value,
189 ))
190 }
191
192 _ => self.unsupported(self.peek_token_as_string()),
194 }
195 }
196 Token::LParen => self.parse_query(),
197 unexpected => self.unsupported(unexpected.to_string()),
198 }
199 }
200
201 pub fn parse_mysql_prepare_stmt(sql: &str, dialect: &dyn Dialect) -> Result<(String, String)> {
203 ParserContext::new(dialect, sql)?.parse_mysql_prepare()
204 }
205
206 pub fn parse_mysql_execute_stmt(
208 sql: &str,
209 dialect: &dyn Dialect,
210 ) -> Result<(String, Vec<Expr>)> {
211 ParserContext::new(dialect, sql)?.parse_mysql_execute()
212 }
213
214 pub fn parse_mysql_deallocate_stmt(sql: &str, dialect: &dyn Dialect) -> Result<String> {
216 ParserContext::new(dialect, sql)?.parse_deallocate()
217 }
218
219 pub fn unsupported<T>(&self, keyword: String) -> Result<T> {
221 error::UnsupportedSnafu { keyword }.fail()
222 }
223
224 pub(crate) fn expected<T>(&self, expected: &str, found: TokenWithSpan) -> Result<T> {
226 Err(ParserError::ParserError(format!(
227 "Expected {expected}, found: {found}",
228 )))
229 .context(SyntaxSnafu)
230 }
231
232 pub fn matches_keyword(&mut self, expected: Keyword) -> bool {
233 match self.parser.peek_token().token {
234 Token::Word(w) => w.keyword == expected,
235 _ => false,
236 }
237 }
238
239 pub fn consume_token(&mut self, expected: &str) -> bool {
240 if self.peek_token_as_string().to_uppercase() == *expected.to_uppercase() {
241 let _ = self.parser.next_token();
242 true
243 } else {
244 false
245 }
246 }
247
248 #[inline]
249 pub(crate) fn peek_token_as_string(&self) -> String {
250 self.parser.peek_token().to_string()
251 }
252
253 pub fn canonicalize_identifier(ident: Ident) -> Ident {
255 if ident.quote_style.is_some() {
256 ident
257 } else {
258 Ident::new(ident.value.to_lowercase())
259 }
260 }
261
262 pub fn canonicalize_object_name(object_name: ObjectName) -> ObjectName {
264 ObjectName(
265 object_name
266 .0
267 .into_iter()
268 .map(Self::canonicalize_identifier)
269 .collect(),
270 )
271 }
272
273 pub(crate) fn parse_object_name(&mut self) -> std::result::Result<ObjectName, ParserError> {
278 self.parser.parse_object_name(false)
279 }
280}
281
282#[cfg(test)]
283mod tests {
284
285 use datatypes::prelude::ConcreteDataType;
286 use sqlparser::dialect::MySqlDialect;
287
288 use super::*;
289 use crate::dialect::GreptimeDbDialect;
290 use crate::statements::create::CreateTable;
291 use crate::statements::sql_data_type_to_concrete_data_type;
292
293 fn test_timestamp_precision(sql: &str, expected_type: ConcreteDataType) {
294 match ParserContext::create_with_dialect(
295 sql,
296 &GreptimeDbDialect {},
297 ParseOptions::default(),
298 )
299 .unwrap()
300 .pop()
301 .unwrap()
302 {
303 Statement::CreateTable(CreateTable { columns, .. }) => {
304 let ts_col = columns.first().unwrap();
305 assert_eq!(
306 expected_type,
307 sql_data_type_to_concrete_data_type(ts_col.data_type()).unwrap()
308 );
309 }
310 _ => unreachable!(),
311 }
312 }
313
314 #[test]
315 pub fn test_create_table_with_precision() {
316 test_timestamp_precision(
317 "create table demo (ts timestamp time index, cnt int);",
318 ConcreteDataType::timestamp_millisecond_datatype(),
319 );
320 test_timestamp_precision(
321 "create table demo (ts timestamp(0) time index, cnt int);",
322 ConcreteDataType::timestamp_second_datatype(),
323 );
324 test_timestamp_precision(
325 "create table demo (ts timestamp(3) time index, cnt int);",
326 ConcreteDataType::timestamp_millisecond_datatype(),
327 );
328 test_timestamp_precision(
329 "create table demo (ts timestamp(6) time index, cnt int);",
330 ConcreteDataType::timestamp_microsecond_datatype(),
331 );
332 test_timestamp_precision(
333 "create table demo (ts timestamp(9) time index, cnt int);",
334 ConcreteDataType::timestamp_nanosecond_datatype(),
335 );
336 }
337
338 #[test]
339 #[should_panic]
340 pub fn test_create_table_with_invalid_precision() {
341 test_timestamp_precision(
342 "create table demo (ts timestamp(1) time index, cnt int);",
343 ConcreteDataType::timestamp_millisecond_datatype(),
344 );
345 }
346
347 #[test]
348 pub fn test_parse_table_name() {
349 let table_name = "a.b.c";
350
351 let object_name =
352 ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
353
354 assert_eq!(object_name.0.len(), 3);
355 assert_eq!(object_name.to_string(), table_name);
356
357 let table_name = "a.b";
358
359 let object_name =
360 ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
361
362 assert_eq!(object_name.0.len(), 2);
363 assert_eq!(object_name.to_string(), table_name);
364
365 let table_name = "Test.\"public-test\"";
366
367 let object_name =
368 ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
369
370 assert_eq!(object_name.0.len(), 2);
371 assert_eq!(object_name.to_string(), table_name.to_ascii_lowercase());
372
373 let table_name = "HelloWorld";
374
375 let object_name =
376 ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
377
378 assert_eq!(object_name.0.len(), 1);
379 assert_eq!(object_name.to_string(), table_name.to_ascii_lowercase());
380 }
381
382 #[test]
383 pub fn test_parse_mysql_prepare_stmt() {
384 let sql = "PREPARE stmt1 FROM 'SELECT * FROM t1 WHERE id = ?';";
385 let (stmt_name, stmt) =
386 ParserContext::parse_mysql_prepare_stmt(sql, &MySqlDialect {}).unwrap();
387 assert_eq!(stmt_name, "stmt1");
388 assert_eq!(stmt, "SELECT * FROM t1 WHERE id = ?");
389
390 let sql = "PREPARE stmt2 FROM \"SELECT * FROM t1 WHERE id = ?\"";
391 let (stmt_name, stmt) =
392 ParserContext::parse_mysql_prepare_stmt(sql, &MySqlDialect {}).unwrap();
393 assert_eq!(stmt_name, "stmt2");
394 assert_eq!(stmt, "SELECT * FROM t1 WHERE id = ?");
395 }
396
397 #[test]
398 pub fn test_parse_mysql_execute_stmt() {
399 let sql = "EXECUTE stmt1 USING 1, 'hello';";
400 let (stmt_name, params) =
401 ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
402 assert_eq!(stmt_name, "stmt1");
403 assert_eq!(params.len(), 2);
404 assert_eq!(params[0].to_string(), "1");
405 assert_eq!(params[1].to_string(), "'hello'");
406
407 let sql = "EXECUTE stmt2;";
408 let (stmt_name, params) =
409 ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
410 assert_eq!(stmt_name, "stmt2");
411 assert_eq!(params.len(), 0);
412
413 let sql = "EXECUTE stmt3 USING 231, 'hello', \"2003-03-1\", NULL, ;";
414 let (stmt_name, params) =
415 ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
416 assert_eq!(stmt_name, "stmt3");
417 assert_eq!(params.len(), 4);
418 assert_eq!(params[0].to_string(), "231");
419 assert_eq!(params[1].to_string(), "'hello'");
420 assert_eq!(params[2].to_string(), "\"2003-03-1\"");
421 assert_eq!(params[3].to_string(), "NULL");
422 }
423
424 #[test]
425 pub fn test_parse_mysql_deallocate_stmt() {
426 let sql = "DEALLOCATE stmt1;";
427 let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap();
428 assert_eq!(stmt_name, "stmt1");
429
430 let sql = "DEALLOCATE stmt2";
431 let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap();
432 assert_eq!(stmt_name, "stmt2");
433 }
434}