1use std::collections::HashMap;
16
17use snafu::ResultExt;
18use sqlparser::keywords::Keyword;
19use sqlparser::tokenizer::Token;
20use sqlparser::tokenizer::Token::Word;
21
22use crate::error::{self, Result};
23use crate::parser::ParserContext;
24use crate::statements::copy::{
25 CopyDatabase, CopyDatabaseArgument, CopyQueryTo, CopyQueryToArgument, CopyTable,
26 CopyTableArgument,
27};
28use crate::statements::statement::Statement;
29use crate::util::parse_option_string;
30
31pub type With = HashMap<String, String>;
32pub type Connection = HashMap<String, String>;
33
34impl ParserContext<'_> {
36 pub(crate) fn parse_copy(&mut self) -> Result<Statement> {
37 let _ = self.parser.next_token();
38 let next = self.parser.peek_token();
39 let copy = if next.token == Token::LParen {
40 let copy_query = self.parse_copy_query_to()?;
41
42 crate::statements::copy::Copy::CopyQueryTo(copy_query)
43 } else if let Word(word) = next.token
44 && word.keyword == Keyword::DATABASE
45 {
46 let _ = self.parser.next_token();
47 let copy_database = self.parser_copy_database()?;
48 crate::statements::copy::Copy::CopyDatabase(copy_database)
49 } else {
50 let copy_table = self.parse_copy_table()?;
51 crate::statements::copy::Copy::CopyTable(copy_table)
52 };
53
54 Ok(Statement::Copy(copy))
55 }
56
57 fn parser_copy_database(&mut self) -> Result<CopyDatabase> {
58 let database_name = self
59 .parse_object_name()
60 .with_context(|_| error::UnexpectedSnafu {
61 expected: "a database name",
62 actual: self.peek_token_as_string(),
63 })?;
64
65 let req = if self.parser.parse_keyword(Keyword::TO) {
66 let (with, connection, location, limit) = self.parse_copy_parameters()?;
67 if limit.is_some() {
68 return error::InvalidSqlSnafu {
69 msg: "limit is not supported",
70 }
71 .fail();
72 }
73
74 let argument = CopyDatabaseArgument {
75 database_name,
76 with: with.into(),
77 connection: connection.into(),
78 location,
79 };
80 CopyDatabase::To(argument)
81 } else {
82 self.parser
83 .expect_keyword(Keyword::FROM)
84 .context(error::SyntaxSnafu)?;
85 let (with, connection, location, limit) = self.parse_copy_parameters()?;
86 if limit.is_some() {
87 return error::InvalidSqlSnafu {
88 msg: "limit is not supported",
89 }
90 .fail();
91 }
92
93 let argument = CopyDatabaseArgument {
94 database_name,
95 with: with.into(),
96 connection: connection.into(),
97 location,
98 };
99 CopyDatabase::From(argument)
100 };
101 Ok(req)
102 }
103
104 fn parse_copy_table(&mut self) -> Result<CopyTable> {
105 let raw_table_name = self
106 .parse_object_name()
107 .with_context(|_| error::UnexpectedSnafu {
108 expected: "a table name",
109 actual: self.peek_token_as_string(),
110 })?;
111 let table_name = Self::canonicalize_object_name(raw_table_name);
112
113 if self.parser.parse_keyword(Keyword::TO) {
114 let (with, connection, location, limit) = self.parse_copy_parameters()?;
115 Ok(CopyTable::To(CopyTableArgument {
116 table_name,
117 with: with.into(),
118 connection: connection.into(),
119 location,
120 limit,
121 }))
122 } else {
123 self.parser
124 .expect_keyword(Keyword::FROM)
125 .context(error::SyntaxSnafu)?;
126 let (with, connection, location, limit) = self.parse_copy_parameters()?;
127 Ok(CopyTable::From(CopyTableArgument {
128 table_name,
129 with: with.into(),
130 connection: connection.into(),
131 location,
132 limit,
133 }))
134 }
135 }
136
137 fn parse_copy_query_to(&mut self) -> Result<CopyQueryTo> {
138 self.parser
139 .expect_token(&Token::LParen)
140 .with_context(|_| error::UnexpectedSnafu {
141 expected: "'('",
142 actual: self.peek_token_as_string(),
143 })?;
144 let query = self.parse_query()?;
145 self.parser
146 .expect_token(&Token::RParen)
147 .with_context(|_| error::UnexpectedSnafu {
148 expected: "')'",
149 actual: self.peek_token_as_string(),
150 })?;
151 self.parser
152 .expect_keyword(Keyword::TO)
153 .context(error::SyntaxSnafu)?;
154 let (with, connection, location, limit) = self.parse_copy_parameters()?;
155 if limit.is_some() {
156 return error::InvalidSqlSnafu {
157 msg: "limit is not supported",
158 }
159 .fail();
160 }
161 Ok(CopyQueryTo {
162 query: Box::new(query),
163 arg: CopyQueryToArgument {
164 with: with.into(),
165 connection: connection.into(),
166 location,
167 },
168 })
169 }
170
171 fn parse_copy_parameters(&mut self) -> Result<(With, Connection, String, Option<u64>)> {
172 let location =
173 self.parser
174 .parse_literal_string()
175 .with_context(|_| error::UnexpectedSnafu {
176 expected: "a file name",
177 actual: self.peek_token_as_string(),
178 })?;
179
180 let options = self
181 .parser
182 .parse_options(Keyword::WITH)
183 .context(error::SyntaxSnafu)?;
184
185 let with = options
186 .into_iter()
187 .map(parse_option_string)
188 .collect::<Result<With>>()?;
189
190 let connection_options = self
191 .parser
192 .parse_options(Keyword::CONNECTION)
193 .context(error::SyntaxSnafu)?;
194
195 let connection = connection_options
196 .into_iter()
197 .map(parse_option_string)
198 .collect::<Result<Connection>>()?;
199
200 let limit = if self.parser.parse_keyword(Keyword::LIMIT) {
201 Some(
202 self.parser
203 .parse_literal_uint()
204 .with_context(|_| error::UnexpectedSnafu {
205 expected: "the number of maximum rows",
206 actual: self.peek_token_as_string(),
207 })?,
208 )
209 } else {
210 None
211 };
212
213 Ok((with, connection, location, limit))
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use std::assert_matches::assert_matches;
220 use std::collections::HashMap;
221
222 use sqlparser::ast::{Ident, ObjectName};
223
224 use super::*;
225 use crate::dialect::GreptimeDbDialect;
226 use crate::parser::ParseOptions;
227 use crate::statements::statement::Statement::Copy;
228
229 #[test]
230 fn test_parse_copy_table() {
231 let sql0 = "COPY catalog0.schema0.tbl TO 'tbl_file.parquet'";
232 let sql1 = "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' WITH (FORMAT = 'parquet')";
233 let result0 = ParserContext::create_with_dialect(
234 sql0,
235 &GreptimeDbDialect {},
236 ParseOptions::default(),
237 )
238 .unwrap();
239 let result1 = ParserContext::create_with_dialect(
240 sql1,
241 &GreptimeDbDialect {},
242 ParseOptions::default(),
243 )
244 .unwrap();
245
246 for mut result in [result0, result1] {
247 assert_eq!(1, result.len());
248
249 let statement = result.remove(0);
250 assert_matches!(statement, Statement::Copy { .. });
251 match statement {
252 Copy(copy) => {
253 let crate::statements::copy::Copy::CopyTable(CopyTable::To(copy_table)) = copy
254 else {
255 unreachable!()
256 };
257 let table = copy_table.table_name.to_string();
258 assert_eq!("catalog0.schema0.tbl", table);
259
260 let file_name = ©_table.location;
261 assert_eq!("tbl_file.parquet", file_name);
262
263 let format = copy_table.format().unwrap();
264 assert_eq!("parquet", format.to_lowercase());
265 }
266 _ => unreachable!(),
267 }
268 }
269 }
270
271 #[test]
272 fn test_parse_copy_table_from_basic() {
273 let results = [
274 "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet'",
275 "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (FORMAT = 'parquet')",
276 ]
277 .iter()
278 .map(|sql| {
279 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
280 .unwrap()
281 })
282 .collect::<Vec<_>>();
283
284 for mut result in results {
285 assert_eq!(1, result.len());
286
287 let statement = result.remove(0);
288 assert_matches!(statement, Statement::Copy { .. });
289 match statement {
290 Statement::Copy(crate::statements::copy::Copy::CopyTable(CopyTable::From(
291 copy_table,
292 ))) => {
293 let table = copy_table.table_name.to_string();
294 assert_eq!("catalog0.schema0.tbl", table);
295
296 let file_name = ©_table.location;
297 assert_eq!("tbl_file.parquet", file_name);
298
299 let format = copy_table.format().unwrap();
300 assert_eq!("parquet", format.to_lowercase());
301 }
302 _ => unreachable!(),
303 }
304 }
305 }
306
307 #[test]
308 fn test_parse_copy_table_from() {
309 struct Test<'a> {
310 sql: &'a str,
311 expected_pattern: Option<String>,
312 expected_connection: HashMap<String, String>,
313 }
314
315 let tests = [
316 Test {
317 sql: "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (PATTERN = 'demo.*')",
318 expected_pattern: Some("demo.*".into()),
319 expected_connection: HashMap::new(),
320 },
321 Test {
322 sql: "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (PATTERN = 'demo.*') CONNECTION (FOO='Bar', ONE='two')",
323 expected_pattern: Some("demo.*".into()),
324 expected_connection: [("foo","Bar"),("one","two")].into_iter().map(|(k,v)|{(k.to_string(),v.to_string())}).collect()
325 },
326 ];
327
328 for test in tests {
329 let mut result = ParserContext::create_with_dialect(
330 test.sql,
331 &GreptimeDbDialect {},
332 ParseOptions::default(),
333 )
334 .unwrap();
335 assert_eq!(1, result.len());
336
337 let statement = result.remove(0);
338 assert_matches!(statement, Statement::Copy { .. });
339 match statement {
340 Statement::Copy(crate::statements::copy::Copy::CopyTable(CopyTable::From(
341 copy_table,
342 ))) => {
343 if let Some(expected_pattern) = test.expected_pattern {
344 assert_eq!(copy_table.pattern().unwrap(), expected_pattern);
345 }
346 assert_eq!(
347 copy_table.connection.clone(),
348 test.expected_connection.into()
349 );
350 }
351 _ => unreachable!(),
352 }
353 }
354 }
355
356 #[test]
357 fn test_parse_copy_table_to() {
358 struct Test<'a> {
359 sql: &'a str,
360 expected_connection: HashMap<String, String>,
361 }
362
363 let tests = [
364 Test {
365 sql: "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' ",
366 expected_connection: HashMap::new(),
367 },
368 Test {
369 sql: "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' CONNECTION (FOO='Bar', ONE='two')",
370 expected_connection: [("foo","Bar"),("one","two")].into_iter().map(|(k,v)|{(k.to_string(),v.to_string())}).collect()
371 },
372 Test {
373 sql:"COPY catalog0.schema0.tbl TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')",
374 expected_connection: [("foo","Bar"),("one","two")].into_iter().map(|(k,v)|{(k.to_string(),v.to_string())}).collect()
375 },
376 ];
377
378 for test in tests {
379 let mut result = ParserContext::create_with_dialect(
380 test.sql,
381 &GreptimeDbDialect {},
382 ParseOptions::default(),
383 )
384 .unwrap();
385 assert_eq!(1, result.len());
386
387 let statement = result.remove(0);
388 assert_matches!(statement, Statement::Copy { .. });
389 match statement {
390 Statement::Copy(crate::statements::copy::Copy::CopyTable(CopyTable::To(
391 copy_table,
392 ))) => {
393 assert_eq!(
394 copy_table.connection.clone(),
395 test.expected_connection.into()
396 );
397 }
398 _ => unreachable!(),
399 }
400 }
401 }
402
403 #[test]
404 fn test_copy_database_to() {
405 let sql = "COPY DATABASE catalog0.schema0 TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
406 let stmt =
407 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
408 .unwrap()
409 .pop()
410 .unwrap();
411
412 let Copy(crate::statements::copy::Copy::CopyDatabase(stmt)) = stmt else {
413 unreachable!()
414 };
415
416 let CopyDatabase::To(stmt) = stmt else {
417 unreachable!()
418 };
419
420 assert_eq!(
421 ObjectName::from(vec![Ident::new("catalog0"), Ident::new("schema0")]),
422 stmt.database_name
423 );
424 assert_eq!(
425 [("format", "parquet")]
426 .into_iter()
427 .collect::<HashMap<_, _>>(),
428 stmt.with.to_str_map()
429 );
430
431 assert_eq!(
432 [("foo", "Bar"), ("one", "two")]
433 .into_iter()
434 .collect::<HashMap<_, _>>(),
435 stmt.connection.to_str_map()
436 );
437 }
438
439 #[test]
440 fn test_copy_database_from() {
441 let sql = "COPY DATABASE catalog0.schema0 FROM '/a/b/c/' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
442 let stmt =
443 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
444 .unwrap()
445 .pop()
446 .unwrap();
447
448 let Copy(crate::statements::copy::Copy::CopyDatabase(stmt)) = stmt else {
449 unreachable!()
450 };
451
452 let CopyDatabase::From(stmt) = stmt else {
453 unreachable!()
454 };
455
456 assert_eq!(
457 ObjectName::from(vec![Ident::new("catalog0"), Ident::new("schema0")]),
458 stmt.database_name
459 );
460 assert_eq!(
461 [("format", "parquet")]
462 .into_iter()
463 .collect::<HashMap<_, _>>(),
464 stmt.with.to_str_map()
465 );
466
467 assert_eq!(
468 [("foo", "Bar"), ("one", "two")]
469 .into_iter()
470 .collect::<HashMap<_, _>>(),
471 stmt.connection.to_str_map()
472 );
473 }
474
475 #[test]
476 fn test_copy_query_to() {
477 let sql = "COPY (SELECT * FROM tbl WHERE ts > 10) TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
478 let stmt =
479 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
480 .unwrap()
481 .pop()
482 .unwrap();
483
484 let Copy(crate::statements::copy::Copy::CopyQueryTo(stmt)) = stmt else {
485 unreachable!()
486 };
487
488 let query = ParserContext::create_with_dialect(
489 "SELECT * FROM tbl WHERE ts > 10",
490 &GreptimeDbDialect {},
491 ParseOptions::default(),
492 )
493 .unwrap()
494 .remove(0);
495
496 assert_eq!(&query, stmt.query.as_ref());
497 assert_eq!(
498 [("format", "parquet")]
499 .into_iter()
500 .collect::<HashMap<_, _>>(),
501 stmt.arg.with.to_str_map()
502 );
503
504 assert_eq!(
505 [("foo", "Bar"), ("one", "two")]
506 .into_iter()
507 .collect::<HashMap<_, _>>(),
508 stmt.arg.connection.to_str_map()
509 );
510 }
511
512 #[test]
513 fn test_invalid_copy_query_to() {
514 {
515 let sql = "COPY SELECT * FROM tbl WHERE ts > 10 TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
516
517 assert!(ParserContext::create_with_dialect(
518 sql,
519 &GreptimeDbDialect {},
520 ParseOptions::default()
521 )
522 .is_err())
523 }
524 {
525 let sql = "COPY SELECT * FROM tbl WHERE ts > 10) TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
526
527 assert!(ParserContext::create_with_dialect(
528 sql,
529 &GreptimeDbDialect {},
530 ParseOptions::default()
531 )
532 .is_err())
533 }
534 {
535 let sql = "COPY (SELECT * FROM tbl WHERE ts > 10 TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
536
537 assert!(ParserContext::create_with_dialect(
538 sql,
539 &GreptimeDbDialect {},
540 ParseOptions::default()
541 )
542 .is_err())
543 }
544 }
545}