1use std::fmt::Display;
16
17use serde::Serialize;
18use sqlparser::ast::ObjectName;
19use sqlparser_derive::{Visit, VisitMut};
20
21use crate::statements::OptionMap;
22use crate::statements::statement::Statement;
23
24#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
25pub enum Copy {
26 CopyTable(CopyTable),
27 CopyDatabase(CopyDatabase),
28 CopyQueryTo(CopyQueryTo),
29}
30
31impl Display for Copy {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 match self {
34 Copy::CopyTable(s) => s.fmt(f),
35 Copy::CopyDatabase(s) => s.fmt(f),
36 Copy::CopyQueryTo(s) => s.fmt(f),
37 }
38 }
39}
40
41#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
42pub struct CopyQueryTo {
43 pub query: Box<Statement>,
44 pub arg: CopyQueryToArgument,
45}
46
47impl Display for CopyQueryTo {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 write!(f, "COPY ({}) TO {}", &self.query, &self.arg.location)?;
50 if !self.arg.with.is_empty() {
51 let options = self.arg.with.kv_pairs();
52 write!(f, " WITH ({})", options.join(", "))?;
53 }
54 if !self.arg.connection.is_empty() {
55 let options = self.arg.connection.kv_pairs();
56 write!(f, " CONNECTION ({})", options.join(", "))?;
57 }
58 Ok(())
59 }
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
63pub enum CopyTable {
64 To(CopyTableArgument),
65 From(CopyTableArgument),
66}
67
68impl Display for CopyTable {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 write!(f, "COPY ")?;
71 let (with, connection) = match self {
72 CopyTable::To(args) => {
73 write!(f, "{} TO {}", &args.table_name, &args.location)?;
74 (&args.with, &args.connection)
75 }
76 CopyTable::From(args) => {
77 write!(f, "{} FROM {}", &args.table_name, &args.location)?;
78 (&args.with, &args.connection)
79 }
80 };
81 if !with.is_empty() {
82 let options = with.kv_pairs();
83 write!(f, " WITH ({})", options.join(", "))?;
84 }
85 if !connection.is_empty() {
86 let options = connection.kv_pairs();
87 write!(f, " CONNECTION ({})", options.join(", "))?;
88 }
89 Ok(())
90 }
91}
92
93#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
94pub enum CopyDatabase {
95 To(CopyDatabaseArgument),
96 From(CopyDatabaseArgument),
97}
98
99impl Display for CopyDatabase {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 write!(f, "COPY DATABASE ")?;
102 let (with, connection) = match self {
103 CopyDatabase::To(args) => {
104 write!(f, "{} TO {}", &args.database_name, &args.location)?;
105 (&args.with, &args.connection)
106 }
107 CopyDatabase::From(args) => {
108 write!(f, "{} FROM {}", &args.database_name, &args.location)?;
109 (&args.with, &args.connection)
110 }
111 };
112 if !with.is_empty() {
113 let options = with.kv_pairs();
114 write!(f, " WITH ({})", options.join(", "))?;
115 }
116 if !connection.is_empty() {
117 let options = connection.kv_pairs();
118 write!(f, " CONNECTION ({})", options.join(", "))?;
119 }
120 Ok(())
121 }
122}
123
124#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
125pub struct CopyDatabaseArgument {
126 pub database_name: ObjectName,
127 pub with: OptionMap,
128 pub connection: OptionMap,
129 pub location: String,
130}
131
132#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
133pub struct CopyTableArgument {
134 pub table_name: ObjectName,
135 pub with: OptionMap,
136 pub connection: OptionMap,
137 pub location: String,
139 pub limit: Option<u64>,
140}
141
142#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
143pub struct CopyQueryToArgument {
144 pub with: OptionMap,
145 pub connection: OptionMap,
146 pub location: String,
147}
148
149#[cfg(test)]
150impl CopyTableArgument {
151 pub fn format(&self) -> Option<String> {
152 self.with
153 .get(common_datasource::file_format::FORMAT_TYPE)
154 .cloned()
155 .or_else(|| Some("PARQUET".to_string()))
156 }
157
158 pub fn pattern(&self) -> Option<String> {
159 self.with
160 .get(common_datasource::file_format::FILE_PATTERN)
161 .cloned()
162 }
163
164 pub fn timestamp_pattern(&self) -> Option<String> {
165 self.with
166 .get(common_datasource::file_format::TIMESTAMP_FORMAT)
167 .cloned()
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use std::assert_matches::assert_matches;
174
175 use crate::dialect::GreptimeDbDialect;
176 use crate::parser::{ParseOptions, ParserContext};
177 use crate::statements::statement::Statement;
178
179 #[test]
180 fn test_display_copy_from_tb() {
181 let sql = r"copy tbl from 's3://my-bucket/data.parquet'
182 with (format = 'parquet', pattern = '.*parquet.*')
183 connection(region = 'us-west-2', secret_access_key = '12345678');";
184 let stmts =
185 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
186 .unwrap();
187 assert_eq!(1, stmts.len());
188 assert_matches!(&stmts[0], Statement::Copy { .. });
189
190 match &stmts[0] {
191 Statement::Copy(copy) => {
192 let new_sql = format!("{}", copy);
193 assert_eq!(
194 r#"COPY tbl FROM s3://my-bucket/data.parquet WITH (format = 'parquet', pattern = '.*parquet.*') CONNECTION (region = 'us-west-2', secret_access_key = '******')"#,
195 &new_sql
196 );
197 }
198 _ => {
199 unreachable!();
200 }
201 }
202 }
203
204 #[test]
205 fn test_display_copy_to_tb() {
206 let sql = r"copy tbl to 's3://my-bucket/data.parquet'
207 with (format = 'parquet', pattern = '.*parquet.*')
208 connection(region = 'us-west-2', secret_access_key = '12345678');";
209 let stmts =
210 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
211 .unwrap();
212 assert_eq!(1, stmts.len());
213 assert_matches!(&stmts[0], Statement::Copy { .. });
214
215 match &stmts[0] {
216 Statement::Copy(copy) => {
217 let new_sql = format!("{}", copy);
218 assert_eq!(
219 r#"COPY tbl TO s3://my-bucket/data.parquet WITH (format = 'parquet', pattern = '.*parquet.*') CONNECTION (region = 'us-west-2', secret_access_key = '******')"#,
220 &new_sql
221 );
222 }
223 _ => {
224 unreachable!();
225 }
226 }
227 }
228
229 #[test]
230 fn test_display_copy_from_db() {
231 let sql = r"copy database db1 from 's3://my-bucket/data.parquet'
232 with (format = 'parquet', pattern = '.*parquet.*')
233 connection(region = 'us-west-2', secret_access_key = '12345678');";
234 let stmts =
235 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
236 .unwrap();
237 assert_eq!(1, stmts.len());
238 assert_matches!(&stmts[0], Statement::Copy { .. });
239
240 match &stmts[0] {
241 Statement::Copy(copy) => {
242 let new_sql = format!("{}", copy);
243 assert_eq!(
244 r#"COPY DATABASE db1 FROM s3://my-bucket/data.parquet WITH (format = 'parquet', pattern = '.*parquet.*') CONNECTION (region = 'us-west-2', secret_access_key = '******')"#,
245 &new_sql
246 );
247 }
248 _ => {
249 unreachable!();
250 }
251 }
252 }
253
254 #[test]
255 fn test_display_copy_to_db() {
256 let sql = r"copy database db1 to 's3://my-bucket/data.parquet'
257 with (format = 'parquet', pattern = '.*parquet.*')
258 connection(region = 'us-west-2', secret_access_key = '12345678');";
259 let stmts =
260 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
261 .unwrap();
262 assert_eq!(1, stmts.len());
263 assert_matches!(&stmts[0], Statement::Copy { .. });
264
265 match &stmts[0] {
266 Statement::Copy(copy) => {
267 let new_sql = format!("{}", copy);
268 assert_eq!(
269 r#"COPY DATABASE db1 TO s3://my-bucket/data.parquet WITH (format = 'parquet', pattern = '.*parquet.*') CONNECTION (region = 'us-west-2', secret_access_key = '******')"#,
270 &new_sql
271 );
272 }
273 _ => {
274 unreachable!();
275 }
276 }
277 }
278
279 #[test]
280 fn test_display_copy_query_to() {
281 let sql = r"copy (SELECT * FROM tbl WHERE ts > 10) to 's3://my-bucket/data.parquet'
282 with (format = 'parquet', pattern = '.*parquet.*')
283 connection(region = 'us-west-2', secret_access_key = '12345678');";
284 let stmts =
285 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
286 .unwrap();
287 assert_eq!(1, stmts.len());
288 assert_matches!(&stmts[0], Statement::Copy { .. });
289
290 match &stmts[0] {
291 Statement::Copy(copy) => {
292 let new_sql = format!("{}", copy);
293 assert_eq!(
294 r#"COPY (SELECT * FROM tbl WHERE ts > 10) TO s3://my-bucket/data.parquet WITH (format = 'parquet', pattern = '.*parquet.*') CONNECTION (region = 'us-west-2', secret_access_key = '******')"#,
295 &new_sql
296 );
297 }
298 _ => {
299 unreachable!();
300 }
301 }
302 }
303}