servers/grpc/
flight.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod stream;
16
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21use api::v1::GreptimeRequest;
22use arrow_flight::flight_service_server::FlightService;
23use arrow_flight::{
24    Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
25    HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
26};
27use async_trait::async_trait;
28use bytes::Bytes;
29use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
30use common_catalog::parse_catalog_and_schema_from_db_string;
31use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse};
32use common_grpc::flight::{FlightEncoder, FlightMessage};
33use common_query::{Output, OutputData};
34use common_telemetry::tracing::info_span;
35use common_telemetry::tracing_context::{FutureExt, TracingContext};
36use futures::{future, ready, Stream};
37use futures_util::{StreamExt, TryStreamExt};
38use prost::Message;
39use snafu::{ensure, ResultExt};
40use table::table_name::TableName;
41use tokio::sync::mpsc;
42use tokio_stream::wrappers::ReceiverStream;
43use tonic::{Request, Response, Status, Streaming};
44
45use crate::error::{InvalidParameterSnafu, ParseJsonSnafu, Result, ToJsonSnafu};
46pub use crate::grpc::flight::stream::FlightRecordBatchStream;
47use crate::grpc::greptime_handler::{get_request_type, GreptimeRequestHandler};
48use crate::grpc::{FlightCompression, TonicResult};
49use crate::http::header::constants::GREPTIME_DB_HEADER_NAME;
50use crate::http::AUTHORIZATION_HEADER;
51use crate::{error, hint_headers};
52
53pub type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + 'static>>;
54
55/// A subset of [FlightService]
56#[async_trait]
57pub trait FlightCraft: Send + Sync + 'static {
58    async fn do_get(
59        &self,
60        request: Request<Ticket>,
61    ) -> TonicResult<Response<TonicStream<FlightData>>>;
62
63    async fn do_put(
64        &self,
65        request: Request<Streaming<FlightData>>,
66    ) -> TonicResult<Response<TonicStream<PutResult>>> {
67        let _ = request;
68        Err(Status::unimplemented("Not yet implemented"))
69    }
70}
71
72pub type FlightCraftRef = Arc<dyn FlightCraft>;
73
74pub struct FlightCraftWrapper<T: FlightCraft>(pub T);
75
76impl<T: FlightCraft> From<T> for FlightCraftWrapper<T> {
77    fn from(t: T) -> Self {
78        Self(t)
79    }
80}
81
82#[async_trait]
83impl FlightCraft for FlightCraftRef {
84    async fn do_get(
85        &self,
86        request: Request<Ticket>,
87    ) -> TonicResult<Response<TonicStream<FlightData>>> {
88        (**self).do_get(request).await
89    }
90
91    async fn do_put(
92        &self,
93        request: Request<Streaming<FlightData>>,
94    ) -> TonicResult<Response<TonicStream<PutResult>>> {
95        self.as_ref().do_put(request).await
96    }
97}
98
99#[async_trait]
100impl<T: FlightCraft> FlightService for FlightCraftWrapper<T> {
101    type HandshakeStream = TonicStream<HandshakeResponse>;
102
103    async fn handshake(
104        &self,
105        _: Request<Streaming<HandshakeRequest>>,
106    ) -> TonicResult<Response<Self::HandshakeStream>> {
107        Err(Status::unimplemented("Not yet implemented"))
108    }
109
110    type ListFlightsStream = TonicStream<FlightInfo>;
111
112    async fn list_flights(
113        &self,
114        _: Request<Criteria>,
115    ) -> TonicResult<Response<Self::ListFlightsStream>> {
116        Err(Status::unimplemented("Not yet implemented"))
117    }
118
119    async fn get_flight_info(
120        &self,
121        _: Request<FlightDescriptor>,
122    ) -> TonicResult<Response<FlightInfo>> {
123        Err(Status::unimplemented("Not yet implemented"))
124    }
125
126    async fn poll_flight_info(
127        &self,
128        _: Request<FlightDescriptor>,
129    ) -> TonicResult<Response<PollInfo>> {
130        Err(Status::unimplemented("Not yet implemented"))
131    }
132
133    async fn get_schema(
134        &self,
135        _: Request<FlightDescriptor>,
136    ) -> TonicResult<Response<SchemaResult>> {
137        Err(Status::unimplemented("Not yet implemented"))
138    }
139
140    type DoGetStream = TonicStream<FlightData>;
141
142    async fn do_get(&self, request: Request<Ticket>) -> TonicResult<Response<Self::DoGetStream>> {
143        self.0.do_get(request).await
144    }
145
146    type DoPutStream = TonicStream<PutResult>;
147
148    async fn do_put(
149        &self,
150        request: Request<Streaming<FlightData>>,
151    ) -> TonicResult<Response<Self::DoPutStream>> {
152        self.0.do_put(request).await
153    }
154
155    type DoExchangeStream = TonicStream<FlightData>;
156
157    async fn do_exchange(
158        &self,
159        _: Request<Streaming<FlightData>>,
160    ) -> TonicResult<Response<Self::DoExchangeStream>> {
161        Err(Status::unimplemented("Not yet implemented"))
162    }
163
164    type DoActionStream = TonicStream<arrow_flight::Result>;
165
166    async fn do_action(&self, _: Request<Action>) -> TonicResult<Response<Self::DoActionStream>> {
167        Err(Status::unimplemented("Not yet implemented"))
168    }
169
170    type ListActionsStream = TonicStream<ActionType>;
171
172    async fn list_actions(
173        &self,
174        _: Request<Empty>,
175    ) -> TonicResult<Response<Self::ListActionsStream>> {
176        Err(Status::unimplemented("Not yet implemented"))
177    }
178}
179
180#[async_trait]
181impl FlightCraft for GreptimeRequestHandler {
182    async fn do_get(
183        &self,
184        request: Request<Ticket>,
185    ) -> TonicResult<Response<TonicStream<FlightData>>> {
186        let hints = hint_headers::extract_hints(request.metadata());
187
188        let ticket = request.into_inner().ticket;
189        let request =
190            GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?;
191
192        // The Grpc protocol pass query by Flight. It needs to be wrapped under a span, in order to record stream
193        let span = info_span!(
194            "GreptimeRequestHandler::do_get",
195            protocol = "grpc",
196            request_type = get_request_type(&request)
197        );
198        let flight_compression = self.flight_compression;
199        async {
200            let output = self.handle_request(request, hints).await?;
201            let stream = to_flight_data_stream(
202                output,
203                TracingContext::from_current_span(),
204                flight_compression,
205            );
206            Ok(Response::new(stream))
207        }
208        .trace(span)
209        .await
210    }
211
212    async fn do_put(
213        &self,
214        request: Request<Streaming<FlightData>>,
215    ) -> TonicResult<Response<TonicStream<PutResult>>> {
216        let (headers, _, stream) = request.into_parts();
217
218        let header = |key: &str| -> TonicResult<Option<&str>> {
219            let Some(v) = headers.get(key) else {
220                return Ok(None);
221            };
222            let Ok(v) = std::str::from_utf8(v.as_bytes()) else {
223                return Err(InvalidParameterSnafu {
224                    reason: "expect valid UTF-8 value",
225                }
226                .build()
227                .into());
228            };
229            Ok(Some(v))
230        };
231
232        let username_and_password = header(AUTHORIZATION_HEADER)?;
233        let db = header(GREPTIME_DB_HEADER_NAME)?;
234        if !self.validate_auth(username_and_password, db).await? {
235            return Err(Status::unauthenticated("auth failed"));
236        }
237
238        const MAX_PENDING_RESPONSES: usize = 32;
239        let (tx, rx) = mpsc::channel::<TonicResult<DoPutResponse>>(MAX_PENDING_RESPONSES);
240
241        let stream = PutRecordBatchRequestStream {
242            flight_data_stream: stream,
243            state: PutRecordBatchRequestStreamState::Init(db.map(ToString::to_string)),
244        };
245        self.put_record_batches(stream, tx).await;
246
247        let response = ReceiverStream::new(rx)
248            .and_then(|response| {
249                future::ready({
250                    serde_json::to_vec(&response)
251                        .context(ToJsonSnafu)
252                        .map(|x| PutResult {
253                            app_metadata: Bytes::from(x),
254                        })
255                        .map_err(Into::into)
256                })
257            })
258            .boxed();
259        Ok(Response::new(response))
260    }
261}
262
263pub(crate) struct PutRecordBatchRequest {
264    pub(crate) table_name: TableName,
265    pub(crate) request_id: i64,
266    pub(crate) data: FlightData,
267}
268
269impl PutRecordBatchRequest {
270    fn try_new(table_name: TableName, flight_data: FlightData) -> Result<Self> {
271        let request_id = if !flight_data.app_metadata.is_empty() {
272            let metadata: DoPutMetadata =
273                serde_json::from_slice(&flight_data.app_metadata).context(ParseJsonSnafu)?;
274            metadata.request_id()
275        } else {
276            0
277        };
278        Ok(Self {
279            table_name,
280            request_id,
281            data: flight_data,
282        })
283    }
284}
285
286pub(crate) struct PutRecordBatchRequestStream {
287    flight_data_stream: Streaming<FlightData>,
288    state: PutRecordBatchRequestStreamState,
289}
290
291enum PutRecordBatchRequestStreamState {
292    Init(Option<String>),
293    Started(TableName),
294}
295
296impl Stream for PutRecordBatchRequestStream {
297    type Item = TonicResult<PutRecordBatchRequest>;
298
299    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
300        fn extract_table_name(mut descriptor: FlightDescriptor) -> Result<String> {
301            ensure!(
302                descriptor.r#type == arrow_flight::flight_descriptor::DescriptorType::Path as i32,
303                InvalidParameterSnafu {
304                    reason: "expect FlightDescriptor::type == 'Path' only",
305                }
306            );
307            ensure!(
308                descriptor.path.len() == 1,
309                InvalidParameterSnafu {
310                    reason: "expect FlightDescriptor::path has only one table name",
311                }
312            );
313            Ok(descriptor.path.remove(0))
314        }
315
316        let poll = ready!(self.flight_data_stream.poll_next_unpin(cx));
317
318        let result = match &mut self.state {
319            PutRecordBatchRequestStreamState::Init(db) => match poll {
320                Some(Ok(mut flight_data)) => {
321                    let flight_descriptor = flight_data.flight_descriptor.take();
322                    let result = if let Some(descriptor) = flight_descriptor {
323                        let table_name = extract_table_name(descriptor).map(|x| {
324                            let (catalog, schema) = if let Some(db) = db {
325                                parse_catalog_and_schema_from_db_string(db)
326                            } else {
327                                (
328                                    DEFAULT_CATALOG_NAME.to_string(),
329                                    DEFAULT_SCHEMA_NAME.to_string(),
330                                )
331                            };
332                            TableName::new(catalog, schema, x)
333                        });
334                        let table_name = match table_name {
335                            Ok(table_name) => table_name,
336                            Err(e) => return Poll::Ready(Some(Err(e.into()))),
337                        };
338
339                        let request =
340                            PutRecordBatchRequest::try_new(table_name.clone(), flight_data);
341                        let request = match request {
342                            Ok(request) => request,
343                            Err(e) => return Poll::Ready(Some(Err(e.into()))),
344                        };
345
346                        self.state = PutRecordBatchRequestStreamState::Started(table_name);
347
348                        Ok(request)
349                    } else {
350                        Err(Status::failed_precondition(
351                            "table to put is not found in flight descriptor",
352                        ))
353                    };
354                    Some(result)
355                }
356                Some(Err(e)) => Some(Err(e)),
357                None => None,
358            },
359            PutRecordBatchRequestStreamState::Started(table_name) => poll.map(|x| {
360                x.and_then(|flight_data| {
361                    PutRecordBatchRequest::try_new(table_name.clone(), flight_data)
362                        .map_err(Into::into)
363                })
364            }),
365        };
366        Poll::Ready(result)
367    }
368}
369
370fn to_flight_data_stream(
371    output: Output,
372    tracing_context: TracingContext,
373    flight_compression: FlightCompression,
374) -> TonicStream<FlightData> {
375    match output.data {
376        OutputData::Stream(stream) => {
377            let stream = FlightRecordBatchStream::new(stream, tracing_context, flight_compression);
378            Box::pin(stream) as _
379        }
380        OutputData::RecordBatches(x) => {
381            let stream =
382                FlightRecordBatchStream::new(x.as_stream(), tracing_context, flight_compression);
383            Box::pin(stream) as _
384        }
385        OutputData::AffectedRows(rows) => {
386            let stream = tokio_stream::once(Ok(
387                FlightEncoder::default().encode(FlightMessage::AffectedRows(rows))
388            ));
389            Box::pin(stream) as _
390        }
391    }
392}