client/
database.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::Arc;
18
19use api::v1::auth_header::AuthScheme;
20use api::v1::ddl_request::Expr as DdlExpr;
21use api::v1::greptime_database_client::GreptimeDatabaseClient;
22use api::v1::greptime_request::Request;
23use api::v1::query_request::Query;
24use api::v1::{
25    AlterTableExpr, AuthHeader, Basic, CreateTableExpr, DdlRequest, GreptimeRequest,
26    InsertRequests, QueryRequest, RequestHeader, RowInsertRequests,
27};
28use arrow_flight::{FlightData, Ticket};
29use async_stream::stream;
30use base64::prelude::BASE64_STANDARD;
31use base64::Engine;
32use common_catalog::build_db_string;
33use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
34use common_error::ext::BoxedError;
35use common_grpc::flight::do_put::DoPutResponse;
36use common_grpc::flight::{FlightDecoder, FlightMessage};
37use common_query::Output;
38use common_recordbatch::error::ExternalSnafu;
39use common_recordbatch::{RecordBatch, RecordBatchStreamWrapper};
40use common_telemetry::tracing_context::W3cTrace;
41use common_telemetry::{error, warn};
42use futures::future;
43use futures_util::{Stream, StreamExt, TryStreamExt};
44use prost::Message;
45use snafu::{ensure, OptionExt, ResultExt};
46use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue, MetadataMap, MetadataValue};
47use tonic::transport::Channel;
48
49use crate::error::{
50    ConvertFlightDataSnafu, Error, FlightGetSnafu, IllegalFlightMessagesSnafu,
51    InvalidTonicMetadataValueSnafu,
52};
53use crate::{error, from_grpc_response, Client, Result};
54
55type FlightDataStream = Pin<Box<dyn Stream<Item = FlightData> + Send>>;
56
57type DoPutResponseStream = Pin<Box<dyn Stream<Item = Result<DoPutResponse>>>>;
58
59#[derive(Clone, Debug, Default)]
60pub struct Database {
61    // The "catalog" and "schema" to be used in processing the requests at the server side.
62    // They are the "hint" or "context", just like how the "database" in "USE" statement is treated in MySQL.
63    // They will be carried in the request header.
64    catalog: String,
65    schema: String,
66    // The dbname follows naming rule as out mysql, postgres and http
67    // protocol. The server treat dbname in priority of catalog/schema.
68    dbname: String,
69    // The time zone indicates the time zone where the user is located.
70    // Some queries need to be aware of the user's time zone to perform some specific actions.
71    timezone: String,
72
73    client: Client,
74    ctx: FlightContext,
75}
76
77pub struct DatabaseClient {
78    pub inner: GreptimeDatabaseClient<Channel>,
79}
80
81fn make_database_client(client: &Client) -> Result<DatabaseClient> {
82    let (_, channel) = client.find_channel()?;
83    Ok(DatabaseClient {
84        inner: GreptimeDatabaseClient::new(channel)
85            .max_decoding_message_size(client.max_grpc_recv_message_size())
86            .max_encoding_message_size(client.max_grpc_send_message_size()),
87    })
88}
89
90impl Database {
91    /// Create database service client using catalog and schema
92    pub fn new(catalog: impl Into<String>, schema: impl Into<String>, client: Client) -> Self {
93        Self {
94            catalog: catalog.into(),
95            schema: schema.into(),
96            dbname: String::default(),
97            timezone: String::default(),
98            client,
99            ctx: FlightContext::default(),
100        }
101    }
102
103    /// Create database service client using dbname.
104    ///
105    /// This API is designed for external usage. `dbname` is:
106    ///
107    /// - the name of database when using GreptimeDB standalone or cluster
108    /// - the name provided by GreptimeCloud or other multi-tenant GreptimeDB
109    ///   environment
110    pub fn new_with_dbname(dbname: impl Into<String>, client: Client) -> Self {
111        Self {
112            catalog: String::default(),
113            schema: String::default(),
114            timezone: String::default(),
115            dbname: dbname.into(),
116            client,
117            ctx: FlightContext::default(),
118        }
119    }
120
121    /// Set the catalog for the database client.
122    pub fn set_catalog(&mut self, catalog: impl Into<String>) {
123        self.catalog = catalog.into();
124    }
125
126    fn catalog_or_default(&self) -> &str {
127        if self.catalog.is_empty() {
128            DEFAULT_CATALOG_NAME
129        } else {
130            &self.catalog
131        }
132    }
133
134    /// Set the schema for the database client.
135    pub fn set_schema(&mut self, schema: impl Into<String>) {
136        self.schema = schema.into();
137    }
138
139    fn schema_or_default(&self) -> &str {
140        if self.schema.is_empty() {
141            DEFAULT_SCHEMA_NAME
142        } else {
143            &self.schema
144        }
145    }
146
147    /// Set the timezone for the database client.
148    pub fn set_timezone(&mut self, timezone: impl Into<String>) {
149        self.timezone = timezone.into();
150    }
151
152    /// Set the auth scheme for the database client.
153    pub fn set_auth(&mut self, auth: AuthScheme) {
154        self.ctx.auth_header = Some(AuthHeader {
155            auth_scheme: Some(auth),
156        });
157    }
158
159    /// Make an InsertRequests request to the database.
160    pub async fn insert(&self, requests: InsertRequests) -> Result<u32> {
161        self.handle(Request::Inserts(requests)).await
162    }
163
164    /// Make an InsertRequests request to the database with hints.
165    pub async fn insert_with_hints(
166        &self,
167        requests: InsertRequests,
168        hints: &[(&str, &str)],
169    ) -> Result<u32> {
170        let mut client = make_database_client(&self.client)?.inner;
171        let request = self.to_rpc_request(Request::Inserts(requests));
172
173        let mut request = tonic::Request::new(request);
174        let metadata = request.metadata_mut();
175        Self::put_hints(metadata, hints)?;
176
177        let response = client.handle(request).await?.into_inner();
178        from_grpc_response(response)
179    }
180
181    /// Make a RowInsertRequests request to the database.
182    pub async fn row_inserts(&self, requests: RowInsertRequests) -> Result<u32> {
183        self.handle(Request::RowInserts(requests)).await
184    }
185
186    /// Make a RowInsertRequests request to the database with hints.
187    pub async fn row_inserts_with_hints(
188        &self,
189        requests: RowInsertRequests,
190        hints: &[(&str, &str)],
191    ) -> Result<u32> {
192        let mut client = make_database_client(&self.client)?.inner;
193        let request = self.to_rpc_request(Request::RowInserts(requests));
194
195        let mut request = tonic::Request::new(request);
196        let metadata = request.metadata_mut();
197        Self::put_hints(metadata, hints)?;
198
199        let response = client.handle(request).await?.into_inner();
200        from_grpc_response(response)
201    }
202
203    fn put_hints(metadata: &mut MetadataMap, hints: &[(&str, &str)]) -> Result<()> {
204        let Some(value) = hints
205            .iter()
206            .map(|(k, v)| format!("{}={}", k, v))
207            .reduce(|a, b| format!("{},{}", a, b))
208        else {
209            return Ok(());
210        };
211
212        let key = AsciiMetadataKey::from_static("x-greptime-hints");
213        let value = AsciiMetadataValue::from_str(&value).context(InvalidTonicMetadataValueSnafu)?;
214        metadata.insert(key, value);
215        Ok(())
216    }
217
218    /// Make a request to the database.
219    pub async fn handle(&self, request: Request) -> Result<u32> {
220        let mut client = make_database_client(&self.client)?.inner;
221        let request = self.to_rpc_request(request);
222        let response = client.handle(request).await?.into_inner();
223        from_grpc_response(response)
224    }
225
226    /// Retry if connection fails, max_retries is the max number of retries, so the total wait time
227    /// is `max_retries * GRPC_CONN_TIMEOUT`
228    pub async fn handle_with_retry(
229        &self,
230        request: Request,
231        max_retries: u32,
232        hints: &[(&str, &str)],
233    ) -> Result<u32> {
234        let mut client = make_database_client(&self.client)?.inner;
235        let mut retries = 0;
236
237        let request = self.to_rpc_request(request);
238
239        loop {
240            let mut tonic_request = tonic::Request::new(request.clone());
241            let metadata = tonic_request.metadata_mut();
242            Self::put_hints(metadata, hints)?;
243            let raw_response = client.handle(tonic_request).await;
244            match (raw_response, retries < max_retries) {
245                (Ok(resp), _) => return from_grpc_response(resp.into_inner()),
246                (Err(err), true) => {
247                    // determine if the error is retryable
248                    if is_grpc_retryable(&err) {
249                        // retry
250                        retries += 1;
251                        warn!("Retrying {} times with error = {:?}", retries, err);
252                        continue;
253                    } else {
254                        error!(
255                            err; "Failed to send request to grpc handle, retries = {}, not retryable error, aborting",
256                            retries
257                        );
258                        return Err(err.into());
259                    }
260                }
261                (Err(err), false) => {
262                    error!(
263                        err; "Failed to send request to grpc handle after {} retries",
264                        retries,
265                    );
266                    return Err(err.into());
267                }
268            }
269        }
270    }
271
272    #[inline]
273    fn to_rpc_request(&self, request: Request) -> GreptimeRequest {
274        GreptimeRequest {
275            header: Some(RequestHeader {
276                catalog: self.catalog.clone(),
277                schema: self.schema.clone(),
278                authorization: self.ctx.auth_header.clone(),
279                dbname: self.dbname.clone(),
280                timezone: self.timezone.clone(),
281                // TODO(Taylor-lagrange): add client grpc tracing
282                tracing_context: W3cTrace::new(),
283            }),
284            request: Some(request),
285        }
286    }
287
288    /// Executes a SQL query without any hints.
289    pub async fn sql<S>(&self, sql: S) -> Result<Output>
290    where
291        S: AsRef<str>,
292    {
293        self.sql_with_hint(sql, &[]).await
294    }
295
296    /// Executes a SQL query with optional hints for query optimization.
297    pub async fn sql_with_hint<S>(&self, sql: S, hints: &[(&str, &str)]) -> Result<Output>
298    where
299        S: AsRef<str>,
300    {
301        let request = Request::Query(QueryRequest {
302            query: Some(Query::Sql(sql.as_ref().to_string())),
303        });
304        self.do_get(request, hints).await
305    }
306
307    /// Executes a logical plan directly without SQL parsing.
308    pub async fn logical_plan(&self, logical_plan: Vec<u8>) -> Result<Output> {
309        let request = Request::Query(QueryRequest {
310            query: Some(Query::LogicalPlan(logical_plan)),
311        });
312        self.do_get(request, &[]).await
313    }
314
315    /// Creates a new table using the provided table expression.
316    pub async fn create(&self, expr: CreateTableExpr) -> Result<Output> {
317        let request = Request::Ddl(DdlRequest {
318            expr: Some(DdlExpr::CreateTable(expr)),
319        });
320        self.do_get(request, &[]).await
321    }
322
323    /// Alters an existing table using the provided alter expression.
324    pub async fn alter(&self, expr: AlterTableExpr) -> Result<Output> {
325        let request = Request::Ddl(DdlRequest {
326            expr: Some(DdlExpr::AlterTable(expr)),
327        });
328        self.do_get(request, &[]).await
329    }
330
331    async fn do_get(&self, request: Request, hints: &[(&str, &str)]) -> Result<Output> {
332        let request = self.to_rpc_request(request);
333        let request = Ticket {
334            ticket: request.encode_to_vec().into(),
335        };
336
337        let mut request = tonic::Request::new(request);
338        Self::put_hints(request.metadata_mut(), hints)?;
339
340        let mut client = self.client.make_flight_client(false, false)?;
341
342        let response = client.mut_inner().do_get(request).await.or_else(|e| {
343            let tonic_code = e.code();
344            let e: Error = e.into();
345            error!(
346                "Failed to do Flight get, addr: {}, code: {}, source: {:?}",
347                client.addr(),
348                tonic_code,
349                e
350            );
351            let error = Err(BoxedError::new(e)).with_context(|_| FlightGetSnafu {
352                addr: client.addr().to_string(),
353                tonic_code,
354            });
355            error
356        })?;
357
358        let flight_data_stream = response.into_inner();
359        let mut decoder = FlightDecoder::default();
360
361        let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
362            flight_data
363                .map_err(Error::from)
364                .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))?
365                .context(IllegalFlightMessagesSnafu {
366                    reason: "none message",
367                })
368        });
369
370        let Some(first_flight_message) = flight_message_stream.next().await else {
371            return IllegalFlightMessagesSnafu {
372                reason: "Expect the response not to be empty",
373            }
374            .fail();
375        };
376
377        let first_flight_message = first_flight_message?;
378
379        match first_flight_message {
380            FlightMessage::AffectedRows(rows) => {
381                ensure!(
382                    flight_message_stream.next().await.is_none(),
383                    IllegalFlightMessagesSnafu {
384                        reason: "Expect 'AffectedRows' Flight messages to be the one and the only!"
385                    }
386                );
387                Ok(Output::new_with_affected_rows(rows))
388            }
389            FlightMessage::RecordBatch(_) | FlightMessage::Metrics(_) => {
390                IllegalFlightMessagesSnafu {
391                    reason: "The first flight message cannot be a RecordBatch or Metrics message",
392                }
393                .fail()
394            }
395            FlightMessage::Schema(schema) => {
396                let schema = Arc::new(
397                    datatypes::schema::Schema::try_from(schema)
398                        .context(error::ConvertSchemaSnafu)?,
399                );
400                let schema_cloned = schema.clone();
401                let stream = Box::pin(stream!({
402                    while let Some(flight_message) = flight_message_stream.next().await {
403                        let flight_message = flight_message
404                            .map_err(BoxedError::new)
405                            .context(ExternalSnafu)?;
406                        match flight_message {
407                            FlightMessage::RecordBatch(arrow_batch) => {
408                                yield RecordBatch::try_from_df_record_batch(
409                                    schema_cloned.clone(),
410                                    arrow_batch,
411                                )
412                            }
413                            FlightMessage::Metrics(_) => {}
414                            FlightMessage::AffectedRows(_) | FlightMessage::Schema(_) => {
415                                yield IllegalFlightMessagesSnafu {reason: format!("A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}", flight_message)}
416                                        .fail()
417                                        .map_err(BoxedError::new)
418                                        .context(ExternalSnafu);
419                                break;
420                            }
421                        }
422                    }
423                }));
424                let record_batch_stream = RecordBatchStreamWrapper {
425                    schema,
426                    stream,
427                    output_ordering: None,
428                    metrics: Default::default(),
429                };
430                Ok(Output::new_with_stream(Box::pin(record_batch_stream)))
431            }
432        }
433    }
434
435    /// Ingest a stream of [RecordBatch]es that belong to a table, using Arrow Flight's "`DoPut`"
436    /// method. The return value is also a stream, produces [DoPutResponse]s.
437    pub async fn do_put(&self, stream: FlightDataStream) -> Result<DoPutResponseStream> {
438        let mut request = tonic::Request::new(stream);
439
440        if let Some(AuthHeader {
441            auth_scheme: Some(AuthScheme::Basic(Basic { username, password })),
442        }) = &self.ctx.auth_header
443        {
444            let encoded = BASE64_STANDARD.encode(format!("{username}:{password}"));
445            let value =
446                MetadataValue::from_str(&encoded).context(InvalidTonicMetadataValueSnafu)?;
447            request.metadata_mut().insert("x-greptime-auth", value);
448        }
449
450        let db_to_put = if !self.dbname.is_empty() {
451            &self.dbname
452        } else {
453            &build_db_string(self.catalog_or_default(), self.schema_or_default())
454        };
455        request.metadata_mut().insert(
456            "x-greptime-db-name",
457            MetadataValue::from_str(db_to_put).context(InvalidTonicMetadataValueSnafu)?,
458        );
459
460        let mut client = self.client.make_flight_client(false, false)?;
461        let response = client.mut_inner().do_put(request).await?;
462        let response = response
463            .into_inner()
464            .map_err(Into::into)
465            .and_then(|x| future::ready(DoPutResponse::try_from(x).context(ConvertFlightDataSnafu)))
466            .boxed();
467        Ok(response)
468    }
469}
470
471/// by grpc standard, only `Unavailable` is retryable, see: https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
472pub fn is_grpc_retryable(err: &tonic::Status) -> bool {
473    matches!(err.code(), tonic::Code::Unavailable)
474}
475
476#[derive(Default, Debug, Clone)]
477struct FlightContext {
478    auth_header: Option<AuthHeader>,
479}
480
481#[cfg(test)]
482mod tests {
483    use std::assert_matches::assert_matches;
484
485    use api::v1::auth_header::AuthScheme;
486    use api::v1::{AuthHeader, Basic};
487    use common_error::status_code::StatusCode;
488    use tonic::{Code, Status};
489
490    use super::*;
491    use crate::error::TonicSnafu;
492
493    #[test]
494    fn test_flight_ctx() {
495        let mut ctx = FlightContext::default();
496        assert!(ctx.auth_header.is_none());
497
498        let basic = AuthScheme::Basic(Basic {
499            username: "u".to_string(),
500            password: "p".to_string(),
501        });
502
503        ctx.auth_header = Some(AuthHeader {
504            auth_scheme: Some(basic),
505        });
506
507        assert_matches!(
508            ctx.auth_header,
509            Some(AuthHeader {
510                auth_scheme: Some(AuthScheme::Basic(_)),
511            })
512        )
513    }
514
515    #[test]
516    fn test_from_tonic_status() {
517        let expected = TonicSnafu {
518            code: StatusCode::Internal,
519            msg: "blabla".to_string(),
520            tonic_code: Code::Internal,
521        }
522        .build();
523
524        let status = Status::new(Code::Internal, "blabla");
525        let actual: Error = status.into();
526
527        assert_eq!(expected.to_string(), actual.to_string());
528    }
529}