1mod stream;
16
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21use api::v1::GreptimeRequest;
22use arrow_flight::flight_service_server::FlightService;
23use arrow_flight::{
24 Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
25 HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
26};
27use async_trait::async_trait;
28use bytes::Bytes;
29use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse};
30use common_grpc::flight::{FlightEncoder, FlightMessage};
31use common_query::{Output, OutputData};
32use common_telemetry::tracing::info_span;
33use common_telemetry::tracing_context::{FutureExt, TracingContext};
34use futures::{Stream, future, ready};
35use futures_util::{StreamExt, TryStreamExt};
36use prost::Message;
37use session::context::{QueryContext, QueryContextRef};
38use snafu::{ResultExt, ensure};
39use table::table_name::TableName;
40use tokio::sync::mpsc;
41use tokio_stream::wrappers::ReceiverStream;
42use tonic::{Request, Response, Status, Streaming};
43
44use crate::error::{InvalidParameterSnafu, ParseJsonSnafu, Result, ToJsonSnafu};
45pub use crate::grpc::flight::stream::FlightRecordBatchStream;
46use crate::grpc::greptime_handler::{GreptimeRequestHandler, get_request_type};
47use crate::grpc::{FlightCompression, TonicResult, context_auth};
48use crate::{error, hint_headers};
49
50pub type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + 'static>>;
51
52#[async_trait]
54pub trait FlightCraft: Send + Sync + 'static {
55 async fn do_get(
56 &self,
57 request: Request<Ticket>,
58 ) -> TonicResult<Response<TonicStream<FlightData>>>;
59
60 async fn do_put(
61 &self,
62 request: Request<Streaming<FlightData>>,
63 ) -> TonicResult<Response<TonicStream<PutResult>>> {
64 let _ = request;
65 Err(Status::unimplemented("Not yet implemented"))
66 }
67}
68
69pub type FlightCraftRef = Arc<dyn FlightCraft>;
70
71pub struct FlightCraftWrapper<T: FlightCraft>(pub T);
72
73impl<T: FlightCraft> From<T> for FlightCraftWrapper<T> {
74 fn from(t: T) -> Self {
75 Self(t)
76 }
77}
78
79#[async_trait]
80impl FlightCraft for FlightCraftRef {
81 async fn do_get(
82 &self,
83 request: Request<Ticket>,
84 ) -> TonicResult<Response<TonicStream<FlightData>>> {
85 (**self).do_get(request).await
86 }
87
88 async fn do_put(
89 &self,
90 request: Request<Streaming<FlightData>>,
91 ) -> TonicResult<Response<TonicStream<PutResult>>> {
92 self.as_ref().do_put(request).await
93 }
94}
95
96#[async_trait]
97impl<T: FlightCraft> FlightService for FlightCraftWrapper<T> {
98 type HandshakeStream = TonicStream<HandshakeResponse>;
99
100 async fn handshake(
101 &self,
102 _: Request<Streaming<HandshakeRequest>>,
103 ) -> TonicResult<Response<Self::HandshakeStream>> {
104 Err(Status::unimplemented("Not yet implemented"))
105 }
106
107 type ListFlightsStream = TonicStream<FlightInfo>;
108
109 async fn list_flights(
110 &self,
111 _: Request<Criteria>,
112 ) -> TonicResult<Response<Self::ListFlightsStream>> {
113 Err(Status::unimplemented("Not yet implemented"))
114 }
115
116 async fn get_flight_info(
117 &self,
118 _: Request<FlightDescriptor>,
119 ) -> TonicResult<Response<FlightInfo>> {
120 Err(Status::unimplemented("Not yet implemented"))
121 }
122
123 async fn poll_flight_info(
124 &self,
125 _: Request<FlightDescriptor>,
126 ) -> TonicResult<Response<PollInfo>> {
127 Err(Status::unimplemented("Not yet implemented"))
128 }
129
130 async fn get_schema(
131 &self,
132 _: Request<FlightDescriptor>,
133 ) -> TonicResult<Response<SchemaResult>> {
134 Err(Status::unimplemented("Not yet implemented"))
135 }
136
137 type DoGetStream = TonicStream<FlightData>;
138
139 async fn do_get(&self, request: Request<Ticket>) -> TonicResult<Response<Self::DoGetStream>> {
140 self.0.do_get(request).await
141 }
142
143 type DoPutStream = TonicStream<PutResult>;
144
145 async fn do_put(
146 &self,
147 request: Request<Streaming<FlightData>>,
148 ) -> TonicResult<Response<Self::DoPutStream>> {
149 self.0.do_put(request).await
150 }
151
152 type DoExchangeStream = TonicStream<FlightData>;
153
154 async fn do_exchange(
155 &self,
156 _: Request<Streaming<FlightData>>,
157 ) -> TonicResult<Response<Self::DoExchangeStream>> {
158 Err(Status::unimplemented("Not yet implemented"))
159 }
160
161 type DoActionStream = TonicStream<arrow_flight::Result>;
162
163 async fn do_action(&self, _: Request<Action>) -> TonicResult<Response<Self::DoActionStream>> {
164 Err(Status::unimplemented("Not yet implemented"))
165 }
166
167 type ListActionsStream = TonicStream<ActionType>;
168
169 async fn list_actions(
170 &self,
171 _: Request<Empty>,
172 ) -> TonicResult<Response<Self::ListActionsStream>> {
173 Err(Status::unimplemented("Not yet implemented"))
174 }
175}
176
177#[async_trait]
178impl FlightCraft for GreptimeRequestHandler {
179 async fn do_get(
180 &self,
181 request: Request<Ticket>,
182 ) -> TonicResult<Response<TonicStream<FlightData>>> {
183 let hints = hint_headers::extract_hints(request.metadata());
184
185 let ticket = request.into_inner().ticket;
186 let request =
187 GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?;
188
189 let span = info_span!(
191 "GreptimeRequestHandler::do_get",
192 protocol = "grpc",
193 request_type = get_request_type(&request)
194 );
195 let flight_compression = self.flight_compression;
196 async {
197 let output = self.handle_request(request, hints).await?;
198 let stream = to_flight_data_stream(
199 output,
200 TracingContext::from_current_span(),
201 flight_compression,
202 QueryContext::arc(),
203 );
204 Ok(Response::new(stream))
205 }
206 .trace(span)
207 .await
208 }
209
210 async fn do_put(
211 &self,
212 request: Request<Streaming<FlightData>>,
213 ) -> TonicResult<Response<TonicStream<PutResult>>> {
214 let (headers, _, stream) = request.into_parts();
215
216 let query_ctx = context_auth::create_query_context_from_grpc_metadata(&headers)?;
217 context_auth::check_auth(self.user_provider.clone(), &headers, query_ctx.clone()).await?;
218
219 const MAX_PENDING_RESPONSES: usize = 32;
220 let (tx, rx) = mpsc::channel::<TonicResult<DoPutResponse>>(MAX_PENDING_RESPONSES);
221
222 let stream = PutRecordBatchRequestStream {
223 flight_data_stream: stream,
224 state: PutRecordBatchRequestStreamState::Init(
225 query_ctx.current_catalog().to_string(),
226 query_ctx.current_schema(),
227 ),
228 };
229 self.put_record_batches(stream, tx, query_ctx).await;
230
231 let response = ReceiverStream::new(rx)
232 .and_then(|response| {
233 future::ready({
234 serde_json::to_vec(&response)
235 .context(ToJsonSnafu)
236 .map(|x| PutResult {
237 app_metadata: Bytes::from(x),
238 })
239 .map_err(Into::into)
240 })
241 })
242 .boxed();
243 Ok(Response::new(response))
244 }
245}
246
247pub(crate) struct PutRecordBatchRequest {
248 pub(crate) table_name: TableName,
249 pub(crate) request_id: i64,
250 pub(crate) data: FlightData,
251}
252
253impl PutRecordBatchRequest {
254 fn try_new(table_name: TableName, flight_data: FlightData) -> Result<Self> {
255 let request_id = if !flight_data.app_metadata.is_empty() {
256 let metadata: DoPutMetadata =
257 serde_json::from_slice(&flight_data.app_metadata).context(ParseJsonSnafu)?;
258 metadata.request_id()
259 } else {
260 0
261 };
262 Ok(Self {
263 table_name,
264 request_id,
265 data: flight_data,
266 })
267 }
268}
269
270pub(crate) struct PutRecordBatchRequestStream {
271 flight_data_stream: Streaming<FlightData>,
272 state: PutRecordBatchRequestStreamState,
273}
274
275enum PutRecordBatchRequestStreamState {
276 Init(String, String),
277 Started(TableName),
278}
279
280impl Stream for PutRecordBatchRequestStream {
281 type Item = TonicResult<PutRecordBatchRequest>;
282
283 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
284 fn extract_table_name(mut descriptor: FlightDescriptor) -> Result<String> {
285 ensure!(
286 descriptor.r#type == arrow_flight::flight_descriptor::DescriptorType::Path as i32,
287 InvalidParameterSnafu {
288 reason: "expect FlightDescriptor::type == 'Path' only",
289 }
290 );
291 ensure!(
292 descriptor.path.len() == 1,
293 InvalidParameterSnafu {
294 reason: "expect FlightDescriptor::path has only one table name",
295 }
296 );
297 Ok(descriptor.path.remove(0))
298 }
299
300 let poll = ready!(self.flight_data_stream.poll_next_unpin(cx));
301
302 let result = match &mut self.state {
303 PutRecordBatchRequestStreamState::Init(catalog, schema) => match poll {
304 Some(Ok(mut flight_data)) => {
305 let flight_descriptor = flight_data.flight_descriptor.take();
306 let result = if let Some(descriptor) = flight_descriptor {
307 let table_name = extract_table_name(descriptor)
308 .map(|x| TableName::new(catalog.clone(), schema.clone(), x));
309 let table_name = match table_name {
310 Ok(table_name) => table_name,
311 Err(e) => return Poll::Ready(Some(Err(e.into()))),
312 };
313
314 let request =
315 PutRecordBatchRequest::try_new(table_name.clone(), flight_data);
316 let request = match request {
317 Ok(request) => request,
318 Err(e) => return Poll::Ready(Some(Err(e.into()))),
319 };
320
321 self.state = PutRecordBatchRequestStreamState::Started(table_name);
322
323 Ok(request)
324 } else {
325 Err(Status::failed_precondition(
326 "table to put is not found in flight descriptor",
327 ))
328 };
329 Some(result)
330 }
331 Some(Err(e)) => Some(Err(e)),
332 None => None,
333 },
334 PutRecordBatchRequestStreamState::Started(table_name) => poll.map(|x| {
335 x.and_then(|flight_data| {
336 PutRecordBatchRequest::try_new(table_name.clone(), flight_data)
337 .map_err(Into::into)
338 })
339 }),
340 };
341 Poll::Ready(result)
342 }
343}
344
345fn to_flight_data_stream(
346 output: Output,
347 tracing_context: TracingContext,
348 flight_compression: FlightCompression,
349 query_ctx: QueryContextRef,
350) -> TonicStream<FlightData> {
351 match output.data {
352 OutputData::Stream(stream) => {
353 let stream = FlightRecordBatchStream::new(
354 stream,
355 tracing_context,
356 flight_compression,
357 query_ctx,
358 );
359 Box::pin(stream) as _
360 }
361 OutputData::RecordBatches(x) => {
362 let stream = FlightRecordBatchStream::new(
363 x.as_stream(),
364 tracing_context,
365 flight_compression,
366 query_ctx,
367 );
368 Box::pin(stream) as _
369 }
370 OutputData::AffectedRows(rows) => {
371 let stream = tokio_stream::iter(
372 FlightEncoder::default()
373 .encode(FlightMessage::AffectedRows(rows))
374 .into_iter()
375 .map(Ok),
376 );
377 Box::pin(stream) as _
378 }
379 }
380}