cli/
import.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::path::PathBuf;
16use std::sync::Arc;
17use std::time::Duration;
18
19use async_trait::async_trait;
20use clap::{Parser, ValueEnum};
21use common_catalog::consts::DEFAULT_SCHEMA_NAME;
22use common_error::ext::BoxedError;
23use common_telemetry::{error, info, warn};
24use snafu::{OptionExt, ResultExt};
25use tokio::sync::Semaphore;
26use tokio::time::Instant;
27
28use crate::database::{parse_proxy_opts, DatabaseClient};
29use crate::error::{Error, FileIoSnafu, Result, SchemaNotFoundSnafu};
30use crate::{database, Tool};
31
32#[derive(Debug, Default, Clone, ValueEnum)]
33enum ImportTarget {
34    /// Import all table schemas into the database.
35    Schema,
36    /// Import all table data into the database.
37    Data,
38    /// Export all table schemas and data at once.
39    #[default]
40    All,
41}
42
43#[derive(Debug, Default, Parser)]
44pub struct ImportCommand {
45    /// Server address to connect
46    #[clap(long)]
47    addr: String,
48
49    /// Directory of the data. E.g.: /tmp/greptimedb-backup
50    #[clap(long)]
51    input_dir: String,
52
53    /// The name of the catalog to import.
54    #[clap(long, default_value = "greptime-*")]
55    database: String,
56
57    /// Parallelism of the import.
58    #[clap(long, short = 'j', default_value = "1")]
59    import_jobs: usize,
60
61    /// Max retry times for each job.
62    #[clap(long, default_value = "3")]
63    max_retry: usize,
64
65    /// Things to export
66    #[clap(long, short = 't', value_enum, default_value = "all")]
67    target: ImportTarget,
68
69    /// The basic authentication for connecting to the server
70    #[clap(long)]
71    auth_basic: Option<String>,
72
73    /// The timeout of invoking the database.
74    ///
75    /// It is used to override the server-side timeout setting.
76    /// The default behavior will disable server-side default timeout(i.e. `0s`).
77    #[clap(long, value_parser = humantime::parse_duration)]
78    timeout: Option<Duration>,
79
80    /// The proxy server address to connect, if set, will override the system proxy.
81    ///
82    /// The default behavior will use the system proxy if neither `proxy` nor `no_proxy` is set.
83    #[clap(long)]
84    proxy: Option<String>,
85
86    /// Disable proxy server, if set, will not use any proxy.
87    #[clap(long, default_value = "false")]
88    no_proxy: bool,
89}
90
91impl ImportCommand {
92    pub async fn build(&self) -> std::result::Result<Box<dyn Tool>, BoxedError> {
93        let (catalog, schema) =
94            database::split_database(&self.database).map_err(BoxedError::new)?;
95        let proxy = parse_proxy_opts(self.proxy.clone(), self.no_proxy)?;
96        let database_client = DatabaseClient::new(
97            self.addr.clone(),
98            catalog.clone(),
99            self.auth_basic.clone(),
100            // Treats `None` as `0s` to disable server-side default timeout.
101            self.timeout.unwrap_or_default(),
102            proxy,
103        );
104
105        Ok(Box::new(Import {
106            catalog,
107            schema,
108            database_client,
109            input_dir: self.input_dir.clone(),
110            parallelism: self.import_jobs,
111            target: self.target.clone(),
112        }))
113    }
114}
115
116pub struct Import {
117    catalog: String,
118    schema: Option<String>,
119    database_client: DatabaseClient,
120    input_dir: String,
121    parallelism: usize,
122    target: ImportTarget,
123}
124
125impl Import {
126    async fn import_create_table(&self) -> Result<()> {
127        // Use default db to creates other dbs
128        self.do_sql_job("create_database.sql", Some(DEFAULT_SCHEMA_NAME))
129            .await?;
130        self.do_sql_job("create_tables.sql", None).await
131    }
132
133    async fn import_database_data(&self) -> Result<()> {
134        self.do_sql_job("copy_from.sql", None).await
135    }
136
137    async fn do_sql_job(&self, filename: &str, exec_db: Option<&str>) -> Result<()> {
138        let timer = Instant::now();
139        let semaphore = Arc::new(Semaphore::new(self.parallelism));
140        let db_names = self.get_db_names().await?;
141        let db_count = db_names.len();
142        let mut tasks = Vec::with_capacity(db_count);
143        for schema in db_names {
144            let semaphore_moved = semaphore.clone();
145            tasks.push(async move {
146                let _permit = semaphore_moved.acquire().await.unwrap();
147                let database_input_dir = self.catalog_path().join(&schema);
148                let sql_file = database_input_dir.join(filename);
149                let sql = tokio::fs::read_to_string(sql_file)
150                    .await
151                    .context(FileIoSnafu)?;
152                if sql.is_empty() {
153                    info!("Empty `{filename}` {database_input_dir:?}");
154                } else {
155                    let db = exec_db.unwrap_or(&schema);
156                    self.database_client.sql(&sql, db).await?;
157                    info!("Imported `{filename}` for database {schema}");
158                }
159
160                Ok::<(), Error>(())
161            })
162        }
163
164        let success = futures::future::join_all(tasks)
165            .await
166            .into_iter()
167            .filter(|r| match r {
168                Ok(_) => true,
169                Err(e) => {
170                    error!(e; "import {filename} job failed");
171                    false
172                }
173            })
174            .count();
175        let elapsed = timer.elapsed();
176        info!("Success {success}/{db_count} `{filename}` jobs, cost: {elapsed:?}");
177
178        Ok(())
179    }
180
181    fn catalog_path(&self) -> PathBuf {
182        PathBuf::from(&self.input_dir).join(&self.catalog)
183    }
184
185    async fn get_db_names(&self) -> Result<Vec<String>> {
186        let db_names = self.all_db_names().await?;
187        let Some(schema) = &self.schema else {
188            return Ok(db_names);
189        };
190
191        // Check if the schema exists
192        db_names
193            .into_iter()
194            .find(|db_name| db_name.to_lowercase() == schema.to_lowercase())
195            .map(|name| vec![name])
196            .context(SchemaNotFoundSnafu {
197                catalog: &self.catalog,
198                schema,
199            })
200    }
201
202    // Get all database names in the input directory.
203    // The directory structure should be like:
204    // /tmp/greptimedb-backup
205    // ├── greptime-1
206    // │   ├── db1
207    // │   └── db2
208    async fn all_db_names(&self) -> Result<Vec<String>> {
209        let mut db_names = vec![];
210        let path = self.catalog_path();
211        let mut entries = tokio::fs::read_dir(path).await.context(FileIoSnafu)?;
212        while let Some(entry) = entries.next_entry().await.context(FileIoSnafu)? {
213            let path = entry.path();
214            if path.is_dir() {
215                let db_name = match path.file_name() {
216                    Some(name) => name.to_string_lossy().to_string(),
217                    None => {
218                        warn!("Failed to get the file name of {:?}", path);
219                        continue;
220                    }
221                };
222                db_names.push(db_name);
223            }
224        }
225        Ok(db_names)
226    }
227}
228
229#[async_trait]
230impl Tool for Import {
231    async fn do_work(&self) -> std::result::Result<(), BoxedError> {
232        match self.target {
233            ImportTarget::Schema => self.import_create_table().await.map_err(BoxedError::new),
234            ImportTarget::Data => self.import_database_data().await.map_err(BoxedError::new),
235            ImportTarget::All => {
236                self.import_create_table().await.map_err(BoxedError::new)?;
237                self.import_database_data().await.map_err(BoxedError::new)
238            }
239        }
240    }
241}