operator/statement/
copy_database.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::path::Path;
17use std::str::FromStr;
18use std::sync::Arc;
19
20use client::{Output, OutputData, OutputMeta};
21use common_catalog::format_full_table_name;
22use common_datasource::file_format::Format;
23use common_datasource::lister::{Lister, Source};
24use common_datasource::object_store::build_backend;
25use common_stat::get_total_cpu_cores;
26use common_telemetry::{debug, error, info, tracing};
27use futures::future::try_join_all;
28use object_store::Entry;
29use regex::Regex;
30use session::context::QueryContextRef;
31use snafu::{OptionExt, ResultExt, ensure};
32use store_api::metric_engine_consts::{LOGICAL_TABLE_METADATA_KEY, METRIC_ENGINE_NAME};
33use table::requests::{CopyDatabaseRequest, CopyDirection, CopyTableRequest};
34use table::table_reference::TableReference;
35use tokio::sync::Semaphore;
36
37use crate::error;
38use crate::error::{CatalogSnafu, InvalidCopyDatabasePathSnafu};
39use crate::statement::StatementExecutor;
40
41pub(crate) const COPY_DATABASE_TIME_START_KEY: &str = "start_time";
42pub(crate) const COPY_DATABASE_TIME_END_KEY: &str = "end_time";
43pub(crate) const CONTINUE_ON_ERROR_KEY: &str = "continue_on_error";
44pub(crate) const PARALLELISM_KEY: &str = "parallelism";
45
46/// Get parallelism from options, default to total CPU cores.
47fn parse_parallelism_from_option_map(options: &HashMap<String, String>) -> usize {
48    options
49        .get(PARALLELISM_KEY)
50        .and_then(|v| v.parse::<usize>().ok())
51        .unwrap_or_else(get_total_cpu_cores)
52        .max(1)
53}
54
55impl StatementExecutor {
56    #[tracing::instrument(skip_all)]
57    pub(crate) async fn copy_database_to(
58        &self,
59        req: CopyDatabaseRequest,
60        ctx: QueryContextRef,
61    ) -> error::Result<Output> {
62        // location must end with / so that every table is exported to a file.
63        ensure!(
64            req.location.ends_with('/'),
65            InvalidCopyDatabasePathSnafu {
66                value: req.location,
67            }
68        );
69
70        let parallelism = parse_parallelism_from_option_map(&req.with);
71        info!(
72            "Copy database {}.{} to dir: {}, time: {:?}, parallelism: {}",
73            req.catalog_name, req.schema_name, req.location, req.time_range, parallelism
74        );
75        let table_names = self
76            .catalog_manager
77            .table_names(&req.catalog_name, &req.schema_name, Some(&ctx))
78            .await
79            .context(CatalogSnafu)?;
80        let num_tables = table_names.len();
81
82        let suffix = Format::try_from(&req.with)
83            .context(error::ParseFileFormatSnafu)?
84            .suffix();
85
86        let mut tasks = Vec::with_capacity(num_tables);
87        let semaphore = Arc::new(Semaphore::new(parallelism));
88
89        for (i, table_name) in table_names.into_iter().enumerate() {
90            let table = self
91                .get_table(&TableReference {
92                    catalog: &req.catalog_name,
93                    schema: &req.schema_name,
94                    table: &table_name,
95                })
96                .await?;
97            // Only base tables, ignores views and temporary tables.
98            if table.table_type() != table::metadata::TableType::Base {
99                continue;
100            }
101            // Ignores physical tables of metric engine.
102            if table.table_info().meta.engine == METRIC_ENGINE_NAME
103                && !table
104                    .table_info()
105                    .meta
106                    .options
107                    .extra_options
108                    .contains_key(LOGICAL_TABLE_METADATA_KEY)
109            {
110                continue;
111            }
112
113            let semaphore_moved = semaphore.clone();
114            let mut table_file = req.location.clone();
115            table_file.push_str(&table_name);
116            table_file.push_str(suffix);
117            let table_no = i + 1;
118            let moved_ctx = ctx.clone();
119            let full_table_name =
120                format_full_table_name(&req.catalog_name, &req.schema_name, &table_name);
121            let copy_table_req = CopyTableRequest {
122                catalog_name: req.catalog_name.clone(),
123                schema_name: req.schema_name.clone(),
124                table_name,
125                location: table_file.clone(),
126                with: req.with.clone(),
127                connection: req.connection.clone(),
128                pattern: None,
129                direction: CopyDirection::Export,
130                timestamp_range: req.time_range,
131                limit: None,
132            };
133
134            tasks.push(async move {
135                let _permit = semaphore_moved.acquire().await.unwrap();
136                info!(
137                    "Copy table({}/{}): {} to {}",
138                    table_no, num_tables, full_table_name, table_file
139                );
140                self.copy_table_to(copy_table_req, moved_ctx).await
141            });
142        }
143
144        let results = try_join_all(tasks).await?;
145        let exported_rows = results.into_iter().sum();
146        Ok(Output::new_with_affected_rows(exported_rows))
147    }
148
149    /// Imports data to database from a given location and returns total rows imported.
150    #[tracing::instrument(skip_all)]
151    pub(crate) async fn copy_database_from(
152        &self,
153        req: CopyDatabaseRequest,
154        ctx: QueryContextRef,
155    ) -> error::Result<Output> {
156        // location must end with /
157        ensure!(
158            req.location.ends_with('/'),
159            InvalidCopyDatabasePathSnafu {
160                value: req.location,
161            }
162        );
163
164        let parallelism = parse_parallelism_from_option_map(&req.with);
165        info!(
166            "Copy database {}.{} from dir: {}, time: {:?}, parallelism: {}",
167            req.catalog_name, req.schema_name, req.location, req.time_range, parallelism
168        );
169        let suffix = Format::try_from(&req.with)
170            .context(error::ParseFileFormatSnafu)?
171            .suffix();
172
173        let entries = list_files_to_copy(&req, suffix).await?;
174
175        let continue_on_error = req
176            .with
177            .get(CONTINUE_ON_ERROR_KEY)
178            .and_then(|v| bool::from_str(v).ok())
179            .unwrap_or(false);
180
181        let mut tasks = Vec::with_capacity(entries.len());
182        let semaphore = Arc::new(Semaphore::new(parallelism));
183
184        for e in entries {
185            let table_name = match parse_file_name_to_copy(&e) {
186                Ok(table_name) => table_name,
187                Err(err) => {
188                    if continue_on_error {
189                        error!(err; "Failed to import table from file: {:?}", e);
190                        continue;
191                    } else {
192                        return Err(err);
193                    }
194                }
195            };
196
197            let req = CopyTableRequest {
198                catalog_name: req.catalog_name.clone(),
199                schema_name: req.schema_name.clone(),
200                table_name: table_name.clone(),
201                location: format!("{}/{}", req.location, e.path()),
202                with: req.with.clone(),
203                connection: req.connection.clone(),
204                pattern: None,
205                direction: CopyDirection::Import,
206                timestamp_range: None,
207                limit: None,
208            };
209            let moved_ctx = ctx.clone();
210            let moved_table_name = table_name.clone();
211            let moved_semaphore = semaphore.clone();
212            tasks.push(async move {
213                let _permit = moved_semaphore.acquire().await.unwrap();
214                debug!("Copy table, arg: {:?}", req);
215                match self.copy_table_from(req, moved_ctx).await {
216                    Ok(o) => {
217                        let (rows, cost) = o.extract_rows_and_cost();
218                        Ok((rows, cost))
219                    }
220                    Err(err) => {
221                        if continue_on_error {
222                            error!(err; "Failed to import file to table: {}", moved_table_name);
223                            Ok((0, 0))
224                        } else {
225                            Err(err)
226                        }
227                    }
228                }
229            });
230        }
231
232        let results = try_join_all(tasks).await?;
233        let (rows_inserted, insert_cost) = results
234            .into_iter()
235            .fold((0, 0), |(acc_rows, acc_cost), (rows, cost)| {
236                (acc_rows + rows, acc_cost + cost)
237            });
238
239        Ok(Output::new(
240            OutputData::AffectedRows(rows_inserted),
241            OutputMeta::new_with_cost(insert_cost),
242        ))
243    }
244}
245
246/// Parses table names from files' names.
247fn parse_file_name_to_copy(e: &Entry) -> error::Result<String> {
248    Path::new(e.name())
249        .file_stem()
250        .and_then(|os_str| os_str.to_str())
251        .map(|s| s.to_string())
252        .context(error::InvalidTableNameSnafu {
253            table_name: e.name().to_string(),
254        })
255}
256
257/// Lists all files with expected suffix that can be imported to database.
258async fn list_files_to_copy(req: &CopyDatabaseRequest, suffix: &str) -> error::Result<Vec<Entry>> {
259    let object_store =
260        build_backend(&req.location, &req.connection).context(error::BuildBackendSnafu)?;
261
262    let pattern = Regex::try_from(format!(".*{}", suffix)).context(error::BuildRegexSnafu)?;
263    let lister = Lister::new(
264        object_store.clone(),
265        Source::Dir,
266        "/".to_string(),
267        Some(pattern),
268    );
269    lister.list().await.context(error::ListObjectsSnafu)
270}
271
272#[cfg(test)]
273mod tests {
274    use std::collections::{HashMap, HashSet};
275
276    use common_stat::get_total_cpu_cores;
277    use object_store::ObjectStore;
278    use object_store::services::Fs;
279    use object_store::util::normalize_dir;
280    use path_slash::PathExt;
281    use table::requests::CopyDatabaseRequest;
282
283    use crate::statement::copy_database::{
284        list_files_to_copy, parse_file_name_to_copy, parse_parallelism_from_option_map,
285    };
286
287    #[tokio::test]
288    async fn test_list_files_and_parse_table_name() {
289        let dir = common_test_util::temp_dir::create_temp_dir("test_list_files_to_copy");
290        let store_dir = normalize_dir(dir.path().to_str().unwrap());
291        let builder = Fs::default().root(&store_dir);
292        let object_store = ObjectStore::new(builder).unwrap().finish();
293        object_store.write("a.parquet", "").await.unwrap();
294        object_store.write("b.parquet", "").await.unwrap();
295        object_store.write("c.csv", "").await.unwrap();
296        object_store.write("d", "").await.unwrap();
297        object_store.write("e.f.parquet", "").await.unwrap();
298
299        let location = normalize_dir(&dir.path().to_slash().unwrap());
300        let request = CopyDatabaseRequest {
301            catalog_name: "catalog_0".to_string(),
302            schema_name: "schema_0".to_string(),
303            location,
304            with: [("FORMAT".to_string(), "parquet".to_string())]
305                .into_iter()
306                .collect(),
307            connection: Default::default(),
308            time_range: None,
309        };
310        let listed = list_files_to_copy(&request, ".parquet")
311            .await
312            .unwrap()
313            .into_iter()
314            .map(|e| parse_file_name_to_copy(&e).unwrap())
315            .collect::<HashSet<_>>();
316
317        assert_eq!(
318            ["a".to_string(), "b".to_string(), "e.f".to_string()]
319                .into_iter()
320                .collect::<HashSet<_>>(),
321            listed
322        );
323    }
324
325    #[test]
326    fn test_parse_parallelism_from_option_map() {
327        let options = HashMap::new();
328        assert_eq!(
329            parse_parallelism_from_option_map(&options),
330            get_total_cpu_cores()
331        );
332
333        let options = HashMap::from([("parallelism".to_string(), "0".to_string())]);
334        assert_eq!(parse_parallelism_from_option_map(&options), 1);
335    }
336}