1use std::path::PathBuf;
16use std::sync::Arc;
17use std::time::Duration;
18
19use async_trait::async_trait;
20use clap::{Parser, ValueEnum};
21use common_catalog::consts::DEFAULT_SCHEMA_NAME;
22use common_error::ext::BoxedError;
23use common_telemetry::{error, info, warn};
24use snafu::{OptionExt, ResultExt, ensure};
25use tokio::sync::Semaphore;
26use tokio::time::Instant;
27
28use crate::data::{COPY_PATH_PLACEHOLDER, default_database};
29use crate::database::{DatabaseClient, parse_proxy_opts};
30use crate::error::{Error, FileIoSnafu, InvalidArgumentsSnafu, Result, SchemaNotFoundSnafu};
31use crate::{Tool, database};
32
33#[derive(Debug, Default, Clone, ValueEnum)]
34enum ImportTarget {
35 Schema,
37 Data,
39 #[default]
41 All,
42}
43
44#[derive(Debug, Default, Parser)]
46pub struct ImportCommand {
47 #[clap(long)]
49 addr: String,
50
51 #[clap(long)]
53 input_dir: String,
54
55 #[clap(long, default_value_t = default_database())]
57 database: String,
58
59 #[clap(long, short = 'j', default_value = "1", alias = "import-jobs")]
63 db_parallelism: usize,
64
65 #[clap(long, default_value = "3")]
67 max_retry: usize,
68
69 #[clap(long, short = 't', value_enum, default_value = "all")]
71 target: ImportTarget,
72
73 #[clap(long)]
75 auth_basic: Option<String>,
76
77 #[clap(long, value_parser = humantime::parse_duration)]
82 timeout: Option<Duration>,
83
84 #[clap(long)]
89 proxy: Option<String>,
90
91 #[clap(long, default_value = "false")]
95 no_proxy: bool,
96}
97
98impl ImportCommand {
99 pub async fn build(&self) -> std::result::Result<Box<dyn Tool>, BoxedError> {
100 let (catalog, schema) =
101 database::split_database(&self.database).map_err(BoxedError::new)?;
102 let proxy = parse_proxy_opts(self.proxy.clone(), self.no_proxy)?;
103 let database_client = DatabaseClient::new(
104 self.addr.clone(),
105 catalog.clone(),
106 self.auth_basic.clone(),
107 self.timeout.unwrap_or_default(),
109 proxy,
110 self.no_proxy,
111 );
112
113 Ok(Box::new(Import {
114 catalog,
115 schema,
116 database_client,
117 input_dir: self.input_dir.clone(),
118 parallelism: self.db_parallelism,
119 target: self.target.clone(),
120 }))
121 }
122}
123
124pub struct Import {
125 catalog: String,
126 schema: Option<String>,
127 database_client: DatabaseClient,
128 input_dir: String,
129 parallelism: usize,
130 target: ImportTarget,
131}
132
133impl Import {
134 async fn import_create_table(&self) -> Result<()> {
135 self.do_sql_job("create_database.sql", Some(DEFAULT_SCHEMA_NAME))
137 .await?;
138 self.do_sql_job("create_tables.sql", None).await
139 }
140
141 async fn import_database_data(&self) -> Result<()> {
142 self.do_sql_job("copy_from.sql", None).await
143 }
144
145 async fn do_sql_job(&self, filename: &str, exec_db: Option<&str>) -> Result<()> {
146 let timer = Instant::now();
147 let semaphore = Arc::new(Semaphore::new(self.parallelism));
148 let db_names = self.get_db_names().await?;
149 let db_count = db_names.len();
150 let mut tasks = Vec::with_capacity(db_count);
151 for schema in db_names {
152 let semaphore_moved = semaphore.clone();
153 tasks.push(async move {
154 let _permit = semaphore_moved.acquire().await.unwrap();
155 let database_input_dir = self.catalog_path().join(&schema);
156 let sql_file = database_input_dir.join(filename);
157 let mut sql = tokio::fs::read_to_string(sql_file)
158 .await
159 .context(FileIoSnafu)?;
160 if sql.trim().is_empty() {
161 info!("Empty `{filename}` {database_input_dir:?}");
162 } else {
163 if filename == "copy_from.sql" {
164 sql = self.rewrite_copy_database_sql(&schema, &sql)?;
165 }
166 let db = exec_db.unwrap_or(&schema);
167 self.database_client.sql(&sql, db).await?;
168 info!("Imported `{filename}` for database {schema}");
169 }
170
171 Ok::<(), Error>(())
172 })
173 }
174
175 let success = futures::future::join_all(tasks)
176 .await
177 .into_iter()
178 .filter(|r| match r {
179 Ok(_) => true,
180 Err(e) => {
181 error!(e; "import {filename} job failed");
182 false
183 }
184 })
185 .count();
186 let elapsed = timer.elapsed();
187 info!("Success {success}/{db_count} `{filename}` jobs, cost: {elapsed:?}");
188
189 Ok(())
190 }
191
192 fn catalog_path(&self) -> PathBuf {
193 PathBuf::from(&self.input_dir).join(&self.catalog)
194 }
195
196 async fn get_db_names(&self) -> Result<Vec<String>> {
197 let db_names = self.all_db_names().await?;
198 let Some(schema) = &self.schema else {
199 return Ok(db_names);
200 };
201
202 db_names
204 .into_iter()
205 .find(|db_name| db_name.to_lowercase() == schema.to_lowercase())
206 .map(|name| vec![name])
207 .context(SchemaNotFoundSnafu {
208 catalog: &self.catalog,
209 schema,
210 })
211 }
212
213 async fn all_db_names(&self) -> Result<Vec<String>> {
220 let mut db_names = vec![];
221 let path = self.catalog_path();
222 let mut entries = tokio::fs::read_dir(path).await.context(FileIoSnafu)?;
223 while let Some(entry) = entries.next_entry().await.context(FileIoSnafu)? {
224 let path = entry.path();
225 if path.is_dir() {
226 let db_name = match path.file_name() {
227 Some(name) => name.to_string_lossy().to_string(),
228 None => {
229 warn!("Failed to get the file name of {:?}", path);
230 continue;
231 }
232 };
233 db_names.push(db_name);
234 }
235 }
236 Ok(db_names)
237 }
238
239 fn rewrite_copy_database_sql(&self, schema: &str, sql: &str) -> Result<String> {
240 let target_location = self.build_copy_database_location(schema);
241 let escaped_location = target_location.replace('\'', "''");
242
243 let mut first_stmt_checked = false;
244 for line in sql.lines() {
245 let trimmed = line.trim_start();
246 if trimmed.is_empty() || trimmed.starts_with("--") {
247 continue;
248 }
249
250 ensure!(
251 trimmed.starts_with("COPY DATABASE"),
252 InvalidArgumentsSnafu {
253 msg: "Expected COPY DATABASE statement at start of copy_from.sql"
254 }
255 );
256 first_stmt_checked = true;
257 break;
258 }
259
260 ensure!(
261 first_stmt_checked,
262 InvalidArgumentsSnafu {
263 msg: "COPY DATABASE statement not found in copy_from.sql"
264 }
265 );
266
267 ensure!(
268 sql.contains(COPY_PATH_PLACEHOLDER),
269 InvalidArgumentsSnafu {
270 msg: format!(
271 "Placeholder `{}` not found in COPY DATABASE statement",
272 COPY_PATH_PLACEHOLDER
273 )
274 }
275 );
276
277 Ok(sql.replacen(COPY_PATH_PLACEHOLDER, &escaped_location, 1))
278 }
279
280 fn build_copy_database_location(&self, schema: &str) -> String {
281 let mut path = self.catalog_path();
282 path.push(schema);
283 let mut path_str = path.to_string_lossy().into_owned();
284 if !path_str.ends_with('/') {
285 path_str.push('/');
286 }
287 path_str
288 }
289}
290
291#[async_trait]
292impl Tool for Import {
293 async fn do_work(&self) -> std::result::Result<(), BoxedError> {
294 match self.target {
295 ImportTarget::Schema => self.import_create_table().await.map_err(BoxedError::new),
296 ImportTarget::Data => self.import_database_data().await.map_err(BoxedError::new),
297 ImportTarget::All => {
298 self.import_create_table().await.map_err(BoxedError::new)?;
299 self.import_database_data().await.map_err(BoxedError::new)
300 }
301 }
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use std::time::Duration;
308
309 use super::*;
310
311 fn build_import(input_dir: &str) -> Import {
312 Import {
313 catalog: "catalog".to_string(),
314 schema: None,
315 database_client: DatabaseClient::new(
316 "127.0.0.1:4000".to_string(),
317 "catalog".to_string(),
318 None,
319 Duration::from_secs(0),
320 None,
321 false,
322 ),
323 input_dir: input_dir.to_string(),
324 parallelism: 1,
325 target: ImportTarget::Data,
326 }
327 }
328
329 #[test]
330 fn rewrite_copy_database_sql_replaces_placeholder() {
331 let import = build_import("/tmp/export-path");
332 let comment = "-- COPY DATABASE \"catalog\".\"schema\" FROM 's3://bucket/demo/' WITH (format = 'parquet') CONNECTION (region = 'us-west-2')";
333 let sql = format!(
334 "{comment}\nCOPY DATABASE \"catalog\".\"schema\" FROM '{}' WITH (format = 'parquet');",
335 COPY_PATH_PLACEHOLDER
336 );
337
338 let rewritten = import.rewrite_copy_database_sql("schema", &sql).unwrap();
339 let expected_location = import.build_copy_database_location("schema");
340 let escaped = expected_location.replace('\'', "''");
341
342 assert!(rewritten.starts_with(comment));
343 assert!(rewritten.contains(&format!("FROM '{escaped}'")));
344 assert!(!rewritten.contains(COPY_PATH_PLACEHOLDER));
345 }
346
347 #[test]
348 fn rewrite_copy_database_sql_requires_placeholder() {
349 let import = build_import("/tmp/export-path");
350 let sql = "COPY DATABASE \"catalog\".\"schema\" FROM '/tmp/export-path/catalog/schema/' WITH (format = 'parquet');";
351 assert!(import.rewrite_copy_database_sql("schema", sql).is_err());
352 }
353}