1use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::Arc;
18
19use api::v1::auth_header::AuthScheme;
20use api::v1::ddl_request::Expr as DdlExpr;
21use api::v1::greptime_database_client::GreptimeDatabaseClient;
22use api::v1::greptime_request::Request;
23use api::v1::query_request::Query;
24use api::v1::{
25 AlterTableExpr, AuthHeader, Basic, CreateTableExpr, DdlRequest, GreptimeRequest,
26 InsertRequests, QueryRequest, RequestHeader,
27};
28use arrow_flight::{FlightData, Ticket};
29use async_stream::stream;
30use base64::prelude::BASE64_STANDARD;
31use base64::Engine;
32use common_catalog::build_db_string;
33use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
34use common_error::ext::BoxedError;
35use common_grpc::flight::do_put::DoPutResponse;
36use common_grpc::flight::{FlightDecoder, FlightMessage};
37use common_query::Output;
38use common_recordbatch::error::ExternalSnafu;
39use common_recordbatch::{RecordBatch, RecordBatchStreamWrapper};
40use common_telemetry::tracing_context::W3cTrace;
41use common_telemetry::{error, warn};
42use futures::future;
43use futures_util::{Stream, StreamExt, TryStreamExt};
44use prost::Message;
45use snafu::{ensure, ResultExt};
46use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue, MetadataMap, MetadataValue};
47use tonic::transport::Channel;
48
49use crate::error::{
50 ConvertFlightDataSnafu, Error, FlightGetSnafu, IllegalFlightMessagesSnafu,
51 InvalidTonicMetadataValueSnafu,
52};
53use crate::{error, from_grpc_response, Client, Result};
54
55type FlightDataStream = Pin<Box<dyn Stream<Item = FlightData> + Send>>;
56
57type DoPutResponseStream = Pin<Box<dyn Stream<Item = Result<DoPutResponse>>>>;
58
59#[derive(Clone, Debug, Default)]
60pub struct Database {
61 catalog: String,
65 schema: String,
66 dbname: String,
69 timezone: String,
72
73 client: Client,
74 ctx: FlightContext,
75}
76
77pub struct DatabaseClient {
78 pub inner: GreptimeDatabaseClient<Channel>,
79}
80
81fn make_database_client(client: &Client) -> Result<DatabaseClient> {
82 let (_, channel) = client.find_channel()?;
83 Ok(DatabaseClient {
84 inner: GreptimeDatabaseClient::new(channel)
85 .max_decoding_message_size(client.max_grpc_recv_message_size())
86 .max_encoding_message_size(client.max_grpc_send_message_size()),
87 })
88}
89
90impl Database {
91 pub fn new(catalog: impl Into<String>, schema: impl Into<String>, client: Client) -> Self {
93 Self {
94 catalog: catalog.into(),
95 schema: schema.into(),
96 dbname: String::default(),
97 timezone: String::default(),
98 client,
99 ctx: FlightContext::default(),
100 }
101 }
102
103 pub fn new_with_dbname(dbname: impl Into<String>, client: Client) -> Self {
111 Self {
112 catalog: String::default(),
113 schema: String::default(),
114 timezone: String::default(),
115 dbname: dbname.into(),
116 client,
117 ctx: FlightContext::default(),
118 }
119 }
120
121 pub fn set_catalog(&mut self, catalog: impl Into<String>) {
122 self.catalog = catalog.into();
123 }
124
125 fn catalog_or_default(&self) -> &str {
126 if self.catalog.is_empty() {
127 DEFAULT_CATALOG_NAME
128 } else {
129 &self.catalog
130 }
131 }
132
133 pub fn set_schema(&mut self, schema: impl Into<String>) {
134 self.schema = schema.into();
135 }
136
137 fn schema_or_default(&self) -> &str {
138 if self.schema.is_empty() {
139 DEFAULT_SCHEMA_NAME
140 } else {
141 &self.schema
142 }
143 }
144
145 pub fn set_timezone(&mut self, timezone: impl Into<String>) {
146 self.timezone = timezone.into();
147 }
148
149 pub fn set_auth(&mut self, auth: AuthScheme) {
150 self.ctx.auth_header = Some(AuthHeader {
151 auth_scheme: Some(auth),
152 });
153 }
154
155 pub async fn insert(&self, requests: InsertRequests) -> Result<u32> {
156 self.handle(Request::Inserts(requests)).await
157 }
158
159 pub async fn insert_with_hints(
160 &self,
161 requests: InsertRequests,
162 hints: &[(&str, &str)],
163 ) -> Result<u32> {
164 let mut client = make_database_client(&self.client)?.inner;
165 let request = self.to_rpc_request(Request::Inserts(requests));
166
167 let mut request = tonic::Request::new(request);
168 let metadata = request.metadata_mut();
169 Self::put_hints(metadata, hints)?;
170
171 let response = client.handle(request).await?.into_inner();
172 from_grpc_response(response)
173 }
174
175 fn put_hints(metadata: &mut MetadataMap, hints: &[(&str, &str)]) -> Result<()> {
176 let Some(value) = hints
177 .iter()
178 .map(|(k, v)| format!("{}={}", k, v))
179 .reduce(|a, b| format!("{},{}", a, b))
180 else {
181 return Ok(());
182 };
183
184 let key = AsciiMetadataKey::from_static("x-greptime-hints");
185 let value = AsciiMetadataValue::from_str(&value).context(InvalidTonicMetadataValueSnafu)?;
186 metadata.insert(key, value);
187 Ok(())
188 }
189
190 pub async fn handle(&self, request: Request) -> Result<u32> {
191 let mut client = make_database_client(&self.client)?.inner;
192 let request = self.to_rpc_request(request);
193 let response = client.handle(request).await?.into_inner();
194 from_grpc_response(response)
195 }
196
197 pub async fn handle_with_retry(
200 &self,
201 request: Request,
202 max_retries: u32,
203 hints: &[(&str, &str)],
204 ) -> Result<u32> {
205 let mut client = make_database_client(&self.client)?.inner;
206 let mut retries = 0;
207
208 let request = self.to_rpc_request(request);
209
210 loop {
211 let mut tonic_request = tonic::Request::new(request.clone());
212 let metadata = tonic_request.metadata_mut();
213 Self::put_hints(metadata, hints)?;
214 let raw_response = client.handle(tonic_request).await;
215 match (raw_response, retries < max_retries) {
216 (Ok(resp), _) => return from_grpc_response(resp.into_inner()),
217 (Err(err), true) => {
218 if is_grpc_retryable(&err) {
220 retries += 1;
222 warn!("Retrying {} times with error = {:?}", retries, err);
223 continue;
224 }
225 }
226 (Err(err), false) => {
227 error!(
228 "Failed to send request to grpc handle after {} retries, error = {:?}",
229 retries, err
230 );
231 return Err(err.into());
232 }
233 }
234 }
235 }
236
237 #[inline]
238 fn to_rpc_request(&self, request: Request) -> GreptimeRequest {
239 GreptimeRequest {
240 header: Some(RequestHeader {
241 catalog: self.catalog.clone(),
242 schema: self.schema.clone(),
243 authorization: self.ctx.auth_header.clone(),
244 dbname: self.dbname.clone(),
245 timezone: self.timezone.clone(),
246 tracing_context: W3cTrace::new(),
248 }),
249 request: Some(request),
250 }
251 }
252
253 pub async fn sql<S>(&self, sql: S) -> Result<Output>
254 where
255 S: AsRef<str>,
256 {
257 self.sql_with_hint(sql, &[]).await
258 }
259
260 pub async fn sql_with_hint<S>(&self, sql: S, hints: &[(&str, &str)]) -> Result<Output>
261 where
262 S: AsRef<str>,
263 {
264 let request = Request::Query(QueryRequest {
265 query: Some(Query::Sql(sql.as_ref().to_string())),
266 });
267 self.do_get(request, hints).await
268 }
269
270 pub async fn logical_plan(&self, logical_plan: Vec<u8>) -> Result<Output> {
271 let request = Request::Query(QueryRequest {
272 query: Some(Query::LogicalPlan(logical_plan)),
273 });
274 self.do_get(request, &[]).await
275 }
276
277 pub async fn create(&self, expr: CreateTableExpr) -> Result<Output> {
278 let request = Request::Ddl(DdlRequest {
279 expr: Some(DdlExpr::CreateTable(expr)),
280 });
281 self.do_get(request, &[]).await
282 }
283
284 pub async fn alter(&self, expr: AlterTableExpr) -> Result<Output> {
285 let request = Request::Ddl(DdlRequest {
286 expr: Some(DdlExpr::AlterTable(expr)),
287 });
288 self.do_get(request, &[]).await
289 }
290
291 async fn do_get(&self, request: Request, hints: &[(&str, &str)]) -> Result<Output> {
292 let request = self.to_rpc_request(request);
293 let request = Ticket {
294 ticket: request.encode_to_vec().into(),
295 };
296
297 let mut request = tonic::Request::new(request);
298 Self::put_hints(request.metadata_mut(), hints)?;
299
300 let mut client = self.client.make_flight_client(false, false)?;
301
302 let response = client.mut_inner().do_get(request).await.or_else(|e| {
303 let tonic_code = e.code();
304 let e: Error = e.into();
305 error!(
306 "Failed to do Flight get, addr: {}, code: {}, source: {:?}",
307 client.addr(),
308 tonic_code,
309 e
310 );
311 let error = Err(BoxedError::new(e)).with_context(|_| FlightGetSnafu {
312 addr: client.addr().to_string(),
313 tonic_code,
314 });
315 error
316 })?;
317
318 let flight_data_stream = response.into_inner();
319 let mut decoder = FlightDecoder::default();
320
321 let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
322 flight_data
323 .map_err(Error::from)
324 .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))
325 });
326
327 let Some(first_flight_message) = flight_message_stream.next().await else {
328 return IllegalFlightMessagesSnafu {
329 reason: "Expect the response not to be empty",
330 }
331 .fail();
332 };
333
334 let first_flight_message = first_flight_message?;
335
336 match first_flight_message {
337 FlightMessage::AffectedRows(rows) => {
338 ensure!(
339 flight_message_stream.next().await.is_none(),
340 IllegalFlightMessagesSnafu {
341 reason: "Expect 'AffectedRows' Flight messages to be the one and the only!"
342 }
343 );
344 Ok(Output::new_with_affected_rows(rows))
345 }
346 FlightMessage::RecordBatch(_) | FlightMessage::Metrics(_) => {
347 IllegalFlightMessagesSnafu {
348 reason: "The first flight message cannot be a RecordBatch or Metrics message",
349 }
350 .fail()
351 }
352 FlightMessage::Schema(schema) => {
353 let schema = Arc::new(
354 datatypes::schema::Schema::try_from(schema)
355 .context(error::ConvertSchemaSnafu)?,
356 );
357 let schema_cloned = schema.clone();
358 let stream = Box::pin(stream!({
359 while let Some(flight_message) = flight_message_stream.next().await {
360 let flight_message = flight_message
361 .map_err(BoxedError::new)
362 .context(ExternalSnafu)?;
363 match flight_message {
364 FlightMessage::RecordBatch(arrow_batch) => {
365 yield RecordBatch::try_from_df_record_batch(
366 schema_cloned.clone(),
367 arrow_batch,
368 )
369 }
370 FlightMessage::Metrics(_) => {}
371 FlightMessage::AffectedRows(_) | FlightMessage::Schema(_) => {
372 yield IllegalFlightMessagesSnafu {reason: format!("A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}", flight_message)}
373 .fail()
374 .map_err(BoxedError::new)
375 .context(ExternalSnafu);
376 break;
377 }
378 }
379 }
380 }));
381 let record_batch_stream = RecordBatchStreamWrapper {
382 schema,
383 stream,
384 output_ordering: None,
385 metrics: Default::default(),
386 };
387 Ok(Output::new_with_stream(Box::pin(record_batch_stream)))
388 }
389 }
390 }
391
392 pub async fn do_put(&self, stream: FlightDataStream) -> Result<DoPutResponseStream> {
395 let mut request = tonic::Request::new(stream);
396
397 if let Some(AuthHeader {
398 auth_scheme: Some(AuthScheme::Basic(Basic { username, password })),
399 }) = &self.ctx.auth_header
400 {
401 let encoded = BASE64_STANDARD.encode(format!("{username}:{password}"));
402 let value =
403 MetadataValue::from_str(&encoded).context(InvalidTonicMetadataValueSnafu)?;
404 request.metadata_mut().insert("x-greptime-auth", value);
405 }
406
407 let db_to_put = if !self.dbname.is_empty() {
408 &self.dbname
409 } else {
410 &build_db_string(self.catalog_or_default(), self.schema_or_default())
411 };
412 request.metadata_mut().insert(
413 "x-greptime-db-name",
414 MetadataValue::from_str(db_to_put).context(InvalidTonicMetadataValueSnafu)?,
415 );
416
417 let mut client = self.client.make_flight_client(false, false)?;
418 let response = client.mut_inner().do_put(request).await?;
419 let response = response
420 .into_inner()
421 .map_err(Into::into)
422 .and_then(|x| future::ready(DoPutResponse::try_from(x).context(ConvertFlightDataSnafu)))
423 .boxed();
424 Ok(response)
425 }
426}
427
428pub fn is_grpc_retryable(err: &tonic::Status) -> bool {
430 matches!(err.code(), tonic::Code::Unavailable)
431}
432
433#[derive(Default, Debug, Clone)]
434struct FlightContext {
435 auth_header: Option<AuthHeader>,
436}
437
438#[cfg(test)]
439mod tests {
440 use std::assert_matches::assert_matches;
441
442 use api::v1::auth_header::AuthScheme;
443 use api::v1::{AuthHeader, Basic};
444 use common_error::status_code::StatusCode;
445 use tonic::{Code, Status};
446
447 use super::*;
448 use crate::error::TonicSnafu;
449
450 #[test]
451 fn test_flight_ctx() {
452 let mut ctx = FlightContext::default();
453 assert!(ctx.auth_header.is_none());
454
455 let basic = AuthScheme::Basic(Basic {
456 username: "u".to_string(),
457 password: "p".to_string(),
458 });
459
460 ctx.auth_header = Some(AuthHeader {
461 auth_scheme: Some(basic),
462 });
463
464 assert_matches!(
465 ctx.auth_header,
466 Some(AuthHeader {
467 auth_scheme: Some(AuthScheme::Basic(_)),
468 })
469 )
470 }
471
472 #[test]
473 fn test_from_tonic_status() {
474 let expected = TonicSnafu {
475 code: StatusCode::Internal,
476 msg: "blabla".to_string(),
477 tonic_code: Code::Internal,
478 }
479 .build();
480
481 let status = Status::new(Code::Internal, "blabla");
482 let actual: Error = status.into();
483
484 assert_eq!(expected.to_string(), actual.to_string());
485 }
486}