servers/grpc/flight/
stream.rs1use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use arrow_flight::FlightData;
19use common_grpc::flight::{FlightEncoder, FlightMessage};
20use common_recordbatch::SendableRecordBatchStream;
21use common_telemetry::tracing::{info_span, Instrument};
22use common_telemetry::tracing_context::{FutureExt, TracingContext};
23use common_telemetry::warn;
24use futures::channel::mpsc;
25use futures::channel::mpsc::Sender;
26use futures::{SinkExt, Stream, StreamExt};
27use pin_project::{pin_project, pinned_drop};
28use snafu::ResultExt;
29use tokio::task::JoinHandle;
30
31use crate::error;
32use crate::grpc::flight::TonicResult;
33
34#[pin_project(PinnedDrop)]
35pub struct FlightRecordBatchStream {
36 #[pin]
37 rx: mpsc::Receiver<Result<FlightMessage, tonic::Status>>,
38 join_handle: JoinHandle<()>,
39 done: bool,
40 encoder: FlightEncoder,
41}
42
43impl FlightRecordBatchStream {
44 pub fn new(recordbatches: SendableRecordBatchStream, tracing_context: TracingContext) -> Self {
45 let (tx, rx) = mpsc::channel::<TonicResult<FlightMessage>>(1);
46 let join_handle = common_runtime::spawn_global(async move {
47 Self::flight_data_stream(recordbatches, tx)
48 .trace(tracing_context.attach(info_span!("flight_data_stream")))
49 .await
50 });
51 Self {
52 rx,
53 join_handle,
54 done: false,
55 encoder: FlightEncoder::default(),
56 }
57 }
58
59 async fn flight_data_stream(
60 mut recordbatches: SendableRecordBatchStream,
61 mut tx: Sender<TonicResult<FlightMessage>>,
62 ) {
63 let schema = recordbatches.schema().arrow_schema().clone();
64 if let Err(e) = tx.send(Ok(FlightMessage::Schema(schema))).await {
65 warn!(e; "stop sending Flight data");
66 return;
67 }
68
69 while let Some(batch_or_err) = recordbatches.next().in_current_span().await {
70 match batch_or_err {
71 Ok(recordbatch) => {
72 if let Err(e) = tx
73 .send(Ok(FlightMessage::RecordBatch(
74 recordbatch.into_df_record_batch(),
75 )))
76 .await
77 {
78 warn!(e; "stop sending Flight data");
79 return;
80 }
81 }
82 Err(e) => {
83 let e = Err(e).context(error::CollectRecordbatchSnafu);
84 if let Err(e) = tx.send(e.map_err(|x| x.into())).await {
85 warn!(e; "stop sending Flight data");
86 }
87 return;
88 }
89 }
90 }
91 if let Some(metrics_str) = recordbatches
93 .metrics()
94 .and_then(|m| serde_json::to_string(&m).ok())
95 {
96 let _ = tx.send(Ok(FlightMessage::Metrics(metrics_str))).await;
97 }
98 }
99}
100
101#[pinned_drop]
102impl PinnedDrop for FlightRecordBatchStream {
103 fn drop(self: Pin<&mut Self>) {
104 self.join_handle.abort();
105 }
106}
107
108impl Stream for FlightRecordBatchStream {
109 type Item = TonicResult<FlightData>;
110
111 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
112 let this = self.project();
113 if *this.done {
114 Poll::Ready(None)
115 } else {
116 match this.rx.poll_next(cx) {
117 Poll::Ready(None) => {
118 *this.done = true;
119 Poll::Ready(None)
120 }
121 Poll::Ready(Some(result)) => match result {
122 Ok(flight_message) => {
123 let flight_data = this.encoder.encode(flight_message);
124 Poll::Ready(Some(Ok(flight_data)))
125 }
126 Err(e) => {
127 *this.done = true;
128 Poll::Ready(Some(Err(e)))
129 }
130 },
131 Poll::Pending => Poll::Pending,
132 }
133 }
134 }
135}
136
137#[cfg(test)]
138mod test {
139 use std::sync::Arc;
140
141 use common_grpc::flight::{FlightDecoder, FlightMessage};
142 use common_recordbatch::{RecordBatch, RecordBatches};
143 use datatypes::prelude::*;
144 use datatypes::schema::{ColumnSchema, Schema};
145 use datatypes::vectors::Int32Vector;
146 use futures::StreamExt;
147
148 use super::*;
149
150 #[tokio::test]
151 async fn test_flight_record_batch_stream() {
152 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
153 "a",
154 ConcreteDataType::int32_datatype(),
155 false,
156 )]));
157
158 let v: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
159 let recordbatch = RecordBatch::new(schema.clone(), vec![v]).unwrap();
160
161 let recordbatches = RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()])
162 .unwrap()
163 .as_stream();
164 let mut stream = FlightRecordBatchStream::new(recordbatches, TracingContext::default());
165
166 let mut raw_data = Vec::with_capacity(2);
167 raw_data.push(stream.next().await.unwrap().unwrap());
168 raw_data.push(stream.next().await.unwrap().unwrap());
169 assert!(stream.next().await.is_none());
170 assert!(stream.done);
171
172 let decoder = &mut FlightDecoder::default();
173 let mut flight_messages = raw_data
174 .into_iter()
175 .map(|x| decoder.try_decode(&x).unwrap())
176 .collect::<Vec<FlightMessage>>();
177 assert_eq!(flight_messages.len(), 2);
178
179 match flight_messages.remove(0) {
180 FlightMessage::Schema(actual_schema) => {
181 assert_eq!(&actual_schema, schema.arrow_schema());
182 }
183 _ => unreachable!(),
184 }
185
186 match flight_messages.remove(0) {
187 FlightMessage::RecordBatch(actual_recordbatch) => {
188 assert_eq!(&actual_recordbatch, recordbatch.df_record_batch());
189 }
190 _ => unreachable!(),
191 }
192 }
193}