servers/grpc/flight/
stream.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use arrow_flight::FlightData;
19use common_grpc::flight::{FlightEncoder, FlightMessage};
20use common_recordbatch::SendableRecordBatchStream;
21use common_telemetry::tracing::{info_span, Instrument};
22use common_telemetry::tracing_context::{FutureExt, TracingContext};
23use common_telemetry::warn;
24use futures::channel::mpsc;
25use futures::channel::mpsc::Sender;
26use futures::{SinkExt, Stream, StreamExt};
27use pin_project::{pin_project, pinned_drop};
28use snafu::ResultExt;
29use tokio::task::JoinHandle;
30
31use crate::error;
32use crate::grpc::flight::TonicResult;
33
34#[pin_project(PinnedDrop)]
35pub struct FlightRecordBatchStream {
36    #[pin]
37    rx: mpsc::Receiver<Result<FlightMessage, tonic::Status>>,
38    join_handle: JoinHandle<()>,
39    done: bool,
40    encoder: FlightEncoder,
41}
42
43impl FlightRecordBatchStream {
44    pub fn new(recordbatches: SendableRecordBatchStream, tracing_context: TracingContext) -> Self {
45        let (tx, rx) = mpsc::channel::<TonicResult<FlightMessage>>(1);
46        let join_handle = common_runtime::spawn_global(async move {
47            Self::flight_data_stream(recordbatches, tx)
48                .trace(tracing_context.attach(info_span!("flight_data_stream")))
49                .await
50        });
51        Self {
52            rx,
53            join_handle,
54            done: false,
55            encoder: FlightEncoder::default(),
56        }
57    }
58
59    async fn flight_data_stream(
60        mut recordbatches: SendableRecordBatchStream,
61        mut tx: Sender<TonicResult<FlightMessage>>,
62    ) {
63        let schema = recordbatches.schema();
64        if let Err(e) = tx.send(Ok(FlightMessage::Schema(schema))).await {
65            warn!(e; "stop sending Flight data");
66            return;
67        }
68
69        while let Some(batch_or_err) = recordbatches.next().in_current_span().await {
70            match batch_or_err {
71                Ok(recordbatch) => {
72                    if let Err(e) = tx.send(Ok(FlightMessage::Recordbatch(recordbatch))).await {
73                        warn!(e; "stop sending Flight data");
74                        return;
75                    }
76                }
77                Err(e) => {
78                    let e = Err(e).context(error::CollectRecordbatchSnafu);
79                    if let Err(e) = tx.send(e.map_err(|x| x.into())).await {
80                        warn!(e; "stop sending Flight data");
81                    }
82                    return;
83                }
84            }
85        }
86        // make last package to pass metrics
87        if let Some(metrics_str) = recordbatches
88            .metrics()
89            .and_then(|m| serde_json::to_string(&m).ok())
90        {
91            let _ = tx.send(Ok(FlightMessage::Metrics(metrics_str))).await;
92        }
93    }
94}
95
96#[pinned_drop]
97impl PinnedDrop for FlightRecordBatchStream {
98    fn drop(self: Pin<&mut Self>) {
99        self.join_handle.abort();
100    }
101}
102
103impl Stream for FlightRecordBatchStream {
104    type Item = TonicResult<FlightData>;
105
106    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
107        let this = self.project();
108        if *this.done {
109            Poll::Ready(None)
110        } else {
111            match this.rx.poll_next(cx) {
112                Poll::Ready(None) => {
113                    *this.done = true;
114                    Poll::Ready(None)
115                }
116                Poll::Ready(Some(result)) => match result {
117                    Ok(flight_message) => {
118                        let flight_data = this.encoder.encode(flight_message);
119                        Poll::Ready(Some(Ok(flight_data)))
120                    }
121                    Err(e) => {
122                        *this.done = true;
123                        Poll::Ready(Some(Err(e)))
124                    }
125                },
126                Poll::Pending => Poll::Pending,
127            }
128        }
129    }
130}
131
132#[cfg(test)]
133mod test {
134    use std::sync::Arc;
135
136    use common_grpc::flight::{FlightDecoder, FlightMessage};
137    use common_recordbatch::{RecordBatch, RecordBatches};
138    use datatypes::prelude::*;
139    use datatypes::schema::{ColumnSchema, Schema};
140    use datatypes::vectors::Int32Vector;
141    use futures::StreamExt;
142
143    use super::*;
144
145    #[tokio::test]
146    async fn test_flight_record_batch_stream() {
147        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
148            "a",
149            ConcreteDataType::int32_datatype(),
150            false,
151        )]));
152
153        let v: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
154        let recordbatch = RecordBatch::new(schema.clone(), vec![v]).unwrap();
155
156        let recordbatches = RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()])
157            .unwrap()
158            .as_stream();
159        let mut stream = FlightRecordBatchStream::new(recordbatches, TracingContext::default());
160
161        let mut raw_data = Vec::with_capacity(2);
162        raw_data.push(stream.next().await.unwrap().unwrap());
163        raw_data.push(stream.next().await.unwrap().unwrap());
164        assert!(stream.next().await.is_none());
165        assert!(stream.done);
166
167        let decoder = &mut FlightDecoder::default();
168        let mut flight_messages = raw_data
169            .into_iter()
170            .map(|x| decoder.try_decode(&x).unwrap())
171            .collect::<Vec<FlightMessage>>();
172        assert_eq!(flight_messages.len(), 2);
173
174        match flight_messages.remove(0) {
175            FlightMessage::Schema(actual_schema) => {
176                assert_eq!(actual_schema, schema);
177            }
178            _ => unreachable!(),
179        }
180
181        match flight_messages.remove(0) {
182            FlightMessage::Recordbatch(actual_recordbatch) => {
183                assert_eq!(actual_recordbatch, recordbatch);
184            }
185            _ => unreachable!(),
186        }
187    }
188}