1mod stream;
16
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21use api::v1::GreptimeRequest;
22use arrow_flight::flight_service_server::FlightService;
23use arrow_flight::{
24 Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
25 HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
26};
27use async_trait::async_trait;
28use bytes::Bytes;
29use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
30use common_catalog::parse_catalog_and_schema_from_db_string;
31use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse};
32use common_grpc::flight::{FlightEncoder, FlightMessage};
33use common_query::{Output, OutputData};
34use common_telemetry::tracing::info_span;
35use common_telemetry::tracing_context::{FutureExt, TracingContext};
36use futures::{future, ready, Stream};
37use futures_util::{StreamExt, TryStreamExt};
38use prost::Message;
39use snafu::{ensure, ResultExt};
40use table::table_name::TableName;
41use tokio::sync::mpsc;
42use tokio_stream::wrappers::ReceiverStream;
43use tonic::{Request, Response, Status, Streaming};
44
45use crate::error::{InvalidParameterSnafu, ParseJsonSnafu, Result, ToJsonSnafu};
46pub use crate::grpc::flight::stream::FlightRecordBatchStream;
47use crate::grpc::greptime_handler::{get_request_type, GreptimeRequestHandler};
48use crate::grpc::TonicResult;
49use crate::http::header::constants::GREPTIME_DB_HEADER_NAME;
50use crate::http::AUTHORIZATION_HEADER;
51use crate::{error, hint_headers};
52
53pub type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + 'static>>;
54
55#[async_trait]
57pub trait FlightCraft: Send + Sync + 'static {
58 async fn do_get(
59 &self,
60 request: Request<Ticket>,
61 ) -> TonicResult<Response<TonicStream<FlightData>>>;
62
63 async fn do_put(
64 &self,
65 request: Request<Streaming<FlightData>>,
66 ) -> TonicResult<Response<TonicStream<PutResult>>> {
67 let _ = request;
68 Err(Status::unimplemented("Not yet implemented"))
69 }
70}
71
72pub type FlightCraftRef = Arc<dyn FlightCraft>;
73
74pub struct FlightCraftWrapper<T: FlightCraft>(pub T);
75
76impl<T: FlightCraft> From<T> for FlightCraftWrapper<T> {
77 fn from(t: T) -> Self {
78 Self(t)
79 }
80}
81
82#[async_trait]
83impl FlightCraft for FlightCraftRef {
84 async fn do_get(
85 &self,
86 request: Request<Ticket>,
87 ) -> TonicResult<Response<TonicStream<FlightData>>> {
88 (**self).do_get(request).await
89 }
90
91 async fn do_put(
92 &self,
93 request: Request<Streaming<FlightData>>,
94 ) -> TonicResult<Response<TonicStream<PutResult>>> {
95 self.as_ref().do_put(request).await
96 }
97}
98
99#[async_trait]
100impl<T: FlightCraft> FlightService for FlightCraftWrapper<T> {
101 type HandshakeStream = TonicStream<HandshakeResponse>;
102
103 async fn handshake(
104 &self,
105 _: Request<Streaming<HandshakeRequest>>,
106 ) -> TonicResult<Response<Self::HandshakeStream>> {
107 Err(Status::unimplemented("Not yet implemented"))
108 }
109
110 type ListFlightsStream = TonicStream<FlightInfo>;
111
112 async fn list_flights(
113 &self,
114 _: Request<Criteria>,
115 ) -> TonicResult<Response<Self::ListFlightsStream>> {
116 Err(Status::unimplemented("Not yet implemented"))
117 }
118
119 async fn get_flight_info(
120 &self,
121 _: Request<FlightDescriptor>,
122 ) -> TonicResult<Response<FlightInfo>> {
123 Err(Status::unimplemented("Not yet implemented"))
124 }
125
126 async fn poll_flight_info(
127 &self,
128 _: Request<FlightDescriptor>,
129 ) -> TonicResult<Response<PollInfo>> {
130 Err(Status::unimplemented("Not yet implemented"))
131 }
132
133 async fn get_schema(
134 &self,
135 _: Request<FlightDescriptor>,
136 ) -> TonicResult<Response<SchemaResult>> {
137 Err(Status::unimplemented("Not yet implemented"))
138 }
139
140 type DoGetStream = TonicStream<FlightData>;
141
142 async fn do_get(&self, request: Request<Ticket>) -> TonicResult<Response<Self::DoGetStream>> {
143 self.0.do_get(request).await
144 }
145
146 type DoPutStream = TonicStream<PutResult>;
147
148 async fn do_put(
149 &self,
150 request: Request<Streaming<FlightData>>,
151 ) -> TonicResult<Response<Self::DoPutStream>> {
152 self.0.do_put(request).await
153 }
154
155 type DoExchangeStream = TonicStream<FlightData>;
156
157 async fn do_exchange(
158 &self,
159 _: Request<Streaming<FlightData>>,
160 ) -> TonicResult<Response<Self::DoExchangeStream>> {
161 Err(Status::unimplemented("Not yet implemented"))
162 }
163
164 type DoActionStream = TonicStream<arrow_flight::Result>;
165
166 async fn do_action(&self, _: Request<Action>) -> TonicResult<Response<Self::DoActionStream>> {
167 Err(Status::unimplemented("Not yet implemented"))
168 }
169
170 type ListActionsStream = TonicStream<ActionType>;
171
172 async fn list_actions(
173 &self,
174 _: Request<Empty>,
175 ) -> TonicResult<Response<Self::ListActionsStream>> {
176 Err(Status::unimplemented("Not yet implemented"))
177 }
178}
179
180#[async_trait]
181impl FlightCraft for GreptimeRequestHandler {
182 async fn do_get(
183 &self,
184 request: Request<Ticket>,
185 ) -> TonicResult<Response<TonicStream<FlightData>>> {
186 let hints = hint_headers::extract_hints(request.metadata());
187
188 let ticket = request.into_inner().ticket;
189 let request =
190 GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?;
191
192 let span = info_span!(
194 "GreptimeRequestHandler::do_get",
195 protocol = "grpc",
196 request_type = get_request_type(&request)
197 );
198 async {
199 let output = self.handle_request(request, hints).await?;
200 let stream = to_flight_data_stream(output, TracingContext::from_current_span());
201 Ok(Response::new(stream))
202 }
203 .trace(span)
204 .await
205 }
206
207 async fn do_put(
208 &self,
209 request: Request<Streaming<FlightData>>,
210 ) -> TonicResult<Response<TonicStream<PutResult>>> {
211 let (headers, _, stream) = request.into_parts();
212
213 let header = |key: &str| -> TonicResult<Option<&str>> {
214 let Some(v) = headers.get(key) else {
215 return Ok(None);
216 };
217 let Ok(v) = std::str::from_utf8(v.as_bytes()) else {
218 return Err(InvalidParameterSnafu {
219 reason: "expect valid UTF-8 value",
220 }
221 .build()
222 .into());
223 };
224 Ok(Some(v))
225 };
226
227 let username_and_password = header(AUTHORIZATION_HEADER)?;
228 let db = header(GREPTIME_DB_HEADER_NAME)?;
229 if !self.validate_auth(username_and_password, db).await? {
230 return Err(Status::unauthenticated("auth failed"));
231 }
232
233 const MAX_PENDING_RESPONSES: usize = 32;
234 let (tx, rx) = mpsc::channel::<TonicResult<DoPutResponse>>(MAX_PENDING_RESPONSES);
235
236 let stream = PutRecordBatchRequestStream {
237 flight_data_stream: stream,
238 state: PutRecordBatchRequestStreamState::Init(db.map(ToString::to_string)),
239 };
240 self.put_record_batches(stream, tx).await;
241
242 let response = ReceiverStream::new(rx)
243 .and_then(|response| {
244 future::ready({
245 serde_json::to_vec(&response)
246 .context(ToJsonSnafu)
247 .map(|x| PutResult {
248 app_metadata: Bytes::from(x),
249 })
250 .map_err(Into::into)
251 })
252 })
253 .boxed();
254 Ok(Response::new(response))
255 }
256}
257
258pub(crate) struct PutRecordBatchRequest {
259 pub(crate) table_name: TableName,
260 pub(crate) request_id: i64,
261 pub(crate) data: FlightData,
262}
263
264impl PutRecordBatchRequest {
265 fn try_new(table_name: TableName, flight_data: FlightData) -> Result<Self> {
266 let request_id = if !flight_data.app_metadata.is_empty() {
267 let metadata: DoPutMetadata =
268 serde_json::from_slice(&flight_data.app_metadata).context(ParseJsonSnafu)?;
269 metadata.request_id()
270 } else {
271 0
272 };
273 Ok(Self {
274 table_name,
275 request_id,
276 data: flight_data,
277 })
278 }
279}
280
281pub(crate) struct PutRecordBatchRequestStream {
282 flight_data_stream: Streaming<FlightData>,
283 state: PutRecordBatchRequestStreamState,
284}
285
286enum PutRecordBatchRequestStreamState {
287 Init(Option<String>),
288 Started(TableName),
289}
290
291impl Stream for PutRecordBatchRequestStream {
292 type Item = TonicResult<PutRecordBatchRequest>;
293
294 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
295 fn extract_table_name(mut descriptor: FlightDescriptor) -> Result<String> {
296 ensure!(
297 descriptor.r#type == arrow_flight::flight_descriptor::DescriptorType::Path as i32,
298 InvalidParameterSnafu {
299 reason: "expect FlightDescriptor::type == 'Path' only",
300 }
301 );
302 ensure!(
303 descriptor.path.len() == 1,
304 InvalidParameterSnafu {
305 reason: "expect FlightDescriptor::path has only one table name",
306 }
307 );
308 Ok(descriptor.path.remove(0))
309 }
310
311 let poll = ready!(self.flight_data_stream.poll_next_unpin(cx));
312
313 let result = match &mut self.state {
314 PutRecordBatchRequestStreamState::Init(db) => match poll {
315 Some(Ok(mut flight_data)) => {
316 let flight_descriptor = flight_data.flight_descriptor.take();
317 let result = if let Some(descriptor) = flight_descriptor {
318 let table_name = extract_table_name(descriptor).map(|x| {
319 let (catalog, schema) = if let Some(db) = db {
320 parse_catalog_and_schema_from_db_string(db)
321 } else {
322 (
323 DEFAULT_CATALOG_NAME.to_string(),
324 DEFAULT_SCHEMA_NAME.to_string(),
325 )
326 };
327 TableName::new(catalog, schema, x)
328 });
329 let table_name = match table_name {
330 Ok(table_name) => table_name,
331 Err(e) => return Poll::Ready(Some(Err(e.into()))),
332 };
333
334 let request =
335 PutRecordBatchRequest::try_new(table_name.clone(), flight_data);
336 let request = match request {
337 Ok(request) => request,
338 Err(e) => return Poll::Ready(Some(Err(e.into()))),
339 };
340
341 self.state = PutRecordBatchRequestStreamState::Started(table_name);
342
343 Ok(request)
344 } else {
345 Err(Status::failed_precondition(
346 "table to put is not found in flight descriptor",
347 ))
348 };
349 Some(result)
350 }
351 Some(Err(e)) => Some(Err(e)),
352 None => None,
353 },
354 PutRecordBatchRequestStreamState::Started(table_name) => poll.map(|x| {
355 x.and_then(|flight_data| {
356 PutRecordBatchRequest::try_new(table_name.clone(), flight_data)
357 .map_err(Into::into)
358 })
359 }),
360 };
361 Poll::Ready(result)
362 }
363}
364
365fn to_flight_data_stream(
366 output: Output,
367 tracing_context: TracingContext,
368) -> TonicStream<FlightData> {
369 match output.data {
370 OutputData::Stream(stream) => {
371 let stream = FlightRecordBatchStream::new(stream, tracing_context);
372 Box::pin(stream) as _
373 }
374 OutputData::RecordBatches(x) => {
375 let stream = FlightRecordBatchStream::new(x.as_stream(), tracing_context);
376 Box::pin(stream) as _
377 }
378 OutputData::AffectedRows(rows) => {
379 let stream = tokio_stream::once(Ok(
380 FlightEncoder::default().encode(FlightMessage::AffectedRows(rows))
381 ));
382 Box::pin(stream) as _
383 }
384 }
385}