1mod stream;
16
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21use api::v1::GreptimeRequest;
22use arrow_flight::flight_service_server::FlightService;
23use arrow_flight::{
24 Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
25 HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
26};
27use async_trait::async_trait;
28use bytes::{self, Bytes};
29use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse};
30use common_grpc::flight::{FlightDecoder, FlightEncoder, FlightMessage};
31use common_memory_manager::MemoryGuard;
32use common_query::{Output, OutputData};
33use common_recordbatch::DfRecordBatch;
34use common_telemetry::debug;
35use common_telemetry::tracing::info_span;
36use common_telemetry::tracing_context::{FutureExt, TracingContext};
37use datatypes::arrow::datatypes::SchemaRef;
38use futures::{Stream, future, ready};
39use futures_util::{StreamExt, TryStreamExt};
40use prost::Message;
41use session::context::{QueryContext, QueryContextRef};
42use snafu::{IntoError, ResultExt, ensure};
43use table::table_name::TableName;
44use tokio::sync::mpsc;
45use tokio_stream::wrappers::ReceiverStream;
46use tonic::{Request, Response, Status, Streaming};
47
48use crate::error::{InvalidParameterSnafu, Result, ToJsonSnafu};
49pub use crate::grpc::flight::stream::FlightRecordBatchStream;
50use crate::grpc::greptime_handler::{GreptimeRequestHandler, get_request_type};
51use crate::grpc::{FlightCompression, TonicResult, context_auth};
52use crate::request_memory_limiter::ServerMemoryLimiter;
53use crate::request_memory_metrics::RequestMemoryMetrics;
54use crate::{error, hint_headers};
55
56pub type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + 'static>>;
57
58#[async_trait]
60pub trait FlightCraft: Send + Sync + 'static {
61 async fn do_get(
62 &self,
63 request: Request<Ticket>,
64 ) -> TonicResult<Response<TonicStream<FlightData>>>;
65
66 async fn do_put(
67 &self,
68 request: Request<Streaming<FlightData>>,
69 ) -> TonicResult<Response<TonicStream<PutResult>>> {
70 let _ = request;
71 Err(Status::unimplemented("Not yet implemented"))
72 }
73}
74
75pub type FlightCraftRef = Arc<dyn FlightCraft>;
76
77pub struct FlightCraftWrapper<T: FlightCraft>(pub T);
78
79impl<T: FlightCraft> From<T> for FlightCraftWrapper<T> {
80 fn from(t: T) -> Self {
81 Self(t)
82 }
83}
84
85#[async_trait]
86impl FlightCraft for FlightCraftRef {
87 async fn do_get(
88 &self,
89 request: Request<Ticket>,
90 ) -> TonicResult<Response<TonicStream<FlightData>>> {
91 (**self).do_get(request).await
92 }
93
94 async fn do_put(
95 &self,
96 request: Request<Streaming<FlightData>>,
97 ) -> TonicResult<Response<TonicStream<PutResult>>> {
98 self.as_ref().do_put(request).await
99 }
100}
101
102#[async_trait]
103impl<T: FlightCraft> FlightService for FlightCraftWrapper<T> {
104 type HandshakeStream = TonicStream<HandshakeResponse>;
105
106 async fn handshake(
107 &self,
108 _: Request<Streaming<HandshakeRequest>>,
109 ) -> TonicResult<Response<Self::HandshakeStream>> {
110 Err(Status::unimplemented("Not yet implemented"))
111 }
112
113 type ListFlightsStream = TonicStream<FlightInfo>;
114
115 async fn list_flights(
116 &self,
117 _: Request<Criteria>,
118 ) -> TonicResult<Response<Self::ListFlightsStream>> {
119 Err(Status::unimplemented("Not yet implemented"))
120 }
121
122 async fn get_flight_info(
123 &self,
124 _: Request<FlightDescriptor>,
125 ) -> TonicResult<Response<FlightInfo>> {
126 Err(Status::unimplemented("Not yet implemented"))
127 }
128
129 async fn poll_flight_info(
130 &self,
131 _: Request<FlightDescriptor>,
132 ) -> TonicResult<Response<PollInfo>> {
133 Err(Status::unimplemented("Not yet implemented"))
134 }
135
136 async fn get_schema(
137 &self,
138 _: Request<FlightDescriptor>,
139 ) -> TonicResult<Response<SchemaResult>> {
140 Err(Status::unimplemented("Not yet implemented"))
141 }
142
143 type DoGetStream = TonicStream<FlightData>;
144
145 async fn do_get(&self, request: Request<Ticket>) -> TonicResult<Response<Self::DoGetStream>> {
146 self.0.do_get(request).await
147 }
148
149 type DoPutStream = TonicStream<PutResult>;
150
151 async fn do_put(
152 &self,
153 request: Request<Streaming<FlightData>>,
154 ) -> TonicResult<Response<Self::DoPutStream>> {
155 self.0.do_put(request).await
156 }
157
158 type DoExchangeStream = TonicStream<FlightData>;
159
160 async fn do_exchange(
161 &self,
162 _: Request<Streaming<FlightData>>,
163 ) -> TonicResult<Response<Self::DoExchangeStream>> {
164 Err(Status::unimplemented("Not yet implemented"))
165 }
166
167 type DoActionStream = TonicStream<arrow_flight::Result>;
168
169 async fn do_action(&self, _: Request<Action>) -> TonicResult<Response<Self::DoActionStream>> {
170 Err(Status::unimplemented("Not yet implemented"))
171 }
172
173 type ListActionsStream = TonicStream<ActionType>;
174
175 async fn list_actions(
176 &self,
177 _: Request<Empty>,
178 ) -> TonicResult<Response<Self::ListActionsStream>> {
179 Err(Status::unimplemented("Not yet implemented"))
180 }
181}
182
183#[async_trait]
184impl FlightCraft for GreptimeRequestHandler {
185 async fn do_get(
186 &self,
187 request: Request<Ticket>,
188 ) -> TonicResult<Response<TonicStream<FlightData>>> {
189 let hints = hint_headers::extract_hints(request.metadata());
190
191 let ticket = request.into_inner().ticket;
192 let request =
193 GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?;
194
195 let span = info_span!(
197 "GreptimeRequestHandler::do_get",
198 protocol = "grpc",
199 request_type = get_request_type(&request)
200 );
201 let flight_compression = self.flight_compression;
202 async {
203 let output = self.handle_request(request, hints).await?;
204 let stream = to_flight_data_stream(
205 output,
206 TracingContext::from_current_span(),
207 flight_compression,
208 QueryContext::arc(),
209 );
210 Ok(Response::new(stream))
211 }
212 .trace(span)
213 .await
214 }
215
216 async fn do_put(
217 &self,
218 request: Request<Streaming<FlightData>>,
219 ) -> TonicResult<Response<TonicStream<PutResult>>> {
220 let (headers, extensions, stream) = request.into_parts();
221
222 let limiter = extensions.get::<ServerMemoryLimiter>().cloned();
223
224 let query_ctx = context_auth::create_query_context_from_grpc_metadata(&headers)?;
225 context_auth::check_auth(self.user_provider.clone(), &headers, query_ctx.clone()).await?;
226
227 const MAX_PENDING_RESPONSES: usize = 32;
228 let (tx, rx) = mpsc::channel::<TonicResult<DoPutResponse>>(MAX_PENDING_RESPONSES);
229
230 let stream = PutRecordBatchRequestStream::new(
231 stream,
232 query_ctx.current_catalog().to_string(),
233 query_ctx.current_schema(),
234 limiter,
235 )
236 .await?;
237 let _ = tx.send(Ok(DoPutResponse::new(0, 0, 0.0))).await;
239 self.put_record_batches(stream, tx, query_ctx).await;
240
241 let response = ReceiverStream::new(rx)
242 .and_then(|response| {
243 future::ready({
244 serde_json::to_vec(&response)
245 .context(ToJsonSnafu)
246 .map(|x| PutResult {
247 app_metadata: Bytes::from(x),
248 })
249 .map_err(Into::into)
250 })
251 })
252 .boxed();
253 Ok(Response::new(response))
254 }
255}
256
257pub struct PutRecordBatchRequest {
258 pub table_name: TableName,
259 pub request_id: i64,
260 pub record_batch: DfRecordBatch,
261 pub schema_bytes: Bytes,
262 pub flight_data: FlightData,
263 pub(crate) _guard: Option<MemoryGuard<RequestMemoryMetrics>>,
264}
265
266impl PutRecordBatchRequest {
267 fn try_new(
268 table_name: TableName,
269 record_batch: DfRecordBatch,
270 request_id: i64,
271 schema_bytes: Bytes,
272 flight_data: FlightData,
273 limiter: Option<&ServerMemoryLimiter>,
274 ) -> Result<Self> {
275 let memory_usage = flight_data.data_body.len()
276 + flight_data.app_metadata.len()
277 + flight_data.data_header.len();
278
279 let _guard = if let Some(limiter) = limiter {
280 let guard = limiter.try_acquire(memory_usage as u64).ok_or_else(|| {
281 let inner_err = common_memory_manager::Error::MemoryLimitExceeded {
282 requested_bytes: memory_usage as u64,
283 limit_bytes: limiter.limit_bytes(),
284 };
285 error::MemoryLimitExceededSnafu.into_error(inner_err)
286 })?;
287 Some(guard)
288 } else {
289 None
290 };
291
292 Ok(Self {
293 table_name,
294 request_id,
295 record_batch,
296 schema_bytes,
297 flight_data,
298 _guard,
299 })
300 }
301}
302
303pub struct PutRecordBatchRequestStream {
304 flight_data_stream: Streaming<FlightData>,
305 catalog: String,
306 schema_name: String,
307 limiter: Option<ServerMemoryLimiter>,
308 state: StreamState,
311}
312
313enum StreamState {
314 Init,
315 Ready {
316 table_name: TableName,
317 schema: SchemaRef,
318 schema_bytes: Bytes,
319 decoder: FlightDecoder,
320 },
321}
322
323impl PutRecordBatchRequestStream {
324 pub async fn new(
327 flight_data_stream: Streaming<FlightData>,
328 catalog: String,
329 schema: String,
330 limiter: Option<ServerMemoryLimiter>,
331 ) -> TonicResult<Self> {
332 Ok(Self {
333 flight_data_stream,
334 catalog,
335 schema_name: schema,
336 limiter,
337 state: StreamState::Init,
338 })
339 }
340
341 pub fn table_name(&self) -> Option<&TableName> {
344 match &self.state {
345 StreamState::Init => None,
346 StreamState::Ready { table_name, .. } => Some(table_name),
347 }
348 }
349
350 pub fn schema(&self) -> Option<&SchemaRef> {
353 match &self.state {
354 StreamState::Init => None,
355 StreamState::Ready { schema, .. } => Some(schema),
356 }
357 }
358
359 pub fn schema_bytes(&self) -> Option<&Bytes> {
362 match &self.state {
363 StreamState::Init => None,
364 StreamState::Ready { schema_bytes, .. } => Some(schema_bytes),
365 }
366 }
367
368 fn extract_table_name(mut descriptor: FlightDescriptor) -> Result<String> {
369 ensure!(
370 descriptor.r#type == arrow_flight::flight_descriptor::DescriptorType::Path as i32,
371 InvalidParameterSnafu {
372 reason: "expect FlightDescriptor::type == 'Path' only",
373 }
374 );
375 ensure!(
376 descriptor.path.len() == 1,
377 InvalidParameterSnafu {
378 reason: "expect FlightDescriptor::path has only one table name",
379 }
380 );
381 Ok(descriptor.path.remove(0))
382 }
383}
384
385impl Stream for PutRecordBatchRequestStream {
386 type Item = TonicResult<PutRecordBatchRequest>;
387
388 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
389 loop {
390 let poll = ready!(self.flight_data_stream.poll_next_unpin(cx));
391
392 match poll {
393 Some(Ok(flight_data)) => {
394 let limiter = self.limiter.clone();
395
396 match &mut self.state {
397 StreamState::Init => {
398 let flight_descriptor = match flight_data.flight_descriptor.as_ref() {
400 Some(descriptor) => descriptor.clone(),
401 None => {
402 return Poll::Ready(Some(Err(Status::failed_precondition(
403 "table to put is not found in flight descriptor",
404 ))));
405 }
406 };
407
408 let table_name_str = match Self::extract_table_name(flight_descriptor) {
409 Ok(name) => name,
410 Err(e) => {
411 return Poll::Ready(Some(Err(Status::invalid_argument(
412 e.to_string(),
413 ))));
414 }
415 };
416 let table_name = TableName::new(
417 self.catalog.clone(),
418 self.schema_name.clone(),
419 table_name_str,
420 );
421
422 let mut decoder = FlightDecoder::default();
424 let schema_message = decoder.try_decode(&flight_data).map_err(|e| {
425 Status::invalid_argument(format!("Failed to decode schema: {}", e))
426 })?;
427
428 match schema_message {
429 Some(FlightMessage::Schema(schema)) => {
430 let schema_bytes = decoder.schema_bytes().ok_or_else(|| {
431 Status::internal(
432 "decoder should have schema bytes after decoding schema",
433 )
434 })?;
435
436 self.state = StreamState::Ready {
438 table_name,
439 schema,
440 schema_bytes,
441 decoder,
442 };
443 continue;
445 }
446 _ => {
447 return Poll::Ready(Some(Err(Status::failed_precondition(
448 "first message must be a Schema message",
449 ))));
450 }
451 }
452 }
453 StreamState::Ready {
454 table_name,
455 schema: _,
456 schema_bytes,
457 decoder,
458 } => {
459 let request_id = if !flight_data.app_metadata.is_empty() {
461 serde_json::from_slice::<DoPutMetadata>(&flight_data.app_metadata)
462 .map(|meta| meta.request_id())
463 .unwrap_or_default()
464 } else {
465 0
466 };
467
468 match decoder.try_decode(&flight_data) {
470 Ok(Some(FlightMessage::RecordBatch(record_batch))) => {
471 let table_name = table_name.clone();
472 let schema_bytes = schema_bytes.clone();
473 return Poll::Ready(Some(
474 PutRecordBatchRequest::try_new(
475 table_name,
476 record_batch,
477 request_id,
478 schema_bytes,
479 flight_data,
480 limiter.as_ref(),
481 )
482 .map_err(|e| Status::invalid_argument(e.to_string())),
483 ));
484 }
485 Ok(Some(other)) => {
486 debug!("Unexpected flight message: {:?}", other);
487 return Poll::Ready(Some(Err(Status::invalid_argument(
488 "Expected RecordBatch message, got other message type",
489 ))));
490 }
491 Ok(None) => {
492 continue;
494 }
495 Err(e) => {
496 return Poll::Ready(Some(Err(Status::invalid_argument(
497 format!("Failed to decode RecordBatch: {}", e),
498 ))));
499 }
500 }
501 }
502 }
503 }
504 Some(Err(e)) => {
505 return Poll::Ready(Some(Err(e)));
506 }
507 None => {
508 return Poll::Ready(None);
509 }
510 }
511 }
512 }
513}
514
515fn to_flight_data_stream(
516 output: Output,
517 tracing_context: TracingContext,
518 flight_compression: FlightCompression,
519 query_ctx: QueryContextRef,
520) -> TonicStream<FlightData> {
521 match output.data {
522 OutputData::Stream(stream) => {
523 let stream = FlightRecordBatchStream::new(
524 stream,
525 tracing_context,
526 flight_compression,
527 query_ctx,
528 );
529 Box::pin(stream) as _
530 }
531 OutputData::RecordBatches(x) => {
532 let stream = FlightRecordBatchStream::new(
533 x.as_stream(),
534 tracing_context,
535 flight_compression,
536 query_ctx,
537 );
538 Box::pin(stream) as _
539 }
540 OutputData::AffectedRows(rows) => {
541 let stream = tokio_stream::iter(
542 FlightEncoder::default()
543 .encode(FlightMessage::AffectedRows(rows))
544 .into_iter()
545 .map(Ok),
546 );
547 Box::pin(stream) as _
548 }
549 }
550}