servers/grpc/flight/
stream.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::VecDeque;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18use std::time::{Duration, Instant};
19
20use arrow_flight::FlightData;
21use common_error::ext::ErrorExt;
22use common_grpc::flight::{FlightEncoder, FlightMessage};
23use common_recordbatch::SendableRecordBatchStream;
24use common_telemetry::tracing::{info_span, Instrument};
25use common_telemetry::tracing_context::{FutureExt, TracingContext};
26use common_telemetry::{error, info, warn};
27use futures::channel::mpsc;
28use futures::channel::mpsc::Sender;
29use futures::{SinkExt, Stream, StreamExt};
30use pin_project::{pin_project, pinned_drop};
31use session::context::QueryContextRef;
32use snafu::ResultExt;
33use tokio::task::JoinHandle;
34
35use crate::error;
36use crate::grpc::flight::TonicResult;
37use crate::grpc::FlightCompression;
38
39/// Metrics collector for Flight stream with RAII logging pattern
40struct StreamMetrics {
41    send_schema_duration: Duration,
42    send_record_batch_duration: Duration,
43    send_metrics_duration: Duration,
44    fetch_content_duration: Duration,
45    record_batch_count: usize,
46    metrics_count: usize,
47    total_rows: usize,
48    total_bytes: usize,
49    should_log: bool,
50}
51
52impl StreamMetrics {
53    fn new(should_log: bool) -> Self {
54        Self {
55            send_schema_duration: Duration::ZERO,
56            send_record_batch_duration: Duration::ZERO,
57            send_metrics_duration: Duration::ZERO,
58            fetch_content_duration: Duration::ZERO,
59            record_batch_count: 0,
60            metrics_count: 0,
61            total_rows: 0,
62            total_bytes: 0,
63            should_log,
64        }
65    }
66}
67
68impl Drop for StreamMetrics {
69    fn drop(&mut self) {
70        if self.should_log {
71            info!(
72                "flight_data_stream finished: \
73                send_schema_duration={:?}, \
74                send_record_batch_duration={:?}, \
75                send_metrics_duration={:?}, \
76                fetch_content_duration={:?}, \
77                record_batch_count={}, \
78                metrics_count={}, \
79                total_rows={}, \
80                total_bytes={}",
81                self.send_schema_duration,
82                self.send_record_batch_duration,
83                self.send_metrics_duration,
84                self.fetch_content_duration,
85                self.record_batch_count,
86                self.metrics_count,
87                self.total_rows,
88                self.total_bytes
89            );
90        }
91    }
92}
93
94#[pin_project(PinnedDrop)]
95pub struct FlightRecordBatchStream {
96    #[pin]
97    rx: mpsc::Receiver<Result<FlightMessage, tonic::Status>>,
98    join_handle: JoinHandle<()>,
99    done: bool,
100    encoder: FlightEncoder,
101    buffer: VecDeque<FlightData>,
102}
103
104impl FlightRecordBatchStream {
105    pub fn new(
106        recordbatches: SendableRecordBatchStream,
107        tracing_context: TracingContext,
108        compression: FlightCompression,
109        query_ctx: QueryContextRef,
110    ) -> Self {
111        let should_send_partial_metrics = query_ctx.explain_verbose();
112        let (tx, rx) = mpsc::channel::<TonicResult<FlightMessage>>(1);
113        let join_handle = common_runtime::spawn_global(async move {
114            Self::flight_data_stream(recordbatches, tx, should_send_partial_metrics)
115                .trace(tracing_context.attach(info_span!("flight_data_stream")))
116                .await
117        });
118        let encoder = if compression.arrow_compression() {
119            FlightEncoder::default()
120        } else {
121            FlightEncoder::with_compression_disabled()
122        };
123        Self {
124            rx,
125            join_handle,
126            done: false,
127            encoder,
128            buffer: VecDeque::new(),
129        }
130    }
131
132    async fn flight_data_stream(
133        mut recordbatches: SendableRecordBatchStream,
134        mut tx: Sender<TonicResult<FlightMessage>>,
135        should_send_partial_metrics: bool,
136    ) {
137        let mut metrics = StreamMetrics::new(should_send_partial_metrics);
138
139        let schema = recordbatches.schema().arrow_schema().clone();
140        let start = Instant::now();
141        if let Err(e) = tx.send(Ok(FlightMessage::Schema(schema))).await {
142            warn!(e; "stop sending Flight data");
143            return;
144        }
145        metrics.send_schema_duration += start.elapsed();
146
147        while let Some(batch_or_err) = {
148            let start = Instant::now();
149            let result = recordbatches.next().in_current_span().await;
150            metrics.fetch_content_duration += start.elapsed();
151            result
152        } {
153            match batch_or_err {
154                Ok(recordbatch) => {
155                    metrics.total_rows += recordbatch.num_rows();
156                    metrics.record_batch_count += 1;
157                    metrics.total_bytes += recordbatch.df_record_batch().get_array_memory_size();
158
159                    let start = Instant::now();
160                    if let Err(e) = tx
161                        .send(Ok(FlightMessage::RecordBatch(
162                            recordbatch.into_df_record_batch(),
163                        )))
164                        .await
165                    {
166                        warn!(e; "stop sending Flight data");
167                        return;
168                    }
169                    metrics.send_record_batch_duration += start.elapsed();
170
171                    if should_send_partial_metrics {
172                        if let Some(metrics_str) = recordbatches
173                            .metrics()
174                            .and_then(|m| serde_json::to_string(&m).ok())
175                        {
176                            metrics.metrics_count += 1;
177                            let start = Instant::now();
178                            if let Err(e) = tx.send(Ok(FlightMessage::Metrics(metrics_str))).await {
179                                warn!(e; "stop sending Flight data");
180                                return;
181                            }
182                            metrics.send_metrics_duration += start.elapsed();
183                        }
184                    }
185                }
186                Err(e) => {
187                    if e.status_code().should_log_error() {
188                        error!("{e:?}");
189                    }
190
191                    let e = Err(e).context(error::CollectRecordbatchSnafu);
192                    if let Err(e) = tx.send(e.map_err(|x| x.into())).await {
193                        warn!(e; "stop sending Flight data");
194                    }
195                    return;
196                }
197            }
198        }
199        // make last package to pass metrics
200        if let Some(metrics_str) = recordbatches
201            .metrics()
202            .and_then(|m| serde_json::to_string(&m).ok())
203        {
204            metrics.metrics_count += 1;
205            let start = Instant::now();
206            let _ = tx.send(Ok(FlightMessage::Metrics(metrics_str))).await;
207            metrics.send_metrics_duration += start.elapsed();
208        }
209    }
210}
211
212#[pinned_drop]
213impl PinnedDrop for FlightRecordBatchStream {
214    fn drop(self: Pin<&mut Self>) {
215        self.join_handle.abort();
216    }
217}
218
219impl Stream for FlightRecordBatchStream {
220    type Item = TonicResult<FlightData>;
221
222    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
223        let this = self.project();
224        if *this.done {
225            Poll::Ready(None)
226        } else {
227            if let Some(x) = this.buffer.pop_front() {
228                return Poll::Ready(Some(Ok(x)));
229            }
230            match this.rx.poll_next(cx) {
231                Poll::Ready(None) => {
232                    *this.done = true;
233                    Poll::Ready(None)
234                }
235                Poll::Ready(Some(result)) => match result {
236                    Ok(flight_message) => {
237                        let mut iter = this.encoder.encode(flight_message).into_iter();
238                        let Some(first) = iter.next() else {
239                            // Safety: `iter` on a type of `Vec1`, which is guaranteed to have
240                            // at least one element.
241                            unreachable!()
242                        };
243                        this.buffer.extend(iter);
244                        Poll::Ready(Some(Ok(first)))
245                    }
246                    Err(e) => {
247                        *this.done = true;
248                        Poll::Ready(Some(Err(e)))
249                    }
250                },
251                Poll::Pending => Poll::Pending,
252            }
253        }
254    }
255}
256
257#[cfg(test)]
258mod test {
259    use std::sync::Arc;
260
261    use common_grpc::flight::{FlightDecoder, FlightMessage};
262    use common_recordbatch::{RecordBatch, RecordBatches};
263    use datatypes::prelude::*;
264    use datatypes::schema::{ColumnSchema, Schema};
265    use datatypes::vectors::Int32Vector;
266    use futures::StreamExt;
267    use session::context::QueryContext;
268
269    use super::*;
270
271    #[tokio::test]
272    async fn test_flight_record_batch_stream() {
273        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
274            "a",
275            ConcreteDataType::int32_datatype(),
276            false,
277        )]));
278
279        let v: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
280        let recordbatch = RecordBatch::new(schema.clone(), vec![v]).unwrap();
281
282        let recordbatches = RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()])
283            .unwrap()
284            .as_stream();
285        let mut stream = FlightRecordBatchStream::new(
286            recordbatches,
287            TracingContext::default(),
288            FlightCompression::default(),
289            QueryContext::arc(),
290        );
291
292        let mut raw_data = Vec::with_capacity(2);
293        raw_data.push(stream.next().await.unwrap().unwrap());
294        raw_data.push(stream.next().await.unwrap().unwrap());
295        assert!(stream.next().await.is_none());
296        assert!(stream.done);
297
298        let decoder = &mut FlightDecoder::default();
299        let mut flight_messages = raw_data
300            .into_iter()
301            .map(|x| decoder.try_decode(&x).unwrap().unwrap())
302            .collect::<Vec<FlightMessage>>();
303        assert_eq!(flight_messages.len(), 2);
304
305        match flight_messages.remove(0) {
306            FlightMessage::Schema(actual_schema) => {
307                assert_eq!(&actual_schema, schema.arrow_schema());
308            }
309            _ => unreachable!(),
310        }
311
312        match flight_messages.remove(0) {
313            FlightMessage::RecordBatch(actual_recordbatch) => {
314                assert_eq!(&actual_recordbatch, recordbatch.df_record_batch());
315            }
316            _ => unreachable!(),
317        }
318    }
319}