1use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::Arc;
18
19use api::v1::auth_header::AuthScheme;
20use api::v1::ddl_request::Expr as DdlExpr;
21use api::v1::greptime_database_client::GreptimeDatabaseClient;
22use api::v1::greptime_request::Request;
23use api::v1::query_request::Query;
24use api::v1::{
25 AlterTableExpr, AuthHeader, Basic, CreateTableExpr, DdlRequest, GreptimeRequest,
26 InsertRequests, QueryRequest, RequestHeader,
27};
28use arrow_flight::{FlightData, Ticket};
29use async_stream::stream;
30use base64::prelude::BASE64_STANDARD;
31use base64::Engine;
32use common_catalog::build_db_string;
33use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
34use common_error::ext::{BoxedError, ErrorExt};
35use common_grpc::flight::do_put::DoPutResponse;
36use common_grpc::flight::{FlightDecoder, FlightMessage};
37use common_query::Output;
38use common_recordbatch::error::ExternalSnafu;
39use common_recordbatch::{RecordBatch, RecordBatchStreamWrapper};
40use common_telemetry::tracing_context::W3cTrace;
41use common_telemetry::{error, warn};
42use futures::future;
43use futures_util::{Stream, StreamExt, TryStreamExt};
44use prost::Message;
45use snafu::{ensure, ResultExt};
46use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue, MetadataMap, MetadataValue};
47use tonic::transport::Channel;
48
49use crate::error::{
50 ConvertFlightDataSnafu, Error, FlightGetSnafu, IllegalFlightMessagesSnafu,
51 InvalidTonicMetadataValueSnafu, ServerSnafu,
52};
53use crate::{error, from_grpc_response, Client, Result};
54
55type FlightDataStream = Pin<Box<dyn Stream<Item = FlightData> + Send>>;
56
57type DoPutResponseStream = Pin<Box<dyn Stream<Item = Result<DoPutResponse>>>>;
58
59#[derive(Clone, Debug, Default)]
60pub struct Database {
61 catalog: String,
65 schema: String,
66 dbname: String,
69 timezone: String,
72
73 client: Client,
74 ctx: FlightContext,
75}
76
77pub struct DatabaseClient {
78 pub inner: GreptimeDatabaseClient<Channel>,
79}
80
81fn make_database_client(client: &Client) -> Result<DatabaseClient> {
82 let (_, channel) = client.find_channel()?;
83 Ok(DatabaseClient {
84 inner: GreptimeDatabaseClient::new(channel)
85 .max_decoding_message_size(client.max_grpc_recv_message_size())
86 .max_encoding_message_size(client.max_grpc_send_message_size()),
87 })
88}
89
90impl Database {
91 pub fn new(catalog: impl Into<String>, schema: impl Into<String>, client: Client) -> Self {
93 Self {
94 catalog: catalog.into(),
95 schema: schema.into(),
96 dbname: String::default(),
97 timezone: String::default(),
98 client,
99 ctx: FlightContext::default(),
100 }
101 }
102
103 pub fn new_with_dbname(dbname: impl Into<String>, client: Client) -> Self {
111 Self {
112 catalog: String::default(),
113 schema: String::default(),
114 timezone: String::default(),
115 dbname: dbname.into(),
116 client,
117 ctx: FlightContext::default(),
118 }
119 }
120
121 pub fn set_catalog(&mut self, catalog: impl Into<String>) {
122 self.catalog = catalog.into();
123 }
124
125 fn catalog_or_default(&self) -> &str {
126 if self.catalog.is_empty() {
127 DEFAULT_CATALOG_NAME
128 } else {
129 &self.catalog
130 }
131 }
132
133 pub fn set_schema(&mut self, schema: impl Into<String>) {
134 self.schema = schema.into();
135 }
136
137 fn schema_or_default(&self) -> &str {
138 if self.schema.is_empty() {
139 DEFAULT_SCHEMA_NAME
140 } else {
141 &self.schema
142 }
143 }
144
145 pub fn set_timezone(&mut self, timezone: impl Into<String>) {
146 self.timezone = timezone.into();
147 }
148
149 pub fn set_auth(&mut self, auth: AuthScheme) {
150 self.ctx.auth_header = Some(AuthHeader {
151 auth_scheme: Some(auth),
152 });
153 }
154
155 pub async fn insert(&self, requests: InsertRequests) -> Result<u32> {
156 self.handle(Request::Inserts(requests)).await
157 }
158
159 pub async fn insert_with_hints(
160 &self,
161 requests: InsertRequests,
162 hints: &[(&str, &str)],
163 ) -> Result<u32> {
164 let mut client = make_database_client(&self.client)?.inner;
165 let request = self.to_rpc_request(Request::Inserts(requests));
166
167 let mut request = tonic::Request::new(request);
168 let metadata = request.metadata_mut();
169 Self::put_hints(metadata, hints)?;
170
171 let response = client.handle(request).await?.into_inner();
172 from_grpc_response(response)
173 }
174
175 fn put_hints(metadata: &mut MetadataMap, hints: &[(&str, &str)]) -> Result<()> {
176 let Some(value) = hints
177 .iter()
178 .map(|(k, v)| format!("{}={}", k, v))
179 .reduce(|a, b| format!("{},{}", a, b))
180 else {
181 return Ok(());
182 };
183
184 let key = AsciiMetadataKey::from_static("x-greptime-hints");
185 let value = AsciiMetadataValue::from_str(&value).context(InvalidTonicMetadataValueSnafu)?;
186 metadata.insert(key, value);
187 Ok(())
188 }
189
190 pub async fn handle(&self, request: Request) -> Result<u32> {
191 let mut client = make_database_client(&self.client)?.inner;
192 let request = self.to_rpc_request(request);
193 let response = client.handle(request).await?.into_inner();
194 from_grpc_response(response)
195 }
196
197 pub async fn handle_with_retry(&self, request: Request, max_retries: u32) -> Result<u32> {
200 let mut client = make_database_client(&self.client)?.inner;
201 let mut retries = 0;
202 let request = self.to_rpc_request(request);
203 loop {
204 let raw_response = client.handle(request.clone()).await;
205 match (raw_response, retries < max_retries) {
206 (Ok(resp), _) => return from_grpc_response(resp.into_inner()),
207 (Err(err), true) => {
208 if is_grpc_retryable(&err) {
210 retries += 1;
212 warn!("Retrying {} times with error = {:?}", retries, err);
213 continue;
214 }
215 }
216 (Err(err), false) => {
217 error!(
218 "Failed to send request to grpc handle after {} retries, error = {:?}",
219 retries, err
220 );
221 return Err(err.into());
222 }
223 }
224 }
225 }
226
227 #[inline]
228 fn to_rpc_request(&self, request: Request) -> GreptimeRequest {
229 GreptimeRequest {
230 header: Some(RequestHeader {
231 catalog: self.catalog.clone(),
232 schema: self.schema.clone(),
233 authorization: self.ctx.auth_header.clone(),
234 dbname: self.dbname.clone(),
235 timezone: self.timezone.clone(),
236 tracing_context: W3cTrace::new(),
238 }),
239 request: Some(request),
240 }
241 }
242
243 pub async fn sql<S>(&self, sql: S) -> Result<Output>
244 where
245 S: AsRef<str>,
246 {
247 self.sql_with_hint(sql, &[]).await
248 }
249
250 pub async fn sql_with_hint<S>(&self, sql: S, hints: &[(&str, &str)]) -> Result<Output>
251 where
252 S: AsRef<str>,
253 {
254 let request = Request::Query(QueryRequest {
255 query: Some(Query::Sql(sql.as_ref().to_string())),
256 });
257 self.do_get(request, hints).await
258 }
259
260 pub async fn logical_plan(&self, logical_plan: Vec<u8>) -> Result<Output> {
261 let request = Request::Query(QueryRequest {
262 query: Some(Query::LogicalPlan(logical_plan)),
263 });
264 self.do_get(request, &[]).await
265 }
266
267 pub async fn create(&self, expr: CreateTableExpr) -> Result<Output> {
268 let request = Request::Ddl(DdlRequest {
269 expr: Some(DdlExpr::CreateTable(expr)),
270 });
271 self.do_get(request, &[]).await
272 }
273
274 pub async fn alter(&self, expr: AlterTableExpr) -> Result<Output> {
275 let request = Request::Ddl(DdlRequest {
276 expr: Some(DdlExpr::AlterTable(expr)),
277 });
278 self.do_get(request, &[]).await
279 }
280
281 async fn do_get(&self, request: Request, hints: &[(&str, &str)]) -> Result<Output> {
282 let request = self.to_rpc_request(request);
283 let request = Ticket {
284 ticket: request.encode_to_vec().into(),
285 };
286
287 let mut request = tonic::Request::new(request);
288 Self::put_hints(request.metadata_mut(), hints)?;
289
290 let mut client = self.client.make_flight_client(false, false)?;
291
292 let response = client.mut_inner().do_get(request).await.or_else(|e| {
293 let tonic_code = e.code();
294 let e: Error = e.into();
295 let code = e.status_code();
296 let msg = e.to_string();
297 let error =
298 Err(BoxedError::new(ServerSnafu { code, msg }.build())).with_context(|_| {
299 FlightGetSnafu {
300 addr: client.addr().to_string(),
301 tonic_code,
302 }
303 });
304 error!(
305 "Failed to do Flight get, addr: {}, code: {}, source: {:?}",
306 client.addr(),
307 tonic_code,
308 error
309 );
310 error
311 })?;
312
313 let flight_data_stream = response.into_inner();
314 let mut decoder = FlightDecoder::default();
315
316 let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
317 flight_data
318 .map_err(Error::from)
319 .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))
320 });
321
322 let Some(first_flight_message) = flight_message_stream.next().await else {
323 return IllegalFlightMessagesSnafu {
324 reason: "Expect the response not to be empty",
325 }
326 .fail();
327 };
328
329 let first_flight_message = first_flight_message?;
330
331 match first_flight_message {
332 FlightMessage::AffectedRows(rows) => {
333 ensure!(
334 flight_message_stream.next().await.is_none(),
335 IllegalFlightMessagesSnafu {
336 reason: "Expect 'AffectedRows' Flight messages to be the one and the only!"
337 }
338 );
339 Ok(Output::new_with_affected_rows(rows))
340 }
341 FlightMessage::RecordBatch(_) | FlightMessage::Metrics(_) => {
342 IllegalFlightMessagesSnafu {
343 reason: "The first flight message cannot be a RecordBatch or Metrics message",
344 }
345 .fail()
346 }
347 FlightMessage::Schema(schema) => {
348 let schema = Arc::new(
349 datatypes::schema::Schema::try_from(schema)
350 .context(error::ConvertSchemaSnafu)?,
351 );
352 let schema_cloned = schema.clone();
353 let stream = Box::pin(stream!({
354 while let Some(flight_message) = flight_message_stream.next().await {
355 let flight_message = flight_message
356 .map_err(BoxedError::new)
357 .context(ExternalSnafu)?;
358 match flight_message {
359 FlightMessage::RecordBatch(arrow_batch) => {
360 yield RecordBatch::try_from_df_record_batch(
361 schema_cloned.clone(),
362 arrow_batch,
363 )
364 }
365 FlightMessage::Metrics(_) => {}
366 FlightMessage::AffectedRows(_) | FlightMessage::Schema(_) => {
367 yield IllegalFlightMessagesSnafu {reason: format!("A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}", flight_message)}
368 .fail()
369 .map_err(BoxedError::new)
370 .context(ExternalSnafu);
371 break;
372 }
373 }
374 }
375 }));
376 let record_batch_stream = RecordBatchStreamWrapper {
377 schema,
378 stream,
379 output_ordering: None,
380 metrics: Default::default(),
381 };
382 Ok(Output::new_with_stream(Box::pin(record_batch_stream)))
383 }
384 }
385 }
386
387 pub async fn do_put(&self, stream: FlightDataStream) -> Result<DoPutResponseStream> {
390 let mut request = tonic::Request::new(stream);
391
392 if let Some(AuthHeader {
393 auth_scheme: Some(AuthScheme::Basic(Basic { username, password })),
394 }) = &self.ctx.auth_header
395 {
396 let encoded = BASE64_STANDARD.encode(format!("{username}:{password}"));
397 let value =
398 MetadataValue::from_str(&encoded).context(InvalidTonicMetadataValueSnafu)?;
399 request.metadata_mut().insert("x-greptime-auth", value);
400 }
401
402 let db_to_put = if !self.dbname.is_empty() {
403 &self.dbname
404 } else {
405 &build_db_string(self.catalog_or_default(), self.schema_or_default())
406 };
407 request.metadata_mut().insert(
408 "x-greptime-db-name",
409 MetadataValue::from_str(db_to_put).context(InvalidTonicMetadataValueSnafu)?,
410 );
411
412 let mut client = self.client.make_flight_client(false, false)?;
413 let response = client.mut_inner().do_put(request).await?;
414 let response = response
415 .into_inner()
416 .map_err(Into::into)
417 .and_then(|x| future::ready(DoPutResponse::try_from(x).context(ConvertFlightDataSnafu)))
418 .boxed();
419 Ok(response)
420 }
421}
422
423pub fn is_grpc_retryable(err: &tonic::Status) -> bool {
425 matches!(err.code(), tonic::Code::Unavailable)
426}
427
428#[derive(Default, Debug, Clone)]
429struct FlightContext {
430 auth_header: Option<AuthHeader>,
431}
432
433#[cfg(test)]
434mod tests {
435 use std::assert_matches::assert_matches;
436
437 use api::v1::auth_header::AuthScheme;
438 use api::v1::{AuthHeader, Basic};
439
440 use super::*;
441
442 #[test]
443 fn test_flight_ctx() {
444 let mut ctx = FlightContext::default();
445 assert!(ctx.auth_header.is_none());
446
447 let basic = AuthScheme::Basic(Basic {
448 username: "u".to_string(),
449 password: "p".to_string(),
450 });
451
452 ctx.auth_header = Some(AuthHeader {
453 auth_scheme: Some(basic),
454 });
455
456 assert_matches!(
457 ctx.auth_header,
458 Some(AuthHeader {
459 auth_scheme: Some(AuthScheme::Basic(_)),
460 })
461 )
462 }
463}