1use std::time::Duration;
16
17use base64::Engine;
18use base64::engine::general_purpose;
19use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
20use common_error::ext::BoxedError;
21use humantime::format_duration;
22use serde_json::Value;
23use servers::http::GreptimeQueryOutput;
24use servers::http::header::constants::GREPTIME_DB_HEADER_TIMEOUT;
25use servers::http::result::greptime_result_v1::GreptimedbV1Response;
26use snafu::ResultExt;
27
28use crate::error::{
29 BuildClientSnafu, HttpQuerySqlSnafu, ParseProxyOptsSnafu, Result, SerdeJsonSnafu,
30};
31
32#[derive(Debug, Clone)]
33pub struct DatabaseClient {
34 addr: String,
35 catalog: String,
36 auth_header: Option<String>,
37 timeout: Duration,
38 proxy: Option<reqwest::Proxy>,
39 no_proxy: bool,
40}
41
42pub fn parse_proxy_opts(
43 proxy: Option<String>,
44 no_proxy: bool,
45) -> std::result::Result<Option<reqwest::Proxy>, BoxedError> {
46 if no_proxy {
47 return Ok(None);
48 }
49 proxy
50 .map(|proxy| {
51 reqwest::Proxy::all(proxy)
52 .context(ParseProxyOptsSnafu)
53 .map_err(BoxedError::new)
54 })
55 .transpose()
56}
57
58impl DatabaseClient {
59 pub fn new(
60 addr: String,
61 catalog: String,
62 auth_basic: Option<String>,
63 timeout: Duration,
64 proxy: Option<reqwest::Proxy>,
65 no_proxy: bool,
66 ) -> Self {
67 let auth_header = if let Some(basic) = auth_basic {
68 let encoded = general_purpose::STANDARD.encode(basic);
69 Some(format!("basic {}", encoded))
70 } else {
71 None
72 };
73
74 if no_proxy {
75 common_telemetry::info!("Proxy disabled");
76 } else if let Some(ref proxy) = proxy {
77 common_telemetry::info!("Using proxy: {:?}", proxy);
78 } else {
79 common_telemetry::info!("Using system proxy(if any)");
80 }
81
82 Self {
83 addr,
84 catalog,
85 auth_header,
86 timeout,
87 proxy,
88 no_proxy,
89 }
90 }
91
92 pub async fn sql_in_public(&self, sql: &str) -> Result<Option<Vec<Vec<Value>>>> {
93 self.sql(sql, DEFAULT_SCHEMA_NAME).await
94 }
95
96 pub async fn sql(&self, sql: &str, schema: &str) -> Result<Option<Vec<Vec<Value>>>> {
98 let url = format!("http://{}/v1/sql", self.addr);
99 let params = [
100 ("db", format!("{}-{}", self.catalog, schema)),
101 ("sql", sql.to_string()),
102 ];
103 let mut builder = reqwest::Client::builder();
104 if let Some(proxy) = self.proxy.clone() {
105 builder = builder.proxy(proxy);
106 }
107 if self.no_proxy {
108 builder = builder.no_proxy();
109 }
110 let client = builder.build().context(BuildClientSnafu)?;
111 let mut request = client
112 .post(&url)
113 .form(¶ms)
114 .header("Content-Type", "application/x-www-form-urlencoded");
115 if let Some(ref auth) = self.auth_header {
116 request = request.header("Authorization", auth);
117 }
118
119 request = request.header(
120 GREPTIME_DB_HEADER_TIMEOUT,
121 format_duration(self.timeout).to_string(),
122 );
123
124 let response = request.send().await.with_context(|_| HttpQuerySqlSnafu {
125 reason: format!("bad url: {}", url),
126 })?;
127 let response = response
128 .error_for_status()
129 .with_context(|_| HttpQuerySqlSnafu {
130 reason: format!("query failed: {}", sql),
131 })?;
132
133 let text = response.text().await.with_context(|_| HttpQuerySqlSnafu {
134 reason: "cannot get response text".to_string(),
135 })?;
136
137 let body = serde_json::from_str::<GreptimedbV1Response>(&text).context(SerdeJsonSnafu)?;
138 Ok(body.output().first().and_then(|output| match output {
139 GreptimeQueryOutput::Records(records) => Some(records.rows().clone()),
140 GreptimeQueryOutput::AffectedRows(_) => None,
141 }))
142 }
143}
144
145pub(crate) fn split_database(database: &str) -> Result<(String, Option<String>)> {
147 let (catalog, schema) = match database.split_once('-') {
148 Some((catalog, schema)) => (catalog, schema),
149 None => (DEFAULT_CATALOG_NAME, database),
150 };
151
152 if schema == "*" {
153 Ok((catalog.to_string(), None))
154 } else {
155 Ok((catalog.to_string(), Some(schema.to_string())))
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[test]
164 fn test_split_database() {
165 let result = split_database("catalog-schema").unwrap();
166 assert_eq!(result, ("catalog".to_string(), Some("schema".to_string())));
167
168 let result = split_database("schema").unwrap();
169 assert_eq!(result, ("greptime".to_string(), Some("schema".to_string())));
170
171 let result = split_database("catalog-*").unwrap();
172 assert_eq!(result, ("catalog".to_string(), None));
173
174 let result = split_database("*").unwrap();
175 assert_eq!(result, ("greptime".to_string(), None));
176 }
177}