1use std::collections::HashMap;
16use std::path::Path;
17use std::str::FromStr;
18use std::sync::Arc;
19
20use client::{Output, OutputData, OutputMeta};
21use common_catalog::format_full_table_name;
22use common_datasource::file_format::Format;
23use common_datasource::lister::{Lister, Source};
24use common_datasource::object_store::build_backend;
25use common_stat::get_total_cpu_cores;
26use common_telemetry::{debug, error, info, tracing};
27use futures::future::try_join_all;
28use object_store::Entry;
29use regex::Regex;
30use session::context::QueryContextRef;
31use snafu::{OptionExt, ResultExt, ensure};
32use store_api::metric_engine_consts::{LOGICAL_TABLE_METADATA_KEY, METRIC_ENGINE_NAME};
33use table::requests::{CopyDatabaseRequest, CopyDirection, CopyTableRequest};
34use table::table_reference::TableReference;
35use tokio::sync::Semaphore;
36
37use crate::error;
38use crate::error::{CatalogSnafu, InvalidCopyDatabasePathSnafu};
39use crate::statement::StatementExecutor;
40
41pub(crate) const COPY_DATABASE_TIME_START_KEY: &str = "start_time";
42pub(crate) const COPY_DATABASE_TIME_END_KEY: &str = "end_time";
43pub(crate) const CONTINUE_ON_ERROR_KEY: &str = "continue_on_error";
44pub(crate) const PARALLELISM_KEY: &str = "parallelism";
45
46fn parse_parallelism_from_option_map(options: &HashMap<String, String>) -> usize {
48 options
49 .get(PARALLELISM_KEY)
50 .and_then(|v| v.parse::<usize>().ok())
51 .unwrap_or_else(get_total_cpu_cores)
52 .max(1)
53}
54
55impl StatementExecutor {
56 #[tracing::instrument(skip_all)]
57 pub(crate) async fn copy_database_to(
58 &self,
59 req: CopyDatabaseRequest,
60 ctx: QueryContextRef,
61 ) -> error::Result<Output> {
62 ensure!(
64 req.location.ends_with('/'),
65 InvalidCopyDatabasePathSnafu {
66 value: req.location,
67 }
68 );
69
70 let parallelism = parse_parallelism_from_option_map(&req.with);
71 info!(
72 "Copy database {}.{} to dir: {}, time: {:?}, parallelism: {}",
73 req.catalog_name, req.schema_name, req.location, req.time_range, parallelism
74 );
75 let table_names = self
76 .catalog_manager
77 .table_names(&req.catalog_name, &req.schema_name, Some(&ctx))
78 .await
79 .context(CatalogSnafu)?;
80 let num_tables = table_names.len();
81
82 let suffix = Format::try_from(&req.with)
83 .context(error::ParseFileFormatSnafu)?
84 .suffix();
85
86 let mut tasks = Vec::with_capacity(num_tables);
87 let semaphore = Arc::new(Semaphore::new(parallelism));
88
89 for (i, table_name) in table_names.into_iter().enumerate() {
90 let table = self
91 .get_table(&TableReference {
92 catalog: &req.catalog_name,
93 schema: &req.schema_name,
94 table: &table_name,
95 })
96 .await?;
97 if table.table_type() != table::metadata::TableType::Base {
99 continue;
100 }
101 if table.table_info().meta.engine == METRIC_ENGINE_NAME
103 && !table
104 .table_info()
105 .meta
106 .options
107 .extra_options
108 .contains_key(LOGICAL_TABLE_METADATA_KEY)
109 {
110 continue;
111 }
112
113 let semaphore_moved = semaphore.clone();
114 let mut table_file = req.location.clone();
115 table_file.push_str(&table_name);
116 table_file.push_str(suffix);
117 let table_no = i + 1;
118 let moved_ctx = ctx.clone();
119 let full_table_name =
120 format_full_table_name(&req.catalog_name, &req.schema_name, &table_name);
121 let copy_table_req = CopyTableRequest {
122 catalog_name: req.catalog_name.clone(),
123 schema_name: req.schema_name.clone(),
124 table_name,
125 location: table_file.clone(),
126 with: req.with.clone(),
127 connection: req.connection.clone(),
128 pattern: None,
129 direction: CopyDirection::Export,
130 timestamp_range: req.time_range,
131 limit: None,
132 };
133
134 tasks.push(async move {
135 let _permit = semaphore_moved.acquire().await.unwrap();
136 info!(
137 "Copy table({}/{}): {} to {}",
138 table_no, num_tables, full_table_name, table_file
139 );
140 self.copy_table_to(copy_table_req, moved_ctx).await
141 });
142 }
143
144 let results = try_join_all(tasks).await?;
145 let exported_rows = results.into_iter().sum();
146 Ok(Output::new_with_affected_rows(exported_rows))
147 }
148
149 #[tracing::instrument(skip_all)]
151 pub(crate) async fn copy_database_from(
152 &self,
153 req: CopyDatabaseRequest,
154 ctx: QueryContextRef,
155 ) -> error::Result<Output> {
156 ensure!(
158 req.location.ends_with('/'),
159 InvalidCopyDatabasePathSnafu {
160 value: req.location,
161 }
162 );
163
164 let parallelism = parse_parallelism_from_option_map(&req.with);
165 info!(
166 "Copy database {}.{} from dir: {}, time: {:?}, parallelism: {}",
167 req.catalog_name, req.schema_name, req.location, req.time_range, parallelism
168 );
169 let suffix = Format::try_from(&req.with)
170 .context(error::ParseFileFormatSnafu)?
171 .suffix();
172
173 let entries = list_files_to_copy(&req, suffix).await?;
174
175 let continue_on_error = req
176 .with
177 .get(CONTINUE_ON_ERROR_KEY)
178 .and_then(|v| bool::from_str(v).ok())
179 .unwrap_or(false);
180
181 let mut tasks = Vec::with_capacity(entries.len());
182 let semaphore = Arc::new(Semaphore::new(parallelism));
183
184 for e in entries {
185 let table_name = match parse_file_name_to_copy(&e) {
186 Ok(table_name) => table_name,
187 Err(err) => {
188 if continue_on_error {
189 error!(err; "Failed to import table from file: {:?}", e);
190 continue;
191 } else {
192 return Err(err);
193 }
194 }
195 };
196
197 let req = CopyTableRequest {
198 catalog_name: req.catalog_name.clone(),
199 schema_name: req.schema_name.clone(),
200 table_name: table_name.clone(),
201 location: format!("{}/{}", req.location, e.path()),
202 with: req.with.clone(),
203 connection: req.connection.clone(),
204 pattern: None,
205 direction: CopyDirection::Import,
206 timestamp_range: None,
207 limit: None,
208 };
209 let moved_ctx = ctx.clone();
210 let moved_table_name = table_name.clone();
211 let moved_semaphore = semaphore.clone();
212 tasks.push(async move {
213 let _permit = moved_semaphore.acquire().await.unwrap();
214 debug!("Copy table, arg: {:?}", req);
215 match self.copy_table_from(req, moved_ctx).await {
216 Ok(o) => {
217 let (rows, cost) = o.extract_rows_and_cost();
218 Ok((rows, cost))
219 }
220 Err(err) => {
221 if continue_on_error {
222 error!(err; "Failed to import file to table: {}", moved_table_name);
223 Ok((0, 0))
224 } else {
225 Err(err)
226 }
227 }
228 }
229 });
230 }
231
232 let results = try_join_all(tasks).await?;
233 let (rows_inserted, insert_cost) = results
234 .into_iter()
235 .fold((0, 0), |(acc_rows, acc_cost), (rows, cost)| {
236 (acc_rows + rows, acc_cost + cost)
237 });
238
239 Ok(Output::new(
240 OutputData::AffectedRows(rows_inserted),
241 OutputMeta::new_with_cost(insert_cost),
242 ))
243 }
244}
245
246fn parse_file_name_to_copy(e: &Entry) -> error::Result<String> {
248 Path::new(e.name())
249 .file_stem()
250 .and_then(|os_str| os_str.to_str())
251 .map(|s| s.to_string())
252 .context(error::InvalidTableNameSnafu {
253 table_name: e.name().to_string(),
254 })
255}
256
257async fn list_files_to_copy(req: &CopyDatabaseRequest, suffix: &str) -> error::Result<Vec<Entry>> {
259 let object_store =
260 build_backend(&req.location, &req.connection).context(error::BuildBackendSnafu)?;
261
262 let pattern = Regex::try_from(format!(".*{}", suffix)).context(error::BuildRegexSnafu)?;
263 let lister = Lister::new(
264 object_store.clone(),
265 Source::Dir,
266 "/".to_string(),
267 Some(pattern),
268 );
269 lister.list().await.context(error::ListObjectsSnafu)
270}
271
272#[cfg(test)]
273mod tests {
274 use std::collections::{HashMap, HashSet};
275
276 use common_stat::get_total_cpu_cores;
277 use object_store::ObjectStore;
278 use object_store::services::Fs;
279 use object_store::util::normalize_dir;
280 use path_slash::PathExt;
281 use table::requests::CopyDatabaseRequest;
282
283 use crate::statement::copy_database::{
284 list_files_to_copy, parse_file_name_to_copy, parse_parallelism_from_option_map,
285 };
286
287 #[tokio::test]
288 async fn test_list_files_and_parse_table_name() {
289 let dir = common_test_util::temp_dir::create_temp_dir("test_list_files_to_copy");
290 let store_dir = normalize_dir(dir.path().to_str().unwrap());
291 let builder = Fs::default().root(&store_dir);
292 let object_store = ObjectStore::new(builder).unwrap().finish();
293 object_store.write("a.parquet", "").await.unwrap();
294 object_store.write("b.parquet", "").await.unwrap();
295 object_store.write("c.csv", "").await.unwrap();
296 object_store.write("d", "").await.unwrap();
297 object_store.write("e.f.parquet", "").await.unwrap();
298
299 let location = normalize_dir(&dir.path().to_slash().unwrap());
300 let request = CopyDatabaseRequest {
301 catalog_name: "catalog_0".to_string(),
302 schema_name: "schema_0".to_string(),
303 location,
304 with: [("FORMAT".to_string(), "parquet".to_string())]
305 .into_iter()
306 .collect(),
307 connection: Default::default(),
308 time_range: None,
309 };
310 let listed = list_files_to_copy(&request, ".parquet")
311 .await
312 .unwrap()
313 .into_iter()
314 .map(|e| parse_file_name_to_copy(&e).unwrap())
315 .collect::<HashSet<_>>();
316
317 assert_eq!(
318 ["a".to_string(), "b".to_string(), "e.f".to_string()]
319 .into_iter()
320 .collect::<HashSet<_>>(),
321 listed
322 );
323 }
324
325 #[test]
326 fn test_parse_parallelism_from_option_map() {
327 let options = HashMap::new();
328 assert_eq!(
329 parse_parallelism_from_option_map(&options),
330 get_total_cpu_cores()
331 );
332
333 let options = HashMap::from([("parallelism".to_string(), "0".to_string())]);
334 assert_eq!(parse_parallelism_from_option_map(&options), 1);
335 }
336}