sql/parsers/
copy_parser.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use snafu::ResultExt;
18use sqlparser::keywords::Keyword;
19use sqlparser::tokenizer::Token;
20use sqlparser::tokenizer::Token::Word;
21
22use crate::error::{self, Result};
23use crate::parser::ParserContext;
24use crate::statements::copy::{
25    CopyDatabase, CopyDatabaseArgument, CopyQueryTo, CopyQueryToArgument, CopyTable,
26    CopyTableArgument,
27};
28use crate::statements::statement::Statement;
29use crate::util::parse_option_string;
30
31pub type With = HashMap<String, String>;
32pub type Connection = HashMap<String, String>;
33
34// COPY tbl TO 'output.parquet';
35impl ParserContext<'_> {
36    pub(crate) fn parse_copy(&mut self) -> Result<Statement> {
37        let _ = self.parser.next_token();
38        let next = self.parser.peek_token();
39        let copy = if next.token == Token::LParen {
40            let copy_query = self.parse_copy_query_to()?;
41
42            crate::statements::copy::Copy::CopyQueryTo(copy_query)
43        } else if let Word(word) = next.token
44            && word.keyword == Keyword::DATABASE
45        {
46            let _ = self.parser.next_token();
47            let copy_database = self.parser_copy_database()?;
48            crate::statements::copy::Copy::CopyDatabase(copy_database)
49        } else {
50            let copy_table = self.parse_copy_table()?;
51            crate::statements::copy::Copy::CopyTable(copy_table)
52        };
53
54        Ok(Statement::Copy(copy))
55    }
56
57    fn parser_copy_database(&mut self) -> Result<CopyDatabase> {
58        let database_name = self
59            .parse_object_name()
60            .with_context(|_| error::UnexpectedSnafu {
61                expected: "a database name",
62                actual: self.peek_token_as_string(),
63            })?;
64
65        let req = if self.parser.parse_keyword(Keyword::TO) {
66            let (with, connection, location, limit) = self.parse_copy_parameters()?;
67            if limit.is_some() {
68                return error::InvalidSqlSnafu {
69                    msg: "limit is not supported",
70                }
71                .fail();
72            }
73
74            let argument = CopyDatabaseArgument {
75                database_name,
76                with: with.into(),
77                connection: connection.into(),
78                location,
79            };
80            CopyDatabase::To(argument)
81        } else {
82            self.parser
83                .expect_keyword(Keyword::FROM)
84                .context(error::SyntaxSnafu)?;
85            let (with, connection, location, limit) = self.parse_copy_parameters()?;
86            if limit.is_some() {
87                return error::InvalidSqlSnafu {
88                    msg: "limit is not supported",
89                }
90                .fail();
91            }
92
93            let argument = CopyDatabaseArgument {
94                database_name,
95                with: with.into(),
96                connection: connection.into(),
97                location,
98            };
99            CopyDatabase::From(argument)
100        };
101        Ok(req)
102    }
103
104    fn parse_copy_table(&mut self) -> Result<CopyTable> {
105        let raw_table_name = self
106            .parse_object_name()
107            .with_context(|_| error::UnexpectedSnafu {
108                expected: "a table name",
109                actual: self.peek_token_as_string(),
110            })?;
111        let table_name = Self::canonicalize_object_name(raw_table_name);
112
113        if self.parser.parse_keyword(Keyword::TO) {
114            let (with, connection, location, limit) = self.parse_copy_parameters()?;
115            Ok(CopyTable::To(CopyTableArgument {
116                table_name,
117                with: with.into(),
118                connection: connection.into(),
119                location,
120                limit,
121            }))
122        } else {
123            self.parser
124                .expect_keyword(Keyword::FROM)
125                .context(error::SyntaxSnafu)?;
126            let (with, connection, location, limit) = self.parse_copy_parameters()?;
127            Ok(CopyTable::From(CopyTableArgument {
128                table_name,
129                with: with.into(),
130                connection: connection.into(),
131                location,
132                limit,
133            }))
134        }
135    }
136
137    fn parse_copy_query_to(&mut self) -> Result<CopyQueryTo> {
138        self.parser
139            .expect_token(&Token::LParen)
140            .with_context(|_| error::UnexpectedSnafu {
141                expected: "'('",
142                actual: self.peek_token_as_string(),
143            })?;
144        let query = self.parse_query()?;
145        self.parser
146            .expect_token(&Token::RParen)
147            .with_context(|_| error::UnexpectedSnafu {
148                expected: "')'",
149                actual: self.peek_token_as_string(),
150            })?;
151        self.parser
152            .expect_keyword(Keyword::TO)
153            .context(error::SyntaxSnafu)?;
154        let (with, connection, location, limit) = self.parse_copy_parameters()?;
155        if limit.is_some() {
156            return error::InvalidSqlSnafu {
157                msg: "limit is not supported",
158            }
159            .fail();
160        }
161        Ok(CopyQueryTo {
162            query: Box::new(query),
163            arg: CopyQueryToArgument {
164                with: with.into(),
165                connection: connection.into(),
166                location,
167            },
168        })
169    }
170
171    fn parse_copy_parameters(&mut self) -> Result<(With, Connection, String, Option<u64>)> {
172        let location =
173            self.parser
174                .parse_literal_string()
175                .with_context(|_| error::UnexpectedSnafu {
176                    expected: "a file name",
177                    actual: self.peek_token_as_string(),
178                })?;
179
180        let options = self
181            .parser
182            .parse_options(Keyword::WITH)
183            .context(error::SyntaxSnafu)?;
184
185        let with = options
186            .into_iter()
187            .map(parse_option_string)
188            .collect::<Result<With>>()?;
189
190        let connection_options = self
191            .parser
192            .parse_options(Keyword::CONNECTION)
193            .context(error::SyntaxSnafu)?;
194
195        let connection = connection_options
196            .into_iter()
197            .map(parse_option_string)
198            .collect::<Result<Connection>>()?;
199
200        let limit = if self.parser.parse_keyword(Keyword::LIMIT) {
201            Some(
202                self.parser
203                    .parse_literal_uint()
204                    .with_context(|_| error::UnexpectedSnafu {
205                        expected: "the number of maximum rows",
206                        actual: self.peek_token_as_string(),
207                    })?,
208            )
209        } else {
210            None
211        };
212
213        Ok((with, connection, location, limit))
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use std::assert_matches::assert_matches;
220    use std::collections::HashMap;
221
222    use sqlparser::ast::{Ident, ObjectName};
223
224    use super::*;
225    use crate::dialect::GreptimeDbDialect;
226    use crate::parser::ParseOptions;
227    use crate::statements::statement::Statement::Copy;
228
229    #[test]
230    fn test_parse_copy_table() {
231        let sql0 = "COPY catalog0.schema0.tbl TO 'tbl_file.parquet'";
232        let sql1 = "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' WITH (FORMAT = 'parquet')";
233        let result0 = ParserContext::create_with_dialect(
234            sql0,
235            &GreptimeDbDialect {},
236            ParseOptions::default(),
237        )
238        .unwrap();
239        let result1 = ParserContext::create_with_dialect(
240            sql1,
241            &GreptimeDbDialect {},
242            ParseOptions::default(),
243        )
244        .unwrap();
245
246        for mut result in [result0, result1] {
247            assert_eq!(1, result.len());
248
249            let statement = result.remove(0);
250            assert_matches!(statement, Statement::Copy { .. });
251            match statement {
252                Copy(copy) => {
253                    let crate::statements::copy::Copy::CopyTable(CopyTable::To(copy_table)) = copy
254                    else {
255                        unreachable!()
256                    };
257                    let table = copy_table.table_name.to_string();
258                    assert_eq!("catalog0.schema0.tbl", table);
259
260                    let file_name = &copy_table.location;
261                    assert_eq!("tbl_file.parquet", file_name);
262
263                    let format = copy_table.format().unwrap();
264                    assert_eq!("parquet", format.to_lowercase());
265                }
266                _ => unreachable!(),
267            }
268        }
269    }
270
271    #[test]
272    fn test_parse_copy_table_from_basic() {
273        let results = [
274            "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet'",
275            "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (FORMAT = 'parquet')",
276        ]
277        .iter()
278        .map(|sql| {
279            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
280                .unwrap()
281        })
282        .collect::<Vec<_>>();
283
284        for mut result in results {
285            assert_eq!(1, result.len());
286
287            let statement = result.remove(0);
288            assert_matches!(statement, Statement::Copy { .. });
289            match statement {
290                Statement::Copy(crate::statements::copy::Copy::CopyTable(CopyTable::From(
291                    copy_table,
292                ))) => {
293                    let table = copy_table.table_name.to_string();
294                    assert_eq!("catalog0.schema0.tbl", table);
295
296                    let file_name = &copy_table.location;
297                    assert_eq!("tbl_file.parquet", file_name);
298
299                    let format = copy_table.format().unwrap();
300                    assert_eq!("parquet", format.to_lowercase());
301                }
302                _ => unreachable!(),
303            }
304        }
305    }
306
307    #[test]
308    fn test_parse_copy_table_from() {
309        struct Test<'a> {
310            sql: &'a str,
311            expected_pattern: Option<String>,
312            expected_connection: HashMap<String, String>,
313        }
314
315        let tests = [
316            Test {
317                sql: "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (PATTERN = 'demo.*')",
318                expected_pattern: Some("demo.*".into()),
319                expected_connection: HashMap::new(),
320            },
321            Test {
322                sql: "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (PATTERN = 'demo.*') CONNECTION (FOO='Bar', ONE='two')",
323                expected_pattern: Some("demo.*".into()),
324                expected_connection: [("foo","Bar"),("one","two")].into_iter().map(|(k,v)|{(k.to_string(),v.to_string())}).collect()
325            },
326        ];
327
328        for test in tests {
329            let mut result = ParserContext::create_with_dialect(
330                test.sql,
331                &GreptimeDbDialect {},
332                ParseOptions::default(),
333            )
334            .unwrap();
335            assert_eq!(1, result.len());
336
337            let statement = result.remove(0);
338            assert_matches!(statement, Statement::Copy { .. });
339            match statement {
340                Statement::Copy(crate::statements::copy::Copy::CopyTable(CopyTable::From(
341                    copy_table,
342                ))) => {
343                    if let Some(expected_pattern) = test.expected_pattern {
344                        assert_eq!(copy_table.pattern().unwrap(), expected_pattern);
345                    }
346                    assert_eq!(
347                        copy_table.connection.clone(),
348                        test.expected_connection.into()
349                    );
350                }
351                _ => unreachable!(),
352            }
353        }
354    }
355
356    #[test]
357    fn test_parse_copy_table_to() {
358        struct Test<'a> {
359            sql: &'a str,
360            expected_connection: HashMap<String, String>,
361        }
362
363        let tests = [
364            Test {
365                sql: "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' ",
366                expected_connection: HashMap::new(),
367            },
368            Test {
369                sql: "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' CONNECTION (FOO='Bar', ONE='two')",
370                expected_connection: [("foo","Bar"),("one","two")].into_iter().map(|(k,v)|{(k.to_string(),v.to_string())}).collect()
371            },
372            Test {
373                sql:"COPY catalog0.schema0.tbl TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')",
374                expected_connection: [("foo","Bar"),("one","two")].into_iter().map(|(k,v)|{(k.to_string(),v.to_string())}).collect()
375            },
376        ];
377
378        for test in tests {
379            let mut result = ParserContext::create_with_dialect(
380                test.sql,
381                &GreptimeDbDialect {},
382                ParseOptions::default(),
383            )
384            .unwrap();
385            assert_eq!(1, result.len());
386
387            let statement = result.remove(0);
388            assert_matches!(statement, Statement::Copy { .. });
389            match statement {
390                Statement::Copy(crate::statements::copy::Copy::CopyTable(CopyTable::To(
391                    copy_table,
392                ))) => {
393                    assert_eq!(
394                        copy_table.connection.clone(),
395                        test.expected_connection.into()
396                    );
397                }
398                _ => unreachable!(),
399            }
400        }
401    }
402
403    #[test]
404    fn test_copy_database_to() {
405        let sql = "COPY DATABASE catalog0.schema0 TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
406        let stmt =
407            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
408                .unwrap()
409                .pop()
410                .unwrap();
411
412        let Copy(crate::statements::copy::Copy::CopyDatabase(stmt)) = stmt else {
413            unreachable!()
414        };
415
416        let CopyDatabase::To(stmt) = stmt else {
417            unreachable!()
418        };
419
420        assert_eq!(
421            ObjectName::from(vec![Ident::new("catalog0"), Ident::new("schema0")]),
422            stmt.database_name
423        );
424        assert_eq!(
425            [("format", "parquet")]
426                .into_iter()
427                .collect::<HashMap<_, _>>(),
428            stmt.with.to_str_map()
429        );
430
431        assert_eq!(
432            [("foo", "Bar"), ("one", "two")]
433                .into_iter()
434                .collect::<HashMap<_, _>>(),
435            stmt.connection.to_str_map()
436        );
437    }
438
439    #[test]
440    fn test_copy_database_from() {
441        let sql = "COPY DATABASE catalog0.schema0 FROM '/a/b/c/' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
442        let stmt =
443            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
444                .unwrap()
445                .pop()
446                .unwrap();
447
448        let Copy(crate::statements::copy::Copy::CopyDatabase(stmt)) = stmt else {
449            unreachable!()
450        };
451
452        let CopyDatabase::From(stmt) = stmt else {
453            unreachable!()
454        };
455
456        assert_eq!(
457            ObjectName::from(vec![Ident::new("catalog0"), Ident::new("schema0")]),
458            stmt.database_name
459        );
460        assert_eq!(
461            [("format", "parquet")]
462                .into_iter()
463                .collect::<HashMap<_, _>>(),
464            stmt.with.to_str_map()
465        );
466
467        assert_eq!(
468            [("foo", "Bar"), ("one", "two")]
469                .into_iter()
470                .collect::<HashMap<_, _>>(),
471            stmt.connection.to_str_map()
472        );
473    }
474
475    #[test]
476    fn test_copy_query_to() {
477        let sql = "COPY (SELECT * FROM tbl WHERE ts > 10) TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
478        let stmt =
479            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
480                .unwrap()
481                .pop()
482                .unwrap();
483
484        let Copy(crate::statements::copy::Copy::CopyQueryTo(stmt)) = stmt else {
485            unreachable!()
486        };
487
488        let query = ParserContext::create_with_dialect(
489            "SELECT * FROM tbl WHERE ts > 10",
490            &GreptimeDbDialect {},
491            ParseOptions::default(),
492        )
493        .unwrap()
494        .remove(0);
495
496        assert_eq!(&query, stmt.query.as_ref());
497        assert_eq!(
498            [("format", "parquet")]
499                .into_iter()
500                .collect::<HashMap<_, _>>(),
501            stmt.arg.with.to_str_map()
502        );
503
504        assert_eq!(
505            [("foo", "Bar"), ("one", "two")]
506                .into_iter()
507                .collect::<HashMap<_, _>>(),
508            stmt.arg.connection.to_str_map()
509        );
510    }
511
512    #[test]
513    fn test_invalid_copy_query_to() {
514        {
515            let sql = "COPY SELECT * FROM tbl WHERE ts > 10 TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
516
517            assert!(ParserContext::create_with_dialect(
518                sql,
519                &GreptimeDbDialect {},
520                ParseOptions::default()
521            )
522            .is_err())
523        }
524        {
525            let sql = "COPY SELECT * FROM tbl WHERE ts > 10) TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
526
527            assert!(ParserContext::create_with_dialect(
528                sql,
529                &GreptimeDbDialect {},
530                ParseOptions::default()
531            )
532            .is_err())
533        }
534        {
535            let sql = "COPY (SELECT * FROM tbl WHERE ts > 10 TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')";
536
537            assert!(ParserContext::create_with_dialect(
538                sql,
539                &GreptimeDbDialect {},
540                ParseOptions::default()
541            )
542            .is_err())
543        }
544    }
545}