sql/parsers/
copy_parser.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use snafu::ResultExt;
18use sqlparser::keywords::Keyword;
19use sqlparser::tokenizer::Token;
20use sqlparser::tokenizer::Token::Word;
21
22use crate::error::{self, Result};
23use crate::parser::ParserContext;
24use crate::statements::copy::{
25    CopyDatabase, CopyDatabaseArgument, CopyQueryTo, CopyQueryToArgument, CopyTable,
26    CopyTableArgument,
27};
28use crate::statements::statement::Statement;
29use crate::util::parse_option_string;
30
31pub type With = HashMap<String, String>;
32pub type Connection = HashMap<String, String>;
33
34// COPY tbl TO 'output.parquet';
35impl ParserContext<'_> {
36    pub(crate) fn parse_copy(&mut self) -> Result<Statement> {
37        let _ = self.parser.next_token();
38        let next = self.parser.peek_token();
39        let copy = if next.token == Token::LParen {
40            let copy_query = self.parse_copy_query_to()?;
41
42            crate::statements::copy::Copy::CopyQueryTo(copy_query)
43        } else if let Word(word) = next.token
44            && word.keyword == Keyword::DATABASE
45        {
46            let _ = self.parser.next_token();
47            let copy_database = self.parser_copy_database()?;
48            crate::statements::copy::Copy::CopyDatabase(copy_database)
49        } else {
50            let copy_table = self.parse_copy_table()?;
51            crate::statements::copy::Copy::CopyTable(copy_table)
52        };
53
54        Ok(Statement::Copy(copy))
55    }
56
57    fn parser_copy_database(&mut self) -> Result<CopyDatabase> {
58        let database_name = self
59            .parse_object_name()
60            .with_context(|_| error::UnexpectedSnafu {
61                expected: "a database name",
62                actual: self.peek_token_as_string(),
63            })?;
64
65        let req = if self.parser.parse_keyword(Keyword::TO) {
66            let (with, connection, location, limit) = self.parse_copy_parameters()?;
67            if limit.is_some() {
68                return error::InvalidSqlSnafu {
69                    msg: "limit is not supported",
70                }
71                .fail();
72            }
73
74            let argument = CopyDatabaseArgument {
75                database_name,
76                with: with.into(),
77                connection: connection.into(),
78                location,
79            };
80            CopyDatabase::To(argument)
81        } else {
82            self.parser
83                .expect_keyword(Keyword::FROM)
84                .context(error::SyntaxSnafu)?;
85            let (with, connection, location, limit) = self.parse_copy_parameters()?;
86            if limit.is_some() {
87                return error::InvalidSqlSnafu {
88                    msg: "limit is not supported",
89                }
90                .fail();
91            }
92
93            let argument = CopyDatabaseArgument {
94                database_name,
95                with: with.into(),
96                connection: connection.into(),
97                location,
98            };
99            CopyDatabase::From(argument)
100        };
101        Ok(req)
102    }
103
104    fn parse_copy_table(&mut self) -> Result<CopyTable> {
105        let raw_table_name = self
106            .parse_object_name()
107            .with_context(|_| error::UnexpectedSnafu {
108                expected: "a table name",
109                actual: self.peek_token_as_string(),
110            })?;
111        let table_name = Self::canonicalize_object_name(raw_table_name);
112
113        if self.parser.parse_keyword(Keyword::TO) {
114            let (with, connection, location, limit) = self.parse_copy_parameters()?;
115            Ok(CopyTable::To(CopyTableArgument {
116                table_name,
117                with: with.into(),
118                connection: connection.into(),
119                location,
120                limit,
121            }))
122        } else {
123            self.parser
124                .expect_keyword(Keyword::FROM)
125                .context(error::SyntaxSnafu)?;
126            let (with, connection, location, limit) = self.parse_copy_parameters()?;
127            Ok(CopyTable::From(CopyTableArgument {
128                table_name,
129                with: with.into(),
130                connection: connection.into(),
131                location,
132                limit,
133            }))
134        }
135    }
136
137    fn parse_copy_query_to(&mut self) -> Result<CopyQueryTo> {
138        self.parser
139            .expect_token(&Token::LParen)
140            .with_context(|_| error::UnexpectedSnafu {
141                expected: "'('",
142                actual: self.peek_token_as_string(),
143            })?;
144        let query = self.parse_query()?;
145        self.parser
146            .expect_token(&Token::RParen)
147            .with_context(|_| error::UnexpectedSnafu {
148                expected: "')'",
149                actual: self.peek_token_as_string(),
150            })?;
151        self.parser
152            .expect_keyword(Keyword::TO)
153            .context(error::SyntaxSnafu)?;
154        let (with, connection, location, limit) = self.parse_copy_parameters()?;
155        if limit.is_some() {
156            return error::InvalidSqlSnafu {
157                msg: "limit is not supported",
158            }
159            .fail();
160        }
161        Ok(CopyQueryTo {
162            query: Box::new(query),
163            arg: CopyQueryToArgument {
164                with: with.into(),
165                connection: connection.into(),
166                location,
167            },
168        })
169    }
170
171    fn parse_copy_parameters(&mut self) -> Result<(With, Connection, String, Option<u64>)> {
172        let location =
173            self.parser
174                .parse_literal_string()
175                .with_context(|_| error::UnexpectedSnafu {
176                    expected: "a file name",
177                    actual: self.peek_token_as_string(),
178                })?;
179
180        let options = self
181            .parser
182            .parse_options(Keyword::WITH)
183            .context(error::SyntaxSnafu)?;
184
185        let with = options
186            .into_iter()
187            .map(parse_option_string)
188            .collect::<Result<With>>()?;
189
190        let connection_options = self
191            .parser
192            .parse_options(Keyword::CONNECTION)
193            .context(error::SyntaxSnafu)?;
194
195        let connection = connection_options
196            .into_iter()
197            .map(parse_option_string)
198            .collect::<Result<Connection>>()?;
199
200        let limit = if self.parser.parse_keyword(Keyword::LIMIT) {
201            Some(
202                self.parser
203                    .parse_literal_uint()
204                    .with_context(|_| error::UnexpectedSnafu {
205                        expected: "the number of maximum rows",
206                        actual: self.peek_token_as_string(),
207                    })?,
208            )
209        } else {
210            None
211        };
212
213        Ok((with, connection, location, limit))
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use std::assert_matches::assert_matches;
220    use std::collections::HashMap;
221
222    use sqlparser::ast::{Ident, ObjectName};
223
224    use super::*;
225    use crate::dialect::GreptimeDbDialect;
226    use crate::parser::ParseOptions;
227    use crate::statements::statement::Statement::Copy;
228
229    #[test]
230    fn test_parse_copy_table() {
231        let sql0 = "COPY catalog0.schema0.tbl TO 'tbl_file.parquet'";
232        let sql1 = "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' WITH (FORMAT = 'parquet')";
233        let result0 = ParserContext::create_with_dialect(
234            sql0,
235            &GreptimeDbDialect {},
236            ParseOptions::default(),
237        )
238        .unwrap();
239        let result1 = ParserContext::create_with_dialect(
240            sql1,
241            &GreptimeDbDialect {},
242            ParseOptions::default(),
243        )
244        .unwrap();
245
246        for mut result in [result0, result1] {
247            assert_eq!(1, result.len());
248
249            let statement = result.remove(0);
250            assert_matches!(statement, Statement::Copy { .. });
251            match statement {
252                Copy(copy) => {
253                    let crate::statements::copy::Copy::CopyTable(CopyTable::To(copy_table)) = copy
254                    else {
255                        unreachable!()
256                    };
257                    let table = copy_table.table_name.to_string();
258                    assert_eq!("catalog0.schema0.tbl", table);
259
260                    let file_name = &copy_table.location;
261                    assert_eq!("tbl_file.parquet", file_name);
262
263                    let format = copy_table.format().unwrap();
264                    assert_eq!("parquet", format.to_lowercase());
265                }
266                _ => unreachable!(),
267            }
268        }
269    }
270
271    #[test]
272    fn test_parse_copy_table_from_basic() {
273        let results = [
274            "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet'",
275            "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (FORMAT = 'parquet')",
276        ]
277        .iter()
278        .map(|sql| {
279            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
280                .unwrap()
281        })
282        .collect::<Vec<_>>();
283
284        for mut result in results {
285            assert_eq!(1, result.len());
286
287            let statement = result.remove(0);
288            assert_matches!(statement, Statement::Copy { .. });
289            match statement {
290                Statement::Copy(crate::statements::copy::Copy::CopyTable(CopyTable::From(
291                    copy_table,
292                ))) => {
293                    let table = copy_table.table_name.to_string();
294                    assert_eq!("catalog0.schema0.tbl", table);
295
296                    let file_name = &copy_table.location;
297                    assert_eq!("tbl_file.parquet", file_name);
298
299                    let format = copy_table.format().unwrap();
300                    assert_eq!("parquet", format.to_lowercase());
301                }
302                _ => unreachable!(),
303            }
304        }
305    }
306
307    #[test]
308    fn test_parse_copy_table_from() {
309        struct Test<'a> {
310            sql: &'a str,
311            expected_pattern: Option<String>,
312            expected_connection: HashMap<String, String>,
313        }
314
315        let tests = [
316            Test {
317                sql: "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (PATTERN = 'demo.*')",
318                expected_pattern: Some("demo.*".into()),
319                expected_connection: HashMap::new(),
320            },
321            Test {
322                sql: "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (PATTERN = 'demo.*') CONNECTION (FOO='Bar', ONE='two')",
323                expected_pattern: Some("demo.*".into()),
324                expected_connection: [("foo", "Bar"), ("one", "two")]
325                    .into_iter()
326                    .map(|(k, v)| (k.to_string(), v.to_string()))
327                    .collect(),
328            },
329        ];
330
331        for test in tests {
332            let mut result = ParserContext::create_with_dialect(
333                test.sql,
334                &GreptimeDbDialect {},
335                ParseOptions::default(),
336            )
337            .unwrap();
338            assert_eq!(1, result.len());
339
340            let statement = result.remove(0);
341            assert_matches!(statement, Statement::Copy { .. });
342            match statement {
343                Statement::Copy(crate::statements::copy::Copy::CopyTable(CopyTable::From(
344                    copy_table,
345                ))) => {
346                    if let Some(expected_pattern) = test.expected_pattern {
347                        assert_eq!(copy_table.pattern().unwrap(), expected_pattern);
348                    }
349                    assert_eq!(
350                        copy_table.connection.clone(),
351                        test.expected_connection.into()
352                    );
353                }
354                _ => unreachable!(),
355            }
356        }
357    }
358
359    #[test]
360    fn test_parse_copy_table_to() {
361        struct Test<'a> {
362            sql: &'a str,
363            expected_connection: HashMap<String, String>,
364        }
365
366        let tests = [
367            Test {
368                sql: "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' ",
369                expected_connection: HashMap::new(),
370            },
371            Test {
372                sql: "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' CONNECTION (FOO='Bar', ONE='two')",
373                expected_connection: [("foo", "Bar"), ("one", "two")]
374                    .into_iter()
375                    .map(|(k, v)| (k.to_string(), v.to_string()))
376                    .collect(),
377            },
378            Test {
379                sql: "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')",
380                expected_connection: [("foo", "Bar"), ("one", "two")]
381                    .into_iter()
382                    .map(|(k, v)| (k.to_string(), v.to_string()))
383                    .collect(),
384            },
385        ];
386
387        for test in tests {
388            let mut result = ParserContext::create_with_dialect(
389                test.sql,
390                &GreptimeDbDialect {},
391                ParseOptions::default(),
392            )
393            .unwrap();
394            assert_eq!(1, result.len());
395
396            let statement = result.remove(0);
397            assert_matches!(statement, Statement::Copy { .. });
398            match statement {
399                Statement::Copy(crate::statements::copy::Copy::CopyTable(CopyTable::To(
400                    copy_table,
401                ))) => {
402                    assert_eq!(
403                        copy_table.connection.clone(),
404                        test.expected_connection.into()
405                    );
406                }
407                _ => unreachable!(),
408            }
409        }
410    }
411
412    #[test]
413    fn test_copy_database_to() {
414        let sql = "COPY DATABASE catalog0.schema0 TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
415        let stmt =
416            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
417                .unwrap()
418                .pop()
419                .unwrap();
420
421        let Copy(crate::statements::copy::Copy::CopyDatabase(stmt)) = stmt else {
422            unreachable!()
423        };
424
425        let CopyDatabase::To(stmt) = stmt else {
426            unreachable!()
427        };
428
429        assert_eq!(
430            ObjectName::from(vec![Ident::new("catalog0"), Ident::new("schema0")]),
431            stmt.database_name
432        );
433        assert_eq!(
434            [("format", "parquet")]
435                .into_iter()
436                .collect::<HashMap<_, _>>(),
437            stmt.with.to_str_map()
438        );
439
440        assert_eq!(
441            [("foo", "Bar"), ("one", "two")]
442                .into_iter()
443                .collect::<HashMap<_, _>>(),
444            stmt.connection.to_str_map()
445        );
446    }
447
448    #[test]
449    fn test_copy_database_from() {
450        let sql = "COPY DATABASE catalog0.schema0 FROM '/a/b/c/' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
451        let stmt =
452            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
453                .unwrap()
454                .pop()
455                .unwrap();
456
457        let Copy(crate::statements::copy::Copy::CopyDatabase(stmt)) = stmt else {
458            unreachable!()
459        };
460
461        let CopyDatabase::From(stmt) = stmt else {
462            unreachable!()
463        };
464
465        assert_eq!(
466            ObjectName::from(vec![Ident::new("catalog0"), Ident::new("schema0")]),
467            stmt.database_name
468        );
469        assert_eq!(
470            [("format", "parquet")]
471                .into_iter()
472                .collect::<HashMap<_, _>>(),
473            stmt.with.to_str_map()
474        );
475
476        assert_eq!(
477            [("foo", "Bar"), ("one", "two")]
478                .into_iter()
479                .collect::<HashMap<_, _>>(),
480            stmt.connection.to_str_map()
481        );
482    }
483
484    #[test]
485    fn test_copy_query_to() {
486        let sql = "COPY (SELECT * FROM tbl WHERE ts > 10) TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
487        let stmt =
488            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
489                .unwrap()
490                .pop()
491                .unwrap();
492
493        let Copy(crate::statements::copy::Copy::CopyQueryTo(stmt)) = stmt else {
494            unreachable!()
495        };
496
497        let query = ParserContext::create_with_dialect(
498            "SELECT * FROM tbl WHERE ts > 10",
499            &GreptimeDbDialect {},
500            ParseOptions::default(),
501        )
502        .unwrap()
503        .remove(0);
504
505        assert_eq!(&query, stmt.query.as_ref());
506        assert_eq!(
507            [("format", "parquet")]
508                .into_iter()
509                .collect::<HashMap<_, _>>(),
510            stmt.arg.with.to_str_map()
511        );
512
513        assert_eq!(
514            [("foo", "Bar"), ("one", "two")]
515                .into_iter()
516                .collect::<HashMap<_, _>>(),
517            stmt.arg.connection.to_str_map()
518        );
519    }
520
521    #[test]
522    fn test_invalid_copy_query_to() {
523        {
524            let sql = "COPY SELECT * FROM tbl WHERE ts > 10 TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
525
526            assert!(
527                ParserContext::create_with_dialect(
528                    sql,
529                    &GreptimeDbDialect {},
530                    ParseOptions::default()
531                )
532                .is_err()
533            )
534        }
535        {
536            let sql = "COPY SELECT * FROM tbl WHERE ts > 10) TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
537
538            assert!(
539                ParserContext::create_with_dialect(
540                    sql,
541                    &GreptimeDbDialect {},
542                    ParseOptions::default()
543                )
544                .is_err()
545            )
546        }
547        {
548            let sql = "COPY (SELECT * FROM tbl WHERE ts > 10 TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
549
550            assert!(
551                ParserContext::create_with_dialect(
552                    sql,
553                    &GreptimeDbDialect {},
554                    ParseOptions::default()
555                )
556                .is_err()
557            )
558        }
559    }
560}