Skip to main content

servers/grpc/
flight.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod stream;
16
17use std::collections::HashMap;
18use std::pin::Pin;
19use std::sync::Arc;
20use std::task::{Context, Poll};
21
22use api::v1::GreptimeRequest;
23use arrow_flight::flight_service_server::FlightService;
24use arrow_flight::{
25    Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
26    HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
27};
28use async_trait::async_trait;
29use bytes::{self, Bytes};
30use common_error::ext::ErrorExt;
31use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse};
32use common_grpc::flight::{
33    FLOW_EXTENSIONS_METADATA_KEY, FlightDecoder, FlightEncoder, FlightMessage,
34    SNAPSHOT_SEQS_METADATA_KEY,
35};
36use common_memory_manager::MemoryGuard;
37use common_query::{Output, OutputData};
38use common_recordbatch::DfRecordBatch;
39use common_telemetry::debug;
40use common_telemetry::tracing::info_span;
41use common_telemetry::tracing_context::{FutureExt, TracingContext};
42use datatypes::arrow::datatypes::SchemaRef;
43use futures::{Stream, future, ready};
44use futures_util::{StreamExt, TryStreamExt};
45use prost::Message;
46use query::metrics::terminal_recordbatch_metrics_from_plan_if_requested;
47use query::options::FlowQueryExtensions;
48use session::context::{Channel, QueryContextRef};
49use snafu::{IntoError, ResultExt, ensure};
50use table::table_name::TableName;
51use tokio::sync::mpsc;
52use tokio_stream::wrappers::ReceiverStream;
53use tonic::{Request, Response, Status, Streaming};
54
55use crate::error::{InvalidParameterSnafu, Result, ToJsonSnafu};
56pub use crate::grpc::flight::stream::FlightRecordBatchStream;
57use crate::grpc::greptime_handler::{
58    GreptimeRequestHandler, create_query_context, get_request_type,
59};
60use crate::grpc::{FlightCompression, TonicResult, context_auth};
61use crate::request_memory_limiter::ServerMemoryLimiter;
62use crate::request_memory_metrics::RequestMemoryMetrics;
63use crate::{error, hint_headers};
64
65pub type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + 'static>>;
66
67/// A subset of [FlightService]
68#[async_trait]
69pub trait FlightCraft: Send + Sync + 'static {
70    async fn do_get(
71        &self,
72        request: Request<Ticket>,
73    ) -> TonicResult<Response<TonicStream<FlightData>>>;
74
75    async fn do_put(
76        &self,
77        request: Request<Streaming<FlightData>>,
78    ) -> TonicResult<Response<TonicStream<PutResult>>> {
79        let _ = request;
80        Err(Status::unimplemented("Not yet implemented"))
81    }
82}
83
84pub type FlightCraftRef = Arc<dyn FlightCraft>;
85
86pub struct FlightCraftWrapper<T: FlightCraft>(pub T);
87
88impl<T: FlightCraft> From<T> for FlightCraftWrapper<T> {
89    fn from(t: T) -> Self {
90        Self(t)
91    }
92}
93
94#[async_trait]
95impl FlightCraft for FlightCraftRef {
96    async fn do_get(
97        &self,
98        request: Request<Ticket>,
99    ) -> TonicResult<Response<TonicStream<FlightData>>> {
100        (**self).do_get(request).await
101    }
102
103    async fn do_put(
104        &self,
105        request: Request<Streaming<FlightData>>,
106    ) -> TonicResult<Response<TonicStream<PutResult>>> {
107        self.as_ref().do_put(request).await
108    }
109}
110
111#[async_trait]
112impl<T: FlightCraft> FlightService for FlightCraftWrapper<T> {
113    type HandshakeStream = TonicStream<HandshakeResponse>;
114
115    async fn handshake(
116        &self,
117        _: Request<Streaming<HandshakeRequest>>,
118    ) -> TonicResult<Response<Self::HandshakeStream>> {
119        Err(Status::unimplemented("Not yet implemented"))
120    }
121
122    type ListFlightsStream = TonicStream<FlightInfo>;
123
124    async fn list_flights(
125        &self,
126        _: Request<Criteria>,
127    ) -> TonicResult<Response<Self::ListFlightsStream>> {
128        Err(Status::unimplemented("Not yet implemented"))
129    }
130
131    async fn get_flight_info(
132        &self,
133        _: Request<FlightDescriptor>,
134    ) -> TonicResult<Response<FlightInfo>> {
135        Err(Status::unimplemented("Not yet implemented"))
136    }
137
138    async fn poll_flight_info(
139        &self,
140        _: Request<FlightDescriptor>,
141    ) -> TonicResult<Response<PollInfo>> {
142        Err(Status::unimplemented("Not yet implemented"))
143    }
144
145    async fn get_schema(
146        &self,
147        _: Request<FlightDescriptor>,
148    ) -> TonicResult<Response<SchemaResult>> {
149        Err(Status::unimplemented("Not yet implemented"))
150    }
151
152    type DoGetStream = TonicStream<FlightData>;
153
154    async fn do_get(&self, request: Request<Ticket>) -> TonicResult<Response<Self::DoGetStream>> {
155        self.0.do_get(request).await
156    }
157
158    type DoPutStream = TonicStream<PutResult>;
159
160    async fn do_put(
161        &self,
162        request: Request<Streaming<FlightData>>,
163    ) -> TonicResult<Response<Self::DoPutStream>> {
164        self.0.do_put(request).await
165    }
166
167    type DoExchangeStream = TonicStream<FlightData>;
168
169    async fn do_exchange(
170        &self,
171        _: Request<Streaming<FlightData>>,
172    ) -> TonicResult<Response<Self::DoExchangeStream>> {
173        Err(Status::unimplemented("Not yet implemented"))
174    }
175
176    type DoActionStream = TonicStream<arrow_flight::Result>;
177
178    async fn do_action(&self, _: Request<Action>) -> TonicResult<Response<Self::DoActionStream>> {
179        Err(Status::unimplemented("Not yet implemented"))
180    }
181
182    type ListActionsStream = TonicStream<ActionType>;
183
184    async fn list_actions(
185        &self,
186        _: Request<Empty>,
187    ) -> TonicResult<Response<Self::ListActionsStream>> {
188        Err(Status::unimplemented("Not yet implemented"))
189    }
190}
191
192#[async_trait]
193impl FlightCraft for GreptimeRequestHandler {
194    async fn do_get(
195        &self,
196        request: Request<Ticket>,
197    ) -> TonicResult<Response<TonicStream<FlightData>>> {
198        let mut hints = hint_headers::extract_hints(request.metadata());
199        hints.extend(extract_flow_extensions(request.metadata())?);
200        let snapshot_seqs = extract_snapshot_seqs(request.metadata())?;
201
202        let ticket = request.into_inner().ticket;
203        let request =
204            GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?;
205        let query_ctx =
206            create_query_context(Channel::Grpc, request.header.as_ref(), hints, snapshot_seqs)?;
207        // Validate flow hint syntax at the transport boundary before dispatching the request.
208        // This does not authorize or execute anything; `handle_request()` below still performs
209        // the normal frontend handling and auth checks before query execution.
210        let flow_extensions = FlowQueryExtensions::parse_flow_extensions(&query_ctx.extensions())
211            .map_err(|e| Status::invalid_argument(e.output_msg()))?;
212        let should_emit_terminal_metrics = flow_extensions
213            .as_ref()
214            .is_some_and(|extensions| extensions.should_collect_region_watermark());
215
216        // The Grpc protocol pass query by Flight. It needs to be wrapped under a span, in order to record stream
217        let span = info_span!(
218            "GreptimeRequestHandler::do_get",
219            protocol = "grpc",
220            request_type = get_request_type(&request)
221        );
222        let flight_compression = self.flight_compression;
223        async {
224            let output = self
225                .handle_request_with_query_ctx(request, query_ctx.clone())
226                .await?;
227            let stream = to_flight_data_stream(
228                output,
229                TracingContext::from_current_span(),
230                flight_compression,
231                query_ctx,
232                should_emit_terminal_metrics,
233            );
234            Ok(Response::new(stream))
235        }
236        .trace(span)
237        .await
238    }
239
240    async fn do_put(
241        &self,
242        request: Request<Streaming<FlightData>>,
243    ) -> TonicResult<Response<TonicStream<PutResult>>> {
244        let (headers, extensions, stream) = request.into_parts();
245
246        let limiter = extensions.get::<ServerMemoryLimiter>().cloned();
247
248        let query_ctx = context_auth::create_query_context_from_grpc_metadata(&headers)?;
249        context_auth::check_auth(self.user_provider.clone(), &headers, query_ctx.clone()).await?;
250
251        const MAX_PENDING_RESPONSES: usize = 32;
252        let (tx, rx) = mpsc::channel::<TonicResult<DoPutResponse>>(MAX_PENDING_RESPONSES);
253
254        let stream = PutRecordBatchRequestStream::new(
255            stream,
256            query_ctx.current_catalog().to_string(),
257            query_ctx.current_schema(),
258            limiter,
259        )
260        .await?;
261        // Ack immediately when stream is created successfully (in Init state)
262        let _ = tx.send(Ok(DoPutResponse::new(0, 0, 0.0))).await;
263        self.put_record_batches(stream, tx, query_ctx).await;
264
265        let response = ReceiverStream::new(rx)
266            .and_then(|response| {
267                future::ready({
268                    serde_json::to_vec(&response)
269                        .context(ToJsonSnafu)
270                        .map(|x| PutResult {
271                            app_metadata: Bytes::from(x),
272                        })
273                        .map_err(Into::into)
274                })
275            })
276            .boxed();
277        Ok(Response::new(response))
278    }
279}
280
281pub struct PutRecordBatchRequest {
282    pub table_name: TableName,
283    pub request_id: i64,
284    pub timestamp_range: Option<(i64, i64)>,
285    pub record_batch: DfRecordBatch,
286    pub schema_bytes: Bytes,
287    pub flight_data: FlightData,
288    pub(crate) _guard: Option<MemoryGuard<RequestMemoryMetrics>>,
289}
290
291impl PutRecordBatchRequest {
292    fn try_new(
293        table_name: TableName,
294        record_batch: DfRecordBatch,
295        request_id: i64,
296        timestamp_range: Option<(i64, i64)>,
297        schema_bytes: Bytes,
298        flight_data: FlightData,
299        limiter: Option<&ServerMemoryLimiter>,
300    ) -> Result<Self> {
301        let memory_usage = flight_data.data_body.len()
302            + flight_data.app_metadata.len()
303            + flight_data.data_header.len();
304
305        let _guard = if let Some(limiter) = limiter {
306            let guard = limiter.try_acquire(memory_usage as u64).ok_or_else(|| {
307                let inner_err = common_memory_manager::Error::MemoryLimitExceeded {
308                    requested_bytes: memory_usage as u64,
309                    limit_bytes: limiter.limit_bytes(),
310                };
311                error::MemoryLimitExceededSnafu.into_error(inner_err)
312            })?;
313            Some(guard)
314        } else {
315            None
316        };
317
318        Ok(Self {
319            table_name,
320            request_id,
321            timestamp_range,
322            record_batch,
323            schema_bytes,
324            flight_data,
325            _guard,
326        })
327    }
328}
329
330pub struct PutRecordBatchRequestStream {
331    flight_data_stream: Streaming<FlightData>,
332    catalog: String,
333    schema_name: String,
334    limiter: Option<ServerMemoryLimiter>,
335    // Client now lazily sends schema data so we cannot eagerly wait for it.
336    // Instead, we need to decode while receiving record batches.
337    state: StreamState,
338}
339
340enum StreamState {
341    Init,
342    Ready {
343        table_name: TableName,
344        schema: SchemaRef,
345        schema_bytes: Bytes,
346        decoder: FlightDecoder,
347    },
348}
349
350impl PutRecordBatchRequestStream {
351    /// Creates a new `PutRecordBatchRequestStream` in Init state.
352    /// The stream will transition to Ready state when it receives the schema message.
353    pub async fn new(
354        flight_data_stream: Streaming<FlightData>,
355        catalog: String,
356        schema: String,
357        limiter: Option<ServerMemoryLimiter>,
358    ) -> TonicResult<Self> {
359        Ok(Self {
360            flight_data_stream,
361            catalog,
362            schema_name: schema,
363            limiter,
364            state: StreamState::Init,
365        })
366    }
367
368    /// Returns the table name extracted from the flight descriptor.
369    /// Returns None if the stream is still in Init state.
370    pub fn table_name(&self) -> Option<&TableName> {
371        match &self.state {
372            StreamState::Init => None,
373            StreamState::Ready { table_name, .. } => Some(table_name),
374        }
375    }
376
377    /// Returns the Arrow schema decoded from the first flight message.
378    /// Returns None if the stream is still in Init state.
379    pub fn schema(&self) -> Option<&SchemaRef> {
380        match &self.state {
381            StreamState::Init => None,
382            StreamState::Ready { schema, .. } => Some(schema),
383        }
384    }
385
386    /// Returns the raw schema bytes in IPC format.
387    /// Returns None if the stream is still in Init state.
388    pub fn schema_bytes(&self) -> Option<&Bytes> {
389        match &self.state {
390            StreamState::Init => None,
391            StreamState::Ready { schema_bytes, .. } => Some(schema_bytes),
392        }
393    }
394
395    fn extract_table_name(mut descriptor: FlightDescriptor) -> Result<String> {
396        ensure!(
397            descriptor.r#type == arrow_flight::flight_descriptor::DescriptorType::Path as i32,
398            InvalidParameterSnafu {
399                reason: "expect FlightDescriptor::type == 'Path' only",
400            }
401        );
402        ensure!(
403            descriptor.path.len() == 1,
404            InvalidParameterSnafu {
405                reason: "expect FlightDescriptor::path has only one table name",
406            }
407        );
408        Ok(descriptor.path.remove(0))
409    }
410}
411
412impl Stream for PutRecordBatchRequestStream {
413    type Item = TonicResult<PutRecordBatchRequest>;
414
415    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
416        loop {
417            let poll = ready!(self.flight_data_stream.poll_next_unpin(cx));
418
419            match poll {
420                Some(Ok(flight_data)) => {
421                    let limiter = self.limiter.clone();
422
423                    match &mut self.state {
424                        StreamState::Init => {
425                            // First message - expecting schema
426                            let flight_descriptor = match flight_data.flight_descriptor.as_ref() {
427                                Some(descriptor) => descriptor.clone(),
428                                None => {
429                                    return Poll::Ready(Some(Err(Status::failed_precondition(
430                                        "table to put is not found in flight descriptor",
431                                    ))));
432                                }
433                            };
434
435                            let table_name_str = match Self::extract_table_name(flight_descriptor) {
436                                Ok(name) => name,
437                                Err(e) => {
438                                    return Poll::Ready(Some(Err(Status::invalid_argument(
439                                        e.to_string(),
440                                    ))));
441                                }
442                            };
443                            let table_name = TableName::new(
444                                self.catalog.clone(),
445                                self.schema_name.clone(),
446                                table_name_str,
447                            );
448
449                            // Decode the schema
450                            let mut decoder = FlightDecoder::default();
451                            let schema_message = decoder.try_decode(&flight_data).map_err(|e| {
452                                Status::invalid_argument(format!("Failed to decode schema: {}", e))
453                            })?;
454
455                            match schema_message {
456                                Some(FlightMessage::Schema(schema)) => {
457                                    let schema_bytes = decoder.schema_bytes().ok_or_else(|| {
458                                        Status::internal(
459                                            "decoder should have schema bytes after decoding schema",
460                                        )
461                                    })?;
462
463                                    // Transition to Ready state with all necessary data
464                                    self.state = StreamState::Ready {
465                                        table_name,
466                                        schema,
467                                        schema_bytes,
468                                        decoder,
469                                    };
470                                    // Continue to next iteration to process RecordBatch messages
471                                    continue;
472                                }
473                                _ => {
474                                    return Poll::Ready(Some(Err(Status::failed_precondition(
475                                        "first message must be a Schema message",
476                                    ))));
477                                }
478                            }
479                        }
480                        StreamState::Ready {
481                            table_name,
482                            schema: _,
483                            schema_bytes,
484                            decoder,
485                        } => {
486                            // Extract request_id and time range from FlightData before decoding
487                            let metadata = if !flight_data.app_metadata.is_empty() {
488                                serde_json::from_slice::<DoPutMetadata>(&flight_data.app_metadata)
489                                    .ok()
490                            } else {
491                                None
492                            };
493                            let request_id = metadata
494                                .as_ref()
495                                .map(|meta| meta.request_id())
496                                .unwrap_or_default();
497                            let timestamp_range = metadata.and_then(|meta| meta.timestamp_range());
498
499                            // Decode FlightData to RecordBatch
500                            match decoder.try_decode(&flight_data) {
501                                Ok(Some(FlightMessage::RecordBatch(record_batch))) => {
502                                    let table_name = table_name.clone();
503                                    let schema_bytes = schema_bytes.clone();
504                                    return Poll::Ready(Some(
505                                        PutRecordBatchRequest::try_new(
506                                            table_name,
507                                            record_batch,
508                                            request_id,
509                                            timestamp_range,
510                                            schema_bytes,
511                                            flight_data,
512                                            limiter.as_ref(),
513                                        )
514                                        .map_err(|e| Status::invalid_argument(e.to_string())),
515                                    ));
516                                }
517                                Ok(Some(other)) => {
518                                    debug!("Unexpected flight message: {:?}", other);
519                                    return Poll::Ready(Some(Err(Status::invalid_argument(
520                                        "Expected RecordBatch message, got other message type",
521                                    ))));
522                                }
523                                Ok(None) => {
524                                    // Dictionary batch - processed internally by decoder, continue polling
525                                    continue;
526                                }
527                                Err(e) => {
528                                    return Poll::Ready(Some(Err(Status::invalid_argument(
529                                        format!("Failed to decode RecordBatch: {}", e),
530                                    ))));
531                                }
532                            }
533                        }
534                    }
535                }
536                Some(Err(e)) => {
537                    return Poll::Ready(Some(Err(e)));
538                }
539                None => {
540                    return Poll::Ready(None);
541                }
542            }
543        }
544    }
545}
546
547fn extract_flow_extensions(
548    metadata: &tonic::metadata::MetadataMap,
549) -> TonicResult<Vec<(String, String)>> {
550    Ok(extract_json_metadata(metadata, FLOW_EXTENSIONS_METADATA_KEY)?.unwrap_or_default())
551}
552
553fn extract_snapshot_seqs(
554    metadata: &tonic::metadata::MetadataMap,
555) -> TonicResult<HashMap<u64, u64>> {
556    Ok(extract_json_metadata(metadata, SNAPSHOT_SEQS_METADATA_KEY)?.unwrap_or_default())
557}
558
559fn extract_json_metadata<T: serde::de::DeserializeOwned>(
560    metadata: &tonic::metadata::MetadataMap,
561    key: &'static str,
562) -> TonicResult<Option<T>> {
563    let Some(value) = metadata.get(key) else {
564        return Ok(None);
565    };
566
567    let value = value
568        .to_str()
569        .map_err(|e| Status::invalid_argument(format!("Invalid {key} metadata value: {e}")))?;
570
571    let parsed = serde_json::from_str::<T>(value)
572        .map_err(|e| Status::invalid_argument(format!("Invalid {key} metadata JSON: {e}")))?;
573    Ok(Some(parsed))
574}
575
576fn to_flight_data_stream(
577    output: Output,
578    tracing_context: TracingContext,
579    flight_compression: FlightCompression,
580    query_ctx: QueryContextRef,
581    should_emit_terminal_metrics: bool,
582) -> TonicStream<FlightData> {
583    match output.data {
584        OutputData::Stream(stream) => {
585            let stream = FlightRecordBatchStream::new(
586                stream,
587                tracing_context,
588                flight_compression,
589                query_ctx,
590            );
591            Box::pin(stream) as _
592        }
593        OutputData::RecordBatches(x) => {
594            let stream = FlightRecordBatchStream::new(
595                x.as_stream(),
596                tracing_context,
597                flight_compression,
598                query_ctx,
599            );
600            Box::pin(stream) as _
601        }
602        OutputData::AffectedRows(rows) => {
603            let terminal_metrics = match terminal_recordbatch_metrics_from_plan_if_requested(
604                output.meta.plan,
605                should_emit_terminal_metrics,
606            ) {
607                Some(metrics) => match serde_json::to_string(&metrics) {
608                    Ok(metrics) => Some(metrics),
609                    Err(e) => {
610                        let stream = tokio_stream::once(Err(Status::internal(format!(
611                            "Failed to serialize terminal metrics: {e}"
612                        ))));
613                        return Box::pin(stream) as _;
614                    }
615                },
616                None => None,
617            };
618            let affected_rows = FlightEncoder::default().encode(FlightMessage::AffectedRows {
619                rows,
620                metrics: terminal_metrics,
621            });
622            let stream = tokio_stream::iter(affected_rows.into_iter().map(Ok));
623            Box::pin(stream) as _
624        }
625    }
626}
627
628#[cfg(test)]
629mod tests {
630    use query::options::FLOW_SCHEDULED_TIME_MILLIS;
631    use tonic::metadata::{AsciiMetadataValue, MetadataMap};
632
633    use super::*;
634
635    #[test]
636    fn test_extract_flow_extensions_preserves_comma_bearing_values() {
637        let mut metadata = MetadataMap::new();
638        metadata.insert(
639            FLOW_EXTENSIONS_METADATA_KEY,
640            AsciiMetadataValue::try_from(
641                r#"[["flow.return_region_seq","true"],["flow.incremental_after_seqs","{\"1\":10,\"2\":20}"]]"#,
642            )
643            .unwrap(),
644        );
645
646        let extensions = extract_flow_extensions(&metadata).unwrap();
647        assert_eq!(
648            extensions,
649            vec![
650                ("flow.return_region_seq".to_string(), "true".to_string()),
651                (
652                    "flow.incremental_after_seqs".to_string(),
653                    r#"{"1":10,"2":20}"#.to_string()
654                ),
655            ]
656        );
657    }
658
659    #[test]
660    fn test_flow_extensions_can_carry_scheduled_time() {
661        let mut metadata = MetadataMap::new();
662        metadata.insert(
663            FLOW_EXTENSIONS_METADATA_KEY,
664            AsciiMetadataValue::try_from(r#"[["flow.scheduled_time_millis","1700000000000"]]"#)
665                .unwrap(),
666        );
667
668        let flow_extensions = extract_flow_extensions(&metadata).unwrap();
669        let query_ctx =
670            create_query_context(Channel::Grpc, None, flow_extensions, HashMap::new()).unwrap();
671
672        assert_eq!(
673            query_ctx.extension(FLOW_SCHEDULED_TIME_MILLIS),
674            Some("1700000000000")
675        );
676    }
677
678    #[test]
679    fn test_extract_flow_extensions_rejects_invalid_json() {
680        let mut metadata = MetadataMap::new();
681        metadata.insert(
682            FLOW_EXTENSIONS_METADATA_KEY,
683            AsciiMetadataValue::try_from("not-json").unwrap(),
684        );
685
686        let err = extract_flow_extensions(&metadata).unwrap_err();
687        assert_eq!(err.code(), tonic::Code::InvalidArgument);
688    }
689}