client/
database.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::str::FromStr;
17
18use api::v1::auth_header::AuthScheme;
19use api::v1::ddl_request::Expr as DdlExpr;
20use api::v1::greptime_database_client::GreptimeDatabaseClient;
21use api::v1::greptime_request::Request;
22use api::v1::query_request::Query;
23use api::v1::{
24    AlterTableExpr, AuthHeader, Basic, CreateTableExpr, DdlRequest, GreptimeRequest,
25    InsertRequests, QueryRequest, RequestHeader,
26};
27use arrow_flight::{FlightData, Ticket};
28use async_stream::stream;
29use base64::prelude::BASE64_STANDARD;
30use base64::Engine;
31use common_catalog::build_db_string;
32use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
33use common_error::ext::{BoxedError, ErrorExt};
34use common_grpc::flight::do_put::DoPutResponse;
35use common_grpc::flight::{FlightDecoder, FlightMessage};
36use common_query::Output;
37use common_recordbatch::error::ExternalSnafu;
38use common_recordbatch::RecordBatchStreamWrapper;
39use common_telemetry::tracing_context::W3cTrace;
40use common_telemetry::{error, warn};
41use futures::future;
42use futures_util::{Stream, StreamExt, TryStreamExt};
43use prost::Message;
44use snafu::{ensure, ResultExt};
45use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue, MetadataMap, MetadataValue};
46use tonic::transport::Channel;
47
48use crate::error::{
49    ConvertFlightDataSnafu, Error, FlightGetSnafu, IllegalFlightMessagesSnafu,
50    InvalidTonicMetadataValueSnafu, ServerSnafu,
51};
52use crate::{from_grpc_response, Client, Result};
53
54type FlightDataStream = Pin<Box<dyn Stream<Item = FlightData> + Send>>;
55
56type DoPutResponseStream = Pin<Box<dyn Stream<Item = Result<DoPutResponse>>>>;
57
58#[derive(Clone, Debug, Default)]
59pub struct Database {
60    // The "catalog" and "schema" to be used in processing the requests at the server side.
61    // They are the "hint" or "context", just like how the "database" in "USE" statement is treated in MySQL.
62    // They will be carried in the request header.
63    catalog: String,
64    schema: String,
65    // The dbname follows naming rule as out mysql, postgres and http
66    // protocol. The server treat dbname in priority of catalog/schema.
67    dbname: String,
68    // The time zone indicates the time zone where the user is located.
69    // Some queries need to be aware of the user's time zone to perform some specific actions.
70    timezone: String,
71
72    client: Client,
73    ctx: FlightContext,
74}
75
76pub struct DatabaseClient {
77    pub inner: GreptimeDatabaseClient<Channel>,
78}
79
80fn make_database_client(client: &Client) -> Result<DatabaseClient> {
81    let (_, channel) = client.find_channel()?;
82    Ok(DatabaseClient {
83        inner: GreptimeDatabaseClient::new(channel)
84            .max_decoding_message_size(client.max_grpc_recv_message_size())
85            .max_encoding_message_size(client.max_grpc_send_message_size()),
86    })
87}
88
89impl Database {
90    /// Create database service client using catalog and schema
91    pub fn new(catalog: impl Into<String>, schema: impl Into<String>, client: Client) -> Self {
92        Self {
93            catalog: catalog.into(),
94            schema: schema.into(),
95            dbname: String::default(),
96            timezone: String::default(),
97            client,
98            ctx: FlightContext::default(),
99        }
100    }
101
102    /// Create database service client using dbname.
103    ///
104    /// This API is designed for external usage. `dbname` is:
105    ///
106    /// - the name of database when using GreptimeDB standalone or cluster
107    /// - the name provided by GreptimeCloud or other multi-tenant GreptimeDB
108    ///   environment
109    pub fn new_with_dbname(dbname: impl Into<String>, client: Client) -> Self {
110        Self {
111            catalog: String::default(),
112            schema: String::default(),
113            timezone: String::default(),
114            dbname: dbname.into(),
115            client,
116            ctx: FlightContext::default(),
117        }
118    }
119
120    pub fn set_catalog(&mut self, catalog: impl Into<String>) {
121        self.catalog = catalog.into();
122    }
123
124    fn catalog_or_default(&self) -> &str {
125        if self.catalog.is_empty() {
126            DEFAULT_CATALOG_NAME
127        } else {
128            &self.catalog
129        }
130    }
131
132    pub fn set_schema(&mut self, schema: impl Into<String>) {
133        self.schema = schema.into();
134    }
135
136    fn schema_or_default(&self) -> &str {
137        if self.schema.is_empty() {
138            DEFAULT_SCHEMA_NAME
139        } else {
140            &self.schema
141        }
142    }
143
144    pub fn set_timezone(&mut self, timezone: impl Into<String>) {
145        self.timezone = timezone.into();
146    }
147
148    pub fn set_auth(&mut self, auth: AuthScheme) {
149        self.ctx.auth_header = Some(AuthHeader {
150            auth_scheme: Some(auth),
151        });
152    }
153
154    pub async fn insert(&self, requests: InsertRequests) -> Result<u32> {
155        self.handle(Request::Inserts(requests)).await
156    }
157
158    pub async fn insert_with_hints(
159        &self,
160        requests: InsertRequests,
161        hints: &[(&str, &str)],
162    ) -> Result<u32> {
163        let mut client = make_database_client(&self.client)?.inner;
164        let request = self.to_rpc_request(Request::Inserts(requests));
165
166        let mut request = tonic::Request::new(request);
167        let metadata = request.metadata_mut();
168        Self::put_hints(metadata, hints)?;
169
170        let response = client.handle(request).await?.into_inner();
171        from_grpc_response(response)
172    }
173
174    fn put_hints(metadata: &mut MetadataMap, hints: &[(&str, &str)]) -> Result<()> {
175        let Some(value) = hints
176            .iter()
177            .map(|(k, v)| format!("{}={}", k, v))
178            .reduce(|a, b| format!("{},{}", a, b))
179        else {
180            return Ok(());
181        };
182
183        let key = AsciiMetadataKey::from_static("x-greptime-hints");
184        let value = AsciiMetadataValue::from_str(&value).context(InvalidTonicMetadataValueSnafu)?;
185        metadata.insert(key, value);
186        Ok(())
187    }
188
189    pub async fn handle(&self, request: Request) -> Result<u32> {
190        let mut client = make_database_client(&self.client)?.inner;
191        let request = self.to_rpc_request(request);
192        let response = client.handle(request).await?.into_inner();
193        from_grpc_response(response)
194    }
195
196    /// Retry if connection fails, max_retries is the max number of retries, so the total wait time
197    /// is `max_retries * GRPC_CONN_TIMEOUT`
198    pub async fn handle_with_retry(&self, request: Request, max_retries: u32) -> Result<u32> {
199        let mut client = make_database_client(&self.client)?.inner;
200        let mut retries = 0;
201        let request = self.to_rpc_request(request);
202        loop {
203            let raw_response = client.handle(request.clone()).await;
204            match (raw_response, retries < max_retries) {
205                (Ok(resp), _) => return from_grpc_response(resp.into_inner()),
206                (Err(err), true) => {
207                    // determine if the error is retryable
208                    if is_grpc_retryable(&err) {
209                        // retry
210                        retries += 1;
211                        warn!("Retrying {} times with error = {:?}", retries, err);
212                        continue;
213                    }
214                }
215                (Err(err), false) => {
216                    error!(
217                        "Failed to send request to grpc handle after {} retries, error = {:?}",
218                        retries, err
219                    );
220                    return Err(err.into());
221                }
222            }
223        }
224    }
225
226    #[inline]
227    fn to_rpc_request(&self, request: Request) -> GreptimeRequest {
228        GreptimeRequest {
229            header: Some(RequestHeader {
230                catalog: self.catalog.clone(),
231                schema: self.schema.clone(),
232                authorization: self.ctx.auth_header.clone(),
233                dbname: self.dbname.clone(),
234                timezone: self.timezone.clone(),
235                // TODO(Taylor-lagrange): add client grpc tracing
236                tracing_context: W3cTrace::new(),
237            }),
238            request: Some(request),
239        }
240    }
241
242    pub async fn sql<S>(&self, sql: S) -> Result<Output>
243    where
244        S: AsRef<str>,
245    {
246        self.sql_with_hint(sql, &[]).await
247    }
248
249    pub async fn sql_with_hint<S>(&self, sql: S, hints: &[(&str, &str)]) -> Result<Output>
250    where
251        S: AsRef<str>,
252    {
253        let request = Request::Query(QueryRequest {
254            query: Some(Query::Sql(sql.as_ref().to_string())),
255        });
256        self.do_get(request, hints).await
257    }
258
259    pub async fn logical_plan(&self, logical_plan: Vec<u8>) -> Result<Output> {
260        let request = Request::Query(QueryRequest {
261            query: Some(Query::LogicalPlan(logical_plan)),
262        });
263        self.do_get(request, &[]).await
264    }
265
266    pub async fn create(&self, expr: CreateTableExpr) -> Result<Output> {
267        let request = Request::Ddl(DdlRequest {
268            expr: Some(DdlExpr::CreateTable(expr)),
269        });
270        self.do_get(request, &[]).await
271    }
272
273    pub async fn alter(&self, expr: AlterTableExpr) -> Result<Output> {
274        let request = Request::Ddl(DdlRequest {
275            expr: Some(DdlExpr::AlterTable(expr)),
276        });
277        self.do_get(request, &[]).await
278    }
279
280    async fn do_get(&self, request: Request, hints: &[(&str, &str)]) -> Result<Output> {
281        let request = self.to_rpc_request(request);
282        let request = Ticket {
283            ticket: request.encode_to_vec().into(),
284        };
285
286        let mut request = tonic::Request::new(request);
287        Self::put_hints(request.metadata_mut(), hints)?;
288
289        let mut client = self.client.make_flight_client()?;
290
291        let response = client.mut_inner().do_get(request).await.or_else(|e| {
292            let tonic_code = e.code();
293            let e: Error = e.into();
294            let code = e.status_code();
295            let msg = e.to_string();
296            let error =
297                Err(BoxedError::new(ServerSnafu { code, msg }.build())).with_context(|_| {
298                    FlightGetSnafu {
299                        addr: client.addr().to_string(),
300                        tonic_code,
301                    }
302                });
303            error!(
304                "Failed to do Flight get, addr: {}, code: {}, source: {:?}",
305                client.addr(),
306                tonic_code,
307                error
308            );
309            error
310        })?;
311
312        let flight_data_stream = response.into_inner();
313        let mut decoder = FlightDecoder::default();
314
315        let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
316            flight_data
317                .map_err(Error::from)
318                .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))
319        });
320
321        let Some(first_flight_message) = flight_message_stream.next().await else {
322            return IllegalFlightMessagesSnafu {
323                reason: "Expect the response not to be empty",
324            }
325            .fail();
326        };
327
328        let first_flight_message = first_flight_message?;
329
330        match first_flight_message {
331            FlightMessage::AffectedRows(rows) => {
332                ensure!(
333                    flight_message_stream.next().await.is_none(),
334                    IllegalFlightMessagesSnafu {
335                        reason: "Expect 'AffectedRows' Flight messages to be the one and the only!"
336                    }
337                );
338                Ok(Output::new_with_affected_rows(rows))
339            }
340            FlightMessage::Recordbatch(_) | FlightMessage::Metrics(_) => {
341                IllegalFlightMessagesSnafu {
342                    reason: "The first flight message cannot be a RecordBatch or Metrics message",
343                }
344                .fail()
345            }
346            FlightMessage::Schema(schema) => {
347                let stream = Box::pin(stream!({
348                    while let Some(flight_message) = flight_message_stream.next().await {
349                        let flight_message = flight_message
350                            .map_err(BoxedError::new)
351                            .context(ExternalSnafu)?;
352                        match flight_message {
353                            FlightMessage::Recordbatch(record_batch) => yield Ok(record_batch),
354                            FlightMessage::Metrics(_) => {}
355                            FlightMessage::AffectedRows(_) | FlightMessage::Schema(_) => {
356                                yield IllegalFlightMessagesSnafu {reason: format!("A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}", flight_message)}
357                                        .fail()
358                                        .map_err(BoxedError::new)
359                                        .context(ExternalSnafu);
360                                break;
361                            }
362                        }
363                    }
364                }));
365                let record_batch_stream = RecordBatchStreamWrapper {
366                    schema,
367                    stream,
368                    output_ordering: None,
369                    metrics: Default::default(),
370                };
371                Ok(Output::new_with_stream(Box::pin(record_batch_stream)))
372            }
373        }
374    }
375
376    /// Ingest a stream of [RecordBatch]es that belong to a table, using Arrow Flight's "`DoPut`"
377    /// method. The return value is also a stream, produces [DoPutResponse]s.
378    pub async fn do_put(&self, stream: FlightDataStream) -> Result<DoPutResponseStream> {
379        let mut request = tonic::Request::new(stream);
380
381        if let Some(AuthHeader {
382            auth_scheme: Some(AuthScheme::Basic(Basic { username, password })),
383        }) = &self.ctx.auth_header
384        {
385            let encoded = BASE64_STANDARD.encode(format!("{username}:{password}"));
386            let value =
387                MetadataValue::from_str(&encoded).context(InvalidTonicMetadataValueSnafu)?;
388            request.metadata_mut().insert("x-greptime-auth", value);
389        }
390
391        let db_to_put = if !self.dbname.is_empty() {
392            &self.dbname
393        } else {
394            &build_db_string(self.catalog_or_default(), self.schema_or_default())
395        };
396        request.metadata_mut().insert(
397            "x-greptime-db-name",
398            MetadataValue::from_str(db_to_put).context(InvalidTonicMetadataValueSnafu)?,
399        );
400
401        let mut client = self.client.make_flight_client()?;
402        let response = client.mut_inner().do_put(request).await?;
403        let response = response
404            .into_inner()
405            .map_err(Into::into)
406            .and_then(|x| future::ready(DoPutResponse::try_from(x).context(ConvertFlightDataSnafu)))
407            .boxed();
408        Ok(response)
409    }
410}
411
412/// by grpc standard, only `Unavailable` is retryable, see: https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
413pub fn is_grpc_retryable(err: &tonic::Status) -> bool {
414    matches!(err.code(), tonic::Code::Unavailable)
415}
416
417#[derive(Default, Debug, Clone)]
418struct FlightContext {
419    auth_header: Option<AuthHeader>,
420}
421
422#[cfg(test)]
423mod tests {
424    use std::assert_matches::assert_matches;
425
426    use api::v1::auth_header::AuthScheme;
427    use api::v1::{AuthHeader, Basic};
428
429    use super::*;
430
431    #[test]
432    fn test_flight_ctx() {
433        let mut ctx = FlightContext::default();
434        assert!(ctx.auth_header.is_none());
435
436        let basic = AuthScheme::Basic(Basic {
437            username: "u".to_string(),
438            password: "p".to_string(),
439        });
440
441        ctx.auth_header = Some(AuthHeader {
442            auth_scheme: Some(basic),
443        });
444
445        assert_matches!(
446            ctx.auth_header,
447            Some(AuthHeader {
448                auth_scheme: Some(AuthScheme::Basic(_)),
449            })
450        )
451    }
452}