1use std::fmt::Display;
16
17use serde::Serialize;
18use sqlparser::ast::ObjectName;
19use sqlparser_derive::{Visit, VisitMut};
20
21use crate::statements::statement::Statement;
22use crate::statements::OptionMap;
23
24#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
25pub enum Copy {
26 CopyTable(CopyTable),
27 CopyDatabase(CopyDatabase),
28 CopyQueryTo(CopyQueryTo),
29}
30
31impl Display for Copy {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 match self {
34 Copy::CopyTable(s) => s.fmt(f),
35 Copy::CopyDatabase(s) => s.fmt(f),
36 Copy::CopyQueryTo(s) => s.fmt(f),
37 }
38 }
39}
40
41#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
42pub struct CopyQueryTo {
43 pub query: Box<Statement>,
44 pub arg: CopyQueryToArgument,
45}
46
47impl Display for CopyQueryTo {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 write!(f, "COPY ({}) TO {}", &self.query, &self.arg.location)?;
50 if !self.arg.with.is_empty() {
51 let options = self.arg.with.kv_pairs();
52 write!(f, " WITH ({})", options.join(", "))?;
53 }
54 if !self.arg.connection.is_empty() {
55 let options = self.arg.connection.kv_pairs();
56 write!(f, " CONNECTION ({})", options.join(", "))?;
57 }
58 Ok(())
59 }
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
63pub enum CopyTable {
64 To(CopyTableArgument),
65 From(CopyTableArgument),
66}
67
68impl Display for CopyTable {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 write!(f, "COPY ")?;
71 let (with, connection) = match self {
72 CopyTable::To(args) => {
73 write!(f, "{} TO {}", &args.table_name, &args.location)?;
74 (&args.with, &args.connection)
75 }
76 CopyTable::From(args) => {
77 write!(f, "{} FROM {}", &args.table_name, &args.location)?;
78 (&args.with, &args.connection)
79 }
80 };
81 if !with.is_empty() {
82 let options = with.kv_pairs();
83 write!(f, " WITH ({})", options.join(", "))?;
84 }
85 if !connection.is_empty() {
86 let options = connection.kv_pairs();
87 write!(f, " CONNECTION ({})", options.join(", "))?;
88 }
89 Ok(())
90 }
91}
92
93#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
94pub enum CopyDatabase {
95 To(CopyDatabaseArgument),
96 From(CopyDatabaseArgument),
97}
98
99impl Display for CopyDatabase {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 write!(f, "COPY DATABASE ")?;
102 let (with, connection) = match self {
103 CopyDatabase::To(args) => {
104 write!(f, "{} TO {}", &args.database_name, &args.location)?;
105 (&args.with, &args.connection)
106 }
107 CopyDatabase::From(args) => {
108 write!(f, "{} FROM {}", &args.database_name, &args.location)?;
109 (&args.with, &args.connection)
110 }
111 };
112 if !with.is_empty() {
113 let options = with.kv_pairs();
114 write!(f, " WITH ({})", options.join(", "))?;
115 }
116 if !connection.is_empty() {
117 let options = connection.kv_pairs();
118 write!(f, " CONNECTION ({})", options.join(", "))?;
119 }
120 Ok(())
121 }
122}
123
124#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
125pub struct CopyDatabaseArgument {
126 pub database_name: ObjectName,
127 pub with: OptionMap,
128 pub connection: OptionMap,
129 pub location: String,
130}
131
132#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
133pub struct CopyTableArgument {
134 pub table_name: ObjectName,
135 pub with: OptionMap,
136 pub connection: OptionMap,
137 pub location: String,
139 pub limit: Option<u64>,
140}
141
142#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
143pub struct CopyQueryToArgument {
144 pub with: OptionMap,
145 pub connection: OptionMap,
146 pub location: String,
147}
148
149#[cfg(test)]
150impl CopyTableArgument {
151 pub fn format(&self) -> Option<String> {
152 self.with
153 .get(common_datasource::file_format::FORMAT_TYPE)
154 .cloned()
155 .or_else(|| Some("PARQUET".to_string()))
156 }
157
158 pub fn pattern(&self) -> Option<String> {
159 self.with
160 .get(common_datasource::file_format::FILE_PATTERN)
161 .cloned()
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use std::assert_matches::assert_matches;
168
169 use crate::dialect::GreptimeDbDialect;
170 use crate::parser::{ParseOptions, ParserContext};
171 use crate::statements::statement::Statement;
172
173 #[test]
174 fn test_display_copy_from_tb() {
175 let sql = r"copy tbl from 's3://my-bucket/data.parquet'
176 with (format = 'parquet', pattern = '.*parquet.*')
177 connection(region = 'us-west-2', secret_access_key = '12345678');";
178 let stmts =
179 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
180 .unwrap();
181 assert_eq!(1, stmts.len());
182 assert_matches!(&stmts[0], Statement::Copy { .. });
183
184 match &stmts[0] {
185 Statement::Copy(copy) => {
186 let new_sql = format!("{}", copy);
187 assert_eq!(
188 r#"COPY tbl FROM s3://my-bucket/data.parquet WITH (format = 'parquet', pattern = '.*parquet.*') CONNECTION (region = 'us-west-2', secret_access_key = '******')"#,
189 &new_sql
190 );
191 }
192 _ => {
193 unreachable!();
194 }
195 }
196 }
197
198 #[test]
199 fn test_display_copy_to_tb() {
200 let sql = r"copy tbl to 's3://my-bucket/data.parquet'
201 with (format = 'parquet', pattern = '.*parquet.*')
202 connection(region = 'us-west-2', secret_access_key = '12345678');";
203 let stmts =
204 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
205 .unwrap();
206 assert_eq!(1, stmts.len());
207 assert_matches!(&stmts[0], Statement::Copy { .. });
208
209 match &stmts[0] {
210 Statement::Copy(copy) => {
211 let new_sql = format!("{}", copy);
212 assert_eq!(
213 r#"COPY tbl TO s3://my-bucket/data.parquet WITH (format = 'parquet', pattern = '.*parquet.*') CONNECTION (region = 'us-west-2', secret_access_key = '******')"#,
214 &new_sql
215 );
216 }
217 _ => {
218 unreachable!();
219 }
220 }
221 }
222
223 #[test]
224 fn test_display_copy_from_db() {
225 let sql = r"copy database db1 from 's3://my-bucket/data.parquet'
226 with (format = 'parquet', pattern = '.*parquet.*')
227 connection(region = 'us-west-2', secret_access_key = '12345678');";
228 let stmts =
229 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
230 .unwrap();
231 assert_eq!(1, stmts.len());
232 assert_matches!(&stmts[0], Statement::Copy { .. });
233
234 match &stmts[0] {
235 Statement::Copy(copy) => {
236 let new_sql = format!("{}", copy);
237 assert_eq!(
238 r#"COPY DATABASE db1 FROM s3://my-bucket/data.parquet WITH (format = 'parquet', pattern = '.*parquet.*') CONNECTION (region = 'us-west-2', secret_access_key = '******')"#,
239 &new_sql
240 );
241 }
242 _ => {
243 unreachable!();
244 }
245 }
246 }
247
248 #[test]
249 fn test_display_copy_to_db() {
250 let sql = r"copy database db1 to 's3://my-bucket/data.parquet'
251 with (format = 'parquet', pattern = '.*parquet.*')
252 connection(region = 'us-west-2', secret_access_key = '12345678');";
253 let stmts =
254 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
255 .unwrap();
256 assert_eq!(1, stmts.len());
257 assert_matches!(&stmts[0], Statement::Copy { .. });
258
259 match &stmts[0] {
260 Statement::Copy(copy) => {
261 let new_sql = format!("{}", copy);
262 assert_eq!(
263 r#"COPY DATABASE db1 TO s3://my-bucket/data.parquet WITH (format = 'parquet', pattern = '.*parquet.*') CONNECTION (region = 'us-west-2', secret_access_key = '******')"#,
264 &new_sql
265 );
266 }
267 _ => {
268 unreachable!();
269 }
270 }
271 }
272
273 #[test]
274 fn test_display_copy_query_to() {
275 let sql = r"copy (SELECT * FROM tbl WHERE ts > 10) to 's3://my-bucket/data.parquet'
276 with (format = 'parquet', pattern = '.*parquet.*')
277 connection(region = 'us-west-2', secret_access_key = '12345678');";
278 let stmts =
279 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
280 .unwrap();
281 assert_eq!(1, stmts.len());
282 assert_matches!(&stmts[0], Statement::Copy { .. });
283
284 match &stmts[0] {
285 Statement::Copy(copy) => {
286 let new_sql = format!("{}", copy);
287 assert_eq!(
288 r#"COPY (SELECT * FROM tbl WHERE ts > 10) TO s3://my-bucket/data.parquet WITH (format = 'parquet', pattern = '.*parquet.*') CONNECTION (region = 'us-west-2', secret_access_key = '******')"#,
289 &new_sql
290 );
291 }
292 _ => {
293 unreachable!();
294 }
295 }
296 }
297}