1mod stream;
16
17use std::collections::HashMap;
18use std::pin::Pin;
19use std::sync::Arc;
20use std::task::{Context, Poll};
21
22use api::v1::GreptimeRequest;
23use arrow_flight::flight_service_server::FlightService;
24use arrow_flight::{
25 Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
26 HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
27};
28use async_trait::async_trait;
29use bytes::{self, Bytes};
30use common_error::ext::ErrorExt;
31use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse};
32use common_grpc::flight::{
33 FLOW_EXTENSIONS_METADATA_KEY, FlightDecoder, FlightEncoder, FlightMessage,
34 SNAPSHOT_SEQS_METADATA_KEY,
35};
36use common_memory_manager::MemoryGuard;
37use common_query::{Output, OutputData};
38use common_recordbatch::DfRecordBatch;
39use common_telemetry::debug;
40use common_telemetry::tracing::info_span;
41use common_telemetry::tracing_context::{FutureExt, TracingContext};
42use datatypes::arrow::datatypes::SchemaRef;
43use futures::{Stream, future, ready};
44use futures_util::{StreamExt, TryStreamExt};
45use prost::Message;
46use query::metrics::terminal_recordbatch_metrics_from_plan_if_requested;
47use query::options::FlowQueryExtensions;
48use session::context::{Channel, QueryContextRef};
49use snafu::{IntoError, ResultExt, ensure};
50use table::table_name::TableName;
51use tokio::sync::mpsc;
52use tokio_stream::wrappers::ReceiverStream;
53use tonic::{Request, Response, Status, Streaming};
54
55use crate::error::{InvalidParameterSnafu, Result, ToJsonSnafu};
56pub use crate::grpc::flight::stream::FlightRecordBatchStream;
57use crate::grpc::greptime_handler::{
58 GreptimeRequestHandler, create_query_context, get_request_type,
59};
60use crate::grpc::{FlightCompression, TonicResult, context_auth};
61use crate::request_memory_limiter::ServerMemoryLimiter;
62use crate::request_memory_metrics::RequestMemoryMetrics;
63use crate::{error, hint_headers};
64
65pub type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + 'static>>;
66
67#[async_trait]
69pub trait FlightCraft: Send + Sync + 'static {
70 async fn do_get(
71 &self,
72 request: Request<Ticket>,
73 ) -> TonicResult<Response<TonicStream<FlightData>>>;
74
75 async fn do_put(
76 &self,
77 request: Request<Streaming<FlightData>>,
78 ) -> TonicResult<Response<TonicStream<PutResult>>> {
79 let _ = request;
80 Err(Status::unimplemented("Not yet implemented"))
81 }
82}
83
84pub type FlightCraftRef = Arc<dyn FlightCraft>;
85
86pub struct FlightCraftWrapper<T: FlightCraft>(pub T);
87
88impl<T: FlightCraft> From<T> for FlightCraftWrapper<T> {
89 fn from(t: T) -> Self {
90 Self(t)
91 }
92}
93
94#[async_trait]
95impl FlightCraft for FlightCraftRef {
96 async fn do_get(
97 &self,
98 request: Request<Ticket>,
99 ) -> TonicResult<Response<TonicStream<FlightData>>> {
100 (**self).do_get(request).await
101 }
102
103 async fn do_put(
104 &self,
105 request: Request<Streaming<FlightData>>,
106 ) -> TonicResult<Response<TonicStream<PutResult>>> {
107 self.as_ref().do_put(request).await
108 }
109}
110
111#[async_trait]
112impl<T: FlightCraft> FlightService for FlightCraftWrapper<T> {
113 type HandshakeStream = TonicStream<HandshakeResponse>;
114
115 async fn handshake(
116 &self,
117 _: Request<Streaming<HandshakeRequest>>,
118 ) -> TonicResult<Response<Self::HandshakeStream>> {
119 Err(Status::unimplemented("Not yet implemented"))
120 }
121
122 type ListFlightsStream = TonicStream<FlightInfo>;
123
124 async fn list_flights(
125 &self,
126 _: Request<Criteria>,
127 ) -> TonicResult<Response<Self::ListFlightsStream>> {
128 Err(Status::unimplemented("Not yet implemented"))
129 }
130
131 async fn get_flight_info(
132 &self,
133 _: Request<FlightDescriptor>,
134 ) -> TonicResult<Response<FlightInfo>> {
135 Err(Status::unimplemented("Not yet implemented"))
136 }
137
138 async fn poll_flight_info(
139 &self,
140 _: Request<FlightDescriptor>,
141 ) -> TonicResult<Response<PollInfo>> {
142 Err(Status::unimplemented("Not yet implemented"))
143 }
144
145 async fn get_schema(
146 &self,
147 _: Request<FlightDescriptor>,
148 ) -> TonicResult<Response<SchemaResult>> {
149 Err(Status::unimplemented("Not yet implemented"))
150 }
151
152 type DoGetStream = TonicStream<FlightData>;
153
154 async fn do_get(&self, request: Request<Ticket>) -> TonicResult<Response<Self::DoGetStream>> {
155 self.0.do_get(request).await
156 }
157
158 type DoPutStream = TonicStream<PutResult>;
159
160 async fn do_put(
161 &self,
162 request: Request<Streaming<FlightData>>,
163 ) -> TonicResult<Response<Self::DoPutStream>> {
164 self.0.do_put(request).await
165 }
166
167 type DoExchangeStream = TonicStream<FlightData>;
168
169 async fn do_exchange(
170 &self,
171 _: Request<Streaming<FlightData>>,
172 ) -> TonicResult<Response<Self::DoExchangeStream>> {
173 Err(Status::unimplemented("Not yet implemented"))
174 }
175
176 type DoActionStream = TonicStream<arrow_flight::Result>;
177
178 async fn do_action(&self, _: Request<Action>) -> TonicResult<Response<Self::DoActionStream>> {
179 Err(Status::unimplemented("Not yet implemented"))
180 }
181
182 type ListActionsStream = TonicStream<ActionType>;
183
184 async fn list_actions(
185 &self,
186 _: Request<Empty>,
187 ) -> TonicResult<Response<Self::ListActionsStream>> {
188 Err(Status::unimplemented("Not yet implemented"))
189 }
190}
191
192#[async_trait]
193impl FlightCraft for GreptimeRequestHandler {
194 async fn do_get(
195 &self,
196 request: Request<Ticket>,
197 ) -> TonicResult<Response<TonicStream<FlightData>>> {
198 let mut hints = hint_headers::extract_hints(request.metadata());
199 hints.extend(extract_flow_extensions(request.metadata())?);
200 let snapshot_seqs = extract_snapshot_seqs(request.metadata())?;
201
202 let ticket = request.into_inner().ticket;
203 let request =
204 GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?;
205 let query_ctx =
206 create_query_context(Channel::Grpc, request.header.as_ref(), hints, snapshot_seqs)?;
207 let flow_extensions = FlowQueryExtensions::parse_flow_extensions(&query_ctx.extensions())
211 .map_err(|e| Status::invalid_argument(e.output_msg()))?;
212 let should_emit_terminal_metrics = flow_extensions
213 .as_ref()
214 .is_some_and(|extensions| extensions.should_collect_region_watermark());
215
216 let span = info_span!(
218 "GreptimeRequestHandler::do_get",
219 protocol = "grpc",
220 request_type = get_request_type(&request)
221 );
222 let flight_compression = self.flight_compression;
223 async {
224 let output = self
225 .handle_request_with_query_ctx(request, query_ctx.clone())
226 .await?;
227 let stream = to_flight_data_stream(
228 output,
229 TracingContext::from_current_span(),
230 flight_compression,
231 query_ctx,
232 should_emit_terminal_metrics,
233 );
234 Ok(Response::new(stream))
235 }
236 .trace(span)
237 .await
238 }
239
240 async fn do_put(
241 &self,
242 request: Request<Streaming<FlightData>>,
243 ) -> TonicResult<Response<TonicStream<PutResult>>> {
244 let (headers, extensions, stream) = request.into_parts();
245
246 let limiter = extensions.get::<ServerMemoryLimiter>().cloned();
247
248 let query_ctx = context_auth::create_query_context_from_grpc_metadata(&headers)?;
249 context_auth::check_auth(self.user_provider.clone(), &headers, query_ctx.clone()).await?;
250
251 const MAX_PENDING_RESPONSES: usize = 32;
252 let (tx, rx) = mpsc::channel::<TonicResult<DoPutResponse>>(MAX_PENDING_RESPONSES);
253
254 let stream = PutRecordBatchRequestStream::new(
255 stream,
256 query_ctx.current_catalog().to_string(),
257 query_ctx.current_schema(),
258 limiter,
259 )
260 .await?;
261 let _ = tx.send(Ok(DoPutResponse::new(0, 0, 0.0))).await;
263 self.put_record_batches(stream, tx, query_ctx).await;
264
265 let response = ReceiverStream::new(rx)
266 .and_then(|response| {
267 future::ready({
268 serde_json::to_vec(&response)
269 .context(ToJsonSnafu)
270 .map(|x| PutResult {
271 app_metadata: Bytes::from(x),
272 })
273 .map_err(Into::into)
274 })
275 })
276 .boxed();
277 Ok(Response::new(response))
278 }
279}
280
281pub struct PutRecordBatchRequest {
282 pub table_name: TableName,
283 pub request_id: i64,
284 pub timestamp_range: Option<(i64, i64)>,
285 pub record_batch: DfRecordBatch,
286 pub schema_bytes: Bytes,
287 pub flight_data: FlightData,
288 pub(crate) _guard: Option<MemoryGuard<RequestMemoryMetrics>>,
289}
290
291impl PutRecordBatchRequest {
292 fn try_new(
293 table_name: TableName,
294 record_batch: DfRecordBatch,
295 request_id: i64,
296 timestamp_range: Option<(i64, i64)>,
297 schema_bytes: Bytes,
298 flight_data: FlightData,
299 limiter: Option<&ServerMemoryLimiter>,
300 ) -> Result<Self> {
301 let memory_usage = flight_data.data_body.len()
302 + flight_data.app_metadata.len()
303 + flight_data.data_header.len();
304
305 let _guard = if let Some(limiter) = limiter {
306 let guard = limiter.try_acquire(memory_usage as u64).ok_or_else(|| {
307 let inner_err = common_memory_manager::Error::MemoryLimitExceeded {
308 requested_bytes: memory_usage as u64,
309 limit_bytes: limiter.limit_bytes(),
310 };
311 error::MemoryLimitExceededSnafu.into_error(inner_err)
312 })?;
313 Some(guard)
314 } else {
315 None
316 };
317
318 Ok(Self {
319 table_name,
320 request_id,
321 timestamp_range,
322 record_batch,
323 schema_bytes,
324 flight_data,
325 _guard,
326 })
327 }
328}
329
330pub struct PutRecordBatchRequestStream {
331 flight_data_stream: Streaming<FlightData>,
332 catalog: String,
333 schema_name: String,
334 limiter: Option<ServerMemoryLimiter>,
335 state: StreamState,
338}
339
340enum StreamState {
341 Init,
342 Ready {
343 table_name: TableName,
344 schema: SchemaRef,
345 schema_bytes: Bytes,
346 decoder: FlightDecoder,
347 },
348}
349
350impl PutRecordBatchRequestStream {
351 pub async fn new(
354 flight_data_stream: Streaming<FlightData>,
355 catalog: String,
356 schema: String,
357 limiter: Option<ServerMemoryLimiter>,
358 ) -> TonicResult<Self> {
359 Ok(Self {
360 flight_data_stream,
361 catalog,
362 schema_name: schema,
363 limiter,
364 state: StreamState::Init,
365 })
366 }
367
368 pub fn table_name(&self) -> Option<&TableName> {
371 match &self.state {
372 StreamState::Init => None,
373 StreamState::Ready { table_name, .. } => Some(table_name),
374 }
375 }
376
377 pub fn schema(&self) -> Option<&SchemaRef> {
380 match &self.state {
381 StreamState::Init => None,
382 StreamState::Ready { schema, .. } => Some(schema),
383 }
384 }
385
386 pub fn schema_bytes(&self) -> Option<&Bytes> {
389 match &self.state {
390 StreamState::Init => None,
391 StreamState::Ready { schema_bytes, .. } => Some(schema_bytes),
392 }
393 }
394
395 fn extract_table_name(mut descriptor: FlightDescriptor) -> Result<String> {
396 ensure!(
397 descriptor.r#type == arrow_flight::flight_descriptor::DescriptorType::Path as i32,
398 InvalidParameterSnafu {
399 reason: "expect FlightDescriptor::type == 'Path' only",
400 }
401 );
402 ensure!(
403 descriptor.path.len() == 1,
404 InvalidParameterSnafu {
405 reason: "expect FlightDescriptor::path has only one table name",
406 }
407 );
408 Ok(descriptor.path.remove(0))
409 }
410}
411
412impl Stream for PutRecordBatchRequestStream {
413 type Item = TonicResult<PutRecordBatchRequest>;
414
415 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
416 loop {
417 let poll = ready!(self.flight_data_stream.poll_next_unpin(cx));
418
419 match poll {
420 Some(Ok(flight_data)) => {
421 let limiter = self.limiter.clone();
422
423 match &mut self.state {
424 StreamState::Init => {
425 let flight_descriptor = match flight_data.flight_descriptor.as_ref() {
427 Some(descriptor) => descriptor.clone(),
428 None => {
429 return Poll::Ready(Some(Err(Status::failed_precondition(
430 "table to put is not found in flight descriptor",
431 ))));
432 }
433 };
434
435 let table_name_str = match Self::extract_table_name(flight_descriptor) {
436 Ok(name) => name,
437 Err(e) => {
438 return Poll::Ready(Some(Err(Status::invalid_argument(
439 e.to_string(),
440 ))));
441 }
442 };
443 let table_name = TableName::new(
444 self.catalog.clone(),
445 self.schema_name.clone(),
446 table_name_str,
447 );
448
449 let mut decoder = FlightDecoder::default();
451 let schema_message = decoder.try_decode(&flight_data).map_err(|e| {
452 Status::invalid_argument(format!("Failed to decode schema: {}", e))
453 })?;
454
455 match schema_message {
456 Some(FlightMessage::Schema(schema)) => {
457 let schema_bytes = decoder.schema_bytes().ok_or_else(|| {
458 Status::internal(
459 "decoder should have schema bytes after decoding schema",
460 )
461 })?;
462
463 self.state = StreamState::Ready {
465 table_name,
466 schema,
467 schema_bytes,
468 decoder,
469 };
470 continue;
472 }
473 _ => {
474 return Poll::Ready(Some(Err(Status::failed_precondition(
475 "first message must be a Schema message",
476 ))));
477 }
478 }
479 }
480 StreamState::Ready {
481 table_name,
482 schema: _,
483 schema_bytes,
484 decoder,
485 } => {
486 let metadata = if !flight_data.app_metadata.is_empty() {
488 serde_json::from_slice::<DoPutMetadata>(&flight_data.app_metadata)
489 .ok()
490 } else {
491 None
492 };
493 let request_id = metadata
494 .as_ref()
495 .map(|meta| meta.request_id())
496 .unwrap_or_default();
497 let timestamp_range = metadata.and_then(|meta| meta.timestamp_range());
498
499 match decoder.try_decode(&flight_data) {
501 Ok(Some(FlightMessage::RecordBatch(record_batch))) => {
502 let table_name = table_name.clone();
503 let schema_bytes = schema_bytes.clone();
504 return Poll::Ready(Some(
505 PutRecordBatchRequest::try_new(
506 table_name,
507 record_batch,
508 request_id,
509 timestamp_range,
510 schema_bytes,
511 flight_data,
512 limiter.as_ref(),
513 )
514 .map_err(|e| Status::invalid_argument(e.to_string())),
515 ));
516 }
517 Ok(Some(other)) => {
518 debug!("Unexpected flight message: {:?}", other);
519 return Poll::Ready(Some(Err(Status::invalid_argument(
520 "Expected RecordBatch message, got other message type",
521 ))));
522 }
523 Ok(None) => {
524 continue;
526 }
527 Err(e) => {
528 return Poll::Ready(Some(Err(Status::invalid_argument(
529 format!("Failed to decode RecordBatch: {}", e),
530 ))));
531 }
532 }
533 }
534 }
535 }
536 Some(Err(e)) => {
537 return Poll::Ready(Some(Err(e)));
538 }
539 None => {
540 return Poll::Ready(None);
541 }
542 }
543 }
544 }
545}
546
547fn extract_flow_extensions(
548 metadata: &tonic::metadata::MetadataMap,
549) -> TonicResult<Vec<(String, String)>> {
550 Ok(extract_json_metadata(metadata, FLOW_EXTENSIONS_METADATA_KEY)?.unwrap_or_default())
551}
552
553fn extract_snapshot_seqs(
554 metadata: &tonic::metadata::MetadataMap,
555) -> TonicResult<HashMap<u64, u64>> {
556 Ok(extract_json_metadata(metadata, SNAPSHOT_SEQS_METADATA_KEY)?.unwrap_or_default())
557}
558
559fn extract_json_metadata<T: serde::de::DeserializeOwned>(
560 metadata: &tonic::metadata::MetadataMap,
561 key: &'static str,
562) -> TonicResult<Option<T>> {
563 let Some(value) = metadata.get(key) else {
564 return Ok(None);
565 };
566
567 let value = value
568 .to_str()
569 .map_err(|e| Status::invalid_argument(format!("Invalid {key} metadata value: {e}")))?;
570
571 let parsed = serde_json::from_str::<T>(value)
572 .map_err(|e| Status::invalid_argument(format!("Invalid {key} metadata JSON: {e}")))?;
573 Ok(Some(parsed))
574}
575
576fn to_flight_data_stream(
577 output: Output,
578 tracing_context: TracingContext,
579 flight_compression: FlightCompression,
580 query_ctx: QueryContextRef,
581 should_emit_terminal_metrics: bool,
582) -> TonicStream<FlightData> {
583 match output.data {
584 OutputData::Stream(stream) => {
585 let stream = FlightRecordBatchStream::new(
586 stream,
587 tracing_context,
588 flight_compression,
589 query_ctx,
590 );
591 Box::pin(stream) as _
592 }
593 OutputData::RecordBatches(x) => {
594 let stream = FlightRecordBatchStream::new(
595 x.as_stream(),
596 tracing_context,
597 flight_compression,
598 query_ctx,
599 );
600 Box::pin(stream) as _
601 }
602 OutputData::AffectedRows(rows) => {
603 let terminal_metrics = match terminal_recordbatch_metrics_from_plan_if_requested(
604 output.meta.plan,
605 should_emit_terminal_metrics,
606 ) {
607 Some(metrics) => match serde_json::to_string(&metrics) {
608 Ok(metrics) => Some(metrics),
609 Err(e) => {
610 let stream = tokio_stream::once(Err(Status::internal(format!(
611 "Failed to serialize terminal metrics: {e}"
612 ))));
613 return Box::pin(stream) as _;
614 }
615 },
616 None => None,
617 };
618 let affected_rows = FlightEncoder::default().encode(FlightMessage::AffectedRows {
619 rows,
620 metrics: terminal_metrics,
621 });
622 let stream = tokio_stream::iter(affected_rows.into_iter().map(Ok));
623 Box::pin(stream) as _
624 }
625 }
626}
627
628#[cfg(test)]
629mod tests {
630 use query::options::FLOW_SCHEDULED_TIME_MILLIS;
631 use tonic::metadata::{AsciiMetadataValue, MetadataMap};
632
633 use super::*;
634
635 #[test]
636 fn test_extract_flow_extensions_preserves_comma_bearing_values() {
637 let mut metadata = MetadataMap::new();
638 metadata.insert(
639 FLOW_EXTENSIONS_METADATA_KEY,
640 AsciiMetadataValue::try_from(
641 r#"[["flow.return_region_seq","true"],["flow.incremental_after_seqs","{\"1\":10,\"2\":20}"]]"#,
642 )
643 .unwrap(),
644 );
645
646 let extensions = extract_flow_extensions(&metadata).unwrap();
647 assert_eq!(
648 extensions,
649 vec![
650 ("flow.return_region_seq".to_string(), "true".to_string()),
651 (
652 "flow.incremental_after_seqs".to_string(),
653 r#"{"1":10,"2":20}"#.to_string()
654 ),
655 ]
656 );
657 }
658
659 #[test]
660 fn test_flow_extensions_can_carry_scheduled_time() {
661 let mut metadata = MetadataMap::new();
662 metadata.insert(
663 FLOW_EXTENSIONS_METADATA_KEY,
664 AsciiMetadataValue::try_from(r#"[["flow.scheduled_time_millis","1700000000000"]]"#)
665 .unwrap(),
666 );
667
668 let flow_extensions = extract_flow_extensions(&metadata).unwrap();
669 let query_ctx =
670 create_query_context(Channel::Grpc, None, flow_extensions, HashMap::new()).unwrap();
671
672 assert_eq!(
673 query_ctx.extension(FLOW_SCHEDULED_TIME_MILLIS),
674 Some("1700000000000")
675 );
676 }
677
678 #[test]
679 fn test_extract_flow_extensions_rejects_invalid_json() {
680 let mut metadata = MetadataMap::new();
681 metadata.insert(
682 FLOW_EXTENSIONS_METADATA_KEY,
683 AsciiMetadataValue::try_from("not-json").unwrap(),
684 );
685
686 let err = extract_flow_extensions(&metadata).unwrap_err();
687 assert_eq!(err.code(), tonic::Code::InvalidArgument);
688 }
689}