sql/statements/
copy.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16
17use serde::Serialize;
18use sqlparser::ast::ObjectName;
19use sqlparser_derive::{Visit, VisitMut};
20
21use crate::statements::statement::Statement;
22use crate::statements::OptionMap;
23
24#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
25pub enum Copy {
26    CopyTable(CopyTable),
27    CopyDatabase(CopyDatabase),
28    CopyQueryTo(CopyQueryTo),
29}
30
31impl Display for Copy {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        match self {
34            Copy::CopyTable(s) => s.fmt(f),
35            Copy::CopyDatabase(s) => s.fmt(f),
36            Copy::CopyQueryTo(s) => s.fmt(f),
37        }
38    }
39}
40
41#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
42pub struct CopyQueryTo {
43    pub query: Box<Statement>,
44    pub arg: CopyQueryToArgument,
45}
46
47impl Display for CopyQueryTo {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        write!(f, "COPY ({}) TO {}", &self.query, &self.arg.location)?;
50        if !self.arg.with.is_empty() {
51            let options = self.arg.with.kv_pairs();
52            write!(f, " WITH ({})", options.join(", "))?;
53        }
54        if !self.arg.connection.is_empty() {
55            let options = self.arg.connection.kv_pairs();
56            write!(f, " CONNECTION ({})", options.join(", "))?;
57        }
58        Ok(())
59    }
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
63pub enum CopyTable {
64    To(CopyTableArgument),
65    From(CopyTableArgument),
66}
67
68impl Display for CopyTable {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        write!(f, "COPY ")?;
71        let (with, connection) = match self {
72            CopyTable::To(args) => {
73                write!(f, "{} TO {}", &args.table_name, &args.location)?;
74                (&args.with, &args.connection)
75            }
76            CopyTable::From(args) => {
77                write!(f, "{} FROM {}", &args.table_name, &args.location)?;
78                (&args.with, &args.connection)
79            }
80        };
81        if !with.is_empty() {
82            let options = with.kv_pairs();
83            write!(f, " WITH ({})", options.join(", "))?;
84        }
85        if !connection.is_empty() {
86            let options = connection.kv_pairs();
87            write!(f, " CONNECTION ({})", options.join(", "))?;
88        }
89        Ok(())
90    }
91}
92
93#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
94pub enum CopyDatabase {
95    To(CopyDatabaseArgument),
96    From(CopyDatabaseArgument),
97}
98
99impl Display for CopyDatabase {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        write!(f, "COPY DATABASE ")?;
102        let (with, connection) = match self {
103            CopyDatabase::To(args) => {
104                write!(f, "{} TO {}", &args.database_name, &args.location)?;
105                (&args.with, &args.connection)
106            }
107            CopyDatabase::From(args) => {
108                write!(f, "{} FROM {}", &args.database_name, &args.location)?;
109                (&args.with, &args.connection)
110            }
111        };
112        if !with.is_empty() {
113            let options = with.kv_pairs();
114            write!(f, " WITH ({})", options.join(", "))?;
115        }
116        if !connection.is_empty() {
117            let options = connection.kv_pairs();
118            write!(f, " CONNECTION ({})", options.join(", "))?;
119        }
120        Ok(())
121    }
122}
123
124#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
125pub struct CopyDatabaseArgument {
126    pub database_name: ObjectName,
127    pub with: OptionMap,
128    pub connection: OptionMap,
129    pub location: String,
130}
131
132#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
133pub struct CopyTableArgument {
134    pub table_name: ObjectName,
135    pub with: OptionMap,
136    pub connection: OptionMap,
137    /// Copy tbl [To|From] 'location'.
138    pub location: String,
139    pub limit: Option<u64>,
140}
141
142#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
143pub struct CopyQueryToArgument {
144    pub with: OptionMap,
145    pub connection: OptionMap,
146    pub location: String,
147}
148
149#[cfg(test)]
150impl CopyTableArgument {
151    pub fn format(&self) -> Option<String> {
152        self.with
153            .get(common_datasource::file_format::FORMAT_TYPE)
154            .cloned()
155            .or_else(|| Some("PARQUET".to_string()))
156    }
157
158    pub fn pattern(&self) -> Option<String> {
159        self.with
160            .get(common_datasource::file_format::FILE_PATTERN)
161            .cloned()
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use std::assert_matches::assert_matches;
168
169    use crate::dialect::GreptimeDbDialect;
170    use crate::parser::{ParseOptions, ParserContext};
171    use crate::statements::statement::Statement;
172
173    #[test]
174    fn test_display_copy_from_tb() {
175        let sql = r"copy tbl from 's3://my-bucket/data.parquet'
176            with (format = 'parquet', pattern = '.*parquet.*')
177            connection(region = 'us-west-2', secret_access_key = '12345678');";
178        let stmts =
179            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
180                .unwrap();
181        assert_eq!(1, stmts.len());
182        assert_matches!(&stmts[0], Statement::Copy { .. });
183
184        match &stmts[0] {
185            Statement::Copy(copy) => {
186                let new_sql = format!("{}", copy);
187                assert_eq!(
188                    r#"COPY tbl FROM s3://my-bucket/data.parquet WITH (format = 'parquet', pattern = '.*parquet.*') CONNECTION (region = 'us-west-2', secret_access_key = '******')"#,
189                    &new_sql
190                );
191            }
192            _ => {
193                unreachable!();
194            }
195        }
196    }
197
198    #[test]
199    fn test_display_copy_to_tb() {
200        let sql = r"copy tbl to 's3://my-bucket/data.parquet'
201            with (format = 'parquet', pattern = '.*parquet.*')
202            connection(region = 'us-west-2', secret_access_key = '12345678');";
203        let stmts =
204            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
205                .unwrap();
206        assert_eq!(1, stmts.len());
207        assert_matches!(&stmts[0], Statement::Copy { .. });
208
209        match &stmts[0] {
210            Statement::Copy(copy) => {
211                let new_sql = format!("{}", copy);
212                assert_eq!(
213                    r#"COPY tbl TO s3://my-bucket/data.parquet WITH (format = 'parquet', pattern = '.*parquet.*') CONNECTION (region = 'us-west-2', secret_access_key = '******')"#,
214                    &new_sql
215                );
216            }
217            _ => {
218                unreachable!();
219            }
220        }
221    }
222
223    #[test]
224    fn test_display_copy_from_db() {
225        let sql = r"copy database db1 from 's3://my-bucket/data.parquet'
226            with (format = 'parquet', pattern = '.*parquet.*')
227            connection(region = 'us-west-2', secret_access_key = '12345678');";
228        let stmts =
229            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
230                .unwrap();
231        assert_eq!(1, stmts.len());
232        assert_matches!(&stmts[0], Statement::Copy { .. });
233
234        match &stmts[0] {
235            Statement::Copy(copy) => {
236                let new_sql = format!("{}", copy);
237                assert_eq!(
238                    r#"COPY DATABASE db1 FROM s3://my-bucket/data.parquet WITH (format = 'parquet', pattern = '.*parquet.*') CONNECTION (region = 'us-west-2', secret_access_key = '******')"#,
239                    &new_sql
240                );
241            }
242            _ => {
243                unreachable!();
244            }
245        }
246    }
247
248    #[test]
249    fn test_display_copy_to_db() {
250        let sql = r"copy database db1 to 's3://my-bucket/data.parquet'
251            with (format = 'parquet', pattern = '.*parquet.*')
252            connection(region = 'us-west-2', secret_access_key = '12345678');";
253        let stmts =
254            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
255                .unwrap();
256        assert_eq!(1, stmts.len());
257        assert_matches!(&stmts[0], Statement::Copy { .. });
258
259        match &stmts[0] {
260            Statement::Copy(copy) => {
261                let new_sql = format!("{}", copy);
262                assert_eq!(
263                    r#"COPY DATABASE db1 TO s3://my-bucket/data.parquet WITH (format = 'parquet', pattern = '.*parquet.*') CONNECTION (region = 'us-west-2', secret_access_key = '******')"#,
264                    &new_sql
265                );
266            }
267            _ => {
268                unreachable!();
269            }
270        }
271    }
272
273    #[test]
274    fn test_display_copy_query_to() {
275        let sql = r"copy (SELECT * FROM tbl WHERE ts > 10) to 's3://my-bucket/data.parquet'
276            with (format = 'parquet', pattern = '.*parquet.*')
277            connection(region = 'us-west-2', secret_access_key = '12345678');";
278        let stmts =
279            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
280                .unwrap();
281        assert_eq!(1, stmts.len());
282        assert_matches!(&stmts[0], Statement::Copy { .. });
283
284        match &stmts[0] {
285            Statement::Copy(copy) => {
286                let new_sql = format!("{}", copy);
287                assert_eq!(
288                    r#"COPY (SELECT * FROM tbl WHERE ts > 10) TO s3://my-bucket/data.parquet WITH (format = 'parquet', pattern = '.*parquet.*') CONNECTION (region = 'us-west-2', secret_access_key = '******')"#,
289                    &new_sql
290                );
291            }
292            _ => {
293                unreachable!();
294            }
295        }
296    }
297}