servers/grpc/
flight.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod stream;
16
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21use api::v1::GreptimeRequest;
22use arrow_flight::flight_service_server::FlightService;
23use arrow_flight::{
24    Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
25    HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
26};
27use async_trait::async_trait;
28use bytes::Bytes;
29use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
30use common_catalog::parse_catalog_and_schema_from_db_string;
31use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse};
32use common_grpc::flight::{FlightEncoder, FlightMessage};
33use common_query::{Output, OutputData};
34use common_telemetry::tracing::info_span;
35use common_telemetry::tracing_context::{FutureExt, TracingContext};
36use futures::{future, ready, Stream};
37use futures_util::{StreamExt, TryStreamExt};
38use prost::Message;
39use session::context::{QueryContext, QueryContextRef};
40use snafu::{ensure, ResultExt};
41use table::table_name::TableName;
42use tokio::sync::mpsc;
43use tokio_stream::wrappers::ReceiverStream;
44use tonic::{Request, Response, Status, Streaming};
45
46use crate::error::{InvalidParameterSnafu, ParseJsonSnafu, Result, ToJsonSnafu};
47pub use crate::grpc::flight::stream::FlightRecordBatchStream;
48use crate::grpc::greptime_handler::{get_request_type, GreptimeRequestHandler};
49use crate::grpc::{FlightCompression, TonicResult};
50use crate::http::header::constants::GREPTIME_DB_HEADER_NAME;
51use crate::http::AUTHORIZATION_HEADER;
52use crate::{error, hint_headers};
53
54pub type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + 'static>>;
55
56/// A subset of [FlightService]
57#[async_trait]
58pub trait FlightCraft: Send + Sync + 'static {
59    async fn do_get(
60        &self,
61        request: Request<Ticket>,
62    ) -> TonicResult<Response<TonicStream<FlightData>>>;
63
64    async fn do_put(
65        &self,
66        request: Request<Streaming<FlightData>>,
67    ) -> TonicResult<Response<TonicStream<PutResult>>> {
68        let _ = request;
69        Err(Status::unimplemented("Not yet implemented"))
70    }
71}
72
73pub type FlightCraftRef = Arc<dyn FlightCraft>;
74
75pub struct FlightCraftWrapper<T: FlightCraft>(pub T);
76
77impl<T: FlightCraft> From<T> for FlightCraftWrapper<T> {
78    fn from(t: T) -> Self {
79        Self(t)
80    }
81}
82
83#[async_trait]
84impl FlightCraft for FlightCraftRef {
85    async fn do_get(
86        &self,
87        request: Request<Ticket>,
88    ) -> TonicResult<Response<TonicStream<FlightData>>> {
89        (**self).do_get(request).await
90    }
91
92    async fn do_put(
93        &self,
94        request: Request<Streaming<FlightData>>,
95    ) -> TonicResult<Response<TonicStream<PutResult>>> {
96        self.as_ref().do_put(request).await
97    }
98}
99
100#[async_trait]
101impl<T: FlightCraft> FlightService for FlightCraftWrapper<T> {
102    type HandshakeStream = TonicStream<HandshakeResponse>;
103
104    async fn handshake(
105        &self,
106        _: Request<Streaming<HandshakeRequest>>,
107    ) -> TonicResult<Response<Self::HandshakeStream>> {
108        Err(Status::unimplemented("Not yet implemented"))
109    }
110
111    type ListFlightsStream = TonicStream<FlightInfo>;
112
113    async fn list_flights(
114        &self,
115        _: Request<Criteria>,
116    ) -> TonicResult<Response<Self::ListFlightsStream>> {
117        Err(Status::unimplemented("Not yet implemented"))
118    }
119
120    async fn get_flight_info(
121        &self,
122        _: Request<FlightDescriptor>,
123    ) -> TonicResult<Response<FlightInfo>> {
124        Err(Status::unimplemented("Not yet implemented"))
125    }
126
127    async fn poll_flight_info(
128        &self,
129        _: Request<FlightDescriptor>,
130    ) -> TonicResult<Response<PollInfo>> {
131        Err(Status::unimplemented("Not yet implemented"))
132    }
133
134    async fn get_schema(
135        &self,
136        _: Request<FlightDescriptor>,
137    ) -> TonicResult<Response<SchemaResult>> {
138        Err(Status::unimplemented("Not yet implemented"))
139    }
140
141    type DoGetStream = TonicStream<FlightData>;
142
143    async fn do_get(&self, request: Request<Ticket>) -> TonicResult<Response<Self::DoGetStream>> {
144        self.0.do_get(request).await
145    }
146
147    type DoPutStream = TonicStream<PutResult>;
148
149    async fn do_put(
150        &self,
151        request: Request<Streaming<FlightData>>,
152    ) -> TonicResult<Response<Self::DoPutStream>> {
153        self.0.do_put(request).await
154    }
155
156    type DoExchangeStream = TonicStream<FlightData>;
157
158    async fn do_exchange(
159        &self,
160        _: Request<Streaming<FlightData>>,
161    ) -> TonicResult<Response<Self::DoExchangeStream>> {
162        Err(Status::unimplemented("Not yet implemented"))
163    }
164
165    type DoActionStream = TonicStream<arrow_flight::Result>;
166
167    async fn do_action(&self, _: Request<Action>) -> TonicResult<Response<Self::DoActionStream>> {
168        Err(Status::unimplemented("Not yet implemented"))
169    }
170
171    type ListActionsStream = TonicStream<ActionType>;
172
173    async fn list_actions(
174        &self,
175        _: Request<Empty>,
176    ) -> TonicResult<Response<Self::ListActionsStream>> {
177        Err(Status::unimplemented("Not yet implemented"))
178    }
179}
180
181#[async_trait]
182impl FlightCraft for GreptimeRequestHandler {
183    async fn do_get(
184        &self,
185        request: Request<Ticket>,
186    ) -> TonicResult<Response<TonicStream<FlightData>>> {
187        let hints = hint_headers::extract_hints(request.metadata());
188
189        let ticket = request.into_inner().ticket;
190        let request =
191            GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?;
192        let query_ctx = QueryContext::arc();
193
194        // The Grpc protocol pass query by Flight. It needs to be wrapped under a span, in order to record stream
195        let span = info_span!(
196            "GreptimeRequestHandler::do_get",
197            protocol = "grpc",
198            request_type = get_request_type(&request)
199        );
200        let flight_compression = self.flight_compression;
201        async {
202            let output = self.handle_request(request, hints).await?;
203            let stream = to_flight_data_stream(
204                output,
205                TracingContext::from_current_span(),
206                flight_compression,
207                query_ctx,
208            );
209            Ok(Response::new(stream))
210        }
211        .trace(span)
212        .await
213    }
214
215    async fn do_put(
216        &self,
217        request: Request<Streaming<FlightData>>,
218    ) -> TonicResult<Response<TonicStream<PutResult>>> {
219        let (headers, _, stream) = request.into_parts();
220
221        let header = |key: &str| -> TonicResult<Option<&str>> {
222            let Some(v) = headers.get(key) else {
223                return Ok(None);
224            };
225            let Ok(v) = std::str::from_utf8(v.as_bytes()) else {
226                return Err(InvalidParameterSnafu {
227                    reason: "expect valid UTF-8 value",
228                }
229                .build()
230                .into());
231            };
232            Ok(Some(v))
233        };
234
235        let username_and_password = header(AUTHORIZATION_HEADER)?;
236        let db = header(GREPTIME_DB_HEADER_NAME)?;
237        if !self.validate_auth(username_and_password, db).await? {
238            return Err(Status::unauthenticated("auth failed"));
239        }
240
241        const MAX_PENDING_RESPONSES: usize = 32;
242        let (tx, rx) = mpsc::channel::<TonicResult<DoPutResponse>>(MAX_PENDING_RESPONSES);
243
244        let stream = PutRecordBatchRequestStream {
245            flight_data_stream: stream,
246            state: PutRecordBatchRequestStreamState::Init(db.map(ToString::to_string)),
247        };
248        self.put_record_batches(stream, tx).await;
249
250        let response = ReceiverStream::new(rx)
251            .and_then(|response| {
252                future::ready({
253                    serde_json::to_vec(&response)
254                        .context(ToJsonSnafu)
255                        .map(|x| PutResult {
256                            app_metadata: Bytes::from(x),
257                        })
258                        .map_err(Into::into)
259                })
260            })
261            .boxed();
262        Ok(Response::new(response))
263    }
264}
265
266pub(crate) struct PutRecordBatchRequest {
267    pub(crate) table_name: TableName,
268    pub(crate) request_id: i64,
269    pub(crate) data: FlightData,
270}
271
272impl PutRecordBatchRequest {
273    fn try_new(table_name: TableName, flight_data: FlightData) -> Result<Self> {
274        let request_id = if !flight_data.app_metadata.is_empty() {
275            let metadata: DoPutMetadata =
276                serde_json::from_slice(&flight_data.app_metadata).context(ParseJsonSnafu)?;
277            metadata.request_id()
278        } else {
279            0
280        };
281        Ok(Self {
282            table_name,
283            request_id,
284            data: flight_data,
285        })
286    }
287}
288
289pub(crate) struct PutRecordBatchRequestStream {
290    flight_data_stream: Streaming<FlightData>,
291    state: PutRecordBatchRequestStreamState,
292}
293
294enum PutRecordBatchRequestStreamState {
295    Init(Option<String>),
296    Started(TableName),
297}
298
299impl Stream for PutRecordBatchRequestStream {
300    type Item = TonicResult<PutRecordBatchRequest>;
301
302    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
303        fn extract_table_name(mut descriptor: FlightDescriptor) -> Result<String> {
304            ensure!(
305                descriptor.r#type == arrow_flight::flight_descriptor::DescriptorType::Path as i32,
306                InvalidParameterSnafu {
307                    reason: "expect FlightDescriptor::type == 'Path' only",
308                }
309            );
310            ensure!(
311                descriptor.path.len() == 1,
312                InvalidParameterSnafu {
313                    reason: "expect FlightDescriptor::path has only one table name",
314                }
315            );
316            Ok(descriptor.path.remove(0))
317        }
318
319        let poll = ready!(self.flight_data_stream.poll_next_unpin(cx));
320
321        let result = match &mut self.state {
322            PutRecordBatchRequestStreamState::Init(db) => match poll {
323                Some(Ok(mut flight_data)) => {
324                    let flight_descriptor = flight_data.flight_descriptor.take();
325                    let result = if let Some(descriptor) = flight_descriptor {
326                        let table_name = extract_table_name(descriptor).map(|x| {
327                            let (catalog, schema) = if let Some(db) = db {
328                                parse_catalog_and_schema_from_db_string(db)
329                            } else {
330                                (
331                                    DEFAULT_CATALOG_NAME.to_string(),
332                                    DEFAULT_SCHEMA_NAME.to_string(),
333                                )
334                            };
335                            TableName::new(catalog, schema, x)
336                        });
337                        let table_name = match table_name {
338                            Ok(table_name) => table_name,
339                            Err(e) => return Poll::Ready(Some(Err(e.into()))),
340                        };
341
342                        let request =
343                            PutRecordBatchRequest::try_new(table_name.clone(), flight_data);
344                        let request = match request {
345                            Ok(request) => request,
346                            Err(e) => return Poll::Ready(Some(Err(e.into()))),
347                        };
348
349                        self.state = PutRecordBatchRequestStreamState::Started(table_name);
350
351                        Ok(request)
352                    } else {
353                        Err(Status::failed_precondition(
354                            "table to put is not found in flight descriptor",
355                        ))
356                    };
357                    Some(result)
358                }
359                Some(Err(e)) => Some(Err(e)),
360                None => None,
361            },
362            PutRecordBatchRequestStreamState::Started(table_name) => poll.map(|x| {
363                x.and_then(|flight_data| {
364                    PutRecordBatchRequest::try_new(table_name.clone(), flight_data)
365                        .map_err(Into::into)
366                })
367            }),
368        };
369        Poll::Ready(result)
370    }
371}
372
373fn to_flight_data_stream(
374    output: Output,
375    tracing_context: TracingContext,
376    flight_compression: FlightCompression,
377    query_ctx: QueryContextRef,
378) -> TonicStream<FlightData> {
379    match output.data {
380        OutputData::Stream(stream) => {
381            let stream = FlightRecordBatchStream::new(
382                stream,
383                tracing_context,
384                flight_compression,
385                query_ctx,
386            );
387            Box::pin(stream) as _
388        }
389        OutputData::RecordBatches(x) => {
390            let stream = FlightRecordBatchStream::new(
391                x.as_stream(),
392                tracing_context,
393                flight_compression,
394                query_ctx,
395            );
396            Box::pin(stream) as _
397        }
398        OutputData::AffectedRows(rows) => {
399            let stream = tokio_stream::iter(
400                FlightEncoder::default()
401                    .encode(FlightMessage::AffectedRows(rows))
402                    .into_iter()
403                    .map(Ok),
404            );
405            Box::pin(stream) as _
406        }
407    }
408}