1use std::pin::Pin;
16use std::str::FromStr;
17
18use api::v1::auth_header::AuthScheme;
19use api::v1::ddl_request::Expr as DdlExpr;
20use api::v1::greptime_database_client::GreptimeDatabaseClient;
21use api::v1::greptime_request::Request;
22use api::v1::query_request::Query;
23use api::v1::{
24 AlterTableExpr, AuthHeader, Basic, CreateTableExpr, DdlRequest, GreptimeRequest,
25 InsertRequests, QueryRequest, RequestHeader,
26};
27use arrow_flight::{FlightData, Ticket};
28use async_stream::stream;
29use base64::prelude::BASE64_STANDARD;
30use base64::Engine;
31use common_catalog::build_db_string;
32use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
33use common_error::ext::{BoxedError, ErrorExt};
34use common_grpc::flight::do_put::DoPutResponse;
35use common_grpc::flight::{FlightDecoder, FlightMessage};
36use common_query::Output;
37use common_recordbatch::error::ExternalSnafu;
38use common_recordbatch::RecordBatchStreamWrapper;
39use common_telemetry::tracing_context::W3cTrace;
40use common_telemetry::{error, warn};
41use futures::future;
42use futures_util::{Stream, StreamExt, TryStreamExt};
43use prost::Message;
44use snafu::{ensure, ResultExt};
45use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue, MetadataMap, MetadataValue};
46use tonic::transport::Channel;
47
48use crate::error::{
49 ConvertFlightDataSnafu, Error, FlightGetSnafu, IllegalFlightMessagesSnafu,
50 InvalidTonicMetadataValueSnafu, ServerSnafu,
51};
52use crate::{from_grpc_response, Client, Result};
53
54type FlightDataStream = Pin<Box<dyn Stream<Item = FlightData> + Send>>;
55
56type DoPutResponseStream = Pin<Box<dyn Stream<Item = Result<DoPutResponse>>>>;
57
58#[derive(Clone, Debug, Default)]
59pub struct Database {
60 catalog: String,
64 schema: String,
65 dbname: String,
68 timezone: String,
71
72 client: Client,
73 ctx: FlightContext,
74}
75
76pub struct DatabaseClient {
77 pub inner: GreptimeDatabaseClient<Channel>,
78}
79
80fn make_database_client(client: &Client) -> Result<DatabaseClient> {
81 let (_, channel) = client.find_channel()?;
82 Ok(DatabaseClient {
83 inner: GreptimeDatabaseClient::new(channel)
84 .max_decoding_message_size(client.max_grpc_recv_message_size())
85 .max_encoding_message_size(client.max_grpc_send_message_size()),
86 })
87}
88
89impl Database {
90 pub fn new(catalog: impl Into<String>, schema: impl Into<String>, client: Client) -> Self {
92 Self {
93 catalog: catalog.into(),
94 schema: schema.into(),
95 dbname: String::default(),
96 timezone: String::default(),
97 client,
98 ctx: FlightContext::default(),
99 }
100 }
101
102 pub fn new_with_dbname(dbname: impl Into<String>, client: Client) -> Self {
110 Self {
111 catalog: String::default(),
112 schema: String::default(),
113 timezone: String::default(),
114 dbname: dbname.into(),
115 client,
116 ctx: FlightContext::default(),
117 }
118 }
119
120 pub fn set_catalog(&mut self, catalog: impl Into<String>) {
121 self.catalog = catalog.into();
122 }
123
124 fn catalog_or_default(&self) -> &str {
125 if self.catalog.is_empty() {
126 DEFAULT_CATALOG_NAME
127 } else {
128 &self.catalog
129 }
130 }
131
132 pub fn set_schema(&mut self, schema: impl Into<String>) {
133 self.schema = schema.into();
134 }
135
136 fn schema_or_default(&self) -> &str {
137 if self.schema.is_empty() {
138 DEFAULT_SCHEMA_NAME
139 } else {
140 &self.schema
141 }
142 }
143
144 pub fn set_timezone(&mut self, timezone: impl Into<String>) {
145 self.timezone = timezone.into();
146 }
147
148 pub fn set_auth(&mut self, auth: AuthScheme) {
149 self.ctx.auth_header = Some(AuthHeader {
150 auth_scheme: Some(auth),
151 });
152 }
153
154 pub async fn insert(&self, requests: InsertRequests) -> Result<u32> {
155 self.handle(Request::Inserts(requests)).await
156 }
157
158 pub async fn insert_with_hints(
159 &self,
160 requests: InsertRequests,
161 hints: &[(&str, &str)],
162 ) -> Result<u32> {
163 let mut client = make_database_client(&self.client)?.inner;
164 let request = self.to_rpc_request(Request::Inserts(requests));
165
166 let mut request = tonic::Request::new(request);
167 let metadata = request.metadata_mut();
168 Self::put_hints(metadata, hints)?;
169
170 let response = client.handle(request).await?.into_inner();
171 from_grpc_response(response)
172 }
173
174 fn put_hints(metadata: &mut MetadataMap, hints: &[(&str, &str)]) -> Result<()> {
175 let Some(value) = hints
176 .iter()
177 .map(|(k, v)| format!("{}={}", k, v))
178 .reduce(|a, b| format!("{},{}", a, b))
179 else {
180 return Ok(());
181 };
182
183 let key = AsciiMetadataKey::from_static("x-greptime-hints");
184 let value = AsciiMetadataValue::from_str(&value).context(InvalidTonicMetadataValueSnafu)?;
185 metadata.insert(key, value);
186 Ok(())
187 }
188
189 pub async fn handle(&self, request: Request) -> Result<u32> {
190 let mut client = make_database_client(&self.client)?.inner;
191 let request = self.to_rpc_request(request);
192 let response = client.handle(request).await?.into_inner();
193 from_grpc_response(response)
194 }
195
196 pub async fn handle_with_retry(&self, request: Request, max_retries: u32) -> Result<u32> {
199 let mut client = make_database_client(&self.client)?.inner;
200 let mut retries = 0;
201 let request = self.to_rpc_request(request);
202 loop {
203 let raw_response = client.handle(request.clone()).await;
204 match (raw_response, retries < max_retries) {
205 (Ok(resp), _) => return from_grpc_response(resp.into_inner()),
206 (Err(err), true) => {
207 if is_grpc_retryable(&err) {
209 retries += 1;
211 warn!("Retrying {} times with error = {:?}", retries, err);
212 continue;
213 }
214 }
215 (Err(err), false) => {
216 error!(
217 "Failed to send request to grpc handle after {} retries, error = {:?}",
218 retries, err
219 );
220 return Err(err.into());
221 }
222 }
223 }
224 }
225
226 #[inline]
227 fn to_rpc_request(&self, request: Request) -> GreptimeRequest {
228 GreptimeRequest {
229 header: Some(RequestHeader {
230 catalog: self.catalog.clone(),
231 schema: self.schema.clone(),
232 authorization: self.ctx.auth_header.clone(),
233 dbname: self.dbname.clone(),
234 timezone: self.timezone.clone(),
235 tracing_context: W3cTrace::new(),
237 }),
238 request: Some(request),
239 }
240 }
241
242 pub async fn sql<S>(&self, sql: S) -> Result<Output>
243 where
244 S: AsRef<str>,
245 {
246 self.sql_with_hint(sql, &[]).await
247 }
248
249 pub async fn sql_with_hint<S>(&self, sql: S, hints: &[(&str, &str)]) -> Result<Output>
250 where
251 S: AsRef<str>,
252 {
253 let request = Request::Query(QueryRequest {
254 query: Some(Query::Sql(sql.as_ref().to_string())),
255 });
256 self.do_get(request, hints).await
257 }
258
259 pub async fn logical_plan(&self, logical_plan: Vec<u8>) -> Result<Output> {
260 let request = Request::Query(QueryRequest {
261 query: Some(Query::LogicalPlan(logical_plan)),
262 });
263 self.do_get(request, &[]).await
264 }
265
266 pub async fn create(&self, expr: CreateTableExpr) -> Result<Output> {
267 let request = Request::Ddl(DdlRequest {
268 expr: Some(DdlExpr::CreateTable(expr)),
269 });
270 self.do_get(request, &[]).await
271 }
272
273 pub async fn alter(&self, expr: AlterTableExpr) -> Result<Output> {
274 let request = Request::Ddl(DdlRequest {
275 expr: Some(DdlExpr::AlterTable(expr)),
276 });
277 self.do_get(request, &[]).await
278 }
279
280 async fn do_get(&self, request: Request, hints: &[(&str, &str)]) -> Result<Output> {
281 let request = self.to_rpc_request(request);
282 let request = Ticket {
283 ticket: request.encode_to_vec().into(),
284 };
285
286 let mut request = tonic::Request::new(request);
287 Self::put_hints(request.metadata_mut(), hints)?;
288
289 let mut client = self.client.make_flight_client()?;
290
291 let response = client.mut_inner().do_get(request).await.or_else(|e| {
292 let tonic_code = e.code();
293 let e: Error = e.into();
294 let code = e.status_code();
295 let msg = e.to_string();
296 let error =
297 Err(BoxedError::new(ServerSnafu { code, msg }.build())).with_context(|_| {
298 FlightGetSnafu {
299 addr: client.addr().to_string(),
300 tonic_code,
301 }
302 });
303 error!(
304 "Failed to do Flight get, addr: {}, code: {}, source: {:?}",
305 client.addr(),
306 tonic_code,
307 error
308 );
309 error
310 })?;
311
312 let flight_data_stream = response.into_inner();
313 let mut decoder = FlightDecoder::default();
314
315 let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
316 flight_data
317 .map_err(Error::from)
318 .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))
319 });
320
321 let Some(first_flight_message) = flight_message_stream.next().await else {
322 return IllegalFlightMessagesSnafu {
323 reason: "Expect the response not to be empty",
324 }
325 .fail();
326 };
327
328 let first_flight_message = first_flight_message?;
329
330 match first_flight_message {
331 FlightMessage::AffectedRows(rows) => {
332 ensure!(
333 flight_message_stream.next().await.is_none(),
334 IllegalFlightMessagesSnafu {
335 reason: "Expect 'AffectedRows' Flight messages to be the one and the only!"
336 }
337 );
338 Ok(Output::new_with_affected_rows(rows))
339 }
340 FlightMessage::Recordbatch(_) | FlightMessage::Metrics(_) => {
341 IllegalFlightMessagesSnafu {
342 reason: "The first flight message cannot be a RecordBatch or Metrics message",
343 }
344 .fail()
345 }
346 FlightMessage::Schema(schema) => {
347 let stream = Box::pin(stream!({
348 while let Some(flight_message) = flight_message_stream.next().await {
349 let flight_message = flight_message
350 .map_err(BoxedError::new)
351 .context(ExternalSnafu)?;
352 match flight_message {
353 FlightMessage::Recordbatch(record_batch) => yield Ok(record_batch),
354 FlightMessage::Metrics(_) => {}
355 FlightMessage::AffectedRows(_) | FlightMessage::Schema(_) => {
356 yield IllegalFlightMessagesSnafu {reason: format!("A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}", flight_message)}
357 .fail()
358 .map_err(BoxedError::new)
359 .context(ExternalSnafu);
360 break;
361 }
362 }
363 }
364 }));
365 let record_batch_stream = RecordBatchStreamWrapper {
366 schema,
367 stream,
368 output_ordering: None,
369 metrics: Default::default(),
370 };
371 Ok(Output::new_with_stream(Box::pin(record_batch_stream)))
372 }
373 }
374 }
375
376 pub async fn do_put(&self, stream: FlightDataStream) -> Result<DoPutResponseStream> {
379 let mut request = tonic::Request::new(stream);
380
381 if let Some(AuthHeader {
382 auth_scheme: Some(AuthScheme::Basic(Basic { username, password })),
383 }) = &self.ctx.auth_header
384 {
385 let encoded = BASE64_STANDARD.encode(format!("{username}:{password}"));
386 let value =
387 MetadataValue::from_str(&encoded).context(InvalidTonicMetadataValueSnafu)?;
388 request.metadata_mut().insert("x-greptime-auth", value);
389 }
390
391 let db_to_put = if !self.dbname.is_empty() {
392 &self.dbname
393 } else {
394 &build_db_string(self.catalog_or_default(), self.schema_or_default())
395 };
396 request.metadata_mut().insert(
397 "x-greptime-db-name",
398 MetadataValue::from_str(db_to_put).context(InvalidTonicMetadataValueSnafu)?,
399 );
400
401 let mut client = self.client.make_flight_client()?;
402 let response = client.mut_inner().do_put(request).await?;
403 let response = response
404 .into_inner()
405 .map_err(Into::into)
406 .and_then(|x| future::ready(DoPutResponse::try_from(x).context(ConvertFlightDataSnafu)))
407 .boxed();
408 Ok(response)
409 }
410}
411
412pub fn is_grpc_retryable(err: &tonic::Status) -> bool {
414 matches!(err.code(), tonic::Code::Unavailable)
415}
416
417#[derive(Default, Debug, Clone)]
418struct FlightContext {
419 auth_header: Option<AuthHeader>,
420}
421
422#[cfg(test)]
423mod tests {
424 use std::assert_matches::assert_matches;
425
426 use api::v1::auth_header::AuthScheme;
427 use api::v1::{AuthHeader, Basic};
428
429 use super::*;
430
431 #[test]
432 fn test_flight_ctx() {
433 let mut ctx = FlightContext::default();
434 assert!(ctx.auth_header.is_none());
435
436 let basic = AuthScheme::Basic(Basic {
437 username: "u".to_string(),
438 password: "p".to_string(),
439 });
440
441 ctx.auth_header = Some(AuthHeader {
442 auth_scheme: Some(basic),
443 });
444
445 assert_matches!(
446 ctx.auth_header,
447 Some(AuthHeader {
448 auth_scheme: Some(AuthScheme::Basic(_)),
449 })
450 )
451 }
452}