sql/parsers/
copy_parser.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use snafu::ResultExt;
18use sqlparser::keywords::Keyword;
19use sqlparser::tokenizer::Token;
20use sqlparser::tokenizer::Token::Word;
21
22use crate::error::{self, Result};
23use crate::parser::ParserContext;
24use crate::statements::copy::{
25    CopyDatabase, CopyDatabaseArgument, CopyQueryTo, CopyQueryToArgument, CopyTable,
26    CopyTableArgument,
27};
28use crate::statements::statement::Statement;
29use crate::util::parse_option_string;
30
31pub type With = HashMap<String, String>;
32pub type Connection = HashMap<String, String>;
33
34// COPY tbl TO 'output.parquet';
35impl ParserContext<'_> {
36    pub(crate) fn parse_copy(&mut self) -> Result<Statement> {
37        let _ = self.parser.next_token();
38        let next = self.parser.peek_token();
39        let copy = if next.token == Token::LParen {
40            let copy_query = self.parse_copy_query_to()?;
41
42            crate::statements::copy::Copy::CopyQueryTo(copy_query)
43        } else if let Word(word) = next.token
44            && word.keyword == Keyword::DATABASE
45        {
46            let _ = self.parser.next_token();
47            let copy_database = self.parser_copy_database()?;
48            crate::statements::copy::Copy::CopyDatabase(copy_database)
49        } else {
50            let copy_table = self.parse_copy_table()?;
51            crate::statements::copy::Copy::CopyTable(copy_table)
52        };
53
54        Ok(Statement::Copy(copy))
55    }
56
57    fn parser_copy_database(&mut self) -> Result<CopyDatabase> {
58        let database_name = self
59            .parse_object_name()
60            .with_context(|_| error::UnexpectedSnafu {
61                expected: "a database name",
62                actual: self.peek_token_as_string(),
63            })?;
64
65        let req = if self.parser.parse_keyword(Keyword::TO) {
66            let (with, connection, location, limit) = self.parse_copy_parameters()?;
67            if limit.is_some() {
68                return error::InvalidSqlSnafu {
69                    msg: "limit is not supported",
70                }
71                .fail();
72            }
73
74            let argument = CopyDatabaseArgument {
75                database_name,
76                with: with.into(),
77                connection: connection.into(),
78                location,
79            };
80            CopyDatabase::To(argument)
81        } else {
82            self.parser
83                .expect_keyword(Keyword::FROM)
84                .context(error::SyntaxSnafu)?;
85            let (with, connection, location, limit) = self.parse_copy_parameters()?;
86            if limit.is_some() {
87                return error::InvalidSqlSnafu {
88                    msg: "limit is not supported",
89                }
90                .fail();
91            }
92
93            let argument = CopyDatabaseArgument {
94                database_name,
95                with: with.into(),
96                connection: connection.into(),
97                location,
98            };
99            CopyDatabase::From(argument)
100        };
101        Ok(req)
102    }
103
104    fn parse_copy_table(&mut self) -> Result<CopyTable> {
105        let raw_table_name = self
106            .parse_object_name()
107            .with_context(|_| error::UnexpectedSnafu {
108                expected: "a table name",
109                actual: self.peek_token_as_string(),
110            })?;
111        let table_name = Self::canonicalize_object_name(raw_table_name);
112
113        if self.parser.parse_keyword(Keyword::TO) {
114            let (with, connection, location, limit) = self.parse_copy_parameters()?;
115            Ok(CopyTable::To(CopyTableArgument {
116                table_name,
117                with: with.into(),
118                connection: connection.into(),
119                location,
120                limit,
121            }))
122        } else {
123            self.parser
124                .expect_keyword(Keyword::FROM)
125                .context(error::SyntaxSnafu)?;
126            let (with, connection, location, limit) = self.parse_copy_parameters()?;
127            Ok(CopyTable::From(CopyTableArgument {
128                table_name,
129                with: with.into(),
130                connection: connection.into(),
131                location,
132                limit,
133            }))
134        }
135    }
136
137    fn parse_copy_query_to(&mut self) -> Result<CopyQueryTo> {
138        self.parser
139            .expect_token(&Token::LParen)
140            .with_context(|_| error::UnexpectedSnafu {
141                expected: "'('",
142                actual: self.peek_token_as_string(),
143            })?;
144        let query = self.parse_query()?;
145        self.parser
146            .expect_token(&Token::RParen)
147            .with_context(|_| error::UnexpectedSnafu {
148                expected: "')'",
149                actual: self.peek_token_as_string(),
150            })?;
151        self.parser
152            .expect_keyword(Keyword::TO)
153            .context(error::SyntaxSnafu)?;
154        let (with, connection, location, limit) = self.parse_copy_parameters()?;
155        if limit.is_some() {
156            return error::InvalidSqlSnafu {
157                msg: "limit is not supported",
158            }
159            .fail();
160        }
161        Ok(CopyQueryTo {
162            query: Box::new(query),
163            arg: CopyQueryToArgument {
164                with: with.into(),
165                connection: connection.into(),
166                location,
167            },
168        })
169    }
170
171    fn parse_copy_parameters(&mut self) -> Result<(With, Connection, String, Option<u64>)> {
172        let location =
173            self.parser
174                .parse_literal_string()
175                .with_context(|_| error::UnexpectedSnafu {
176                    expected: "a file name",
177                    actual: self.peek_token_as_string(),
178                })?;
179
180        let options = self
181            .parser
182            .parse_options(Keyword::WITH)
183            .context(error::SyntaxSnafu)?;
184
185        let with = options
186            .into_iter()
187            .map(parse_option_string)
188            .collect::<Result<With>>()?;
189
190        let connection_options = self
191            .parser
192            .parse_options(Keyword::CONNECTION)
193            .context(error::SyntaxSnafu)?;
194
195        let connection = connection_options
196            .into_iter()
197            .map(parse_option_string)
198            .collect::<Result<Connection>>()?;
199
200        let limit = if self.parser.parse_keyword(Keyword::LIMIT) {
201            Some(
202                self.parser
203                    .parse_literal_uint()
204                    .with_context(|_| error::UnexpectedSnafu {
205                        expected: "the number of maximum rows",
206                        actual: self.peek_token_as_string(),
207                    })?,
208            )
209        } else {
210            None
211        };
212
213        Ok((with, connection, location, limit))
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use std::assert_matches::assert_matches;
220    use std::collections::HashMap;
221
222    use sqlparser::ast::{Ident, ObjectName};
223
224    use super::*;
225    use crate::dialect::GreptimeDbDialect;
226    use crate::parser::ParseOptions;
227    use crate::statements::statement::Statement::Copy;
228
229    #[test]
230    fn test_parse_copy_table() {
231        let sql0 = "COPY catalog0.schema0.tbl TO 'tbl_file.parquet'";
232        let sql1 = "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' WITH (FORMAT = 'parquet')";
233        let result0 = ParserContext::create_with_dialect(
234            sql0,
235            &GreptimeDbDialect {},
236            ParseOptions::default(),
237        )
238        .unwrap();
239        let result1 = ParserContext::create_with_dialect(
240            sql1,
241            &GreptimeDbDialect {},
242            ParseOptions::default(),
243        )
244        .unwrap();
245
246        for mut result in [result0, result1] {
247            assert_eq!(1, result.len());
248
249            let statement = result.remove(0);
250            assert_matches!(statement, Statement::Copy { .. });
251            match statement {
252                Copy(copy) => {
253                    let crate::statements::copy::Copy::CopyTable(CopyTable::To(copy_table)) = copy
254                    else {
255                        unreachable!()
256                    };
257                    let (catalog, schema, table) =
258                        if let [catalog, schema, table] = &copy_table.table_name.0[..] {
259                            (
260                                catalog.value.clone(),
261                                schema.value.clone(),
262                                table.value.clone(),
263                            )
264                        } else {
265                            unreachable!()
266                        };
267
268                    assert_eq!("catalog0", catalog);
269                    assert_eq!("schema0", schema);
270                    assert_eq!("tbl", table);
271
272                    let file_name = &copy_table.location;
273                    assert_eq!("tbl_file.parquet", file_name);
274
275                    let format = copy_table.format().unwrap();
276                    assert_eq!("parquet", format.to_lowercase());
277                }
278                _ => unreachable!(),
279            }
280        }
281    }
282
283    #[test]
284    fn test_parse_copy_table_from_basic() {
285        let results = [
286            "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet'",
287            "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (FORMAT = 'parquet')",
288        ]
289        .iter()
290        .map(|sql| {
291            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
292                .unwrap()
293        })
294        .collect::<Vec<_>>();
295
296        for mut result in results {
297            assert_eq!(1, result.len());
298
299            let statement = result.remove(0);
300            assert_matches!(statement, Statement::Copy { .. });
301            match statement {
302                Statement::Copy(crate::statements::copy::Copy::CopyTable(CopyTable::From(
303                    copy_table,
304                ))) => {
305                    let (catalog, schema, table) =
306                        if let [catalog, schema, table] = &copy_table.table_name.0[..] {
307                            (
308                                catalog.value.clone(),
309                                schema.value.clone(),
310                                table.value.clone(),
311                            )
312                        } else {
313                            unreachable!()
314                        };
315
316                    assert_eq!("catalog0", catalog);
317                    assert_eq!("schema0", schema);
318                    assert_eq!("tbl", table);
319
320                    let file_name = &copy_table.location;
321                    assert_eq!("tbl_file.parquet", file_name);
322
323                    let format = copy_table.format().unwrap();
324                    assert_eq!("parquet", format.to_lowercase());
325                }
326                _ => unreachable!(),
327            }
328        }
329    }
330
331    #[test]
332    fn test_parse_copy_table_from() {
333        struct Test<'a> {
334            sql: &'a str,
335            expected_pattern: Option<String>,
336            expected_connection: HashMap<String, String>,
337        }
338
339        let tests = [
340            Test {
341                sql: "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (PATTERN = 'demo.*')",
342                expected_pattern: Some("demo.*".into()),
343                expected_connection: HashMap::new(),
344            },
345            Test {
346                sql: "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (PATTERN = 'demo.*') CONNECTION (FOO='Bar', ONE='two')",
347                expected_pattern: Some("demo.*".into()),
348                expected_connection: [("foo","Bar"),("one","two")].into_iter().map(|(k,v)|{(k.to_string(),v.to_string())}).collect()
349            },
350        ];
351
352        for test in tests {
353            let mut result = ParserContext::create_with_dialect(
354                test.sql,
355                &GreptimeDbDialect {},
356                ParseOptions::default(),
357            )
358            .unwrap();
359            assert_eq!(1, result.len());
360
361            let statement = result.remove(0);
362            assert_matches!(statement, Statement::Copy { .. });
363            match statement {
364                Statement::Copy(crate::statements::copy::Copy::CopyTable(CopyTable::From(
365                    copy_table,
366                ))) => {
367                    if let Some(expected_pattern) = test.expected_pattern {
368                        assert_eq!(copy_table.pattern().unwrap(), expected_pattern);
369                    }
370                    assert_eq!(
371                        copy_table.connection.clone(),
372                        test.expected_connection.into()
373                    );
374                }
375                _ => unreachable!(),
376            }
377        }
378    }
379
380    #[test]
381    fn test_parse_copy_table_to() {
382        struct Test<'a> {
383            sql: &'a str,
384            expected_connection: HashMap<String, String>,
385        }
386
387        let tests = [
388            Test {
389                sql: "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' ",
390                expected_connection: HashMap::new(),
391            },
392            Test {
393                sql: "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' CONNECTION (FOO='Bar', ONE='two')",
394                expected_connection: [("foo","Bar"),("one","two")].into_iter().map(|(k,v)|{(k.to_string(),v.to_string())}).collect()
395            },
396            Test {
397                sql:"COPY catalog0.schema0.tbl TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')",
398                expected_connection: [("foo","Bar"),("one","two")].into_iter().map(|(k,v)|{(k.to_string(),v.to_string())}).collect()
399            },
400        ];
401
402        for test in tests {
403            let mut result = ParserContext::create_with_dialect(
404                test.sql,
405                &GreptimeDbDialect {},
406                ParseOptions::default(),
407            )
408            .unwrap();
409            assert_eq!(1, result.len());
410
411            let statement = result.remove(0);
412            assert_matches!(statement, Statement::Copy { .. });
413            match statement {
414                Statement::Copy(crate::statements::copy::Copy::CopyTable(CopyTable::To(
415                    copy_table,
416                ))) => {
417                    assert_eq!(
418                        copy_table.connection.clone(),
419                        test.expected_connection.into()
420                    );
421                }
422                _ => unreachable!(),
423            }
424        }
425    }
426
427    #[test]
428    fn test_copy_database_to() {
429        let sql = "COPY DATABASE catalog0.schema0 TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
430        let stmt =
431            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
432                .unwrap()
433                .pop()
434                .unwrap();
435
436        let Copy(crate::statements::copy::Copy::CopyDatabase(stmt)) = stmt else {
437            unreachable!()
438        };
439
440        let CopyDatabase::To(stmt) = stmt else {
441            unreachable!()
442        };
443
444        assert_eq!(
445            ObjectName(vec![Ident::new("catalog0"), Ident::new("schema0")]),
446            stmt.database_name
447        );
448        assert_eq!(
449            [("format", "parquet")]
450                .into_iter()
451                .collect::<HashMap<_, _>>(),
452            stmt.with.to_str_map()
453        );
454
455        assert_eq!(
456            [("foo", "Bar"), ("one", "two")]
457                .into_iter()
458                .collect::<HashMap<_, _>>(),
459            stmt.connection.to_str_map()
460        );
461    }
462
463    #[test]
464    fn test_copy_database_from() {
465        let sql = "COPY DATABASE catalog0.schema0 FROM '/a/b/c/' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
466        let stmt =
467            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
468                .unwrap()
469                .pop()
470                .unwrap();
471
472        let Copy(crate::statements::copy::Copy::CopyDatabase(stmt)) = stmt else {
473            unreachable!()
474        };
475
476        let CopyDatabase::From(stmt) = stmt else {
477            unreachable!()
478        };
479
480        assert_eq!(
481            ObjectName(vec![Ident::new("catalog0"), Ident::new("schema0")]),
482            stmt.database_name
483        );
484        assert_eq!(
485            [("format", "parquet")]
486                .into_iter()
487                .collect::<HashMap<_, _>>(),
488            stmt.with.to_str_map()
489        );
490
491        assert_eq!(
492            [("foo", "Bar"), ("one", "two")]
493                .into_iter()
494                .collect::<HashMap<_, _>>(),
495            stmt.connection.to_str_map()
496        );
497    }
498
499    #[test]
500    fn test_copy_query_to() {
501        let sql = "COPY (SELECT * FROM tbl WHERE ts > 10) TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
502        let stmt =
503            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
504                .unwrap()
505                .pop()
506                .unwrap();
507
508        let Copy(crate::statements::copy::Copy::CopyQueryTo(stmt)) = stmt else {
509            unreachable!()
510        };
511
512        let query = ParserContext::create_with_dialect(
513            "SELECT * FROM tbl WHERE ts > 10",
514            &GreptimeDbDialect {},
515            ParseOptions::default(),
516        )
517        .unwrap()
518        .remove(0);
519
520        assert_eq!(&query, stmt.query.as_ref());
521        assert_eq!(
522            [("format", "parquet")]
523                .into_iter()
524                .collect::<HashMap<_, _>>(),
525            stmt.arg.with.to_str_map()
526        );
527
528        assert_eq!(
529            [("foo", "Bar"), ("one", "two")]
530                .into_iter()
531                .collect::<HashMap<_, _>>(),
532            stmt.arg.connection.to_str_map()
533        );
534    }
535
536    #[test]
537    fn test_invalid_copy_query_to() {
538        {
539            let sql = "COPY SELECT * FROM tbl WHERE ts > 10 TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
540
541            assert!(ParserContext::create_with_dialect(
542                sql,
543                &GreptimeDbDialect {},
544                ParseOptions::default()
545            )
546            .is_err())
547        }
548        {
549            let sql = "COPY SELECT * FROM tbl WHERE ts > 10) TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
550
551            assert!(ParserContext::create_with_dialect(
552                sql,
553                &GreptimeDbDialect {},
554                ParseOptions::default()
555            )
556            .is_err())
557        }
558        {
559            let sql = "COPY (SELECT * FROM tbl WHERE ts > 10 TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
560
561            assert!(ParserContext::create_with_dialect(
562                sql,
563                &GreptimeDbDialect {},
564                ParseOptions::default()
565            )
566            .is_err())
567        }
568    }
569}