client/
database.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::Arc;
18
19use api::v1::auth_header::AuthScheme;
20use api::v1::ddl_request::Expr as DdlExpr;
21use api::v1::greptime_database_client::GreptimeDatabaseClient;
22use api::v1::greptime_request::Request;
23use api::v1::query_request::Query;
24use api::v1::{
25    AlterTableExpr, AuthHeader, Basic, CreateTableExpr, DdlRequest, GreptimeRequest,
26    InsertRequests, QueryRequest, RequestHeader, RowInsertRequests,
27};
28use arrow_flight::{FlightData, Ticket};
29use async_stream::stream;
30use base64::prelude::BASE64_STANDARD;
31use base64::Engine;
32use common_catalog::build_db_string;
33use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
34use common_error::ext::BoxedError;
35use common_grpc::flight::do_put::DoPutResponse;
36use common_grpc::flight::{FlightDecoder, FlightMessage};
37use common_query::Output;
38use common_recordbatch::error::ExternalSnafu;
39use common_recordbatch::{RecordBatch, RecordBatchStreamWrapper};
40use common_telemetry::tracing_context::W3cTrace;
41use common_telemetry::{error, warn};
42use futures::future;
43use futures_util::{Stream, StreamExt, TryStreamExt};
44use prost::Message;
45use snafu::{ensure, OptionExt, ResultExt};
46use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue, MetadataMap, MetadataValue};
47use tonic::transport::Channel;
48
49use crate::error::{
50    ConvertFlightDataSnafu, Error, FlightGetSnafu, IllegalFlightMessagesSnafu,
51    InvalidTonicMetadataValueSnafu,
52};
53use crate::{error, from_grpc_response, Client, Result};
54
55type FlightDataStream = Pin<Box<dyn Stream<Item = FlightData> + Send>>;
56
57type DoPutResponseStream = Pin<Box<dyn Stream<Item = Result<DoPutResponse>>>>;
58
59#[derive(Clone, Debug, Default)]
60pub struct Database {
61    // The "catalog" and "schema" to be used in processing the requests at the server side.
62    // They are the "hint" or "context", just like how the "database" in "USE" statement is treated in MySQL.
63    // They will be carried in the request header.
64    catalog: String,
65    schema: String,
66    // The dbname follows naming rule as out mysql, postgres and http
67    // protocol. The server treat dbname in priority of catalog/schema.
68    dbname: String,
69    // The time zone indicates the time zone where the user is located.
70    // Some queries need to be aware of the user's time zone to perform some specific actions.
71    timezone: String,
72
73    client: Client,
74    ctx: FlightContext,
75}
76
77pub struct DatabaseClient {
78    pub addr: String,
79    pub inner: GreptimeDatabaseClient<Channel>,
80}
81
82impl DatabaseClient {
83    /// Returns a closure that logs the error when the request fails.
84    pub fn inspect_err<'a>(&'a self, context: &'a str) -> impl Fn(&tonic::Status) + 'a {
85        let addr = &self.addr;
86        move |status| {
87            error!("Failed to {context} request, peer: {addr}, status: {status:?}");
88        }
89    }
90}
91
92fn make_database_client(client: &Client) -> Result<DatabaseClient> {
93    let (addr, channel) = client.find_channel()?;
94    Ok(DatabaseClient {
95        addr,
96        inner: GreptimeDatabaseClient::new(channel)
97            .max_decoding_message_size(client.max_grpc_recv_message_size())
98            .max_encoding_message_size(client.max_grpc_send_message_size()),
99    })
100}
101
102impl Database {
103    /// Create database service client using catalog and schema
104    pub fn new(catalog: impl Into<String>, schema: impl Into<String>, client: Client) -> Self {
105        Self {
106            catalog: catalog.into(),
107            schema: schema.into(),
108            dbname: String::default(),
109            timezone: String::default(),
110            client,
111            ctx: FlightContext::default(),
112        }
113    }
114
115    /// Create database service client using dbname.
116    ///
117    /// This API is designed for external usage. `dbname` is:
118    ///
119    /// - the name of database when using GreptimeDB standalone or cluster
120    /// - the name provided by GreptimeCloud or other multi-tenant GreptimeDB
121    ///   environment
122    pub fn new_with_dbname(dbname: impl Into<String>, client: Client) -> Self {
123        Self {
124            catalog: String::default(),
125            schema: String::default(),
126            timezone: String::default(),
127            dbname: dbname.into(),
128            client,
129            ctx: FlightContext::default(),
130        }
131    }
132
133    /// Set the catalog for the database client.
134    pub fn set_catalog(&mut self, catalog: impl Into<String>) {
135        self.catalog = catalog.into();
136    }
137
138    fn catalog_or_default(&self) -> &str {
139        if self.catalog.is_empty() {
140            DEFAULT_CATALOG_NAME
141        } else {
142            &self.catalog
143        }
144    }
145
146    /// Set the schema for the database client.
147    pub fn set_schema(&mut self, schema: impl Into<String>) {
148        self.schema = schema.into();
149    }
150
151    fn schema_or_default(&self) -> &str {
152        if self.schema.is_empty() {
153            DEFAULT_SCHEMA_NAME
154        } else {
155            &self.schema
156        }
157    }
158
159    /// Set the timezone for the database client.
160    pub fn set_timezone(&mut self, timezone: impl Into<String>) {
161        self.timezone = timezone.into();
162    }
163
164    /// Set the auth scheme for the database client.
165    pub fn set_auth(&mut self, auth: AuthScheme) {
166        self.ctx.auth_header = Some(AuthHeader {
167            auth_scheme: Some(auth),
168        });
169    }
170
171    /// Make an InsertRequests request to the database.
172    pub async fn insert(&self, requests: InsertRequests) -> Result<u32> {
173        self.handle(Request::Inserts(requests)).await
174    }
175
176    /// Make an InsertRequests request to the database with hints.
177    pub async fn insert_with_hints(
178        &self,
179        requests: InsertRequests,
180        hints: &[(&str, &str)],
181    ) -> Result<u32> {
182        let mut client = make_database_client(&self.client)?;
183        let request = self.to_rpc_request(Request::Inserts(requests));
184
185        let mut request = tonic::Request::new(request);
186        let metadata = request.metadata_mut();
187        Self::put_hints(metadata, hints)?;
188
189        let response = client
190            .inner
191            .handle(request)
192            .await
193            .inspect_err(client.inspect_err("insert_with_hints"))?
194            .into_inner();
195        from_grpc_response(response)
196    }
197
198    /// Make a RowInsertRequests request to the database.
199    pub async fn row_inserts(&self, requests: RowInsertRequests) -> Result<u32> {
200        self.handle(Request::RowInserts(requests)).await
201    }
202
203    /// Make a RowInsertRequests request to the database with hints.
204    pub async fn row_inserts_with_hints(
205        &self,
206        requests: RowInsertRequests,
207        hints: &[(&str, &str)],
208    ) -> Result<u32> {
209        let mut client = make_database_client(&self.client)?;
210        let request = self.to_rpc_request(Request::RowInserts(requests));
211
212        let mut request = tonic::Request::new(request);
213        let metadata = request.metadata_mut();
214        Self::put_hints(metadata, hints)?;
215
216        let response = client
217            .inner
218            .handle(request)
219            .await
220            .inspect_err(client.inspect_err("row_inserts_with_hints"))?
221            .into_inner();
222        from_grpc_response(response)
223    }
224
225    fn put_hints(metadata: &mut MetadataMap, hints: &[(&str, &str)]) -> Result<()> {
226        let Some(value) = hints
227            .iter()
228            .map(|(k, v)| format!("{}={}", k, v))
229            .reduce(|a, b| format!("{},{}", a, b))
230        else {
231            return Ok(());
232        };
233
234        let key = AsciiMetadataKey::from_static("x-greptime-hints");
235        let value = AsciiMetadataValue::from_str(&value).context(InvalidTonicMetadataValueSnafu)?;
236        metadata.insert(key, value);
237        Ok(())
238    }
239
240    /// Make a request to the database.
241    pub async fn handle(&self, request: Request) -> Result<u32> {
242        let mut client = make_database_client(&self.client)?;
243        let request = self.to_rpc_request(request);
244        let response = client
245            .inner
246            .handle(request)
247            .await
248            .inspect_err(client.inspect_err("handle"))?
249            .into_inner();
250        from_grpc_response(response)
251    }
252
253    /// Retry if connection fails, max_retries is the max number of retries, so the total wait time
254    /// is `max_retries * GRPC_CONN_TIMEOUT`
255    pub async fn handle_with_retry(
256        &self,
257        request: Request,
258        max_retries: u32,
259        hints: &[(&str, &str)],
260    ) -> Result<u32> {
261        let mut client = make_database_client(&self.client)?;
262        let mut retries = 0;
263
264        let request = self.to_rpc_request(request);
265
266        loop {
267            let mut tonic_request = tonic::Request::new(request.clone());
268            let metadata = tonic_request.metadata_mut();
269            Self::put_hints(metadata, hints)?;
270            let raw_response = client
271                .inner
272                .handle(tonic_request)
273                .await
274                .inspect_err(client.inspect_err("handle"));
275            match (raw_response, retries < max_retries) {
276                (Ok(resp), _) => return from_grpc_response(resp.into_inner()),
277                (Err(err), true) => {
278                    // determine if the error is retryable
279                    if is_grpc_retryable(&err) {
280                        // retry
281                        retries += 1;
282                        warn!("Retrying {} times with error = {:?}", retries, err);
283                        continue;
284                    } else {
285                        error!(
286                            err; "Failed to send request to grpc handle, retries = {}, not retryable error, aborting",
287                            retries
288                        );
289                        return Err(err.into());
290                    }
291                }
292                (Err(err), false) => {
293                    error!(
294                        err; "Failed to send request to grpc handle after {} retries",
295                        retries,
296                    );
297                    return Err(err.into());
298                }
299            }
300        }
301    }
302
303    #[inline]
304    fn to_rpc_request(&self, request: Request) -> GreptimeRequest {
305        GreptimeRequest {
306            header: Some(RequestHeader {
307                catalog: self.catalog.clone(),
308                schema: self.schema.clone(),
309                authorization: self.ctx.auth_header.clone(),
310                dbname: self.dbname.clone(),
311                timezone: self.timezone.clone(),
312                // TODO(Taylor-lagrange): add client grpc tracing
313                tracing_context: W3cTrace::new(),
314            }),
315            request: Some(request),
316        }
317    }
318
319    /// Executes a SQL query without any hints.
320    pub async fn sql<S>(&self, sql: S) -> Result<Output>
321    where
322        S: AsRef<str>,
323    {
324        self.sql_with_hint(sql, &[]).await
325    }
326
327    /// Executes a SQL query with optional hints for query optimization.
328    pub async fn sql_with_hint<S>(&self, sql: S, hints: &[(&str, &str)]) -> Result<Output>
329    where
330        S: AsRef<str>,
331    {
332        let request = Request::Query(QueryRequest {
333            query: Some(Query::Sql(sql.as_ref().to_string())),
334        });
335        self.do_get(request, hints).await
336    }
337
338    /// Executes a logical plan directly without SQL parsing.
339    pub async fn logical_plan(&self, logical_plan: Vec<u8>) -> Result<Output> {
340        let request = Request::Query(QueryRequest {
341            query: Some(Query::LogicalPlan(logical_plan)),
342        });
343        self.do_get(request, &[]).await
344    }
345
346    /// Creates a new table using the provided table expression.
347    pub async fn create(&self, expr: CreateTableExpr) -> Result<Output> {
348        let request = Request::Ddl(DdlRequest {
349            expr: Some(DdlExpr::CreateTable(expr)),
350        });
351        self.do_get(request, &[]).await
352    }
353
354    /// Alters an existing table using the provided alter expression.
355    pub async fn alter(&self, expr: AlterTableExpr) -> Result<Output> {
356        let request = Request::Ddl(DdlRequest {
357            expr: Some(DdlExpr::AlterTable(expr)),
358        });
359        self.do_get(request, &[]).await
360    }
361
362    async fn do_get(&self, request: Request, hints: &[(&str, &str)]) -> Result<Output> {
363        let request = self.to_rpc_request(request);
364        let request = Ticket {
365            ticket: request.encode_to_vec().into(),
366        };
367
368        let mut request = tonic::Request::new(request);
369        Self::put_hints(request.metadata_mut(), hints)?;
370
371        let mut client = self.client.make_flight_client(false, false)?;
372
373        let response = client.mut_inner().do_get(request).await.or_else(|e| {
374            let tonic_code = e.code();
375            let e: Error = e.into();
376            error!(
377                "Failed to do Flight get, addr: {}, code: {}, source: {:?}",
378                client.addr(),
379                tonic_code,
380                e
381            );
382            let error = Err(BoxedError::new(e)).with_context(|_| FlightGetSnafu {
383                addr: client.addr().to_string(),
384                tonic_code,
385            });
386            error
387        })?;
388
389        let flight_data_stream = response.into_inner();
390        let mut decoder = FlightDecoder::default();
391
392        let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
393            flight_data
394                .map_err(Error::from)
395                .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))?
396                .context(IllegalFlightMessagesSnafu {
397                    reason: "none message",
398                })
399        });
400
401        let Some(first_flight_message) = flight_message_stream.next().await else {
402            return IllegalFlightMessagesSnafu {
403                reason: "Expect the response not to be empty",
404            }
405            .fail();
406        };
407
408        let first_flight_message = first_flight_message?;
409
410        match first_flight_message {
411            FlightMessage::AffectedRows(rows) => {
412                ensure!(
413                    flight_message_stream.next().await.is_none(),
414                    IllegalFlightMessagesSnafu {
415                        reason: "Expect 'AffectedRows' Flight messages to be the one and the only!"
416                    }
417                );
418                Ok(Output::new_with_affected_rows(rows))
419            }
420            FlightMessage::RecordBatch(_) | FlightMessage::Metrics(_) => {
421                IllegalFlightMessagesSnafu {
422                    reason: "The first flight message cannot be a RecordBatch or Metrics message",
423                }
424                .fail()
425            }
426            FlightMessage::Schema(schema) => {
427                let schema = Arc::new(
428                    datatypes::schema::Schema::try_from(schema)
429                        .context(error::ConvertSchemaSnafu)?,
430                );
431                let schema_cloned = schema.clone();
432                let stream = Box::pin(stream!({
433                    while let Some(flight_message) = flight_message_stream.next().await {
434                        let flight_message = flight_message
435                            .map_err(BoxedError::new)
436                            .context(ExternalSnafu)?;
437                        match flight_message {
438                            FlightMessage::RecordBatch(arrow_batch) => {
439                                yield RecordBatch::try_from_df_record_batch(
440                                    schema_cloned.clone(),
441                                    arrow_batch,
442                                )
443                            }
444                            FlightMessage::Metrics(_) => {}
445                            FlightMessage::AffectedRows(_) | FlightMessage::Schema(_) => {
446                                yield IllegalFlightMessagesSnafu {reason: format!("A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}", flight_message)}
447                                        .fail()
448                                        .map_err(BoxedError::new)
449                                        .context(ExternalSnafu);
450                                break;
451                            }
452                        }
453                    }
454                }));
455                let record_batch_stream = RecordBatchStreamWrapper {
456                    schema,
457                    stream,
458                    output_ordering: None,
459                    metrics: Default::default(),
460                };
461                Ok(Output::new_with_stream(Box::pin(record_batch_stream)))
462            }
463        }
464    }
465
466    /// Ingest a stream of [RecordBatch]es that belong to a table, using Arrow Flight's "`DoPut`"
467    /// method. The return value is also a stream, produces [DoPutResponse]s.
468    pub async fn do_put(&self, stream: FlightDataStream) -> Result<DoPutResponseStream> {
469        let mut request = tonic::Request::new(stream);
470
471        if let Some(AuthHeader {
472            auth_scheme: Some(AuthScheme::Basic(Basic { username, password })),
473        }) = &self.ctx.auth_header
474        {
475            let encoded = BASE64_STANDARD.encode(format!("{username}:{password}"));
476            let value =
477                MetadataValue::from_str(&encoded).context(InvalidTonicMetadataValueSnafu)?;
478            request.metadata_mut().insert("x-greptime-auth", value);
479        }
480
481        let db_to_put = if !self.dbname.is_empty() {
482            &self.dbname
483        } else {
484            &build_db_string(self.catalog_or_default(), self.schema_or_default())
485        };
486        request.metadata_mut().insert(
487            "x-greptime-db-name",
488            MetadataValue::from_str(db_to_put).context(InvalidTonicMetadataValueSnafu)?,
489        );
490
491        let mut client = self.client.make_flight_client(false, false)?;
492        let response = client.mut_inner().do_put(request).await?;
493        let response = response
494            .into_inner()
495            .map_err(Into::into)
496            .and_then(|x| future::ready(DoPutResponse::try_from(x).context(ConvertFlightDataSnafu)))
497            .boxed();
498        Ok(response)
499    }
500}
501
502/// by grpc standard, only `Unavailable` is retryable, see: https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
503pub fn is_grpc_retryable(err: &tonic::Status) -> bool {
504    matches!(err.code(), tonic::Code::Unavailable)
505}
506
507#[derive(Default, Debug, Clone)]
508struct FlightContext {
509    auth_header: Option<AuthHeader>,
510}
511
512#[cfg(test)]
513mod tests {
514    use std::assert_matches::assert_matches;
515
516    use api::v1::auth_header::AuthScheme;
517    use api::v1::{AuthHeader, Basic};
518    use common_error::status_code::StatusCode;
519    use tonic::{Code, Status};
520
521    use super::*;
522    use crate::error::TonicSnafu;
523
524    #[test]
525    fn test_flight_ctx() {
526        let mut ctx = FlightContext::default();
527        assert!(ctx.auth_header.is_none());
528
529        let basic = AuthScheme::Basic(Basic {
530            username: "u".to_string(),
531            password: "p".to_string(),
532        });
533
534        ctx.auth_header = Some(AuthHeader {
535            auth_scheme: Some(basic),
536        });
537
538        assert_matches!(
539            ctx.auth_header,
540            Some(AuthHeader {
541                auth_scheme: Some(AuthScheme::Basic(_)),
542            })
543        )
544    }
545
546    #[test]
547    fn test_from_tonic_status() {
548        let expected = TonicSnafu {
549            code: StatusCode::Internal,
550            msg: "blabla".to_string(),
551            tonic_code: Code::Internal,
552        }
553        .build();
554
555        let status = Status::new(Code::Internal, "blabla");
556        let actual: Error = status.into();
557
558        assert_eq!(expected.to_string(), actual.to_string());
559    }
560}