servers/grpc/flight/
stream.rs1use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use arrow_flight::FlightData;
19use common_error::ext::ErrorExt;
20use common_grpc::flight::{FlightEncoder, FlightMessage};
21use common_recordbatch::SendableRecordBatchStream;
22use common_telemetry::tracing::{info_span, Instrument};
23use common_telemetry::tracing_context::{FutureExt, TracingContext};
24use common_telemetry::{error, warn};
25use futures::channel::mpsc;
26use futures::channel::mpsc::Sender;
27use futures::{SinkExt, Stream, StreamExt};
28use pin_project::{pin_project, pinned_drop};
29use snafu::ResultExt;
30use tokio::task::JoinHandle;
31
32use crate::error;
33use crate::grpc::flight::TonicResult;
34use crate::grpc::FlightCompression;
35
36#[pin_project(PinnedDrop)]
37pub struct FlightRecordBatchStream {
38 #[pin]
39 rx: mpsc::Receiver<Result<FlightMessage, tonic::Status>>,
40 join_handle: JoinHandle<()>,
41 done: bool,
42 encoder: FlightEncoder,
43}
44
45impl FlightRecordBatchStream {
46 pub fn new(
47 recordbatches: SendableRecordBatchStream,
48 tracing_context: TracingContext,
49 compression: FlightCompression,
50 ) -> Self {
51 let (tx, rx) = mpsc::channel::<TonicResult<FlightMessage>>(1);
52 let join_handle = common_runtime::spawn_global(async move {
53 Self::flight_data_stream(recordbatches, tx)
54 .trace(tracing_context.attach(info_span!("flight_data_stream")))
55 .await
56 });
57 let encoder = if compression.arrow_compression() {
58 FlightEncoder::default()
59 } else {
60 FlightEncoder::with_compression_disabled()
61 };
62 Self {
63 rx,
64 join_handle,
65 done: false,
66 encoder,
67 }
68 }
69
70 async fn flight_data_stream(
71 mut recordbatches: SendableRecordBatchStream,
72 mut tx: Sender<TonicResult<FlightMessage>>,
73 ) {
74 let schema = recordbatches.schema().arrow_schema().clone();
75 if let Err(e) = tx.send(Ok(FlightMessage::Schema(schema))).await {
76 warn!(e; "stop sending Flight data");
77 return;
78 }
79
80 while let Some(batch_or_err) = recordbatches.next().in_current_span().await {
81 match batch_or_err {
82 Ok(recordbatch) => {
83 if let Err(e) = tx
84 .send(Ok(FlightMessage::RecordBatch(
85 recordbatch.into_df_record_batch(),
86 )))
87 .await
88 {
89 warn!(e; "stop sending Flight data");
90 return;
91 }
92 }
93 Err(e) => {
94 if e.status_code().should_log_error() {
95 error!("{e:?}");
96 }
97
98 let e = Err(e).context(error::CollectRecordbatchSnafu);
99 if let Err(e) = tx.send(e.map_err(|x| x.into())).await {
100 warn!(e; "stop sending Flight data");
101 }
102 return;
103 }
104 }
105 }
106 if let Some(metrics_str) = recordbatches
108 .metrics()
109 .and_then(|m| serde_json::to_string(&m).ok())
110 {
111 let _ = tx.send(Ok(FlightMessage::Metrics(metrics_str))).await;
112 }
113 }
114}
115
116#[pinned_drop]
117impl PinnedDrop for FlightRecordBatchStream {
118 fn drop(self: Pin<&mut Self>) {
119 self.join_handle.abort();
120 }
121}
122
123impl Stream for FlightRecordBatchStream {
124 type Item = TonicResult<FlightData>;
125
126 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
127 let this = self.project();
128 if *this.done {
129 Poll::Ready(None)
130 } else {
131 match this.rx.poll_next(cx) {
132 Poll::Ready(None) => {
133 *this.done = true;
134 Poll::Ready(None)
135 }
136 Poll::Ready(Some(result)) => match result {
137 Ok(flight_message) => {
138 let flight_data = this.encoder.encode(flight_message);
139 Poll::Ready(Some(Ok(flight_data)))
140 }
141 Err(e) => {
142 *this.done = true;
143 Poll::Ready(Some(Err(e)))
144 }
145 },
146 Poll::Pending => Poll::Pending,
147 }
148 }
149 }
150}
151
152#[cfg(test)]
153mod test {
154 use std::sync::Arc;
155
156 use common_grpc::flight::{FlightDecoder, FlightMessage};
157 use common_recordbatch::{RecordBatch, RecordBatches};
158 use datatypes::prelude::*;
159 use datatypes::schema::{ColumnSchema, Schema};
160 use datatypes::vectors::Int32Vector;
161 use futures::StreamExt;
162
163 use super::*;
164
165 #[tokio::test]
166 async fn test_flight_record_batch_stream() {
167 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
168 "a",
169 ConcreteDataType::int32_datatype(),
170 false,
171 )]));
172
173 let v: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
174 let recordbatch = RecordBatch::new(schema.clone(), vec![v]).unwrap();
175
176 let recordbatches = RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()])
177 .unwrap()
178 .as_stream();
179 let mut stream = FlightRecordBatchStream::new(
180 recordbatches,
181 TracingContext::default(),
182 FlightCompression::default(),
183 );
184
185 let mut raw_data = Vec::with_capacity(2);
186 raw_data.push(stream.next().await.unwrap().unwrap());
187 raw_data.push(stream.next().await.unwrap().unwrap());
188 assert!(stream.next().await.is_none());
189 assert!(stream.done);
190
191 let decoder = &mut FlightDecoder::default();
192 let mut flight_messages = raw_data
193 .into_iter()
194 .map(|x| decoder.try_decode(&x).unwrap())
195 .collect::<Vec<FlightMessage>>();
196 assert_eq!(flight_messages.len(), 2);
197
198 match flight_messages.remove(0) {
199 FlightMessage::Schema(actual_schema) => {
200 assert_eq!(&actual_schema, schema.arrow_schema());
201 }
202 _ => unreachable!(),
203 }
204
205 match flight_messages.remove(0) {
206 FlightMessage::RecordBatch(actual_recordbatch) => {
207 assert_eq!(&actual_recordbatch, recordbatch.df_record_batch());
208 }
209 _ => unreachable!(),
210 }
211 }
212}