servers/grpc/
flight.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod stream;
16
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21use api::v1::GreptimeRequest;
22use arrow_flight::flight_service_server::FlightService;
23use arrow_flight::{
24    Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
25    HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
26};
27use async_trait::async_trait;
28use bytes::Bytes;
29use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
30use common_catalog::parse_catalog_and_schema_from_db_string;
31use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse};
32use common_grpc::flight::{FlightEncoder, FlightMessage};
33use common_query::{Output, OutputData};
34use common_telemetry::tracing::info_span;
35use common_telemetry::tracing_context::{FutureExt, TracingContext};
36use futures::{future, ready, Stream};
37use futures_util::{StreamExt, TryStreamExt};
38use prost::Message;
39use snafu::{ensure, ResultExt};
40use table::table_name::TableName;
41use tokio::sync::mpsc;
42use tokio_stream::wrappers::ReceiverStream;
43use tonic::{Request, Response, Status, Streaming};
44
45use crate::error::{InvalidParameterSnafu, ParseJsonSnafu, Result, ToJsonSnafu};
46pub use crate::grpc::flight::stream::FlightRecordBatchStream;
47use crate::grpc::greptime_handler::{get_request_type, GreptimeRequestHandler};
48use crate::grpc::TonicResult;
49use crate::http::header::constants::GREPTIME_DB_HEADER_NAME;
50use crate::http::AUTHORIZATION_HEADER;
51use crate::{error, hint_headers};
52
53pub type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + 'static>>;
54
55/// A subset of [FlightService]
56#[async_trait]
57pub trait FlightCraft: Send + Sync + 'static {
58    async fn do_get(
59        &self,
60        request: Request<Ticket>,
61    ) -> TonicResult<Response<TonicStream<FlightData>>>;
62
63    async fn do_put(
64        &self,
65        request: Request<Streaming<FlightData>>,
66    ) -> TonicResult<Response<TonicStream<PutResult>>> {
67        let _ = request;
68        Err(Status::unimplemented("Not yet implemented"))
69    }
70}
71
72pub type FlightCraftRef = Arc<dyn FlightCraft>;
73
74pub struct FlightCraftWrapper<T: FlightCraft>(pub T);
75
76impl<T: FlightCraft> From<T> for FlightCraftWrapper<T> {
77    fn from(t: T) -> Self {
78        Self(t)
79    }
80}
81
82#[async_trait]
83impl FlightCraft for FlightCraftRef {
84    async fn do_get(
85        &self,
86        request: Request<Ticket>,
87    ) -> TonicResult<Response<TonicStream<FlightData>>> {
88        (**self).do_get(request).await
89    }
90
91    async fn do_put(
92        &self,
93        request: Request<Streaming<FlightData>>,
94    ) -> TonicResult<Response<TonicStream<PutResult>>> {
95        self.as_ref().do_put(request).await
96    }
97}
98
99#[async_trait]
100impl<T: FlightCraft> FlightService for FlightCraftWrapper<T> {
101    type HandshakeStream = TonicStream<HandshakeResponse>;
102
103    async fn handshake(
104        &self,
105        _: Request<Streaming<HandshakeRequest>>,
106    ) -> TonicResult<Response<Self::HandshakeStream>> {
107        Err(Status::unimplemented("Not yet implemented"))
108    }
109
110    type ListFlightsStream = TonicStream<FlightInfo>;
111
112    async fn list_flights(
113        &self,
114        _: Request<Criteria>,
115    ) -> TonicResult<Response<Self::ListFlightsStream>> {
116        Err(Status::unimplemented("Not yet implemented"))
117    }
118
119    async fn get_flight_info(
120        &self,
121        _: Request<FlightDescriptor>,
122    ) -> TonicResult<Response<FlightInfo>> {
123        Err(Status::unimplemented("Not yet implemented"))
124    }
125
126    async fn poll_flight_info(
127        &self,
128        _: Request<FlightDescriptor>,
129    ) -> TonicResult<Response<PollInfo>> {
130        Err(Status::unimplemented("Not yet implemented"))
131    }
132
133    async fn get_schema(
134        &self,
135        _: Request<FlightDescriptor>,
136    ) -> TonicResult<Response<SchemaResult>> {
137        Err(Status::unimplemented("Not yet implemented"))
138    }
139
140    type DoGetStream = TonicStream<FlightData>;
141
142    async fn do_get(&self, request: Request<Ticket>) -> TonicResult<Response<Self::DoGetStream>> {
143        self.0.do_get(request).await
144    }
145
146    type DoPutStream = TonicStream<PutResult>;
147
148    async fn do_put(
149        &self,
150        request: Request<Streaming<FlightData>>,
151    ) -> TonicResult<Response<Self::DoPutStream>> {
152        self.0.do_put(request).await
153    }
154
155    type DoExchangeStream = TonicStream<FlightData>;
156
157    async fn do_exchange(
158        &self,
159        _: Request<Streaming<FlightData>>,
160    ) -> TonicResult<Response<Self::DoExchangeStream>> {
161        Err(Status::unimplemented("Not yet implemented"))
162    }
163
164    type DoActionStream = TonicStream<arrow_flight::Result>;
165
166    async fn do_action(&self, _: Request<Action>) -> TonicResult<Response<Self::DoActionStream>> {
167        Err(Status::unimplemented("Not yet implemented"))
168    }
169
170    type ListActionsStream = TonicStream<ActionType>;
171
172    async fn list_actions(
173        &self,
174        _: Request<Empty>,
175    ) -> TonicResult<Response<Self::ListActionsStream>> {
176        Err(Status::unimplemented("Not yet implemented"))
177    }
178}
179
180#[async_trait]
181impl FlightCraft for GreptimeRequestHandler {
182    async fn do_get(
183        &self,
184        request: Request<Ticket>,
185    ) -> TonicResult<Response<TonicStream<FlightData>>> {
186        let hints = hint_headers::extract_hints(request.metadata());
187
188        let ticket = request.into_inner().ticket;
189        let request =
190            GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?;
191
192        // The Grpc protocol pass query by Flight. It needs to be wrapped under a span, in order to record stream
193        let span = info_span!(
194            "GreptimeRequestHandler::do_get",
195            protocol = "grpc",
196            request_type = get_request_type(&request)
197        );
198        async {
199            let output = self.handle_request(request, hints).await?;
200            let stream = to_flight_data_stream(output, TracingContext::from_current_span());
201            Ok(Response::new(stream))
202        }
203        .trace(span)
204        .await
205    }
206
207    async fn do_put(
208        &self,
209        request: Request<Streaming<FlightData>>,
210    ) -> TonicResult<Response<TonicStream<PutResult>>> {
211        let (headers, _, stream) = request.into_parts();
212
213        let header = |key: &str| -> TonicResult<Option<&str>> {
214            let Some(v) = headers.get(key) else {
215                return Ok(None);
216            };
217            let Ok(v) = std::str::from_utf8(v.as_bytes()) else {
218                return Err(InvalidParameterSnafu {
219                    reason: "expect valid UTF-8 value",
220                }
221                .build()
222                .into());
223            };
224            Ok(Some(v))
225        };
226
227        let username_and_password = header(AUTHORIZATION_HEADER)?;
228        let db = header(GREPTIME_DB_HEADER_NAME)?;
229        if !self.validate_auth(username_and_password, db).await? {
230            return Err(Status::unauthenticated("auth failed"));
231        }
232
233        const MAX_PENDING_RESPONSES: usize = 32;
234        let (tx, rx) = mpsc::channel::<TonicResult<DoPutResponse>>(MAX_PENDING_RESPONSES);
235
236        let stream = PutRecordBatchRequestStream {
237            flight_data_stream: stream,
238            state: PutRecordBatchRequestStreamState::Init(db.map(ToString::to_string)),
239        };
240        self.put_record_batches(stream, tx).await;
241
242        let response = ReceiverStream::new(rx)
243            .and_then(|response| {
244                future::ready({
245                    serde_json::to_vec(&response)
246                        .context(ToJsonSnafu)
247                        .map(|x| PutResult {
248                            app_metadata: Bytes::from(x),
249                        })
250                        .map_err(Into::into)
251                })
252            })
253            .boxed();
254        Ok(Response::new(response))
255    }
256}
257
258pub(crate) struct PutRecordBatchRequest {
259    pub(crate) table_name: TableName,
260    pub(crate) request_id: i64,
261    pub(crate) data: FlightData,
262}
263
264impl PutRecordBatchRequest {
265    fn try_new(table_name: TableName, flight_data: FlightData) -> Result<Self> {
266        let request_id = if !flight_data.app_metadata.is_empty() {
267            let metadata: DoPutMetadata =
268                serde_json::from_slice(&flight_data.app_metadata).context(ParseJsonSnafu)?;
269            metadata.request_id()
270        } else {
271            0
272        };
273        Ok(Self {
274            table_name,
275            request_id,
276            data: flight_data,
277        })
278    }
279}
280
281pub(crate) struct PutRecordBatchRequestStream {
282    flight_data_stream: Streaming<FlightData>,
283    state: PutRecordBatchRequestStreamState,
284}
285
286enum PutRecordBatchRequestStreamState {
287    Init(Option<String>),
288    Started(TableName),
289}
290
291impl Stream for PutRecordBatchRequestStream {
292    type Item = TonicResult<PutRecordBatchRequest>;
293
294    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
295        fn extract_table_name(mut descriptor: FlightDescriptor) -> Result<String> {
296            ensure!(
297                descriptor.r#type == arrow_flight::flight_descriptor::DescriptorType::Path as i32,
298                InvalidParameterSnafu {
299                    reason: "expect FlightDescriptor::type == 'Path' only",
300                }
301            );
302            ensure!(
303                descriptor.path.len() == 1,
304                InvalidParameterSnafu {
305                    reason: "expect FlightDescriptor::path has only one table name",
306                }
307            );
308            Ok(descriptor.path.remove(0))
309        }
310
311        let poll = ready!(self.flight_data_stream.poll_next_unpin(cx));
312
313        let result = match &mut self.state {
314            PutRecordBatchRequestStreamState::Init(db) => match poll {
315                Some(Ok(mut flight_data)) => {
316                    let flight_descriptor = flight_data.flight_descriptor.take();
317                    let result = if let Some(descriptor) = flight_descriptor {
318                        let table_name = extract_table_name(descriptor).map(|x| {
319                            let (catalog, schema) = if let Some(db) = db {
320                                parse_catalog_and_schema_from_db_string(db)
321                            } else {
322                                (
323                                    DEFAULT_CATALOG_NAME.to_string(),
324                                    DEFAULT_SCHEMA_NAME.to_string(),
325                                )
326                            };
327                            TableName::new(catalog, schema, x)
328                        });
329                        let table_name = match table_name {
330                            Ok(table_name) => table_name,
331                            Err(e) => return Poll::Ready(Some(Err(e.into()))),
332                        };
333
334                        let request =
335                            PutRecordBatchRequest::try_new(table_name.clone(), flight_data);
336                        let request = match request {
337                            Ok(request) => request,
338                            Err(e) => return Poll::Ready(Some(Err(e.into()))),
339                        };
340
341                        self.state = PutRecordBatchRequestStreamState::Started(table_name);
342
343                        Ok(request)
344                    } else {
345                        Err(Status::failed_precondition(
346                            "table to put is not found in flight descriptor",
347                        ))
348                    };
349                    Some(result)
350                }
351                Some(Err(e)) => Some(Err(e)),
352                None => None,
353            },
354            PutRecordBatchRequestStreamState::Started(table_name) => poll.map(|x| {
355                x.and_then(|flight_data| {
356                    PutRecordBatchRequest::try_new(table_name.clone(), flight_data)
357                        .map_err(Into::into)
358                })
359            }),
360        };
361        Poll::Ready(result)
362    }
363}
364
365fn to_flight_data_stream(
366    output: Output,
367    tracing_context: TracingContext,
368) -> TonicStream<FlightData> {
369    match output.data {
370        OutputData::Stream(stream) => {
371            let stream = FlightRecordBatchStream::new(stream, tracing_context);
372            Box::pin(stream) as _
373        }
374        OutputData::RecordBatches(x) => {
375            let stream = FlightRecordBatchStream::new(x.as_stream(), tracing_context);
376            Box::pin(stream) as _
377        }
378        OutputData::AffectedRows(rows) => {
379            let stream = tokio_stream::once(Ok(
380                FlightEncoder::default().encode(FlightMessage::AffectedRows(rows))
381            ));
382            Box::pin(stream) as _
383        }
384    }
385}