servers/grpc/flight/
stream.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::VecDeque;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18use std::time::{Duration, Instant};
19
20use arrow_flight::FlightData;
21use common_error::ext::ErrorExt;
22use common_grpc::flight::{FlightEncoder, FlightMessage};
23use common_recordbatch::SendableRecordBatchStream;
24use common_telemetry::tracing::{Instrument, info_span};
25use common_telemetry::tracing_context::{FutureExt, TracingContext};
26use common_telemetry::{error, info, warn};
27use futures::channel::mpsc;
28use futures::channel::mpsc::Sender;
29use futures::{SinkExt, Stream, StreamExt};
30use pin_project::{pin_project, pinned_drop};
31use session::context::QueryContextRef;
32use snafu::ResultExt;
33use tokio::task::JoinHandle;
34
35use crate::error;
36use crate::grpc::FlightCompression;
37use crate::grpc::flight::TonicResult;
38
39/// Metrics collector for Flight stream with RAII logging pattern
40struct StreamMetrics {
41    send_schema_duration: Duration,
42    send_record_batch_duration: Duration,
43    send_metrics_duration: Duration,
44    fetch_content_duration: Duration,
45    record_batch_count: usize,
46    metrics_count: usize,
47    total_rows: usize,
48    total_bytes: usize,
49    should_log: bool,
50}
51
52impl StreamMetrics {
53    fn new(should_log: bool) -> Self {
54        Self {
55            send_schema_duration: Duration::ZERO,
56            send_record_batch_duration: Duration::ZERO,
57            send_metrics_duration: Duration::ZERO,
58            fetch_content_duration: Duration::ZERO,
59            record_batch_count: 0,
60            metrics_count: 0,
61            total_rows: 0,
62            total_bytes: 0,
63            should_log,
64        }
65    }
66}
67
68impl Drop for StreamMetrics {
69    fn drop(&mut self) {
70        if self.should_log {
71            info!(
72                "flight_data_stream finished: \
73                send_schema_duration={:?}, \
74                send_record_batch_duration={:?}, \
75                send_metrics_duration={:?}, \
76                fetch_content_duration={:?}, \
77                record_batch_count={}, \
78                metrics_count={}, \
79                total_rows={}, \
80                total_bytes={}",
81                self.send_schema_duration,
82                self.send_record_batch_duration,
83                self.send_metrics_duration,
84                self.fetch_content_duration,
85                self.record_batch_count,
86                self.metrics_count,
87                self.total_rows,
88                self.total_bytes
89            );
90        }
91    }
92}
93
94#[pin_project(PinnedDrop)]
95pub struct FlightRecordBatchStream {
96    #[pin]
97    rx: mpsc::Receiver<Result<FlightMessage, tonic::Status>>,
98    join_handle: JoinHandle<()>,
99    done: bool,
100    encoder: FlightEncoder,
101    buffer: VecDeque<FlightData>,
102}
103
104impl FlightRecordBatchStream {
105    pub fn new(
106        recordbatches: SendableRecordBatchStream,
107        tracing_context: TracingContext,
108        compression: FlightCompression,
109        query_ctx: QueryContextRef,
110    ) -> Self {
111        let should_send_partial_metrics = query_ctx.explain_verbose();
112        let (tx, rx) = mpsc::channel::<TonicResult<FlightMessage>>(1);
113        let join_handle = common_runtime::spawn_global(async move {
114            Self::flight_data_stream(recordbatches, tx, should_send_partial_metrics)
115                .trace(tracing_context.attach(info_span!("flight_data_stream")))
116                .await
117        });
118        let encoder = if compression.arrow_compression() {
119            FlightEncoder::default()
120        } else {
121            FlightEncoder::with_compression_disabled()
122        };
123        Self {
124            rx,
125            join_handle,
126            done: false,
127            encoder,
128            buffer: VecDeque::new(),
129        }
130    }
131
132    async fn flight_data_stream(
133        mut recordbatches: SendableRecordBatchStream,
134        mut tx: Sender<TonicResult<FlightMessage>>,
135        should_send_partial_metrics: bool,
136    ) {
137        let mut metrics = StreamMetrics::new(should_send_partial_metrics);
138
139        let schema = recordbatches.schema().arrow_schema().clone();
140        let start = Instant::now();
141        if let Err(e) = tx.send(Ok(FlightMessage::Schema(schema))).await {
142            warn!(e; "stop sending Flight data");
143            return;
144        }
145        metrics.send_schema_duration += start.elapsed();
146
147        while let Some(batch_or_err) = {
148            let start = Instant::now();
149            let result = recordbatches.next().in_current_span().await;
150            metrics.fetch_content_duration += start.elapsed();
151            result
152        } {
153            match batch_or_err {
154                Ok(recordbatch) => {
155                    metrics.total_rows += recordbatch.num_rows();
156                    metrics.record_batch_count += 1;
157                    metrics.total_bytes += recordbatch.df_record_batch().get_array_memory_size();
158
159                    let start = Instant::now();
160                    if let Err(e) = tx
161                        .send(Ok(FlightMessage::RecordBatch(
162                            recordbatch.into_df_record_batch(),
163                        )))
164                        .await
165                    {
166                        warn!(e; "stop sending Flight data");
167                        return;
168                    }
169                    metrics.send_record_batch_duration += start.elapsed();
170
171                    if should_send_partial_metrics
172                        && let Some(metrics_str) = recordbatches
173                            .metrics()
174                            .and_then(|m| serde_json::to_string(&m).ok())
175                    {
176                        metrics.metrics_count += 1;
177                        let start = Instant::now();
178                        if let Err(e) = tx.send(Ok(FlightMessage::Metrics(metrics_str))).await {
179                            warn!(e; "stop sending Flight data");
180                            return;
181                        }
182                        metrics.send_metrics_duration += start.elapsed();
183                    }
184                }
185                Err(e) => {
186                    if e.status_code().should_log_error() {
187                        error!("{e:?}");
188                    }
189
190                    let e = Err(e).context(error::CollectRecordbatchSnafu);
191                    if let Err(e) = tx.send(e.map_err(|x| x.into())).await {
192                        warn!(e; "stop sending Flight data");
193                    }
194                    return;
195                }
196            }
197        }
198        // make last package to pass metrics
199        if let Some(metrics_str) = recordbatches
200            .metrics()
201            .and_then(|m| serde_json::to_string(&m).ok())
202        {
203            metrics.metrics_count += 1;
204            let start = Instant::now();
205            let _ = tx.send(Ok(FlightMessage::Metrics(metrics_str))).await;
206            metrics.send_metrics_duration += start.elapsed();
207        }
208    }
209}
210
211#[pinned_drop]
212impl PinnedDrop for FlightRecordBatchStream {
213    fn drop(self: Pin<&mut Self>) {
214        self.join_handle.abort();
215    }
216}
217
218impl Stream for FlightRecordBatchStream {
219    type Item = TonicResult<FlightData>;
220
221    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
222        let this = self.project();
223        if *this.done {
224            Poll::Ready(None)
225        } else {
226            if let Some(x) = this.buffer.pop_front() {
227                return Poll::Ready(Some(Ok(x)));
228            }
229            match this.rx.poll_next(cx) {
230                Poll::Ready(None) => {
231                    *this.done = true;
232                    Poll::Ready(None)
233                }
234                Poll::Ready(Some(result)) => match result {
235                    Ok(flight_message) => {
236                        let mut iter = this.encoder.encode(flight_message).into_iter();
237                        let Some(first) = iter.next() else {
238                            // Safety: `iter` on a type of `Vec1`, which is guaranteed to have
239                            // at least one element.
240                            unreachable!()
241                        };
242                        this.buffer.extend(iter);
243                        Poll::Ready(Some(Ok(first)))
244                    }
245                    Err(e) => {
246                        *this.done = true;
247                        Poll::Ready(Some(Err(e)))
248                    }
249                },
250                Poll::Pending => Poll::Pending,
251            }
252        }
253    }
254}
255
256#[cfg(test)]
257mod test {
258    use std::sync::Arc;
259
260    use common_grpc::flight::{FlightDecoder, FlightMessage};
261    use common_recordbatch::{RecordBatch, RecordBatches};
262    use datatypes::prelude::*;
263    use datatypes::schema::{ColumnSchema, Schema};
264    use datatypes::vectors::Int32Vector;
265    use futures::StreamExt;
266    use session::context::QueryContext;
267
268    use super::*;
269
270    #[tokio::test]
271    async fn test_flight_record_batch_stream() {
272        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
273            "a",
274            ConcreteDataType::int32_datatype(),
275            false,
276        )]));
277
278        let v: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
279        let recordbatch = RecordBatch::new(schema.clone(), vec![v]).unwrap();
280
281        let recordbatches = RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()])
282            .unwrap()
283            .as_stream();
284        let mut stream = FlightRecordBatchStream::new(
285            recordbatches,
286            TracingContext::default(),
287            FlightCompression::default(),
288            QueryContext::arc(),
289        );
290
291        let mut raw_data = Vec::with_capacity(2);
292        raw_data.push(stream.next().await.unwrap().unwrap());
293        raw_data.push(stream.next().await.unwrap().unwrap());
294        assert!(stream.next().await.is_none());
295        assert!(stream.done);
296
297        let decoder = &mut FlightDecoder::default();
298        let mut flight_messages = raw_data
299            .into_iter()
300            .map(|x| decoder.try_decode(&x).unwrap().unwrap())
301            .collect::<Vec<FlightMessage>>();
302        assert_eq!(flight_messages.len(), 2);
303
304        match flight_messages.remove(0) {
305            FlightMessage::Schema(actual_schema) => {
306                assert_eq!(&actual_schema, schema.arrow_schema());
307            }
308            _ => unreachable!(),
309        }
310
311        match flight_messages.remove(0) {
312            FlightMessage::RecordBatch(actual_recordbatch) => {
313                assert_eq!(&actual_recordbatch, recordbatch.df_record_batch());
314            }
315            _ => unreachable!(),
316        }
317    }
318}