1mod stream;
16
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21use api::v1::GreptimeRequest;
22use arrow_flight::flight_service_server::FlightService;
23use arrow_flight::{
24 Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
25 HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
26};
27use async_trait::async_trait;
28use bytes::Bytes;
29use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse};
30use common_grpc::flight::{FlightEncoder, FlightMessage};
31use common_query::{Output, OutputData};
32use common_telemetry::tracing::info_span;
33use common_telemetry::tracing_context::{FutureExt, TracingContext};
34use futures::{Stream, future, ready};
35use futures_util::{StreamExt, TryStreamExt};
36use prost::Message;
37use session::context::{QueryContext, QueryContextRef};
38use snafu::{ResultExt, ensure};
39use table::table_name::TableName;
40use tokio::sync::mpsc;
41use tokio_stream::wrappers::ReceiverStream;
42use tonic::{Request, Response, Status, Streaming};
43
44use crate::error::{InvalidParameterSnafu, ParseJsonSnafu, Result, ToJsonSnafu};
45pub use crate::grpc::flight::stream::FlightRecordBatchStream;
46use crate::grpc::greptime_handler::{GreptimeRequestHandler, get_request_type};
47use crate::grpc::{FlightCompression, TonicResult, context_auth};
48use crate::metrics::{METRIC_GRPC_MEMORY_USAGE_BYTES, METRIC_GRPC_REQUESTS_REJECTED_TOTAL};
49use crate::request_limiter::{RequestMemoryGuard, RequestMemoryLimiter};
50use crate::{error, hint_headers};
51
52pub type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + 'static>>;
53
54#[async_trait]
56pub trait FlightCraft: Send + Sync + 'static {
57 async fn do_get(
58 &self,
59 request: Request<Ticket>,
60 ) -> TonicResult<Response<TonicStream<FlightData>>>;
61
62 async fn do_put(
63 &self,
64 request: Request<Streaming<FlightData>>,
65 ) -> TonicResult<Response<TonicStream<PutResult>>> {
66 let _ = request;
67 Err(Status::unimplemented("Not yet implemented"))
68 }
69}
70
71pub type FlightCraftRef = Arc<dyn FlightCraft>;
72
73pub struct FlightCraftWrapper<T: FlightCraft>(pub T);
74
75impl<T: FlightCraft> From<T> for FlightCraftWrapper<T> {
76 fn from(t: T) -> Self {
77 Self(t)
78 }
79}
80
81#[async_trait]
82impl FlightCraft for FlightCraftRef {
83 async fn do_get(
84 &self,
85 request: Request<Ticket>,
86 ) -> TonicResult<Response<TonicStream<FlightData>>> {
87 (**self).do_get(request).await
88 }
89
90 async fn do_put(
91 &self,
92 request: Request<Streaming<FlightData>>,
93 ) -> TonicResult<Response<TonicStream<PutResult>>> {
94 self.as_ref().do_put(request).await
95 }
96}
97
98#[async_trait]
99impl<T: FlightCraft> FlightService for FlightCraftWrapper<T> {
100 type HandshakeStream = TonicStream<HandshakeResponse>;
101
102 async fn handshake(
103 &self,
104 _: Request<Streaming<HandshakeRequest>>,
105 ) -> TonicResult<Response<Self::HandshakeStream>> {
106 Err(Status::unimplemented("Not yet implemented"))
107 }
108
109 type ListFlightsStream = TonicStream<FlightInfo>;
110
111 async fn list_flights(
112 &self,
113 _: Request<Criteria>,
114 ) -> TonicResult<Response<Self::ListFlightsStream>> {
115 Err(Status::unimplemented("Not yet implemented"))
116 }
117
118 async fn get_flight_info(
119 &self,
120 _: Request<FlightDescriptor>,
121 ) -> TonicResult<Response<FlightInfo>> {
122 Err(Status::unimplemented("Not yet implemented"))
123 }
124
125 async fn poll_flight_info(
126 &self,
127 _: Request<FlightDescriptor>,
128 ) -> TonicResult<Response<PollInfo>> {
129 Err(Status::unimplemented("Not yet implemented"))
130 }
131
132 async fn get_schema(
133 &self,
134 _: Request<FlightDescriptor>,
135 ) -> TonicResult<Response<SchemaResult>> {
136 Err(Status::unimplemented("Not yet implemented"))
137 }
138
139 type DoGetStream = TonicStream<FlightData>;
140
141 async fn do_get(&self, request: Request<Ticket>) -> TonicResult<Response<Self::DoGetStream>> {
142 self.0.do_get(request).await
143 }
144
145 type DoPutStream = TonicStream<PutResult>;
146
147 async fn do_put(
148 &self,
149 request: Request<Streaming<FlightData>>,
150 ) -> TonicResult<Response<Self::DoPutStream>> {
151 self.0.do_put(request).await
152 }
153
154 type DoExchangeStream = TonicStream<FlightData>;
155
156 async fn do_exchange(
157 &self,
158 _: Request<Streaming<FlightData>>,
159 ) -> TonicResult<Response<Self::DoExchangeStream>> {
160 Err(Status::unimplemented("Not yet implemented"))
161 }
162
163 type DoActionStream = TonicStream<arrow_flight::Result>;
164
165 async fn do_action(&self, _: Request<Action>) -> TonicResult<Response<Self::DoActionStream>> {
166 Err(Status::unimplemented("Not yet implemented"))
167 }
168
169 type ListActionsStream = TonicStream<ActionType>;
170
171 async fn list_actions(
172 &self,
173 _: Request<Empty>,
174 ) -> TonicResult<Response<Self::ListActionsStream>> {
175 Err(Status::unimplemented("Not yet implemented"))
176 }
177}
178
179#[async_trait]
180impl FlightCraft for GreptimeRequestHandler {
181 async fn do_get(
182 &self,
183 request: Request<Ticket>,
184 ) -> TonicResult<Response<TonicStream<FlightData>>> {
185 let hints = hint_headers::extract_hints(request.metadata());
186
187 let ticket = request.into_inner().ticket;
188 let request =
189 GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?;
190
191 let span = info_span!(
193 "GreptimeRequestHandler::do_get",
194 protocol = "grpc",
195 request_type = get_request_type(&request)
196 );
197 let flight_compression = self.flight_compression;
198 async {
199 let output = self.handle_request(request, hints).await?;
200 let stream = to_flight_data_stream(
201 output,
202 TracingContext::from_current_span(),
203 flight_compression,
204 QueryContext::arc(),
205 );
206 Ok(Response::new(stream))
207 }
208 .trace(span)
209 .await
210 }
211
212 async fn do_put(
213 &self,
214 request: Request<Streaming<FlightData>>,
215 ) -> TonicResult<Response<TonicStream<PutResult>>> {
216 let (headers, extensions, stream) = request.into_parts();
217
218 let limiter = extensions.get::<RequestMemoryLimiter>().cloned();
219
220 let query_ctx = context_auth::create_query_context_from_grpc_metadata(&headers)?;
221 context_auth::check_auth(self.user_provider.clone(), &headers, query_ctx.clone()).await?;
222
223 const MAX_PENDING_RESPONSES: usize = 32;
224 let (tx, rx) = mpsc::channel::<TonicResult<DoPutResponse>>(MAX_PENDING_RESPONSES);
225
226 let stream = PutRecordBatchRequestStream {
227 flight_data_stream: stream,
228 state: PutRecordBatchRequestStreamState::Init(
229 query_ctx.current_catalog().to_string(),
230 query_ctx.current_schema(),
231 ),
232 limiter,
233 };
234 self.put_record_batches(stream, tx, query_ctx).await;
235
236 let response = ReceiverStream::new(rx)
237 .and_then(|response| {
238 future::ready({
239 serde_json::to_vec(&response)
240 .context(ToJsonSnafu)
241 .map(|x| PutResult {
242 app_metadata: Bytes::from(x),
243 })
244 .map_err(Into::into)
245 })
246 })
247 .boxed();
248 Ok(Response::new(response))
249 }
250}
251
252pub(crate) struct PutRecordBatchRequest {
253 pub(crate) table_name: TableName,
254 pub(crate) request_id: i64,
255 pub(crate) data: FlightData,
256 pub(crate) _guard: Option<RequestMemoryGuard>,
257}
258
259impl PutRecordBatchRequest {
260 fn try_new(
261 table_name: TableName,
262 flight_data: FlightData,
263 limiter: Option<&RequestMemoryLimiter>,
264 ) -> Result<Self> {
265 let request_id = if !flight_data.app_metadata.is_empty() {
266 let metadata: DoPutMetadata =
267 serde_json::from_slice(&flight_data.app_metadata).context(ParseJsonSnafu)?;
268 metadata.request_id()
269 } else {
270 0
271 };
272
273 let _guard = limiter
274 .filter(|limiter| limiter.is_enabled())
275 .map(|limiter| {
276 let message_size = flight_data.encoded_len();
277 limiter
278 .try_acquire(message_size)
279 .map(|guard| {
280 guard.inspect(|g| {
281 METRIC_GRPC_MEMORY_USAGE_BYTES.set(g.current_usage() as i64);
282 })
283 })
284 .inspect_err(|_| {
285 METRIC_GRPC_REQUESTS_REJECTED_TOTAL.inc();
286 })
287 })
288 .transpose()?
289 .flatten();
290
291 Ok(Self {
292 table_name,
293 request_id,
294 data: flight_data,
295 _guard,
296 })
297 }
298}
299
300pub(crate) struct PutRecordBatchRequestStream {
301 flight_data_stream: Streaming<FlightData>,
302 state: PutRecordBatchRequestStreamState,
303 limiter: Option<RequestMemoryLimiter>,
304}
305
306enum PutRecordBatchRequestStreamState {
307 Init(String, String),
308 Started(TableName),
309}
310
311impl Stream for PutRecordBatchRequestStream {
312 type Item = TonicResult<PutRecordBatchRequest>;
313
314 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
315 fn extract_table_name(mut descriptor: FlightDescriptor) -> Result<String> {
316 ensure!(
317 descriptor.r#type == arrow_flight::flight_descriptor::DescriptorType::Path as i32,
318 InvalidParameterSnafu {
319 reason: "expect FlightDescriptor::type == 'Path' only",
320 }
321 );
322 ensure!(
323 descriptor.path.len() == 1,
324 InvalidParameterSnafu {
325 reason: "expect FlightDescriptor::path has only one table name",
326 }
327 );
328 Ok(descriptor.path.remove(0))
329 }
330
331 let poll = ready!(self.flight_data_stream.poll_next_unpin(cx));
332 let limiter = self.limiter.clone();
333
334 let result = match &mut self.state {
335 PutRecordBatchRequestStreamState::Init(catalog, schema) => match poll {
336 Some(Ok(mut flight_data)) => {
337 let flight_descriptor = flight_data.flight_descriptor.take();
338 let result = if let Some(descriptor) = flight_descriptor {
339 let table_name = extract_table_name(descriptor)
340 .map(|x| TableName::new(catalog.clone(), schema.clone(), x));
341 let table_name = match table_name {
342 Ok(table_name) => table_name,
343 Err(e) => return Poll::Ready(Some(Err(e.into()))),
344 };
345
346 let request = PutRecordBatchRequest::try_new(
347 table_name.clone(),
348 flight_data,
349 limiter.as_ref(),
350 );
351 let request = match request {
352 Ok(request) => request,
353 Err(e) => return Poll::Ready(Some(Err(e.into()))),
354 };
355
356 self.state = PutRecordBatchRequestStreamState::Started(table_name);
357
358 Ok(request)
359 } else {
360 Err(Status::failed_precondition(
361 "table to put is not found in flight descriptor",
362 ))
363 };
364 Some(result)
365 }
366 Some(Err(e)) => Some(Err(e)),
367 None => None,
368 },
369 PutRecordBatchRequestStreamState::Started(table_name) => poll.map(|x| {
370 x.and_then(|flight_data| {
371 PutRecordBatchRequest::try_new(
372 table_name.clone(),
373 flight_data,
374 limiter.as_ref(),
375 )
376 .map_err(Into::into)
377 })
378 }),
379 };
380 Poll::Ready(result)
381 }
382}
383
384fn to_flight_data_stream(
385 output: Output,
386 tracing_context: TracingContext,
387 flight_compression: FlightCompression,
388 query_ctx: QueryContextRef,
389) -> TonicStream<FlightData> {
390 match output.data {
391 OutputData::Stream(stream) => {
392 let stream = FlightRecordBatchStream::new(
393 stream,
394 tracing_context,
395 flight_compression,
396 query_ctx,
397 );
398 Box::pin(stream) as _
399 }
400 OutputData::RecordBatches(x) => {
401 let stream = FlightRecordBatchStream::new(
402 x.as_stream(),
403 tracing_context,
404 flight_compression,
405 query_ctx,
406 );
407 Box::pin(stream) as _
408 }
409 OutputData::AffectedRows(rows) => {
410 let stream = tokio_stream::iter(
411 FlightEncoder::default()
412 .encode(FlightMessage::AffectedRows(rows))
413 .into_iter()
414 .map(Ok),
415 );
416 Box::pin(stream) as _
417 }
418 }
419}