client/
database.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::Arc;
18
19use api::v1::auth_header::AuthScheme;
20use api::v1::ddl_request::Expr as DdlExpr;
21use api::v1::greptime_database_client::GreptimeDatabaseClient;
22use api::v1::greptime_request::Request;
23use api::v1::query_request::Query;
24use api::v1::{
25    AlterTableExpr, AuthHeader, Basic, CreateTableExpr, DdlRequest, GreptimeRequest,
26    InsertRequests, QueryRequest, RequestHeader,
27};
28use arrow_flight::{FlightData, Ticket};
29use async_stream::stream;
30use base64::prelude::BASE64_STANDARD;
31use base64::Engine;
32use common_catalog::build_db_string;
33use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
34use common_error::ext::{BoxedError, ErrorExt};
35use common_grpc::flight::do_put::DoPutResponse;
36use common_grpc::flight::{FlightDecoder, FlightMessage};
37use common_query::Output;
38use common_recordbatch::error::ExternalSnafu;
39use common_recordbatch::{RecordBatch, RecordBatchStreamWrapper};
40use common_telemetry::tracing_context::W3cTrace;
41use common_telemetry::{error, warn};
42use futures::future;
43use futures_util::{Stream, StreamExt, TryStreamExt};
44use prost::Message;
45use snafu::{ensure, ResultExt};
46use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue, MetadataMap, MetadataValue};
47use tonic::transport::Channel;
48
49use crate::error::{
50    ConvertFlightDataSnafu, Error, FlightGetSnafu, IllegalFlightMessagesSnafu,
51    InvalidTonicMetadataValueSnafu, ServerSnafu,
52};
53use crate::{error, from_grpc_response, Client, Result};
54
55type FlightDataStream = Pin<Box<dyn Stream<Item = FlightData> + Send>>;
56
57type DoPutResponseStream = Pin<Box<dyn Stream<Item = Result<DoPutResponse>>>>;
58
59#[derive(Clone, Debug, Default)]
60pub struct Database {
61    // The "catalog" and "schema" to be used in processing the requests at the server side.
62    // They are the "hint" or "context", just like how the "database" in "USE" statement is treated in MySQL.
63    // They will be carried in the request header.
64    catalog: String,
65    schema: String,
66    // The dbname follows naming rule as out mysql, postgres and http
67    // protocol. The server treat dbname in priority of catalog/schema.
68    dbname: String,
69    // The time zone indicates the time zone where the user is located.
70    // Some queries need to be aware of the user's time zone to perform some specific actions.
71    timezone: String,
72
73    client: Client,
74    ctx: FlightContext,
75}
76
77pub struct DatabaseClient {
78    pub inner: GreptimeDatabaseClient<Channel>,
79}
80
81fn make_database_client(client: &Client) -> Result<DatabaseClient> {
82    let (_, channel) = client.find_channel()?;
83    Ok(DatabaseClient {
84        inner: GreptimeDatabaseClient::new(channel)
85            .max_decoding_message_size(client.max_grpc_recv_message_size())
86            .max_encoding_message_size(client.max_grpc_send_message_size()),
87    })
88}
89
90impl Database {
91    /// Create database service client using catalog and schema
92    pub fn new(catalog: impl Into<String>, schema: impl Into<String>, client: Client) -> Self {
93        Self {
94            catalog: catalog.into(),
95            schema: schema.into(),
96            dbname: String::default(),
97            timezone: String::default(),
98            client,
99            ctx: FlightContext::default(),
100        }
101    }
102
103    /// Create database service client using dbname.
104    ///
105    /// This API is designed for external usage. `dbname` is:
106    ///
107    /// - the name of database when using GreptimeDB standalone or cluster
108    /// - the name provided by GreptimeCloud or other multi-tenant GreptimeDB
109    ///   environment
110    pub fn new_with_dbname(dbname: impl Into<String>, client: Client) -> Self {
111        Self {
112            catalog: String::default(),
113            schema: String::default(),
114            timezone: String::default(),
115            dbname: dbname.into(),
116            client,
117            ctx: FlightContext::default(),
118        }
119    }
120
121    pub fn set_catalog(&mut self, catalog: impl Into<String>) {
122        self.catalog = catalog.into();
123    }
124
125    fn catalog_or_default(&self) -> &str {
126        if self.catalog.is_empty() {
127            DEFAULT_CATALOG_NAME
128        } else {
129            &self.catalog
130        }
131    }
132
133    pub fn set_schema(&mut self, schema: impl Into<String>) {
134        self.schema = schema.into();
135    }
136
137    fn schema_or_default(&self) -> &str {
138        if self.schema.is_empty() {
139            DEFAULT_SCHEMA_NAME
140        } else {
141            &self.schema
142        }
143    }
144
145    pub fn set_timezone(&mut self, timezone: impl Into<String>) {
146        self.timezone = timezone.into();
147    }
148
149    pub fn set_auth(&mut self, auth: AuthScheme) {
150        self.ctx.auth_header = Some(AuthHeader {
151            auth_scheme: Some(auth),
152        });
153    }
154
155    pub async fn insert(&self, requests: InsertRequests) -> Result<u32> {
156        self.handle(Request::Inserts(requests)).await
157    }
158
159    pub async fn insert_with_hints(
160        &self,
161        requests: InsertRequests,
162        hints: &[(&str, &str)],
163    ) -> Result<u32> {
164        let mut client = make_database_client(&self.client)?.inner;
165        let request = self.to_rpc_request(Request::Inserts(requests));
166
167        let mut request = tonic::Request::new(request);
168        let metadata = request.metadata_mut();
169        Self::put_hints(metadata, hints)?;
170
171        let response = client.handle(request).await?.into_inner();
172        from_grpc_response(response)
173    }
174
175    fn put_hints(metadata: &mut MetadataMap, hints: &[(&str, &str)]) -> Result<()> {
176        let Some(value) = hints
177            .iter()
178            .map(|(k, v)| format!("{}={}", k, v))
179            .reduce(|a, b| format!("{},{}", a, b))
180        else {
181            return Ok(());
182        };
183
184        let key = AsciiMetadataKey::from_static("x-greptime-hints");
185        let value = AsciiMetadataValue::from_str(&value).context(InvalidTonicMetadataValueSnafu)?;
186        metadata.insert(key, value);
187        Ok(())
188    }
189
190    pub async fn handle(&self, request: Request) -> Result<u32> {
191        let mut client = make_database_client(&self.client)?.inner;
192        let request = self.to_rpc_request(request);
193        let response = client.handle(request).await?.into_inner();
194        from_grpc_response(response)
195    }
196
197    /// Retry if connection fails, max_retries is the max number of retries, so the total wait time
198    /// is `max_retries * GRPC_CONN_TIMEOUT`
199    pub async fn handle_with_retry(&self, request: Request, max_retries: u32) -> Result<u32> {
200        let mut client = make_database_client(&self.client)?.inner;
201        let mut retries = 0;
202        let request = self.to_rpc_request(request);
203        loop {
204            let raw_response = client.handle(request.clone()).await;
205            match (raw_response, retries < max_retries) {
206                (Ok(resp), _) => return from_grpc_response(resp.into_inner()),
207                (Err(err), true) => {
208                    // determine if the error is retryable
209                    if is_grpc_retryable(&err) {
210                        // retry
211                        retries += 1;
212                        warn!("Retrying {} times with error = {:?}", retries, err);
213                        continue;
214                    }
215                }
216                (Err(err), false) => {
217                    error!(
218                        "Failed to send request to grpc handle after {} retries, error = {:?}",
219                        retries, err
220                    );
221                    return Err(err.into());
222                }
223            }
224        }
225    }
226
227    #[inline]
228    fn to_rpc_request(&self, request: Request) -> GreptimeRequest {
229        GreptimeRequest {
230            header: Some(RequestHeader {
231                catalog: self.catalog.clone(),
232                schema: self.schema.clone(),
233                authorization: self.ctx.auth_header.clone(),
234                dbname: self.dbname.clone(),
235                timezone: self.timezone.clone(),
236                // TODO(Taylor-lagrange): add client grpc tracing
237                tracing_context: W3cTrace::new(),
238            }),
239            request: Some(request),
240        }
241    }
242
243    pub async fn sql<S>(&self, sql: S) -> Result<Output>
244    where
245        S: AsRef<str>,
246    {
247        self.sql_with_hint(sql, &[]).await
248    }
249
250    pub async fn sql_with_hint<S>(&self, sql: S, hints: &[(&str, &str)]) -> Result<Output>
251    where
252        S: AsRef<str>,
253    {
254        let request = Request::Query(QueryRequest {
255            query: Some(Query::Sql(sql.as_ref().to_string())),
256        });
257        self.do_get(request, hints).await
258    }
259
260    pub async fn logical_plan(&self, logical_plan: Vec<u8>) -> Result<Output> {
261        let request = Request::Query(QueryRequest {
262            query: Some(Query::LogicalPlan(logical_plan)),
263        });
264        self.do_get(request, &[]).await
265    }
266
267    pub async fn create(&self, expr: CreateTableExpr) -> Result<Output> {
268        let request = Request::Ddl(DdlRequest {
269            expr: Some(DdlExpr::CreateTable(expr)),
270        });
271        self.do_get(request, &[]).await
272    }
273
274    pub async fn alter(&self, expr: AlterTableExpr) -> Result<Output> {
275        let request = Request::Ddl(DdlRequest {
276            expr: Some(DdlExpr::AlterTable(expr)),
277        });
278        self.do_get(request, &[]).await
279    }
280
281    async fn do_get(&self, request: Request, hints: &[(&str, &str)]) -> Result<Output> {
282        let request = self.to_rpc_request(request);
283        let request = Ticket {
284            ticket: request.encode_to_vec().into(),
285        };
286
287        let mut request = tonic::Request::new(request);
288        Self::put_hints(request.metadata_mut(), hints)?;
289
290        let mut client = self.client.make_flight_client(false, false)?;
291
292        let response = client.mut_inner().do_get(request).await.or_else(|e| {
293            let tonic_code = e.code();
294            let e: Error = e.into();
295            let code = e.status_code();
296            let msg = e.to_string();
297            let error =
298                Err(BoxedError::new(ServerSnafu { code, msg }.build())).with_context(|_| {
299                    FlightGetSnafu {
300                        addr: client.addr().to_string(),
301                        tonic_code,
302                    }
303                });
304            error!(
305                "Failed to do Flight get, addr: {}, code: {}, source: {:?}",
306                client.addr(),
307                tonic_code,
308                error
309            );
310            error
311        })?;
312
313        let flight_data_stream = response.into_inner();
314        let mut decoder = FlightDecoder::default();
315
316        let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
317            flight_data
318                .map_err(Error::from)
319                .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))
320        });
321
322        let Some(first_flight_message) = flight_message_stream.next().await else {
323            return IllegalFlightMessagesSnafu {
324                reason: "Expect the response not to be empty",
325            }
326            .fail();
327        };
328
329        let first_flight_message = first_flight_message?;
330
331        match first_flight_message {
332            FlightMessage::AffectedRows(rows) => {
333                ensure!(
334                    flight_message_stream.next().await.is_none(),
335                    IllegalFlightMessagesSnafu {
336                        reason: "Expect 'AffectedRows' Flight messages to be the one and the only!"
337                    }
338                );
339                Ok(Output::new_with_affected_rows(rows))
340            }
341            FlightMessage::RecordBatch(_) | FlightMessage::Metrics(_) => {
342                IllegalFlightMessagesSnafu {
343                    reason: "The first flight message cannot be a RecordBatch or Metrics message",
344                }
345                .fail()
346            }
347            FlightMessage::Schema(schema) => {
348                let schema = Arc::new(
349                    datatypes::schema::Schema::try_from(schema)
350                        .context(error::ConvertSchemaSnafu)?,
351                );
352                let schema_cloned = schema.clone();
353                let stream = Box::pin(stream!({
354                    while let Some(flight_message) = flight_message_stream.next().await {
355                        let flight_message = flight_message
356                            .map_err(BoxedError::new)
357                            .context(ExternalSnafu)?;
358                        match flight_message {
359                            FlightMessage::RecordBatch(arrow_batch) => {
360                                yield RecordBatch::try_from_df_record_batch(
361                                    schema_cloned.clone(),
362                                    arrow_batch,
363                                )
364                            }
365                            FlightMessage::Metrics(_) => {}
366                            FlightMessage::AffectedRows(_) | FlightMessage::Schema(_) => {
367                                yield IllegalFlightMessagesSnafu {reason: format!("A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}", flight_message)}
368                                        .fail()
369                                        .map_err(BoxedError::new)
370                                        .context(ExternalSnafu);
371                                break;
372                            }
373                        }
374                    }
375                }));
376                let record_batch_stream = RecordBatchStreamWrapper {
377                    schema,
378                    stream,
379                    output_ordering: None,
380                    metrics: Default::default(),
381                };
382                Ok(Output::new_with_stream(Box::pin(record_batch_stream)))
383            }
384        }
385    }
386
387    /// Ingest a stream of [RecordBatch]es that belong to a table, using Arrow Flight's "`DoPut`"
388    /// method. The return value is also a stream, produces [DoPutResponse]s.
389    pub async fn do_put(&self, stream: FlightDataStream) -> Result<DoPutResponseStream> {
390        let mut request = tonic::Request::new(stream);
391
392        if let Some(AuthHeader {
393            auth_scheme: Some(AuthScheme::Basic(Basic { username, password })),
394        }) = &self.ctx.auth_header
395        {
396            let encoded = BASE64_STANDARD.encode(format!("{username}:{password}"));
397            let value =
398                MetadataValue::from_str(&encoded).context(InvalidTonicMetadataValueSnafu)?;
399            request.metadata_mut().insert("x-greptime-auth", value);
400        }
401
402        let db_to_put = if !self.dbname.is_empty() {
403            &self.dbname
404        } else {
405            &build_db_string(self.catalog_or_default(), self.schema_or_default())
406        };
407        request.metadata_mut().insert(
408            "x-greptime-db-name",
409            MetadataValue::from_str(db_to_put).context(InvalidTonicMetadataValueSnafu)?,
410        );
411
412        let mut client = self.client.make_flight_client(false, false)?;
413        let response = client.mut_inner().do_put(request).await?;
414        let response = response
415            .into_inner()
416            .map_err(Into::into)
417            .and_then(|x| future::ready(DoPutResponse::try_from(x).context(ConvertFlightDataSnafu)))
418            .boxed();
419        Ok(response)
420    }
421}
422
423/// by grpc standard, only `Unavailable` is retryable, see: https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
424pub fn is_grpc_retryable(err: &tonic::Status) -> bool {
425    matches!(err.code(), tonic::Code::Unavailable)
426}
427
428#[derive(Default, Debug, Clone)]
429struct FlightContext {
430    auth_header: Option<AuthHeader>,
431}
432
433#[cfg(test)]
434mod tests {
435    use std::assert_matches::assert_matches;
436
437    use api::v1::auth_header::AuthScheme;
438    use api::v1::{AuthHeader, Basic};
439
440    use super::*;
441
442    #[test]
443    fn test_flight_ctx() {
444        let mut ctx = FlightContext::default();
445        assert!(ctx.auth_header.is_none());
446
447        let basic = AuthScheme::Basic(Basic {
448            username: "u".to_string(),
449            password: "p".to_string(),
450        });
451
452        ctx.auth_header = Some(AuthHeader {
453            auth_scheme: Some(basic),
454        });
455
456        assert_matches!(
457            ctx.auth_header,
458            Some(AuthHeader {
459                auth_scheme: Some(AuthScheme::Basic(_)),
460            })
461        )
462    }
463}