1use std::path::PathBuf;
16use std::sync::Arc;
17use std::time::Duration;
18
19use async_trait::async_trait;
20use clap::{Parser, ValueEnum};
21use common_catalog::consts::DEFAULT_SCHEMA_NAME;
22use common_error::ext::BoxedError;
23use common_telemetry::{error, info, warn};
24use snafu::{OptionExt, ResultExt};
25use tokio::sync::Semaphore;
26use tokio::time::Instant;
27
28use crate::database::{parse_proxy_opts, DatabaseClient};
29use crate::error::{Error, FileIoSnafu, Result, SchemaNotFoundSnafu};
30use crate::{database, Tool};
31
32#[derive(Debug, Default, Clone, ValueEnum)]
33enum ImportTarget {
34 Schema,
36 Data,
38 #[default]
40 All,
41}
42
43#[derive(Debug, Default, Parser)]
45pub struct ImportCommand {
46 #[clap(long)]
48 addr: String,
49
50 #[clap(long)]
52 input_dir: String,
53
54 #[clap(long, default_value = "greptime-*")]
56 database: String,
57
58 #[clap(long, short = 'j', default_value = "1")]
60 import_jobs: usize,
61
62 #[clap(long, default_value = "3")]
64 max_retry: usize,
65
66 #[clap(long, short = 't', value_enum, default_value = "all")]
68 target: ImportTarget,
69
70 #[clap(long)]
72 auth_basic: Option<String>,
73
74 #[clap(long, value_parser = humantime::parse_duration)]
79 timeout: Option<Duration>,
80
81 #[clap(long)]
85 proxy: Option<String>,
86
87 #[clap(long, default_value = "false")]
89 no_proxy: bool,
90}
91
92impl ImportCommand {
93 pub async fn build(&self) -> std::result::Result<Box<dyn Tool>, BoxedError> {
94 let (catalog, schema) =
95 database::split_database(&self.database).map_err(BoxedError::new)?;
96 let proxy = parse_proxy_opts(self.proxy.clone(), self.no_proxy)?;
97 let database_client = DatabaseClient::new(
98 self.addr.clone(),
99 catalog.clone(),
100 self.auth_basic.clone(),
101 self.timeout.unwrap_or_default(),
103 proxy,
104 );
105
106 Ok(Box::new(Import {
107 catalog,
108 schema,
109 database_client,
110 input_dir: self.input_dir.clone(),
111 parallelism: self.import_jobs,
112 target: self.target.clone(),
113 }))
114 }
115}
116
117pub struct Import {
118 catalog: String,
119 schema: Option<String>,
120 database_client: DatabaseClient,
121 input_dir: String,
122 parallelism: usize,
123 target: ImportTarget,
124}
125
126impl Import {
127 async fn import_create_table(&self) -> Result<()> {
128 self.do_sql_job("create_database.sql", Some(DEFAULT_SCHEMA_NAME))
130 .await?;
131 self.do_sql_job("create_tables.sql", None).await
132 }
133
134 async fn import_database_data(&self) -> Result<()> {
135 self.do_sql_job("copy_from.sql", None).await
136 }
137
138 async fn do_sql_job(&self, filename: &str, exec_db: Option<&str>) -> Result<()> {
139 let timer = Instant::now();
140 let semaphore = Arc::new(Semaphore::new(self.parallelism));
141 let db_names = self.get_db_names().await?;
142 let db_count = db_names.len();
143 let mut tasks = Vec::with_capacity(db_count);
144 for schema in db_names {
145 let semaphore_moved = semaphore.clone();
146 tasks.push(async move {
147 let _permit = semaphore_moved.acquire().await.unwrap();
148 let database_input_dir = self.catalog_path().join(&schema);
149 let sql_file = database_input_dir.join(filename);
150 let sql = tokio::fs::read_to_string(sql_file)
151 .await
152 .context(FileIoSnafu)?;
153 if sql.is_empty() {
154 info!("Empty `{filename}` {database_input_dir:?}");
155 } else {
156 let db = exec_db.unwrap_or(&schema);
157 self.database_client.sql(&sql, db).await?;
158 info!("Imported `{filename}` for database {schema}");
159 }
160
161 Ok::<(), Error>(())
162 })
163 }
164
165 let success = futures::future::join_all(tasks)
166 .await
167 .into_iter()
168 .filter(|r| match r {
169 Ok(_) => true,
170 Err(e) => {
171 error!(e; "import {filename} job failed");
172 false
173 }
174 })
175 .count();
176 let elapsed = timer.elapsed();
177 info!("Success {success}/{db_count} `{filename}` jobs, cost: {elapsed:?}");
178
179 Ok(())
180 }
181
182 fn catalog_path(&self) -> PathBuf {
183 PathBuf::from(&self.input_dir).join(&self.catalog)
184 }
185
186 async fn get_db_names(&self) -> Result<Vec<String>> {
187 let db_names = self.all_db_names().await?;
188 let Some(schema) = &self.schema else {
189 return Ok(db_names);
190 };
191
192 db_names
194 .into_iter()
195 .find(|db_name| db_name.to_lowercase() == schema.to_lowercase())
196 .map(|name| vec![name])
197 .context(SchemaNotFoundSnafu {
198 catalog: &self.catalog,
199 schema,
200 })
201 }
202
203 async fn all_db_names(&self) -> Result<Vec<String>> {
210 let mut db_names = vec![];
211 let path = self.catalog_path();
212 let mut entries = tokio::fs::read_dir(path).await.context(FileIoSnafu)?;
213 while let Some(entry) = entries.next_entry().await.context(FileIoSnafu)? {
214 let path = entry.path();
215 if path.is_dir() {
216 let db_name = match path.file_name() {
217 Some(name) => name.to_string_lossy().to_string(),
218 None => {
219 warn!("Failed to get the file name of {:?}", path);
220 continue;
221 }
222 };
223 db_names.push(db_name);
224 }
225 }
226 Ok(db_names)
227 }
228}
229
230#[async_trait]
231impl Tool for Import {
232 async fn do_work(&self) -> std::result::Result<(), BoxedError> {
233 match self.target {
234 ImportTarget::Schema => self.import_create_table().await.map_err(BoxedError::new),
235 ImportTarget::Data => self.import_database_data().await.map_err(BoxedError::new),
236 ImportTarget::All => {
237 self.import_create_table().await.map_err(BoxedError::new)?;
238 self.import_database_data().await.map_err(BoxedError::new)
239 }
240 }
241 }
242}