sql/statements/
copy.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16
17use serde::Serialize;
18use sqlparser::ast::ObjectName;
19use sqlparser_derive::{Visit, VisitMut};
20
21use crate::statements::OptionMap;
22use crate::statements::statement::Statement;
23
24#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
25pub enum Copy {
26    CopyTable(CopyTable),
27    CopyDatabase(CopyDatabase),
28    CopyQueryTo(CopyQueryTo),
29}
30
31impl Display for Copy {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        match self {
34            Copy::CopyTable(s) => s.fmt(f),
35            Copy::CopyDatabase(s) => s.fmt(f),
36            Copy::CopyQueryTo(s) => s.fmt(f),
37        }
38    }
39}
40
41#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
42pub struct CopyQueryTo {
43    pub query: Box<Statement>,
44    pub arg: CopyQueryToArgument,
45}
46
47impl Display for CopyQueryTo {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        write!(f, "COPY ({}) TO {}", &self.query, &self.arg.location)?;
50        if !self.arg.with.is_empty() {
51            let options = self.arg.with.kv_pairs();
52            write!(f, " WITH ({})", options.join(", "))?;
53        }
54        if !self.arg.connection.is_empty() {
55            let options = self.arg.connection.kv_pairs();
56            write!(f, " CONNECTION ({})", options.join(", "))?;
57        }
58        Ok(())
59    }
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
63pub enum CopyTable {
64    To(CopyTableArgument),
65    From(CopyTableArgument),
66}
67
68impl Display for CopyTable {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        write!(f, "COPY ")?;
71        let (with, connection) = match self {
72            CopyTable::To(args) => {
73                write!(f, "{} TO {}", &args.table_name, &args.location)?;
74                (&args.with, &args.connection)
75            }
76            CopyTable::From(args) => {
77                write!(f, "{} FROM {}", &args.table_name, &args.location)?;
78                (&args.with, &args.connection)
79            }
80        };
81        if !with.is_empty() {
82            let options = with.kv_pairs();
83            write!(f, " WITH ({})", options.join(", "))?;
84        }
85        if !connection.is_empty() {
86            let options = connection.kv_pairs();
87            write!(f, " CONNECTION ({})", options.join(", "))?;
88        }
89        Ok(())
90    }
91}
92
93#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
94pub enum CopyDatabase {
95    To(CopyDatabaseArgument),
96    From(CopyDatabaseArgument),
97}
98
99impl Display for CopyDatabase {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        write!(f, "COPY DATABASE ")?;
102        let (with, connection) = match self {
103            CopyDatabase::To(args) => {
104                write!(f, "{} TO {}", &args.database_name, &args.location)?;
105                (&args.with, &args.connection)
106            }
107            CopyDatabase::From(args) => {
108                write!(f, "{} FROM {}", &args.database_name, &args.location)?;
109                (&args.with, &args.connection)
110            }
111        };
112        if !with.is_empty() {
113            let options = with.kv_pairs();
114            write!(f, " WITH ({})", options.join(", "))?;
115        }
116        if !connection.is_empty() {
117            let options = connection.kv_pairs();
118            write!(f, " CONNECTION ({})", options.join(", "))?;
119        }
120        Ok(())
121    }
122}
123
124#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
125pub struct CopyDatabaseArgument {
126    pub database_name: ObjectName,
127    pub with: OptionMap,
128    pub connection: OptionMap,
129    pub location: String,
130}
131
132#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
133pub struct CopyTableArgument {
134    pub table_name: ObjectName,
135    pub with: OptionMap,
136    pub connection: OptionMap,
137    /// Copy tbl [To|From] 'location'.
138    pub location: String,
139    pub limit: Option<u64>,
140}
141
142#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
143pub struct CopyQueryToArgument {
144    pub with: OptionMap,
145    pub connection: OptionMap,
146    pub location: String,
147}
148
149#[cfg(test)]
150impl CopyTableArgument {
151    pub fn format(&self) -> Option<String> {
152        self.with
153            .get(common_datasource::file_format::FORMAT_TYPE)
154            .cloned()
155            .or_else(|| Some("PARQUET".to_string()))
156    }
157
158    pub fn pattern(&self) -> Option<String> {
159        self.with
160            .get(common_datasource::file_format::FILE_PATTERN)
161            .cloned()
162    }
163
164    pub fn timestamp_pattern(&self) -> Option<String> {
165        self.with
166            .get(common_datasource::file_format::TIMESTAMP_FORMAT)
167            .cloned()
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use std::assert_matches::assert_matches;
174
175    use crate::dialect::GreptimeDbDialect;
176    use crate::parser::{ParseOptions, ParserContext};
177    use crate::statements::statement::Statement;
178
179    #[test]
180    fn test_display_copy_from_tb() {
181        let sql = r"copy tbl from 's3://my-bucket/data.parquet'
182            with (format = 'parquet', pattern = '.*parquet.*')
183            connection(region = 'us-west-2', secret_access_key = '12345678');";
184        let stmts =
185            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
186                .unwrap();
187        assert_eq!(1, stmts.len());
188        assert_matches!(&stmts[0], Statement::Copy { .. });
189
190        match &stmts[0] {
191            Statement::Copy(copy) => {
192                let new_sql = format!("{}", copy);
193                assert_eq!(
194                    r#"COPY tbl FROM s3://my-bucket/data.parquet WITH (format = 'parquet', pattern = '.*parquet.*') CONNECTION (region = 'us-west-2', secret_access_key = '******')"#,
195                    &new_sql
196                );
197            }
198            _ => {
199                unreachable!();
200            }
201        }
202    }
203
204    #[test]
205    fn test_display_copy_to_tb() {
206        let sql = r"copy tbl to 's3://my-bucket/data.parquet'
207            with (format = 'parquet', pattern = '.*parquet.*')
208            connection(region = 'us-west-2', secret_access_key = '12345678');";
209        let stmts =
210            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
211                .unwrap();
212        assert_eq!(1, stmts.len());
213        assert_matches!(&stmts[0], Statement::Copy { .. });
214
215        match &stmts[0] {
216            Statement::Copy(copy) => {
217                let new_sql = format!("{}", copy);
218                assert_eq!(
219                    r#"COPY tbl TO s3://my-bucket/data.parquet WITH (format = 'parquet', pattern = '.*parquet.*') CONNECTION (region = 'us-west-2', secret_access_key = '******')"#,
220                    &new_sql
221                );
222            }
223            _ => {
224                unreachable!();
225            }
226        }
227    }
228
229    #[test]
230    fn test_display_copy_from_db() {
231        let sql = r"copy database db1 from 's3://my-bucket/data.parquet'
232            with (format = 'parquet', pattern = '.*parquet.*')
233            connection(region = 'us-west-2', secret_access_key = '12345678');";
234        let stmts =
235            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
236                .unwrap();
237        assert_eq!(1, stmts.len());
238        assert_matches!(&stmts[0], Statement::Copy { .. });
239
240        match &stmts[0] {
241            Statement::Copy(copy) => {
242                let new_sql = format!("{}", copy);
243                assert_eq!(
244                    r#"COPY DATABASE db1 FROM s3://my-bucket/data.parquet WITH (format = 'parquet', pattern = '.*parquet.*') CONNECTION (region = 'us-west-2', secret_access_key = '******')"#,
245                    &new_sql
246                );
247            }
248            _ => {
249                unreachable!();
250            }
251        }
252    }
253
254    #[test]
255    fn test_display_copy_to_db() {
256        let sql = r"copy database db1 to 's3://my-bucket/data.parquet'
257            with (format = 'parquet', pattern = '.*parquet.*')
258            connection(region = 'us-west-2', secret_access_key = '12345678');";
259        let stmts =
260            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
261                .unwrap();
262        assert_eq!(1, stmts.len());
263        assert_matches!(&stmts[0], Statement::Copy { .. });
264
265        match &stmts[0] {
266            Statement::Copy(copy) => {
267                let new_sql = format!("{}", copy);
268                assert_eq!(
269                    r#"COPY DATABASE db1 TO s3://my-bucket/data.parquet WITH (format = 'parquet', pattern = '.*parquet.*') CONNECTION (region = 'us-west-2', secret_access_key = '******')"#,
270                    &new_sql
271                );
272            }
273            _ => {
274                unreachable!();
275            }
276        }
277    }
278
279    #[test]
280    fn test_display_copy_query_to() {
281        let sql = r"copy (SELECT * FROM tbl WHERE ts > 10) to 's3://my-bucket/data.parquet'
282            with (format = 'parquet', pattern = '.*parquet.*')
283            connection(region = 'us-west-2', secret_access_key = '12345678');";
284        let stmts =
285            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
286                .unwrap();
287        assert_eq!(1, stmts.len());
288        assert_matches!(&stmts[0], Statement::Copy { .. });
289
290        match &stmts[0] {
291            Statement::Copy(copy) => {
292                let new_sql = format!("{}", copy);
293                assert_eq!(
294                    r#"COPY (SELECT * FROM tbl WHERE ts > 10) TO s3://my-bucket/data.parquet WITH (format = 'parquet', pattern = '.*parquet.*') CONNECTION (region = 'us-west-2', secret_access_key = '******')"#,
295                    &new_sql
296                );
297            }
298            _ => {
299                unreachable!();
300            }
301        }
302    }
303}