servers/grpc/
flight.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod stream;
16
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21use api::v1::GreptimeRequest;
22use arrow_flight::flight_service_server::FlightService;
23use arrow_flight::{
24    Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
25    HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
26};
27use async_trait::async_trait;
28use bytes::Bytes;
29use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse};
30use common_grpc::flight::{FlightEncoder, FlightMessage};
31use common_query::{Output, OutputData};
32use common_telemetry::tracing::info_span;
33use common_telemetry::tracing_context::{FutureExt, TracingContext};
34use futures::{Stream, future, ready};
35use futures_util::{StreamExt, TryStreamExt};
36use prost::Message;
37use session::context::{QueryContext, QueryContextRef};
38use snafu::{ResultExt, ensure};
39use table::table_name::TableName;
40use tokio::sync::mpsc;
41use tokio_stream::wrappers::ReceiverStream;
42use tonic::{Request, Response, Status, Streaming};
43
44use crate::error::{InvalidParameterSnafu, ParseJsonSnafu, Result, ToJsonSnafu};
45pub use crate::grpc::flight::stream::FlightRecordBatchStream;
46use crate::grpc::greptime_handler::{GreptimeRequestHandler, get_request_type};
47use crate::grpc::{FlightCompression, TonicResult, context_auth};
48use crate::{error, hint_headers};
49
50pub type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + 'static>>;
51
52/// A subset of [FlightService]
53#[async_trait]
54pub trait FlightCraft: Send + Sync + 'static {
55    async fn do_get(
56        &self,
57        request: Request<Ticket>,
58    ) -> TonicResult<Response<TonicStream<FlightData>>>;
59
60    async fn do_put(
61        &self,
62        request: Request<Streaming<FlightData>>,
63    ) -> TonicResult<Response<TonicStream<PutResult>>> {
64        let _ = request;
65        Err(Status::unimplemented("Not yet implemented"))
66    }
67}
68
69pub type FlightCraftRef = Arc<dyn FlightCraft>;
70
71pub struct FlightCraftWrapper<T: FlightCraft>(pub T);
72
73impl<T: FlightCraft> From<T> for FlightCraftWrapper<T> {
74    fn from(t: T) -> Self {
75        Self(t)
76    }
77}
78
79#[async_trait]
80impl FlightCraft for FlightCraftRef {
81    async fn do_get(
82        &self,
83        request: Request<Ticket>,
84    ) -> TonicResult<Response<TonicStream<FlightData>>> {
85        (**self).do_get(request).await
86    }
87
88    async fn do_put(
89        &self,
90        request: Request<Streaming<FlightData>>,
91    ) -> TonicResult<Response<TonicStream<PutResult>>> {
92        self.as_ref().do_put(request).await
93    }
94}
95
96#[async_trait]
97impl<T: FlightCraft> FlightService for FlightCraftWrapper<T> {
98    type HandshakeStream = TonicStream<HandshakeResponse>;
99
100    async fn handshake(
101        &self,
102        _: Request<Streaming<HandshakeRequest>>,
103    ) -> TonicResult<Response<Self::HandshakeStream>> {
104        Err(Status::unimplemented("Not yet implemented"))
105    }
106
107    type ListFlightsStream = TonicStream<FlightInfo>;
108
109    async fn list_flights(
110        &self,
111        _: Request<Criteria>,
112    ) -> TonicResult<Response<Self::ListFlightsStream>> {
113        Err(Status::unimplemented("Not yet implemented"))
114    }
115
116    async fn get_flight_info(
117        &self,
118        _: Request<FlightDescriptor>,
119    ) -> TonicResult<Response<FlightInfo>> {
120        Err(Status::unimplemented("Not yet implemented"))
121    }
122
123    async fn poll_flight_info(
124        &self,
125        _: Request<FlightDescriptor>,
126    ) -> TonicResult<Response<PollInfo>> {
127        Err(Status::unimplemented("Not yet implemented"))
128    }
129
130    async fn get_schema(
131        &self,
132        _: Request<FlightDescriptor>,
133    ) -> TonicResult<Response<SchemaResult>> {
134        Err(Status::unimplemented("Not yet implemented"))
135    }
136
137    type DoGetStream = TonicStream<FlightData>;
138
139    async fn do_get(&self, request: Request<Ticket>) -> TonicResult<Response<Self::DoGetStream>> {
140        self.0.do_get(request).await
141    }
142
143    type DoPutStream = TonicStream<PutResult>;
144
145    async fn do_put(
146        &self,
147        request: Request<Streaming<FlightData>>,
148    ) -> TonicResult<Response<Self::DoPutStream>> {
149        self.0.do_put(request).await
150    }
151
152    type DoExchangeStream = TonicStream<FlightData>;
153
154    async fn do_exchange(
155        &self,
156        _: Request<Streaming<FlightData>>,
157    ) -> TonicResult<Response<Self::DoExchangeStream>> {
158        Err(Status::unimplemented("Not yet implemented"))
159    }
160
161    type DoActionStream = TonicStream<arrow_flight::Result>;
162
163    async fn do_action(&self, _: Request<Action>) -> TonicResult<Response<Self::DoActionStream>> {
164        Err(Status::unimplemented("Not yet implemented"))
165    }
166
167    type ListActionsStream = TonicStream<ActionType>;
168
169    async fn list_actions(
170        &self,
171        _: Request<Empty>,
172    ) -> TonicResult<Response<Self::ListActionsStream>> {
173        Err(Status::unimplemented("Not yet implemented"))
174    }
175}
176
177#[async_trait]
178impl FlightCraft for GreptimeRequestHandler {
179    async fn do_get(
180        &self,
181        request: Request<Ticket>,
182    ) -> TonicResult<Response<TonicStream<FlightData>>> {
183        let hints = hint_headers::extract_hints(request.metadata());
184
185        let ticket = request.into_inner().ticket;
186        let request =
187            GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?;
188
189        // The Grpc protocol pass query by Flight. It needs to be wrapped under a span, in order to record stream
190        let span = info_span!(
191            "GreptimeRequestHandler::do_get",
192            protocol = "grpc",
193            request_type = get_request_type(&request)
194        );
195        let flight_compression = self.flight_compression;
196        async {
197            let output = self.handle_request(request, hints).await?;
198            let stream = to_flight_data_stream(
199                output,
200                TracingContext::from_current_span(),
201                flight_compression,
202                QueryContext::arc(),
203            );
204            Ok(Response::new(stream))
205        }
206        .trace(span)
207        .await
208    }
209
210    async fn do_put(
211        &self,
212        request: Request<Streaming<FlightData>>,
213    ) -> TonicResult<Response<TonicStream<PutResult>>> {
214        let (headers, _, stream) = request.into_parts();
215
216        let query_ctx = context_auth::create_query_context_from_grpc_metadata(&headers)?;
217        context_auth::check_auth(self.user_provider.clone(), &headers, query_ctx.clone()).await?;
218
219        const MAX_PENDING_RESPONSES: usize = 32;
220        let (tx, rx) = mpsc::channel::<TonicResult<DoPutResponse>>(MAX_PENDING_RESPONSES);
221
222        let stream = PutRecordBatchRequestStream {
223            flight_data_stream: stream,
224            state: PutRecordBatchRequestStreamState::Init(
225                query_ctx.current_catalog().to_string(),
226                query_ctx.current_schema(),
227            ),
228        };
229        self.put_record_batches(stream, tx, query_ctx).await;
230
231        let response = ReceiverStream::new(rx)
232            .and_then(|response| {
233                future::ready({
234                    serde_json::to_vec(&response)
235                        .context(ToJsonSnafu)
236                        .map(|x| PutResult {
237                            app_metadata: Bytes::from(x),
238                        })
239                        .map_err(Into::into)
240                })
241            })
242            .boxed();
243        Ok(Response::new(response))
244    }
245}
246
247pub(crate) struct PutRecordBatchRequest {
248    pub(crate) table_name: TableName,
249    pub(crate) request_id: i64,
250    pub(crate) data: FlightData,
251}
252
253impl PutRecordBatchRequest {
254    fn try_new(table_name: TableName, flight_data: FlightData) -> Result<Self> {
255        let request_id = if !flight_data.app_metadata.is_empty() {
256            let metadata: DoPutMetadata =
257                serde_json::from_slice(&flight_data.app_metadata).context(ParseJsonSnafu)?;
258            metadata.request_id()
259        } else {
260            0
261        };
262        Ok(Self {
263            table_name,
264            request_id,
265            data: flight_data,
266        })
267    }
268}
269
270pub(crate) struct PutRecordBatchRequestStream {
271    flight_data_stream: Streaming<FlightData>,
272    state: PutRecordBatchRequestStreamState,
273}
274
275enum PutRecordBatchRequestStreamState {
276    Init(String, String),
277    Started(TableName),
278}
279
280impl Stream for PutRecordBatchRequestStream {
281    type Item = TonicResult<PutRecordBatchRequest>;
282
283    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
284        fn extract_table_name(mut descriptor: FlightDescriptor) -> Result<String> {
285            ensure!(
286                descriptor.r#type == arrow_flight::flight_descriptor::DescriptorType::Path as i32,
287                InvalidParameterSnafu {
288                    reason: "expect FlightDescriptor::type == 'Path' only",
289                }
290            );
291            ensure!(
292                descriptor.path.len() == 1,
293                InvalidParameterSnafu {
294                    reason: "expect FlightDescriptor::path has only one table name",
295                }
296            );
297            Ok(descriptor.path.remove(0))
298        }
299
300        let poll = ready!(self.flight_data_stream.poll_next_unpin(cx));
301
302        let result = match &mut self.state {
303            PutRecordBatchRequestStreamState::Init(catalog, schema) => match poll {
304                Some(Ok(mut flight_data)) => {
305                    let flight_descriptor = flight_data.flight_descriptor.take();
306                    let result = if let Some(descriptor) = flight_descriptor {
307                        let table_name = extract_table_name(descriptor)
308                            .map(|x| TableName::new(catalog.clone(), schema.clone(), x));
309                        let table_name = match table_name {
310                            Ok(table_name) => table_name,
311                            Err(e) => return Poll::Ready(Some(Err(e.into()))),
312                        };
313
314                        let request =
315                            PutRecordBatchRequest::try_new(table_name.clone(), flight_data);
316                        let request = match request {
317                            Ok(request) => request,
318                            Err(e) => return Poll::Ready(Some(Err(e.into()))),
319                        };
320
321                        self.state = PutRecordBatchRequestStreamState::Started(table_name);
322
323                        Ok(request)
324                    } else {
325                        Err(Status::failed_precondition(
326                            "table to put is not found in flight descriptor",
327                        ))
328                    };
329                    Some(result)
330                }
331                Some(Err(e)) => Some(Err(e)),
332                None => None,
333            },
334            PutRecordBatchRequestStreamState::Started(table_name) => poll.map(|x| {
335                x.and_then(|flight_data| {
336                    PutRecordBatchRequest::try_new(table_name.clone(), flight_data)
337                        .map_err(Into::into)
338                })
339            }),
340        };
341        Poll::Ready(result)
342    }
343}
344
345fn to_flight_data_stream(
346    output: Output,
347    tracing_context: TracingContext,
348    flight_compression: FlightCompression,
349    query_ctx: QueryContextRef,
350) -> TonicStream<FlightData> {
351    match output.data {
352        OutputData::Stream(stream) => {
353            let stream = FlightRecordBatchStream::new(
354                stream,
355                tracing_context,
356                flight_compression,
357                query_ctx,
358            );
359            Box::pin(stream) as _
360        }
361        OutputData::RecordBatches(x) => {
362            let stream = FlightRecordBatchStream::new(
363                x.as_stream(),
364                tracing_context,
365                flight_compression,
366                query_ctx,
367            );
368            Box::pin(stream) as _
369        }
370        OutputData::AffectedRows(rows) => {
371            let stream = tokio_stream::iter(
372                FlightEncoder::default()
373                    .encode(FlightMessage::AffectedRows(rows))
374                    .into_iter()
375                    .map(Ok),
376            );
377            Box::pin(stream) as _
378        }
379    }
380}