1use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::Arc;
18
19use api::v1::auth_header::AuthScheme;
20use api::v1::ddl_request::Expr as DdlExpr;
21use api::v1::greptime_database_client::GreptimeDatabaseClient;
22use api::v1::greptime_request::Request;
23use api::v1::query_request::Query;
24use api::v1::{
25 AlterTableExpr, AuthHeader, Basic, CreateTableExpr, DdlRequest, GreptimeRequest,
26 InsertRequests, QueryRequest, RequestHeader, RowInsertRequests,
27};
28use arrow_flight::{FlightData, Ticket};
29use async_stream::stream;
30use base64::Engine;
31use base64::prelude::BASE64_STANDARD;
32use common_catalog::build_db_string;
33use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
34use common_error::ext::BoxedError;
35use common_grpc::flight::do_put::DoPutResponse;
36use common_grpc::flight::{FlightDecoder, FlightMessage};
37use common_query::Output;
38use common_recordbatch::error::ExternalSnafu;
39use common_recordbatch::{RecordBatch, RecordBatchStreamWrapper};
40use common_telemetry::tracing::Span;
41use common_telemetry::tracing_context::W3cTrace;
42use common_telemetry::{error, warn};
43use futures::future;
44use futures_util::{Stream, StreamExt, TryStreamExt};
45use prost::Message;
46use snafu::{OptionExt, ResultExt, ensure};
47use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue, MetadataMap, MetadataValue};
48use tonic::transport::Channel;
49
50use crate::error::{
51 ConvertFlightDataSnafu, Error, FlightGetSnafu, IllegalFlightMessagesSnafu,
52 InvalidTonicMetadataValueSnafu,
53};
54use crate::{Client, Result, error, from_grpc_response};
55
56type FlightDataStream = Pin<Box<dyn Stream<Item = FlightData> + Send>>;
57
58type DoPutResponseStream = Pin<Box<dyn Stream<Item = Result<DoPutResponse>>>>;
59
60#[derive(Clone, Debug, Default)]
61pub struct Database {
62 catalog: String,
66 schema: String,
67 dbname: String,
70 timezone: String,
73
74 client: Client,
75 ctx: FlightContext,
76}
77
78pub struct DatabaseClient {
79 pub addr: String,
80 pub inner: GreptimeDatabaseClient<Channel>,
81}
82
83impl DatabaseClient {
84 pub fn inspect_err<'a>(&'a self, context: &'a str) -> impl Fn(&tonic::Status) + 'a {
86 let addr = &self.addr;
87 move |status| {
88 error!("Failed to {context} request, peer: {addr}, status: {status:?}");
89 }
90 }
91}
92
93fn make_database_client(client: &Client) -> Result<DatabaseClient> {
94 let (addr, channel) = client.find_channel()?;
95 Ok(DatabaseClient {
96 addr,
97 inner: GreptimeDatabaseClient::new(channel)
98 .max_decoding_message_size(client.max_grpc_recv_message_size())
99 .max_encoding_message_size(client.max_grpc_send_message_size()),
100 })
101}
102
103impl Database {
104 pub fn new(catalog: impl Into<String>, schema: impl Into<String>, client: Client) -> Self {
106 Self {
107 catalog: catalog.into(),
108 schema: schema.into(),
109 dbname: String::default(),
110 timezone: String::default(),
111 client,
112 ctx: FlightContext::default(),
113 }
114 }
115
116 pub fn new_with_dbname(dbname: impl Into<String>, client: Client) -> Self {
124 Self {
125 catalog: String::default(),
126 schema: String::default(),
127 timezone: String::default(),
128 dbname: dbname.into(),
129 client,
130 ctx: FlightContext::default(),
131 }
132 }
133
134 pub fn set_catalog(&mut self, catalog: impl Into<String>) {
136 self.catalog = catalog.into();
137 }
138
139 fn catalog_or_default(&self) -> &str {
140 if self.catalog.is_empty() {
141 DEFAULT_CATALOG_NAME
142 } else {
143 &self.catalog
144 }
145 }
146
147 pub fn set_schema(&mut self, schema: impl Into<String>) {
149 self.schema = schema.into();
150 }
151
152 fn schema_or_default(&self) -> &str {
153 if self.schema.is_empty() {
154 DEFAULT_SCHEMA_NAME
155 } else {
156 &self.schema
157 }
158 }
159
160 pub fn set_timezone(&mut self, timezone: impl Into<String>) {
162 self.timezone = timezone.into();
163 }
164
165 pub fn set_auth(&mut self, auth: AuthScheme) {
167 self.ctx.auth_header = Some(AuthHeader {
168 auth_scheme: Some(auth),
169 });
170 }
171
172 pub async fn insert(&self, requests: InsertRequests) -> Result<u32> {
174 self.handle(Request::Inserts(requests)).await
175 }
176
177 pub async fn insert_with_hints(
179 &self,
180 requests: InsertRequests,
181 hints: &[(&str, &str)],
182 ) -> Result<u32> {
183 let mut client = make_database_client(&self.client)?;
184 let request = self.to_rpc_request(Request::Inserts(requests));
185
186 let mut request = tonic::Request::new(request);
187 let metadata = request.metadata_mut();
188 Self::put_hints(metadata, hints)?;
189
190 let response = client
191 .inner
192 .handle(request)
193 .await
194 .inspect_err(client.inspect_err("insert_with_hints"))?
195 .into_inner();
196 from_grpc_response(response)
197 }
198
199 pub async fn row_inserts(&self, requests: RowInsertRequests) -> Result<u32> {
201 self.handle(Request::RowInserts(requests)).await
202 }
203
204 pub async fn row_inserts_with_hints(
206 &self,
207 requests: RowInsertRequests,
208 hints: &[(&str, &str)],
209 ) -> Result<u32> {
210 let mut client = make_database_client(&self.client)?;
211 let request = self.to_rpc_request(Request::RowInserts(requests));
212
213 let mut request = tonic::Request::new(request);
214 let metadata = request.metadata_mut();
215 Self::put_hints(metadata, hints)?;
216
217 let response = client
218 .inner
219 .handle(request)
220 .await
221 .inspect_err(client.inspect_err("row_inserts_with_hints"))?
222 .into_inner();
223 from_grpc_response(response)
224 }
225
226 fn put_hints(metadata: &mut MetadataMap, hints: &[(&str, &str)]) -> Result<()> {
227 let Some(value) = hints
228 .iter()
229 .map(|(k, v)| format!("{}={}", k, v))
230 .reduce(|a, b| format!("{},{}", a, b))
231 else {
232 return Ok(());
233 };
234
235 let key = AsciiMetadataKey::from_static("x-greptime-hints");
236 let value = AsciiMetadataValue::from_str(&value).context(InvalidTonicMetadataValueSnafu)?;
237 metadata.insert(key, value);
238 Ok(())
239 }
240
241 pub async fn handle(&self, request: Request) -> Result<u32> {
243 let mut client = make_database_client(&self.client)?;
244 let request = self.to_rpc_request(request);
245 let response = client
246 .inner
247 .handle(request)
248 .await
249 .inspect_err(client.inspect_err("handle"))?
250 .into_inner();
251 from_grpc_response(response)
252 }
253
254 pub async fn handle_with_retry(
257 &self,
258 request: Request,
259 max_retries: u32,
260 hints: &[(&str, &str)],
261 ) -> Result<u32> {
262 let mut client = make_database_client(&self.client)?;
263 let mut retries = 0;
264
265 let request = self.to_rpc_request(request);
266
267 loop {
268 let mut tonic_request = tonic::Request::new(request.clone());
269 let metadata = tonic_request.metadata_mut();
270 Self::put_hints(metadata, hints)?;
271 let raw_response = client
272 .inner
273 .handle(tonic_request)
274 .await
275 .inspect_err(client.inspect_err("handle"));
276 match (raw_response, retries < max_retries) {
277 (Ok(resp), _) => return from_grpc_response(resp.into_inner()),
278 (Err(err), true) => {
279 if is_grpc_retryable(&err) {
281 retries += 1;
283 warn!("Retrying {} times with error = {:?}", retries, err);
284 continue;
285 } else {
286 error!(
287 err; "Failed to send request to grpc handle, retries = {}, not retryable error, aborting",
288 retries
289 );
290 return Err(err.into());
291 }
292 }
293 (Err(err), false) => {
294 error!(
295 err; "Failed to send request to grpc handle after {} retries",
296 retries,
297 );
298 return Err(err.into());
299 }
300 }
301 }
302 }
303
304 #[inline]
305 fn to_rpc_request(&self, request: Request) -> GreptimeRequest {
306 GreptimeRequest {
307 header: Some(RequestHeader {
308 catalog: self.catalog.clone(),
309 schema: self.schema.clone(),
310 authorization: self.ctx.auth_header.clone(),
311 dbname: self.dbname.clone(),
312 timezone: self.timezone.clone(),
313 tracing_context: W3cTrace::new(),
315 }),
316 request: Some(request),
317 }
318 }
319
320 pub async fn sql<S>(&self, sql: S) -> Result<Output>
322 where
323 S: AsRef<str>,
324 {
325 self.sql_with_hint(sql, &[]).await
326 }
327
328 pub async fn sql_with_hint<S>(&self, sql: S, hints: &[(&str, &str)]) -> Result<Output>
330 where
331 S: AsRef<str>,
332 {
333 let request = Request::Query(QueryRequest {
334 query: Some(Query::Sql(sql.as_ref().to_string())),
335 });
336 self.do_get(request, hints).await
337 }
338
339 pub async fn logical_plan(&self, logical_plan: Vec<u8>) -> Result<Output> {
341 let request = Request::Query(QueryRequest {
342 query: Some(Query::LogicalPlan(logical_plan)),
343 });
344 self.do_get(request, &[]).await
345 }
346
347 pub async fn create(&self, expr: CreateTableExpr) -> Result<Output> {
349 let request = Request::Ddl(DdlRequest {
350 expr: Some(DdlExpr::CreateTable(expr)),
351 });
352 self.do_get(request, &[]).await
353 }
354
355 pub async fn alter(&self, expr: AlterTableExpr) -> Result<Output> {
357 let request = Request::Ddl(DdlRequest {
358 expr: Some(DdlExpr::AlterTable(expr)),
359 });
360 self.do_get(request, &[]).await
361 }
362
363 async fn do_get(&self, request: Request, hints: &[(&str, &str)]) -> Result<Output> {
364 let request = self.to_rpc_request(request);
365 let request = Ticket {
366 ticket: request.encode_to_vec().into(),
367 };
368
369 let mut request = tonic::Request::new(request);
370 Self::put_hints(request.metadata_mut(), hints)?;
371
372 let mut client = self.client.make_flight_client(false, false)?;
373
374 let response = client.mut_inner().do_get(request).await.or_else(|e| {
375 let tonic_code = e.code();
376 let e: Error = e.into();
377 error!(
378 "Failed to do Flight get, addr: {}, code: {}, source: {:?}",
379 client.addr(),
380 tonic_code,
381 e
382 );
383 Err(BoxedError::new(e)).with_context(|_| FlightGetSnafu {
384 addr: client.addr().to_string(),
385 tonic_code,
386 })
387 })?;
388
389 let flight_data_stream = response.into_inner();
390 let mut decoder = FlightDecoder::default();
391
392 let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
393 flight_data
394 .map_err(Error::from)
395 .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))?
396 .context(IllegalFlightMessagesSnafu {
397 reason: "none message",
398 })
399 });
400
401 let Some(first_flight_message) = flight_message_stream.next().await else {
402 return IllegalFlightMessagesSnafu {
403 reason: "Expect the response not to be empty",
404 }
405 .fail();
406 };
407
408 let first_flight_message = first_flight_message?;
409
410 match first_flight_message {
411 FlightMessage::AffectedRows(rows) => {
412 ensure!(
413 flight_message_stream.next().await.is_none(),
414 IllegalFlightMessagesSnafu {
415 reason: "Expect 'AffectedRows' Flight messages to be the one and the only!"
416 }
417 );
418 Ok(Output::new_with_affected_rows(rows))
419 }
420 FlightMessage::RecordBatch(_) | FlightMessage::Metrics(_) => {
421 IllegalFlightMessagesSnafu {
422 reason: "The first flight message cannot be a RecordBatch or Metrics message",
423 }
424 .fail()
425 }
426 FlightMessage::Schema(schema) => {
427 let schema = Arc::new(
428 datatypes::schema::Schema::try_from(schema)
429 .context(error::ConvertSchemaSnafu)?,
430 );
431 let schema_cloned = schema.clone();
432 let stream = Box::pin(stream!({
433 while let Some(flight_message) = flight_message_stream.next().await {
434 let flight_message = flight_message
435 .map_err(BoxedError::new)
436 .context(ExternalSnafu)?;
437 match flight_message {
438 FlightMessage::RecordBatch(arrow_batch) => {
439 yield Ok(RecordBatch::from_df_record_batch(
440 schema_cloned.clone(),
441 arrow_batch,
442 ))
443 }
444 FlightMessage::Metrics(_) => {}
445 FlightMessage::AffectedRows(_) | FlightMessage::Schema(_) => {
446 yield IllegalFlightMessagesSnafu {reason: format!("A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}", flight_message)}
447 .fail()
448 .map_err(BoxedError::new)
449 .context(ExternalSnafu);
450 break;
451 }
452 }
453 }
454 }));
455 let record_batch_stream = RecordBatchStreamWrapper {
456 schema,
457 stream,
458 output_ordering: None,
459 metrics: Default::default(),
460 span: Span::current(),
461 };
462 Ok(Output::new_with_stream(Box::pin(record_batch_stream)))
463 }
464 }
465 }
466
467 pub async fn do_put(&self, stream: FlightDataStream) -> Result<DoPutResponseStream> {
470 let mut request = tonic::Request::new(stream);
471
472 if let Some(AuthHeader {
473 auth_scheme: Some(AuthScheme::Basic(Basic { username, password })),
474 }) = &self.ctx.auth_header
475 {
476 let encoded = BASE64_STANDARD.encode(format!("{username}:{password}"));
477 let value = MetadataValue::from_str(&format!("Basic {encoded}"))
478 .context(InvalidTonicMetadataValueSnafu)?;
479 request.metadata_mut().insert("x-greptime-auth", value);
480 }
481
482 let db_to_put = if !self.dbname.is_empty() {
483 &self.dbname
484 } else {
485 &build_db_string(self.catalog_or_default(), self.schema_or_default())
486 };
487 request.metadata_mut().insert(
488 "x-greptime-db-name",
489 MetadataValue::from_str(db_to_put).context(InvalidTonicMetadataValueSnafu)?,
490 );
491
492 let mut client = self.client.make_flight_client(false, false)?;
493 let response = client.mut_inner().do_put(request).await?;
494 let response = response
495 .into_inner()
496 .map_err(Into::into)
497 .and_then(|x| future::ready(DoPutResponse::try_from(x).context(ConvertFlightDataSnafu)))
498 .boxed();
499 Ok(response)
500 }
501}
502
503pub fn is_grpc_retryable(err: &tonic::Status) -> bool {
505 matches!(err.code(), tonic::Code::Unavailable)
506}
507
508#[derive(Default, Debug, Clone)]
509struct FlightContext {
510 auth_header: Option<AuthHeader>,
511}
512
513#[cfg(test)]
514mod tests {
515 use std::assert_matches::assert_matches;
516
517 use api::v1::auth_header::AuthScheme;
518 use api::v1::{AuthHeader, Basic};
519 use common_error::status_code::StatusCode;
520 use tonic::{Code, Status};
521
522 use super::*;
523 use crate::error::TonicSnafu;
524
525 #[test]
526 fn test_flight_ctx() {
527 let mut ctx = FlightContext::default();
528 assert!(ctx.auth_header.is_none());
529
530 let basic = AuthScheme::Basic(Basic {
531 username: "u".to_string(),
532 password: "p".to_string(),
533 });
534
535 ctx.auth_header = Some(AuthHeader {
536 auth_scheme: Some(basic),
537 });
538
539 assert_matches!(
540 ctx.auth_header,
541 Some(AuthHeader {
542 auth_scheme: Some(AuthScheme::Basic(_)),
543 })
544 )
545 }
546
547 #[test]
548 fn test_from_tonic_status() {
549 let expected = TonicSnafu {
550 code: StatusCode::Internal,
551 msg: "blabla".to_string(),
552 tonic_code: Code::Internal,
553 }
554 .build();
555
556 let status = Status::new(Code::Internal, "blabla");
557 let actual: Error = status.into();
558
559 assert_eq!(expected.to_string(), actual.to_string());
560 }
561}