1use std::collections::HashMap;
16
17use snafu::ResultExt;
18use sqlparser::keywords::Keyword;
19use sqlparser::tokenizer::Token;
20use sqlparser::tokenizer::Token::Word;
21
22use crate::error::{self, Result};
23use crate::parser::ParserContext;
24use crate::statements::copy::{
25 CopyDatabase, CopyDatabaseArgument, CopyQueryTo, CopyQueryToArgument, CopyTable,
26 CopyTableArgument,
27};
28use crate::statements::statement::Statement;
29use crate::util::parse_option_string;
30
31pub type With = HashMap<String, String>;
32pub type Connection = HashMap<String, String>;
33
34impl ParserContext<'_> {
36 pub(crate) fn parse_copy(&mut self) -> Result<Statement> {
37 let _ = self.parser.next_token();
38 let next = self.parser.peek_token();
39 let copy = if next.token == Token::LParen {
40 let copy_query = self.parse_copy_query_to()?;
41
42 crate::statements::copy::Copy::CopyQueryTo(copy_query)
43 } else if let Word(word) = next.token
44 && word.keyword == Keyword::DATABASE
45 {
46 let _ = self.parser.next_token();
47 let copy_database = self.parser_copy_database()?;
48 crate::statements::copy::Copy::CopyDatabase(copy_database)
49 } else {
50 let copy_table = self.parse_copy_table()?;
51 crate::statements::copy::Copy::CopyTable(copy_table)
52 };
53
54 Ok(Statement::Copy(copy))
55 }
56
57 fn parser_copy_database(&mut self) -> Result<CopyDatabase> {
58 let database_name = self
59 .parse_object_name()
60 .with_context(|_| error::UnexpectedSnafu {
61 expected: "a database name",
62 actual: self.peek_token_as_string(),
63 })?;
64
65 let req = if self.parser.parse_keyword(Keyword::TO) {
66 let (with, connection, location, limit) = self.parse_copy_parameters()?;
67 if limit.is_some() {
68 return error::InvalidSqlSnafu {
69 msg: "limit is not supported",
70 }
71 .fail();
72 }
73
74 let argument = CopyDatabaseArgument {
75 database_name,
76 with: with.into(),
77 connection: connection.into(),
78 location,
79 };
80 CopyDatabase::To(argument)
81 } else {
82 self.parser
83 .expect_keyword(Keyword::FROM)
84 .context(error::SyntaxSnafu)?;
85 let (with, connection, location, limit) = self.parse_copy_parameters()?;
86 if limit.is_some() {
87 return error::InvalidSqlSnafu {
88 msg: "limit is not supported",
89 }
90 .fail();
91 }
92
93 let argument = CopyDatabaseArgument {
94 database_name,
95 with: with.into(),
96 connection: connection.into(),
97 location,
98 };
99 CopyDatabase::From(argument)
100 };
101 Ok(req)
102 }
103
104 fn parse_copy_table(&mut self) -> Result<CopyTable> {
105 let raw_table_name = self
106 .parse_object_name()
107 .with_context(|_| error::UnexpectedSnafu {
108 expected: "a table name",
109 actual: self.peek_token_as_string(),
110 })?;
111 let table_name = Self::canonicalize_object_name(raw_table_name);
112
113 if self.parser.parse_keyword(Keyword::TO) {
114 let (with, connection, location, limit) = self.parse_copy_parameters()?;
115 Ok(CopyTable::To(CopyTableArgument {
116 table_name,
117 with: with.into(),
118 connection: connection.into(),
119 location,
120 limit,
121 }))
122 } else {
123 self.parser
124 .expect_keyword(Keyword::FROM)
125 .context(error::SyntaxSnafu)?;
126 let (with, connection, location, limit) = self.parse_copy_parameters()?;
127 Ok(CopyTable::From(CopyTableArgument {
128 table_name,
129 with: with.into(),
130 connection: connection.into(),
131 location,
132 limit,
133 }))
134 }
135 }
136
137 fn parse_copy_query_to(&mut self) -> Result<CopyQueryTo> {
138 self.parser
139 .expect_token(&Token::LParen)
140 .with_context(|_| error::UnexpectedSnafu {
141 expected: "'('",
142 actual: self.peek_token_as_string(),
143 })?;
144 let query = self.parse_query()?;
145 self.parser
146 .expect_token(&Token::RParen)
147 .with_context(|_| error::UnexpectedSnafu {
148 expected: "')'",
149 actual: self.peek_token_as_string(),
150 })?;
151 self.parser
152 .expect_keyword(Keyword::TO)
153 .context(error::SyntaxSnafu)?;
154 let (with, connection, location, limit) = self.parse_copy_parameters()?;
155 if limit.is_some() {
156 return error::InvalidSqlSnafu {
157 msg: "limit is not supported",
158 }
159 .fail();
160 }
161 Ok(CopyQueryTo {
162 query: Box::new(query),
163 arg: CopyQueryToArgument {
164 with: with.into(),
165 connection: connection.into(),
166 location,
167 },
168 })
169 }
170
171 fn parse_copy_parameters(&mut self) -> Result<(With, Connection, String, Option<u64>)> {
172 let location =
173 self.parser
174 .parse_literal_string()
175 .with_context(|_| error::UnexpectedSnafu {
176 expected: "a file name",
177 actual: self.peek_token_as_string(),
178 })?;
179
180 let options = self
181 .parser
182 .parse_options(Keyword::WITH)
183 .context(error::SyntaxSnafu)?;
184
185 let with = options
186 .into_iter()
187 .map(parse_option_string)
188 .collect::<Result<With>>()?;
189
190 let connection_options = self
191 .parser
192 .parse_options(Keyword::CONNECTION)
193 .context(error::SyntaxSnafu)?;
194
195 let connection = connection_options
196 .into_iter()
197 .map(parse_option_string)
198 .collect::<Result<Connection>>()?;
199
200 let limit = if self.parser.parse_keyword(Keyword::LIMIT) {
201 Some(
202 self.parser
203 .parse_literal_uint()
204 .with_context(|_| error::UnexpectedSnafu {
205 expected: "the number of maximum rows",
206 actual: self.peek_token_as_string(),
207 })?,
208 )
209 } else {
210 None
211 };
212
213 Ok((with, connection, location, limit))
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use std::assert_matches::assert_matches;
220 use std::collections::HashMap;
221
222 use sqlparser::ast::{Ident, ObjectName};
223
224 use super::*;
225 use crate::dialect::GreptimeDbDialect;
226 use crate::parser::ParseOptions;
227 use crate::statements::statement::Statement::Copy;
228
229 #[test]
230 fn test_parse_copy_table() {
231 let sql0 = "COPY catalog0.schema0.tbl TO 'tbl_file.parquet'";
232 let sql1 = "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' WITH (FORMAT = 'parquet')";
233 let result0 = ParserContext::create_with_dialect(
234 sql0,
235 &GreptimeDbDialect {},
236 ParseOptions::default(),
237 )
238 .unwrap();
239 let result1 = ParserContext::create_with_dialect(
240 sql1,
241 &GreptimeDbDialect {},
242 ParseOptions::default(),
243 )
244 .unwrap();
245
246 for mut result in [result0, result1] {
247 assert_eq!(1, result.len());
248
249 let statement = result.remove(0);
250 assert_matches!(statement, Statement::Copy { .. });
251 match statement {
252 Copy(copy) => {
253 let crate::statements::copy::Copy::CopyTable(CopyTable::To(copy_table)) = copy
254 else {
255 unreachable!()
256 };
257 let table = copy_table.table_name.to_string();
258 assert_eq!("catalog0.schema0.tbl", table);
259
260 let file_name = ©_table.location;
261 assert_eq!("tbl_file.parquet", file_name);
262
263 let format = copy_table.format().unwrap();
264 assert_eq!("parquet", format.to_lowercase());
265 }
266 _ => unreachable!(),
267 }
268 }
269 }
270
271 #[test]
272 fn test_parse_copy_table_from_basic() {
273 let results = [
274 "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet'",
275 "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (FORMAT = 'parquet')",
276 ]
277 .iter()
278 .map(|sql| {
279 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
280 .unwrap()
281 })
282 .collect::<Vec<_>>();
283
284 for mut result in results {
285 assert_eq!(1, result.len());
286
287 let statement = result.remove(0);
288 assert_matches!(statement, Statement::Copy { .. });
289 match statement {
290 Statement::Copy(crate::statements::copy::Copy::CopyTable(CopyTable::From(
291 copy_table,
292 ))) => {
293 let table = copy_table.table_name.to_string();
294 assert_eq!("catalog0.schema0.tbl", table);
295
296 let file_name = ©_table.location;
297 assert_eq!("tbl_file.parquet", file_name);
298
299 let format = copy_table.format().unwrap();
300 assert_eq!("parquet", format.to_lowercase());
301 }
302 _ => unreachable!(),
303 }
304 }
305 }
306
307 #[test]
308 fn test_parse_copy_table_from() {
309 struct Test<'a> {
310 sql: &'a str,
311 expected_pattern: Option<String>,
312 expected_connection: HashMap<String, String>,
313 }
314
315 let tests = [
316 Test {
317 sql: "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (PATTERN = 'demo.*')",
318 expected_pattern: Some("demo.*".into()),
319 expected_connection: HashMap::new(),
320 },
321 Test {
322 sql: "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (PATTERN = 'demo.*') CONNECTION (FOO='Bar', ONE='two')",
323 expected_pattern: Some("demo.*".into()),
324 expected_connection: [("foo", "Bar"), ("one", "two")]
325 .into_iter()
326 .map(|(k, v)| (k.to_string(), v.to_string()))
327 .collect(),
328 },
329 ];
330
331 for test in tests {
332 let mut result = ParserContext::create_with_dialect(
333 test.sql,
334 &GreptimeDbDialect {},
335 ParseOptions::default(),
336 )
337 .unwrap();
338 assert_eq!(1, result.len());
339
340 let statement = result.remove(0);
341 assert_matches!(statement, Statement::Copy { .. });
342 match statement {
343 Statement::Copy(crate::statements::copy::Copy::CopyTable(CopyTable::From(
344 copy_table,
345 ))) => {
346 if let Some(expected_pattern) = test.expected_pattern {
347 assert_eq!(copy_table.pattern().unwrap(), expected_pattern);
348 }
349 assert_eq!(
350 copy_table.connection.clone(),
351 test.expected_connection.into()
352 );
353 }
354 _ => unreachable!(),
355 }
356 }
357 }
358
359 #[test]
360 fn test_parse_copy_table_to() {
361 struct Test<'a> {
362 sql: &'a str,
363 expected_connection: HashMap<String, String>,
364 }
365
366 let tests = [
367 Test {
368 sql: "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' ",
369 expected_connection: HashMap::new(),
370 },
371 Test {
372 sql: "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' CONNECTION (FOO='Bar', ONE='two')",
373 expected_connection: [("foo", "Bar"), ("one", "two")]
374 .into_iter()
375 .map(|(k, v)| (k.to_string(), v.to_string()))
376 .collect(),
377 },
378 Test {
379 sql: "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')",
380 expected_connection: [("foo", "Bar"), ("one", "two")]
381 .into_iter()
382 .map(|(k, v)| (k.to_string(), v.to_string()))
383 .collect(),
384 },
385 ];
386
387 for test in tests {
388 let mut result = ParserContext::create_with_dialect(
389 test.sql,
390 &GreptimeDbDialect {},
391 ParseOptions::default(),
392 )
393 .unwrap();
394 assert_eq!(1, result.len());
395
396 let statement = result.remove(0);
397 assert_matches!(statement, Statement::Copy { .. });
398 match statement {
399 Statement::Copy(crate::statements::copy::Copy::CopyTable(CopyTable::To(
400 copy_table,
401 ))) => {
402 assert_eq!(
403 copy_table.connection.clone(),
404 test.expected_connection.into()
405 );
406 }
407 _ => unreachable!(),
408 }
409 }
410 }
411
412 #[test]
413 fn test_copy_database_to() {
414 let sql = "COPY DATABASE catalog0.schema0 TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
415 let stmt =
416 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
417 .unwrap()
418 .pop()
419 .unwrap();
420
421 let Copy(crate::statements::copy::Copy::CopyDatabase(stmt)) = stmt else {
422 unreachable!()
423 };
424
425 let CopyDatabase::To(stmt) = stmt else {
426 unreachable!()
427 };
428
429 assert_eq!(
430 ObjectName::from(vec![Ident::new("catalog0"), Ident::new("schema0")]),
431 stmt.database_name
432 );
433 assert_eq!(
434 [("format", "parquet")]
435 .into_iter()
436 .collect::<HashMap<_, _>>(),
437 stmt.with.to_str_map()
438 );
439
440 assert_eq!(
441 [("foo", "Bar"), ("one", "two")]
442 .into_iter()
443 .collect::<HashMap<_, _>>(),
444 stmt.connection.to_str_map()
445 );
446 }
447
448 #[test]
449 fn test_copy_database_from() {
450 let sql = "COPY DATABASE catalog0.schema0 FROM '/a/b/c/' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
451 let stmt =
452 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
453 .unwrap()
454 .pop()
455 .unwrap();
456
457 let Copy(crate::statements::copy::Copy::CopyDatabase(stmt)) = stmt else {
458 unreachable!()
459 };
460
461 let CopyDatabase::From(stmt) = stmt else {
462 unreachable!()
463 };
464
465 assert_eq!(
466 ObjectName::from(vec![Ident::new("catalog0"), Ident::new("schema0")]),
467 stmt.database_name
468 );
469 assert_eq!(
470 [("format", "parquet")]
471 .into_iter()
472 .collect::<HashMap<_, _>>(),
473 stmt.with.to_str_map()
474 );
475
476 assert_eq!(
477 [("foo", "Bar"), ("one", "two")]
478 .into_iter()
479 .collect::<HashMap<_, _>>(),
480 stmt.connection.to_str_map()
481 );
482 }
483
484 #[test]
485 fn test_copy_query_to() {
486 let sql = "COPY (SELECT * FROM tbl WHERE ts > 10) TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
487 let stmt =
488 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
489 .unwrap()
490 .pop()
491 .unwrap();
492
493 let Copy(crate::statements::copy::Copy::CopyQueryTo(stmt)) = stmt else {
494 unreachable!()
495 };
496
497 let query = ParserContext::create_with_dialect(
498 "SELECT * FROM tbl WHERE ts > 10",
499 &GreptimeDbDialect {},
500 ParseOptions::default(),
501 )
502 .unwrap()
503 .remove(0);
504
505 assert_eq!(&query, stmt.query.as_ref());
506 assert_eq!(
507 [("format", "parquet")]
508 .into_iter()
509 .collect::<HashMap<_, _>>(),
510 stmt.arg.with.to_str_map()
511 );
512
513 assert_eq!(
514 [("foo", "Bar"), ("one", "two")]
515 .into_iter()
516 .collect::<HashMap<_, _>>(),
517 stmt.arg.connection.to_str_map()
518 );
519 }
520
521 #[test]
522 fn test_invalid_copy_query_to() {
523 {
524 let sql = "COPY SELECT * FROM tbl WHERE ts > 10 TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
525
526 assert!(
527 ParserContext::create_with_dialect(
528 sql,
529 &GreptimeDbDialect {},
530 ParseOptions::default()
531 )
532 .is_err()
533 )
534 }
535 {
536 let sql = "COPY SELECT * FROM tbl WHERE ts > 10) TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
537
538 assert!(
539 ParserContext::create_with_dialect(
540 sql,
541 &GreptimeDbDialect {},
542 ParseOptions::default()
543 )
544 .is_err()
545 )
546 }
547 {
548 let sql = "COPY (SELECT * FROM tbl WHERE ts > 10 TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
549
550 assert!(
551 ParserContext::create_with_dialect(
552 sql,
553 &GreptimeDbDialect {},
554 ParseOptions::default()
555 )
556 .is_err()
557 )
558 }
559 }
560}