cli/data/
import.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::path::PathBuf;
16use std::sync::Arc;
17use std::time::Duration;
18
19use async_trait::async_trait;
20use clap::{Parser, ValueEnum};
21use common_catalog::consts::DEFAULT_SCHEMA_NAME;
22use common_error::ext::BoxedError;
23use common_telemetry::{error, info, warn};
24use snafu::{OptionExt, ResultExt, ensure};
25use tokio::sync::Semaphore;
26use tokio::time::Instant;
27
28use crate::data::{COPY_PATH_PLACEHOLDER, default_database};
29use crate::database::{DatabaseClient, parse_proxy_opts};
30use crate::error::{Error, FileIoSnafu, InvalidArgumentsSnafu, Result, SchemaNotFoundSnafu};
31use crate::{Tool, database};
32
33#[derive(Debug, Default, Clone, ValueEnum)]
34enum ImportTarget {
35    /// Import all table schemas into the database.
36    Schema,
37    /// Import all table data into the database.
38    Data,
39    /// Export all table schemas and data at once.
40    #[default]
41    All,
42}
43
44/// Command to import data from a directory into a GreptimeDB instance.
45#[derive(Debug, Default, Parser)]
46pub struct ImportCommand {
47    /// Server address to connect
48    #[clap(long)]
49    addr: String,
50
51    /// Directory of the data. E.g.: /tmp/greptimedb-backup
52    #[clap(long)]
53    input_dir: String,
54
55    /// The name of the catalog to import.
56    #[clap(long, default_value_t = default_database())]
57    database: String,
58
59    /// The number of databases imported in parallel.
60    /// For example, if there are 20 databases and `db_parallelism` is 4,
61    /// 4 databases will be imported concurrently.
62    #[clap(long, short = 'j', default_value = "1", alias = "import-jobs")]
63    db_parallelism: usize,
64
65    /// Max retry times for each job.
66    #[clap(long, default_value = "3")]
67    max_retry: usize,
68
69    /// Things to export
70    #[clap(long, short = 't', value_enum, default_value = "all")]
71    target: ImportTarget,
72
73    /// The basic authentication for connecting to the server
74    #[clap(long)]
75    auth_basic: Option<String>,
76
77    /// The timeout of invoking the database.
78    ///
79    /// It is used to override the server-side timeout setting.
80    /// The default behavior will disable server-side default timeout(i.e. `0s`).
81    #[clap(long, value_parser = humantime::parse_duration)]
82    timeout: Option<Duration>,
83
84    /// The proxy server address to connect, if set, will override the system proxy.
85    ///
86    /// The default behavior will use the system proxy if neither `proxy` nor `no_proxy` is set.
87    #[clap(long)]
88    proxy: Option<String>,
89
90    /// Disable proxy server, if set, will not use any proxy.
91    #[clap(long, default_value = "false")]
92    no_proxy: bool,
93}
94
95impl ImportCommand {
96    pub async fn build(&self) -> std::result::Result<Box<dyn Tool>, BoxedError> {
97        let (catalog, schema) =
98            database::split_database(&self.database).map_err(BoxedError::new)?;
99        let proxy = parse_proxy_opts(self.proxy.clone(), self.no_proxy)?;
100        let database_client = DatabaseClient::new(
101            self.addr.clone(),
102            catalog.clone(),
103            self.auth_basic.clone(),
104            // Treats `None` as `0s` to disable server-side default timeout.
105            self.timeout.unwrap_or_default(),
106            proxy,
107        );
108
109        Ok(Box::new(Import {
110            catalog,
111            schema,
112            database_client,
113            input_dir: self.input_dir.clone(),
114            parallelism: self.db_parallelism,
115            target: self.target.clone(),
116        }))
117    }
118}
119
120pub struct Import {
121    catalog: String,
122    schema: Option<String>,
123    database_client: DatabaseClient,
124    input_dir: String,
125    parallelism: usize,
126    target: ImportTarget,
127}
128
129impl Import {
130    async fn import_create_table(&self) -> Result<()> {
131        // Use default db to creates other dbs
132        self.do_sql_job("create_database.sql", Some(DEFAULT_SCHEMA_NAME))
133            .await?;
134        self.do_sql_job("create_tables.sql", None).await
135    }
136
137    async fn import_database_data(&self) -> Result<()> {
138        self.do_sql_job("copy_from.sql", None).await
139    }
140
141    async fn do_sql_job(&self, filename: &str, exec_db: Option<&str>) -> Result<()> {
142        let timer = Instant::now();
143        let semaphore = Arc::new(Semaphore::new(self.parallelism));
144        let db_names = self.get_db_names().await?;
145        let db_count = db_names.len();
146        let mut tasks = Vec::with_capacity(db_count);
147        for schema in db_names {
148            let semaphore_moved = semaphore.clone();
149            tasks.push(async move {
150                let _permit = semaphore_moved.acquire().await.unwrap();
151                let database_input_dir = self.catalog_path().join(&schema);
152                let sql_file = database_input_dir.join(filename);
153                let mut sql = tokio::fs::read_to_string(sql_file)
154                    .await
155                    .context(FileIoSnafu)?;
156                if sql.trim().is_empty() {
157                    info!("Empty `{filename}` {database_input_dir:?}");
158                } else {
159                    if filename == "copy_from.sql" {
160                        sql = self.rewrite_copy_database_sql(&schema, &sql)?;
161                    }
162                    let db = exec_db.unwrap_or(&schema);
163                    self.database_client.sql(&sql, db).await?;
164                    info!("Imported `{filename}` for database {schema}");
165                }
166
167                Ok::<(), Error>(())
168            })
169        }
170
171        let success = futures::future::join_all(tasks)
172            .await
173            .into_iter()
174            .filter(|r| match r {
175                Ok(_) => true,
176                Err(e) => {
177                    error!(e; "import {filename} job failed");
178                    false
179                }
180            })
181            .count();
182        let elapsed = timer.elapsed();
183        info!("Success {success}/{db_count} `{filename}` jobs, cost: {elapsed:?}");
184
185        Ok(())
186    }
187
188    fn catalog_path(&self) -> PathBuf {
189        PathBuf::from(&self.input_dir).join(&self.catalog)
190    }
191
192    async fn get_db_names(&self) -> Result<Vec<String>> {
193        let db_names = self.all_db_names().await?;
194        let Some(schema) = &self.schema else {
195            return Ok(db_names);
196        };
197
198        // Check if the schema exists
199        db_names
200            .into_iter()
201            .find(|db_name| db_name.to_lowercase() == schema.to_lowercase())
202            .map(|name| vec![name])
203            .context(SchemaNotFoundSnafu {
204                catalog: &self.catalog,
205                schema,
206            })
207    }
208
209    // Get all database names in the input directory.
210    // The directory structure should be like:
211    // /tmp/greptimedb-backup
212    // ├── greptime-1
213    // │   ├── db1
214    // │   └── db2
215    async fn all_db_names(&self) -> Result<Vec<String>> {
216        let mut db_names = vec![];
217        let path = self.catalog_path();
218        let mut entries = tokio::fs::read_dir(path).await.context(FileIoSnafu)?;
219        while let Some(entry) = entries.next_entry().await.context(FileIoSnafu)? {
220            let path = entry.path();
221            if path.is_dir() {
222                let db_name = match path.file_name() {
223                    Some(name) => name.to_string_lossy().to_string(),
224                    None => {
225                        warn!("Failed to get the file name of {:?}", path);
226                        continue;
227                    }
228                };
229                db_names.push(db_name);
230            }
231        }
232        Ok(db_names)
233    }
234
235    fn rewrite_copy_database_sql(&self, schema: &str, sql: &str) -> Result<String> {
236        let target_location = self.build_copy_database_location(schema);
237        let escaped_location = target_location.replace('\'', "''");
238
239        let mut first_stmt_checked = false;
240        for line in sql.lines() {
241            let trimmed = line.trim_start();
242            if trimmed.is_empty() || trimmed.starts_with("--") {
243                continue;
244            }
245
246            ensure!(
247                trimmed.starts_with("COPY DATABASE"),
248                InvalidArgumentsSnafu {
249                    msg: "Expected COPY DATABASE statement at start of copy_from.sql"
250                }
251            );
252            first_stmt_checked = true;
253            break;
254        }
255
256        ensure!(
257            first_stmt_checked,
258            InvalidArgumentsSnafu {
259                msg: "COPY DATABASE statement not found in copy_from.sql"
260            }
261        );
262
263        ensure!(
264            sql.contains(COPY_PATH_PLACEHOLDER),
265            InvalidArgumentsSnafu {
266                msg: format!(
267                    "Placeholder `{}` not found in COPY DATABASE statement",
268                    COPY_PATH_PLACEHOLDER
269                )
270            }
271        );
272
273        Ok(sql.replacen(COPY_PATH_PLACEHOLDER, &escaped_location, 1))
274    }
275
276    fn build_copy_database_location(&self, schema: &str) -> String {
277        let mut path = self.catalog_path();
278        path.push(schema);
279        let mut path_str = path.to_string_lossy().into_owned();
280        if !path_str.ends_with('/') {
281            path_str.push('/');
282        }
283        path_str
284    }
285}
286
287#[async_trait]
288impl Tool for Import {
289    async fn do_work(&self) -> std::result::Result<(), BoxedError> {
290        match self.target {
291            ImportTarget::Schema => self.import_create_table().await.map_err(BoxedError::new),
292            ImportTarget::Data => self.import_database_data().await.map_err(BoxedError::new),
293            ImportTarget::All => {
294                self.import_create_table().await.map_err(BoxedError::new)?;
295                self.import_database_data().await.map_err(BoxedError::new)
296            }
297        }
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use std::time::Duration;
304
305    use super::*;
306
307    fn build_import(input_dir: &str) -> Import {
308        Import {
309            catalog: "catalog".to_string(),
310            schema: None,
311            database_client: DatabaseClient::new(
312                "127.0.0.1:4000".to_string(),
313                "catalog".to_string(),
314                None,
315                Duration::from_secs(0),
316                None,
317            ),
318            input_dir: input_dir.to_string(),
319            parallelism: 1,
320            target: ImportTarget::Data,
321        }
322    }
323
324    #[test]
325    fn rewrite_copy_database_sql_replaces_placeholder() {
326        let import = build_import("/tmp/export-path");
327        let comment = "-- COPY DATABASE \"catalog\".\"schema\" FROM 's3://bucket/demo/' WITH (format = 'parquet') CONNECTION (region = 'us-west-2')";
328        let sql = format!(
329            "{comment}\nCOPY DATABASE \"catalog\".\"schema\" FROM '{}' WITH (format = 'parquet');",
330            COPY_PATH_PLACEHOLDER
331        );
332
333        let rewritten = import.rewrite_copy_database_sql("schema", &sql).unwrap();
334        let expected_location = import.build_copy_database_location("schema");
335        let escaped = expected_location.replace('\'', "''");
336
337        assert!(rewritten.starts_with(comment));
338        assert!(rewritten.contains(&format!("FROM '{escaped}'")));
339        assert!(!rewritten.contains(COPY_PATH_PLACEHOLDER));
340    }
341
342    #[test]
343    fn rewrite_copy_database_sql_requires_placeholder() {
344        let import = build_import("/tmp/export-path");
345        let sql = "COPY DATABASE \"catalog\".\"schema\" FROM '/tmp/export-path/catalog/schema/' WITH (format = 'parquet');";
346        assert!(import.rewrite_copy_database_sql("schema", sql).is_err());
347    }
348}