common_recordbatch/
util.rs1use std::pin::Pin;
16use std::sync::Arc;
17use std::task::{Context, Poll};
18
19use arc_swap::ArcSwapOption;
20use datatypes::schema::SchemaRef;
21use futures::{Stream, StreamExt, TryStreamExt};
22use snafu::ensure;
23
24use crate::adapter::RecordBatchMetrics;
25use crate::error::{EmptyStreamSnafu, Result, SchemaNotMatchSnafu};
26use crate::{
27 OrderOption, RecordBatch, RecordBatchStream, RecordBatches, SendableRecordBatchStream,
28};
29
30pub async fn collect(stream: SendableRecordBatchStream) -> Result<Vec<RecordBatch>> {
32 stream.try_collect::<Vec<_>>().await
33}
34
35pub async fn collect_batches(stream: SendableRecordBatchStream) -> Result<RecordBatches> {
37 let schema = stream.schema();
38 let batches = stream.try_collect::<Vec<_>>().await?;
39 RecordBatches::try_new(schema, batches)
40}
41
42pub struct ChainedRecordBatchStream {
44 inputs: Vec<SendableRecordBatchStream>,
45 curr_index: usize,
46 schema: SchemaRef,
47 metrics: Arc<ArcSwapOption<RecordBatchMetrics>>,
48}
49
50impl ChainedRecordBatchStream {
51 pub fn new(inputs: Vec<SendableRecordBatchStream>) -> Result<Self> {
52 ensure!(!inputs.is_empty(), EmptyStreamSnafu);
54
55 let first_schema = inputs[0].schema();
57 for input in inputs.iter().skip(1) {
58 let schema = input.schema();
59 ensure!(
60 first_schema == schema,
61 SchemaNotMatchSnafu {
62 left: first_schema,
63 right: schema
64 }
65 );
66 }
67
68 Ok(Self {
69 inputs,
70 curr_index: 0,
71 schema: first_schema,
72 metrics: Default::default(),
73 })
74 }
75
76 fn sequence_poll(
77 mut self: Pin<&mut Self>,
78 ctx: &mut Context<'_>,
79 ) -> Poll<Option<Result<RecordBatch>>> {
80 if self.curr_index >= self.inputs.len() {
81 return Poll::Ready(None);
82 }
83
84 let curr_index = self.curr_index;
85 match self.inputs[curr_index].poll_next_unpin(ctx) {
86 Poll::Ready(Some(Ok(batch))) => Poll::Ready(Some(Ok(batch))),
87 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
88 Poll::Ready(None) => {
89 self.curr_index += 1;
90 if self.curr_index < self.inputs.len() {
91 self.sequence_poll(ctx)
92 } else {
93 Poll::Ready(None)
94 }
95 }
96 Poll::Pending => Poll::Pending,
97 }
98 }
99}
100
101impl RecordBatchStream for ChainedRecordBatchStream {
102 fn name(&self) -> &str {
103 "ChainedRecordBatchStream"
104 }
105
106 fn schema(&self) -> SchemaRef {
107 self.schema.clone()
108 }
109
110 fn output_ordering(&self) -> Option<&[OrderOption]> {
111 None
112 }
113
114 fn metrics(&self) -> Option<RecordBatchMetrics> {
115 self.metrics.load().as_ref().map(|m| m.as_ref().clone())
116 }
117}
118
119impl Stream for ChainedRecordBatchStream {
120 type Item = Result<RecordBatch>;
121
122 fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
123 self.sequence_poll(ctx)
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use std::pin::Pin;
130 use std::sync::Arc;
131
132 use datatypes::prelude::*;
133 use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
134 use datatypes::vectors::UInt32Vector;
135 use futures::task::{Context, Poll};
136 use futures::Stream;
137
138 use super::*;
139 use crate::adapter::RecordBatchMetrics;
140 use crate::{OrderOption, RecordBatchStream};
141
142 struct MockRecordBatchStream {
143 batch: Option<RecordBatch>,
144 schema: SchemaRef,
145 }
146
147 impl RecordBatchStream for MockRecordBatchStream {
148 fn schema(&self) -> SchemaRef {
149 self.schema.clone()
150 }
151
152 fn output_ordering(&self) -> Option<&[OrderOption]> {
153 None
154 }
155
156 fn metrics(&self) -> Option<RecordBatchMetrics> {
157 None
158 }
159 }
160
161 impl Stream for MockRecordBatchStream {
162 type Item = Result<RecordBatch>;
163
164 fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
165 let batch = self.batch.take();
166
167 if let Some(batch) = batch {
168 Poll::Ready(Some(Ok(batch)))
169 } else {
170 Poll::Ready(None)
171 }
172 }
173 }
174
175 #[tokio::test]
176 async fn test_collect() {
177 let column_schemas = vec![ColumnSchema::new(
178 "number",
179 ConcreteDataType::uint32_datatype(),
180 false,
181 )];
182
183 let schema = Arc::new(Schema::try_new(column_schemas).unwrap());
184
185 let stream = MockRecordBatchStream {
186 schema: schema.clone(),
187 batch: None,
188 };
189
190 let batches = collect(Box::pin(stream)).await.unwrap();
191 assert_eq!(0, batches.len());
192
193 let numbers: Vec<u32> = (0..10).collect();
194 let columns = [Arc::new(UInt32Vector::from_vec(numbers)) as _];
195 let batch = RecordBatch::new(schema.clone(), columns).unwrap();
196
197 let stream = MockRecordBatchStream {
198 schema: schema.clone(),
199 batch: Some(batch.clone()),
200 };
201 let batches = collect(Box::pin(stream)).await.unwrap();
202 assert_eq!(1, batches.len());
203 assert_eq!(batch, batches[0]);
204
205 let stream = MockRecordBatchStream {
206 schema: schema.clone(),
207 batch: Some(batch.clone()),
208 };
209 let batches = collect_batches(Box::pin(stream)).await.unwrap();
210 let expect_batches = RecordBatches::try_new(schema.clone(), vec![batch]).unwrap();
211 assert_eq!(expect_batches, batches);
212 }
213}