1mod stream;
16
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21use api::v1::GreptimeRequest;
22use arrow_flight::flight_service_server::FlightService;
23use arrow_flight::{
24 Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
25 HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
26};
27use async_trait::async_trait;
28use bytes::Bytes;
29use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
30use common_catalog::parse_catalog_and_schema_from_db_string;
31use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse};
32use common_grpc::flight::{FlightEncoder, FlightMessage};
33use common_query::{Output, OutputData};
34use common_telemetry::tracing::info_span;
35use common_telemetry::tracing_context::{FutureExt, TracingContext};
36use futures::{future, ready, Stream};
37use futures_util::{StreamExt, TryStreamExt};
38use prost::Message;
39use session::context::{QueryContext, QueryContextRef};
40use snafu::{ensure, ResultExt};
41use table::table_name::TableName;
42use tokio::sync::mpsc;
43use tokio_stream::wrappers::ReceiverStream;
44use tonic::{Request, Response, Status, Streaming};
45
46use crate::error::{InvalidParameterSnafu, ParseJsonSnafu, Result, ToJsonSnafu};
47pub use crate::grpc::flight::stream::FlightRecordBatchStream;
48use crate::grpc::greptime_handler::{get_request_type, GreptimeRequestHandler};
49use crate::grpc::{FlightCompression, TonicResult};
50use crate::http::header::constants::GREPTIME_DB_HEADER_NAME;
51use crate::http::AUTHORIZATION_HEADER;
52use crate::{error, hint_headers};
53
54pub type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + 'static>>;
55
56#[async_trait]
58pub trait FlightCraft: Send + Sync + 'static {
59 async fn do_get(
60 &self,
61 request: Request<Ticket>,
62 ) -> TonicResult<Response<TonicStream<FlightData>>>;
63
64 async fn do_put(
65 &self,
66 request: Request<Streaming<FlightData>>,
67 ) -> TonicResult<Response<TonicStream<PutResult>>> {
68 let _ = request;
69 Err(Status::unimplemented("Not yet implemented"))
70 }
71}
72
73pub type FlightCraftRef = Arc<dyn FlightCraft>;
74
75pub struct FlightCraftWrapper<T: FlightCraft>(pub T);
76
77impl<T: FlightCraft> From<T> for FlightCraftWrapper<T> {
78 fn from(t: T) -> Self {
79 Self(t)
80 }
81}
82
83#[async_trait]
84impl FlightCraft for FlightCraftRef {
85 async fn do_get(
86 &self,
87 request: Request<Ticket>,
88 ) -> TonicResult<Response<TonicStream<FlightData>>> {
89 (**self).do_get(request).await
90 }
91
92 async fn do_put(
93 &self,
94 request: Request<Streaming<FlightData>>,
95 ) -> TonicResult<Response<TonicStream<PutResult>>> {
96 self.as_ref().do_put(request).await
97 }
98}
99
100#[async_trait]
101impl<T: FlightCraft> FlightService for FlightCraftWrapper<T> {
102 type HandshakeStream = TonicStream<HandshakeResponse>;
103
104 async fn handshake(
105 &self,
106 _: Request<Streaming<HandshakeRequest>>,
107 ) -> TonicResult<Response<Self::HandshakeStream>> {
108 Err(Status::unimplemented("Not yet implemented"))
109 }
110
111 type ListFlightsStream = TonicStream<FlightInfo>;
112
113 async fn list_flights(
114 &self,
115 _: Request<Criteria>,
116 ) -> TonicResult<Response<Self::ListFlightsStream>> {
117 Err(Status::unimplemented("Not yet implemented"))
118 }
119
120 async fn get_flight_info(
121 &self,
122 _: Request<FlightDescriptor>,
123 ) -> TonicResult<Response<FlightInfo>> {
124 Err(Status::unimplemented("Not yet implemented"))
125 }
126
127 async fn poll_flight_info(
128 &self,
129 _: Request<FlightDescriptor>,
130 ) -> TonicResult<Response<PollInfo>> {
131 Err(Status::unimplemented("Not yet implemented"))
132 }
133
134 async fn get_schema(
135 &self,
136 _: Request<FlightDescriptor>,
137 ) -> TonicResult<Response<SchemaResult>> {
138 Err(Status::unimplemented("Not yet implemented"))
139 }
140
141 type DoGetStream = TonicStream<FlightData>;
142
143 async fn do_get(&self, request: Request<Ticket>) -> TonicResult<Response<Self::DoGetStream>> {
144 self.0.do_get(request).await
145 }
146
147 type DoPutStream = TonicStream<PutResult>;
148
149 async fn do_put(
150 &self,
151 request: Request<Streaming<FlightData>>,
152 ) -> TonicResult<Response<Self::DoPutStream>> {
153 self.0.do_put(request).await
154 }
155
156 type DoExchangeStream = TonicStream<FlightData>;
157
158 async fn do_exchange(
159 &self,
160 _: Request<Streaming<FlightData>>,
161 ) -> TonicResult<Response<Self::DoExchangeStream>> {
162 Err(Status::unimplemented("Not yet implemented"))
163 }
164
165 type DoActionStream = TonicStream<arrow_flight::Result>;
166
167 async fn do_action(&self, _: Request<Action>) -> TonicResult<Response<Self::DoActionStream>> {
168 Err(Status::unimplemented("Not yet implemented"))
169 }
170
171 type ListActionsStream = TonicStream<ActionType>;
172
173 async fn list_actions(
174 &self,
175 _: Request<Empty>,
176 ) -> TonicResult<Response<Self::ListActionsStream>> {
177 Err(Status::unimplemented("Not yet implemented"))
178 }
179}
180
181#[async_trait]
182impl FlightCraft for GreptimeRequestHandler {
183 async fn do_get(
184 &self,
185 request: Request<Ticket>,
186 ) -> TonicResult<Response<TonicStream<FlightData>>> {
187 let hints = hint_headers::extract_hints(request.metadata());
188
189 let ticket = request.into_inner().ticket;
190 let request =
191 GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?;
192 let query_ctx = QueryContext::arc();
193
194 let span = info_span!(
196 "GreptimeRequestHandler::do_get",
197 protocol = "grpc",
198 request_type = get_request_type(&request)
199 );
200 let flight_compression = self.flight_compression;
201 async {
202 let output = self.handle_request(request, hints).await?;
203 let stream = to_flight_data_stream(
204 output,
205 TracingContext::from_current_span(),
206 flight_compression,
207 query_ctx,
208 );
209 Ok(Response::new(stream))
210 }
211 .trace(span)
212 .await
213 }
214
215 async fn do_put(
216 &self,
217 request: Request<Streaming<FlightData>>,
218 ) -> TonicResult<Response<TonicStream<PutResult>>> {
219 let (headers, _, stream) = request.into_parts();
220
221 let header = |key: &str| -> TonicResult<Option<&str>> {
222 let Some(v) = headers.get(key) else {
223 return Ok(None);
224 };
225 let Ok(v) = std::str::from_utf8(v.as_bytes()) else {
226 return Err(InvalidParameterSnafu {
227 reason: "expect valid UTF-8 value",
228 }
229 .build()
230 .into());
231 };
232 Ok(Some(v))
233 };
234
235 let username_and_password = header(AUTHORIZATION_HEADER)?;
236 let db = header(GREPTIME_DB_HEADER_NAME)?;
237 if !self.validate_auth(username_and_password, db).await? {
238 return Err(Status::unauthenticated("auth failed"));
239 }
240
241 const MAX_PENDING_RESPONSES: usize = 32;
242 let (tx, rx) = mpsc::channel::<TonicResult<DoPutResponse>>(MAX_PENDING_RESPONSES);
243
244 let stream = PutRecordBatchRequestStream {
245 flight_data_stream: stream,
246 state: PutRecordBatchRequestStreamState::Init(db.map(ToString::to_string)),
247 };
248 self.put_record_batches(stream, tx).await;
249
250 let response = ReceiverStream::new(rx)
251 .and_then(|response| {
252 future::ready({
253 serde_json::to_vec(&response)
254 .context(ToJsonSnafu)
255 .map(|x| PutResult {
256 app_metadata: Bytes::from(x),
257 })
258 .map_err(Into::into)
259 })
260 })
261 .boxed();
262 Ok(Response::new(response))
263 }
264}
265
266pub(crate) struct PutRecordBatchRequest {
267 pub(crate) table_name: TableName,
268 pub(crate) request_id: i64,
269 pub(crate) data: FlightData,
270}
271
272impl PutRecordBatchRequest {
273 fn try_new(table_name: TableName, flight_data: FlightData) -> Result<Self> {
274 let request_id = if !flight_data.app_metadata.is_empty() {
275 let metadata: DoPutMetadata =
276 serde_json::from_slice(&flight_data.app_metadata).context(ParseJsonSnafu)?;
277 metadata.request_id()
278 } else {
279 0
280 };
281 Ok(Self {
282 table_name,
283 request_id,
284 data: flight_data,
285 })
286 }
287}
288
289pub(crate) struct PutRecordBatchRequestStream {
290 flight_data_stream: Streaming<FlightData>,
291 state: PutRecordBatchRequestStreamState,
292}
293
294enum PutRecordBatchRequestStreamState {
295 Init(Option<String>),
296 Started(TableName),
297}
298
299impl Stream for PutRecordBatchRequestStream {
300 type Item = TonicResult<PutRecordBatchRequest>;
301
302 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
303 fn extract_table_name(mut descriptor: FlightDescriptor) -> Result<String> {
304 ensure!(
305 descriptor.r#type == arrow_flight::flight_descriptor::DescriptorType::Path as i32,
306 InvalidParameterSnafu {
307 reason: "expect FlightDescriptor::type == 'Path' only",
308 }
309 );
310 ensure!(
311 descriptor.path.len() == 1,
312 InvalidParameterSnafu {
313 reason: "expect FlightDescriptor::path has only one table name",
314 }
315 );
316 Ok(descriptor.path.remove(0))
317 }
318
319 let poll = ready!(self.flight_data_stream.poll_next_unpin(cx));
320
321 let result = match &mut self.state {
322 PutRecordBatchRequestStreamState::Init(db) => match poll {
323 Some(Ok(mut flight_data)) => {
324 let flight_descriptor = flight_data.flight_descriptor.take();
325 let result = if let Some(descriptor) = flight_descriptor {
326 let table_name = extract_table_name(descriptor).map(|x| {
327 let (catalog, schema) = if let Some(db) = db {
328 parse_catalog_and_schema_from_db_string(db)
329 } else {
330 (
331 DEFAULT_CATALOG_NAME.to_string(),
332 DEFAULT_SCHEMA_NAME.to_string(),
333 )
334 };
335 TableName::new(catalog, schema, x)
336 });
337 let table_name = match table_name {
338 Ok(table_name) => table_name,
339 Err(e) => return Poll::Ready(Some(Err(e.into()))),
340 };
341
342 let request =
343 PutRecordBatchRequest::try_new(table_name.clone(), flight_data);
344 let request = match request {
345 Ok(request) => request,
346 Err(e) => return Poll::Ready(Some(Err(e.into()))),
347 };
348
349 self.state = PutRecordBatchRequestStreamState::Started(table_name);
350
351 Ok(request)
352 } else {
353 Err(Status::failed_precondition(
354 "table to put is not found in flight descriptor",
355 ))
356 };
357 Some(result)
358 }
359 Some(Err(e)) => Some(Err(e)),
360 None => None,
361 },
362 PutRecordBatchRequestStreamState::Started(table_name) => poll.map(|x| {
363 x.and_then(|flight_data| {
364 PutRecordBatchRequest::try_new(table_name.clone(), flight_data)
365 .map_err(Into::into)
366 })
367 }),
368 };
369 Poll::Ready(result)
370 }
371}
372
373fn to_flight_data_stream(
374 output: Output,
375 tracing_context: TracingContext,
376 flight_compression: FlightCompression,
377 query_ctx: QueryContextRef,
378) -> TonicStream<FlightData> {
379 match output.data {
380 OutputData::Stream(stream) => {
381 let stream = FlightRecordBatchStream::new(
382 stream,
383 tracing_context,
384 flight_compression,
385 query_ctx,
386 );
387 Box::pin(stream) as _
388 }
389 OutputData::RecordBatches(x) => {
390 let stream = FlightRecordBatchStream::new(
391 x.as_stream(),
392 tracing_context,
393 flight_compression,
394 query_ctx,
395 );
396 Box::pin(stream) as _
397 }
398 OutputData::AffectedRows(rows) => {
399 let stream = tokio_stream::iter(
400 FlightEncoder::default()
401 .encode(FlightMessage::AffectedRows(rows))
402 .into_iter()
403 .map(Ok),
404 );
405 Box::pin(stream) as _
406 }
407 }
408}