servers/grpc/
flight.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod stream;
16
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21use api::v1::GreptimeRequest;
22use arrow_flight::flight_service_server::FlightService;
23use arrow_flight::{
24    Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
25    HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
26};
27use async_trait::async_trait;
28use bytes::Bytes;
29use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse};
30use common_grpc::flight::{FlightEncoder, FlightMessage};
31use common_query::{Output, OutputData};
32use common_telemetry::tracing::info_span;
33use common_telemetry::tracing_context::{FutureExt, TracingContext};
34use futures::{Stream, future, ready};
35use futures_util::{StreamExt, TryStreamExt};
36use prost::Message;
37use session::context::{QueryContext, QueryContextRef};
38use snafu::{ResultExt, ensure};
39use table::table_name::TableName;
40use tokio::sync::mpsc;
41use tokio_stream::wrappers::ReceiverStream;
42use tonic::{Request, Response, Status, Streaming};
43
44use crate::error::{InvalidParameterSnafu, ParseJsonSnafu, Result, ToJsonSnafu};
45pub use crate::grpc::flight::stream::FlightRecordBatchStream;
46use crate::grpc::greptime_handler::{GreptimeRequestHandler, get_request_type};
47use crate::grpc::{FlightCompression, TonicResult, context_auth};
48use crate::metrics::{METRIC_GRPC_MEMORY_USAGE_BYTES, METRIC_GRPC_REQUESTS_REJECTED_TOTAL};
49use crate::request_limiter::{RequestMemoryGuard, RequestMemoryLimiter};
50use crate::{error, hint_headers};
51
52pub type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + 'static>>;
53
54/// A subset of [FlightService]
55#[async_trait]
56pub trait FlightCraft: Send + Sync + 'static {
57    async fn do_get(
58        &self,
59        request: Request<Ticket>,
60    ) -> TonicResult<Response<TonicStream<FlightData>>>;
61
62    async fn do_put(
63        &self,
64        request: Request<Streaming<FlightData>>,
65    ) -> TonicResult<Response<TonicStream<PutResult>>> {
66        let _ = request;
67        Err(Status::unimplemented("Not yet implemented"))
68    }
69}
70
71pub type FlightCraftRef = Arc<dyn FlightCraft>;
72
73pub struct FlightCraftWrapper<T: FlightCraft>(pub T);
74
75impl<T: FlightCraft> From<T> for FlightCraftWrapper<T> {
76    fn from(t: T) -> Self {
77        Self(t)
78    }
79}
80
81#[async_trait]
82impl FlightCraft for FlightCraftRef {
83    async fn do_get(
84        &self,
85        request: Request<Ticket>,
86    ) -> TonicResult<Response<TonicStream<FlightData>>> {
87        (**self).do_get(request).await
88    }
89
90    async fn do_put(
91        &self,
92        request: Request<Streaming<FlightData>>,
93    ) -> TonicResult<Response<TonicStream<PutResult>>> {
94        self.as_ref().do_put(request).await
95    }
96}
97
98#[async_trait]
99impl<T: FlightCraft> FlightService for FlightCraftWrapper<T> {
100    type HandshakeStream = TonicStream<HandshakeResponse>;
101
102    async fn handshake(
103        &self,
104        _: Request<Streaming<HandshakeRequest>>,
105    ) -> TonicResult<Response<Self::HandshakeStream>> {
106        Err(Status::unimplemented("Not yet implemented"))
107    }
108
109    type ListFlightsStream = TonicStream<FlightInfo>;
110
111    async fn list_flights(
112        &self,
113        _: Request<Criteria>,
114    ) -> TonicResult<Response<Self::ListFlightsStream>> {
115        Err(Status::unimplemented("Not yet implemented"))
116    }
117
118    async fn get_flight_info(
119        &self,
120        _: Request<FlightDescriptor>,
121    ) -> TonicResult<Response<FlightInfo>> {
122        Err(Status::unimplemented("Not yet implemented"))
123    }
124
125    async fn poll_flight_info(
126        &self,
127        _: Request<FlightDescriptor>,
128    ) -> TonicResult<Response<PollInfo>> {
129        Err(Status::unimplemented("Not yet implemented"))
130    }
131
132    async fn get_schema(
133        &self,
134        _: Request<FlightDescriptor>,
135    ) -> TonicResult<Response<SchemaResult>> {
136        Err(Status::unimplemented("Not yet implemented"))
137    }
138
139    type DoGetStream = TonicStream<FlightData>;
140
141    async fn do_get(&self, request: Request<Ticket>) -> TonicResult<Response<Self::DoGetStream>> {
142        self.0.do_get(request).await
143    }
144
145    type DoPutStream = TonicStream<PutResult>;
146
147    async fn do_put(
148        &self,
149        request: Request<Streaming<FlightData>>,
150    ) -> TonicResult<Response<Self::DoPutStream>> {
151        self.0.do_put(request).await
152    }
153
154    type DoExchangeStream = TonicStream<FlightData>;
155
156    async fn do_exchange(
157        &self,
158        _: Request<Streaming<FlightData>>,
159    ) -> TonicResult<Response<Self::DoExchangeStream>> {
160        Err(Status::unimplemented("Not yet implemented"))
161    }
162
163    type DoActionStream = TonicStream<arrow_flight::Result>;
164
165    async fn do_action(&self, _: Request<Action>) -> TonicResult<Response<Self::DoActionStream>> {
166        Err(Status::unimplemented("Not yet implemented"))
167    }
168
169    type ListActionsStream = TonicStream<ActionType>;
170
171    async fn list_actions(
172        &self,
173        _: Request<Empty>,
174    ) -> TonicResult<Response<Self::ListActionsStream>> {
175        Err(Status::unimplemented("Not yet implemented"))
176    }
177}
178
179#[async_trait]
180impl FlightCraft for GreptimeRequestHandler {
181    async fn do_get(
182        &self,
183        request: Request<Ticket>,
184    ) -> TonicResult<Response<TonicStream<FlightData>>> {
185        let hints = hint_headers::extract_hints(request.metadata());
186
187        let ticket = request.into_inner().ticket;
188        let request =
189            GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?;
190
191        // The Grpc protocol pass query by Flight. It needs to be wrapped under a span, in order to record stream
192        let span = info_span!(
193            "GreptimeRequestHandler::do_get",
194            protocol = "grpc",
195            request_type = get_request_type(&request)
196        );
197        let flight_compression = self.flight_compression;
198        async {
199            let output = self.handle_request(request, hints).await?;
200            let stream = to_flight_data_stream(
201                output,
202                TracingContext::from_current_span(),
203                flight_compression,
204                QueryContext::arc(),
205            );
206            Ok(Response::new(stream))
207        }
208        .trace(span)
209        .await
210    }
211
212    async fn do_put(
213        &self,
214        request: Request<Streaming<FlightData>>,
215    ) -> TonicResult<Response<TonicStream<PutResult>>> {
216        let (headers, extensions, stream) = request.into_parts();
217
218        let limiter = extensions.get::<RequestMemoryLimiter>().cloned();
219
220        let query_ctx = context_auth::create_query_context_from_grpc_metadata(&headers)?;
221        context_auth::check_auth(self.user_provider.clone(), &headers, query_ctx.clone()).await?;
222
223        const MAX_PENDING_RESPONSES: usize = 32;
224        let (tx, rx) = mpsc::channel::<TonicResult<DoPutResponse>>(MAX_PENDING_RESPONSES);
225
226        let stream = PutRecordBatchRequestStream {
227            flight_data_stream: stream,
228            state: PutRecordBatchRequestStreamState::Init(
229                query_ctx.current_catalog().to_string(),
230                query_ctx.current_schema(),
231            ),
232            limiter,
233        };
234        self.put_record_batches(stream, tx, query_ctx).await;
235
236        let response = ReceiverStream::new(rx)
237            .and_then(|response| {
238                future::ready({
239                    serde_json::to_vec(&response)
240                        .context(ToJsonSnafu)
241                        .map(|x| PutResult {
242                            app_metadata: Bytes::from(x),
243                        })
244                        .map_err(Into::into)
245                })
246            })
247            .boxed();
248        Ok(Response::new(response))
249    }
250}
251
252pub(crate) struct PutRecordBatchRequest {
253    pub(crate) table_name: TableName,
254    pub(crate) request_id: i64,
255    pub(crate) data: FlightData,
256    pub(crate) _guard: Option<RequestMemoryGuard>,
257}
258
259impl PutRecordBatchRequest {
260    fn try_new(
261        table_name: TableName,
262        flight_data: FlightData,
263        limiter: Option<&RequestMemoryLimiter>,
264    ) -> Result<Self> {
265        let request_id = if !flight_data.app_metadata.is_empty() {
266            let metadata: DoPutMetadata =
267                serde_json::from_slice(&flight_data.app_metadata).context(ParseJsonSnafu)?;
268            metadata.request_id()
269        } else {
270            0
271        };
272
273        let _guard = limiter
274            .filter(|limiter| limiter.is_enabled())
275            .map(|limiter| {
276                let message_size = flight_data.encoded_len();
277                limiter
278                    .try_acquire(message_size)
279                    .map(|guard| {
280                        guard.inspect(|g| {
281                            METRIC_GRPC_MEMORY_USAGE_BYTES.set(g.current_usage() as i64);
282                        })
283                    })
284                    .inspect_err(|_| {
285                        METRIC_GRPC_REQUESTS_REJECTED_TOTAL.inc();
286                    })
287            })
288            .transpose()?
289            .flatten();
290
291        Ok(Self {
292            table_name,
293            request_id,
294            data: flight_data,
295            _guard,
296        })
297    }
298}
299
300pub(crate) struct PutRecordBatchRequestStream {
301    flight_data_stream: Streaming<FlightData>,
302    state: PutRecordBatchRequestStreamState,
303    limiter: Option<RequestMemoryLimiter>,
304}
305
306enum PutRecordBatchRequestStreamState {
307    Init(String, String),
308    Started(TableName),
309}
310
311impl Stream for PutRecordBatchRequestStream {
312    type Item = TonicResult<PutRecordBatchRequest>;
313
314    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
315        fn extract_table_name(mut descriptor: FlightDescriptor) -> Result<String> {
316            ensure!(
317                descriptor.r#type == arrow_flight::flight_descriptor::DescriptorType::Path as i32,
318                InvalidParameterSnafu {
319                    reason: "expect FlightDescriptor::type == 'Path' only",
320                }
321            );
322            ensure!(
323                descriptor.path.len() == 1,
324                InvalidParameterSnafu {
325                    reason: "expect FlightDescriptor::path has only one table name",
326                }
327            );
328            Ok(descriptor.path.remove(0))
329        }
330
331        let poll = ready!(self.flight_data_stream.poll_next_unpin(cx));
332        let limiter = self.limiter.clone();
333
334        let result = match &mut self.state {
335            PutRecordBatchRequestStreamState::Init(catalog, schema) => match poll {
336                Some(Ok(mut flight_data)) => {
337                    let flight_descriptor = flight_data.flight_descriptor.take();
338                    let result = if let Some(descriptor) = flight_descriptor {
339                        let table_name = extract_table_name(descriptor)
340                            .map(|x| TableName::new(catalog.clone(), schema.clone(), x));
341                        let table_name = match table_name {
342                            Ok(table_name) => table_name,
343                            Err(e) => return Poll::Ready(Some(Err(e.into()))),
344                        };
345
346                        let request = PutRecordBatchRequest::try_new(
347                            table_name.clone(),
348                            flight_data,
349                            limiter.as_ref(),
350                        );
351                        let request = match request {
352                            Ok(request) => request,
353                            Err(e) => return Poll::Ready(Some(Err(e.into()))),
354                        };
355
356                        self.state = PutRecordBatchRequestStreamState::Started(table_name);
357
358                        Ok(request)
359                    } else {
360                        Err(Status::failed_precondition(
361                            "table to put is not found in flight descriptor",
362                        ))
363                    };
364                    Some(result)
365                }
366                Some(Err(e)) => Some(Err(e)),
367                None => None,
368            },
369            PutRecordBatchRequestStreamState::Started(table_name) => poll.map(|x| {
370                x.and_then(|flight_data| {
371                    PutRecordBatchRequest::try_new(
372                        table_name.clone(),
373                        flight_data,
374                        limiter.as_ref(),
375                    )
376                    .map_err(Into::into)
377                })
378            }),
379        };
380        Poll::Ready(result)
381    }
382}
383
384fn to_flight_data_stream(
385    output: Output,
386    tracing_context: TracingContext,
387    flight_compression: FlightCompression,
388    query_ctx: QueryContextRef,
389) -> TonicStream<FlightData> {
390    match output.data {
391        OutputData::Stream(stream) => {
392            let stream = FlightRecordBatchStream::new(
393                stream,
394                tracing_context,
395                flight_compression,
396                query_ctx,
397            );
398            Box::pin(stream) as _
399        }
400        OutputData::RecordBatches(x) => {
401            let stream = FlightRecordBatchStream::new(
402                x.as_stream(),
403                tracing_context,
404                flight_compression,
405                query_ctx,
406            );
407            Box::pin(stream) as _
408        }
409        OutputData::AffectedRows(rows) => {
410            let stream = tokio_stream::iter(
411                FlightEncoder::default()
412                    .encode(FlightMessage::AffectedRows(rows))
413                    .into_iter()
414                    .map(Ok),
415            );
416            Box::pin(stream) as _
417        }
418    }
419}