servers/grpc/flight/
stream.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use arrow_flight::FlightData;
19use common_error::ext::ErrorExt;
20use common_grpc::flight::{FlightEncoder, FlightMessage};
21use common_recordbatch::SendableRecordBatchStream;
22use common_telemetry::tracing::{info_span, Instrument};
23use common_telemetry::tracing_context::{FutureExt, TracingContext};
24use common_telemetry::{error, warn};
25use futures::channel::mpsc;
26use futures::channel::mpsc::Sender;
27use futures::{SinkExt, Stream, StreamExt};
28use pin_project::{pin_project, pinned_drop};
29use snafu::ResultExt;
30use tokio::task::JoinHandle;
31
32use crate::error;
33use crate::grpc::flight::TonicResult;
34use crate::grpc::FlightCompression;
35
36#[pin_project(PinnedDrop)]
37pub struct FlightRecordBatchStream {
38    #[pin]
39    rx: mpsc::Receiver<Result<FlightMessage, tonic::Status>>,
40    join_handle: JoinHandle<()>,
41    done: bool,
42    encoder: FlightEncoder,
43}
44
45impl FlightRecordBatchStream {
46    pub fn new(
47        recordbatches: SendableRecordBatchStream,
48        tracing_context: TracingContext,
49        compression: FlightCompression,
50    ) -> Self {
51        let (tx, rx) = mpsc::channel::<TonicResult<FlightMessage>>(1);
52        let join_handle = common_runtime::spawn_global(async move {
53            Self::flight_data_stream(recordbatches, tx)
54                .trace(tracing_context.attach(info_span!("flight_data_stream")))
55                .await
56        });
57        let encoder = if compression.arrow_compression() {
58            FlightEncoder::default()
59        } else {
60            FlightEncoder::with_compression_disabled()
61        };
62        Self {
63            rx,
64            join_handle,
65            done: false,
66            encoder,
67        }
68    }
69
70    async fn flight_data_stream(
71        mut recordbatches: SendableRecordBatchStream,
72        mut tx: Sender<TonicResult<FlightMessage>>,
73    ) {
74        let schema = recordbatches.schema().arrow_schema().clone();
75        if let Err(e) = tx.send(Ok(FlightMessage::Schema(schema))).await {
76            warn!(e; "stop sending Flight data");
77            return;
78        }
79
80        while let Some(batch_or_err) = recordbatches.next().in_current_span().await {
81            match batch_or_err {
82                Ok(recordbatch) => {
83                    if let Err(e) = tx
84                        .send(Ok(FlightMessage::RecordBatch(
85                            recordbatch.into_df_record_batch(),
86                        )))
87                        .await
88                    {
89                        warn!(e; "stop sending Flight data");
90                        return;
91                    }
92                }
93                Err(e) => {
94                    if e.status_code().should_log_error() {
95                        error!("{e:?}");
96                    }
97
98                    let e = Err(e).context(error::CollectRecordbatchSnafu);
99                    if let Err(e) = tx.send(e.map_err(|x| x.into())).await {
100                        warn!(e; "stop sending Flight data");
101                    }
102                    return;
103                }
104            }
105        }
106        // make last package to pass metrics
107        if let Some(metrics_str) = recordbatches
108            .metrics()
109            .and_then(|m| serde_json::to_string(&m).ok())
110        {
111            let _ = tx.send(Ok(FlightMessage::Metrics(metrics_str))).await;
112        }
113    }
114}
115
116#[pinned_drop]
117impl PinnedDrop for FlightRecordBatchStream {
118    fn drop(self: Pin<&mut Self>) {
119        self.join_handle.abort();
120    }
121}
122
123impl Stream for FlightRecordBatchStream {
124    type Item = TonicResult<FlightData>;
125
126    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
127        let this = self.project();
128        if *this.done {
129            Poll::Ready(None)
130        } else {
131            match this.rx.poll_next(cx) {
132                Poll::Ready(None) => {
133                    *this.done = true;
134                    Poll::Ready(None)
135                }
136                Poll::Ready(Some(result)) => match result {
137                    Ok(flight_message) => {
138                        let flight_data = this.encoder.encode(flight_message);
139                        Poll::Ready(Some(Ok(flight_data)))
140                    }
141                    Err(e) => {
142                        *this.done = true;
143                        Poll::Ready(Some(Err(e)))
144                    }
145                },
146                Poll::Pending => Poll::Pending,
147            }
148        }
149    }
150}
151
152#[cfg(test)]
153mod test {
154    use std::sync::Arc;
155
156    use common_grpc::flight::{FlightDecoder, FlightMessage};
157    use common_recordbatch::{RecordBatch, RecordBatches};
158    use datatypes::prelude::*;
159    use datatypes::schema::{ColumnSchema, Schema};
160    use datatypes::vectors::Int32Vector;
161    use futures::StreamExt;
162
163    use super::*;
164
165    #[tokio::test]
166    async fn test_flight_record_batch_stream() {
167        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
168            "a",
169            ConcreteDataType::int32_datatype(),
170            false,
171        )]));
172
173        let v: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
174        let recordbatch = RecordBatch::new(schema.clone(), vec![v]).unwrap();
175
176        let recordbatches = RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()])
177            .unwrap()
178            .as_stream();
179        let mut stream = FlightRecordBatchStream::new(
180            recordbatches,
181            TracingContext::default(),
182            FlightCompression::default(),
183        );
184
185        let mut raw_data = Vec::with_capacity(2);
186        raw_data.push(stream.next().await.unwrap().unwrap());
187        raw_data.push(stream.next().await.unwrap().unwrap());
188        assert!(stream.next().await.is_none());
189        assert!(stream.done);
190
191        let decoder = &mut FlightDecoder::default();
192        let mut flight_messages = raw_data
193            .into_iter()
194            .map(|x| decoder.try_decode(&x).unwrap())
195            .collect::<Vec<FlightMessage>>();
196        assert_eq!(flight_messages.len(), 2);
197
198        match flight_messages.remove(0) {
199            FlightMessage::Schema(actual_schema) => {
200                assert_eq!(&actual_schema, schema.arrow_schema());
201            }
202            _ => unreachable!(),
203        }
204
205        match flight_messages.remove(0) {
206            FlightMessage::RecordBatch(actual_recordbatch) => {
207                assert_eq!(&actual_recordbatch, recordbatch.df_record_batch());
208            }
209            _ => unreachable!(),
210        }
211    }
212}