cli/data/
import.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::path::PathBuf;
16use std::sync::Arc;
17use std::time::Duration;
18
19use async_trait::async_trait;
20use clap::{Parser, ValueEnum};
21use common_catalog::consts::DEFAULT_SCHEMA_NAME;
22use common_error::ext::BoxedError;
23use common_telemetry::{error, info, warn};
24use snafu::{OptionExt, ResultExt};
25use tokio::sync::Semaphore;
26use tokio::time::Instant;
27
28use crate::database::{parse_proxy_opts, DatabaseClient};
29use crate::error::{Error, FileIoSnafu, Result, SchemaNotFoundSnafu};
30use crate::{database, Tool};
31
32#[derive(Debug, Default, Clone, ValueEnum)]
33enum ImportTarget {
34    /// Import all table schemas into the database.
35    Schema,
36    /// Import all table data into the database.
37    Data,
38    /// Export all table schemas and data at once.
39    #[default]
40    All,
41}
42
43/// Command to import data from a directory into a GreptimeDB instance.
44#[derive(Debug, Default, Parser)]
45pub struct ImportCommand {
46    /// Server address to connect
47    #[clap(long)]
48    addr: String,
49
50    /// Directory of the data. E.g.: /tmp/greptimedb-backup
51    #[clap(long)]
52    input_dir: String,
53
54    /// The name of the catalog to import.
55    #[clap(long, default_value = "greptime-*")]
56    database: String,
57
58    /// Parallelism of the import.
59    #[clap(long, short = 'j', default_value = "1")]
60    import_jobs: usize,
61
62    /// Max retry times for each job.
63    #[clap(long, default_value = "3")]
64    max_retry: usize,
65
66    /// Things to export
67    #[clap(long, short = 't', value_enum, default_value = "all")]
68    target: ImportTarget,
69
70    /// The basic authentication for connecting to the server
71    #[clap(long)]
72    auth_basic: Option<String>,
73
74    /// The timeout of invoking the database.
75    ///
76    /// It is used to override the server-side timeout setting.
77    /// The default behavior will disable server-side default timeout(i.e. `0s`).
78    #[clap(long, value_parser = humantime::parse_duration)]
79    timeout: Option<Duration>,
80
81    /// The proxy server address to connect, if set, will override the system proxy.
82    ///
83    /// The default behavior will use the system proxy if neither `proxy` nor `no_proxy` is set.
84    #[clap(long)]
85    proxy: Option<String>,
86
87    /// Disable proxy server, if set, will not use any proxy.
88    #[clap(long, default_value = "false")]
89    no_proxy: bool,
90}
91
92impl ImportCommand {
93    pub async fn build(&self) -> std::result::Result<Box<dyn Tool>, BoxedError> {
94        let (catalog, schema) =
95            database::split_database(&self.database).map_err(BoxedError::new)?;
96        let proxy = parse_proxy_opts(self.proxy.clone(), self.no_proxy)?;
97        let database_client = DatabaseClient::new(
98            self.addr.clone(),
99            catalog.clone(),
100            self.auth_basic.clone(),
101            // Treats `None` as `0s` to disable server-side default timeout.
102            self.timeout.unwrap_or_default(),
103            proxy,
104        );
105
106        Ok(Box::new(Import {
107            catalog,
108            schema,
109            database_client,
110            input_dir: self.input_dir.clone(),
111            parallelism: self.import_jobs,
112            target: self.target.clone(),
113        }))
114    }
115}
116
117pub struct Import {
118    catalog: String,
119    schema: Option<String>,
120    database_client: DatabaseClient,
121    input_dir: String,
122    parallelism: usize,
123    target: ImportTarget,
124}
125
126impl Import {
127    async fn import_create_table(&self) -> Result<()> {
128        // Use default db to creates other dbs
129        self.do_sql_job("create_database.sql", Some(DEFAULT_SCHEMA_NAME))
130            .await?;
131        self.do_sql_job("create_tables.sql", None).await
132    }
133
134    async fn import_database_data(&self) -> Result<()> {
135        self.do_sql_job("copy_from.sql", None).await
136    }
137
138    async fn do_sql_job(&self, filename: &str, exec_db: Option<&str>) -> Result<()> {
139        let timer = Instant::now();
140        let semaphore = Arc::new(Semaphore::new(self.parallelism));
141        let db_names = self.get_db_names().await?;
142        let db_count = db_names.len();
143        let mut tasks = Vec::with_capacity(db_count);
144        for schema in db_names {
145            let semaphore_moved = semaphore.clone();
146            tasks.push(async move {
147                let _permit = semaphore_moved.acquire().await.unwrap();
148                let database_input_dir = self.catalog_path().join(&schema);
149                let sql_file = database_input_dir.join(filename);
150                let sql = tokio::fs::read_to_string(sql_file)
151                    .await
152                    .context(FileIoSnafu)?;
153                if sql.is_empty() {
154                    info!("Empty `{filename}` {database_input_dir:?}");
155                } else {
156                    let db = exec_db.unwrap_or(&schema);
157                    self.database_client.sql(&sql, db).await?;
158                    info!("Imported `{filename}` for database {schema}");
159                }
160
161                Ok::<(), Error>(())
162            })
163        }
164
165        let success = futures::future::join_all(tasks)
166            .await
167            .into_iter()
168            .filter(|r| match r {
169                Ok(_) => true,
170                Err(e) => {
171                    error!(e; "import {filename} job failed");
172                    false
173                }
174            })
175            .count();
176        let elapsed = timer.elapsed();
177        info!("Success {success}/{db_count} `{filename}` jobs, cost: {elapsed:?}");
178
179        Ok(())
180    }
181
182    fn catalog_path(&self) -> PathBuf {
183        PathBuf::from(&self.input_dir).join(&self.catalog)
184    }
185
186    async fn get_db_names(&self) -> Result<Vec<String>> {
187        let db_names = self.all_db_names().await?;
188        let Some(schema) = &self.schema else {
189            return Ok(db_names);
190        };
191
192        // Check if the schema exists
193        db_names
194            .into_iter()
195            .find(|db_name| db_name.to_lowercase() == schema.to_lowercase())
196            .map(|name| vec![name])
197            .context(SchemaNotFoundSnafu {
198                catalog: &self.catalog,
199                schema,
200            })
201    }
202
203    // Get all database names in the input directory.
204    // The directory structure should be like:
205    // /tmp/greptimedb-backup
206    // ├── greptime-1
207    // │   ├── db1
208    // │   └── db2
209    async fn all_db_names(&self) -> Result<Vec<String>> {
210        let mut db_names = vec![];
211        let path = self.catalog_path();
212        let mut entries = tokio::fs::read_dir(path).await.context(FileIoSnafu)?;
213        while let Some(entry) = entries.next_entry().await.context(FileIoSnafu)? {
214            let path = entry.path();
215            if path.is_dir() {
216                let db_name = match path.file_name() {
217                    Some(name) => name.to_string_lossy().to_string(),
218                    None => {
219                        warn!("Failed to get the file name of {:?}", path);
220                        continue;
221                    }
222                };
223                db_names.push(db_name);
224            }
225        }
226        Ok(db_names)
227    }
228}
229
230#[async_trait]
231impl Tool for Import {
232    async fn do_work(&self) -> std::result::Result<(), BoxedError> {
233        match self.target {
234            ImportTarget::Schema => self.import_create_table().await.map_err(BoxedError::new),
235            ImportTarget::Data => self.import_database_data().await.map_err(BoxedError::new),
236            ImportTarget::All => {
237                self.import_create_table().await.map_err(BoxedError::new)?;
238                self.import_database_data().await.map_err(BoxedError::new)
239            }
240        }
241    }
242}