1mod stream;
16
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21use api::v1::GreptimeRequest;
22use arrow_flight::flight_service_server::FlightService;
23use arrow_flight::{
24 Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
25 HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
26};
27use async_trait::async_trait;
28use bytes;
29use bytes::Bytes;
30use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse};
31use common_grpc::flight::{FlightDecoder, FlightEncoder, FlightMessage};
32use common_memory_manager::MemoryGuard;
33use common_query::{Output, OutputData};
34use common_recordbatch::DfRecordBatch;
35use common_telemetry::debug;
36use common_telemetry::tracing::info_span;
37use common_telemetry::tracing_context::{FutureExt, TracingContext};
38use datatypes::arrow::datatypes::SchemaRef;
39use futures::{Stream, future, ready};
40use futures_util::{StreamExt, TryStreamExt};
41use prost::Message;
42use session::context::{QueryContext, QueryContextRef};
43use snafu::{IntoError, ResultExt, ensure};
44use table::table_name::TableName;
45use tokio::sync::mpsc;
46use tokio_stream::wrappers::ReceiverStream;
47use tonic::{Request, Response, Status, Streaming};
48
49use crate::error::{InvalidParameterSnafu, Result, ToJsonSnafu};
50pub use crate::grpc::flight::stream::FlightRecordBatchStream;
51use crate::grpc::greptime_handler::{GreptimeRequestHandler, get_request_type};
52use crate::grpc::{FlightCompression, TonicResult, context_auth};
53use crate::request_memory_limiter::ServerMemoryLimiter;
54use crate::request_memory_metrics::RequestMemoryMetrics;
55use crate::{error, hint_headers};
56
57pub type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + 'static>>;
58
59#[async_trait]
61pub trait FlightCraft: Send + Sync + 'static {
62 async fn do_get(
63 &self,
64 request: Request<Ticket>,
65 ) -> TonicResult<Response<TonicStream<FlightData>>>;
66
67 async fn do_put(
68 &self,
69 request: Request<Streaming<FlightData>>,
70 ) -> TonicResult<Response<TonicStream<PutResult>>> {
71 let _ = request;
72 Err(Status::unimplemented("Not yet implemented"))
73 }
74}
75
76pub type FlightCraftRef = Arc<dyn FlightCraft>;
77
78pub struct FlightCraftWrapper<T: FlightCraft>(pub T);
79
80impl<T: FlightCraft> From<T> for FlightCraftWrapper<T> {
81 fn from(t: T) -> Self {
82 Self(t)
83 }
84}
85
86#[async_trait]
87impl FlightCraft for FlightCraftRef {
88 async fn do_get(
89 &self,
90 request: Request<Ticket>,
91 ) -> TonicResult<Response<TonicStream<FlightData>>> {
92 (**self).do_get(request).await
93 }
94
95 async fn do_put(
96 &self,
97 request: Request<Streaming<FlightData>>,
98 ) -> TonicResult<Response<TonicStream<PutResult>>> {
99 self.as_ref().do_put(request).await
100 }
101}
102
103#[async_trait]
104impl<T: FlightCraft> FlightService for FlightCraftWrapper<T> {
105 type HandshakeStream = TonicStream<HandshakeResponse>;
106
107 async fn handshake(
108 &self,
109 _: Request<Streaming<HandshakeRequest>>,
110 ) -> TonicResult<Response<Self::HandshakeStream>> {
111 Err(Status::unimplemented("Not yet implemented"))
112 }
113
114 type ListFlightsStream = TonicStream<FlightInfo>;
115
116 async fn list_flights(
117 &self,
118 _: Request<Criteria>,
119 ) -> TonicResult<Response<Self::ListFlightsStream>> {
120 Err(Status::unimplemented("Not yet implemented"))
121 }
122
123 async fn get_flight_info(
124 &self,
125 _: Request<FlightDescriptor>,
126 ) -> TonicResult<Response<FlightInfo>> {
127 Err(Status::unimplemented("Not yet implemented"))
128 }
129
130 async fn poll_flight_info(
131 &self,
132 _: Request<FlightDescriptor>,
133 ) -> TonicResult<Response<PollInfo>> {
134 Err(Status::unimplemented("Not yet implemented"))
135 }
136
137 async fn get_schema(
138 &self,
139 _: Request<FlightDescriptor>,
140 ) -> TonicResult<Response<SchemaResult>> {
141 Err(Status::unimplemented("Not yet implemented"))
142 }
143
144 type DoGetStream = TonicStream<FlightData>;
145
146 async fn do_get(&self, request: Request<Ticket>) -> TonicResult<Response<Self::DoGetStream>> {
147 self.0.do_get(request).await
148 }
149
150 type DoPutStream = TonicStream<PutResult>;
151
152 async fn do_put(
153 &self,
154 request: Request<Streaming<FlightData>>,
155 ) -> TonicResult<Response<Self::DoPutStream>> {
156 self.0.do_put(request).await
157 }
158
159 type DoExchangeStream = TonicStream<FlightData>;
160
161 async fn do_exchange(
162 &self,
163 _: Request<Streaming<FlightData>>,
164 ) -> TonicResult<Response<Self::DoExchangeStream>> {
165 Err(Status::unimplemented("Not yet implemented"))
166 }
167
168 type DoActionStream = TonicStream<arrow_flight::Result>;
169
170 async fn do_action(&self, _: Request<Action>) -> TonicResult<Response<Self::DoActionStream>> {
171 Err(Status::unimplemented("Not yet implemented"))
172 }
173
174 type ListActionsStream = TonicStream<ActionType>;
175
176 async fn list_actions(
177 &self,
178 _: Request<Empty>,
179 ) -> TonicResult<Response<Self::ListActionsStream>> {
180 Err(Status::unimplemented("Not yet implemented"))
181 }
182}
183
184#[async_trait]
185impl FlightCraft for GreptimeRequestHandler {
186 async fn do_get(
187 &self,
188 request: Request<Ticket>,
189 ) -> TonicResult<Response<TonicStream<FlightData>>> {
190 let hints = hint_headers::extract_hints(request.metadata());
191
192 let ticket = request.into_inner().ticket;
193 let request =
194 GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?;
195
196 let span = info_span!(
198 "GreptimeRequestHandler::do_get",
199 protocol = "grpc",
200 request_type = get_request_type(&request)
201 );
202 let flight_compression = self.flight_compression;
203 async {
204 let output = self.handle_request(request, hints).await?;
205 let stream = to_flight_data_stream(
206 output,
207 TracingContext::from_current_span(),
208 flight_compression,
209 QueryContext::arc(),
210 );
211 Ok(Response::new(stream))
212 }
213 .trace(span)
214 .await
215 }
216
217 async fn do_put(
218 &self,
219 request: Request<Streaming<FlightData>>,
220 ) -> TonicResult<Response<TonicStream<PutResult>>> {
221 let (headers, extensions, stream) = request.into_parts();
222
223 let limiter = extensions.get::<ServerMemoryLimiter>().cloned();
224
225 let query_ctx = context_auth::create_query_context_from_grpc_metadata(&headers)?;
226 context_auth::check_auth(self.user_provider.clone(), &headers, query_ctx.clone()).await?;
227
228 const MAX_PENDING_RESPONSES: usize = 32;
229 let (tx, rx) = mpsc::channel::<TonicResult<DoPutResponse>>(MAX_PENDING_RESPONSES);
230
231 let stream = PutRecordBatchRequestStream::new(
232 stream,
233 query_ctx.current_catalog().to_string(),
234 query_ctx.current_schema(),
235 limiter,
236 )
237 .await?;
238 let _ = tx.send(Ok(DoPutResponse::new(0, 0, 0.0))).await;
240 self.put_record_batches(stream, tx, query_ctx).await;
241
242 let response = ReceiverStream::new(rx)
243 .and_then(|response| {
244 future::ready({
245 serde_json::to_vec(&response)
246 .context(ToJsonSnafu)
247 .map(|x| PutResult {
248 app_metadata: Bytes::from(x),
249 })
250 .map_err(Into::into)
251 })
252 })
253 .boxed();
254 Ok(Response::new(response))
255 }
256}
257
258pub struct PutRecordBatchRequest {
259 pub table_name: TableName,
260 pub request_id: i64,
261 pub record_batch: DfRecordBatch,
262 pub schema_bytes: Bytes,
263 pub flight_data: FlightData,
264 pub(crate) _guard: Option<MemoryGuard<RequestMemoryMetrics>>,
265}
266
267impl PutRecordBatchRequest {
268 fn try_new(
269 table_name: TableName,
270 record_batch: DfRecordBatch,
271 request_id: i64,
272 schema_bytes: Bytes,
273 flight_data: FlightData,
274 limiter: Option<&ServerMemoryLimiter>,
275 ) -> Result<Self> {
276 let memory_usage = flight_data.data_body.len()
277 + flight_data.app_metadata.len()
278 + flight_data.data_header.len();
279
280 let _guard = if let Some(limiter) = limiter {
281 let guard = limiter.try_acquire(memory_usage as u64).ok_or_else(|| {
282 let inner_err = common_memory_manager::Error::MemoryLimitExceeded {
283 requested_bytes: memory_usage as u64,
284 limit_bytes: limiter.limit_bytes(),
285 };
286 error::MemoryLimitExceededSnafu.into_error(inner_err)
287 })?;
288 Some(guard)
289 } else {
290 None
291 };
292
293 Ok(Self {
294 table_name,
295 request_id,
296 record_batch,
297 schema_bytes,
298 flight_data,
299 _guard,
300 })
301 }
302}
303
304pub struct PutRecordBatchRequestStream {
305 flight_data_stream: Streaming<FlightData>,
306 catalog: String,
307 schema_name: String,
308 limiter: Option<ServerMemoryLimiter>,
309 state: StreamState,
312}
313
314enum StreamState {
315 Init,
316 Ready {
317 table_name: TableName,
318 schema: SchemaRef,
319 schema_bytes: Bytes,
320 decoder: FlightDecoder,
321 },
322}
323
324impl PutRecordBatchRequestStream {
325 pub async fn new(
328 flight_data_stream: Streaming<FlightData>,
329 catalog: String,
330 schema: String,
331 limiter: Option<ServerMemoryLimiter>,
332 ) -> TonicResult<Self> {
333 Ok(Self {
334 flight_data_stream,
335 catalog,
336 schema_name: schema,
337 limiter,
338 state: StreamState::Init,
339 })
340 }
341
342 pub fn table_name(&self) -> Option<&TableName> {
345 match &self.state {
346 StreamState::Init => None,
347 StreamState::Ready { table_name, .. } => Some(table_name),
348 }
349 }
350
351 pub fn schema(&self) -> Option<&SchemaRef> {
354 match &self.state {
355 StreamState::Init => None,
356 StreamState::Ready { schema, .. } => Some(schema),
357 }
358 }
359
360 pub fn schema_bytes(&self) -> Option<&Bytes> {
363 match &self.state {
364 StreamState::Init => None,
365 StreamState::Ready { schema_bytes, .. } => Some(schema_bytes),
366 }
367 }
368
369 fn extract_table_name(mut descriptor: FlightDescriptor) -> Result<String> {
370 ensure!(
371 descriptor.r#type == arrow_flight::flight_descriptor::DescriptorType::Path as i32,
372 InvalidParameterSnafu {
373 reason: "expect FlightDescriptor::type == 'Path' only",
374 }
375 );
376 ensure!(
377 descriptor.path.len() == 1,
378 InvalidParameterSnafu {
379 reason: "expect FlightDescriptor::path has only one table name",
380 }
381 );
382 Ok(descriptor.path.remove(0))
383 }
384}
385
386impl Stream for PutRecordBatchRequestStream {
387 type Item = TonicResult<PutRecordBatchRequest>;
388
389 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
390 loop {
391 let poll = ready!(self.flight_data_stream.poll_next_unpin(cx));
392
393 match poll {
394 Some(Ok(flight_data)) => {
395 let limiter = self.limiter.clone();
396
397 match &mut self.state {
398 StreamState::Init => {
399 let flight_descriptor = match flight_data.flight_descriptor.as_ref() {
401 Some(descriptor) => descriptor.clone(),
402 None => {
403 return Poll::Ready(Some(Err(Status::failed_precondition(
404 "table to put is not found in flight descriptor",
405 ))));
406 }
407 };
408
409 let table_name_str = match Self::extract_table_name(flight_descriptor) {
410 Ok(name) => name,
411 Err(e) => {
412 return Poll::Ready(Some(Err(Status::invalid_argument(
413 e.to_string(),
414 ))));
415 }
416 };
417 let table_name = TableName::new(
418 self.catalog.clone(),
419 self.schema_name.clone(),
420 table_name_str,
421 );
422
423 let mut decoder = FlightDecoder::default();
425 let schema_message = decoder.try_decode(&flight_data).map_err(|e| {
426 Status::invalid_argument(format!("Failed to decode schema: {}", e))
427 })?;
428
429 match schema_message {
430 Some(FlightMessage::Schema(schema)) => {
431 let schema_bytes = decoder.schema_bytes().ok_or_else(|| {
432 Status::internal(
433 "decoder should have schema bytes after decoding schema",
434 )
435 })?;
436
437 self.state = StreamState::Ready {
439 table_name,
440 schema,
441 schema_bytes,
442 decoder,
443 };
444 continue;
446 }
447 _ => {
448 return Poll::Ready(Some(Err(Status::failed_precondition(
449 "first message must be a Schema message",
450 ))));
451 }
452 }
453 }
454 StreamState::Ready {
455 table_name,
456 schema: _,
457 schema_bytes,
458 decoder,
459 } => {
460 let request_id = if !flight_data.app_metadata.is_empty() {
462 serde_json::from_slice::<DoPutMetadata>(&flight_data.app_metadata)
463 .map(|meta| meta.request_id())
464 .unwrap_or_default()
465 } else {
466 0
467 };
468
469 match decoder.try_decode(&flight_data) {
471 Ok(Some(FlightMessage::RecordBatch(record_batch))) => {
472 let table_name = table_name.clone();
473 let schema_bytes = schema_bytes.clone();
474 return Poll::Ready(Some(
475 PutRecordBatchRequest::try_new(
476 table_name,
477 record_batch,
478 request_id,
479 schema_bytes,
480 flight_data,
481 limiter.as_ref(),
482 )
483 .map_err(|e| Status::invalid_argument(e.to_string())),
484 ));
485 }
486 Ok(Some(other)) => {
487 debug!("Unexpected flight message: {:?}", other);
488 return Poll::Ready(Some(Err(Status::invalid_argument(
489 "Expected RecordBatch message, got other message type",
490 ))));
491 }
492 Ok(None) => {
493 continue;
495 }
496 Err(e) => {
497 return Poll::Ready(Some(Err(Status::invalid_argument(
498 format!("Failed to decode RecordBatch: {}", e),
499 ))));
500 }
501 }
502 }
503 }
504 }
505 Some(Err(e)) => {
506 return Poll::Ready(Some(Err(e)));
507 }
508 None => {
509 return Poll::Ready(None);
510 }
511 }
512 }
513 }
514}
515
516fn to_flight_data_stream(
517 output: Output,
518 tracing_context: TracingContext,
519 flight_compression: FlightCompression,
520 query_ctx: QueryContextRef,
521) -> TonicStream<FlightData> {
522 match output.data {
523 OutputData::Stream(stream) => {
524 let stream = FlightRecordBatchStream::new(
525 stream,
526 tracing_context,
527 flight_compression,
528 query_ctx,
529 );
530 Box::pin(stream) as _
531 }
532 OutputData::RecordBatches(x) => {
533 let stream = FlightRecordBatchStream::new(
534 x.as_stream(),
535 tracing_context,
536 flight_compression,
537 query_ctx,
538 );
539 Box::pin(stream) as _
540 }
541 OutputData::AffectedRows(rows) => {
542 let stream = tokio_stream::iter(
543 FlightEncoder::default()
544 .encode(FlightMessage::AffectedRows(rows))
545 .into_iter()
546 .map(Ok),
547 );
548 Box::pin(stream) as _
549 }
550 }
551}