sql/parsers/
copy_parser.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use snafu::ResultExt;
16use sqlparser::keywords::Keyword;
17use sqlparser::tokenizer::Token;
18use sqlparser::tokenizer::Token::Word;
19
20use crate::error::{self, Result};
21use crate::parser::ParserContext;
22use crate::statements::OptionMap;
23use crate::statements::copy::{
24    CopyDatabase, CopyDatabaseArgument, CopyQueryTo, CopyQueryToArgument, CopyTable,
25    CopyTableArgument,
26};
27use crate::statements::statement::Statement;
28use crate::util::parse_option_string;
29
30// COPY tbl TO 'output.parquet';
31impl ParserContext<'_> {
32    pub(crate) fn parse_copy(&mut self) -> Result<Statement> {
33        let _ = self.parser.next_token();
34        let next = self.parser.peek_token();
35        let copy = if next.token == Token::LParen {
36            let copy_query = self.parse_copy_query_to()?;
37
38            crate::statements::copy::Copy::CopyQueryTo(copy_query)
39        } else if let Word(word) = next.token
40            && word.keyword == Keyword::DATABASE
41        {
42            let _ = self.parser.next_token();
43            let copy_database = self.parser_copy_database()?;
44            crate::statements::copy::Copy::CopyDatabase(copy_database)
45        } else {
46            let copy_table = self.parse_copy_table()?;
47            crate::statements::copy::Copy::CopyTable(copy_table)
48        };
49
50        Ok(Statement::Copy(copy))
51    }
52
53    fn parser_copy_database(&mut self) -> Result<CopyDatabase> {
54        let database_name = self
55            .parse_object_name()
56            .with_context(|_| error::UnexpectedSnafu {
57                expected: "a database name",
58                actual: self.peek_token_as_string(),
59            })?;
60
61        let req = if self.parser.parse_keyword(Keyword::TO) {
62            let (with, connection, location, limit) = self.parse_copy_parameters()?;
63            if limit.is_some() {
64                return error::InvalidSqlSnafu {
65                    msg: "limit is not supported",
66                }
67                .fail();
68            }
69
70            let argument = CopyDatabaseArgument {
71                database_name,
72                with,
73                connection,
74                location,
75            };
76            CopyDatabase::To(argument)
77        } else {
78            self.parser
79                .expect_keyword(Keyword::FROM)
80                .context(error::SyntaxSnafu)?;
81            let (with, connection, location, limit) = self.parse_copy_parameters()?;
82            if limit.is_some() {
83                return error::InvalidSqlSnafu {
84                    msg: "limit is not supported",
85                }
86                .fail();
87            }
88
89            let argument = CopyDatabaseArgument {
90                database_name,
91                with,
92                connection,
93                location,
94            };
95            CopyDatabase::From(argument)
96        };
97        Ok(req)
98    }
99
100    fn parse_copy_table(&mut self) -> Result<CopyTable> {
101        let raw_table_name = self
102            .parse_object_name()
103            .with_context(|_| error::UnexpectedSnafu {
104                expected: "a table name",
105                actual: self.peek_token_as_string(),
106            })?;
107        let table_name = Self::canonicalize_object_name(raw_table_name)?;
108
109        if self.parser.parse_keyword(Keyword::TO) {
110            let (with, connection, location, limit) = self.parse_copy_parameters()?;
111            Ok(CopyTable::To(CopyTableArgument {
112                table_name,
113                with,
114                connection,
115                location,
116                limit,
117            }))
118        } else {
119            self.parser
120                .expect_keyword(Keyword::FROM)
121                .context(error::SyntaxSnafu)?;
122            let (with, connection, location, limit) = self.parse_copy_parameters()?;
123            Ok(CopyTable::From(CopyTableArgument {
124                table_name,
125                with,
126                connection,
127                location,
128                limit,
129            }))
130        }
131    }
132
133    fn parse_copy_query_to(&mut self) -> Result<CopyQueryTo> {
134        self.parser
135            .expect_token(&Token::LParen)
136            .with_context(|_| error::UnexpectedSnafu {
137                expected: "'('",
138                actual: self.peek_token_as_string(),
139            })?;
140        let query = self.parse_query()?;
141        self.parser
142            .expect_token(&Token::RParen)
143            .with_context(|_| error::UnexpectedSnafu {
144                expected: "')'",
145                actual: self.peek_token_as_string(),
146            })?;
147        self.parser
148            .expect_keyword(Keyword::TO)
149            .context(error::SyntaxSnafu)?;
150        let (with, connection, location, limit) = self.parse_copy_parameters()?;
151        if limit.is_some() {
152            return error::InvalidSqlSnafu {
153                msg: "limit is not supported",
154            }
155            .fail();
156        }
157        Ok(CopyQueryTo {
158            query: Box::new(query),
159            arg: CopyQueryToArgument {
160                with,
161                connection,
162                location,
163            },
164        })
165    }
166
167    fn parse_copy_parameters(&mut self) -> Result<(OptionMap, OptionMap, String, Option<u64>)> {
168        let location =
169            self.parser
170                .parse_literal_string()
171                .with_context(|_| error::UnexpectedSnafu {
172                    expected: "a file name",
173                    actual: self.peek_token_as_string(),
174                })?;
175
176        let options = self
177            .parser
178            .parse_options(Keyword::WITH)
179            .context(error::SyntaxSnafu)?;
180
181        let with = options
182            .into_iter()
183            .map(parse_option_string)
184            .collect::<Result<Vec<_>>>()?;
185        let with = OptionMap::new(with);
186
187        let connection_options = self
188            .parser
189            .parse_options(Keyword::CONNECTION)
190            .context(error::SyntaxSnafu)?;
191
192        let connection = connection_options
193            .into_iter()
194            .map(parse_option_string)
195            .collect::<Result<Vec<_>>>()?;
196        let connection = OptionMap::new(connection);
197
198        let limit = if self.parser.parse_keyword(Keyword::LIMIT) {
199            Some(
200                self.parser
201                    .parse_literal_uint()
202                    .with_context(|_| error::UnexpectedSnafu {
203                        expected: "the number of maximum rows",
204                        actual: self.peek_token_as_string(),
205                    })?,
206            )
207        } else {
208            None
209        };
210
211        Ok((with, connection, location, limit))
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use std::assert_matches::assert_matches;
218    use std::collections::HashMap;
219
220    use sqlparser::ast::{Ident, ObjectName};
221
222    use super::*;
223    use crate::dialect::GreptimeDbDialect;
224    use crate::parser::ParseOptions;
225    use crate::statements::statement::Statement::Copy;
226
227    #[test]
228    fn test_parse_copy_table() {
229        let sql0 = "COPY catalog0.schema0.tbl TO 'tbl_file.parquet'";
230        let sql1 = "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' WITH (FORMAT = 'parquet')";
231        let result0 = ParserContext::create_with_dialect(
232            sql0,
233            &GreptimeDbDialect {},
234            ParseOptions::default(),
235        )
236        .unwrap();
237        let result1 = ParserContext::create_with_dialect(
238            sql1,
239            &GreptimeDbDialect {},
240            ParseOptions::default(),
241        )
242        .unwrap();
243
244        for mut result in [result0, result1] {
245            assert_eq!(1, result.len());
246
247            let statement = result.remove(0);
248            assert_matches!(statement, Statement::Copy { .. });
249            match statement {
250                Copy(copy) => {
251                    let crate::statements::copy::Copy::CopyTable(CopyTable::To(copy_table)) = copy
252                    else {
253                        unreachable!()
254                    };
255                    let table = copy_table.table_name.to_string();
256                    assert_eq!("catalog0.schema0.tbl", table);
257
258                    let file_name = &copy_table.location;
259                    assert_eq!("tbl_file.parquet", file_name);
260
261                    let format = copy_table.format().unwrap();
262                    assert_eq!("parquet", format.to_lowercase());
263                }
264                _ => unreachable!(),
265            }
266        }
267    }
268
269    #[test]
270    fn test_parse_copy_table_from_basic() {
271        let results = [
272            "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet'",
273            "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (FORMAT = 'parquet')",
274        ]
275        .iter()
276        .map(|sql| {
277            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
278                .unwrap()
279        })
280        .collect::<Vec<_>>();
281
282        for mut result in results {
283            assert_eq!(1, result.len());
284
285            let statement = result.remove(0);
286            assert_matches!(statement, Statement::Copy { .. });
287            match statement {
288                Statement::Copy(crate::statements::copy::Copy::CopyTable(CopyTable::From(
289                    copy_table,
290                ))) => {
291                    let table = copy_table.table_name.to_string();
292                    assert_eq!("catalog0.schema0.tbl", table);
293
294                    let file_name = &copy_table.location;
295                    assert_eq!("tbl_file.parquet", file_name);
296
297                    let format = copy_table.format().unwrap();
298                    assert_eq!("parquet", format.to_lowercase());
299                }
300                _ => unreachable!(),
301            }
302        }
303    }
304
305    #[test]
306    fn test_parse_copy_table_from() {
307        struct Test<'a> {
308            sql: &'a str,
309            expected_pattern: Option<String>,
310            expected_connection: HashMap<&'a str, &'a str>,
311        }
312
313        let tests = [
314            Test {
315                sql: "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (PATTERN = 'demo.*')",
316                expected_pattern: Some("demo.*".into()),
317                expected_connection: HashMap::new(),
318            },
319            Test {
320                sql: "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (PATTERN = 'demo.*') CONNECTION (FOO='Bar', ONE='two')",
321                expected_pattern: Some("demo.*".into()),
322                expected_connection: HashMap::from([("foo", "Bar"), ("one", "two")]),
323            },
324        ];
325
326        for test in tests {
327            let mut result = ParserContext::create_with_dialect(
328                test.sql,
329                &GreptimeDbDialect {},
330                ParseOptions::default(),
331            )
332            .unwrap();
333            assert_eq!(1, result.len());
334
335            let statement = result.remove(0);
336            assert_matches!(statement, Statement::Copy { .. });
337            match statement {
338                Statement::Copy(crate::statements::copy::Copy::CopyTable(CopyTable::From(
339                    copy_table,
340                ))) => {
341                    if let Some(expected_pattern) = test.expected_pattern {
342                        assert_eq!(copy_table.pattern().unwrap(), expected_pattern);
343                    }
344                    assert_eq!(copy_table.connection.to_str_map(), test.expected_connection);
345                }
346                _ => unreachable!(),
347            }
348        }
349    }
350
351    #[test]
352    fn test_parse_copy_table_to() {
353        struct Test<'a> {
354            sql: &'a str,
355            expected_connection: HashMap<&'a str, &'a str>,
356        }
357
358        let tests = [
359            Test {
360                sql: "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' ",
361                expected_connection: HashMap::new(),
362            },
363            Test {
364                sql: "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' CONNECTION (FOO='Bar', ONE='two')",
365                expected_connection: HashMap::from([("foo", "Bar"), ("one", "two")]),
366            },
367            Test {
368                sql: "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')",
369                expected_connection: HashMap::from([("foo", "Bar"), ("one", "two")]),
370            },
371        ];
372
373        for test in tests {
374            let mut result = ParserContext::create_with_dialect(
375                test.sql,
376                &GreptimeDbDialect {},
377                ParseOptions::default(),
378            )
379            .unwrap();
380            assert_eq!(1, result.len());
381
382            let statement = result.remove(0);
383            assert_matches!(statement, Statement::Copy { .. });
384            match statement {
385                Statement::Copy(crate::statements::copy::Copy::CopyTable(CopyTable::To(
386                    copy_table,
387                ))) => {
388                    assert_eq!(copy_table.connection.to_str_map(), test.expected_connection);
389                }
390                _ => unreachable!(),
391            }
392        }
393    }
394
395    #[test]
396    fn test_copy_database_to() {
397        let sql = "COPY DATABASE catalog0.schema0 TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
398        let stmt =
399            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
400                .unwrap()
401                .pop()
402                .unwrap();
403
404        let Copy(crate::statements::copy::Copy::CopyDatabase(stmt)) = stmt else {
405            unreachable!()
406        };
407
408        let CopyDatabase::To(stmt) = stmt else {
409            unreachable!()
410        };
411
412        assert_eq!(
413            ObjectName::from(vec![Ident::new("catalog0"), Ident::new("schema0")]),
414            stmt.database_name
415        );
416        assert_eq!(
417            [("format", "parquet")]
418                .into_iter()
419                .collect::<HashMap<_, _>>(),
420            stmt.with.to_str_map()
421        );
422
423        assert_eq!(
424            [("foo", "Bar"), ("one", "two")]
425                .into_iter()
426                .collect::<HashMap<_, _>>(),
427            stmt.connection.to_str_map()
428        );
429    }
430
431    #[test]
432    fn test_copy_database_from() {
433        let sql = "COPY DATABASE catalog0.schema0 FROM '/a/b/c/' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
434        let stmt =
435            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
436                .unwrap()
437                .pop()
438                .unwrap();
439
440        let Copy(crate::statements::copy::Copy::CopyDatabase(stmt)) = stmt else {
441            unreachable!()
442        };
443
444        let CopyDatabase::From(stmt) = stmt else {
445            unreachable!()
446        };
447
448        assert_eq!(
449            ObjectName::from(vec![Ident::new("catalog0"), Ident::new("schema0")]),
450            stmt.database_name
451        );
452        assert_eq!(
453            [("format", "parquet")]
454                .into_iter()
455                .collect::<HashMap<_, _>>(),
456            stmt.with.to_str_map()
457        );
458
459        assert_eq!(
460            [("foo", "Bar"), ("one", "two")]
461                .into_iter()
462                .collect::<HashMap<_, _>>(),
463            stmt.connection.to_str_map()
464        );
465    }
466
467    #[test]
468    fn test_copy_query_to() {
469        let sql = "COPY (SELECT * FROM tbl WHERE ts > 10) TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
470        let stmt =
471            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
472                .unwrap()
473                .pop()
474                .unwrap();
475
476        let Copy(crate::statements::copy::Copy::CopyQueryTo(stmt)) = stmt else {
477            unreachable!()
478        };
479
480        let query = ParserContext::create_with_dialect(
481            "SELECT * FROM tbl WHERE ts > 10",
482            &GreptimeDbDialect {},
483            ParseOptions::default(),
484        )
485        .unwrap()
486        .remove(0);
487
488        assert_eq!(&query, stmt.query.as_ref());
489        assert_eq!(
490            [("format", "parquet")]
491                .into_iter()
492                .collect::<HashMap<_, _>>(),
493            stmt.arg.with.to_str_map()
494        );
495
496        assert_eq!(
497            [("foo", "Bar"), ("one", "two")]
498                .into_iter()
499                .collect::<HashMap<_, _>>(),
500            stmt.arg.connection.to_str_map()
501        );
502    }
503
504    #[test]
505    fn test_invalid_copy_query_to() {
506        {
507            let sql = "COPY SELECT * FROM tbl WHERE ts > 10 TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
508
509            assert!(
510                ParserContext::create_with_dialect(
511                    sql,
512                    &GreptimeDbDialect {},
513                    ParseOptions::default()
514                )
515                .is_err()
516            )
517        }
518        {
519            let sql = "COPY SELECT * FROM tbl WHERE ts > 10) TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
520
521            assert!(
522                ParserContext::create_with_dialect(
523                    sql,
524                    &GreptimeDbDialect {},
525                    ParseOptions::default()
526                )
527                .is_err()
528            )
529        }
530        {
531            let sql = "COPY (SELECT * FROM tbl WHERE ts > 10 TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
532
533            assert!(
534                ParserContext::create_with_dialect(
535                    sql,
536                    &GreptimeDbDialect {},
537                    ParseOptions::default()
538                )
539                .is_err()
540            )
541        }
542    }
543}