common_recordbatch/
util.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::sync::Arc;
17use std::task::{Context, Poll};
18
19use arc_swap::ArcSwapOption;
20use datatypes::schema::SchemaRef;
21use futures::{Stream, StreamExt, TryStreamExt};
22use snafu::ensure;
23
24use crate::adapter::RecordBatchMetrics;
25use crate::error::{EmptyStreamSnafu, Result, SchemaNotMatchSnafu};
26use crate::{
27    OrderOption, RecordBatch, RecordBatchStream, RecordBatches, SendableRecordBatchStream,
28};
29
30/// Collect all the items from the stream into a vector of [`RecordBatch`].
31pub async fn collect(stream: SendableRecordBatchStream) -> Result<Vec<RecordBatch>> {
32    stream.try_collect::<Vec<_>>().await
33}
34
35/// Collect all the items from the stream into [RecordBatches].
36pub async fn collect_batches(stream: SendableRecordBatchStream) -> Result<RecordBatches> {
37    let schema = stream.schema();
38    let batches = stream.try_collect::<Vec<_>>().await?;
39    RecordBatches::try_new(schema, batches)
40}
41
42/// A stream that chains multiple streams into a single stream.
43pub struct ChainedRecordBatchStream {
44    inputs: Vec<SendableRecordBatchStream>,
45    curr_index: usize,
46    schema: SchemaRef,
47    metrics: Arc<ArcSwapOption<RecordBatchMetrics>>,
48}
49
50impl ChainedRecordBatchStream {
51    pub fn new(inputs: Vec<SendableRecordBatchStream>) -> Result<Self> {
52        // check length
53        ensure!(!inputs.is_empty(), EmptyStreamSnafu);
54
55        // check schema
56        let first_schema = inputs[0].schema();
57        for input in inputs.iter().skip(1) {
58            let schema = input.schema();
59            ensure!(
60                first_schema == schema,
61                SchemaNotMatchSnafu {
62                    left: first_schema,
63                    right: schema
64                }
65            );
66        }
67
68        Ok(Self {
69            inputs,
70            curr_index: 0,
71            schema: first_schema,
72            metrics: Default::default(),
73        })
74    }
75
76    fn sequence_poll(
77        mut self: Pin<&mut Self>,
78        ctx: &mut Context<'_>,
79    ) -> Poll<Option<Result<RecordBatch>>> {
80        if self.curr_index >= self.inputs.len() {
81            return Poll::Ready(None);
82        }
83
84        let curr_index = self.curr_index;
85        match self.inputs[curr_index].poll_next_unpin(ctx) {
86            Poll::Ready(Some(Ok(batch))) => Poll::Ready(Some(Ok(batch))),
87            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
88            Poll::Ready(None) => {
89                self.curr_index += 1;
90                if self.curr_index < self.inputs.len() {
91                    self.sequence_poll(ctx)
92                } else {
93                    Poll::Ready(None)
94                }
95            }
96            Poll::Pending => Poll::Pending,
97        }
98    }
99}
100
101impl RecordBatchStream for ChainedRecordBatchStream {
102    fn name(&self) -> &str {
103        "ChainedRecordBatchStream"
104    }
105
106    fn schema(&self) -> SchemaRef {
107        self.schema.clone()
108    }
109
110    fn output_ordering(&self) -> Option<&[OrderOption]> {
111        None
112    }
113
114    fn metrics(&self) -> Option<RecordBatchMetrics> {
115        self.metrics.load().as_ref().map(|m| m.as_ref().clone())
116    }
117}
118
119impl Stream for ChainedRecordBatchStream {
120    type Item = Result<RecordBatch>;
121
122    fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
123        self.sequence_poll(ctx)
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use std::pin::Pin;
130    use std::sync::Arc;
131
132    use datatypes::prelude::*;
133    use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
134    use datatypes::vectors::UInt32Vector;
135    use futures::task::{Context, Poll};
136    use futures::Stream;
137
138    use super::*;
139    use crate::adapter::RecordBatchMetrics;
140    use crate::{OrderOption, RecordBatchStream};
141
142    struct MockRecordBatchStream {
143        batch: Option<RecordBatch>,
144        schema: SchemaRef,
145    }
146
147    impl RecordBatchStream for MockRecordBatchStream {
148        fn schema(&self) -> SchemaRef {
149            self.schema.clone()
150        }
151
152        fn output_ordering(&self) -> Option<&[OrderOption]> {
153            None
154        }
155
156        fn metrics(&self) -> Option<RecordBatchMetrics> {
157            None
158        }
159    }
160
161    impl Stream for MockRecordBatchStream {
162        type Item = Result<RecordBatch>;
163
164        fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
165            let batch = self.batch.take();
166
167            if let Some(batch) = batch {
168                Poll::Ready(Some(Ok(batch)))
169            } else {
170                Poll::Ready(None)
171            }
172        }
173    }
174
175    #[tokio::test]
176    async fn test_collect() {
177        let column_schemas = vec![ColumnSchema::new(
178            "number",
179            ConcreteDataType::uint32_datatype(),
180            false,
181        )];
182
183        let schema = Arc::new(Schema::try_new(column_schemas).unwrap());
184
185        let stream = MockRecordBatchStream {
186            schema: schema.clone(),
187            batch: None,
188        };
189
190        let batches = collect(Box::pin(stream)).await.unwrap();
191        assert_eq!(0, batches.len());
192
193        let numbers: Vec<u32> = (0..10).collect();
194        let columns = [Arc::new(UInt32Vector::from_vec(numbers)) as _];
195        let batch = RecordBatch::new(schema.clone(), columns).unwrap();
196
197        let stream = MockRecordBatchStream {
198            schema: schema.clone(),
199            batch: Some(batch.clone()),
200        };
201        let batches = collect(Box::pin(stream)).await.unwrap();
202        assert_eq!(1, batches.len());
203        assert_eq!(batch, batches[0]);
204
205        let stream = MockRecordBatchStream {
206            schema: schema.clone(),
207            batch: Some(batch.clone()),
208        };
209        let batches = collect_batches(Box::pin(stream)).await.unwrap();
210        let expect_batches = RecordBatches::try_new(schema.clone(), vec![batch]).unwrap();
211        assert_eq!(expect_batches, batches);
212    }
213}