1use std::path::PathBuf;
16use std::sync::Arc;
17use std::time::Duration;
18
19use async_trait::async_trait;
20use clap::{Parser, ValueEnum};
21use common_catalog::consts::DEFAULT_SCHEMA_NAME;
22use common_error::ext::BoxedError;
23use common_telemetry::{error, info, warn};
24use snafu::{OptionExt, ResultExt};
25use tokio::sync::Semaphore;
26use tokio::time::Instant;
27
28use crate::database::{parse_proxy_opts, DatabaseClient};
29use crate::error::{Error, FileIoSnafu, Result, SchemaNotFoundSnafu};
30use crate::{database, Tool};
31
32#[derive(Debug, Default, Clone, ValueEnum)]
33enum ImportTarget {
34 Schema,
36 Data,
38 #[default]
40 All,
41}
42
43#[derive(Debug, Default, Parser)]
44pub struct ImportCommand {
45 #[clap(long)]
47 addr: String,
48
49 #[clap(long)]
51 input_dir: String,
52
53 #[clap(long, default_value = "greptime-*")]
55 database: String,
56
57 #[clap(long, short = 'j', default_value = "1")]
59 import_jobs: usize,
60
61 #[clap(long, default_value = "3")]
63 max_retry: usize,
64
65 #[clap(long, short = 't', value_enum, default_value = "all")]
67 target: ImportTarget,
68
69 #[clap(long)]
71 auth_basic: Option<String>,
72
73 #[clap(long, value_parser = humantime::parse_duration)]
78 timeout: Option<Duration>,
79
80 #[clap(long)]
84 proxy: Option<String>,
85
86 #[clap(long, default_value = "false")]
88 no_proxy: bool,
89}
90
91impl ImportCommand {
92 pub async fn build(&self) -> std::result::Result<Box<dyn Tool>, BoxedError> {
93 let (catalog, schema) =
94 database::split_database(&self.database).map_err(BoxedError::new)?;
95 let proxy = parse_proxy_opts(self.proxy.clone(), self.no_proxy)?;
96 let database_client = DatabaseClient::new(
97 self.addr.clone(),
98 catalog.clone(),
99 self.auth_basic.clone(),
100 self.timeout.unwrap_or_default(),
102 proxy,
103 );
104
105 Ok(Box::new(Import {
106 catalog,
107 schema,
108 database_client,
109 input_dir: self.input_dir.clone(),
110 parallelism: self.import_jobs,
111 target: self.target.clone(),
112 }))
113 }
114}
115
116pub struct Import {
117 catalog: String,
118 schema: Option<String>,
119 database_client: DatabaseClient,
120 input_dir: String,
121 parallelism: usize,
122 target: ImportTarget,
123}
124
125impl Import {
126 async fn import_create_table(&self) -> Result<()> {
127 self.do_sql_job("create_database.sql", Some(DEFAULT_SCHEMA_NAME))
129 .await?;
130 self.do_sql_job("create_tables.sql", None).await
131 }
132
133 async fn import_database_data(&self) -> Result<()> {
134 self.do_sql_job("copy_from.sql", None).await
135 }
136
137 async fn do_sql_job(&self, filename: &str, exec_db: Option<&str>) -> Result<()> {
138 let timer = Instant::now();
139 let semaphore = Arc::new(Semaphore::new(self.parallelism));
140 let db_names = self.get_db_names().await?;
141 let db_count = db_names.len();
142 let mut tasks = Vec::with_capacity(db_count);
143 for schema in db_names {
144 let semaphore_moved = semaphore.clone();
145 tasks.push(async move {
146 let _permit = semaphore_moved.acquire().await.unwrap();
147 let database_input_dir = self.catalog_path().join(&schema);
148 let sql_file = database_input_dir.join(filename);
149 let sql = tokio::fs::read_to_string(sql_file)
150 .await
151 .context(FileIoSnafu)?;
152 if sql.is_empty() {
153 info!("Empty `{filename}` {database_input_dir:?}");
154 } else {
155 let db = exec_db.unwrap_or(&schema);
156 self.database_client.sql(&sql, db).await?;
157 info!("Imported `{filename}` for database {schema}");
158 }
159
160 Ok::<(), Error>(())
161 })
162 }
163
164 let success = futures::future::join_all(tasks)
165 .await
166 .into_iter()
167 .filter(|r| match r {
168 Ok(_) => true,
169 Err(e) => {
170 error!(e; "import {filename} job failed");
171 false
172 }
173 })
174 .count();
175 let elapsed = timer.elapsed();
176 info!("Success {success}/{db_count} `{filename}` jobs, cost: {elapsed:?}");
177
178 Ok(())
179 }
180
181 fn catalog_path(&self) -> PathBuf {
182 PathBuf::from(&self.input_dir).join(&self.catalog)
183 }
184
185 async fn get_db_names(&self) -> Result<Vec<String>> {
186 let db_names = self.all_db_names().await?;
187 let Some(schema) = &self.schema else {
188 return Ok(db_names);
189 };
190
191 db_names
193 .into_iter()
194 .find(|db_name| db_name.to_lowercase() == schema.to_lowercase())
195 .map(|name| vec![name])
196 .context(SchemaNotFoundSnafu {
197 catalog: &self.catalog,
198 schema,
199 })
200 }
201
202 async fn all_db_names(&self) -> Result<Vec<String>> {
209 let mut db_names = vec![];
210 let path = self.catalog_path();
211 let mut entries = tokio::fs::read_dir(path).await.context(FileIoSnafu)?;
212 while let Some(entry) = entries.next_entry().await.context(FileIoSnafu)? {
213 let path = entry.path();
214 if path.is_dir() {
215 let db_name = match path.file_name() {
216 Some(name) => name.to_string_lossy().to_string(),
217 None => {
218 warn!("Failed to get the file name of {:?}", path);
219 continue;
220 }
221 };
222 db_names.push(db_name);
223 }
224 }
225 Ok(db_names)
226 }
227}
228
229#[async_trait]
230impl Tool for Import {
231 async fn do_work(&self) -> std::result::Result<(), BoxedError> {
232 match self.target {
233 ImportTarget::Schema => self.import_create_table().await.map_err(BoxedError::new),
234 ImportTarget::Data => self.import_database_data().await.map_err(BoxedError::new),
235 ImportTarget::All => {
236 self.import_create_table().await.map_err(BoxedError::new)?;
237 self.import_database_data().await.map_err(BoxedError::new)
238 }
239 }
240 }
241}