1mod stream;
16
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21use api::v1::GreptimeRequest;
22use arrow_flight::flight_service_server::FlightService;
23use arrow_flight::{
24 Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
25 HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
26};
27use async_trait::async_trait;
28use bytes::Bytes;
29use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
30use common_catalog::parse_catalog_and_schema_from_db_string;
31use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse};
32use common_grpc::flight::{FlightEncoder, FlightMessage};
33use common_query::{Output, OutputData};
34use common_telemetry::tracing::info_span;
35use common_telemetry::tracing_context::{FutureExt, TracingContext};
36use futures::{future, ready, Stream};
37use futures_util::{StreamExt, TryStreamExt};
38use prost::Message;
39use snafu::{ensure, ResultExt};
40use table::table_name::TableName;
41use tokio::sync::mpsc;
42use tokio_stream::wrappers::ReceiverStream;
43use tonic::{Request, Response, Status, Streaming};
44
45use crate::error::{InvalidParameterSnafu, ParseJsonSnafu, Result, ToJsonSnafu};
46pub use crate::grpc::flight::stream::FlightRecordBatchStream;
47use crate::grpc::greptime_handler::{get_request_type, GreptimeRequestHandler};
48use crate::grpc::{FlightCompression, TonicResult};
49use crate::http::header::constants::GREPTIME_DB_HEADER_NAME;
50use crate::http::AUTHORIZATION_HEADER;
51use crate::{error, hint_headers};
52
53pub type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + 'static>>;
54
55#[async_trait]
57pub trait FlightCraft: Send + Sync + 'static {
58 async fn do_get(
59 &self,
60 request: Request<Ticket>,
61 ) -> TonicResult<Response<TonicStream<FlightData>>>;
62
63 async fn do_put(
64 &self,
65 request: Request<Streaming<FlightData>>,
66 ) -> TonicResult<Response<TonicStream<PutResult>>> {
67 let _ = request;
68 Err(Status::unimplemented("Not yet implemented"))
69 }
70}
71
72pub type FlightCraftRef = Arc<dyn FlightCraft>;
73
74pub struct FlightCraftWrapper<T: FlightCraft>(pub T);
75
76impl<T: FlightCraft> From<T> for FlightCraftWrapper<T> {
77 fn from(t: T) -> Self {
78 Self(t)
79 }
80}
81
82#[async_trait]
83impl FlightCraft for FlightCraftRef {
84 async fn do_get(
85 &self,
86 request: Request<Ticket>,
87 ) -> TonicResult<Response<TonicStream<FlightData>>> {
88 (**self).do_get(request).await
89 }
90
91 async fn do_put(
92 &self,
93 request: Request<Streaming<FlightData>>,
94 ) -> TonicResult<Response<TonicStream<PutResult>>> {
95 self.as_ref().do_put(request).await
96 }
97}
98
99#[async_trait]
100impl<T: FlightCraft> FlightService for FlightCraftWrapper<T> {
101 type HandshakeStream = TonicStream<HandshakeResponse>;
102
103 async fn handshake(
104 &self,
105 _: Request<Streaming<HandshakeRequest>>,
106 ) -> TonicResult<Response<Self::HandshakeStream>> {
107 Err(Status::unimplemented("Not yet implemented"))
108 }
109
110 type ListFlightsStream = TonicStream<FlightInfo>;
111
112 async fn list_flights(
113 &self,
114 _: Request<Criteria>,
115 ) -> TonicResult<Response<Self::ListFlightsStream>> {
116 Err(Status::unimplemented("Not yet implemented"))
117 }
118
119 async fn get_flight_info(
120 &self,
121 _: Request<FlightDescriptor>,
122 ) -> TonicResult<Response<FlightInfo>> {
123 Err(Status::unimplemented("Not yet implemented"))
124 }
125
126 async fn poll_flight_info(
127 &self,
128 _: Request<FlightDescriptor>,
129 ) -> TonicResult<Response<PollInfo>> {
130 Err(Status::unimplemented("Not yet implemented"))
131 }
132
133 async fn get_schema(
134 &self,
135 _: Request<FlightDescriptor>,
136 ) -> TonicResult<Response<SchemaResult>> {
137 Err(Status::unimplemented("Not yet implemented"))
138 }
139
140 type DoGetStream = TonicStream<FlightData>;
141
142 async fn do_get(&self, request: Request<Ticket>) -> TonicResult<Response<Self::DoGetStream>> {
143 self.0.do_get(request).await
144 }
145
146 type DoPutStream = TonicStream<PutResult>;
147
148 async fn do_put(
149 &self,
150 request: Request<Streaming<FlightData>>,
151 ) -> TonicResult<Response<Self::DoPutStream>> {
152 self.0.do_put(request).await
153 }
154
155 type DoExchangeStream = TonicStream<FlightData>;
156
157 async fn do_exchange(
158 &self,
159 _: Request<Streaming<FlightData>>,
160 ) -> TonicResult<Response<Self::DoExchangeStream>> {
161 Err(Status::unimplemented("Not yet implemented"))
162 }
163
164 type DoActionStream = TonicStream<arrow_flight::Result>;
165
166 async fn do_action(&self, _: Request<Action>) -> TonicResult<Response<Self::DoActionStream>> {
167 Err(Status::unimplemented("Not yet implemented"))
168 }
169
170 type ListActionsStream = TonicStream<ActionType>;
171
172 async fn list_actions(
173 &self,
174 _: Request<Empty>,
175 ) -> TonicResult<Response<Self::ListActionsStream>> {
176 Err(Status::unimplemented("Not yet implemented"))
177 }
178}
179
180#[async_trait]
181impl FlightCraft for GreptimeRequestHandler {
182 async fn do_get(
183 &self,
184 request: Request<Ticket>,
185 ) -> TonicResult<Response<TonicStream<FlightData>>> {
186 let hints = hint_headers::extract_hints(request.metadata());
187
188 let ticket = request.into_inner().ticket;
189 let request =
190 GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?;
191
192 let span = info_span!(
194 "GreptimeRequestHandler::do_get",
195 protocol = "grpc",
196 request_type = get_request_type(&request)
197 );
198 let flight_compression = self.flight_compression;
199 async {
200 let output = self.handle_request(request, hints).await?;
201 let stream = to_flight_data_stream(
202 output,
203 TracingContext::from_current_span(),
204 flight_compression,
205 );
206 Ok(Response::new(stream))
207 }
208 .trace(span)
209 .await
210 }
211
212 async fn do_put(
213 &self,
214 request: Request<Streaming<FlightData>>,
215 ) -> TonicResult<Response<TonicStream<PutResult>>> {
216 let (headers, _, stream) = request.into_parts();
217
218 let header = |key: &str| -> TonicResult<Option<&str>> {
219 let Some(v) = headers.get(key) else {
220 return Ok(None);
221 };
222 let Ok(v) = std::str::from_utf8(v.as_bytes()) else {
223 return Err(InvalidParameterSnafu {
224 reason: "expect valid UTF-8 value",
225 }
226 .build()
227 .into());
228 };
229 Ok(Some(v))
230 };
231
232 let username_and_password = header(AUTHORIZATION_HEADER)?;
233 let db = header(GREPTIME_DB_HEADER_NAME)?;
234 if !self.validate_auth(username_and_password, db).await? {
235 return Err(Status::unauthenticated("auth failed"));
236 }
237
238 const MAX_PENDING_RESPONSES: usize = 32;
239 let (tx, rx) = mpsc::channel::<TonicResult<DoPutResponse>>(MAX_PENDING_RESPONSES);
240
241 let stream = PutRecordBatchRequestStream {
242 flight_data_stream: stream,
243 state: PutRecordBatchRequestStreamState::Init(db.map(ToString::to_string)),
244 };
245 self.put_record_batches(stream, tx).await;
246
247 let response = ReceiverStream::new(rx)
248 .and_then(|response| {
249 future::ready({
250 serde_json::to_vec(&response)
251 .context(ToJsonSnafu)
252 .map(|x| PutResult {
253 app_metadata: Bytes::from(x),
254 })
255 .map_err(Into::into)
256 })
257 })
258 .boxed();
259 Ok(Response::new(response))
260 }
261}
262
263pub(crate) struct PutRecordBatchRequest {
264 pub(crate) table_name: TableName,
265 pub(crate) request_id: i64,
266 pub(crate) data: FlightData,
267}
268
269impl PutRecordBatchRequest {
270 fn try_new(table_name: TableName, flight_data: FlightData) -> Result<Self> {
271 let request_id = if !flight_data.app_metadata.is_empty() {
272 let metadata: DoPutMetadata =
273 serde_json::from_slice(&flight_data.app_metadata).context(ParseJsonSnafu)?;
274 metadata.request_id()
275 } else {
276 0
277 };
278 Ok(Self {
279 table_name,
280 request_id,
281 data: flight_data,
282 })
283 }
284}
285
286pub(crate) struct PutRecordBatchRequestStream {
287 flight_data_stream: Streaming<FlightData>,
288 state: PutRecordBatchRequestStreamState,
289}
290
291enum PutRecordBatchRequestStreamState {
292 Init(Option<String>),
293 Started(TableName),
294}
295
296impl Stream for PutRecordBatchRequestStream {
297 type Item = TonicResult<PutRecordBatchRequest>;
298
299 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
300 fn extract_table_name(mut descriptor: FlightDescriptor) -> Result<String> {
301 ensure!(
302 descriptor.r#type == arrow_flight::flight_descriptor::DescriptorType::Path as i32,
303 InvalidParameterSnafu {
304 reason: "expect FlightDescriptor::type == 'Path' only",
305 }
306 );
307 ensure!(
308 descriptor.path.len() == 1,
309 InvalidParameterSnafu {
310 reason: "expect FlightDescriptor::path has only one table name",
311 }
312 );
313 Ok(descriptor.path.remove(0))
314 }
315
316 let poll = ready!(self.flight_data_stream.poll_next_unpin(cx));
317
318 let result = match &mut self.state {
319 PutRecordBatchRequestStreamState::Init(db) => match poll {
320 Some(Ok(mut flight_data)) => {
321 let flight_descriptor = flight_data.flight_descriptor.take();
322 let result = if let Some(descriptor) = flight_descriptor {
323 let table_name = extract_table_name(descriptor).map(|x| {
324 let (catalog, schema) = if let Some(db) = db {
325 parse_catalog_and_schema_from_db_string(db)
326 } else {
327 (
328 DEFAULT_CATALOG_NAME.to_string(),
329 DEFAULT_SCHEMA_NAME.to_string(),
330 )
331 };
332 TableName::new(catalog, schema, x)
333 });
334 let table_name = match table_name {
335 Ok(table_name) => table_name,
336 Err(e) => return Poll::Ready(Some(Err(e.into()))),
337 };
338
339 let request =
340 PutRecordBatchRequest::try_new(table_name.clone(), flight_data);
341 let request = match request {
342 Ok(request) => request,
343 Err(e) => return Poll::Ready(Some(Err(e.into()))),
344 };
345
346 self.state = PutRecordBatchRequestStreamState::Started(table_name);
347
348 Ok(request)
349 } else {
350 Err(Status::failed_precondition(
351 "table to put is not found in flight descriptor",
352 ))
353 };
354 Some(result)
355 }
356 Some(Err(e)) => Some(Err(e)),
357 None => None,
358 },
359 PutRecordBatchRequestStreamState::Started(table_name) => poll.map(|x| {
360 x.and_then(|flight_data| {
361 PutRecordBatchRequest::try_new(table_name.clone(), flight_data)
362 .map_err(Into::into)
363 })
364 }),
365 };
366 Poll::Ready(result)
367 }
368}
369
370fn to_flight_data_stream(
371 output: Output,
372 tracing_context: TracingContext,
373 flight_compression: FlightCompression,
374) -> TonicStream<FlightData> {
375 match output.data {
376 OutputData::Stream(stream) => {
377 let stream = FlightRecordBatchStream::new(stream, tracing_context, flight_compression);
378 Box::pin(stream) as _
379 }
380 OutputData::RecordBatches(x) => {
381 let stream =
382 FlightRecordBatchStream::new(x.as_stream(), tracing_context, flight_compression);
383 Box::pin(stream) as _
384 }
385 OutputData::AffectedRows(rows) => {
386 let stream = tokio_stream::once(Ok(
387 FlightEncoder::default().encode(FlightMessage::AffectedRows(rows))
388 ));
389 Box::pin(stream) as _
390 }
391 }
392}