1mod stream;
16
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21use api::v1::GreptimeRequest;
22use arrow_flight::flight_service_server::FlightService;
23use arrow_flight::{
24 Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
25 HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
26};
27use async_trait::async_trait;
28use bytes::{self, Bytes};
29use common_error::ext::ErrorExt;
30use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse};
31use common_grpc::flight::{
32 FLOW_EXTENSIONS_METADATA_KEY, FlightDecoder, FlightEncoder, FlightMessage,
33};
34use common_memory_manager::MemoryGuard;
35use common_query::{Output, OutputData};
36use common_recordbatch::DfRecordBatch;
37use common_telemetry::debug;
38use common_telemetry::tracing::info_span;
39use common_telemetry::tracing_context::{FutureExt, TracingContext};
40use datatypes::arrow::datatypes::SchemaRef;
41use futures::{Stream, future, ready};
42use futures_util::{StreamExt, TryStreamExt};
43use prost::Message;
44use query::metrics::terminal_recordbatch_metrics_from_plan_if_requested;
45use query::options::FlowQueryExtensions;
46use session::context::{Channel, QueryContextRef};
47use snafu::{IntoError, ResultExt, ensure};
48use table::table_name::TableName;
49use tokio::sync::mpsc;
50use tokio_stream::wrappers::ReceiverStream;
51use tonic::{Request, Response, Status, Streaming};
52
53use crate::error::{InvalidParameterSnafu, Result, ToJsonSnafu};
54pub use crate::grpc::flight::stream::FlightRecordBatchStream;
55use crate::grpc::greptime_handler::{
56 GreptimeRequestHandler, create_query_context, get_request_type,
57};
58use crate::grpc::{FlightCompression, TonicResult, context_auth};
59use crate::request_memory_limiter::ServerMemoryLimiter;
60use crate::request_memory_metrics::RequestMemoryMetrics;
61use crate::{error, hint_headers};
62
63pub type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + 'static>>;
64
65#[async_trait]
67pub trait FlightCraft: Send + Sync + 'static {
68 async fn do_get(
69 &self,
70 request: Request<Ticket>,
71 ) -> TonicResult<Response<TonicStream<FlightData>>>;
72
73 async fn do_put(
74 &self,
75 request: Request<Streaming<FlightData>>,
76 ) -> TonicResult<Response<TonicStream<PutResult>>> {
77 let _ = request;
78 Err(Status::unimplemented("Not yet implemented"))
79 }
80}
81
82pub type FlightCraftRef = Arc<dyn FlightCraft>;
83
84pub struct FlightCraftWrapper<T: FlightCraft>(pub T);
85
86impl<T: FlightCraft> From<T> for FlightCraftWrapper<T> {
87 fn from(t: T) -> Self {
88 Self(t)
89 }
90}
91
92#[async_trait]
93impl FlightCraft for FlightCraftRef {
94 async fn do_get(
95 &self,
96 request: Request<Ticket>,
97 ) -> TonicResult<Response<TonicStream<FlightData>>> {
98 (**self).do_get(request).await
99 }
100
101 async fn do_put(
102 &self,
103 request: Request<Streaming<FlightData>>,
104 ) -> TonicResult<Response<TonicStream<PutResult>>> {
105 self.as_ref().do_put(request).await
106 }
107}
108
109#[async_trait]
110impl<T: FlightCraft> FlightService for FlightCraftWrapper<T> {
111 type HandshakeStream = TonicStream<HandshakeResponse>;
112
113 async fn handshake(
114 &self,
115 _: Request<Streaming<HandshakeRequest>>,
116 ) -> TonicResult<Response<Self::HandshakeStream>> {
117 Err(Status::unimplemented("Not yet implemented"))
118 }
119
120 type ListFlightsStream = TonicStream<FlightInfo>;
121
122 async fn list_flights(
123 &self,
124 _: Request<Criteria>,
125 ) -> TonicResult<Response<Self::ListFlightsStream>> {
126 Err(Status::unimplemented("Not yet implemented"))
127 }
128
129 async fn get_flight_info(
130 &self,
131 _: Request<FlightDescriptor>,
132 ) -> TonicResult<Response<FlightInfo>> {
133 Err(Status::unimplemented("Not yet implemented"))
134 }
135
136 async fn poll_flight_info(
137 &self,
138 _: Request<FlightDescriptor>,
139 ) -> TonicResult<Response<PollInfo>> {
140 Err(Status::unimplemented("Not yet implemented"))
141 }
142
143 async fn get_schema(
144 &self,
145 _: Request<FlightDescriptor>,
146 ) -> TonicResult<Response<SchemaResult>> {
147 Err(Status::unimplemented("Not yet implemented"))
148 }
149
150 type DoGetStream = TonicStream<FlightData>;
151
152 async fn do_get(&self, request: Request<Ticket>) -> TonicResult<Response<Self::DoGetStream>> {
153 self.0.do_get(request).await
154 }
155
156 type DoPutStream = TonicStream<PutResult>;
157
158 async fn do_put(
159 &self,
160 request: Request<Streaming<FlightData>>,
161 ) -> TonicResult<Response<Self::DoPutStream>> {
162 self.0.do_put(request).await
163 }
164
165 type DoExchangeStream = TonicStream<FlightData>;
166
167 async fn do_exchange(
168 &self,
169 _: Request<Streaming<FlightData>>,
170 ) -> TonicResult<Response<Self::DoExchangeStream>> {
171 Err(Status::unimplemented("Not yet implemented"))
172 }
173
174 type DoActionStream = TonicStream<arrow_flight::Result>;
175
176 async fn do_action(&self, _: Request<Action>) -> TonicResult<Response<Self::DoActionStream>> {
177 Err(Status::unimplemented("Not yet implemented"))
178 }
179
180 type ListActionsStream = TonicStream<ActionType>;
181
182 async fn list_actions(
183 &self,
184 _: Request<Empty>,
185 ) -> TonicResult<Response<Self::ListActionsStream>> {
186 Err(Status::unimplemented("Not yet implemented"))
187 }
188}
189
190#[async_trait]
191impl FlightCraft for GreptimeRequestHandler {
192 async fn do_get(
193 &self,
194 request: Request<Ticket>,
195 ) -> TonicResult<Response<TonicStream<FlightData>>> {
196 let mut hints = hint_headers::extract_hints(request.metadata());
197 hints.extend(extract_flow_extensions(request.metadata())?);
198
199 let ticket = request.into_inner().ticket;
200 let request =
201 GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?;
202 let query_ctx =
203 create_query_context(Channel::Grpc, request.header.as_ref(), hints.clone())?;
204 let flow_extensions = FlowQueryExtensions::parse_flow_extensions(&query_ctx.extensions())
208 .map_err(|e| Status::invalid_argument(e.output_msg()))?;
209 let should_emit_terminal_metrics = flow_extensions
210 .as_ref()
211 .is_some_and(|extensions| extensions.should_collect_region_watermark());
212
213 let span = info_span!(
215 "GreptimeRequestHandler::do_get",
216 protocol = "grpc",
217 request_type = get_request_type(&request)
218 );
219 let flight_compression = self.flight_compression;
220 async {
221 let output = self.handle_request(request, hints).await?;
222 let stream = to_flight_data_stream(
223 output,
224 TracingContext::from_current_span(),
225 flight_compression,
226 query_ctx,
227 should_emit_terminal_metrics,
228 );
229 Ok(Response::new(stream))
230 }
231 .trace(span)
232 .await
233 }
234
235 async fn do_put(
236 &self,
237 request: Request<Streaming<FlightData>>,
238 ) -> TonicResult<Response<TonicStream<PutResult>>> {
239 let (headers, extensions, stream) = request.into_parts();
240
241 let limiter = extensions.get::<ServerMemoryLimiter>().cloned();
242
243 let query_ctx = context_auth::create_query_context_from_grpc_metadata(&headers)?;
244 context_auth::check_auth(self.user_provider.clone(), &headers, query_ctx.clone()).await?;
245
246 const MAX_PENDING_RESPONSES: usize = 32;
247 let (tx, rx) = mpsc::channel::<TonicResult<DoPutResponse>>(MAX_PENDING_RESPONSES);
248
249 let stream = PutRecordBatchRequestStream::new(
250 stream,
251 query_ctx.current_catalog().to_string(),
252 query_ctx.current_schema(),
253 limiter,
254 )
255 .await?;
256 let _ = tx.send(Ok(DoPutResponse::new(0, 0, 0.0))).await;
258 self.put_record_batches(stream, tx, query_ctx).await;
259
260 let response = ReceiverStream::new(rx)
261 .and_then(|response| {
262 future::ready({
263 serde_json::to_vec(&response)
264 .context(ToJsonSnafu)
265 .map(|x| PutResult {
266 app_metadata: Bytes::from(x),
267 })
268 .map_err(Into::into)
269 })
270 })
271 .boxed();
272 Ok(Response::new(response))
273 }
274}
275
276pub struct PutRecordBatchRequest {
277 pub table_name: TableName,
278 pub request_id: i64,
279 pub timestamp_range: Option<(i64, i64)>,
280 pub record_batch: DfRecordBatch,
281 pub schema_bytes: Bytes,
282 pub flight_data: FlightData,
283 pub(crate) _guard: Option<MemoryGuard<RequestMemoryMetrics>>,
284}
285
286impl PutRecordBatchRequest {
287 fn try_new(
288 table_name: TableName,
289 record_batch: DfRecordBatch,
290 request_id: i64,
291 timestamp_range: Option<(i64, i64)>,
292 schema_bytes: Bytes,
293 flight_data: FlightData,
294 limiter: Option<&ServerMemoryLimiter>,
295 ) -> Result<Self> {
296 let memory_usage = flight_data.data_body.len()
297 + flight_data.app_metadata.len()
298 + flight_data.data_header.len();
299
300 let _guard = if let Some(limiter) = limiter {
301 let guard = limiter.try_acquire(memory_usage as u64).ok_or_else(|| {
302 let inner_err = common_memory_manager::Error::MemoryLimitExceeded {
303 requested_bytes: memory_usage as u64,
304 limit_bytes: limiter.limit_bytes(),
305 };
306 error::MemoryLimitExceededSnafu.into_error(inner_err)
307 })?;
308 Some(guard)
309 } else {
310 None
311 };
312
313 Ok(Self {
314 table_name,
315 request_id,
316 timestamp_range,
317 record_batch,
318 schema_bytes,
319 flight_data,
320 _guard,
321 })
322 }
323}
324
325pub struct PutRecordBatchRequestStream {
326 flight_data_stream: Streaming<FlightData>,
327 catalog: String,
328 schema_name: String,
329 limiter: Option<ServerMemoryLimiter>,
330 state: StreamState,
333}
334
335enum StreamState {
336 Init,
337 Ready {
338 table_name: TableName,
339 schema: SchemaRef,
340 schema_bytes: Bytes,
341 decoder: FlightDecoder,
342 },
343}
344
345impl PutRecordBatchRequestStream {
346 pub async fn new(
349 flight_data_stream: Streaming<FlightData>,
350 catalog: String,
351 schema: String,
352 limiter: Option<ServerMemoryLimiter>,
353 ) -> TonicResult<Self> {
354 Ok(Self {
355 flight_data_stream,
356 catalog,
357 schema_name: schema,
358 limiter,
359 state: StreamState::Init,
360 })
361 }
362
363 pub fn table_name(&self) -> Option<&TableName> {
366 match &self.state {
367 StreamState::Init => None,
368 StreamState::Ready { table_name, .. } => Some(table_name),
369 }
370 }
371
372 pub fn schema(&self) -> Option<&SchemaRef> {
375 match &self.state {
376 StreamState::Init => None,
377 StreamState::Ready { schema, .. } => Some(schema),
378 }
379 }
380
381 pub fn schema_bytes(&self) -> Option<&Bytes> {
384 match &self.state {
385 StreamState::Init => None,
386 StreamState::Ready { schema_bytes, .. } => Some(schema_bytes),
387 }
388 }
389
390 fn extract_table_name(mut descriptor: FlightDescriptor) -> Result<String> {
391 ensure!(
392 descriptor.r#type == arrow_flight::flight_descriptor::DescriptorType::Path as i32,
393 InvalidParameterSnafu {
394 reason: "expect FlightDescriptor::type == 'Path' only",
395 }
396 );
397 ensure!(
398 descriptor.path.len() == 1,
399 InvalidParameterSnafu {
400 reason: "expect FlightDescriptor::path has only one table name",
401 }
402 );
403 Ok(descriptor.path.remove(0))
404 }
405}
406
407impl Stream for PutRecordBatchRequestStream {
408 type Item = TonicResult<PutRecordBatchRequest>;
409
410 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
411 loop {
412 let poll = ready!(self.flight_data_stream.poll_next_unpin(cx));
413
414 match poll {
415 Some(Ok(flight_data)) => {
416 let limiter = self.limiter.clone();
417
418 match &mut self.state {
419 StreamState::Init => {
420 let flight_descriptor = match flight_data.flight_descriptor.as_ref() {
422 Some(descriptor) => descriptor.clone(),
423 None => {
424 return Poll::Ready(Some(Err(Status::failed_precondition(
425 "table to put is not found in flight descriptor",
426 ))));
427 }
428 };
429
430 let table_name_str = match Self::extract_table_name(flight_descriptor) {
431 Ok(name) => name,
432 Err(e) => {
433 return Poll::Ready(Some(Err(Status::invalid_argument(
434 e.to_string(),
435 ))));
436 }
437 };
438 let table_name = TableName::new(
439 self.catalog.clone(),
440 self.schema_name.clone(),
441 table_name_str,
442 );
443
444 let mut decoder = FlightDecoder::default();
446 let schema_message = decoder.try_decode(&flight_data).map_err(|e| {
447 Status::invalid_argument(format!("Failed to decode schema: {}", e))
448 })?;
449
450 match schema_message {
451 Some(FlightMessage::Schema(schema)) => {
452 let schema_bytes = decoder.schema_bytes().ok_or_else(|| {
453 Status::internal(
454 "decoder should have schema bytes after decoding schema",
455 )
456 })?;
457
458 self.state = StreamState::Ready {
460 table_name,
461 schema,
462 schema_bytes,
463 decoder,
464 };
465 continue;
467 }
468 _ => {
469 return Poll::Ready(Some(Err(Status::failed_precondition(
470 "first message must be a Schema message",
471 ))));
472 }
473 }
474 }
475 StreamState::Ready {
476 table_name,
477 schema: _,
478 schema_bytes,
479 decoder,
480 } => {
481 let metadata = if !flight_data.app_metadata.is_empty() {
483 serde_json::from_slice::<DoPutMetadata>(&flight_data.app_metadata)
484 .ok()
485 } else {
486 None
487 };
488 let request_id = metadata
489 .as_ref()
490 .map(|meta| meta.request_id())
491 .unwrap_or_default();
492 let timestamp_range = metadata.and_then(|meta| meta.timestamp_range());
493
494 match decoder.try_decode(&flight_data) {
496 Ok(Some(FlightMessage::RecordBatch(record_batch))) => {
497 let table_name = table_name.clone();
498 let schema_bytes = schema_bytes.clone();
499 return Poll::Ready(Some(
500 PutRecordBatchRequest::try_new(
501 table_name,
502 record_batch,
503 request_id,
504 timestamp_range,
505 schema_bytes,
506 flight_data,
507 limiter.as_ref(),
508 )
509 .map_err(|e| Status::invalid_argument(e.to_string())),
510 ));
511 }
512 Ok(Some(other)) => {
513 debug!("Unexpected flight message: {:?}", other);
514 return Poll::Ready(Some(Err(Status::invalid_argument(
515 "Expected RecordBatch message, got other message type",
516 ))));
517 }
518 Ok(None) => {
519 continue;
521 }
522 Err(e) => {
523 return Poll::Ready(Some(Err(Status::invalid_argument(
524 format!("Failed to decode RecordBatch: {}", e),
525 ))));
526 }
527 }
528 }
529 }
530 }
531 Some(Err(e)) => {
532 return Poll::Ready(Some(Err(e)));
533 }
534 None => {
535 return Poll::Ready(None);
536 }
537 }
538 }
539 }
540}
541
542fn extract_flow_extensions(
543 metadata: &tonic::metadata::MetadataMap,
544) -> TonicResult<Vec<(String, String)>> {
545 let Some(value) = metadata.get(FLOW_EXTENSIONS_METADATA_KEY) else {
546 return Ok(vec![]);
547 };
548
549 let value = value.to_str().map_err(|e| {
550 Status::invalid_argument(format!(
551 "Invalid {FLOW_EXTENSIONS_METADATA_KEY} metadata value: {e}"
552 ))
553 })?;
554
555 serde_json::from_str::<Vec<(String, String)>>(value).map_err(|e| {
556 Status::invalid_argument(format!(
557 "Invalid {FLOW_EXTENSIONS_METADATA_KEY} metadata JSON: {e}"
558 ))
559 })
560}
561
562fn to_flight_data_stream(
563 output: Output,
564 tracing_context: TracingContext,
565 flight_compression: FlightCompression,
566 query_ctx: QueryContextRef,
567 should_emit_terminal_metrics: bool,
568) -> TonicStream<FlightData> {
569 match output.data {
570 OutputData::Stream(stream) => {
571 let stream = FlightRecordBatchStream::new(
572 stream,
573 tracing_context,
574 flight_compression,
575 query_ctx,
576 );
577 Box::pin(stream) as _
578 }
579 OutputData::RecordBatches(x) => {
580 let stream = FlightRecordBatchStream::new(
581 x.as_stream(),
582 tracing_context,
583 flight_compression,
584 query_ctx,
585 );
586 Box::pin(stream) as _
587 }
588 OutputData::AffectedRows(rows) => {
589 let terminal_metrics = match terminal_recordbatch_metrics_from_plan_if_requested(
590 output.meta.plan,
591 should_emit_terminal_metrics,
592 ) {
593 Some(metrics) => match serde_json::to_string(&metrics) {
594 Ok(metrics) => Some(metrics),
595 Err(e) => {
596 let stream = tokio_stream::once(Err(Status::internal(format!(
597 "Failed to serialize terminal metrics: {e}"
598 ))));
599 return Box::pin(stream) as _;
600 }
601 },
602 None => None,
603 };
604 let affected_rows = FlightEncoder::default().encode(FlightMessage::AffectedRows {
605 rows,
606 metrics: terminal_metrics,
607 });
608 let stream = tokio_stream::iter(affected_rows.into_iter().map(Ok));
609 Box::pin(stream) as _
610 }
611 }
612}
613
614#[cfg(test)]
615mod tests {
616 use tonic::metadata::{AsciiMetadataValue, MetadataMap};
617
618 use super::*;
619
620 #[test]
621 fn test_extract_flow_extensions_preserves_comma_bearing_values() {
622 let mut metadata = MetadataMap::new();
623 metadata.insert(
624 FLOW_EXTENSIONS_METADATA_KEY,
625 AsciiMetadataValue::try_from(
626 r#"[["flow.return_region_seq","true"],["flow.incremental_after_seqs","{\"1\":10,\"2\":20}"]]"#,
627 )
628 .unwrap(),
629 );
630
631 let extensions = extract_flow_extensions(&metadata).unwrap();
632 assert_eq!(
633 extensions,
634 vec![
635 ("flow.return_region_seq".to_string(), "true".to_string()),
636 (
637 "flow.incremental_after_seqs".to_string(),
638 r#"{"1":10,"2":20}"#.to_string()
639 ),
640 ]
641 );
642 }
643
644 #[test]
645 fn test_extract_flow_extensions_rejects_invalid_json() {
646 let mut metadata = MetadataMap::new();
647 metadata.insert(
648 FLOW_EXTENSIONS_METADATA_KEY,
649 AsciiMetadataValue::try_from("not-json").unwrap(),
650 );
651
652 let err = extract_flow_extensions(&metadata).unwrap_err();
653 assert_eq!(err.code(), tonic::Code::InvalidArgument);
654 }
655}