client/
database.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::Arc;
18
19use api::v1::auth_header::AuthScheme;
20use api::v1::ddl_request::Expr as DdlExpr;
21use api::v1::greptime_database_client::GreptimeDatabaseClient;
22use api::v1::greptime_request::Request;
23use api::v1::query_request::Query;
24use api::v1::{
25    AlterTableExpr, AuthHeader, Basic, CreateTableExpr, DdlRequest, GreptimeRequest,
26    InsertRequests, QueryRequest, RequestHeader, RowInsertRequests,
27};
28use arrow_flight::{FlightData, Ticket};
29use async_stream::stream;
30use base64::Engine;
31use base64::prelude::BASE64_STANDARD;
32use common_catalog::build_db_string;
33use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
34use common_error::ext::BoxedError;
35use common_grpc::flight::do_put::DoPutResponse;
36use common_grpc::flight::{FlightDecoder, FlightMessage};
37use common_query::Output;
38use common_recordbatch::error::ExternalSnafu;
39use common_recordbatch::{RecordBatch, RecordBatchStreamWrapper};
40use common_telemetry::tracing::Span;
41use common_telemetry::tracing_context::W3cTrace;
42use common_telemetry::{error, warn};
43use futures::future;
44use futures_util::{Stream, StreamExt, TryStreamExt};
45use prost::Message;
46use snafu::{OptionExt, ResultExt, ensure};
47use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue, MetadataMap, MetadataValue};
48use tonic::transport::Channel;
49
50use crate::error::{
51    ConvertFlightDataSnafu, Error, FlightGetSnafu, IllegalFlightMessagesSnafu,
52    InvalidTonicMetadataValueSnafu,
53};
54use crate::{Client, Result, error, from_grpc_response};
55
56type FlightDataStream = Pin<Box<dyn Stream<Item = FlightData> + Send>>;
57
58type DoPutResponseStream = Pin<Box<dyn Stream<Item = Result<DoPutResponse>>>>;
59
60#[derive(Clone, Debug, Default)]
61pub struct Database {
62    // The "catalog" and "schema" to be used in processing the requests at the server side.
63    // They are the "hint" or "context", just like how the "database" in "USE" statement is treated in MySQL.
64    // They will be carried in the request header.
65    catalog: String,
66    schema: String,
67    // The dbname follows naming rule as out mysql, postgres and http
68    // protocol. The server treat dbname in priority of catalog/schema.
69    dbname: String,
70    // The time zone indicates the time zone where the user is located.
71    // Some queries need to be aware of the user's time zone to perform some specific actions.
72    timezone: String,
73
74    client: Client,
75    ctx: FlightContext,
76}
77
78pub struct DatabaseClient {
79    pub addr: String,
80    pub inner: GreptimeDatabaseClient<Channel>,
81}
82
83impl DatabaseClient {
84    /// Returns a closure that logs the error when the request fails.
85    pub fn inspect_err<'a>(&'a self, context: &'a str) -> impl Fn(&tonic::Status) + 'a {
86        let addr = &self.addr;
87        move |status| {
88            error!("Failed to {context} request, peer: {addr}, status: {status:?}");
89        }
90    }
91}
92
93fn make_database_client(client: &Client) -> Result<DatabaseClient> {
94    let (addr, channel) = client.find_channel()?;
95    Ok(DatabaseClient {
96        addr,
97        inner: GreptimeDatabaseClient::new(channel)
98            .max_decoding_message_size(client.max_grpc_recv_message_size())
99            .max_encoding_message_size(client.max_grpc_send_message_size()),
100    })
101}
102
103impl Database {
104    /// Create database service client using catalog and schema
105    pub fn new(catalog: impl Into<String>, schema: impl Into<String>, client: Client) -> Self {
106        Self {
107            catalog: catalog.into(),
108            schema: schema.into(),
109            dbname: String::default(),
110            timezone: String::default(),
111            client,
112            ctx: FlightContext::default(),
113        }
114    }
115
116    /// Create database service client using dbname.
117    ///
118    /// This API is designed for external usage. `dbname` is:
119    ///
120    /// - the name of database when using GreptimeDB standalone or cluster
121    /// - the name provided by GreptimeCloud or other multi-tenant GreptimeDB
122    ///   environment
123    pub fn new_with_dbname(dbname: impl Into<String>, client: Client) -> Self {
124        Self {
125            catalog: String::default(),
126            schema: String::default(),
127            timezone: String::default(),
128            dbname: dbname.into(),
129            client,
130            ctx: FlightContext::default(),
131        }
132    }
133
134    /// Set the catalog for the database client.
135    pub fn set_catalog(&mut self, catalog: impl Into<String>) {
136        self.catalog = catalog.into();
137    }
138
139    fn catalog_or_default(&self) -> &str {
140        if self.catalog.is_empty() {
141            DEFAULT_CATALOG_NAME
142        } else {
143            &self.catalog
144        }
145    }
146
147    /// Set the schema for the database client.
148    pub fn set_schema(&mut self, schema: impl Into<String>) {
149        self.schema = schema.into();
150    }
151
152    fn schema_or_default(&self) -> &str {
153        if self.schema.is_empty() {
154            DEFAULT_SCHEMA_NAME
155        } else {
156            &self.schema
157        }
158    }
159
160    /// Set the timezone for the database client.
161    pub fn set_timezone(&mut self, timezone: impl Into<String>) {
162        self.timezone = timezone.into();
163    }
164
165    /// Set the auth scheme for the database client.
166    pub fn set_auth(&mut self, auth: AuthScheme) {
167        self.ctx.auth_header = Some(AuthHeader {
168            auth_scheme: Some(auth),
169        });
170    }
171
172    /// Make an InsertRequests request to the database.
173    pub async fn insert(&self, requests: InsertRequests) -> Result<u32> {
174        self.handle(Request::Inserts(requests)).await
175    }
176
177    /// Make an InsertRequests request to the database with hints.
178    pub async fn insert_with_hints(
179        &self,
180        requests: InsertRequests,
181        hints: &[(&str, &str)],
182    ) -> Result<u32> {
183        let mut client = make_database_client(&self.client)?;
184        let request = self.to_rpc_request(Request::Inserts(requests));
185
186        let mut request = tonic::Request::new(request);
187        let metadata = request.metadata_mut();
188        Self::put_hints(metadata, hints)?;
189
190        let response = client
191            .inner
192            .handle(request)
193            .await
194            .inspect_err(client.inspect_err("insert_with_hints"))?
195            .into_inner();
196        from_grpc_response(response)
197    }
198
199    /// Make a RowInsertRequests request to the database.
200    pub async fn row_inserts(&self, requests: RowInsertRequests) -> Result<u32> {
201        self.handle(Request::RowInserts(requests)).await
202    }
203
204    /// Make a RowInsertRequests request to the database with hints.
205    pub async fn row_inserts_with_hints(
206        &self,
207        requests: RowInsertRequests,
208        hints: &[(&str, &str)],
209    ) -> Result<u32> {
210        let mut client = make_database_client(&self.client)?;
211        let request = self.to_rpc_request(Request::RowInserts(requests));
212
213        let mut request = tonic::Request::new(request);
214        let metadata = request.metadata_mut();
215        Self::put_hints(metadata, hints)?;
216
217        let response = client
218            .inner
219            .handle(request)
220            .await
221            .inspect_err(client.inspect_err("row_inserts_with_hints"))?
222            .into_inner();
223        from_grpc_response(response)
224    }
225
226    fn put_hints(metadata: &mut MetadataMap, hints: &[(&str, &str)]) -> Result<()> {
227        let Some(value) = hints
228            .iter()
229            .map(|(k, v)| format!("{}={}", k, v))
230            .reduce(|a, b| format!("{},{}", a, b))
231        else {
232            return Ok(());
233        };
234
235        let key = AsciiMetadataKey::from_static("x-greptime-hints");
236        let value = AsciiMetadataValue::from_str(&value).context(InvalidTonicMetadataValueSnafu)?;
237        metadata.insert(key, value);
238        Ok(())
239    }
240
241    /// Make a request to the database.
242    pub async fn handle(&self, request: Request) -> Result<u32> {
243        let mut client = make_database_client(&self.client)?;
244        let request = self.to_rpc_request(request);
245        let response = client
246            .inner
247            .handle(request)
248            .await
249            .inspect_err(client.inspect_err("handle"))?
250            .into_inner();
251        from_grpc_response(response)
252    }
253
254    /// Retry if connection fails, max_retries is the max number of retries, so the total wait time
255    /// is `max_retries * GRPC_CONN_TIMEOUT`
256    pub async fn handle_with_retry(
257        &self,
258        request: Request,
259        max_retries: u32,
260        hints: &[(&str, &str)],
261    ) -> Result<u32> {
262        let mut client = make_database_client(&self.client)?;
263        let mut retries = 0;
264
265        let request = self.to_rpc_request(request);
266
267        loop {
268            let mut tonic_request = tonic::Request::new(request.clone());
269            let metadata = tonic_request.metadata_mut();
270            Self::put_hints(metadata, hints)?;
271            let raw_response = client
272                .inner
273                .handle(tonic_request)
274                .await
275                .inspect_err(client.inspect_err("handle"));
276            match (raw_response, retries < max_retries) {
277                (Ok(resp), _) => return from_grpc_response(resp.into_inner()),
278                (Err(err), true) => {
279                    // determine if the error is retryable
280                    if is_grpc_retryable(&err) {
281                        // retry
282                        retries += 1;
283                        warn!("Retrying {} times with error = {:?}", retries, err);
284                        continue;
285                    } else {
286                        error!(
287                            err; "Failed to send request to grpc handle, retries = {}, not retryable error, aborting",
288                            retries
289                        );
290                        return Err(err.into());
291                    }
292                }
293                (Err(err), false) => {
294                    error!(
295                        err; "Failed to send request to grpc handle after {} retries",
296                        retries,
297                    );
298                    return Err(err.into());
299                }
300            }
301        }
302    }
303
304    #[inline]
305    fn to_rpc_request(&self, request: Request) -> GreptimeRequest {
306        GreptimeRequest {
307            header: Some(RequestHeader {
308                catalog: self.catalog.clone(),
309                schema: self.schema.clone(),
310                authorization: self.ctx.auth_header.clone(),
311                dbname: self.dbname.clone(),
312                timezone: self.timezone.clone(),
313                // TODO(Taylor-lagrange): add client grpc tracing
314                tracing_context: W3cTrace::new(),
315            }),
316            request: Some(request),
317        }
318    }
319
320    /// Executes a SQL query without any hints.
321    pub async fn sql<S>(&self, sql: S) -> Result<Output>
322    where
323        S: AsRef<str>,
324    {
325        self.sql_with_hint(sql, &[]).await
326    }
327
328    /// Executes a SQL query with optional hints for query optimization.
329    pub async fn sql_with_hint<S>(&self, sql: S, hints: &[(&str, &str)]) -> Result<Output>
330    where
331        S: AsRef<str>,
332    {
333        let request = Request::Query(QueryRequest {
334            query: Some(Query::Sql(sql.as_ref().to_string())),
335        });
336        self.do_get(request, hints).await
337    }
338
339    /// Executes a logical plan directly without SQL parsing.
340    pub async fn logical_plan(&self, logical_plan: Vec<u8>) -> Result<Output> {
341        let request = Request::Query(QueryRequest {
342            query: Some(Query::LogicalPlan(logical_plan)),
343        });
344        self.do_get(request, &[]).await
345    }
346
347    /// Creates a new table using the provided table expression.
348    pub async fn create(&self, expr: CreateTableExpr) -> Result<Output> {
349        let request = Request::Ddl(DdlRequest {
350            expr: Some(DdlExpr::CreateTable(expr)),
351        });
352        self.do_get(request, &[]).await
353    }
354
355    /// Alters an existing table using the provided alter expression.
356    pub async fn alter(&self, expr: AlterTableExpr) -> Result<Output> {
357        let request = Request::Ddl(DdlRequest {
358            expr: Some(DdlExpr::AlterTable(expr)),
359        });
360        self.do_get(request, &[]).await
361    }
362
363    async fn do_get(&self, request: Request, hints: &[(&str, &str)]) -> Result<Output> {
364        let request = self.to_rpc_request(request);
365        let request = Ticket {
366            ticket: request.encode_to_vec().into(),
367        };
368
369        let mut request = tonic::Request::new(request);
370        Self::put_hints(request.metadata_mut(), hints)?;
371
372        let mut client = self.client.make_flight_client(false, false)?;
373
374        let response = client.mut_inner().do_get(request).await.or_else(|e| {
375            let tonic_code = e.code();
376            let e: Error = e.into();
377            error!(
378                "Failed to do Flight get, addr: {}, code: {}, source: {:?}",
379                client.addr(),
380                tonic_code,
381                e
382            );
383            Err(BoxedError::new(e)).with_context(|_| FlightGetSnafu {
384                addr: client.addr().to_string(),
385                tonic_code,
386            })
387        })?;
388
389        let flight_data_stream = response.into_inner();
390        let mut decoder = FlightDecoder::default();
391
392        let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
393            flight_data
394                .map_err(Error::from)
395                .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))?
396                .context(IllegalFlightMessagesSnafu {
397                    reason: "none message",
398                })
399        });
400
401        let Some(first_flight_message) = flight_message_stream.next().await else {
402            return IllegalFlightMessagesSnafu {
403                reason: "Expect the response not to be empty",
404            }
405            .fail();
406        };
407
408        let first_flight_message = first_flight_message?;
409
410        match first_flight_message {
411            FlightMessage::AffectedRows(rows) => {
412                ensure!(
413                    flight_message_stream.next().await.is_none(),
414                    IllegalFlightMessagesSnafu {
415                        reason: "Expect 'AffectedRows' Flight messages to be the one and the only!"
416                    }
417                );
418                Ok(Output::new_with_affected_rows(rows))
419            }
420            FlightMessage::RecordBatch(_) | FlightMessage::Metrics(_) => {
421                IllegalFlightMessagesSnafu {
422                    reason: "The first flight message cannot be a RecordBatch or Metrics message",
423                }
424                .fail()
425            }
426            FlightMessage::Schema(schema) => {
427                let schema = Arc::new(
428                    datatypes::schema::Schema::try_from(schema)
429                        .context(error::ConvertSchemaSnafu)?,
430                );
431                let schema_cloned = schema.clone();
432                let stream = Box::pin(stream!({
433                    while let Some(flight_message) = flight_message_stream.next().await {
434                        let flight_message = flight_message
435                            .map_err(BoxedError::new)
436                            .context(ExternalSnafu)?;
437                        match flight_message {
438                            FlightMessage::RecordBatch(arrow_batch) => {
439                                yield Ok(RecordBatch::from_df_record_batch(
440                                    schema_cloned.clone(),
441                                    arrow_batch,
442                                ))
443                            }
444                            FlightMessage::Metrics(_) => {}
445                            FlightMessage::AffectedRows(_) | FlightMessage::Schema(_) => {
446                                yield IllegalFlightMessagesSnafu {reason: format!("A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}", flight_message)}
447                                        .fail()
448                                        .map_err(BoxedError::new)
449                                        .context(ExternalSnafu);
450                                break;
451                            }
452                        }
453                    }
454                }));
455                let record_batch_stream = RecordBatchStreamWrapper {
456                    schema,
457                    stream,
458                    output_ordering: None,
459                    metrics: Default::default(),
460                    span: Span::current(),
461                };
462                Ok(Output::new_with_stream(Box::pin(record_batch_stream)))
463            }
464        }
465    }
466
467    /// Ingest a stream of [RecordBatch]es that belong to a table, using Arrow Flight's "`DoPut`"
468    /// method. The return value is also a stream, produces [DoPutResponse]s.
469    pub async fn do_put(&self, stream: FlightDataStream) -> Result<DoPutResponseStream> {
470        let mut request = tonic::Request::new(stream);
471
472        if let Some(AuthHeader {
473            auth_scheme: Some(AuthScheme::Basic(Basic { username, password })),
474        }) = &self.ctx.auth_header
475        {
476            let encoded = BASE64_STANDARD.encode(format!("{username}:{password}"));
477            let value = MetadataValue::from_str(&format!("Basic {encoded}"))
478                .context(InvalidTonicMetadataValueSnafu)?;
479            request.metadata_mut().insert("x-greptime-auth", value);
480        }
481
482        let db_to_put = if !self.dbname.is_empty() {
483            &self.dbname
484        } else {
485            &build_db_string(self.catalog_or_default(), self.schema_or_default())
486        };
487        request.metadata_mut().insert(
488            "x-greptime-db-name",
489            MetadataValue::from_str(db_to_put).context(InvalidTonicMetadataValueSnafu)?,
490        );
491
492        let mut client = self.client.make_flight_client(false, false)?;
493        let response = client.mut_inner().do_put(request).await?;
494        let response = response
495            .into_inner()
496            .map_err(Into::into)
497            .and_then(|x| future::ready(DoPutResponse::try_from(x).context(ConvertFlightDataSnafu)))
498            .boxed();
499        Ok(response)
500    }
501}
502
503/// by grpc standard, only `Unavailable` is retryable, see: https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
504pub fn is_grpc_retryable(err: &tonic::Status) -> bool {
505    matches!(err.code(), tonic::Code::Unavailable)
506}
507
508#[derive(Default, Debug, Clone)]
509struct FlightContext {
510    auth_header: Option<AuthHeader>,
511}
512
513#[cfg(test)]
514mod tests {
515    use std::assert_matches::assert_matches;
516
517    use api::v1::auth_header::AuthScheme;
518    use api::v1::{AuthHeader, Basic};
519    use common_error::status_code::StatusCode;
520    use tonic::{Code, Status};
521
522    use super::*;
523    use crate::error::TonicSnafu;
524
525    #[test]
526    fn test_flight_ctx() {
527        let mut ctx = FlightContext::default();
528        assert!(ctx.auth_header.is_none());
529
530        let basic = AuthScheme::Basic(Basic {
531            username: "u".to_string(),
532            password: "p".to_string(),
533        });
534
535        ctx.auth_header = Some(AuthHeader {
536            auth_scheme: Some(basic),
537        });
538
539        assert_matches!(
540            ctx.auth_header,
541            Some(AuthHeader {
542                auth_scheme: Some(AuthScheme::Basic(_)),
543            })
544        )
545    }
546
547    #[test]
548    fn test_from_tonic_status() {
549        let expected = TonicSnafu {
550            code: StatusCode::Internal,
551            msg: "blabla".to_string(),
552            tonic_code: Code::Internal,
553        }
554        .build();
555
556        let status = Status::new(Code::Internal, "blabla");
557        let actual: Error = status.into();
558
559        assert_eq!(expected.to_string(), actual.to_string());
560    }
561}