1use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::Arc;
18
19use api::v1::auth_header::AuthScheme;
20use api::v1::ddl_request::Expr as DdlExpr;
21use api::v1::greptime_database_client::GreptimeDatabaseClient;
22use api::v1::greptime_request::Request;
23use api::v1::query_request::Query;
24use api::v1::{
25 AlterTableExpr, AuthHeader, Basic, CreateTableExpr, DdlRequest, GreptimeRequest,
26 InsertRequests, QueryRequest, RequestHeader, RowInsertRequests,
27};
28use arrow_flight::{FlightData, Ticket};
29use async_stream::stream;
30use base64::prelude::BASE64_STANDARD;
31use base64::Engine;
32use common_catalog::build_db_string;
33use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
34use common_error::ext::BoxedError;
35use common_grpc::flight::do_put::DoPutResponse;
36use common_grpc::flight::{FlightDecoder, FlightMessage};
37use common_query::Output;
38use common_recordbatch::error::ExternalSnafu;
39use common_recordbatch::{RecordBatch, RecordBatchStreamWrapper};
40use common_telemetry::tracing_context::W3cTrace;
41use common_telemetry::{error, warn};
42use futures::future;
43use futures_util::{Stream, StreamExt, TryStreamExt};
44use prost::Message;
45use snafu::{ensure, OptionExt, ResultExt};
46use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue, MetadataMap, MetadataValue};
47use tonic::transport::Channel;
48
49use crate::error::{
50 ConvertFlightDataSnafu, Error, FlightGetSnafu, IllegalFlightMessagesSnafu,
51 InvalidTonicMetadataValueSnafu,
52};
53use crate::{error, from_grpc_response, Client, Result};
54
55type FlightDataStream = Pin<Box<dyn Stream<Item = FlightData> + Send>>;
56
57type DoPutResponseStream = Pin<Box<dyn Stream<Item = Result<DoPutResponse>>>>;
58
59#[derive(Clone, Debug, Default)]
60pub struct Database {
61 catalog: String,
65 schema: String,
66 dbname: String,
69 timezone: String,
72
73 client: Client,
74 ctx: FlightContext,
75}
76
77pub struct DatabaseClient {
78 pub inner: GreptimeDatabaseClient<Channel>,
79}
80
81fn make_database_client(client: &Client) -> Result<DatabaseClient> {
82 let (_, channel) = client.find_channel()?;
83 Ok(DatabaseClient {
84 inner: GreptimeDatabaseClient::new(channel)
85 .max_decoding_message_size(client.max_grpc_recv_message_size())
86 .max_encoding_message_size(client.max_grpc_send_message_size()),
87 })
88}
89
90impl Database {
91 pub fn new(catalog: impl Into<String>, schema: impl Into<String>, client: Client) -> Self {
93 Self {
94 catalog: catalog.into(),
95 schema: schema.into(),
96 dbname: String::default(),
97 timezone: String::default(),
98 client,
99 ctx: FlightContext::default(),
100 }
101 }
102
103 pub fn new_with_dbname(dbname: impl Into<String>, client: Client) -> Self {
111 Self {
112 catalog: String::default(),
113 schema: String::default(),
114 timezone: String::default(),
115 dbname: dbname.into(),
116 client,
117 ctx: FlightContext::default(),
118 }
119 }
120
121 pub fn set_catalog(&mut self, catalog: impl Into<String>) {
123 self.catalog = catalog.into();
124 }
125
126 fn catalog_or_default(&self) -> &str {
127 if self.catalog.is_empty() {
128 DEFAULT_CATALOG_NAME
129 } else {
130 &self.catalog
131 }
132 }
133
134 pub fn set_schema(&mut self, schema: impl Into<String>) {
136 self.schema = schema.into();
137 }
138
139 fn schema_or_default(&self) -> &str {
140 if self.schema.is_empty() {
141 DEFAULT_SCHEMA_NAME
142 } else {
143 &self.schema
144 }
145 }
146
147 pub fn set_timezone(&mut self, timezone: impl Into<String>) {
149 self.timezone = timezone.into();
150 }
151
152 pub fn set_auth(&mut self, auth: AuthScheme) {
154 self.ctx.auth_header = Some(AuthHeader {
155 auth_scheme: Some(auth),
156 });
157 }
158
159 pub async fn insert(&self, requests: InsertRequests) -> Result<u32> {
161 self.handle(Request::Inserts(requests)).await
162 }
163
164 pub async fn insert_with_hints(
166 &self,
167 requests: InsertRequests,
168 hints: &[(&str, &str)],
169 ) -> Result<u32> {
170 let mut client = make_database_client(&self.client)?.inner;
171 let request = self.to_rpc_request(Request::Inserts(requests));
172
173 let mut request = tonic::Request::new(request);
174 let metadata = request.metadata_mut();
175 Self::put_hints(metadata, hints)?;
176
177 let response = client.handle(request).await?.into_inner();
178 from_grpc_response(response)
179 }
180
181 pub async fn row_inserts(&self, requests: RowInsertRequests) -> Result<u32> {
183 self.handle(Request::RowInserts(requests)).await
184 }
185
186 pub async fn row_inserts_with_hints(
188 &self,
189 requests: RowInsertRequests,
190 hints: &[(&str, &str)],
191 ) -> Result<u32> {
192 let mut client = make_database_client(&self.client)?.inner;
193 let request = self.to_rpc_request(Request::RowInserts(requests));
194
195 let mut request = tonic::Request::new(request);
196 let metadata = request.metadata_mut();
197 Self::put_hints(metadata, hints)?;
198
199 let response = client.handle(request).await?.into_inner();
200 from_grpc_response(response)
201 }
202
203 fn put_hints(metadata: &mut MetadataMap, hints: &[(&str, &str)]) -> Result<()> {
204 let Some(value) = hints
205 .iter()
206 .map(|(k, v)| format!("{}={}", k, v))
207 .reduce(|a, b| format!("{},{}", a, b))
208 else {
209 return Ok(());
210 };
211
212 let key = AsciiMetadataKey::from_static("x-greptime-hints");
213 let value = AsciiMetadataValue::from_str(&value).context(InvalidTonicMetadataValueSnafu)?;
214 metadata.insert(key, value);
215 Ok(())
216 }
217
218 pub async fn handle(&self, request: Request) -> Result<u32> {
220 let mut client = make_database_client(&self.client)?.inner;
221 let request = self.to_rpc_request(request);
222 let response = client.handle(request).await?.into_inner();
223 from_grpc_response(response)
224 }
225
226 pub async fn handle_with_retry(
229 &self,
230 request: Request,
231 max_retries: u32,
232 hints: &[(&str, &str)],
233 ) -> Result<u32> {
234 let mut client = make_database_client(&self.client)?.inner;
235 let mut retries = 0;
236
237 let request = self.to_rpc_request(request);
238
239 loop {
240 let mut tonic_request = tonic::Request::new(request.clone());
241 let metadata = tonic_request.metadata_mut();
242 Self::put_hints(metadata, hints)?;
243 let raw_response = client.handle(tonic_request).await;
244 match (raw_response, retries < max_retries) {
245 (Ok(resp), _) => return from_grpc_response(resp.into_inner()),
246 (Err(err), true) => {
247 if is_grpc_retryable(&err) {
249 retries += 1;
251 warn!("Retrying {} times with error = {:?}", retries, err);
252 continue;
253 } else {
254 error!(
255 err; "Failed to send request to grpc handle, retries = {}, not retryable error, aborting",
256 retries
257 );
258 return Err(err.into());
259 }
260 }
261 (Err(err), false) => {
262 error!(
263 err; "Failed to send request to grpc handle after {} retries",
264 retries,
265 );
266 return Err(err.into());
267 }
268 }
269 }
270 }
271
272 #[inline]
273 fn to_rpc_request(&self, request: Request) -> GreptimeRequest {
274 GreptimeRequest {
275 header: Some(RequestHeader {
276 catalog: self.catalog.clone(),
277 schema: self.schema.clone(),
278 authorization: self.ctx.auth_header.clone(),
279 dbname: self.dbname.clone(),
280 timezone: self.timezone.clone(),
281 tracing_context: W3cTrace::new(),
283 }),
284 request: Some(request),
285 }
286 }
287
288 pub async fn sql<S>(&self, sql: S) -> Result<Output>
290 where
291 S: AsRef<str>,
292 {
293 self.sql_with_hint(sql, &[]).await
294 }
295
296 pub async fn sql_with_hint<S>(&self, sql: S, hints: &[(&str, &str)]) -> Result<Output>
298 where
299 S: AsRef<str>,
300 {
301 let request = Request::Query(QueryRequest {
302 query: Some(Query::Sql(sql.as_ref().to_string())),
303 });
304 self.do_get(request, hints).await
305 }
306
307 pub async fn logical_plan(&self, logical_plan: Vec<u8>) -> Result<Output> {
309 let request = Request::Query(QueryRequest {
310 query: Some(Query::LogicalPlan(logical_plan)),
311 });
312 self.do_get(request, &[]).await
313 }
314
315 pub async fn create(&self, expr: CreateTableExpr) -> Result<Output> {
317 let request = Request::Ddl(DdlRequest {
318 expr: Some(DdlExpr::CreateTable(expr)),
319 });
320 self.do_get(request, &[]).await
321 }
322
323 pub async fn alter(&self, expr: AlterTableExpr) -> Result<Output> {
325 let request = Request::Ddl(DdlRequest {
326 expr: Some(DdlExpr::AlterTable(expr)),
327 });
328 self.do_get(request, &[]).await
329 }
330
331 async fn do_get(&self, request: Request, hints: &[(&str, &str)]) -> Result<Output> {
332 let request = self.to_rpc_request(request);
333 let request = Ticket {
334 ticket: request.encode_to_vec().into(),
335 };
336
337 let mut request = tonic::Request::new(request);
338 Self::put_hints(request.metadata_mut(), hints)?;
339
340 let mut client = self.client.make_flight_client(false, false)?;
341
342 let response = client.mut_inner().do_get(request).await.or_else(|e| {
343 let tonic_code = e.code();
344 let e: Error = e.into();
345 error!(
346 "Failed to do Flight get, addr: {}, code: {}, source: {:?}",
347 client.addr(),
348 tonic_code,
349 e
350 );
351 let error = Err(BoxedError::new(e)).with_context(|_| FlightGetSnafu {
352 addr: client.addr().to_string(),
353 tonic_code,
354 });
355 error
356 })?;
357
358 let flight_data_stream = response.into_inner();
359 let mut decoder = FlightDecoder::default();
360
361 let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
362 flight_data
363 .map_err(Error::from)
364 .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))?
365 .context(IllegalFlightMessagesSnafu {
366 reason: "none message",
367 })
368 });
369
370 let Some(first_flight_message) = flight_message_stream.next().await else {
371 return IllegalFlightMessagesSnafu {
372 reason: "Expect the response not to be empty",
373 }
374 .fail();
375 };
376
377 let first_flight_message = first_flight_message?;
378
379 match first_flight_message {
380 FlightMessage::AffectedRows(rows) => {
381 ensure!(
382 flight_message_stream.next().await.is_none(),
383 IllegalFlightMessagesSnafu {
384 reason: "Expect 'AffectedRows' Flight messages to be the one and the only!"
385 }
386 );
387 Ok(Output::new_with_affected_rows(rows))
388 }
389 FlightMessage::RecordBatch(_) | FlightMessage::Metrics(_) => {
390 IllegalFlightMessagesSnafu {
391 reason: "The first flight message cannot be a RecordBatch or Metrics message",
392 }
393 .fail()
394 }
395 FlightMessage::Schema(schema) => {
396 let schema = Arc::new(
397 datatypes::schema::Schema::try_from(schema)
398 .context(error::ConvertSchemaSnafu)?,
399 );
400 let schema_cloned = schema.clone();
401 let stream = Box::pin(stream!({
402 while let Some(flight_message) = flight_message_stream.next().await {
403 let flight_message = flight_message
404 .map_err(BoxedError::new)
405 .context(ExternalSnafu)?;
406 match flight_message {
407 FlightMessage::RecordBatch(arrow_batch) => {
408 yield RecordBatch::try_from_df_record_batch(
409 schema_cloned.clone(),
410 arrow_batch,
411 )
412 }
413 FlightMessage::Metrics(_) => {}
414 FlightMessage::AffectedRows(_) | FlightMessage::Schema(_) => {
415 yield IllegalFlightMessagesSnafu {reason: format!("A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}", flight_message)}
416 .fail()
417 .map_err(BoxedError::new)
418 .context(ExternalSnafu);
419 break;
420 }
421 }
422 }
423 }));
424 let record_batch_stream = RecordBatchStreamWrapper {
425 schema,
426 stream,
427 output_ordering: None,
428 metrics: Default::default(),
429 };
430 Ok(Output::new_with_stream(Box::pin(record_batch_stream)))
431 }
432 }
433 }
434
435 pub async fn do_put(&self, stream: FlightDataStream) -> Result<DoPutResponseStream> {
438 let mut request = tonic::Request::new(stream);
439
440 if let Some(AuthHeader {
441 auth_scheme: Some(AuthScheme::Basic(Basic { username, password })),
442 }) = &self.ctx.auth_header
443 {
444 let encoded = BASE64_STANDARD.encode(format!("{username}:{password}"));
445 let value =
446 MetadataValue::from_str(&encoded).context(InvalidTonicMetadataValueSnafu)?;
447 request.metadata_mut().insert("x-greptime-auth", value);
448 }
449
450 let db_to_put = if !self.dbname.is_empty() {
451 &self.dbname
452 } else {
453 &build_db_string(self.catalog_or_default(), self.schema_or_default())
454 };
455 request.metadata_mut().insert(
456 "x-greptime-db-name",
457 MetadataValue::from_str(db_to_put).context(InvalidTonicMetadataValueSnafu)?,
458 );
459
460 let mut client = self.client.make_flight_client(false, false)?;
461 let response = client.mut_inner().do_put(request).await?;
462 let response = response
463 .into_inner()
464 .map_err(Into::into)
465 .and_then(|x| future::ready(DoPutResponse::try_from(x).context(ConvertFlightDataSnafu)))
466 .boxed();
467 Ok(response)
468 }
469}
470
471pub fn is_grpc_retryable(err: &tonic::Status) -> bool {
473 matches!(err.code(), tonic::Code::Unavailable)
474}
475
476#[derive(Default, Debug, Clone)]
477struct FlightContext {
478 auth_header: Option<AuthHeader>,
479}
480
481#[cfg(test)]
482mod tests {
483 use std::assert_matches::assert_matches;
484
485 use api::v1::auth_header::AuthScheme;
486 use api::v1::{AuthHeader, Basic};
487 use common_error::status_code::StatusCode;
488 use tonic::{Code, Status};
489
490 use super::*;
491 use crate::error::TonicSnafu;
492
493 #[test]
494 fn test_flight_ctx() {
495 let mut ctx = FlightContext::default();
496 assert!(ctx.auth_header.is_none());
497
498 let basic = AuthScheme::Basic(Basic {
499 username: "u".to_string(),
500 password: "p".to_string(),
501 });
502
503 ctx.auth_header = Some(AuthHeader {
504 auth_scheme: Some(basic),
505 });
506
507 assert_matches!(
508 ctx.auth_header,
509 Some(AuthHeader {
510 auth_scheme: Some(AuthScheme::Basic(_)),
511 })
512 )
513 }
514
515 #[test]
516 fn test_from_tonic_status() {
517 let expected = TonicSnafu {
518 code: StatusCode::Internal,
519 msg: "blabla".to_string(),
520 tonic_code: Code::Internal,
521 }
522 .build();
523
524 let status = Status::new(Code::Internal, "blabla");
525 let actual: Error = status.into();
526
527 assert_eq!(expected.to_string(), actual.to_string());
528 }
529}