client/
database.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::Arc;
18
19use api::v1::auth_header::AuthScheme;
20use api::v1::ddl_request::Expr as DdlExpr;
21use api::v1::greptime_database_client::GreptimeDatabaseClient;
22use api::v1::greptime_request::Request;
23use api::v1::query_request::Query;
24use api::v1::{
25    AlterTableExpr, AuthHeader, Basic, CreateTableExpr, DdlRequest, GreptimeRequest,
26    InsertRequests, QueryRequest, RequestHeader,
27};
28use arrow_flight::{FlightData, Ticket};
29use async_stream::stream;
30use base64::prelude::BASE64_STANDARD;
31use base64::Engine;
32use common_catalog::build_db_string;
33use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
34use common_error::ext::BoxedError;
35use common_grpc::flight::do_put::DoPutResponse;
36use common_grpc::flight::{FlightDecoder, FlightMessage};
37use common_query::Output;
38use common_recordbatch::error::ExternalSnafu;
39use common_recordbatch::{RecordBatch, RecordBatchStreamWrapper};
40use common_telemetry::tracing_context::W3cTrace;
41use common_telemetry::{error, warn};
42use futures::future;
43use futures_util::{Stream, StreamExt, TryStreamExt};
44use prost::Message;
45use snafu::{ensure, ResultExt};
46use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue, MetadataMap, MetadataValue};
47use tonic::transport::Channel;
48
49use crate::error::{
50    ConvertFlightDataSnafu, Error, FlightGetSnafu, IllegalFlightMessagesSnafu,
51    InvalidTonicMetadataValueSnafu,
52};
53use crate::{error, from_grpc_response, Client, Result};
54
55type FlightDataStream = Pin<Box<dyn Stream<Item = FlightData> + Send>>;
56
57type DoPutResponseStream = Pin<Box<dyn Stream<Item = Result<DoPutResponse>>>>;
58
59#[derive(Clone, Debug, Default)]
60pub struct Database {
61    // The "catalog" and "schema" to be used in processing the requests at the server side.
62    // They are the "hint" or "context", just like how the "database" in "USE" statement is treated in MySQL.
63    // They will be carried in the request header.
64    catalog: String,
65    schema: String,
66    // The dbname follows naming rule as out mysql, postgres and http
67    // protocol. The server treat dbname in priority of catalog/schema.
68    dbname: String,
69    // The time zone indicates the time zone where the user is located.
70    // Some queries need to be aware of the user's time zone to perform some specific actions.
71    timezone: String,
72
73    client: Client,
74    ctx: FlightContext,
75}
76
77pub struct DatabaseClient {
78    pub inner: GreptimeDatabaseClient<Channel>,
79}
80
81fn make_database_client(client: &Client) -> Result<DatabaseClient> {
82    let (_, channel) = client.find_channel()?;
83    Ok(DatabaseClient {
84        inner: GreptimeDatabaseClient::new(channel)
85            .max_decoding_message_size(client.max_grpc_recv_message_size())
86            .max_encoding_message_size(client.max_grpc_send_message_size()),
87    })
88}
89
90impl Database {
91    /// Create database service client using catalog and schema
92    pub fn new(catalog: impl Into<String>, schema: impl Into<String>, client: Client) -> Self {
93        Self {
94            catalog: catalog.into(),
95            schema: schema.into(),
96            dbname: String::default(),
97            timezone: String::default(),
98            client,
99            ctx: FlightContext::default(),
100        }
101    }
102
103    /// Create database service client using dbname.
104    ///
105    /// This API is designed for external usage. `dbname` is:
106    ///
107    /// - the name of database when using GreptimeDB standalone or cluster
108    /// - the name provided by GreptimeCloud or other multi-tenant GreptimeDB
109    ///   environment
110    pub fn new_with_dbname(dbname: impl Into<String>, client: Client) -> Self {
111        Self {
112            catalog: String::default(),
113            schema: String::default(),
114            timezone: String::default(),
115            dbname: dbname.into(),
116            client,
117            ctx: FlightContext::default(),
118        }
119    }
120
121    pub fn set_catalog(&mut self, catalog: impl Into<String>) {
122        self.catalog = catalog.into();
123    }
124
125    fn catalog_or_default(&self) -> &str {
126        if self.catalog.is_empty() {
127            DEFAULT_CATALOG_NAME
128        } else {
129            &self.catalog
130        }
131    }
132
133    pub fn set_schema(&mut self, schema: impl Into<String>) {
134        self.schema = schema.into();
135    }
136
137    fn schema_or_default(&self) -> &str {
138        if self.schema.is_empty() {
139            DEFAULT_SCHEMA_NAME
140        } else {
141            &self.schema
142        }
143    }
144
145    pub fn set_timezone(&mut self, timezone: impl Into<String>) {
146        self.timezone = timezone.into();
147    }
148
149    pub fn set_auth(&mut self, auth: AuthScheme) {
150        self.ctx.auth_header = Some(AuthHeader {
151            auth_scheme: Some(auth),
152        });
153    }
154
155    pub async fn insert(&self, requests: InsertRequests) -> Result<u32> {
156        self.handle(Request::Inserts(requests)).await
157    }
158
159    pub async fn insert_with_hints(
160        &self,
161        requests: InsertRequests,
162        hints: &[(&str, &str)],
163    ) -> Result<u32> {
164        let mut client = make_database_client(&self.client)?.inner;
165        let request = self.to_rpc_request(Request::Inserts(requests));
166
167        let mut request = tonic::Request::new(request);
168        let metadata = request.metadata_mut();
169        Self::put_hints(metadata, hints)?;
170
171        let response = client.handle(request).await?.into_inner();
172        from_grpc_response(response)
173    }
174
175    fn put_hints(metadata: &mut MetadataMap, hints: &[(&str, &str)]) -> Result<()> {
176        let Some(value) = hints
177            .iter()
178            .map(|(k, v)| format!("{}={}", k, v))
179            .reduce(|a, b| format!("{},{}", a, b))
180        else {
181            return Ok(());
182        };
183
184        let key = AsciiMetadataKey::from_static("x-greptime-hints");
185        let value = AsciiMetadataValue::from_str(&value).context(InvalidTonicMetadataValueSnafu)?;
186        metadata.insert(key, value);
187        Ok(())
188    }
189
190    pub async fn handle(&self, request: Request) -> Result<u32> {
191        let mut client = make_database_client(&self.client)?.inner;
192        let request = self.to_rpc_request(request);
193        let response = client.handle(request).await?.into_inner();
194        from_grpc_response(response)
195    }
196
197    /// Retry if connection fails, max_retries is the max number of retries, so the total wait time
198    /// is `max_retries * GRPC_CONN_TIMEOUT`
199    pub async fn handle_with_retry(
200        &self,
201        request: Request,
202        max_retries: u32,
203        hints: &[(&str, &str)],
204    ) -> Result<u32> {
205        let mut client = make_database_client(&self.client)?.inner;
206        let mut retries = 0;
207
208        let request = self.to_rpc_request(request);
209
210        loop {
211            let mut tonic_request = tonic::Request::new(request.clone());
212            let metadata = tonic_request.metadata_mut();
213            Self::put_hints(metadata, hints)?;
214            let raw_response = client.handle(tonic_request).await;
215            match (raw_response, retries < max_retries) {
216                (Ok(resp), _) => return from_grpc_response(resp.into_inner()),
217                (Err(err), true) => {
218                    // determine if the error is retryable
219                    if is_grpc_retryable(&err) {
220                        // retry
221                        retries += 1;
222                        warn!("Retrying {} times with error = {:?}", retries, err);
223                        continue;
224                    }
225                }
226                (Err(err), false) => {
227                    error!(
228                        "Failed to send request to grpc handle after {} retries, error = {:?}",
229                        retries, err
230                    );
231                    return Err(err.into());
232                }
233            }
234        }
235    }
236
237    #[inline]
238    fn to_rpc_request(&self, request: Request) -> GreptimeRequest {
239        GreptimeRequest {
240            header: Some(RequestHeader {
241                catalog: self.catalog.clone(),
242                schema: self.schema.clone(),
243                authorization: self.ctx.auth_header.clone(),
244                dbname: self.dbname.clone(),
245                timezone: self.timezone.clone(),
246                // TODO(Taylor-lagrange): add client grpc tracing
247                tracing_context: W3cTrace::new(),
248            }),
249            request: Some(request),
250        }
251    }
252
253    pub async fn sql<S>(&self, sql: S) -> Result<Output>
254    where
255        S: AsRef<str>,
256    {
257        self.sql_with_hint(sql, &[]).await
258    }
259
260    pub async fn sql_with_hint<S>(&self, sql: S, hints: &[(&str, &str)]) -> Result<Output>
261    where
262        S: AsRef<str>,
263    {
264        let request = Request::Query(QueryRequest {
265            query: Some(Query::Sql(sql.as_ref().to_string())),
266        });
267        self.do_get(request, hints).await
268    }
269
270    pub async fn logical_plan(&self, logical_plan: Vec<u8>) -> Result<Output> {
271        let request = Request::Query(QueryRequest {
272            query: Some(Query::LogicalPlan(logical_plan)),
273        });
274        self.do_get(request, &[]).await
275    }
276
277    pub async fn create(&self, expr: CreateTableExpr) -> Result<Output> {
278        let request = Request::Ddl(DdlRequest {
279            expr: Some(DdlExpr::CreateTable(expr)),
280        });
281        self.do_get(request, &[]).await
282    }
283
284    pub async fn alter(&self, expr: AlterTableExpr) -> Result<Output> {
285        let request = Request::Ddl(DdlRequest {
286            expr: Some(DdlExpr::AlterTable(expr)),
287        });
288        self.do_get(request, &[]).await
289    }
290
291    async fn do_get(&self, request: Request, hints: &[(&str, &str)]) -> Result<Output> {
292        let request = self.to_rpc_request(request);
293        let request = Ticket {
294            ticket: request.encode_to_vec().into(),
295        };
296
297        let mut request = tonic::Request::new(request);
298        Self::put_hints(request.metadata_mut(), hints)?;
299
300        let mut client = self.client.make_flight_client(false, false)?;
301
302        let response = client.mut_inner().do_get(request).await.or_else(|e| {
303            let tonic_code = e.code();
304            let e: Error = e.into();
305            error!(
306                "Failed to do Flight get, addr: {}, code: {}, source: {:?}",
307                client.addr(),
308                tonic_code,
309                e
310            );
311            let error = Err(BoxedError::new(e)).with_context(|_| FlightGetSnafu {
312                addr: client.addr().to_string(),
313                tonic_code,
314            });
315            error
316        })?;
317
318        let flight_data_stream = response.into_inner();
319        let mut decoder = FlightDecoder::default();
320
321        let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
322            flight_data
323                .map_err(Error::from)
324                .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))
325        });
326
327        let Some(first_flight_message) = flight_message_stream.next().await else {
328            return IllegalFlightMessagesSnafu {
329                reason: "Expect the response not to be empty",
330            }
331            .fail();
332        };
333
334        let first_flight_message = first_flight_message?;
335
336        match first_flight_message {
337            FlightMessage::AffectedRows(rows) => {
338                ensure!(
339                    flight_message_stream.next().await.is_none(),
340                    IllegalFlightMessagesSnafu {
341                        reason: "Expect 'AffectedRows' Flight messages to be the one and the only!"
342                    }
343                );
344                Ok(Output::new_with_affected_rows(rows))
345            }
346            FlightMessage::RecordBatch(_) | FlightMessage::Metrics(_) => {
347                IllegalFlightMessagesSnafu {
348                    reason: "The first flight message cannot be a RecordBatch or Metrics message",
349                }
350                .fail()
351            }
352            FlightMessage::Schema(schema) => {
353                let schema = Arc::new(
354                    datatypes::schema::Schema::try_from(schema)
355                        .context(error::ConvertSchemaSnafu)?,
356                );
357                let schema_cloned = schema.clone();
358                let stream = Box::pin(stream!({
359                    while let Some(flight_message) = flight_message_stream.next().await {
360                        let flight_message = flight_message
361                            .map_err(BoxedError::new)
362                            .context(ExternalSnafu)?;
363                        match flight_message {
364                            FlightMessage::RecordBatch(arrow_batch) => {
365                                yield RecordBatch::try_from_df_record_batch(
366                                    schema_cloned.clone(),
367                                    arrow_batch,
368                                )
369                            }
370                            FlightMessage::Metrics(_) => {}
371                            FlightMessage::AffectedRows(_) | FlightMessage::Schema(_) => {
372                                yield IllegalFlightMessagesSnafu {reason: format!("A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}", flight_message)}
373                                        .fail()
374                                        .map_err(BoxedError::new)
375                                        .context(ExternalSnafu);
376                                break;
377                            }
378                        }
379                    }
380                }));
381                let record_batch_stream = RecordBatchStreamWrapper {
382                    schema,
383                    stream,
384                    output_ordering: None,
385                    metrics: Default::default(),
386                };
387                Ok(Output::new_with_stream(Box::pin(record_batch_stream)))
388            }
389        }
390    }
391
392    /// Ingest a stream of [RecordBatch]es that belong to a table, using Arrow Flight's "`DoPut`"
393    /// method. The return value is also a stream, produces [DoPutResponse]s.
394    pub async fn do_put(&self, stream: FlightDataStream) -> Result<DoPutResponseStream> {
395        let mut request = tonic::Request::new(stream);
396
397        if let Some(AuthHeader {
398            auth_scheme: Some(AuthScheme::Basic(Basic { username, password })),
399        }) = &self.ctx.auth_header
400        {
401            let encoded = BASE64_STANDARD.encode(format!("{username}:{password}"));
402            let value =
403                MetadataValue::from_str(&encoded).context(InvalidTonicMetadataValueSnafu)?;
404            request.metadata_mut().insert("x-greptime-auth", value);
405        }
406
407        let db_to_put = if !self.dbname.is_empty() {
408            &self.dbname
409        } else {
410            &build_db_string(self.catalog_or_default(), self.schema_or_default())
411        };
412        request.metadata_mut().insert(
413            "x-greptime-db-name",
414            MetadataValue::from_str(db_to_put).context(InvalidTonicMetadataValueSnafu)?,
415        );
416
417        let mut client = self.client.make_flight_client(false, false)?;
418        let response = client.mut_inner().do_put(request).await?;
419        let response = response
420            .into_inner()
421            .map_err(Into::into)
422            .and_then(|x| future::ready(DoPutResponse::try_from(x).context(ConvertFlightDataSnafu)))
423            .boxed();
424        Ok(response)
425    }
426}
427
428/// by grpc standard, only `Unavailable` is retryable, see: https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
429pub fn is_grpc_retryable(err: &tonic::Status) -> bool {
430    matches!(err.code(), tonic::Code::Unavailable)
431}
432
433#[derive(Default, Debug, Clone)]
434struct FlightContext {
435    auth_header: Option<AuthHeader>,
436}
437
438#[cfg(test)]
439mod tests {
440    use std::assert_matches::assert_matches;
441
442    use api::v1::auth_header::AuthScheme;
443    use api::v1::{AuthHeader, Basic};
444    use common_error::status_code::StatusCode;
445    use tonic::{Code, Status};
446
447    use super::*;
448    use crate::error::TonicSnafu;
449
450    #[test]
451    fn test_flight_ctx() {
452        let mut ctx = FlightContext::default();
453        assert!(ctx.auth_header.is_none());
454
455        let basic = AuthScheme::Basic(Basic {
456            username: "u".to_string(),
457            password: "p".to_string(),
458        });
459
460        ctx.auth_header = Some(AuthHeader {
461            auth_scheme: Some(basic),
462        });
463
464        assert_matches!(
465            ctx.auth_header,
466            Some(AuthHeader {
467                auth_scheme: Some(AuthScheme::Basic(_)),
468            })
469        )
470    }
471
472    #[test]
473    fn test_from_tonic_status() {
474        let expected = TonicSnafu {
475            code: StatusCode::Internal,
476            msg: "blabla".to_string(),
477            tonic_code: Code::Internal,
478        }
479        .build();
480
481        let status = Status::new(Code::Internal, "blabla");
482        let actual: Error = status.into();
483
484        assert_eq!(expected.to_string(), actual.to_string());
485    }
486}