Skip to main content

servers/grpc/
flight.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod stream;
16
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21use api::v1::GreptimeRequest;
22use arrow_flight::flight_service_server::FlightService;
23use arrow_flight::{
24    Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
25    HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
26};
27use async_trait::async_trait;
28use bytes::{self, Bytes};
29use common_error::ext::ErrorExt;
30use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse};
31use common_grpc::flight::{
32    FLOW_EXTENSIONS_METADATA_KEY, FlightDecoder, FlightEncoder, FlightMessage,
33};
34use common_memory_manager::MemoryGuard;
35use common_query::{Output, OutputData};
36use common_recordbatch::DfRecordBatch;
37use common_telemetry::debug;
38use common_telemetry::tracing::info_span;
39use common_telemetry::tracing_context::{FutureExt, TracingContext};
40use datatypes::arrow::datatypes::SchemaRef;
41use futures::{Stream, future, ready};
42use futures_util::{StreamExt, TryStreamExt};
43use prost::Message;
44use query::metrics::terminal_recordbatch_metrics_from_plan_if_requested;
45use query::options::FlowQueryExtensions;
46use session::context::{Channel, QueryContextRef};
47use snafu::{IntoError, ResultExt, ensure};
48use table::table_name::TableName;
49use tokio::sync::mpsc;
50use tokio_stream::wrappers::ReceiverStream;
51use tonic::{Request, Response, Status, Streaming};
52
53use crate::error::{InvalidParameterSnafu, Result, ToJsonSnafu};
54pub use crate::grpc::flight::stream::FlightRecordBatchStream;
55use crate::grpc::greptime_handler::{
56    GreptimeRequestHandler, create_query_context, get_request_type,
57};
58use crate::grpc::{FlightCompression, TonicResult, context_auth};
59use crate::request_memory_limiter::ServerMemoryLimiter;
60use crate::request_memory_metrics::RequestMemoryMetrics;
61use crate::{error, hint_headers};
62
63pub type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + 'static>>;
64
65/// A subset of [FlightService]
66#[async_trait]
67pub trait FlightCraft: Send + Sync + 'static {
68    async fn do_get(
69        &self,
70        request: Request<Ticket>,
71    ) -> TonicResult<Response<TonicStream<FlightData>>>;
72
73    async fn do_put(
74        &self,
75        request: Request<Streaming<FlightData>>,
76    ) -> TonicResult<Response<TonicStream<PutResult>>> {
77        let _ = request;
78        Err(Status::unimplemented("Not yet implemented"))
79    }
80}
81
82pub type FlightCraftRef = Arc<dyn FlightCraft>;
83
84pub struct FlightCraftWrapper<T: FlightCraft>(pub T);
85
86impl<T: FlightCraft> From<T> for FlightCraftWrapper<T> {
87    fn from(t: T) -> Self {
88        Self(t)
89    }
90}
91
92#[async_trait]
93impl FlightCraft for FlightCraftRef {
94    async fn do_get(
95        &self,
96        request: Request<Ticket>,
97    ) -> TonicResult<Response<TonicStream<FlightData>>> {
98        (**self).do_get(request).await
99    }
100
101    async fn do_put(
102        &self,
103        request: Request<Streaming<FlightData>>,
104    ) -> TonicResult<Response<TonicStream<PutResult>>> {
105        self.as_ref().do_put(request).await
106    }
107}
108
109#[async_trait]
110impl<T: FlightCraft> FlightService for FlightCraftWrapper<T> {
111    type HandshakeStream = TonicStream<HandshakeResponse>;
112
113    async fn handshake(
114        &self,
115        _: Request<Streaming<HandshakeRequest>>,
116    ) -> TonicResult<Response<Self::HandshakeStream>> {
117        Err(Status::unimplemented("Not yet implemented"))
118    }
119
120    type ListFlightsStream = TonicStream<FlightInfo>;
121
122    async fn list_flights(
123        &self,
124        _: Request<Criteria>,
125    ) -> TonicResult<Response<Self::ListFlightsStream>> {
126        Err(Status::unimplemented("Not yet implemented"))
127    }
128
129    async fn get_flight_info(
130        &self,
131        _: Request<FlightDescriptor>,
132    ) -> TonicResult<Response<FlightInfo>> {
133        Err(Status::unimplemented("Not yet implemented"))
134    }
135
136    async fn poll_flight_info(
137        &self,
138        _: Request<FlightDescriptor>,
139    ) -> TonicResult<Response<PollInfo>> {
140        Err(Status::unimplemented("Not yet implemented"))
141    }
142
143    async fn get_schema(
144        &self,
145        _: Request<FlightDescriptor>,
146    ) -> TonicResult<Response<SchemaResult>> {
147        Err(Status::unimplemented("Not yet implemented"))
148    }
149
150    type DoGetStream = TonicStream<FlightData>;
151
152    async fn do_get(&self, request: Request<Ticket>) -> TonicResult<Response<Self::DoGetStream>> {
153        self.0.do_get(request).await
154    }
155
156    type DoPutStream = TonicStream<PutResult>;
157
158    async fn do_put(
159        &self,
160        request: Request<Streaming<FlightData>>,
161    ) -> TonicResult<Response<Self::DoPutStream>> {
162        self.0.do_put(request).await
163    }
164
165    type DoExchangeStream = TonicStream<FlightData>;
166
167    async fn do_exchange(
168        &self,
169        _: Request<Streaming<FlightData>>,
170    ) -> TonicResult<Response<Self::DoExchangeStream>> {
171        Err(Status::unimplemented("Not yet implemented"))
172    }
173
174    type DoActionStream = TonicStream<arrow_flight::Result>;
175
176    async fn do_action(&self, _: Request<Action>) -> TonicResult<Response<Self::DoActionStream>> {
177        Err(Status::unimplemented("Not yet implemented"))
178    }
179
180    type ListActionsStream = TonicStream<ActionType>;
181
182    async fn list_actions(
183        &self,
184        _: Request<Empty>,
185    ) -> TonicResult<Response<Self::ListActionsStream>> {
186        Err(Status::unimplemented("Not yet implemented"))
187    }
188}
189
190#[async_trait]
191impl FlightCraft for GreptimeRequestHandler {
192    async fn do_get(
193        &self,
194        request: Request<Ticket>,
195    ) -> TonicResult<Response<TonicStream<FlightData>>> {
196        let mut hints = hint_headers::extract_hints(request.metadata());
197        hints.extend(extract_flow_extensions(request.metadata())?);
198
199        let ticket = request.into_inner().ticket;
200        let request =
201            GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?;
202        let query_ctx =
203            create_query_context(Channel::Grpc, request.header.as_ref(), hints.clone())?;
204        // Validate flow hint syntax at the transport boundary before dispatching the request.
205        // This does not authorize or execute anything; `handle_request()` below still performs
206        // the normal frontend handling and auth checks before query execution.
207        let flow_extensions = FlowQueryExtensions::parse_flow_extensions(&query_ctx.extensions())
208            .map_err(|e| Status::invalid_argument(e.output_msg()))?;
209        let should_emit_terminal_metrics = flow_extensions
210            .as_ref()
211            .is_some_and(|extensions| extensions.should_collect_region_watermark());
212
213        // The Grpc protocol pass query by Flight. It needs to be wrapped under a span, in order to record stream
214        let span = info_span!(
215            "GreptimeRequestHandler::do_get",
216            protocol = "grpc",
217            request_type = get_request_type(&request)
218        );
219        let flight_compression = self.flight_compression;
220        async {
221            let output = self.handle_request(request, hints).await?;
222            let stream = to_flight_data_stream(
223                output,
224                TracingContext::from_current_span(),
225                flight_compression,
226                query_ctx,
227                should_emit_terminal_metrics,
228            );
229            Ok(Response::new(stream))
230        }
231        .trace(span)
232        .await
233    }
234
235    async fn do_put(
236        &self,
237        request: Request<Streaming<FlightData>>,
238    ) -> TonicResult<Response<TonicStream<PutResult>>> {
239        let (headers, extensions, stream) = request.into_parts();
240
241        let limiter = extensions.get::<ServerMemoryLimiter>().cloned();
242
243        let query_ctx = context_auth::create_query_context_from_grpc_metadata(&headers)?;
244        context_auth::check_auth(self.user_provider.clone(), &headers, query_ctx.clone()).await?;
245
246        const MAX_PENDING_RESPONSES: usize = 32;
247        let (tx, rx) = mpsc::channel::<TonicResult<DoPutResponse>>(MAX_PENDING_RESPONSES);
248
249        let stream = PutRecordBatchRequestStream::new(
250            stream,
251            query_ctx.current_catalog().to_string(),
252            query_ctx.current_schema(),
253            limiter,
254        )
255        .await?;
256        // Ack immediately when stream is created successfully (in Init state)
257        let _ = tx.send(Ok(DoPutResponse::new(0, 0, 0.0))).await;
258        self.put_record_batches(stream, tx, query_ctx).await;
259
260        let response = ReceiverStream::new(rx)
261            .and_then(|response| {
262                future::ready({
263                    serde_json::to_vec(&response)
264                        .context(ToJsonSnafu)
265                        .map(|x| PutResult {
266                            app_metadata: Bytes::from(x),
267                        })
268                        .map_err(Into::into)
269                })
270            })
271            .boxed();
272        Ok(Response::new(response))
273    }
274}
275
276pub struct PutRecordBatchRequest {
277    pub table_name: TableName,
278    pub request_id: i64,
279    pub timestamp_range: Option<(i64, i64)>,
280    pub record_batch: DfRecordBatch,
281    pub schema_bytes: Bytes,
282    pub flight_data: FlightData,
283    pub(crate) _guard: Option<MemoryGuard<RequestMemoryMetrics>>,
284}
285
286impl PutRecordBatchRequest {
287    fn try_new(
288        table_name: TableName,
289        record_batch: DfRecordBatch,
290        request_id: i64,
291        timestamp_range: Option<(i64, i64)>,
292        schema_bytes: Bytes,
293        flight_data: FlightData,
294        limiter: Option<&ServerMemoryLimiter>,
295    ) -> Result<Self> {
296        let memory_usage = flight_data.data_body.len()
297            + flight_data.app_metadata.len()
298            + flight_data.data_header.len();
299
300        let _guard = if let Some(limiter) = limiter {
301            let guard = limiter.try_acquire(memory_usage as u64).ok_or_else(|| {
302                let inner_err = common_memory_manager::Error::MemoryLimitExceeded {
303                    requested_bytes: memory_usage as u64,
304                    limit_bytes: limiter.limit_bytes(),
305                };
306                error::MemoryLimitExceededSnafu.into_error(inner_err)
307            })?;
308            Some(guard)
309        } else {
310            None
311        };
312
313        Ok(Self {
314            table_name,
315            request_id,
316            timestamp_range,
317            record_batch,
318            schema_bytes,
319            flight_data,
320            _guard,
321        })
322    }
323}
324
325pub struct PutRecordBatchRequestStream {
326    flight_data_stream: Streaming<FlightData>,
327    catalog: String,
328    schema_name: String,
329    limiter: Option<ServerMemoryLimiter>,
330    // Client now lazily sends schema data so we cannot eagerly wait for it.
331    // Instead, we need to decode while receiving record batches.
332    state: StreamState,
333}
334
335enum StreamState {
336    Init,
337    Ready {
338        table_name: TableName,
339        schema: SchemaRef,
340        schema_bytes: Bytes,
341        decoder: FlightDecoder,
342    },
343}
344
345impl PutRecordBatchRequestStream {
346    /// Creates a new `PutRecordBatchRequestStream` in Init state.
347    /// The stream will transition to Ready state when it receives the schema message.
348    pub async fn new(
349        flight_data_stream: Streaming<FlightData>,
350        catalog: String,
351        schema: String,
352        limiter: Option<ServerMemoryLimiter>,
353    ) -> TonicResult<Self> {
354        Ok(Self {
355            flight_data_stream,
356            catalog,
357            schema_name: schema,
358            limiter,
359            state: StreamState::Init,
360        })
361    }
362
363    /// Returns the table name extracted from the flight descriptor.
364    /// Returns None if the stream is still in Init state.
365    pub fn table_name(&self) -> Option<&TableName> {
366        match &self.state {
367            StreamState::Init => None,
368            StreamState::Ready { table_name, .. } => Some(table_name),
369        }
370    }
371
372    /// Returns the Arrow schema decoded from the first flight message.
373    /// Returns None if the stream is still in Init state.
374    pub fn schema(&self) -> Option<&SchemaRef> {
375        match &self.state {
376            StreamState::Init => None,
377            StreamState::Ready { schema, .. } => Some(schema),
378        }
379    }
380
381    /// Returns the raw schema bytes in IPC format.
382    /// Returns None if the stream is still in Init state.
383    pub fn schema_bytes(&self) -> Option<&Bytes> {
384        match &self.state {
385            StreamState::Init => None,
386            StreamState::Ready { schema_bytes, .. } => Some(schema_bytes),
387        }
388    }
389
390    fn extract_table_name(mut descriptor: FlightDescriptor) -> Result<String> {
391        ensure!(
392            descriptor.r#type == arrow_flight::flight_descriptor::DescriptorType::Path as i32,
393            InvalidParameterSnafu {
394                reason: "expect FlightDescriptor::type == 'Path' only",
395            }
396        );
397        ensure!(
398            descriptor.path.len() == 1,
399            InvalidParameterSnafu {
400                reason: "expect FlightDescriptor::path has only one table name",
401            }
402        );
403        Ok(descriptor.path.remove(0))
404    }
405}
406
407impl Stream for PutRecordBatchRequestStream {
408    type Item = TonicResult<PutRecordBatchRequest>;
409
410    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
411        loop {
412            let poll = ready!(self.flight_data_stream.poll_next_unpin(cx));
413
414            match poll {
415                Some(Ok(flight_data)) => {
416                    let limiter = self.limiter.clone();
417
418                    match &mut self.state {
419                        StreamState::Init => {
420                            // First message - expecting schema
421                            let flight_descriptor = match flight_data.flight_descriptor.as_ref() {
422                                Some(descriptor) => descriptor.clone(),
423                                None => {
424                                    return Poll::Ready(Some(Err(Status::failed_precondition(
425                                        "table to put is not found in flight descriptor",
426                                    ))));
427                                }
428                            };
429
430                            let table_name_str = match Self::extract_table_name(flight_descriptor) {
431                                Ok(name) => name,
432                                Err(e) => {
433                                    return Poll::Ready(Some(Err(Status::invalid_argument(
434                                        e.to_string(),
435                                    ))));
436                                }
437                            };
438                            let table_name = TableName::new(
439                                self.catalog.clone(),
440                                self.schema_name.clone(),
441                                table_name_str,
442                            );
443
444                            // Decode the schema
445                            let mut decoder = FlightDecoder::default();
446                            let schema_message = decoder.try_decode(&flight_data).map_err(|e| {
447                                Status::invalid_argument(format!("Failed to decode schema: {}", e))
448                            })?;
449
450                            match schema_message {
451                                Some(FlightMessage::Schema(schema)) => {
452                                    let schema_bytes = decoder.schema_bytes().ok_or_else(|| {
453                                        Status::internal(
454                                            "decoder should have schema bytes after decoding schema",
455                                        )
456                                    })?;
457
458                                    // Transition to Ready state with all necessary data
459                                    self.state = StreamState::Ready {
460                                        table_name,
461                                        schema,
462                                        schema_bytes,
463                                        decoder,
464                                    };
465                                    // Continue to next iteration to process RecordBatch messages
466                                    continue;
467                                }
468                                _ => {
469                                    return Poll::Ready(Some(Err(Status::failed_precondition(
470                                        "first message must be a Schema message",
471                                    ))));
472                                }
473                            }
474                        }
475                        StreamState::Ready {
476                            table_name,
477                            schema: _,
478                            schema_bytes,
479                            decoder,
480                        } => {
481                            // Extract request_id and time range from FlightData before decoding
482                            let metadata = if !flight_data.app_metadata.is_empty() {
483                                serde_json::from_slice::<DoPutMetadata>(&flight_data.app_metadata)
484                                    .ok()
485                            } else {
486                                None
487                            };
488                            let request_id = metadata
489                                .as_ref()
490                                .map(|meta| meta.request_id())
491                                .unwrap_or_default();
492                            let timestamp_range = metadata.and_then(|meta| meta.timestamp_range());
493
494                            // Decode FlightData to RecordBatch
495                            match decoder.try_decode(&flight_data) {
496                                Ok(Some(FlightMessage::RecordBatch(record_batch))) => {
497                                    let table_name = table_name.clone();
498                                    let schema_bytes = schema_bytes.clone();
499                                    return Poll::Ready(Some(
500                                        PutRecordBatchRequest::try_new(
501                                            table_name,
502                                            record_batch,
503                                            request_id,
504                                            timestamp_range,
505                                            schema_bytes,
506                                            flight_data,
507                                            limiter.as_ref(),
508                                        )
509                                        .map_err(|e| Status::invalid_argument(e.to_string())),
510                                    ));
511                                }
512                                Ok(Some(other)) => {
513                                    debug!("Unexpected flight message: {:?}", other);
514                                    return Poll::Ready(Some(Err(Status::invalid_argument(
515                                        "Expected RecordBatch message, got other message type",
516                                    ))));
517                                }
518                                Ok(None) => {
519                                    // Dictionary batch - processed internally by decoder, continue polling
520                                    continue;
521                                }
522                                Err(e) => {
523                                    return Poll::Ready(Some(Err(Status::invalid_argument(
524                                        format!("Failed to decode RecordBatch: {}", e),
525                                    ))));
526                                }
527                            }
528                        }
529                    }
530                }
531                Some(Err(e)) => {
532                    return Poll::Ready(Some(Err(e)));
533                }
534                None => {
535                    return Poll::Ready(None);
536                }
537            }
538        }
539    }
540}
541
542fn extract_flow_extensions(
543    metadata: &tonic::metadata::MetadataMap,
544) -> TonicResult<Vec<(String, String)>> {
545    let Some(value) = metadata.get(FLOW_EXTENSIONS_METADATA_KEY) else {
546        return Ok(vec![]);
547    };
548
549    let value = value.to_str().map_err(|e| {
550        Status::invalid_argument(format!(
551            "Invalid {FLOW_EXTENSIONS_METADATA_KEY} metadata value: {e}"
552        ))
553    })?;
554
555    serde_json::from_str::<Vec<(String, String)>>(value).map_err(|e| {
556        Status::invalid_argument(format!(
557            "Invalid {FLOW_EXTENSIONS_METADATA_KEY} metadata JSON: {e}"
558        ))
559    })
560}
561
562fn to_flight_data_stream(
563    output: Output,
564    tracing_context: TracingContext,
565    flight_compression: FlightCompression,
566    query_ctx: QueryContextRef,
567    should_emit_terminal_metrics: bool,
568) -> TonicStream<FlightData> {
569    match output.data {
570        OutputData::Stream(stream) => {
571            let stream = FlightRecordBatchStream::new(
572                stream,
573                tracing_context,
574                flight_compression,
575                query_ctx,
576            );
577            Box::pin(stream) as _
578        }
579        OutputData::RecordBatches(x) => {
580            let stream = FlightRecordBatchStream::new(
581                x.as_stream(),
582                tracing_context,
583                flight_compression,
584                query_ctx,
585            );
586            Box::pin(stream) as _
587        }
588        OutputData::AffectedRows(rows) => {
589            let terminal_metrics = match terminal_recordbatch_metrics_from_plan_if_requested(
590                output.meta.plan,
591                should_emit_terminal_metrics,
592            ) {
593                Some(metrics) => match serde_json::to_string(&metrics) {
594                    Ok(metrics) => Some(metrics),
595                    Err(e) => {
596                        let stream = tokio_stream::once(Err(Status::internal(format!(
597                            "Failed to serialize terminal metrics: {e}"
598                        ))));
599                        return Box::pin(stream) as _;
600                    }
601                },
602                None => None,
603            };
604            let affected_rows = FlightEncoder::default().encode(FlightMessage::AffectedRows {
605                rows,
606                metrics: terminal_metrics,
607            });
608            let stream = tokio_stream::iter(affected_rows.into_iter().map(Ok));
609            Box::pin(stream) as _
610        }
611    }
612}
613
614#[cfg(test)]
615mod tests {
616    use tonic::metadata::{AsciiMetadataValue, MetadataMap};
617
618    use super::*;
619
620    #[test]
621    fn test_extract_flow_extensions_preserves_comma_bearing_values() {
622        let mut metadata = MetadataMap::new();
623        metadata.insert(
624            FLOW_EXTENSIONS_METADATA_KEY,
625            AsciiMetadataValue::try_from(
626                r#"[["flow.return_region_seq","true"],["flow.incremental_after_seqs","{\"1\":10,\"2\":20}"]]"#,
627            )
628            .unwrap(),
629        );
630
631        let extensions = extract_flow_extensions(&metadata).unwrap();
632        assert_eq!(
633            extensions,
634            vec![
635                ("flow.return_region_seq".to_string(), "true".to_string()),
636                (
637                    "flow.incremental_after_seqs".to_string(),
638                    r#"{"1":10,"2":20}"#.to_string()
639                ),
640            ]
641        );
642    }
643
644    #[test]
645    fn test_extract_flow_extensions_rejects_invalid_json() {
646        let mut metadata = MetadataMap::new();
647        metadata.insert(
648            FLOW_EXTENSIONS_METADATA_KEY,
649            AsciiMetadataValue::try_from("not-json").unwrap(),
650        );
651
652        let err = extract_flow_extensions(&metadata).unwrap_err();
653        assert_eq!(err.code(), tonic::Code::InvalidArgument);
654    }
655}