servers/grpc/flight/
stream.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use arrow_flight::FlightData;
19use common_grpc::flight::{FlightEncoder, FlightMessage};
20use common_recordbatch::SendableRecordBatchStream;
21use common_telemetry::tracing::{info_span, Instrument};
22use common_telemetry::tracing_context::{FutureExt, TracingContext};
23use common_telemetry::warn;
24use futures::channel::mpsc;
25use futures::channel::mpsc::Sender;
26use futures::{SinkExt, Stream, StreamExt};
27use pin_project::{pin_project, pinned_drop};
28use snafu::ResultExt;
29use tokio::task::JoinHandle;
30
31use crate::error;
32use crate::grpc::flight::TonicResult;
33
34#[pin_project(PinnedDrop)]
35pub struct FlightRecordBatchStream {
36    #[pin]
37    rx: mpsc::Receiver<Result<FlightMessage, tonic::Status>>,
38    join_handle: JoinHandle<()>,
39    done: bool,
40    encoder: FlightEncoder,
41}
42
43impl FlightRecordBatchStream {
44    pub fn new(recordbatches: SendableRecordBatchStream, tracing_context: TracingContext) -> Self {
45        let (tx, rx) = mpsc::channel::<TonicResult<FlightMessage>>(1);
46        let join_handle = common_runtime::spawn_global(async move {
47            Self::flight_data_stream(recordbatches, tx)
48                .trace(tracing_context.attach(info_span!("flight_data_stream")))
49                .await
50        });
51        Self {
52            rx,
53            join_handle,
54            done: false,
55            encoder: FlightEncoder::default(),
56        }
57    }
58
59    async fn flight_data_stream(
60        mut recordbatches: SendableRecordBatchStream,
61        mut tx: Sender<TonicResult<FlightMessage>>,
62    ) {
63        let schema = recordbatches.schema().arrow_schema().clone();
64        if let Err(e) = tx.send(Ok(FlightMessage::Schema(schema))).await {
65            warn!(e; "stop sending Flight data");
66            return;
67        }
68
69        while let Some(batch_or_err) = recordbatches.next().in_current_span().await {
70            match batch_or_err {
71                Ok(recordbatch) => {
72                    if let Err(e) = tx
73                        .send(Ok(FlightMessage::RecordBatch(
74                            recordbatch.into_df_record_batch(),
75                        )))
76                        .await
77                    {
78                        warn!(e; "stop sending Flight data");
79                        return;
80                    }
81                }
82                Err(e) => {
83                    let e = Err(e).context(error::CollectRecordbatchSnafu);
84                    if let Err(e) = tx.send(e.map_err(|x| x.into())).await {
85                        warn!(e; "stop sending Flight data");
86                    }
87                    return;
88                }
89            }
90        }
91        // make last package to pass metrics
92        if let Some(metrics_str) = recordbatches
93            .metrics()
94            .and_then(|m| serde_json::to_string(&m).ok())
95        {
96            let _ = tx.send(Ok(FlightMessage::Metrics(metrics_str))).await;
97        }
98    }
99}
100
101#[pinned_drop]
102impl PinnedDrop for FlightRecordBatchStream {
103    fn drop(self: Pin<&mut Self>) {
104        self.join_handle.abort();
105    }
106}
107
108impl Stream for FlightRecordBatchStream {
109    type Item = TonicResult<FlightData>;
110
111    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
112        let this = self.project();
113        if *this.done {
114            Poll::Ready(None)
115        } else {
116            match this.rx.poll_next(cx) {
117                Poll::Ready(None) => {
118                    *this.done = true;
119                    Poll::Ready(None)
120                }
121                Poll::Ready(Some(result)) => match result {
122                    Ok(flight_message) => {
123                        let flight_data = this.encoder.encode(flight_message);
124                        Poll::Ready(Some(Ok(flight_data)))
125                    }
126                    Err(e) => {
127                        *this.done = true;
128                        Poll::Ready(Some(Err(e)))
129                    }
130                },
131                Poll::Pending => Poll::Pending,
132            }
133        }
134    }
135}
136
137#[cfg(test)]
138mod test {
139    use std::sync::Arc;
140
141    use common_grpc::flight::{FlightDecoder, FlightMessage};
142    use common_recordbatch::{RecordBatch, RecordBatches};
143    use datatypes::prelude::*;
144    use datatypes::schema::{ColumnSchema, Schema};
145    use datatypes::vectors::Int32Vector;
146    use futures::StreamExt;
147
148    use super::*;
149
150    #[tokio::test]
151    async fn test_flight_record_batch_stream() {
152        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
153            "a",
154            ConcreteDataType::int32_datatype(),
155            false,
156        )]));
157
158        let v: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
159        let recordbatch = RecordBatch::new(schema.clone(), vec![v]).unwrap();
160
161        let recordbatches = RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()])
162            .unwrap()
163            .as_stream();
164        let mut stream = FlightRecordBatchStream::new(recordbatches, TracingContext::default());
165
166        let mut raw_data = Vec::with_capacity(2);
167        raw_data.push(stream.next().await.unwrap().unwrap());
168        raw_data.push(stream.next().await.unwrap().unwrap());
169        assert!(stream.next().await.is_none());
170        assert!(stream.done);
171
172        let decoder = &mut FlightDecoder::default();
173        let mut flight_messages = raw_data
174            .into_iter()
175            .map(|x| decoder.try_decode(&x).unwrap())
176            .collect::<Vec<FlightMessage>>();
177        assert_eq!(flight_messages.len(), 2);
178
179        match flight_messages.remove(0) {
180            FlightMessage::Schema(actual_schema) => {
181                assert_eq!(&actual_schema, schema.arrow_schema());
182            }
183            _ => unreachable!(),
184        }
185
186        match flight_messages.remove(0) {
187            FlightMessage::RecordBatch(actual_recordbatch) => {
188                assert_eq!(&actual_recordbatch, recordbatch.df_record_batch());
189            }
190            _ => unreachable!(),
191        }
192    }
193}