1use std::collections::HashMap;
16
17use snafu::ResultExt;
18use sqlparser::keywords::Keyword;
19use sqlparser::tokenizer::Token;
20use sqlparser::tokenizer::Token::Word;
21
22use crate::error::{self, Result};
23use crate::parser::ParserContext;
24use crate::statements::copy::{
25 CopyDatabase, CopyDatabaseArgument, CopyQueryTo, CopyQueryToArgument, CopyTable,
26 CopyTableArgument,
27};
28use crate::statements::statement::Statement;
29use crate::util::parse_option_string;
30
31pub type With = HashMap<String, String>;
32pub type Connection = HashMap<String, String>;
33
34impl ParserContext<'_> {
36 pub(crate) fn parse_copy(&mut self) -> Result<Statement> {
37 let _ = self.parser.next_token();
38 let next = self.parser.peek_token();
39 let copy = if next.token == Token::LParen {
40 let copy_query = self.parse_copy_query_to()?;
41
42 crate::statements::copy::Copy::CopyQueryTo(copy_query)
43 } else if let Word(word) = next.token
44 && word.keyword == Keyword::DATABASE
45 {
46 let _ = self.parser.next_token();
47 let copy_database = self.parser_copy_database()?;
48 crate::statements::copy::Copy::CopyDatabase(copy_database)
49 } else {
50 let copy_table = self.parse_copy_table()?;
51 crate::statements::copy::Copy::CopyTable(copy_table)
52 };
53
54 Ok(Statement::Copy(copy))
55 }
56
57 fn parser_copy_database(&mut self) -> Result<CopyDatabase> {
58 let database_name = self
59 .parse_object_name()
60 .with_context(|_| error::UnexpectedSnafu {
61 expected: "a database name",
62 actual: self.peek_token_as_string(),
63 })?;
64
65 let req = if self.parser.parse_keyword(Keyword::TO) {
66 let (with, connection, location, limit) = self.parse_copy_parameters()?;
67 if limit.is_some() {
68 return error::InvalidSqlSnafu {
69 msg: "limit is not supported",
70 }
71 .fail();
72 }
73
74 let argument = CopyDatabaseArgument {
75 database_name,
76 with: with.into(),
77 connection: connection.into(),
78 location,
79 };
80 CopyDatabase::To(argument)
81 } else {
82 self.parser
83 .expect_keyword(Keyword::FROM)
84 .context(error::SyntaxSnafu)?;
85 let (with, connection, location, limit) = self.parse_copy_parameters()?;
86 if limit.is_some() {
87 return error::InvalidSqlSnafu {
88 msg: "limit is not supported",
89 }
90 .fail();
91 }
92
93 let argument = CopyDatabaseArgument {
94 database_name,
95 with: with.into(),
96 connection: connection.into(),
97 location,
98 };
99 CopyDatabase::From(argument)
100 };
101 Ok(req)
102 }
103
104 fn parse_copy_table(&mut self) -> Result<CopyTable> {
105 let raw_table_name = self
106 .parse_object_name()
107 .with_context(|_| error::UnexpectedSnafu {
108 expected: "a table name",
109 actual: self.peek_token_as_string(),
110 })?;
111 let table_name = Self::canonicalize_object_name(raw_table_name);
112
113 if self.parser.parse_keyword(Keyword::TO) {
114 let (with, connection, location, limit) = self.parse_copy_parameters()?;
115 Ok(CopyTable::To(CopyTableArgument {
116 table_name,
117 with: with.into(),
118 connection: connection.into(),
119 location,
120 limit,
121 }))
122 } else {
123 self.parser
124 .expect_keyword(Keyword::FROM)
125 .context(error::SyntaxSnafu)?;
126 let (with, connection, location, limit) = self.parse_copy_parameters()?;
127 Ok(CopyTable::From(CopyTableArgument {
128 table_name,
129 with: with.into(),
130 connection: connection.into(),
131 location,
132 limit,
133 }))
134 }
135 }
136
137 fn parse_copy_query_to(&mut self) -> Result<CopyQueryTo> {
138 self.parser
139 .expect_token(&Token::LParen)
140 .with_context(|_| error::UnexpectedSnafu {
141 expected: "'('",
142 actual: self.peek_token_as_string(),
143 })?;
144 let query = self.parse_query()?;
145 self.parser
146 .expect_token(&Token::RParen)
147 .with_context(|_| error::UnexpectedSnafu {
148 expected: "')'",
149 actual: self.peek_token_as_string(),
150 })?;
151 self.parser
152 .expect_keyword(Keyword::TO)
153 .context(error::SyntaxSnafu)?;
154 let (with, connection, location, limit) = self.parse_copy_parameters()?;
155 if limit.is_some() {
156 return error::InvalidSqlSnafu {
157 msg: "limit is not supported",
158 }
159 .fail();
160 }
161 Ok(CopyQueryTo {
162 query: Box::new(query),
163 arg: CopyQueryToArgument {
164 with: with.into(),
165 connection: connection.into(),
166 location,
167 },
168 })
169 }
170
171 fn parse_copy_parameters(&mut self) -> Result<(With, Connection, String, Option<u64>)> {
172 let location =
173 self.parser
174 .parse_literal_string()
175 .with_context(|_| error::UnexpectedSnafu {
176 expected: "a file name",
177 actual: self.peek_token_as_string(),
178 })?;
179
180 let options = self
181 .parser
182 .parse_options(Keyword::WITH)
183 .context(error::SyntaxSnafu)?;
184
185 let with = options
186 .into_iter()
187 .map(parse_option_string)
188 .collect::<Result<With>>()?;
189
190 let connection_options = self
191 .parser
192 .parse_options(Keyword::CONNECTION)
193 .context(error::SyntaxSnafu)?;
194
195 let connection = connection_options
196 .into_iter()
197 .map(parse_option_string)
198 .collect::<Result<Connection>>()?;
199
200 let limit = if self.parser.parse_keyword(Keyword::LIMIT) {
201 Some(
202 self.parser
203 .parse_literal_uint()
204 .with_context(|_| error::UnexpectedSnafu {
205 expected: "the number of maximum rows",
206 actual: self.peek_token_as_string(),
207 })?,
208 )
209 } else {
210 None
211 };
212
213 Ok((with, connection, location, limit))
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use std::assert_matches::assert_matches;
220 use std::collections::HashMap;
221
222 use sqlparser::ast::{Ident, ObjectName};
223
224 use super::*;
225 use crate::dialect::GreptimeDbDialect;
226 use crate::parser::ParseOptions;
227 use crate::statements::statement::Statement::Copy;
228
229 #[test]
230 fn test_parse_copy_table() {
231 let sql0 = "COPY catalog0.schema0.tbl TO 'tbl_file.parquet'";
232 let sql1 = "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' WITH (FORMAT = 'parquet')";
233 let result0 = ParserContext::create_with_dialect(
234 sql0,
235 &GreptimeDbDialect {},
236 ParseOptions::default(),
237 )
238 .unwrap();
239 let result1 = ParserContext::create_with_dialect(
240 sql1,
241 &GreptimeDbDialect {},
242 ParseOptions::default(),
243 )
244 .unwrap();
245
246 for mut result in [result0, result1] {
247 assert_eq!(1, result.len());
248
249 let statement = result.remove(0);
250 assert_matches!(statement, Statement::Copy { .. });
251 match statement {
252 Copy(copy) => {
253 let crate::statements::copy::Copy::CopyTable(CopyTable::To(copy_table)) = copy
254 else {
255 unreachable!()
256 };
257 let (catalog, schema, table) =
258 if let [catalog, schema, table] = ©_table.table_name.0[..] {
259 (
260 catalog.value.clone(),
261 schema.value.clone(),
262 table.value.clone(),
263 )
264 } else {
265 unreachable!()
266 };
267
268 assert_eq!("catalog0", catalog);
269 assert_eq!("schema0", schema);
270 assert_eq!("tbl", table);
271
272 let file_name = ©_table.location;
273 assert_eq!("tbl_file.parquet", file_name);
274
275 let format = copy_table.format().unwrap();
276 assert_eq!("parquet", format.to_lowercase());
277 }
278 _ => unreachable!(),
279 }
280 }
281 }
282
283 #[test]
284 fn test_parse_copy_table_from_basic() {
285 let results = [
286 "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet'",
287 "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (FORMAT = 'parquet')",
288 ]
289 .iter()
290 .map(|sql| {
291 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
292 .unwrap()
293 })
294 .collect::<Vec<_>>();
295
296 for mut result in results {
297 assert_eq!(1, result.len());
298
299 let statement = result.remove(0);
300 assert_matches!(statement, Statement::Copy { .. });
301 match statement {
302 Statement::Copy(crate::statements::copy::Copy::CopyTable(CopyTable::From(
303 copy_table,
304 ))) => {
305 let (catalog, schema, table) =
306 if let [catalog, schema, table] = ©_table.table_name.0[..] {
307 (
308 catalog.value.clone(),
309 schema.value.clone(),
310 table.value.clone(),
311 )
312 } else {
313 unreachable!()
314 };
315
316 assert_eq!("catalog0", catalog);
317 assert_eq!("schema0", schema);
318 assert_eq!("tbl", table);
319
320 let file_name = ©_table.location;
321 assert_eq!("tbl_file.parquet", file_name);
322
323 let format = copy_table.format().unwrap();
324 assert_eq!("parquet", format.to_lowercase());
325 }
326 _ => unreachable!(),
327 }
328 }
329 }
330
331 #[test]
332 fn test_parse_copy_table_from() {
333 struct Test<'a> {
334 sql: &'a str,
335 expected_pattern: Option<String>,
336 expected_connection: HashMap<String, String>,
337 }
338
339 let tests = [
340 Test {
341 sql: "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (PATTERN = 'demo.*')",
342 expected_pattern: Some("demo.*".into()),
343 expected_connection: HashMap::new(),
344 },
345 Test {
346 sql: "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (PATTERN = 'demo.*') CONNECTION (FOO='Bar', ONE='two')",
347 expected_pattern: Some("demo.*".into()),
348 expected_connection: [("foo","Bar"),("one","two")].into_iter().map(|(k,v)|{(k.to_string(),v.to_string())}).collect()
349 },
350 ];
351
352 for test in tests {
353 let mut result = ParserContext::create_with_dialect(
354 test.sql,
355 &GreptimeDbDialect {},
356 ParseOptions::default(),
357 )
358 .unwrap();
359 assert_eq!(1, result.len());
360
361 let statement = result.remove(0);
362 assert_matches!(statement, Statement::Copy { .. });
363 match statement {
364 Statement::Copy(crate::statements::copy::Copy::CopyTable(CopyTable::From(
365 copy_table,
366 ))) => {
367 if let Some(expected_pattern) = test.expected_pattern {
368 assert_eq!(copy_table.pattern().unwrap(), expected_pattern);
369 }
370 assert_eq!(
371 copy_table.connection.clone(),
372 test.expected_connection.into()
373 );
374 }
375 _ => unreachable!(),
376 }
377 }
378 }
379
380 #[test]
381 fn test_parse_copy_table_to() {
382 struct Test<'a> {
383 sql: &'a str,
384 expected_connection: HashMap<String, String>,
385 }
386
387 let tests = [
388 Test {
389 sql: "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' ",
390 expected_connection: HashMap::new(),
391 },
392 Test {
393 sql: "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' CONNECTION (FOO='Bar', ONE='two')",
394 expected_connection: [("foo","Bar"),("one","two")].into_iter().map(|(k,v)|{(k.to_string(),v.to_string())}).collect()
395 },
396 Test {
397 sql:"COPY catalog0.schema0.tbl TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')",
398 expected_connection: [("foo","Bar"),("one","two")].into_iter().map(|(k,v)|{(k.to_string(),v.to_string())}).collect()
399 },
400 ];
401
402 for test in tests {
403 let mut result = ParserContext::create_with_dialect(
404 test.sql,
405 &GreptimeDbDialect {},
406 ParseOptions::default(),
407 )
408 .unwrap();
409 assert_eq!(1, result.len());
410
411 let statement = result.remove(0);
412 assert_matches!(statement, Statement::Copy { .. });
413 match statement {
414 Statement::Copy(crate::statements::copy::Copy::CopyTable(CopyTable::To(
415 copy_table,
416 ))) => {
417 assert_eq!(
418 copy_table.connection.clone(),
419 test.expected_connection.into()
420 );
421 }
422 _ => unreachable!(),
423 }
424 }
425 }
426
427 #[test]
428 fn test_copy_database_to() {
429 let sql = "COPY DATABASE catalog0.schema0 TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
430 let stmt =
431 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
432 .unwrap()
433 .pop()
434 .unwrap();
435
436 let Copy(crate::statements::copy::Copy::CopyDatabase(stmt)) = stmt else {
437 unreachable!()
438 };
439
440 let CopyDatabase::To(stmt) = stmt else {
441 unreachable!()
442 };
443
444 assert_eq!(
445 ObjectName(vec![Ident::new("catalog0"), Ident::new("schema0")]),
446 stmt.database_name
447 );
448 assert_eq!(
449 [("format", "parquet")]
450 .into_iter()
451 .collect::<HashMap<_, _>>(),
452 stmt.with.to_str_map()
453 );
454
455 assert_eq!(
456 [("foo", "Bar"), ("one", "two")]
457 .into_iter()
458 .collect::<HashMap<_, _>>(),
459 stmt.connection.to_str_map()
460 );
461 }
462
463 #[test]
464 fn test_copy_database_from() {
465 let sql = "COPY DATABASE catalog0.schema0 FROM '/a/b/c/' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
466 let stmt =
467 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
468 .unwrap()
469 .pop()
470 .unwrap();
471
472 let Copy(crate::statements::copy::Copy::CopyDatabase(stmt)) = stmt else {
473 unreachable!()
474 };
475
476 let CopyDatabase::From(stmt) = stmt else {
477 unreachable!()
478 };
479
480 assert_eq!(
481 ObjectName(vec![Ident::new("catalog0"), Ident::new("schema0")]),
482 stmt.database_name
483 );
484 assert_eq!(
485 [("format", "parquet")]
486 .into_iter()
487 .collect::<HashMap<_, _>>(),
488 stmt.with.to_str_map()
489 );
490
491 assert_eq!(
492 [("foo", "Bar"), ("one", "two")]
493 .into_iter()
494 .collect::<HashMap<_, _>>(),
495 stmt.connection.to_str_map()
496 );
497 }
498
499 #[test]
500 fn test_copy_query_to() {
501 let sql = "COPY (SELECT * FROM tbl WHERE ts > 10) TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
502 let stmt =
503 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
504 .unwrap()
505 .pop()
506 .unwrap();
507
508 let Copy(crate::statements::copy::Copy::CopyQueryTo(stmt)) = stmt else {
509 unreachable!()
510 };
511
512 let query = ParserContext::create_with_dialect(
513 "SELECT * FROM tbl WHERE ts > 10",
514 &GreptimeDbDialect {},
515 ParseOptions::default(),
516 )
517 .unwrap()
518 .remove(0);
519
520 assert_eq!(&query, stmt.query.as_ref());
521 assert_eq!(
522 [("format", "parquet")]
523 .into_iter()
524 .collect::<HashMap<_, _>>(),
525 stmt.arg.with.to_str_map()
526 );
527
528 assert_eq!(
529 [("foo", "Bar"), ("one", "two")]
530 .into_iter()
531 .collect::<HashMap<_, _>>(),
532 stmt.arg.connection.to_str_map()
533 );
534 }
535
536 #[test]
537 fn test_invalid_copy_query_to() {
538 {
539 let sql = "COPY SELECT * FROM tbl WHERE ts > 10 TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
540
541 assert!(ParserContext::create_with_dialect(
542 sql,
543 &GreptimeDbDialect {},
544 ParseOptions::default()
545 )
546 .is_err())
547 }
548 {
549 let sql = "COPY SELECT * FROM tbl WHERE ts > 10) TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
550
551 assert!(ParserContext::create_with_dialect(
552 sql,
553 &GreptimeDbDialect {},
554 ParseOptions::default()
555 )
556 .is_err())
557 }
558 {
559 let sql = "COPY (SELECT * FROM tbl WHERE ts > 10 TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
560
561 assert!(ParserContext::create_with_dialect(
562 sql,
563 &GreptimeDbDialect {},
564 ParseOptions::default()
565 )
566 .is_err())
567 }
568 }
569}