1use std::path::PathBuf;
16use std::sync::Arc;
17use std::time::Duration;
18
19use async_trait::async_trait;
20use clap::{Parser, ValueEnum};
21use common_catalog::consts::DEFAULT_SCHEMA_NAME;
22use common_error::ext::BoxedError;
23use common_telemetry::{error, info, warn};
24use snafu::{OptionExt, ResultExt, ensure};
25use tokio::sync::Semaphore;
26use tokio::time::Instant;
27
28use crate::data::{COPY_PATH_PLACEHOLDER, default_database};
29use crate::database::{DatabaseClient, parse_proxy_opts};
30use crate::error::{Error, FileIoSnafu, InvalidArgumentsSnafu, Result, SchemaNotFoundSnafu};
31use crate::{Tool, database};
32
33#[derive(Debug, Default, Clone, ValueEnum)]
34enum ImportTarget {
35 Schema,
37 Data,
39 #[default]
41 All,
42}
43
44#[derive(Debug, Default, Parser)]
46pub struct ImportCommand {
47 #[clap(long)]
49 addr: String,
50
51 #[clap(long)]
53 input_dir: String,
54
55 #[clap(long, default_value_t = default_database())]
57 database: String,
58
59 #[clap(long, short = 'j', default_value = "1", alias = "import-jobs")]
63 db_parallelism: usize,
64
65 #[clap(long, default_value = "3")]
67 max_retry: usize,
68
69 #[clap(long, short = 't', value_enum, default_value = "all")]
71 target: ImportTarget,
72
73 #[clap(long)]
75 auth_basic: Option<String>,
76
77 #[clap(long, value_parser = humantime::parse_duration)]
82 timeout: Option<Duration>,
83
84 #[clap(long)]
88 proxy: Option<String>,
89
90 #[clap(long, default_value = "false")]
92 no_proxy: bool,
93}
94
95impl ImportCommand {
96 pub async fn build(&self) -> std::result::Result<Box<dyn Tool>, BoxedError> {
97 let (catalog, schema) =
98 database::split_database(&self.database).map_err(BoxedError::new)?;
99 let proxy = parse_proxy_opts(self.proxy.clone(), self.no_proxy)?;
100 let database_client = DatabaseClient::new(
101 self.addr.clone(),
102 catalog.clone(),
103 self.auth_basic.clone(),
104 self.timeout.unwrap_or_default(),
106 proxy,
107 );
108
109 Ok(Box::new(Import {
110 catalog,
111 schema,
112 database_client,
113 input_dir: self.input_dir.clone(),
114 parallelism: self.db_parallelism,
115 target: self.target.clone(),
116 }))
117 }
118}
119
120pub struct Import {
121 catalog: String,
122 schema: Option<String>,
123 database_client: DatabaseClient,
124 input_dir: String,
125 parallelism: usize,
126 target: ImportTarget,
127}
128
129impl Import {
130 async fn import_create_table(&self) -> Result<()> {
131 self.do_sql_job("create_database.sql", Some(DEFAULT_SCHEMA_NAME))
133 .await?;
134 self.do_sql_job("create_tables.sql", None).await
135 }
136
137 async fn import_database_data(&self) -> Result<()> {
138 self.do_sql_job("copy_from.sql", None).await
139 }
140
141 async fn do_sql_job(&self, filename: &str, exec_db: Option<&str>) -> Result<()> {
142 let timer = Instant::now();
143 let semaphore = Arc::new(Semaphore::new(self.parallelism));
144 let db_names = self.get_db_names().await?;
145 let db_count = db_names.len();
146 let mut tasks = Vec::with_capacity(db_count);
147 for schema in db_names {
148 let semaphore_moved = semaphore.clone();
149 tasks.push(async move {
150 let _permit = semaphore_moved.acquire().await.unwrap();
151 let database_input_dir = self.catalog_path().join(&schema);
152 let sql_file = database_input_dir.join(filename);
153 let mut sql = tokio::fs::read_to_string(sql_file)
154 .await
155 .context(FileIoSnafu)?;
156 if sql.trim().is_empty() {
157 info!("Empty `{filename}` {database_input_dir:?}");
158 } else {
159 if filename == "copy_from.sql" {
160 sql = self.rewrite_copy_database_sql(&schema, &sql)?;
161 }
162 let db = exec_db.unwrap_or(&schema);
163 self.database_client.sql(&sql, db).await?;
164 info!("Imported `{filename}` for database {schema}");
165 }
166
167 Ok::<(), Error>(())
168 })
169 }
170
171 let success = futures::future::join_all(tasks)
172 .await
173 .into_iter()
174 .filter(|r| match r {
175 Ok(_) => true,
176 Err(e) => {
177 error!(e; "import {filename} job failed");
178 false
179 }
180 })
181 .count();
182 let elapsed = timer.elapsed();
183 info!("Success {success}/{db_count} `{filename}` jobs, cost: {elapsed:?}");
184
185 Ok(())
186 }
187
188 fn catalog_path(&self) -> PathBuf {
189 PathBuf::from(&self.input_dir).join(&self.catalog)
190 }
191
192 async fn get_db_names(&self) -> Result<Vec<String>> {
193 let db_names = self.all_db_names().await?;
194 let Some(schema) = &self.schema else {
195 return Ok(db_names);
196 };
197
198 db_names
200 .into_iter()
201 .find(|db_name| db_name.to_lowercase() == schema.to_lowercase())
202 .map(|name| vec![name])
203 .context(SchemaNotFoundSnafu {
204 catalog: &self.catalog,
205 schema,
206 })
207 }
208
209 async fn all_db_names(&self) -> Result<Vec<String>> {
216 let mut db_names = vec![];
217 let path = self.catalog_path();
218 let mut entries = tokio::fs::read_dir(path).await.context(FileIoSnafu)?;
219 while let Some(entry) = entries.next_entry().await.context(FileIoSnafu)? {
220 let path = entry.path();
221 if path.is_dir() {
222 let db_name = match path.file_name() {
223 Some(name) => name.to_string_lossy().to_string(),
224 None => {
225 warn!("Failed to get the file name of {:?}", path);
226 continue;
227 }
228 };
229 db_names.push(db_name);
230 }
231 }
232 Ok(db_names)
233 }
234
235 fn rewrite_copy_database_sql(&self, schema: &str, sql: &str) -> Result<String> {
236 let target_location = self.build_copy_database_location(schema);
237 let escaped_location = target_location.replace('\'', "''");
238
239 let mut first_stmt_checked = false;
240 for line in sql.lines() {
241 let trimmed = line.trim_start();
242 if trimmed.is_empty() || trimmed.starts_with("--") {
243 continue;
244 }
245
246 ensure!(
247 trimmed.starts_with("COPY DATABASE"),
248 InvalidArgumentsSnafu {
249 msg: "Expected COPY DATABASE statement at start of copy_from.sql"
250 }
251 );
252 first_stmt_checked = true;
253 break;
254 }
255
256 ensure!(
257 first_stmt_checked,
258 InvalidArgumentsSnafu {
259 msg: "COPY DATABASE statement not found in copy_from.sql"
260 }
261 );
262
263 ensure!(
264 sql.contains(COPY_PATH_PLACEHOLDER),
265 InvalidArgumentsSnafu {
266 msg: format!(
267 "Placeholder `{}` not found in COPY DATABASE statement",
268 COPY_PATH_PLACEHOLDER
269 )
270 }
271 );
272
273 Ok(sql.replacen(COPY_PATH_PLACEHOLDER, &escaped_location, 1))
274 }
275
276 fn build_copy_database_location(&self, schema: &str) -> String {
277 let mut path = self.catalog_path();
278 path.push(schema);
279 let mut path_str = path.to_string_lossy().into_owned();
280 if !path_str.ends_with('/') {
281 path_str.push('/');
282 }
283 path_str
284 }
285}
286
287#[async_trait]
288impl Tool for Import {
289 async fn do_work(&self) -> std::result::Result<(), BoxedError> {
290 match self.target {
291 ImportTarget::Schema => self.import_create_table().await.map_err(BoxedError::new),
292 ImportTarget::Data => self.import_database_data().await.map_err(BoxedError::new),
293 ImportTarget::All => {
294 self.import_create_table().await.map_err(BoxedError::new)?;
295 self.import_database_data().await.map_err(BoxedError::new)
296 }
297 }
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use std::time::Duration;
304
305 use super::*;
306
307 fn build_import(input_dir: &str) -> Import {
308 Import {
309 catalog: "catalog".to_string(),
310 schema: None,
311 database_client: DatabaseClient::new(
312 "127.0.0.1:4000".to_string(),
313 "catalog".to_string(),
314 None,
315 Duration::from_secs(0),
316 None,
317 ),
318 input_dir: input_dir.to_string(),
319 parallelism: 1,
320 target: ImportTarget::Data,
321 }
322 }
323
324 #[test]
325 fn rewrite_copy_database_sql_replaces_placeholder() {
326 let import = build_import("/tmp/export-path");
327 let comment = "-- COPY DATABASE \"catalog\".\"schema\" FROM 's3://bucket/demo/' WITH (format = 'parquet') CONNECTION (region = 'us-west-2')";
328 let sql = format!(
329 "{comment}\nCOPY DATABASE \"catalog\".\"schema\" FROM '{}' WITH (format = 'parquet');",
330 COPY_PATH_PLACEHOLDER
331 );
332
333 let rewritten = import.rewrite_copy_database_sql("schema", &sql).unwrap();
334 let expected_location = import.build_copy_database_location("schema");
335 let escaped = expected_location.replace('\'', "''");
336
337 assert!(rewritten.starts_with(comment));
338 assert!(rewritten.contains(&format!("FROM '{escaped}'")));
339 assert!(!rewritten.contains(COPY_PATH_PLACEHOLDER));
340 }
341
342 #[test]
343 fn rewrite_copy_database_sql_requires_placeholder() {
344 let import = build_import("/tmp/export-path");
345 let sql = "COPY DATABASE \"catalog\".\"schema\" FROM '/tmp/export-path/catalog/schema/' WITH (format = 'parquet');";
346 assert!(import.rewrite_copy_database_sql("schema", sql).is_err());
347 }
348}