servers/grpc/
flight.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod stream;
16
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21use api::v1::GreptimeRequest;
22use arrow_flight::flight_service_server::FlightService;
23use arrow_flight::{
24    Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
25    HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
26};
27use async_trait::async_trait;
28use bytes;
29use bytes::Bytes;
30use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse};
31use common_grpc::flight::{FlightDecoder, FlightEncoder, FlightMessage};
32use common_memory_manager::MemoryGuard;
33use common_query::{Output, OutputData};
34use common_recordbatch::DfRecordBatch;
35use common_telemetry::debug;
36use common_telemetry::tracing::info_span;
37use common_telemetry::tracing_context::{FutureExt, TracingContext};
38use datatypes::arrow::datatypes::SchemaRef;
39use futures::{Stream, future, ready};
40use futures_util::{StreamExt, TryStreamExt};
41use prost::Message;
42use session::context::{QueryContext, QueryContextRef};
43use snafu::{IntoError, ResultExt, ensure};
44use table::table_name::TableName;
45use tokio::sync::mpsc;
46use tokio_stream::wrappers::ReceiverStream;
47use tonic::{Request, Response, Status, Streaming};
48
49use crate::error::{InvalidParameterSnafu, Result, ToJsonSnafu};
50pub use crate::grpc::flight::stream::FlightRecordBatchStream;
51use crate::grpc::greptime_handler::{GreptimeRequestHandler, get_request_type};
52use crate::grpc::{FlightCompression, TonicResult, context_auth};
53use crate::request_memory_limiter::ServerMemoryLimiter;
54use crate::request_memory_metrics::RequestMemoryMetrics;
55use crate::{error, hint_headers};
56
57pub type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + 'static>>;
58
59/// A subset of [FlightService]
60#[async_trait]
61pub trait FlightCraft: Send + Sync + 'static {
62    async fn do_get(
63        &self,
64        request: Request<Ticket>,
65    ) -> TonicResult<Response<TonicStream<FlightData>>>;
66
67    async fn do_put(
68        &self,
69        request: Request<Streaming<FlightData>>,
70    ) -> TonicResult<Response<TonicStream<PutResult>>> {
71        let _ = request;
72        Err(Status::unimplemented("Not yet implemented"))
73    }
74}
75
76pub type FlightCraftRef = Arc<dyn FlightCraft>;
77
78pub struct FlightCraftWrapper<T: FlightCraft>(pub T);
79
80impl<T: FlightCraft> From<T> for FlightCraftWrapper<T> {
81    fn from(t: T) -> Self {
82        Self(t)
83    }
84}
85
86#[async_trait]
87impl FlightCraft for FlightCraftRef {
88    async fn do_get(
89        &self,
90        request: Request<Ticket>,
91    ) -> TonicResult<Response<TonicStream<FlightData>>> {
92        (**self).do_get(request).await
93    }
94
95    async fn do_put(
96        &self,
97        request: Request<Streaming<FlightData>>,
98    ) -> TonicResult<Response<TonicStream<PutResult>>> {
99        self.as_ref().do_put(request).await
100    }
101}
102
103#[async_trait]
104impl<T: FlightCraft> FlightService for FlightCraftWrapper<T> {
105    type HandshakeStream = TonicStream<HandshakeResponse>;
106
107    async fn handshake(
108        &self,
109        _: Request<Streaming<HandshakeRequest>>,
110    ) -> TonicResult<Response<Self::HandshakeStream>> {
111        Err(Status::unimplemented("Not yet implemented"))
112    }
113
114    type ListFlightsStream = TonicStream<FlightInfo>;
115
116    async fn list_flights(
117        &self,
118        _: Request<Criteria>,
119    ) -> TonicResult<Response<Self::ListFlightsStream>> {
120        Err(Status::unimplemented("Not yet implemented"))
121    }
122
123    async fn get_flight_info(
124        &self,
125        _: Request<FlightDescriptor>,
126    ) -> TonicResult<Response<FlightInfo>> {
127        Err(Status::unimplemented("Not yet implemented"))
128    }
129
130    async fn poll_flight_info(
131        &self,
132        _: Request<FlightDescriptor>,
133    ) -> TonicResult<Response<PollInfo>> {
134        Err(Status::unimplemented("Not yet implemented"))
135    }
136
137    async fn get_schema(
138        &self,
139        _: Request<FlightDescriptor>,
140    ) -> TonicResult<Response<SchemaResult>> {
141        Err(Status::unimplemented("Not yet implemented"))
142    }
143
144    type DoGetStream = TonicStream<FlightData>;
145
146    async fn do_get(&self, request: Request<Ticket>) -> TonicResult<Response<Self::DoGetStream>> {
147        self.0.do_get(request).await
148    }
149
150    type DoPutStream = TonicStream<PutResult>;
151
152    async fn do_put(
153        &self,
154        request: Request<Streaming<FlightData>>,
155    ) -> TonicResult<Response<Self::DoPutStream>> {
156        self.0.do_put(request).await
157    }
158
159    type DoExchangeStream = TonicStream<FlightData>;
160
161    async fn do_exchange(
162        &self,
163        _: Request<Streaming<FlightData>>,
164    ) -> TonicResult<Response<Self::DoExchangeStream>> {
165        Err(Status::unimplemented("Not yet implemented"))
166    }
167
168    type DoActionStream = TonicStream<arrow_flight::Result>;
169
170    async fn do_action(&self, _: Request<Action>) -> TonicResult<Response<Self::DoActionStream>> {
171        Err(Status::unimplemented("Not yet implemented"))
172    }
173
174    type ListActionsStream = TonicStream<ActionType>;
175
176    async fn list_actions(
177        &self,
178        _: Request<Empty>,
179    ) -> TonicResult<Response<Self::ListActionsStream>> {
180        Err(Status::unimplemented("Not yet implemented"))
181    }
182}
183
184#[async_trait]
185impl FlightCraft for GreptimeRequestHandler {
186    async fn do_get(
187        &self,
188        request: Request<Ticket>,
189    ) -> TonicResult<Response<TonicStream<FlightData>>> {
190        let hints = hint_headers::extract_hints(request.metadata());
191
192        let ticket = request.into_inner().ticket;
193        let request =
194            GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?;
195
196        // The Grpc protocol pass query by Flight. It needs to be wrapped under a span, in order to record stream
197        let span = info_span!(
198            "GreptimeRequestHandler::do_get",
199            protocol = "grpc",
200            request_type = get_request_type(&request)
201        );
202        let flight_compression = self.flight_compression;
203        async {
204            let output = self.handle_request(request, hints).await?;
205            let stream = to_flight_data_stream(
206                output,
207                TracingContext::from_current_span(),
208                flight_compression,
209                QueryContext::arc(),
210            );
211            Ok(Response::new(stream))
212        }
213        .trace(span)
214        .await
215    }
216
217    async fn do_put(
218        &self,
219        request: Request<Streaming<FlightData>>,
220    ) -> TonicResult<Response<TonicStream<PutResult>>> {
221        let (headers, extensions, stream) = request.into_parts();
222
223        let limiter = extensions.get::<ServerMemoryLimiter>().cloned();
224
225        let query_ctx = context_auth::create_query_context_from_grpc_metadata(&headers)?;
226        context_auth::check_auth(self.user_provider.clone(), &headers, query_ctx.clone()).await?;
227
228        const MAX_PENDING_RESPONSES: usize = 32;
229        let (tx, rx) = mpsc::channel::<TonicResult<DoPutResponse>>(MAX_PENDING_RESPONSES);
230
231        let stream = PutRecordBatchRequestStream::new(
232            stream,
233            query_ctx.current_catalog().to_string(),
234            query_ctx.current_schema(),
235            limiter,
236        )
237        .await?;
238        // Ack immediately when stream is created successfully (in Init state)
239        let _ = tx.send(Ok(DoPutResponse::new(0, 0, 0.0))).await;
240        self.put_record_batches(stream, tx, query_ctx).await;
241
242        let response = ReceiverStream::new(rx)
243            .and_then(|response| {
244                future::ready({
245                    serde_json::to_vec(&response)
246                        .context(ToJsonSnafu)
247                        .map(|x| PutResult {
248                            app_metadata: Bytes::from(x),
249                        })
250                        .map_err(Into::into)
251                })
252            })
253            .boxed();
254        Ok(Response::new(response))
255    }
256}
257
258pub struct PutRecordBatchRequest {
259    pub table_name: TableName,
260    pub request_id: i64,
261    pub record_batch: DfRecordBatch,
262    pub schema_bytes: Bytes,
263    pub flight_data: FlightData,
264    pub(crate) _guard: Option<MemoryGuard<RequestMemoryMetrics>>,
265}
266
267impl PutRecordBatchRequest {
268    fn try_new(
269        table_name: TableName,
270        record_batch: DfRecordBatch,
271        request_id: i64,
272        schema_bytes: Bytes,
273        flight_data: FlightData,
274        limiter: Option<&ServerMemoryLimiter>,
275    ) -> Result<Self> {
276        let memory_usage = flight_data.data_body.len()
277            + flight_data.app_metadata.len()
278            + flight_data.data_header.len();
279
280        let _guard = if let Some(limiter) = limiter {
281            let guard = limiter.try_acquire(memory_usage as u64).ok_or_else(|| {
282                let inner_err = common_memory_manager::Error::MemoryLimitExceeded {
283                    requested_bytes: memory_usage as u64,
284                    limit_bytes: limiter.limit_bytes(),
285                };
286                error::MemoryLimitExceededSnafu.into_error(inner_err)
287            })?;
288            Some(guard)
289        } else {
290            None
291        };
292
293        Ok(Self {
294            table_name,
295            request_id,
296            record_batch,
297            schema_bytes,
298            flight_data,
299            _guard,
300        })
301    }
302}
303
304pub struct PutRecordBatchRequestStream {
305    flight_data_stream: Streaming<FlightData>,
306    catalog: String,
307    schema_name: String,
308    limiter: Option<ServerMemoryLimiter>,
309    // Client now lazily sends schema data so we cannot eagerly wait for it.
310    // Instead, we need to decode while receiving record batches.
311    state: StreamState,
312}
313
314enum StreamState {
315    Init,
316    Ready {
317        table_name: TableName,
318        schema: SchemaRef,
319        schema_bytes: Bytes,
320        decoder: FlightDecoder,
321    },
322}
323
324impl PutRecordBatchRequestStream {
325    /// Creates a new `PutRecordBatchRequestStream` in Init state.
326    /// The stream will transition to Ready state when it receives the schema message.
327    pub async fn new(
328        flight_data_stream: Streaming<FlightData>,
329        catalog: String,
330        schema: String,
331        limiter: Option<ServerMemoryLimiter>,
332    ) -> TonicResult<Self> {
333        Ok(Self {
334            flight_data_stream,
335            catalog,
336            schema_name: schema,
337            limiter,
338            state: StreamState::Init,
339        })
340    }
341
342    /// Returns the table name extracted from the flight descriptor.
343    /// Returns None if the stream is still in Init state.
344    pub fn table_name(&self) -> Option<&TableName> {
345        match &self.state {
346            StreamState::Init => None,
347            StreamState::Ready { table_name, .. } => Some(table_name),
348        }
349    }
350
351    /// Returns the Arrow schema decoded from the first flight message.
352    /// Returns None if the stream is still in Init state.
353    pub fn schema(&self) -> Option<&SchemaRef> {
354        match &self.state {
355            StreamState::Init => None,
356            StreamState::Ready { schema, .. } => Some(schema),
357        }
358    }
359
360    /// Returns the raw schema bytes in IPC format.
361    /// Returns None if the stream is still in Init state.
362    pub fn schema_bytes(&self) -> Option<&Bytes> {
363        match &self.state {
364            StreamState::Init => None,
365            StreamState::Ready { schema_bytes, .. } => Some(schema_bytes),
366        }
367    }
368
369    fn extract_table_name(mut descriptor: FlightDescriptor) -> Result<String> {
370        ensure!(
371            descriptor.r#type == arrow_flight::flight_descriptor::DescriptorType::Path as i32,
372            InvalidParameterSnafu {
373                reason: "expect FlightDescriptor::type == 'Path' only",
374            }
375        );
376        ensure!(
377            descriptor.path.len() == 1,
378            InvalidParameterSnafu {
379                reason: "expect FlightDescriptor::path has only one table name",
380            }
381        );
382        Ok(descriptor.path.remove(0))
383    }
384}
385
386impl Stream for PutRecordBatchRequestStream {
387    type Item = TonicResult<PutRecordBatchRequest>;
388
389    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
390        loop {
391            let poll = ready!(self.flight_data_stream.poll_next_unpin(cx));
392
393            match poll {
394                Some(Ok(flight_data)) => {
395                    let limiter = self.limiter.clone();
396
397                    match &mut self.state {
398                        StreamState::Init => {
399                            // First message - expecting schema
400                            let flight_descriptor = match flight_data.flight_descriptor.as_ref() {
401                                Some(descriptor) => descriptor.clone(),
402                                None => {
403                                    return Poll::Ready(Some(Err(Status::failed_precondition(
404                                        "table to put is not found in flight descriptor",
405                                    ))));
406                                }
407                            };
408
409                            let table_name_str = match Self::extract_table_name(flight_descriptor) {
410                                Ok(name) => name,
411                                Err(e) => {
412                                    return Poll::Ready(Some(Err(Status::invalid_argument(
413                                        e.to_string(),
414                                    ))));
415                                }
416                            };
417                            let table_name = TableName::new(
418                                self.catalog.clone(),
419                                self.schema_name.clone(),
420                                table_name_str,
421                            );
422
423                            // Decode the schema
424                            let mut decoder = FlightDecoder::default();
425                            let schema_message = decoder.try_decode(&flight_data).map_err(|e| {
426                                Status::invalid_argument(format!("Failed to decode schema: {}", e))
427                            })?;
428
429                            match schema_message {
430                                Some(FlightMessage::Schema(schema)) => {
431                                    let schema_bytes = decoder.schema_bytes().ok_or_else(|| {
432                                        Status::internal(
433                                            "decoder should have schema bytes after decoding schema",
434                                        )
435                                    })?;
436
437                                    // Transition to Ready state with all necessary data
438                                    self.state = StreamState::Ready {
439                                        table_name,
440                                        schema,
441                                        schema_bytes,
442                                        decoder,
443                                    };
444                                    // Continue to next iteration to process RecordBatch messages
445                                    continue;
446                                }
447                                _ => {
448                                    return Poll::Ready(Some(Err(Status::failed_precondition(
449                                        "first message must be a Schema message",
450                                    ))));
451                                }
452                            }
453                        }
454                        StreamState::Ready {
455                            table_name,
456                            schema: _,
457                            schema_bytes,
458                            decoder,
459                        } => {
460                            // Extract request_id and body_size from FlightData before decoding
461                            let request_id = if !flight_data.app_metadata.is_empty() {
462                                serde_json::from_slice::<DoPutMetadata>(&flight_data.app_metadata)
463                                    .map(|meta| meta.request_id())
464                                    .unwrap_or_default()
465                            } else {
466                                0
467                            };
468
469                            // Decode FlightData to RecordBatch
470                            match decoder.try_decode(&flight_data) {
471                                Ok(Some(FlightMessage::RecordBatch(record_batch))) => {
472                                    let table_name = table_name.clone();
473                                    let schema_bytes = schema_bytes.clone();
474                                    return Poll::Ready(Some(
475                                        PutRecordBatchRequest::try_new(
476                                            table_name,
477                                            record_batch,
478                                            request_id,
479                                            schema_bytes,
480                                            flight_data,
481                                            limiter.as_ref(),
482                                        )
483                                        .map_err(|e| Status::invalid_argument(e.to_string())),
484                                    ));
485                                }
486                                Ok(Some(other)) => {
487                                    debug!("Unexpected flight message: {:?}", other);
488                                    return Poll::Ready(Some(Err(Status::invalid_argument(
489                                        "Expected RecordBatch message, got other message type",
490                                    ))));
491                                }
492                                Ok(None) => {
493                                    // Dictionary batch - processed internally by decoder, continue polling
494                                    continue;
495                                }
496                                Err(e) => {
497                                    return Poll::Ready(Some(Err(Status::invalid_argument(
498                                        format!("Failed to decode RecordBatch: {}", e),
499                                    ))));
500                                }
501                            }
502                        }
503                    }
504                }
505                Some(Err(e)) => {
506                    return Poll::Ready(Some(Err(e)));
507                }
508                None => {
509                    return Poll::Ready(None);
510                }
511            }
512        }
513    }
514}
515
516fn to_flight_data_stream(
517    output: Output,
518    tracing_context: TracingContext,
519    flight_compression: FlightCompression,
520    query_ctx: QueryContextRef,
521) -> TonicStream<FlightData> {
522    match output.data {
523        OutputData::Stream(stream) => {
524            let stream = FlightRecordBatchStream::new(
525                stream,
526                tracing_context,
527                flight_compression,
528                query_ctx,
529            );
530            Box::pin(stream) as _
531        }
532        OutputData::RecordBatches(x) => {
533            let stream = FlightRecordBatchStream::new(
534                x.as_stream(),
535                tracing_context,
536                flight_compression,
537                query_ctx,
538            );
539            Box::pin(stream) as _
540        }
541        OutputData::AffectedRows(rows) => {
542            let stream = tokio_stream::iter(
543                FlightEncoder::default()
544                    .encode(FlightMessage::AffectedRows(rows))
545                    .into_iter()
546                    .map(Ok),
547            );
548            Box::pin(stream) as _
549        }
550    }
551}