1use std::pin::Pin;
16use std::str::FromStr;
17
18use api::v1::auth_header::AuthScheme;
19use api::v1::ddl_request::Expr as DdlExpr;
20use api::v1::greptime_database_client::GreptimeDatabaseClient;
21use api::v1::greptime_request::Request;
22use api::v1::query_request::Query;
23use api::v1::{
24 AlterTableExpr, AuthHeader, Basic, CreateTableExpr, DdlRequest, GreptimeRequest,
25 InsertRequests, QueryRequest, RequestHeader,
26};
27use arrow_flight::{FlightData, Ticket};
28use async_stream::stream;
29use base64::prelude::BASE64_STANDARD;
30use base64::Engine;
31use common_catalog::build_db_string;
32use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
33use common_error::ext::{BoxedError, ErrorExt};
34use common_grpc::flight::do_put::DoPutResponse;
35use common_grpc::flight::{FlightDecoder, FlightMessage};
36use common_query::Output;
37use common_recordbatch::error::ExternalSnafu;
38use common_recordbatch::RecordBatchStreamWrapper;
39use common_telemetry::tracing_context::W3cTrace;
40use common_telemetry::{error, warn};
41use futures::future;
42use futures_util::{Stream, StreamExt, TryStreamExt};
43use prost::Message;
44use snafu::{ensure, ResultExt};
45use tonic::metadata::{AsciiMetadataKey, MetadataValue};
46use tonic::transport::Channel;
47
48use crate::error::{
49 ConvertFlightDataSnafu, Error, FlightGetSnafu, IllegalFlightMessagesSnafu, InvalidAsciiSnafu,
50 InvalidTonicMetadataValueSnafu, ServerSnafu,
51};
52use crate::{from_grpc_response, Client, Result};
53
54type FlightDataStream = Pin<Box<dyn Stream<Item = FlightData> + Send>>;
55
56type DoPutResponseStream = Pin<Box<dyn Stream<Item = Result<DoPutResponse>>>>;
57
58#[derive(Clone, Debug, Default)]
59pub struct Database {
60 catalog: String,
64 schema: String,
65 dbname: String,
68 timezone: String,
71
72 client: Client,
73 ctx: FlightContext,
74}
75
76pub struct DatabaseClient {
77 pub inner: GreptimeDatabaseClient<Channel>,
78}
79
80fn make_database_client(client: &Client) -> Result<DatabaseClient> {
81 let (_, channel) = client.find_channel()?;
82 Ok(DatabaseClient {
83 inner: GreptimeDatabaseClient::new(channel)
84 .max_decoding_message_size(client.max_grpc_recv_message_size())
85 .max_encoding_message_size(client.max_grpc_send_message_size()),
86 })
87}
88
89impl Database {
90 pub fn new(catalog: impl Into<String>, schema: impl Into<String>, client: Client) -> Self {
92 Self {
93 catalog: catalog.into(),
94 schema: schema.into(),
95 dbname: String::default(),
96 timezone: String::default(),
97 client,
98 ctx: FlightContext::default(),
99 }
100 }
101
102 pub fn new_with_dbname(dbname: impl Into<String>, client: Client) -> Self {
110 Self {
111 catalog: String::default(),
112 schema: String::default(),
113 timezone: String::default(),
114 dbname: dbname.into(),
115 client,
116 ctx: FlightContext::default(),
117 }
118 }
119
120 pub fn set_catalog(&mut self, catalog: impl Into<String>) {
121 self.catalog = catalog.into();
122 }
123
124 fn catalog_or_default(&self) -> &str {
125 if self.catalog.is_empty() {
126 DEFAULT_CATALOG_NAME
127 } else {
128 &self.catalog
129 }
130 }
131
132 pub fn set_schema(&mut self, schema: impl Into<String>) {
133 self.schema = schema.into();
134 }
135
136 fn schema_or_default(&self) -> &str {
137 if self.schema.is_empty() {
138 DEFAULT_SCHEMA_NAME
139 } else {
140 &self.schema
141 }
142 }
143
144 pub fn set_timezone(&mut self, timezone: impl Into<String>) {
145 self.timezone = timezone.into();
146 }
147
148 pub fn set_auth(&mut self, auth: AuthScheme) {
149 self.ctx.auth_header = Some(AuthHeader {
150 auth_scheme: Some(auth),
151 });
152 }
153
154 pub async fn insert(&self, requests: InsertRequests) -> Result<u32> {
155 self.handle(Request::Inserts(requests)).await
156 }
157
158 pub async fn insert_with_hints(
159 &self,
160 requests: InsertRequests,
161 hints: &[(&str, &str)],
162 ) -> Result<u32> {
163 let mut client = make_database_client(&self.client)?.inner;
164 let request = self.to_rpc_request(Request::Inserts(requests));
165
166 let mut request = tonic::Request::new(request);
167 let metadata = request.metadata_mut();
168 for (key, value) in hints {
169 let key = AsciiMetadataKey::from_bytes(format!("x-greptime-hint-{}", key).as_bytes())
170 .map_err(|_| {
171 InvalidAsciiSnafu {
172 value: key.to_string(),
173 }
174 .build()
175 })?;
176 let value = value.parse().map_err(|_| {
177 InvalidAsciiSnafu {
178 value: value.to_string(),
179 }
180 .build()
181 })?;
182 metadata.insert(key, value);
183 }
184 let response = client.handle(request).await?.into_inner();
185 from_grpc_response(response)
186 }
187
188 pub async fn handle(&self, request: Request) -> Result<u32> {
189 let mut client = make_database_client(&self.client)?.inner;
190 let request = self.to_rpc_request(request);
191 let response = client.handle(request).await?.into_inner();
192 from_grpc_response(response)
193 }
194
195 pub async fn handle_with_retry(&self, request: Request, max_retries: u32) -> Result<u32> {
198 let mut client = make_database_client(&self.client)?.inner;
199 let mut retries = 0;
200 let request = self.to_rpc_request(request);
201 loop {
202 let raw_response = client.handle(request.clone()).await;
203 match (raw_response, retries < max_retries) {
204 (Ok(resp), _) => return from_grpc_response(resp.into_inner()),
205 (Err(err), true) => {
206 if is_grpc_retryable(&err) {
208 retries += 1;
210 warn!("Retrying {} times with error = {:?}", retries, err);
211 continue;
212 }
213 }
214 (Err(err), false) => {
215 error!(
216 "Failed to send request to grpc handle after {} retries, error = {:?}",
217 retries, err
218 );
219 return Err(err.into());
220 }
221 }
222 }
223 }
224
225 #[inline]
226 fn to_rpc_request(&self, request: Request) -> GreptimeRequest {
227 GreptimeRequest {
228 header: Some(RequestHeader {
229 catalog: self.catalog.clone(),
230 schema: self.schema.clone(),
231 authorization: self.ctx.auth_header.clone(),
232 dbname: self.dbname.clone(),
233 timezone: self.timezone.clone(),
234 tracing_context: W3cTrace::new(),
236 }),
237 request: Some(request),
238 }
239 }
240
241 pub async fn sql<S>(&self, sql: S) -> Result<Output>
242 where
243 S: AsRef<str>,
244 {
245 self.do_get(Request::Query(QueryRequest {
246 query: Some(Query::Sql(sql.as_ref().to_string())),
247 }))
248 .await
249 }
250
251 pub async fn logical_plan(&self, logical_plan: Vec<u8>) -> Result<Output> {
252 self.do_get(Request::Query(QueryRequest {
253 query: Some(Query::LogicalPlan(logical_plan)),
254 }))
255 .await
256 }
257
258 pub async fn create(&self, expr: CreateTableExpr) -> Result<Output> {
259 self.do_get(Request::Ddl(DdlRequest {
260 expr: Some(DdlExpr::CreateTable(expr)),
261 }))
262 .await
263 }
264
265 pub async fn alter(&self, expr: AlterTableExpr) -> Result<Output> {
266 self.do_get(Request::Ddl(DdlRequest {
267 expr: Some(DdlExpr::AlterTable(expr)),
268 }))
269 .await
270 }
271
272 async fn do_get(&self, request: Request) -> Result<Output> {
273 let request = self.to_rpc_request(request);
274 let request = Ticket {
275 ticket: request.encode_to_vec().into(),
276 };
277
278 let mut client = self.client.make_flight_client()?;
279
280 let response = client.mut_inner().do_get(request).await.or_else(|e| {
281 let tonic_code = e.code();
282 let e: Error = e.into();
283 let code = e.status_code();
284 let msg = e.to_string();
285 let error =
286 Err(BoxedError::new(ServerSnafu { code, msg }.build())).with_context(|_| {
287 FlightGetSnafu {
288 addr: client.addr().to_string(),
289 tonic_code,
290 }
291 });
292 error!(
293 "Failed to do Flight get, addr: {}, code: {}, source: {:?}",
294 client.addr(),
295 tonic_code,
296 error
297 );
298 error
299 })?;
300
301 let flight_data_stream = response.into_inner();
302 let mut decoder = FlightDecoder::default();
303
304 let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
305 flight_data
306 .map_err(Error::from)
307 .and_then(|data| decoder.try_decode(data).context(ConvertFlightDataSnafu))
308 });
309
310 let Some(first_flight_message) = flight_message_stream.next().await else {
311 return IllegalFlightMessagesSnafu {
312 reason: "Expect the response not to be empty",
313 }
314 .fail();
315 };
316
317 let first_flight_message = first_flight_message?;
318
319 match first_flight_message {
320 FlightMessage::AffectedRows(rows) => {
321 ensure!(
322 flight_message_stream.next().await.is_none(),
323 IllegalFlightMessagesSnafu {
324 reason: "Expect 'AffectedRows' Flight messages to be the one and the only!"
325 }
326 );
327 Ok(Output::new_with_affected_rows(rows))
328 }
329 FlightMessage::Recordbatch(_) | FlightMessage::Metrics(_) => {
330 IllegalFlightMessagesSnafu {
331 reason: "The first flight message cannot be a RecordBatch or Metrics message",
332 }
333 .fail()
334 }
335 FlightMessage::Schema(schema) => {
336 let stream = Box::pin(stream!({
337 while let Some(flight_message) = flight_message_stream.next().await {
338 let flight_message = flight_message
339 .map_err(BoxedError::new)
340 .context(ExternalSnafu)?;
341 match flight_message {
342 FlightMessage::Recordbatch(record_batch) => yield Ok(record_batch),
343 FlightMessage::Metrics(_) => {}
344 FlightMessage::AffectedRows(_) | FlightMessage::Schema(_) => {
345 yield IllegalFlightMessagesSnafu {reason: format!("A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}", flight_message)}
346 .fail()
347 .map_err(BoxedError::new)
348 .context(ExternalSnafu);
349 break;
350 }
351 }
352 }
353 }));
354 let record_batch_stream = RecordBatchStreamWrapper {
355 schema,
356 stream,
357 output_ordering: None,
358 metrics: Default::default(),
359 };
360 Ok(Output::new_with_stream(Box::pin(record_batch_stream)))
361 }
362 }
363 }
364
365 pub async fn do_put(&self, stream: FlightDataStream) -> Result<DoPutResponseStream> {
368 let mut request = tonic::Request::new(stream);
369
370 if let Some(AuthHeader {
371 auth_scheme: Some(AuthScheme::Basic(Basic { username, password })),
372 }) = &self.ctx.auth_header
373 {
374 let encoded = BASE64_STANDARD.encode(format!("{username}:{password}"));
375 let value =
376 MetadataValue::from_str(&encoded).context(InvalidTonicMetadataValueSnafu)?;
377 request.metadata_mut().insert("x-greptime-auth", value);
378 }
379
380 let db_to_put = if !self.dbname.is_empty() {
381 &self.dbname
382 } else {
383 &build_db_string(self.catalog_or_default(), self.schema_or_default())
384 };
385 request.metadata_mut().insert(
386 "x-greptime-db-name",
387 MetadataValue::from_str(db_to_put).context(InvalidTonicMetadataValueSnafu)?,
388 );
389
390 let mut client = self.client.make_flight_client()?;
391 let response = client.mut_inner().do_put(request).await?;
392 let response = response
393 .into_inner()
394 .map_err(Into::into)
395 .and_then(|x| future::ready(DoPutResponse::try_from(x).context(ConvertFlightDataSnafu)))
396 .boxed();
397 Ok(response)
398 }
399}
400
401pub fn is_grpc_retryable(err: &tonic::Status) -> bool {
403 matches!(err.code(), tonic::Code::Unavailable)
404}
405
406#[derive(Default, Debug, Clone)]
407struct FlightContext {
408 auth_header: Option<AuthHeader>,
409}
410
411#[cfg(test)]
412mod tests {
413 use std::assert_matches::assert_matches;
414
415 use api::v1::auth_header::AuthScheme;
416 use api::v1::{AuthHeader, Basic};
417
418 use super::*;
419
420 #[test]
421 fn test_flight_ctx() {
422 let mut ctx = FlightContext::default();
423 assert!(ctx.auth_header.is_none());
424
425 let basic = AuthScheme::Basic(Basic {
426 username: "u".to_string(),
427 password: "p".to_string(),
428 });
429
430 ctx.auth_header = Some(AuthHeader {
431 auth_scheme: Some(basic),
432 });
433
434 assert_matches!(
435 ctx.auth_header,
436 Some(AuthHeader {
437 auth_scheme: Some(AuthScheme::Basic(_)),
438 })
439 )
440 }
441}