Skip to main content

servers/grpc/
flight.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod stream;
16
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21use api::v1::GreptimeRequest;
22use arrow_flight::flight_service_server::FlightService;
23use arrow_flight::{
24    Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
25    HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
26};
27use async_trait::async_trait;
28use bytes::{self, Bytes};
29use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse};
30use common_grpc::flight::{FlightDecoder, FlightEncoder, FlightMessage};
31use common_memory_manager::MemoryGuard;
32use common_query::{Output, OutputData};
33use common_recordbatch::DfRecordBatch;
34use common_telemetry::debug;
35use common_telemetry::tracing::info_span;
36use common_telemetry::tracing_context::{FutureExt, TracingContext};
37use datatypes::arrow::datatypes::SchemaRef;
38use futures::{Stream, future, ready};
39use futures_util::{StreamExt, TryStreamExt};
40use prost::Message;
41use session::context::{QueryContext, QueryContextRef};
42use snafu::{IntoError, ResultExt, ensure};
43use table::table_name::TableName;
44use tokio::sync::mpsc;
45use tokio_stream::wrappers::ReceiverStream;
46use tonic::{Request, Response, Status, Streaming};
47
48use crate::error::{InvalidParameterSnafu, Result, ToJsonSnafu};
49pub use crate::grpc::flight::stream::FlightRecordBatchStream;
50use crate::grpc::greptime_handler::{GreptimeRequestHandler, get_request_type};
51use crate::grpc::{FlightCompression, TonicResult, context_auth};
52use crate::request_memory_limiter::ServerMemoryLimiter;
53use crate::request_memory_metrics::RequestMemoryMetrics;
54use crate::{error, hint_headers};
55
56pub type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + 'static>>;
57
58/// A subset of [FlightService]
59#[async_trait]
60pub trait FlightCraft: Send + Sync + 'static {
61    async fn do_get(
62        &self,
63        request: Request<Ticket>,
64    ) -> TonicResult<Response<TonicStream<FlightData>>>;
65
66    async fn do_put(
67        &self,
68        request: Request<Streaming<FlightData>>,
69    ) -> TonicResult<Response<TonicStream<PutResult>>> {
70        let _ = request;
71        Err(Status::unimplemented("Not yet implemented"))
72    }
73}
74
75pub type FlightCraftRef = Arc<dyn FlightCraft>;
76
77pub struct FlightCraftWrapper<T: FlightCraft>(pub T);
78
79impl<T: FlightCraft> From<T> for FlightCraftWrapper<T> {
80    fn from(t: T) -> Self {
81        Self(t)
82    }
83}
84
85#[async_trait]
86impl FlightCraft for FlightCraftRef {
87    async fn do_get(
88        &self,
89        request: Request<Ticket>,
90    ) -> TonicResult<Response<TonicStream<FlightData>>> {
91        (**self).do_get(request).await
92    }
93
94    async fn do_put(
95        &self,
96        request: Request<Streaming<FlightData>>,
97    ) -> TonicResult<Response<TonicStream<PutResult>>> {
98        self.as_ref().do_put(request).await
99    }
100}
101
102#[async_trait]
103impl<T: FlightCraft> FlightService for FlightCraftWrapper<T> {
104    type HandshakeStream = TonicStream<HandshakeResponse>;
105
106    async fn handshake(
107        &self,
108        _: Request<Streaming<HandshakeRequest>>,
109    ) -> TonicResult<Response<Self::HandshakeStream>> {
110        Err(Status::unimplemented("Not yet implemented"))
111    }
112
113    type ListFlightsStream = TonicStream<FlightInfo>;
114
115    async fn list_flights(
116        &self,
117        _: Request<Criteria>,
118    ) -> TonicResult<Response<Self::ListFlightsStream>> {
119        Err(Status::unimplemented("Not yet implemented"))
120    }
121
122    async fn get_flight_info(
123        &self,
124        _: Request<FlightDescriptor>,
125    ) -> TonicResult<Response<FlightInfo>> {
126        Err(Status::unimplemented("Not yet implemented"))
127    }
128
129    async fn poll_flight_info(
130        &self,
131        _: Request<FlightDescriptor>,
132    ) -> TonicResult<Response<PollInfo>> {
133        Err(Status::unimplemented("Not yet implemented"))
134    }
135
136    async fn get_schema(
137        &self,
138        _: Request<FlightDescriptor>,
139    ) -> TonicResult<Response<SchemaResult>> {
140        Err(Status::unimplemented("Not yet implemented"))
141    }
142
143    type DoGetStream = TonicStream<FlightData>;
144
145    async fn do_get(&self, request: Request<Ticket>) -> TonicResult<Response<Self::DoGetStream>> {
146        self.0.do_get(request).await
147    }
148
149    type DoPutStream = TonicStream<PutResult>;
150
151    async fn do_put(
152        &self,
153        request: Request<Streaming<FlightData>>,
154    ) -> TonicResult<Response<Self::DoPutStream>> {
155        self.0.do_put(request).await
156    }
157
158    type DoExchangeStream = TonicStream<FlightData>;
159
160    async fn do_exchange(
161        &self,
162        _: Request<Streaming<FlightData>>,
163    ) -> TonicResult<Response<Self::DoExchangeStream>> {
164        Err(Status::unimplemented("Not yet implemented"))
165    }
166
167    type DoActionStream = TonicStream<arrow_flight::Result>;
168
169    async fn do_action(&self, _: Request<Action>) -> TonicResult<Response<Self::DoActionStream>> {
170        Err(Status::unimplemented("Not yet implemented"))
171    }
172
173    type ListActionsStream = TonicStream<ActionType>;
174
175    async fn list_actions(
176        &self,
177        _: Request<Empty>,
178    ) -> TonicResult<Response<Self::ListActionsStream>> {
179        Err(Status::unimplemented("Not yet implemented"))
180    }
181}
182
183#[async_trait]
184impl FlightCraft for GreptimeRequestHandler {
185    async fn do_get(
186        &self,
187        request: Request<Ticket>,
188    ) -> TonicResult<Response<TonicStream<FlightData>>> {
189        let hints = hint_headers::extract_hints(request.metadata());
190
191        let ticket = request.into_inner().ticket;
192        let request =
193            GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?;
194
195        // The Grpc protocol pass query by Flight. It needs to be wrapped under a span, in order to record stream
196        let span = info_span!(
197            "GreptimeRequestHandler::do_get",
198            protocol = "grpc",
199            request_type = get_request_type(&request)
200        );
201        let flight_compression = self.flight_compression;
202        async {
203            let output = self.handle_request(request, hints).await?;
204            let stream = to_flight_data_stream(
205                output,
206                TracingContext::from_current_span(),
207                flight_compression,
208                QueryContext::arc(),
209            );
210            Ok(Response::new(stream))
211        }
212        .trace(span)
213        .await
214    }
215
216    async fn do_put(
217        &self,
218        request: Request<Streaming<FlightData>>,
219    ) -> TonicResult<Response<TonicStream<PutResult>>> {
220        let (headers, extensions, stream) = request.into_parts();
221
222        let limiter = extensions.get::<ServerMemoryLimiter>().cloned();
223
224        let query_ctx = context_auth::create_query_context_from_grpc_metadata(&headers)?;
225        context_auth::check_auth(self.user_provider.clone(), &headers, query_ctx.clone()).await?;
226
227        const MAX_PENDING_RESPONSES: usize = 32;
228        let (tx, rx) = mpsc::channel::<TonicResult<DoPutResponse>>(MAX_PENDING_RESPONSES);
229
230        let stream = PutRecordBatchRequestStream::new(
231            stream,
232            query_ctx.current_catalog().to_string(),
233            query_ctx.current_schema(),
234            limiter,
235        )
236        .await?;
237        // Ack immediately when stream is created successfully (in Init state)
238        let _ = tx.send(Ok(DoPutResponse::new(0, 0, 0.0))).await;
239        self.put_record_batches(stream, tx, query_ctx).await;
240
241        let response = ReceiverStream::new(rx)
242            .and_then(|response| {
243                future::ready({
244                    serde_json::to_vec(&response)
245                        .context(ToJsonSnafu)
246                        .map(|x| PutResult {
247                            app_metadata: Bytes::from(x),
248                        })
249                        .map_err(Into::into)
250                })
251            })
252            .boxed();
253        Ok(Response::new(response))
254    }
255}
256
257pub struct PutRecordBatchRequest {
258    pub table_name: TableName,
259    pub request_id: i64,
260    pub record_batch: DfRecordBatch,
261    pub schema_bytes: Bytes,
262    pub flight_data: FlightData,
263    pub(crate) _guard: Option<MemoryGuard<RequestMemoryMetrics>>,
264}
265
266impl PutRecordBatchRequest {
267    fn try_new(
268        table_name: TableName,
269        record_batch: DfRecordBatch,
270        request_id: i64,
271        schema_bytes: Bytes,
272        flight_data: FlightData,
273        limiter: Option<&ServerMemoryLimiter>,
274    ) -> Result<Self> {
275        let memory_usage = flight_data.data_body.len()
276            + flight_data.app_metadata.len()
277            + flight_data.data_header.len();
278
279        let _guard = if let Some(limiter) = limiter {
280            let guard = limiter.try_acquire(memory_usage as u64).ok_or_else(|| {
281                let inner_err = common_memory_manager::Error::MemoryLimitExceeded {
282                    requested_bytes: memory_usage as u64,
283                    limit_bytes: limiter.limit_bytes(),
284                };
285                error::MemoryLimitExceededSnafu.into_error(inner_err)
286            })?;
287            Some(guard)
288        } else {
289            None
290        };
291
292        Ok(Self {
293            table_name,
294            request_id,
295            record_batch,
296            schema_bytes,
297            flight_data,
298            _guard,
299        })
300    }
301}
302
303pub struct PutRecordBatchRequestStream {
304    flight_data_stream: Streaming<FlightData>,
305    catalog: String,
306    schema_name: String,
307    limiter: Option<ServerMemoryLimiter>,
308    // Client now lazily sends schema data so we cannot eagerly wait for it.
309    // Instead, we need to decode while receiving record batches.
310    state: StreamState,
311}
312
313enum StreamState {
314    Init,
315    Ready {
316        table_name: TableName,
317        schema: SchemaRef,
318        schema_bytes: Bytes,
319        decoder: FlightDecoder,
320    },
321}
322
323impl PutRecordBatchRequestStream {
324    /// Creates a new `PutRecordBatchRequestStream` in Init state.
325    /// The stream will transition to Ready state when it receives the schema message.
326    pub async fn new(
327        flight_data_stream: Streaming<FlightData>,
328        catalog: String,
329        schema: String,
330        limiter: Option<ServerMemoryLimiter>,
331    ) -> TonicResult<Self> {
332        Ok(Self {
333            flight_data_stream,
334            catalog,
335            schema_name: schema,
336            limiter,
337            state: StreamState::Init,
338        })
339    }
340
341    /// Returns the table name extracted from the flight descriptor.
342    /// Returns None if the stream is still in Init state.
343    pub fn table_name(&self) -> Option<&TableName> {
344        match &self.state {
345            StreamState::Init => None,
346            StreamState::Ready { table_name, .. } => Some(table_name),
347        }
348    }
349
350    /// Returns the Arrow schema decoded from the first flight message.
351    /// Returns None if the stream is still in Init state.
352    pub fn schema(&self) -> Option<&SchemaRef> {
353        match &self.state {
354            StreamState::Init => None,
355            StreamState::Ready { schema, .. } => Some(schema),
356        }
357    }
358
359    /// Returns the raw schema bytes in IPC format.
360    /// Returns None if the stream is still in Init state.
361    pub fn schema_bytes(&self) -> Option<&Bytes> {
362        match &self.state {
363            StreamState::Init => None,
364            StreamState::Ready { schema_bytes, .. } => Some(schema_bytes),
365        }
366    }
367
368    fn extract_table_name(mut descriptor: FlightDescriptor) -> Result<String> {
369        ensure!(
370            descriptor.r#type == arrow_flight::flight_descriptor::DescriptorType::Path as i32,
371            InvalidParameterSnafu {
372                reason: "expect FlightDescriptor::type == 'Path' only",
373            }
374        );
375        ensure!(
376            descriptor.path.len() == 1,
377            InvalidParameterSnafu {
378                reason: "expect FlightDescriptor::path has only one table name",
379            }
380        );
381        Ok(descriptor.path.remove(0))
382    }
383}
384
385impl Stream for PutRecordBatchRequestStream {
386    type Item = TonicResult<PutRecordBatchRequest>;
387
388    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
389        loop {
390            let poll = ready!(self.flight_data_stream.poll_next_unpin(cx));
391
392            match poll {
393                Some(Ok(flight_data)) => {
394                    let limiter = self.limiter.clone();
395
396                    match &mut self.state {
397                        StreamState::Init => {
398                            // First message - expecting schema
399                            let flight_descriptor = match flight_data.flight_descriptor.as_ref() {
400                                Some(descriptor) => descriptor.clone(),
401                                None => {
402                                    return Poll::Ready(Some(Err(Status::failed_precondition(
403                                        "table to put is not found in flight descriptor",
404                                    ))));
405                                }
406                            };
407
408                            let table_name_str = match Self::extract_table_name(flight_descriptor) {
409                                Ok(name) => name,
410                                Err(e) => {
411                                    return Poll::Ready(Some(Err(Status::invalid_argument(
412                                        e.to_string(),
413                                    ))));
414                                }
415                            };
416                            let table_name = TableName::new(
417                                self.catalog.clone(),
418                                self.schema_name.clone(),
419                                table_name_str,
420                            );
421
422                            // Decode the schema
423                            let mut decoder = FlightDecoder::default();
424                            let schema_message = decoder.try_decode(&flight_data).map_err(|e| {
425                                Status::invalid_argument(format!("Failed to decode schema: {}", e))
426                            })?;
427
428                            match schema_message {
429                                Some(FlightMessage::Schema(schema)) => {
430                                    let schema_bytes = decoder.schema_bytes().ok_or_else(|| {
431                                        Status::internal(
432                                            "decoder should have schema bytes after decoding schema",
433                                        )
434                                    })?;
435
436                                    // Transition to Ready state with all necessary data
437                                    self.state = StreamState::Ready {
438                                        table_name,
439                                        schema,
440                                        schema_bytes,
441                                        decoder,
442                                    };
443                                    // Continue to next iteration to process RecordBatch messages
444                                    continue;
445                                }
446                                _ => {
447                                    return Poll::Ready(Some(Err(Status::failed_precondition(
448                                        "first message must be a Schema message",
449                                    ))));
450                                }
451                            }
452                        }
453                        StreamState::Ready {
454                            table_name,
455                            schema: _,
456                            schema_bytes,
457                            decoder,
458                        } => {
459                            // Extract request_id and body_size from FlightData before decoding
460                            let request_id = if !flight_data.app_metadata.is_empty() {
461                                serde_json::from_slice::<DoPutMetadata>(&flight_data.app_metadata)
462                                    .map(|meta| meta.request_id())
463                                    .unwrap_or_default()
464                            } else {
465                                0
466                            };
467
468                            // Decode FlightData to RecordBatch
469                            match decoder.try_decode(&flight_data) {
470                                Ok(Some(FlightMessage::RecordBatch(record_batch))) => {
471                                    let table_name = table_name.clone();
472                                    let schema_bytes = schema_bytes.clone();
473                                    return Poll::Ready(Some(
474                                        PutRecordBatchRequest::try_new(
475                                            table_name,
476                                            record_batch,
477                                            request_id,
478                                            schema_bytes,
479                                            flight_data,
480                                            limiter.as_ref(),
481                                        )
482                                        .map_err(|e| Status::invalid_argument(e.to_string())),
483                                    ));
484                                }
485                                Ok(Some(other)) => {
486                                    debug!("Unexpected flight message: {:?}", other);
487                                    return Poll::Ready(Some(Err(Status::invalid_argument(
488                                        "Expected RecordBatch message, got other message type",
489                                    ))));
490                                }
491                                Ok(None) => {
492                                    // Dictionary batch - processed internally by decoder, continue polling
493                                    continue;
494                                }
495                                Err(e) => {
496                                    return Poll::Ready(Some(Err(Status::invalid_argument(
497                                        format!("Failed to decode RecordBatch: {}", e),
498                                    ))));
499                                }
500                            }
501                        }
502                    }
503                }
504                Some(Err(e)) => {
505                    return Poll::Ready(Some(Err(e)));
506                }
507                None => {
508                    return Poll::Ready(None);
509                }
510            }
511        }
512    }
513}
514
515fn to_flight_data_stream(
516    output: Output,
517    tracing_context: TracingContext,
518    flight_compression: FlightCompression,
519    query_ctx: QueryContextRef,
520) -> TonicStream<FlightData> {
521    match output.data {
522        OutputData::Stream(stream) => {
523            let stream = FlightRecordBatchStream::new(
524                stream,
525                tracing_context,
526                flight_compression,
527                query_ctx,
528            );
529            Box::pin(stream) as _
530        }
531        OutputData::RecordBatches(x) => {
532            let stream = FlightRecordBatchStream::new(
533                x.as_stream(),
534                tracing_context,
535                flight_compression,
536                query_ctx,
537            );
538            Box::pin(stream) as _
539        }
540        OutputData::AffectedRows(rows) => {
541            let stream = tokio_stream::iter(
542                FlightEncoder::default()
543                    .encode(FlightMessage::AffectedRows(rows))
544                    .into_iter()
545                    .map(Ok),
546            );
547            Box::pin(stream) as _
548        }
549    }
550}