Skip to main content

cli/
database.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::time::Duration;
16
17use base64::Engine;
18use base64::engine::general_purpose;
19use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
20use common_error::ext::BoxedError;
21use humantime::format_duration;
22use serde_json::Value;
23use servers::http::GreptimeQueryOutput;
24use servers::http::header::constants::GREPTIME_DB_HEADER_TIMEOUT;
25use servers::http::result::greptime_result_v1::GreptimedbV1Response;
26use snafu::ResultExt;
27
28use crate::error::{
29    BuildClientSnafu, HttpQuerySqlSnafu, ParseProxyOptsSnafu, Result, SerdeJsonSnafu,
30};
31
32#[derive(Debug, Clone)]
33pub struct DatabaseClient {
34    addr: String,
35    catalog: String,
36    auth_header: Option<String>,
37    timeout: Duration,
38    proxy: Option<reqwest::Proxy>,
39    no_proxy: bool,
40}
41
42pub fn parse_proxy_opts(
43    proxy: Option<String>,
44    no_proxy: bool,
45) -> std::result::Result<Option<reqwest::Proxy>, BoxedError> {
46    if no_proxy {
47        return Ok(None);
48    }
49    proxy
50        .map(|proxy| {
51            reqwest::Proxy::all(proxy)
52                .context(ParseProxyOptsSnafu)
53                .map_err(BoxedError::new)
54        })
55        .transpose()
56}
57
58impl DatabaseClient {
59    pub fn new(
60        addr: String,
61        catalog: String,
62        auth_basic: Option<String>,
63        timeout: Duration,
64        proxy: Option<reqwest::Proxy>,
65        no_proxy: bool,
66    ) -> Self {
67        let auth_header = if let Some(basic) = auth_basic {
68            let encoded = general_purpose::STANDARD.encode(basic);
69            Some(format!("basic {}", encoded))
70        } else {
71            None
72        };
73
74        if no_proxy {
75            common_telemetry::info!("Proxy disabled");
76        } else if let Some(ref proxy) = proxy {
77            common_telemetry::info!("Using proxy: {:?}", proxy);
78        } else {
79            common_telemetry::info!("Using system proxy(if any)");
80        }
81
82        Self {
83            addr,
84            catalog,
85            auth_header,
86            timeout,
87            proxy,
88            no_proxy,
89        }
90    }
91
92    pub async fn sql_in_public(&self, sql: &str) -> Result<Option<Vec<Vec<Value>>>> {
93        self.sql(sql, DEFAULT_SCHEMA_NAME).await
94    }
95
96    /// Execute sql query.
97    pub async fn sql(&self, sql: &str, schema: &str) -> Result<Option<Vec<Vec<Value>>>> {
98        let url = format!("http://{}/v1/sql", self.addr);
99        let params = [
100            ("db", format!("{}-{}", self.catalog, schema)),
101            ("sql", sql.to_string()),
102        ];
103        let mut builder = reqwest::Client::builder();
104        if let Some(proxy) = self.proxy.clone() {
105            builder = builder.proxy(proxy);
106        }
107        if self.no_proxy {
108            builder = builder.no_proxy();
109        }
110        let client = builder.build().context(BuildClientSnafu)?;
111        let mut request = client
112            .post(&url)
113            .form(&params)
114            .header("Content-Type", "application/x-www-form-urlencoded");
115        if let Some(ref auth) = self.auth_header {
116            request = request.header("Authorization", auth);
117        }
118
119        request = request.header(
120            GREPTIME_DB_HEADER_TIMEOUT,
121            format_duration(self.timeout).to_string(),
122        );
123
124        let response = request.send().await.with_context(|_| HttpQuerySqlSnafu {
125            reason: format!("bad url: {}", url),
126        })?;
127        let response = response
128            .error_for_status()
129            .with_context(|_| HttpQuerySqlSnafu {
130                reason: format!("query failed: {}", sql),
131            })?;
132
133        let text = response.text().await.with_context(|_| HttpQuerySqlSnafu {
134            reason: "cannot get response text".to_string(),
135        })?;
136
137        let body = serde_json::from_str::<GreptimedbV1Response>(&text).context(SerdeJsonSnafu)?;
138        Ok(body.output().first().and_then(|output| match output {
139            GreptimeQueryOutput::Records(records) => Some(records.rows().clone()),
140            GreptimeQueryOutput::AffectedRows(_) => None,
141        }))
142    }
143}
144
145/// Split at `-`.
146pub(crate) fn split_database(database: &str) -> Result<(String, Option<String>)> {
147    let (catalog, schema) = match database.split_once('-') {
148        Some((catalog, schema)) => (catalog, schema),
149        None => (DEFAULT_CATALOG_NAME, database),
150    };
151
152    if schema == "*" {
153        Ok((catalog.to_string(), None))
154    } else {
155        Ok((catalog.to_string(), Some(schema.to_string())))
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn test_split_database() {
165        let result = split_database("catalog-schema").unwrap();
166        assert_eq!(result, ("catalog".to_string(), Some("schema".to_string())));
167
168        let result = split_database("schema").unwrap();
169        assert_eq!(result, ("greptime".to_string(), Some("schema".to_string())));
170
171        let result = split_database("catalog-*").unwrap();
172        assert_eq!(result, ("catalog".to_string(), None));
173
174        let result = split_database("*").unwrap();
175        assert_eq!(result, ("greptime".to_string(), None));
176    }
177}