cli/
database.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::time::Duration;
16
17use base64::engine::general_purpose;
18use base64::Engine;
19use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
20use common_error::ext::BoxedError;
21use humantime::format_duration;
22use serde_json::Value;
23use servers::http::header::constants::GREPTIME_DB_HEADER_TIMEOUT;
24use servers::http::result::greptime_result_v1::GreptimedbV1Response;
25use servers::http::GreptimeQueryOutput;
26use snafu::ResultExt;
27
28use crate::error::{
29    BuildClientSnafu, HttpQuerySqlSnafu, ParseProxyOptsSnafu, Result, SerdeJsonSnafu,
30};
31
32#[derive(Debug, Clone)]
33pub struct DatabaseClient {
34    addr: String,
35    catalog: String,
36    auth_header: Option<String>,
37    timeout: Duration,
38    proxy: Option<reqwest::Proxy>,
39}
40
41pub fn parse_proxy_opts(
42    proxy: Option<String>,
43    no_proxy: bool,
44) -> std::result::Result<Option<reqwest::Proxy>, BoxedError> {
45    if no_proxy {
46        return Ok(None);
47    }
48    proxy
49        .map(|proxy| {
50            reqwest::Proxy::all(proxy)
51                .context(ParseProxyOptsSnafu)
52                .map_err(BoxedError::new)
53        })
54        .transpose()
55}
56
57impl DatabaseClient {
58    pub fn new(
59        addr: String,
60        catalog: String,
61        auth_basic: Option<String>,
62        timeout: Duration,
63        proxy: Option<reqwest::Proxy>,
64    ) -> Self {
65        let auth_header = if let Some(basic) = auth_basic {
66            let encoded = general_purpose::STANDARD.encode(basic);
67            Some(format!("basic {}", encoded))
68        } else {
69            None
70        };
71
72        if let Some(ref proxy) = proxy {
73            common_telemetry::info!("Using proxy: {:?}", proxy);
74        } else {
75            common_telemetry::info!("Using system proxy(if any)");
76        }
77
78        Self {
79            addr,
80            catalog,
81            auth_header,
82            timeout,
83            proxy,
84        }
85    }
86
87    pub async fn sql_in_public(&self, sql: &str) -> Result<Option<Vec<Vec<Value>>>> {
88        self.sql(sql, DEFAULT_SCHEMA_NAME).await
89    }
90
91    /// Execute sql query.
92    pub async fn sql(&self, sql: &str, schema: &str) -> Result<Option<Vec<Vec<Value>>>> {
93        let url = format!("http://{}/v1/sql", self.addr);
94        let params = [
95            ("db", format!("{}-{}", self.catalog, schema)),
96            ("sql", sql.to_string()),
97        ];
98        let client = self
99            .proxy
100            .clone()
101            .map(|proxy| reqwest::Client::builder().proxy(proxy).build())
102            .unwrap_or_else(|| Ok(reqwest::Client::new()))
103            .context(BuildClientSnafu)?;
104        let mut request = client
105            .post(&url)
106            .form(&params)
107            .header("Content-Type", "application/x-www-form-urlencoded");
108        if let Some(ref auth) = self.auth_header {
109            request = request.header("Authorization", auth);
110        }
111
112        request = request.header(
113            GREPTIME_DB_HEADER_TIMEOUT,
114            format_duration(self.timeout).to_string(),
115        );
116
117        let response = request.send().await.with_context(|_| HttpQuerySqlSnafu {
118            reason: format!("bad url: {}", url),
119        })?;
120        let response = response
121            .error_for_status()
122            .with_context(|_| HttpQuerySqlSnafu {
123                reason: format!("query failed: {}", sql),
124            })?;
125
126        let text = response.text().await.with_context(|_| HttpQuerySqlSnafu {
127            reason: "cannot get response text".to_string(),
128        })?;
129
130        let body = serde_json::from_str::<GreptimedbV1Response>(&text).context(SerdeJsonSnafu)?;
131        Ok(body.output().first().and_then(|output| match output {
132            GreptimeQueryOutput::Records(records) => Some(records.rows().clone()),
133            GreptimeQueryOutput::AffectedRows(_) => None,
134        }))
135    }
136}
137
138/// Split at `-`.
139pub(crate) fn split_database(database: &str) -> Result<(String, Option<String>)> {
140    let (catalog, schema) = match database.split_once('-') {
141        Some((catalog, schema)) => (catalog, schema),
142        None => (DEFAULT_CATALOG_NAME, database),
143    };
144
145    if schema == "*" {
146        Ok((catalog.to_string(), None))
147    } else {
148        Ok((catalog.to_string(), Some(schema.to_string())))
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn test_split_database() {
158        let result = split_database("catalog-schema").unwrap();
159        assert_eq!(result, ("catalog".to_string(), Some("schema".to_string())));
160
161        let result = split_database("schema").unwrap();
162        assert_eq!(result, ("greptime".to_string(), Some("schema".to_string())));
163
164        let result = split_database("catalog-*").unwrap();
165        assert_eq!(result, ("catalog".to_string(), None));
166
167        let result = split_database("*").unwrap();
168        assert_eq!(result, ("greptime".to_string(), None));
169    }
170}