1use std::collections::VecDeque;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18use std::time::{Duration, Instant};
19
20use arrow_flight::FlightData;
21use common_error::ext::ErrorExt;
22use common_grpc::flight::{FlightEncoder, FlightMessage};
23use common_recordbatch::SendableRecordBatchStream;
24use common_telemetry::tracing::{info_span, Instrument};
25use common_telemetry::tracing_context::{FutureExt, TracingContext};
26use common_telemetry::{error, info, warn};
27use futures::channel::mpsc;
28use futures::channel::mpsc::Sender;
29use futures::{SinkExt, Stream, StreamExt};
30use pin_project::{pin_project, pinned_drop};
31use session::context::QueryContextRef;
32use snafu::ResultExt;
33use tokio::task::JoinHandle;
34
35use crate::error;
36use crate::grpc::flight::TonicResult;
37use crate::grpc::FlightCompression;
38
39struct StreamMetrics {
41 send_schema_duration: Duration,
42 send_record_batch_duration: Duration,
43 send_metrics_duration: Duration,
44 fetch_content_duration: Duration,
45 record_batch_count: usize,
46 metrics_count: usize,
47 total_rows: usize,
48 total_bytes: usize,
49 should_log: bool,
50}
51
52impl StreamMetrics {
53 fn new(should_log: bool) -> Self {
54 Self {
55 send_schema_duration: Duration::ZERO,
56 send_record_batch_duration: Duration::ZERO,
57 send_metrics_duration: Duration::ZERO,
58 fetch_content_duration: Duration::ZERO,
59 record_batch_count: 0,
60 metrics_count: 0,
61 total_rows: 0,
62 total_bytes: 0,
63 should_log,
64 }
65 }
66}
67
68impl Drop for StreamMetrics {
69 fn drop(&mut self) {
70 if self.should_log {
71 info!(
72 "flight_data_stream finished: \
73 send_schema_duration={:?}, \
74 send_record_batch_duration={:?}, \
75 send_metrics_duration={:?}, \
76 fetch_content_duration={:?}, \
77 record_batch_count={}, \
78 metrics_count={}, \
79 total_rows={}, \
80 total_bytes={}",
81 self.send_schema_duration,
82 self.send_record_batch_duration,
83 self.send_metrics_duration,
84 self.fetch_content_duration,
85 self.record_batch_count,
86 self.metrics_count,
87 self.total_rows,
88 self.total_bytes
89 );
90 }
91 }
92}
93
94#[pin_project(PinnedDrop)]
95pub struct FlightRecordBatchStream {
96 #[pin]
97 rx: mpsc::Receiver<Result<FlightMessage, tonic::Status>>,
98 join_handle: JoinHandle<()>,
99 done: bool,
100 encoder: FlightEncoder,
101 buffer: VecDeque<FlightData>,
102}
103
104impl FlightRecordBatchStream {
105 pub fn new(
106 recordbatches: SendableRecordBatchStream,
107 tracing_context: TracingContext,
108 compression: FlightCompression,
109 query_ctx: QueryContextRef,
110 ) -> Self {
111 let should_send_partial_metrics = query_ctx.explain_verbose();
112 let (tx, rx) = mpsc::channel::<TonicResult<FlightMessage>>(1);
113 let join_handle = common_runtime::spawn_global(async move {
114 Self::flight_data_stream(recordbatches, tx, should_send_partial_metrics)
115 .trace(tracing_context.attach(info_span!("flight_data_stream")))
116 .await
117 });
118 let encoder = if compression.arrow_compression() {
119 FlightEncoder::default()
120 } else {
121 FlightEncoder::with_compression_disabled()
122 };
123 Self {
124 rx,
125 join_handle,
126 done: false,
127 encoder,
128 buffer: VecDeque::new(),
129 }
130 }
131
132 async fn flight_data_stream(
133 mut recordbatches: SendableRecordBatchStream,
134 mut tx: Sender<TonicResult<FlightMessage>>,
135 should_send_partial_metrics: bool,
136 ) {
137 let mut metrics = StreamMetrics::new(should_send_partial_metrics);
138
139 let schema = recordbatches.schema().arrow_schema().clone();
140 let start = Instant::now();
141 if let Err(e) = tx.send(Ok(FlightMessage::Schema(schema))).await {
142 warn!(e; "stop sending Flight data");
143 return;
144 }
145 metrics.send_schema_duration += start.elapsed();
146
147 while let Some(batch_or_err) = {
148 let start = Instant::now();
149 let result = recordbatches.next().in_current_span().await;
150 metrics.fetch_content_duration += start.elapsed();
151 result
152 } {
153 match batch_or_err {
154 Ok(recordbatch) => {
155 metrics.total_rows += recordbatch.num_rows();
156 metrics.record_batch_count += 1;
157 metrics.total_bytes += recordbatch.df_record_batch().get_array_memory_size();
158
159 let start = Instant::now();
160 if let Err(e) = tx
161 .send(Ok(FlightMessage::RecordBatch(
162 recordbatch.into_df_record_batch(),
163 )))
164 .await
165 {
166 warn!(e; "stop sending Flight data");
167 return;
168 }
169 metrics.send_record_batch_duration += start.elapsed();
170
171 if should_send_partial_metrics {
172 if let Some(metrics_str) = recordbatches
173 .metrics()
174 .and_then(|m| serde_json::to_string(&m).ok())
175 {
176 metrics.metrics_count += 1;
177 let start = Instant::now();
178 if let Err(e) = tx.send(Ok(FlightMessage::Metrics(metrics_str))).await {
179 warn!(e; "stop sending Flight data");
180 return;
181 }
182 metrics.send_metrics_duration += start.elapsed();
183 }
184 }
185 }
186 Err(e) => {
187 if e.status_code().should_log_error() {
188 error!("{e:?}");
189 }
190
191 let e = Err(e).context(error::CollectRecordbatchSnafu);
192 if let Err(e) = tx.send(e.map_err(|x| x.into())).await {
193 warn!(e; "stop sending Flight data");
194 }
195 return;
196 }
197 }
198 }
199 if let Some(metrics_str) = recordbatches
201 .metrics()
202 .and_then(|m| serde_json::to_string(&m).ok())
203 {
204 metrics.metrics_count += 1;
205 let start = Instant::now();
206 let _ = tx.send(Ok(FlightMessage::Metrics(metrics_str))).await;
207 metrics.send_metrics_duration += start.elapsed();
208 }
209 }
210}
211
212#[pinned_drop]
213impl PinnedDrop for FlightRecordBatchStream {
214 fn drop(self: Pin<&mut Self>) {
215 self.join_handle.abort();
216 }
217}
218
219impl Stream for FlightRecordBatchStream {
220 type Item = TonicResult<FlightData>;
221
222 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
223 let this = self.project();
224 if *this.done {
225 Poll::Ready(None)
226 } else {
227 if let Some(x) = this.buffer.pop_front() {
228 return Poll::Ready(Some(Ok(x)));
229 }
230 match this.rx.poll_next(cx) {
231 Poll::Ready(None) => {
232 *this.done = true;
233 Poll::Ready(None)
234 }
235 Poll::Ready(Some(result)) => match result {
236 Ok(flight_message) => {
237 let mut iter = this.encoder.encode(flight_message).into_iter();
238 let Some(first) = iter.next() else {
239 unreachable!()
242 };
243 this.buffer.extend(iter);
244 Poll::Ready(Some(Ok(first)))
245 }
246 Err(e) => {
247 *this.done = true;
248 Poll::Ready(Some(Err(e)))
249 }
250 },
251 Poll::Pending => Poll::Pending,
252 }
253 }
254 }
255}
256
257#[cfg(test)]
258mod test {
259 use std::sync::Arc;
260
261 use common_grpc::flight::{FlightDecoder, FlightMessage};
262 use common_recordbatch::{RecordBatch, RecordBatches};
263 use datatypes::prelude::*;
264 use datatypes::schema::{ColumnSchema, Schema};
265 use datatypes::vectors::Int32Vector;
266 use futures::StreamExt;
267 use session::context::QueryContext;
268
269 use super::*;
270
271 #[tokio::test]
272 async fn test_flight_record_batch_stream() {
273 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
274 "a",
275 ConcreteDataType::int32_datatype(),
276 false,
277 )]));
278
279 let v: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
280 let recordbatch = RecordBatch::new(schema.clone(), vec![v]).unwrap();
281
282 let recordbatches = RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()])
283 .unwrap()
284 .as_stream();
285 let mut stream = FlightRecordBatchStream::new(
286 recordbatches,
287 TracingContext::default(),
288 FlightCompression::default(),
289 QueryContext::arc(),
290 );
291
292 let mut raw_data = Vec::with_capacity(2);
293 raw_data.push(stream.next().await.unwrap().unwrap());
294 raw_data.push(stream.next().await.unwrap().unwrap());
295 assert!(stream.next().await.is_none());
296 assert!(stream.done);
297
298 let decoder = &mut FlightDecoder::default();
299 let mut flight_messages = raw_data
300 .into_iter()
301 .map(|x| decoder.try_decode(&x).unwrap().unwrap())
302 .collect::<Vec<FlightMessage>>();
303 assert_eq!(flight_messages.len(), 2);
304
305 match flight_messages.remove(0) {
306 FlightMessage::Schema(actual_schema) => {
307 assert_eq!(&actual_schema, schema.arrow_schema());
308 }
309 _ => unreachable!(),
310 }
311
312 match flight_messages.remove(0) {
313 FlightMessage::RecordBatch(actual_recordbatch) => {
314 assert_eq!(&actual_recordbatch, recordbatch.df_record_batch());
315 }
316 _ => unreachable!(),
317 }
318 }
319}