client/
database.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::str::FromStr;
17
18use api::v1::auth_header::AuthScheme;
19use api::v1::ddl_request::Expr as DdlExpr;
20use api::v1::greptime_database_client::GreptimeDatabaseClient;
21use api::v1::greptime_request::Request;
22use api::v1::query_request::Query;
23use api::v1::{
24    AlterTableExpr, AuthHeader, Basic, CreateTableExpr, DdlRequest, GreptimeRequest,
25    InsertRequests, QueryRequest, RequestHeader,
26};
27use arrow_flight::{FlightData, Ticket};
28use async_stream::stream;
29use base64::prelude::BASE64_STANDARD;
30use base64::Engine;
31use common_catalog::build_db_string;
32use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
33use common_error::ext::{BoxedError, ErrorExt};
34use common_grpc::flight::do_put::DoPutResponse;
35use common_grpc::flight::{FlightDecoder, FlightMessage};
36use common_query::Output;
37use common_recordbatch::error::ExternalSnafu;
38use common_recordbatch::RecordBatchStreamWrapper;
39use common_telemetry::tracing_context::W3cTrace;
40use common_telemetry::{error, warn};
41use futures::future;
42use futures_util::{Stream, StreamExt, TryStreamExt};
43use prost::Message;
44use snafu::{ensure, ResultExt};
45use tonic::metadata::{AsciiMetadataKey, MetadataValue};
46use tonic::transport::Channel;
47
48use crate::error::{
49    ConvertFlightDataSnafu, Error, FlightGetSnafu, IllegalFlightMessagesSnafu, InvalidAsciiSnafu,
50    InvalidTonicMetadataValueSnafu, ServerSnafu,
51};
52use crate::{from_grpc_response, Client, Result};
53
54type FlightDataStream = Pin<Box<dyn Stream<Item = FlightData> + Send>>;
55
56type DoPutResponseStream = Pin<Box<dyn Stream<Item = Result<DoPutResponse>>>>;
57
58#[derive(Clone, Debug, Default)]
59pub struct Database {
60    // The "catalog" and "schema" to be used in processing the requests at the server side.
61    // They are the "hint" or "context", just like how the "database" in "USE" statement is treated in MySQL.
62    // They will be carried in the request header.
63    catalog: String,
64    schema: String,
65    // The dbname follows naming rule as out mysql, postgres and http
66    // protocol. The server treat dbname in priority of catalog/schema.
67    dbname: String,
68    // The time zone indicates the time zone where the user is located.
69    // Some queries need to be aware of the user's time zone to perform some specific actions.
70    timezone: String,
71
72    client: Client,
73    ctx: FlightContext,
74}
75
76pub struct DatabaseClient {
77    pub inner: GreptimeDatabaseClient<Channel>,
78}
79
80fn make_database_client(client: &Client) -> Result<DatabaseClient> {
81    let (_, channel) = client.find_channel()?;
82    Ok(DatabaseClient {
83        inner: GreptimeDatabaseClient::new(channel)
84            .max_decoding_message_size(client.max_grpc_recv_message_size())
85            .max_encoding_message_size(client.max_grpc_send_message_size()),
86    })
87}
88
89impl Database {
90    /// Create database service client using catalog and schema
91    pub fn new(catalog: impl Into<String>, schema: impl Into<String>, client: Client) -> Self {
92        Self {
93            catalog: catalog.into(),
94            schema: schema.into(),
95            dbname: String::default(),
96            timezone: String::default(),
97            client,
98            ctx: FlightContext::default(),
99        }
100    }
101
102    /// Create database service client using dbname.
103    ///
104    /// This API is designed for external usage. `dbname` is:
105    ///
106    /// - the name of database when using GreptimeDB standalone or cluster
107    /// - the name provided by GreptimeCloud or other multi-tenant GreptimeDB
108    ///   environment
109    pub fn new_with_dbname(dbname: impl Into<String>, client: Client) -> Self {
110        Self {
111            catalog: String::default(),
112            schema: String::default(),
113            timezone: String::default(),
114            dbname: dbname.into(),
115            client,
116            ctx: FlightContext::default(),
117        }
118    }
119
120    pub fn set_catalog(&mut self, catalog: impl Into<String>) {
121        self.catalog = catalog.into();
122    }
123
124    fn catalog_or_default(&self) -> &str {
125        if self.catalog.is_empty() {
126            DEFAULT_CATALOG_NAME
127        } else {
128            &self.catalog
129        }
130    }
131
132    pub fn set_schema(&mut self, schema: impl Into<String>) {
133        self.schema = schema.into();
134    }
135
136    fn schema_or_default(&self) -> &str {
137        if self.schema.is_empty() {
138            DEFAULT_SCHEMA_NAME
139        } else {
140            &self.schema
141        }
142    }
143
144    pub fn set_timezone(&mut self, timezone: impl Into<String>) {
145        self.timezone = timezone.into();
146    }
147
148    pub fn set_auth(&mut self, auth: AuthScheme) {
149        self.ctx.auth_header = Some(AuthHeader {
150            auth_scheme: Some(auth),
151        });
152    }
153
154    pub async fn insert(&self, requests: InsertRequests) -> Result<u32> {
155        self.handle(Request::Inserts(requests)).await
156    }
157
158    pub async fn insert_with_hints(
159        &self,
160        requests: InsertRequests,
161        hints: &[(&str, &str)],
162    ) -> Result<u32> {
163        let mut client = make_database_client(&self.client)?.inner;
164        let request = self.to_rpc_request(Request::Inserts(requests));
165
166        let mut request = tonic::Request::new(request);
167        let metadata = request.metadata_mut();
168        for (key, value) in hints {
169            let key = AsciiMetadataKey::from_bytes(format!("x-greptime-hint-{}", key).as_bytes())
170                .map_err(|_| {
171                InvalidAsciiSnafu {
172                    value: key.to_string(),
173                }
174                .build()
175            })?;
176            let value = value.parse().map_err(|_| {
177                InvalidAsciiSnafu {
178                    value: value.to_string(),
179                }
180                .build()
181            })?;
182            metadata.insert(key, value);
183        }
184        let response = client.handle(request).await?.into_inner();
185        from_grpc_response(response)
186    }
187
188    pub async fn handle(&self, request: Request) -> Result<u32> {
189        let mut client = make_database_client(&self.client)?.inner;
190        let request = self.to_rpc_request(request);
191        let response = client.handle(request).await?.into_inner();
192        from_grpc_response(response)
193    }
194
195    /// Retry if connection fails, max_retries is the max number of retries, so the total wait time
196    /// is `max_retries * GRPC_CONN_TIMEOUT`
197    pub async fn handle_with_retry(&self, request: Request, max_retries: u32) -> Result<u32> {
198        let mut client = make_database_client(&self.client)?.inner;
199        let mut retries = 0;
200        let request = self.to_rpc_request(request);
201        loop {
202            let raw_response = client.handle(request.clone()).await;
203            match (raw_response, retries < max_retries) {
204                (Ok(resp), _) => return from_grpc_response(resp.into_inner()),
205                (Err(err), true) => {
206                    // determine if the error is retryable
207                    if is_grpc_retryable(&err) {
208                        // retry
209                        retries += 1;
210                        warn!("Retrying {} times with error = {:?}", retries, err);
211                        continue;
212                    }
213                }
214                (Err(err), false) => {
215                    error!(
216                        "Failed to send request to grpc handle after {} retries, error = {:?}",
217                        retries, err
218                    );
219                    return Err(err.into());
220                }
221            }
222        }
223    }
224
225    #[inline]
226    fn to_rpc_request(&self, request: Request) -> GreptimeRequest {
227        GreptimeRequest {
228            header: Some(RequestHeader {
229                catalog: self.catalog.clone(),
230                schema: self.schema.clone(),
231                authorization: self.ctx.auth_header.clone(),
232                dbname: self.dbname.clone(),
233                timezone: self.timezone.clone(),
234                // TODO(Taylor-lagrange): add client grpc tracing
235                tracing_context: W3cTrace::new(),
236            }),
237            request: Some(request),
238        }
239    }
240
241    pub async fn sql<S>(&self, sql: S) -> Result<Output>
242    where
243        S: AsRef<str>,
244    {
245        self.do_get(Request::Query(QueryRequest {
246            query: Some(Query::Sql(sql.as_ref().to_string())),
247        }))
248        .await
249    }
250
251    pub async fn logical_plan(&self, logical_plan: Vec<u8>) -> Result<Output> {
252        self.do_get(Request::Query(QueryRequest {
253            query: Some(Query::LogicalPlan(logical_plan)),
254        }))
255        .await
256    }
257
258    pub async fn create(&self, expr: CreateTableExpr) -> Result<Output> {
259        self.do_get(Request::Ddl(DdlRequest {
260            expr: Some(DdlExpr::CreateTable(expr)),
261        }))
262        .await
263    }
264
265    pub async fn alter(&self, expr: AlterTableExpr) -> Result<Output> {
266        self.do_get(Request::Ddl(DdlRequest {
267            expr: Some(DdlExpr::AlterTable(expr)),
268        }))
269        .await
270    }
271
272    async fn do_get(&self, request: Request) -> Result<Output> {
273        let request = self.to_rpc_request(request);
274        let request = Ticket {
275            ticket: request.encode_to_vec().into(),
276        };
277
278        let mut client = self.client.make_flight_client()?;
279
280        let response = client.mut_inner().do_get(request).await.or_else(|e| {
281            let tonic_code = e.code();
282            let e: Error = e.into();
283            let code = e.status_code();
284            let msg = e.to_string();
285            let error =
286                Err(BoxedError::new(ServerSnafu { code, msg }.build())).with_context(|_| {
287                    FlightGetSnafu {
288                        addr: client.addr().to_string(),
289                        tonic_code,
290                    }
291                });
292            error!(
293                "Failed to do Flight get, addr: {}, code: {}, source: {:?}",
294                client.addr(),
295                tonic_code,
296                error
297            );
298            error
299        })?;
300
301        let flight_data_stream = response.into_inner();
302        let mut decoder = FlightDecoder::default();
303
304        let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
305            flight_data
306                .map_err(Error::from)
307                .and_then(|data| decoder.try_decode(data).context(ConvertFlightDataSnafu))
308        });
309
310        let Some(first_flight_message) = flight_message_stream.next().await else {
311            return IllegalFlightMessagesSnafu {
312                reason: "Expect the response not to be empty",
313            }
314            .fail();
315        };
316
317        let first_flight_message = first_flight_message?;
318
319        match first_flight_message {
320            FlightMessage::AffectedRows(rows) => {
321                ensure!(
322                    flight_message_stream.next().await.is_none(),
323                    IllegalFlightMessagesSnafu {
324                        reason: "Expect 'AffectedRows' Flight messages to be the one and the only!"
325                    }
326                );
327                Ok(Output::new_with_affected_rows(rows))
328            }
329            FlightMessage::Recordbatch(_) | FlightMessage::Metrics(_) => {
330                IllegalFlightMessagesSnafu {
331                    reason: "The first flight message cannot be a RecordBatch or Metrics message",
332                }
333                .fail()
334            }
335            FlightMessage::Schema(schema) => {
336                let stream = Box::pin(stream!({
337                    while let Some(flight_message) = flight_message_stream.next().await {
338                        let flight_message = flight_message
339                            .map_err(BoxedError::new)
340                            .context(ExternalSnafu)?;
341                        match flight_message {
342                            FlightMessage::Recordbatch(record_batch) => yield Ok(record_batch),
343                            FlightMessage::Metrics(_) => {}
344                            FlightMessage::AffectedRows(_) | FlightMessage::Schema(_) => {
345                                yield IllegalFlightMessagesSnafu {reason: format!("A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}", flight_message)}
346                                        .fail()
347                                        .map_err(BoxedError::new)
348                                        .context(ExternalSnafu);
349                                break;
350                            }
351                        }
352                    }
353                }));
354                let record_batch_stream = RecordBatchStreamWrapper {
355                    schema,
356                    stream,
357                    output_ordering: None,
358                    metrics: Default::default(),
359                };
360                Ok(Output::new_with_stream(Box::pin(record_batch_stream)))
361            }
362        }
363    }
364
365    /// Ingest a stream of [RecordBatch]es that belong to a table, using Arrow Flight's "`DoPut`"
366    /// method. The return value is also a stream, produces [DoPutResponse]s.
367    pub async fn do_put(&self, stream: FlightDataStream) -> Result<DoPutResponseStream> {
368        let mut request = tonic::Request::new(stream);
369
370        if let Some(AuthHeader {
371            auth_scheme: Some(AuthScheme::Basic(Basic { username, password })),
372        }) = &self.ctx.auth_header
373        {
374            let encoded = BASE64_STANDARD.encode(format!("{username}:{password}"));
375            let value =
376                MetadataValue::from_str(&encoded).context(InvalidTonicMetadataValueSnafu)?;
377            request.metadata_mut().insert("x-greptime-auth", value);
378        }
379
380        let db_to_put = if !self.dbname.is_empty() {
381            &self.dbname
382        } else {
383            &build_db_string(self.catalog_or_default(), self.schema_or_default())
384        };
385        request.metadata_mut().insert(
386            "x-greptime-db-name",
387            MetadataValue::from_str(db_to_put).context(InvalidTonicMetadataValueSnafu)?,
388        );
389
390        let mut client = self.client.make_flight_client()?;
391        let response = client.mut_inner().do_put(request).await?;
392        let response = response
393            .into_inner()
394            .map_err(Into::into)
395            .and_then(|x| future::ready(DoPutResponse::try_from(x).context(ConvertFlightDataSnafu)))
396            .boxed();
397        Ok(response)
398    }
399}
400
401/// by grpc standard, only `Unavailable` is retryable, see: https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
402pub fn is_grpc_retryable(err: &tonic::Status) -> bool {
403    matches!(err.code(), tonic::Code::Unavailable)
404}
405
406#[derive(Default, Debug, Clone)]
407struct FlightContext {
408    auth_header: Option<AuthHeader>,
409}
410
411#[cfg(test)]
412mod tests {
413    use std::assert_matches::assert_matches;
414
415    use api::v1::auth_header::AuthScheme;
416    use api::v1::{AuthHeader, Basic};
417
418    use super::*;
419
420    #[test]
421    fn test_flight_ctx() {
422        let mut ctx = FlightContext::default();
423        assert!(ctx.auth_header.is_none());
424
425        let basic = AuthScheme::Basic(Basic {
426            username: "u".to_string(),
427            password: "p".to_string(),
428        });
429
430        ctx.auth_header = Some(AuthHeader {
431            auth_scheme: Some(basic),
432        });
433
434        assert_matches!(
435            ctx.auth_header,
436            Some(AuthHeader {
437                auth_scheme: Some(AuthScheme::Basic(_)),
438            })
439        )
440    }
441}