operator/statement/
copy_database.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::path::Path;
16use std::str::FromStr;
17
18use client::{Output, OutputData, OutputMeta};
19use common_datasource::file_format::Format;
20use common_datasource::lister::{Lister, Source};
21use common_datasource::object_store::build_backend;
22use common_telemetry::{debug, error, info, tracing};
23use object_store::Entry;
24use regex::Regex;
25use session::context::QueryContextRef;
26use snafu::{ensure, OptionExt, ResultExt};
27use store_api::metric_engine_consts::{LOGICAL_TABLE_METADATA_KEY, METRIC_ENGINE_NAME};
28use table::requests::{CopyDatabaseRequest, CopyDirection, CopyTableRequest};
29use table::table_reference::TableReference;
30
31use crate::error;
32use crate::error::{CatalogSnafu, InvalidCopyDatabasePathSnafu};
33use crate::statement::StatementExecutor;
34
35pub(crate) const COPY_DATABASE_TIME_START_KEY: &str = "start_time";
36pub(crate) const COPY_DATABASE_TIME_END_KEY: &str = "end_time";
37pub(crate) const CONTINUE_ON_ERROR_KEY: &str = "continue_on_error";
38
39impl StatementExecutor {
40    #[tracing::instrument(skip_all)]
41    pub(crate) async fn copy_database_to(
42        &self,
43        req: CopyDatabaseRequest,
44        ctx: QueryContextRef,
45    ) -> error::Result<Output> {
46        // location must end with / so that every table is exported to a file.
47        ensure!(
48            req.location.ends_with('/'),
49            InvalidCopyDatabasePathSnafu {
50                value: req.location,
51            }
52        );
53
54        info!(
55            "Copy database {}.{} to dir: {}, time: {:?}",
56            req.catalog_name, req.schema_name, req.location, req.time_range
57        );
58        let table_names = self
59            .catalog_manager
60            .table_names(&req.catalog_name, &req.schema_name, Some(&ctx))
61            .await
62            .context(CatalogSnafu)?;
63
64        let suffix = Format::try_from(&req.with)
65            .context(error::ParseFileFormatSnafu)?
66            .suffix();
67
68        let mut exported_rows = 0;
69        for table_name in table_names {
70            let table = self
71                .get_table(&TableReference {
72                    catalog: &req.catalog_name,
73                    schema: &req.schema_name,
74                    table: &table_name,
75                })
76                .await?;
77            // Only base tables, ignores views and temporary tables.
78            if table.table_type() != table::metadata::TableType::Base {
79                continue;
80            }
81            // Ignores physical tables of metric engine.
82            if table.table_info().meta.engine == METRIC_ENGINE_NAME
83                && !table
84                    .table_info()
85                    .meta
86                    .options
87                    .extra_options
88                    .contains_key(LOGICAL_TABLE_METADATA_KEY)
89            {
90                continue;
91            }
92            let mut table_file = req.location.clone();
93            table_file.push_str(&table_name);
94            table_file.push_str(suffix);
95            info!(
96                "Copy table: {}.{}.{} to {}",
97                req.catalog_name, req.schema_name, table_name, table_file
98            );
99
100            let exported = self
101                .copy_table_to(
102                    CopyTableRequest {
103                        catalog_name: req.catalog_name.clone(),
104                        schema_name: req.schema_name.clone(),
105                        table_name,
106                        location: table_file,
107                        with: req.with.clone(),
108                        connection: req.connection.clone(),
109                        pattern: None,
110                        direction: CopyDirection::Export,
111                        timestamp_range: req.time_range,
112                        limit: None,
113                    },
114                    ctx.clone(),
115                )
116                .await?;
117            exported_rows += exported;
118        }
119        Ok(Output::new_with_affected_rows(exported_rows))
120    }
121
122    /// Imports data to database from a given location and returns total rows imported.
123    #[tracing::instrument(skip_all)]
124    pub(crate) async fn copy_database_from(
125        &self,
126        req: CopyDatabaseRequest,
127        ctx: QueryContextRef,
128    ) -> error::Result<Output> {
129        // location must end with /
130        ensure!(
131            req.location.ends_with('/'),
132            InvalidCopyDatabasePathSnafu {
133                value: req.location,
134            }
135        );
136
137        info!(
138            "Copy database {}.{} from dir: {}, time: {:?}",
139            req.catalog_name, req.schema_name, req.location, req.time_range
140        );
141        let suffix = Format::try_from(&req.with)
142            .context(error::ParseFileFormatSnafu)?
143            .suffix();
144
145        let entries = list_files_to_copy(&req, suffix).await?;
146
147        let continue_on_error = req
148            .with
149            .get(CONTINUE_ON_ERROR_KEY)
150            .and_then(|v| bool::from_str(v).ok())
151            .unwrap_or(false);
152
153        let mut rows_inserted = 0;
154        let mut insert_cost = 0;
155
156        for e in entries {
157            let table_name = match parse_file_name_to_copy(&e) {
158                Ok(table_name) => table_name,
159                Err(err) => {
160                    if continue_on_error {
161                        error!(err; "Failed to import table from file: {:?}", e);
162                        continue;
163                    } else {
164                        return Err(err);
165                    }
166                }
167            };
168            let req = CopyTableRequest {
169                catalog_name: req.catalog_name.clone(),
170                schema_name: req.schema_name.clone(),
171                table_name: table_name.clone(),
172                location: format!("{}/{}", req.location, e.path()),
173                with: req.with.clone(),
174                connection: req.connection.clone(),
175                pattern: None,
176                direction: CopyDirection::Import,
177                timestamp_range: None,
178                limit: None,
179            };
180            debug!("Copy table, arg: {:?}", req);
181            match self.copy_table_from(req, ctx.clone()).await {
182                Ok(o) => {
183                    let (rows, cost) = o.extract_rows_and_cost();
184                    rows_inserted += rows;
185                    insert_cost += cost;
186                }
187                Err(err) => {
188                    if continue_on_error {
189                        error!(err; "Failed to import file to table: {}", table_name);
190                        continue;
191                    } else {
192                        return Err(err);
193                    }
194                }
195            }
196        }
197        Ok(Output::new(
198            OutputData::AffectedRows(rows_inserted),
199            OutputMeta::new_with_cost(insert_cost),
200        ))
201    }
202}
203
204/// Parses table names from files' names.
205fn parse_file_name_to_copy(e: &Entry) -> error::Result<String> {
206    Path::new(e.name())
207        .file_stem()
208        .and_then(|os_str| os_str.to_str())
209        .map(|s| s.to_string())
210        .context(error::InvalidTableNameSnafu {
211            table_name: e.name().to_string(),
212        })
213}
214
215/// Lists all files with expected suffix that can be imported to database.
216async fn list_files_to_copy(req: &CopyDatabaseRequest, suffix: &str) -> error::Result<Vec<Entry>> {
217    let object_store =
218        build_backend(&req.location, &req.connection).context(error::BuildBackendSnafu)?;
219
220    let pattern = Regex::try_from(format!(".*{}", suffix)).context(error::BuildRegexSnafu)?;
221    let lister = Lister::new(
222        object_store.clone(),
223        Source::Dir,
224        "/".to_string(),
225        Some(pattern),
226    );
227    lister.list().await.context(error::ListObjectsSnafu)
228}
229
230#[cfg(test)]
231mod tests {
232    use std::collections::HashSet;
233
234    use object_store::services::Fs;
235    use object_store::util::normalize_dir;
236    use object_store::ObjectStore;
237    use path_slash::PathExt;
238    use table::requests::CopyDatabaseRequest;
239
240    use crate::statement::copy_database::{list_files_to_copy, parse_file_name_to_copy};
241
242    #[tokio::test]
243    async fn test_list_files_and_parse_table_name() {
244        let dir = common_test_util::temp_dir::create_temp_dir("test_list_files_to_copy");
245        let store_dir = normalize_dir(dir.path().to_str().unwrap());
246        let builder = Fs::default().root(&store_dir);
247        let object_store = ObjectStore::new(builder).unwrap().finish();
248        object_store.write("a.parquet", "").await.unwrap();
249        object_store.write("b.parquet", "").await.unwrap();
250        object_store.write("c.csv", "").await.unwrap();
251        object_store.write("d", "").await.unwrap();
252        object_store.write("e.f.parquet", "").await.unwrap();
253
254        let location = normalize_dir(&dir.path().to_slash().unwrap());
255        let request = CopyDatabaseRequest {
256            catalog_name: "catalog_0".to_string(),
257            schema_name: "schema_0".to_string(),
258            location,
259            with: [("FORMAT".to_string(), "parquet".to_string())]
260                .into_iter()
261                .collect(),
262            connection: Default::default(),
263            time_range: None,
264        };
265        let listed = list_files_to_copy(&request, ".parquet")
266            .await
267            .unwrap()
268            .into_iter()
269            .map(|e| parse_file_name_to_copy(&e).unwrap())
270            .collect::<HashSet<_>>();
271
272        assert_eq!(
273            ["a".to_string(), "b".to_string(), "e.f".to_string()]
274                .into_iter()
275                .collect::<HashSet<_>>(),
276            listed
277        );
278    }
279}