1use snafu::ResultExt;
16use sqlparser::keywords::Keyword;
17use sqlparser::tokenizer::Token;
18use sqlparser::tokenizer::Token::Word;
19
20use crate::error::{self, Result};
21use crate::parser::ParserContext;
22use crate::statements::OptionMap;
23use crate::statements::copy::{
24 CopyDatabase, CopyDatabaseArgument, CopyQueryTo, CopyQueryToArgument, CopyTable,
25 CopyTableArgument,
26};
27use crate::statements::statement::Statement;
28use crate::util::parse_option_string;
29
30impl ParserContext<'_> {
32 pub(crate) fn parse_copy(&mut self) -> Result<Statement> {
33 let _ = self.parser.next_token();
34 let next = self.parser.peek_token();
35 let copy = if next.token == Token::LParen {
36 let copy_query = self.parse_copy_query_to()?;
37
38 crate::statements::copy::Copy::CopyQueryTo(copy_query)
39 } else if let Word(word) = next.token
40 && word.keyword == Keyword::DATABASE
41 {
42 let _ = self.parser.next_token();
43 let copy_database = self.parser_copy_database()?;
44 crate::statements::copy::Copy::CopyDatabase(copy_database)
45 } else {
46 let copy_table = self.parse_copy_table()?;
47 crate::statements::copy::Copy::CopyTable(copy_table)
48 };
49
50 Ok(Statement::Copy(copy))
51 }
52
53 fn parser_copy_database(&mut self) -> Result<CopyDatabase> {
54 let database_name = self
55 .parse_object_name()
56 .with_context(|_| error::UnexpectedSnafu {
57 expected: "a database name",
58 actual: self.peek_token_as_string(),
59 })?;
60
61 let req = if self.parser.parse_keyword(Keyword::TO) {
62 let (with, connection, location, limit) = self.parse_copy_parameters()?;
63 if limit.is_some() {
64 return error::InvalidSqlSnafu {
65 msg: "limit is not supported",
66 }
67 .fail();
68 }
69
70 let argument = CopyDatabaseArgument {
71 database_name,
72 with,
73 connection,
74 location,
75 };
76 CopyDatabase::To(argument)
77 } else {
78 self.parser
79 .expect_keyword(Keyword::FROM)
80 .context(error::SyntaxSnafu)?;
81 let (with, connection, location, limit) = self.parse_copy_parameters()?;
82 if limit.is_some() {
83 return error::InvalidSqlSnafu {
84 msg: "limit is not supported",
85 }
86 .fail();
87 }
88
89 let argument = CopyDatabaseArgument {
90 database_name,
91 with,
92 connection,
93 location,
94 };
95 CopyDatabase::From(argument)
96 };
97 Ok(req)
98 }
99
100 fn parse_copy_table(&mut self) -> Result<CopyTable> {
101 let raw_table_name = self
102 .parse_object_name()
103 .with_context(|_| error::UnexpectedSnafu {
104 expected: "a table name",
105 actual: self.peek_token_as_string(),
106 })?;
107 let table_name = Self::canonicalize_object_name(raw_table_name)?;
108
109 if self.parser.parse_keyword(Keyword::TO) {
110 let (with, connection, location, limit) = self.parse_copy_parameters()?;
111 Ok(CopyTable::To(CopyTableArgument {
112 table_name,
113 with,
114 connection,
115 location,
116 limit,
117 }))
118 } else {
119 self.parser
120 .expect_keyword(Keyword::FROM)
121 .context(error::SyntaxSnafu)?;
122 let (with, connection, location, limit) = self.parse_copy_parameters()?;
123 Ok(CopyTable::From(CopyTableArgument {
124 table_name,
125 with,
126 connection,
127 location,
128 limit,
129 }))
130 }
131 }
132
133 fn parse_copy_query_to(&mut self) -> Result<CopyQueryTo> {
134 self.parser
135 .expect_token(&Token::LParen)
136 .with_context(|_| error::UnexpectedSnafu {
137 expected: "'('",
138 actual: self.peek_token_as_string(),
139 })?;
140 let query = self.parse_query()?;
141 self.parser
142 .expect_token(&Token::RParen)
143 .with_context(|_| error::UnexpectedSnafu {
144 expected: "')'",
145 actual: self.peek_token_as_string(),
146 })?;
147 self.parser
148 .expect_keyword(Keyword::TO)
149 .context(error::SyntaxSnafu)?;
150 let (with, connection, location, limit) = self.parse_copy_parameters()?;
151 if limit.is_some() {
152 return error::InvalidSqlSnafu {
153 msg: "limit is not supported",
154 }
155 .fail();
156 }
157 Ok(CopyQueryTo {
158 query: Box::new(query),
159 arg: CopyQueryToArgument {
160 with,
161 connection,
162 location,
163 },
164 })
165 }
166
167 fn parse_copy_parameters(&mut self) -> Result<(OptionMap, OptionMap, String, Option<u64>)> {
168 let location =
169 self.parser
170 .parse_literal_string()
171 .with_context(|_| error::UnexpectedSnafu {
172 expected: "a file name",
173 actual: self.peek_token_as_string(),
174 })?;
175
176 let options = self
177 .parser
178 .parse_options(Keyword::WITH)
179 .context(error::SyntaxSnafu)?;
180
181 let with = options
182 .into_iter()
183 .map(parse_option_string)
184 .collect::<Result<Vec<_>>>()?;
185 let with = OptionMap::new(with);
186
187 let connection_options = self
188 .parser
189 .parse_options(Keyword::CONNECTION)
190 .context(error::SyntaxSnafu)?;
191
192 let connection = connection_options
193 .into_iter()
194 .map(parse_option_string)
195 .collect::<Result<Vec<_>>>()?;
196 let connection = OptionMap::new(connection);
197
198 let limit = if self.parser.parse_keyword(Keyword::LIMIT) {
199 Some(
200 self.parser
201 .parse_literal_uint()
202 .with_context(|_| error::UnexpectedSnafu {
203 expected: "the number of maximum rows",
204 actual: self.peek_token_as_string(),
205 })?,
206 )
207 } else {
208 None
209 };
210
211 Ok((with, connection, location, limit))
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use std::assert_matches::assert_matches;
218 use std::collections::HashMap;
219
220 use sqlparser::ast::{Ident, ObjectName};
221
222 use super::*;
223 use crate::dialect::GreptimeDbDialect;
224 use crate::parser::ParseOptions;
225 use crate::statements::statement::Statement::Copy;
226
227 #[test]
228 fn test_parse_copy_table() {
229 let sql0 = "COPY catalog0.schema0.tbl TO 'tbl_file.parquet'";
230 let sql1 = "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' WITH (FORMAT = 'parquet')";
231 let result0 = ParserContext::create_with_dialect(
232 sql0,
233 &GreptimeDbDialect {},
234 ParseOptions::default(),
235 )
236 .unwrap();
237 let result1 = ParserContext::create_with_dialect(
238 sql1,
239 &GreptimeDbDialect {},
240 ParseOptions::default(),
241 )
242 .unwrap();
243
244 for mut result in [result0, result1] {
245 assert_eq!(1, result.len());
246
247 let statement = result.remove(0);
248 assert_matches!(statement, Statement::Copy { .. });
249 match statement {
250 Copy(copy) => {
251 let crate::statements::copy::Copy::CopyTable(CopyTable::To(copy_table)) = copy
252 else {
253 unreachable!()
254 };
255 let table = copy_table.table_name.to_string();
256 assert_eq!("catalog0.schema0.tbl", table);
257
258 let file_name = ©_table.location;
259 assert_eq!("tbl_file.parquet", file_name);
260
261 let format = copy_table.format().unwrap();
262 assert_eq!("parquet", format.to_lowercase());
263 }
264 _ => unreachable!(),
265 }
266 }
267 }
268
269 #[test]
270 fn test_parse_copy_table_from_basic() {
271 let results = [
272 "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet'",
273 "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (FORMAT = 'parquet')",
274 ]
275 .iter()
276 .map(|sql| {
277 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
278 .unwrap()
279 })
280 .collect::<Vec<_>>();
281
282 for mut result in results {
283 assert_eq!(1, result.len());
284
285 let statement = result.remove(0);
286 assert_matches!(statement, Statement::Copy { .. });
287 match statement {
288 Statement::Copy(crate::statements::copy::Copy::CopyTable(CopyTable::From(
289 copy_table,
290 ))) => {
291 let table = copy_table.table_name.to_string();
292 assert_eq!("catalog0.schema0.tbl", table);
293
294 let file_name = ©_table.location;
295 assert_eq!("tbl_file.parquet", file_name);
296
297 let format = copy_table.format().unwrap();
298 assert_eq!("parquet", format.to_lowercase());
299 }
300 _ => unreachable!(),
301 }
302 }
303 }
304
305 #[test]
306 fn test_parse_copy_table_from() {
307 struct Test<'a> {
308 sql: &'a str,
309 expected_pattern: Option<String>,
310 expected_connection: HashMap<&'a str, &'a str>,
311 }
312
313 let tests = [
314 Test {
315 sql: "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (PATTERN = 'demo.*')",
316 expected_pattern: Some("demo.*".into()),
317 expected_connection: HashMap::new(),
318 },
319 Test {
320 sql: "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (PATTERN = 'demo.*') CONNECTION (FOO='Bar', ONE='two')",
321 expected_pattern: Some("demo.*".into()),
322 expected_connection: HashMap::from([("foo", "Bar"), ("one", "two")]),
323 },
324 ];
325
326 for test in tests {
327 let mut result = ParserContext::create_with_dialect(
328 test.sql,
329 &GreptimeDbDialect {},
330 ParseOptions::default(),
331 )
332 .unwrap();
333 assert_eq!(1, result.len());
334
335 let statement = result.remove(0);
336 assert_matches!(statement, Statement::Copy { .. });
337 match statement {
338 Statement::Copy(crate::statements::copy::Copy::CopyTable(CopyTable::From(
339 copy_table,
340 ))) => {
341 if let Some(expected_pattern) = test.expected_pattern {
342 assert_eq!(copy_table.pattern().unwrap(), expected_pattern);
343 }
344 assert_eq!(copy_table.connection.to_str_map(), test.expected_connection);
345 }
346 _ => unreachable!(),
347 }
348 }
349 }
350
351 #[test]
352 fn test_parse_copy_table_to() {
353 struct Test<'a> {
354 sql: &'a str,
355 expected_connection: HashMap<&'a str, &'a str>,
356 }
357
358 let tests = [
359 Test {
360 sql: "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' ",
361 expected_connection: HashMap::new(),
362 },
363 Test {
364 sql: "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' CONNECTION (FOO='Bar', ONE='two')",
365 expected_connection: HashMap::from([("foo", "Bar"), ("one", "two")]),
366 },
367 Test {
368 sql: "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')",
369 expected_connection: HashMap::from([("foo", "Bar"), ("one", "two")]),
370 },
371 ];
372
373 for test in tests {
374 let mut result = ParserContext::create_with_dialect(
375 test.sql,
376 &GreptimeDbDialect {},
377 ParseOptions::default(),
378 )
379 .unwrap();
380 assert_eq!(1, result.len());
381
382 let statement = result.remove(0);
383 assert_matches!(statement, Statement::Copy { .. });
384 match statement {
385 Statement::Copy(crate::statements::copy::Copy::CopyTable(CopyTable::To(
386 copy_table,
387 ))) => {
388 assert_eq!(copy_table.connection.to_str_map(), test.expected_connection);
389 }
390 _ => unreachable!(),
391 }
392 }
393 }
394
395 #[test]
396 fn test_copy_database_to() {
397 let sql = "COPY DATABASE catalog0.schema0 TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
398 let stmt =
399 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
400 .unwrap()
401 .pop()
402 .unwrap();
403
404 let Copy(crate::statements::copy::Copy::CopyDatabase(stmt)) = stmt else {
405 unreachable!()
406 };
407
408 let CopyDatabase::To(stmt) = stmt else {
409 unreachable!()
410 };
411
412 assert_eq!(
413 ObjectName::from(vec![Ident::new("catalog0"), Ident::new("schema0")]),
414 stmt.database_name
415 );
416 assert_eq!(
417 [("format", "parquet")]
418 .into_iter()
419 .collect::<HashMap<_, _>>(),
420 stmt.with.to_str_map()
421 );
422
423 assert_eq!(
424 [("foo", "Bar"), ("one", "two")]
425 .into_iter()
426 .collect::<HashMap<_, _>>(),
427 stmt.connection.to_str_map()
428 );
429 }
430
431 #[test]
432 fn test_copy_database_from() {
433 let sql = "COPY DATABASE catalog0.schema0 FROM '/a/b/c/' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
434 let stmt =
435 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
436 .unwrap()
437 .pop()
438 .unwrap();
439
440 let Copy(crate::statements::copy::Copy::CopyDatabase(stmt)) = stmt else {
441 unreachable!()
442 };
443
444 let CopyDatabase::From(stmt) = stmt else {
445 unreachable!()
446 };
447
448 assert_eq!(
449 ObjectName::from(vec![Ident::new("catalog0"), Ident::new("schema0")]),
450 stmt.database_name
451 );
452 assert_eq!(
453 [("format", "parquet")]
454 .into_iter()
455 .collect::<HashMap<_, _>>(),
456 stmt.with.to_str_map()
457 );
458
459 assert_eq!(
460 [("foo", "Bar"), ("one", "two")]
461 .into_iter()
462 .collect::<HashMap<_, _>>(),
463 stmt.connection.to_str_map()
464 );
465 }
466
467 #[test]
468 fn test_copy_query_to() {
469 let sql = "COPY (SELECT * FROM tbl WHERE ts > 10) TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
470 let stmt =
471 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
472 .unwrap()
473 .pop()
474 .unwrap();
475
476 let Copy(crate::statements::copy::Copy::CopyQueryTo(stmt)) = stmt else {
477 unreachable!()
478 };
479
480 let query = ParserContext::create_with_dialect(
481 "SELECT * FROM tbl WHERE ts > 10",
482 &GreptimeDbDialect {},
483 ParseOptions::default(),
484 )
485 .unwrap()
486 .remove(0);
487
488 assert_eq!(&query, stmt.query.as_ref());
489 assert_eq!(
490 [("format", "parquet")]
491 .into_iter()
492 .collect::<HashMap<_, _>>(),
493 stmt.arg.with.to_str_map()
494 );
495
496 assert_eq!(
497 [("foo", "Bar"), ("one", "two")]
498 .into_iter()
499 .collect::<HashMap<_, _>>(),
500 stmt.arg.connection.to_str_map()
501 );
502 }
503
504 #[test]
505 fn test_invalid_copy_query_to() {
506 {
507 let sql = "COPY SELECT * FROM tbl WHERE ts > 10 TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
508
509 assert!(
510 ParserContext::create_with_dialect(
511 sql,
512 &GreptimeDbDialect {},
513 ParseOptions::default()
514 )
515 .is_err()
516 )
517 }
518 {
519 let sql = "COPY SELECT * FROM tbl WHERE ts > 10) TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
520
521 assert!(
522 ParserContext::create_with_dialect(
523 sql,
524 &GreptimeDbDialect {},
525 ParseOptions::default()
526 )
527 .is_err()
528 )
529 }
530 {
531 let sql = "COPY (SELECT * FROM tbl WHERE ts > 10 TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
532
533 assert!(
534 ParserContext::create_with_dialect(
535 sql,
536 &GreptimeDbDialect {},
537 ParseOptions::default()
538 )
539 .is_err()
540 )
541 }
542 }
543}