1use std::collections::VecDeque;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18use std::time::{Duration, Instant};
19
20use arrow_flight::FlightData;
21use common_error::ext::ErrorExt;
22use common_grpc::flight::{FlightEncoder, FlightMessage};
23use common_recordbatch::SendableRecordBatchStream;
24use common_telemetry::tracing::{Instrument, info_span};
25use common_telemetry::tracing_context::{FutureExt, TracingContext};
26use common_telemetry::{error, info, warn};
27use futures::channel::mpsc;
28use futures::channel::mpsc::Sender;
29use futures::{SinkExt, Stream, StreamExt};
30use pin_project::{pin_project, pinned_drop};
31use session::context::QueryContextRef;
32use snafu::ResultExt;
33use tokio::task::JoinHandle;
34
35use crate::error;
36use crate::grpc::FlightCompression;
37use crate::grpc::flight::TonicResult;
38
39struct StreamMetrics {
41 send_schema_duration: Duration,
42 send_record_batch_duration: Duration,
43 send_metrics_duration: Duration,
44 fetch_content_duration: Duration,
45 record_batch_count: usize,
46 metrics_count: usize,
47 total_rows: usize,
48 total_bytes: usize,
49 should_log: bool,
50}
51
52impl StreamMetrics {
53 fn new(should_log: bool) -> Self {
54 Self {
55 send_schema_duration: Duration::ZERO,
56 send_record_batch_duration: Duration::ZERO,
57 send_metrics_duration: Duration::ZERO,
58 fetch_content_duration: Duration::ZERO,
59 record_batch_count: 0,
60 metrics_count: 0,
61 total_rows: 0,
62 total_bytes: 0,
63 should_log,
64 }
65 }
66}
67
68impl Drop for StreamMetrics {
69 fn drop(&mut self) {
70 if self.should_log {
71 info!(
72 "flight_data_stream finished: \
73 send_schema_duration={:?}, \
74 send_record_batch_duration={:?}, \
75 send_metrics_duration={:?}, \
76 fetch_content_duration={:?}, \
77 record_batch_count={}, \
78 metrics_count={}, \
79 total_rows={}, \
80 total_bytes={}",
81 self.send_schema_duration,
82 self.send_record_batch_duration,
83 self.send_metrics_duration,
84 self.fetch_content_duration,
85 self.record_batch_count,
86 self.metrics_count,
87 self.total_rows,
88 self.total_bytes
89 );
90 }
91 }
92}
93
94#[pin_project(PinnedDrop)]
95pub struct FlightRecordBatchStream {
96 #[pin]
97 rx: mpsc::Receiver<Result<FlightMessage, tonic::Status>>,
98 join_handle: JoinHandle<()>,
99 done: bool,
100 encoder: FlightEncoder,
101 buffer: VecDeque<FlightData>,
102}
103
104impl FlightRecordBatchStream {
105 pub fn new(
106 recordbatches: SendableRecordBatchStream,
107 tracing_context: TracingContext,
108 compression: FlightCompression,
109 query_ctx: QueryContextRef,
110 ) -> Self {
111 let should_send_partial_metrics = query_ctx.explain_verbose();
112 let (tx, rx) = mpsc::channel::<TonicResult<FlightMessage>>(1);
113 let join_handle = common_runtime::spawn_global(async move {
114 Self::flight_data_stream(recordbatches, tx, should_send_partial_metrics)
115 .trace(tracing_context.attach(info_span!("flight_data_stream")))
116 .await
117 });
118 let encoder = if compression.arrow_compression() {
119 FlightEncoder::default()
120 } else {
121 FlightEncoder::with_compression_disabled()
122 };
123 Self {
124 rx,
125 join_handle,
126 done: false,
127 encoder,
128 buffer: VecDeque::new(),
129 }
130 }
131
132 async fn flight_data_stream(
133 mut recordbatches: SendableRecordBatchStream,
134 mut tx: Sender<TonicResult<FlightMessage>>,
135 should_send_partial_metrics: bool,
136 ) {
137 let mut metrics = StreamMetrics::new(should_send_partial_metrics);
138
139 let schema = recordbatches.schema().arrow_schema().clone();
140 let start = Instant::now();
141 if let Err(e) = tx.send(Ok(FlightMessage::Schema(schema))).await {
142 warn!(e; "stop sending Flight data");
143 return;
144 }
145 metrics.send_schema_duration += start.elapsed();
146
147 while let Some(batch_or_err) = {
148 let start = Instant::now();
149 let result = recordbatches.next().in_current_span().await;
150 metrics.fetch_content_duration += start.elapsed();
151 result
152 } {
153 match batch_or_err {
154 Ok(recordbatch) => {
155 metrics.total_rows += recordbatch.num_rows();
156 metrics.record_batch_count += 1;
157 metrics.total_bytes += recordbatch.df_record_batch().get_array_memory_size();
158
159 let start = Instant::now();
160 if let Err(e) = tx
161 .send(Ok(FlightMessage::RecordBatch(
162 recordbatch.into_df_record_batch(),
163 )))
164 .await
165 {
166 warn!(e; "stop sending Flight data");
167 return;
168 }
169 metrics.send_record_batch_duration += start.elapsed();
170
171 if should_send_partial_metrics
172 && let Some(metrics_str) = recordbatches
173 .metrics()
174 .and_then(|m| serde_json::to_string(&m).ok())
175 {
176 metrics.metrics_count += 1;
177 let start = Instant::now();
178 if let Err(e) = tx.send(Ok(FlightMessage::Metrics(metrics_str))).await {
179 warn!(e; "stop sending Flight data");
180 return;
181 }
182 metrics.send_metrics_duration += start.elapsed();
183 }
184 }
185 Err(e) => {
186 if e.status_code().should_log_error() {
187 error!("{e:?}");
188 }
189
190 let e = Err(e).context(error::CollectRecordbatchSnafu);
191 if let Err(e) = tx.send(e.map_err(|x| x.into())).await {
192 warn!(e; "stop sending Flight data");
193 }
194 return;
195 }
196 }
197 }
198 if let Some(metrics_str) = recordbatches
200 .metrics()
201 .and_then(|m| serde_json::to_string(&m).ok())
202 {
203 metrics.metrics_count += 1;
204 let start = Instant::now();
205 let _ = tx.send(Ok(FlightMessage::Metrics(metrics_str))).await;
206 metrics.send_metrics_duration += start.elapsed();
207 }
208 }
209}
210
211#[pinned_drop]
212impl PinnedDrop for FlightRecordBatchStream {
213 fn drop(self: Pin<&mut Self>) {
214 self.join_handle.abort();
215 }
216}
217
218impl Stream for FlightRecordBatchStream {
219 type Item = TonicResult<FlightData>;
220
221 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
222 let this = self.project();
223 if *this.done {
224 Poll::Ready(None)
225 } else {
226 if let Some(x) = this.buffer.pop_front() {
227 return Poll::Ready(Some(Ok(x)));
228 }
229 match this.rx.poll_next(cx) {
230 Poll::Ready(None) => {
231 *this.done = true;
232 Poll::Ready(None)
233 }
234 Poll::Ready(Some(result)) => match result {
235 Ok(flight_message) => {
236 let mut iter = this.encoder.encode(flight_message).into_iter();
237 let Some(first) = iter.next() else {
238 unreachable!()
241 };
242 this.buffer.extend(iter);
243 Poll::Ready(Some(Ok(first)))
244 }
245 Err(e) => {
246 *this.done = true;
247 Poll::Ready(Some(Err(e)))
248 }
249 },
250 Poll::Pending => Poll::Pending,
251 }
252 }
253 }
254}
255
256#[cfg(test)]
257mod test {
258 use std::sync::Arc;
259
260 use common_grpc::flight::{FlightDecoder, FlightMessage};
261 use common_recordbatch::{RecordBatch, RecordBatches};
262 use datatypes::prelude::*;
263 use datatypes::schema::{ColumnSchema, Schema};
264 use datatypes::vectors::Int32Vector;
265 use futures::StreamExt;
266 use session::context::QueryContext;
267
268 use super::*;
269
270 #[tokio::test]
271 async fn test_flight_record_batch_stream() {
272 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
273 "a",
274 ConcreteDataType::int32_datatype(),
275 false,
276 )]));
277
278 let v: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
279 let recordbatch = RecordBatch::new(schema.clone(), vec![v]).unwrap();
280
281 let recordbatches = RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()])
282 .unwrap()
283 .as_stream();
284 let mut stream = FlightRecordBatchStream::new(
285 recordbatches,
286 TracingContext::default(),
287 FlightCompression::default(),
288 QueryContext::arc(),
289 );
290
291 let mut raw_data = Vec::with_capacity(2);
292 raw_data.push(stream.next().await.unwrap().unwrap());
293 raw_data.push(stream.next().await.unwrap().unwrap());
294 assert!(stream.next().await.is_none());
295 assert!(stream.done);
296
297 let decoder = &mut FlightDecoder::default();
298 let mut flight_messages = raw_data
299 .into_iter()
300 .map(|x| decoder.try_decode(&x).unwrap().unwrap())
301 .collect::<Vec<FlightMessage>>();
302 assert_eq!(flight_messages.len(), 2);
303
304 match flight_messages.remove(0) {
305 FlightMessage::Schema(actual_schema) => {
306 assert_eq!(&actual_schema, schema.arrow_schema());
307 }
308 _ => unreachable!(),
309 }
310
311 match flight_messages.remove(0) {
312 FlightMessage::RecordBatch(actual_recordbatch) => {
313 assert_eq!(&actual_recordbatch, recordbatch.df_record_batch());
314 }
315 _ => unreachable!(),
316 }
317 }
318}