1use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::Arc;
18
19use api::v1::auth_header::AuthScheme;
20use api::v1::ddl_request::Expr as DdlExpr;
21use api::v1::greptime_database_client::GreptimeDatabaseClient;
22use api::v1::greptime_request::Request;
23use api::v1::query_request::Query;
24use api::v1::{
25 AlterTableExpr, AuthHeader, Basic, CreateTableExpr, DdlRequest, GreptimeRequest,
26 InsertRequests, QueryRequest, RequestHeader, RowInsertRequests,
27};
28use arrow_flight::{FlightData, Ticket};
29use async_stream::stream;
30use base64::prelude::BASE64_STANDARD;
31use base64::Engine;
32use common_catalog::build_db_string;
33use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
34use common_error::ext::BoxedError;
35use common_grpc::flight::do_put::DoPutResponse;
36use common_grpc::flight::{FlightDecoder, FlightMessage};
37use common_query::Output;
38use common_recordbatch::error::ExternalSnafu;
39use common_recordbatch::{RecordBatch, RecordBatchStreamWrapper};
40use common_telemetry::tracing_context::W3cTrace;
41use common_telemetry::{error, warn};
42use futures::future;
43use futures_util::{Stream, StreamExt, TryStreamExt};
44use prost::Message;
45use snafu::{ensure, OptionExt, ResultExt};
46use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue, MetadataMap, MetadataValue};
47use tonic::transport::Channel;
48
49use crate::error::{
50 ConvertFlightDataSnafu, Error, FlightGetSnafu, IllegalFlightMessagesSnafu,
51 InvalidTonicMetadataValueSnafu,
52};
53use crate::{error, from_grpc_response, Client, Result};
54
55type FlightDataStream = Pin<Box<dyn Stream<Item = FlightData> + Send>>;
56
57type DoPutResponseStream = Pin<Box<dyn Stream<Item = Result<DoPutResponse>>>>;
58
59#[derive(Clone, Debug, Default)]
60pub struct Database {
61 catalog: String,
65 schema: String,
66 dbname: String,
69 timezone: String,
72
73 client: Client,
74 ctx: FlightContext,
75}
76
77pub struct DatabaseClient {
78 pub addr: String,
79 pub inner: GreptimeDatabaseClient<Channel>,
80}
81
82impl DatabaseClient {
83 pub fn inspect_err<'a>(&'a self, context: &'a str) -> impl Fn(&tonic::Status) + 'a {
85 let addr = &self.addr;
86 move |status| {
87 error!("Failed to {context} request, peer: {addr}, status: {status:?}");
88 }
89 }
90}
91
92fn make_database_client(client: &Client) -> Result<DatabaseClient> {
93 let (addr, channel) = client.find_channel()?;
94 Ok(DatabaseClient {
95 addr,
96 inner: GreptimeDatabaseClient::new(channel)
97 .max_decoding_message_size(client.max_grpc_recv_message_size())
98 .max_encoding_message_size(client.max_grpc_send_message_size()),
99 })
100}
101
102impl Database {
103 pub fn new(catalog: impl Into<String>, schema: impl Into<String>, client: Client) -> Self {
105 Self {
106 catalog: catalog.into(),
107 schema: schema.into(),
108 dbname: String::default(),
109 timezone: String::default(),
110 client,
111 ctx: FlightContext::default(),
112 }
113 }
114
115 pub fn new_with_dbname(dbname: impl Into<String>, client: Client) -> Self {
123 Self {
124 catalog: String::default(),
125 schema: String::default(),
126 timezone: String::default(),
127 dbname: dbname.into(),
128 client,
129 ctx: FlightContext::default(),
130 }
131 }
132
133 pub fn set_catalog(&mut self, catalog: impl Into<String>) {
135 self.catalog = catalog.into();
136 }
137
138 fn catalog_or_default(&self) -> &str {
139 if self.catalog.is_empty() {
140 DEFAULT_CATALOG_NAME
141 } else {
142 &self.catalog
143 }
144 }
145
146 pub fn set_schema(&mut self, schema: impl Into<String>) {
148 self.schema = schema.into();
149 }
150
151 fn schema_or_default(&self) -> &str {
152 if self.schema.is_empty() {
153 DEFAULT_SCHEMA_NAME
154 } else {
155 &self.schema
156 }
157 }
158
159 pub fn set_timezone(&mut self, timezone: impl Into<String>) {
161 self.timezone = timezone.into();
162 }
163
164 pub fn set_auth(&mut self, auth: AuthScheme) {
166 self.ctx.auth_header = Some(AuthHeader {
167 auth_scheme: Some(auth),
168 });
169 }
170
171 pub async fn insert(&self, requests: InsertRequests) -> Result<u32> {
173 self.handle(Request::Inserts(requests)).await
174 }
175
176 pub async fn insert_with_hints(
178 &self,
179 requests: InsertRequests,
180 hints: &[(&str, &str)],
181 ) -> Result<u32> {
182 let mut client = make_database_client(&self.client)?;
183 let request = self.to_rpc_request(Request::Inserts(requests));
184
185 let mut request = tonic::Request::new(request);
186 let metadata = request.metadata_mut();
187 Self::put_hints(metadata, hints)?;
188
189 let response = client
190 .inner
191 .handle(request)
192 .await
193 .inspect_err(client.inspect_err("insert_with_hints"))?
194 .into_inner();
195 from_grpc_response(response)
196 }
197
198 pub async fn row_inserts(&self, requests: RowInsertRequests) -> Result<u32> {
200 self.handle(Request::RowInserts(requests)).await
201 }
202
203 pub async fn row_inserts_with_hints(
205 &self,
206 requests: RowInsertRequests,
207 hints: &[(&str, &str)],
208 ) -> Result<u32> {
209 let mut client = make_database_client(&self.client)?;
210 let request = self.to_rpc_request(Request::RowInserts(requests));
211
212 let mut request = tonic::Request::new(request);
213 let metadata = request.metadata_mut();
214 Self::put_hints(metadata, hints)?;
215
216 let response = client
217 .inner
218 .handle(request)
219 .await
220 .inspect_err(client.inspect_err("row_inserts_with_hints"))?
221 .into_inner();
222 from_grpc_response(response)
223 }
224
225 fn put_hints(metadata: &mut MetadataMap, hints: &[(&str, &str)]) -> Result<()> {
226 let Some(value) = hints
227 .iter()
228 .map(|(k, v)| format!("{}={}", k, v))
229 .reduce(|a, b| format!("{},{}", a, b))
230 else {
231 return Ok(());
232 };
233
234 let key = AsciiMetadataKey::from_static("x-greptime-hints");
235 let value = AsciiMetadataValue::from_str(&value).context(InvalidTonicMetadataValueSnafu)?;
236 metadata.insert(key, value);
237 Ok(())
238 }
239
240 pub async fn handle(&self, request: Request) -> Result<u32> {
242 let mut client = make_database_client(&self.client)?;
243 let request = self.to_rpc_request(request);
244 let response = client
245 .inner
246 .handle(request)
247 .await
248 .inspect_err(client.inspect_err("handle"))?
249 .into_inner();
250 from_grpc_response(response)
251 }
252
253 pub async fn handle_with_retry(
256 &self,
257 request: Request,
258 max_retries: u32,
259 hints: &[(&str, &str)],
260 ) -> Result<u32> {
261 let mut client = make_database_client(&self.client)?;
262 let mut retries = 0;
263
264 let request = self.to_rpc_request(request);
265
266 loop {
267 let mut tonic_request = tonic::Request::new(request.clone());
268 let metadata = tonic_request.metadata_mut();
269 Self::put_hints(metadata, hints)?;
270 let raw_response = client
271 .inner
272 .handle(tonic_request)
273 .await
274 .inspect_err(client.inspect_err("handle"));
275 match (raw_response, retries < max_retries) {
276 (Ok(resp), _) => return from_grpc_response(resp.into_inner()),
277 (Err(err), true) => {
278 if is_grpc_retryable(&err) {
280 retries += 1;
282 warn!("Retrying {} times with error = {:?}", retries, err);
283 continue;
284 } else {
285 error!(
286 err; "Failed to send request to grpc handle, retries = {}, not retryable error, aborting",
287 retries
288 );
289 return Err(err.into());
290 }
291 }
292 (Err(err), false) => {
293 error!(
294 err; "Failed to send request to grpc handle after {} retries",
295 retries,
296 );
297 return Err(err.into());
298 }
299 }
300 }
301 }
302
303 #[inline]
304 fn to_rpc_request(&self, request: Request) -> GreptimeRequest {
305 GreptimeRequest {
306 header: Some(RequestHeader {
307 catalog: self.catalog.clone(),
308 schema: self.schema.clone(),
309 authorization: self.ctx.auth_header.clone(),
310 dbname: self.dbname.clone(),
311 timezone: self.timezone.clone(),
312 tracing_context: W3cTrace::new(),
314 }),
315 request: Some(request),
316 }
317 }
318
319 pub async fn sql<S>(&self, sql: S) -> Result<Output>
321 where
322 S: AsRef<str>,
323 {
324 self.sql_with_hint(sql, &[]).await
325 }
326
327 pub async fn sql_with_hint<S>(&self, sql: S, hints: &[(&str, &str)]) -> Result<Output>
329 where
330 S: AsRef<str>,
331 {
332 let request = Request::Query(QueryRequest {
333 query: Some(Query::Sql(sql.as_ref().to_string())),
334 });
335 self.do_get(request, hints).await
336 }
337
338 pub async fn logical_plan(&self, logical_plan: Vec<u8>) -> Result<Output> {
340 let request = Request::Query(QueryRequest {
341 query: Some(Query::LogicalPlan(logical_plan)),
342 });
343 self.do_get(request, &[]).await
344 }
345
346 pub async fn create(&self, expr: CreateTableExpr) -> Result<Output> {
348 let request = Request::Ddl(DdlRequest {
349 expr: Some(DdlExpr::CreateTable(expr)),
350 });
351 self.do_get(request, &[]).await
352 }
353
354 pub async fn alter(&self, expr: AlterTableExpr) -> Result<Output> {
356 let request = Request::Ddl(DdlRequest {
357 expr: Some(DdlExpr::AlterTable(expr)),
358 });
359 self.do_get(request, &[]).await
360 }
361
362 async fn do_get(&self, request: Request, hints: &[(&str, &str)]) -> Result<Output> {
363 let request = self.to_rpc_request(request);
364 let request = Ticket {
365 ticket: request.encode_to_vec().into(),
366 };
367
368 let mut request = tonic::Request::new(request);
369 Self::put_hints(request.metadata_mut(), hints)?;
370
371 let mut client = self.client.make_flight_client(false, false)?;
372
373 let response = client.mut_inner().do_get(request).await.or_else(|e| {
374 let tonic_code = e.code();
375 let e: Error = e.into();
376 error!(
377 "Failed to do Flight get, addr: {}, code: {}, source: {:?}",
378 client.addr(),
379 tonic_code,
380 e
381 );
382 let error = Err(BoxedError::new(e)).with_context(|_| FlightGetSnafu {
383 addr: client.addr().to_string(),
384 tonic_code,
385 });
386 error
387 })?;
388
389 let flight_data_stream = response.into_inner();
390 let mut decoder = FlightDecoder::default();
391
392 let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
393 flight_data
394 .map_err(Error::from)
395 .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))?
396 .context(IllegalFlightMessagesSnafu {
397 reason: "none message",
398 })
399 });
400
401 let Some(first_flight_message) = flight_message_stream.next().await else {
402 return IllegalFlightMessagesSnafu {
403 reason: "Expect the response not to be empty",
404 }
405 .fail();
406 };
407
408 let first_flight_message = first_flight_message?;
409
410 match first_flight_message {
411 FlightMessage::AffectedRows(rows) => {
412 ensure!(
413 flight_message_stream.next().await.is_none(),
414 IllegalFlightMessagesSnafu {
415 reason: "Expect 'AffectedRows' Flight messages to be the one and the only!"
416 }
417 );
418 Ok(Output::new_with_affected_rows(rows))
419 }
420 FlightMessage::RecordBatch(_) | FlightMessage::Metrics(_) => {
421 IllegalFlightMessagesSnafu {
422 reason: "The first flight message cannot be a RecordBatch or Metrics message",
423 }
424 .fail()
425 }
426 FlightMessage::Schema(schema) => {
427 let schema = Arc::new(
428 datatypes::schema::Schema::try_from(schema)
429 .context(error::ConvertSchemaSnafu)?,
430 );
431 let schema_cloned = schema.clone();
432 let stream = Box::pin(stream!({
433 while let Some(flight_message) = flight_message_stream.next().await {
434 let flight_message = flight_message
435 .map_err(BoxedError::new)
436 .context(ExternalSnafu)?;
437 match flight_message {
438 FlightMessage::RecordBatch(arrow_batch) => {
439 yield RecordBatch::try_from_df_record_batch(
440 schema_cloned.clone(),
441 arrow_batch,
442 )
443 }
444 FlightMessage::Metrics(_) => {}
445 FlightMessage::AffectedRows(_) | FlightMessage::Schema(_) => {
446 yield IllegalFlightMessagesSnafu {reason: format!("A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}", flight_message)}
447 .fail()
448 .map_err(BoxedError::new)
449 .context(ExternalSnafu);
450 break;
451 }
452 }
453 }
454 }));
455 let record_batch_stream = RecordBatchStreamWrapper {
456 schema,
457 stream,
458 output_ordering: None,
459 metrics: Default::default(),
460 };
461 Ok(Output::new_with_stream(Box::pin(record_batch_stream)))
462 }
463 }
464 }
465
466 pub async fn do_put(&self, stream: FlightDataStream) -> Result<DoPutResponseStream> {
469 let mut request = tonic::Request::new(stream);
470
471 if let Some(AuthHeader {
472 auth_scheme: Some(AuthScheme::Basic(Basic { username, password })),
473 }) = &self.ctx.auth_header
474 {
475 let encoded = BASE64_STANDARD.encode(format!("{username}:{password}"));
476 let value =
477 MetadataValue::from_str(&encoded).context(InvalidTonicMetadataValueSnafu)?;
478 request.metadata_mut().insert("x-greptime-auth", value);
479 }
480
481 let db_to_put = if !self.dbname.is_empty() {
482 &self.dbname
483 } else {
484 &build_db_string(self.catalog_or_default(), self.schema_or_default())
485 };
486 request.metadata_mut().insert(
487 "x-greptime-db-name",
488 MetadataValue::from_str(db_to_put).context(InvalidTonicMetadataValueSnafu)?,
489 );
490
491 let mut client = self.client.make_flight_client(false, false)?;
492 let response = client.mut_inner().do_put(request).await?;
493 let response = response
494 .into_inner()
495 .map_err(Into::into)
496 .and_then(|x| future::ready(DoPutResponse::try_from(x).context(ConvertFlightDataSnafu)))
497 .boxed();
498 Ok(response)
499 }
500}
501
502pub fn is_grpc_retryable(err: &tonic::Status) -> bool {
504 matches!(err.code(), tonic::Code::Unavailable)
505}
506
507#[derive(Default, Debug, Clone)]
508struct FlightContext {
509 auth_header: Option<AuthHeader>,
510}
511
512#[cfg(test)]
513mod tests {
514 use std::assert_matches::assert_matches;
515
516 use api::v1::auth_header::AuthScheme;
517 use api::v1::{AuthHeader, Basic};
518 use common_error::status_code::StatusCode;
519 use tonic::{Code, Status};
520
521 use super::*;
522 use crate::error::TonicSnafu;
523
524 #[test]
525 fn test_flight_ctx() {
526 let mut ctx = FlightContext::default();
527 assert!(ctx.auth_header.is_none());
528
529 let basic = AuthScheme::Basic(Basic {
530 username: "u".to_string(),
531 password: "p".to_string(),
532 });
533
534 ctx.auth_header = Some(AuthHeader {
535 auth_scheme: Some(basic),
536 });
537
538 assert_matches!(
539 ctx.auth_header,
540 Some(AuthHeader {
541 auth_scheme: Some(AuthScheme::Basic(_)),
542 })
543 )
544 }
545
546 #[test]
547 fn test_from_tonic_status() {
548 let expected = TonicSnafu {
549 code: StatusCode::Internal,
550 msg: "blabla".to_string(),
551 tonic_code: Code::Internal,
552 }
553 .build();
554
555 let status = Status::new(Code::Internal, "blabla");
556 let actual: Error = status.into();
557
558 assert_eq!(expected.to_string(), actual.to_string());
559 }
560}