servers/grpc/flight/
stream.rs1use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use arrow_flight::FlightData;
19use common_grpc::flight::{FlightEncoder, FlightMessage};
20use common_recordbatch::SendableRecordBatchStream;
21use common_telemetry::tracing::{info_span, Instrument};
22use common_telemetry::tracing_context::{FutureExt, TracingContext};
23use common_telemetry::warn;
24use futures::channel::mpsc;
25use futures::channel::mpsc::Sender;
26use futures::{SinkExt, Stream, StreamExt};
27use pin_project::{pin_project, pinned_drop};
28use snafu::ResultExt;
29use tokio::task::JoinHandle;
30
31use crate::error;
32use crate::grpc::flight::TonicResult;
33
34#[pin_project(PinnedDrop)]
35pub struct FlightRecordBatchStream {
36 #[pin]
37 rx: mpsc::Receiver<Result<FlightMessage, tonic::Status>>,
38 join_handle: JoinHandle<()>,
39 done: bool,
40 encoder: FlightEncoder,
41}
42
43impl FlightRecordBatchStream {
44 pub fn new(recordbatches: SendableRecordBatchStream, tracing_context: TracingContext) -> Self {
45 let (tx, rx) = mpsc::channel::<TonicResult<FlightMessage>>(1);
46 let join_handle = common_runtime::spawn_global(async move {
47 Self::flight_data_stream(recordbatches, tx)
48 .trace(tracing_context.attach(info_span!("flight_data_stream")))
49 .await
50 });
51 Self {
52 rx,
53 join_handle,
54 done: false,
55 encoder: FlightEncoder::default(),
56 }
57 }
58
59 async fn flight_data_stream(
60 mut recordbatches: SendableRecordBatchStream,
61 mut tx: Sender<TonicResult<FlightMessage>>,
62 ) {
63 let schema = recordbatches.schema();
64 if let Err(e) = tx.send(Ok(FlightMessage::Schema(schema))).await {
65 warn!(e; "stop sending Flight data");
66 return;
67 }
68
69 while let Some(batch_or_err) = recordbatches.next().in_current_span().await {
70 match batch_or_err {
71 Ok(recordbatch) => {
72 if let Err(e) = tx.send(Ok(FlightMessage::Recordbatch(recordbatch))).await {
73 warn!(e; "stop sending Flight data");
74 return;
75 }
76 }
77 Err(e) => {
78 let e = Err(e).context(error::CollectRecordbatchSnafu);
79 if let Err(e) = tx.send(e.map_err(|x| x.into())).await {
80 warn!(e; "stop sending Flight data");
81 }
82 return;
83 }
84 }
85 }
86 if let Some(metrics_str) = recordbatches
88 .metrics()
89 .and_then(|m| serde_json::to_string(&m).ok())
90 {
91 let _ = tx.send(Ok(FlightMessage::Metrics(metrics_str))).await;
92 }
93 }
94}
95
96#[pinned_drop]
97impl PinnedDrop for FlightRecordBatchStream {
98 fn drop(self: Pin<&mut Self>) {
99 self.join_handle.abort();
100 }
101}
102
103impl Stream for FlightRecordBatchStream {
104 type Item = TonicResult<FlightData>;
105
106 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
107 let this = self.project();
108 if *this.done {
109 Poll::Ready(None)
110 } else {
111 match this.rx.poll_next(cx) {
112 Poll::Ready(None) => {
113 *this.done = true;
114 Poll::Ready(None)
115 }
116 Poll::Ready(Some(result)) => match result {
117 Ok(flight_message) => {
118 let flight_data = this.encoder.encode(flight_message);
119 Poll::Ready(Some(Ok(flight_data)))
120 }
121 Err(e) => {
122 *this.done = true;
123 Poll::Ready(Some(Err(e)))
124 }
125 },
126 Poll::Pending => Poll::Pending,
127 }
128 }
129 }
130}
131
132#[cfg(test)]
133mod test {
134 use std::sync::Arc;
135
136 use common_grpc::flight::{FlightDecoder, FlightMessage};
137 use common_recordbatch::{RecordBatch, RecordBatches};
138 use datatypes::prelude::*;
139 use datatypes::schema::{ColumnSchema, Schema};
140 use datatypes::vectors::Int32Vector;
141 use futures::StreamExt;
142
143 use super::*;
144
145 #[tokio::test]
146 async fn test_flight_record_batch_stream() {
147 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
148 "a",
149 ConcreteDataType::int32_datatype(),
150 false,
151 )]));
152
153 let v: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
154 let recordbatch = RecordBatch::new(schema.clone(), vec![v]).unwrap();
155
156 let recordbatches = RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()])
157 .unwrap()
158 .as_stream();
159 let mut stream = FlightRecordBatchStream::new(recordbatches, TracingContext::default());
160
161 let mut raw_data = Vec::with_capacity(2);
162 raw_data.push(stream.next().await.unwrap().unwrap());
163 raw_data.push(stream.next().await.unwrap().unwrap());
164 assert!(stream.next().await.is_none());
165 assert!(stream.done);
166
167 let decoder = &mut FlightDecoder::default();
168 let mut flight_messages = raw_data
169 .into_iter()
170 .map(|x| decoder.try_decode(&x).unwrap())
171 .collect::<Vec<FlightMessage>>();
172 assert_eq!(flight_messages.len(), 2);
173
174 match flight_messages.remove(0) {
175 FlightMessage::Schema(actual_schema) => {
176 assert_eq!(actual_schema, schema);
177 }
178 _ => unreachable!(),
179 }
180
181 match flight_messages.remove(0) {
182 FlightMessage::Recordbatch(actual_recordbatch) => {
183 assert_eq!(actual_recordbatch, recordbatch);
184 }
185 _ => unreachable!(),
186 }
187 }
188}