common_recordbatch/
cursor.rs1use futures::StreamExt;
16use tokio::sync::Mutex;
17
18use crate::error::Result;
19use crate::recordbatch::merge_record_batches;
20use crate::{RecordBatch, SendableRecordBatchStream};
21
22struct Inner {
23 stream: SendableRecordBatchStream,
24 current_row_index: usize,
25 current_batch: Option<RecordBatch>,
26 total_rows_in_current_batch: usize,
27}
28
29pub struct RecordBatchStreamCursor {
31 inner: Mutex<Inner>,
32}
33
34impl RecordBatchStreamCursor {
35 pub fn new(stream: SendableRecordBatchStream) -> RecordBatchStreamCursor {
36 Self {
37 inner: Mutex::new(Inner {
38 stream,
39 current_row_index: 0,
40 current_batch: None,
41 total_rows_in_current_batch: 0,
42 }),
43 }
44 }
45
46 pub async fn take(&self, size: usize) -> Result<RecordBatch> {
49 let mut remaining_rows_to_take = size;
50 let mut accumulated_rows = Vec::new();
51
52 let mut inner = self.inner.lock().await;
53
54 while remaining_rows_to_take > 0 {
55 if inner.current_batch.is_none()
57 || inner.current_row_index >= inner.total_rows_in_current_batch
58 {
59 match inner.stream.next().await {
60 Some(Ok(batch)) => {
61 inner.total_rows_in_current_batch = batch.num_rows();
62 inner.current_batch = Some(batch);
63 inner.current_row_index = 0;
64 }
65 Some(Err(e)) => return Err(e),
66 None => {
67 break;
69 }
70 }
71 }
72
73 let current_batch = match &inner.current_batch {
75 Some(batch) => batch,
76 None => break,
77 };
78
79 let rows_to_take_from_batch = remaining_rows_to_take
81 .min(inner.total_rows_in_current_batch - inner.current_row_index);
82
83 let taken_batch =
85 current_batch.slice(inner.current_row_index, rows_to_take_from_batch)?;
86
87 accumulated_rows.push(taken_batch);
89
90 inner.current_row_index += rows_to_take_from_batch;
92 remaining_rows_to_take -= rows_to_take_from_batch;
93 }
94
95 if accumulated_rows.is_empty() {
97 return Ok(RecordBatch::new_empty(inner.stream.schema()));
98 }
99
100 if accumulated_rows.len() == 1 {
102 return Ok(accumulated_rows.remove(0));
103 }
104
105 merge_record_batches(inner.stream.schema(), &accumulated_rows)
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use std::sync::Arc;
113
114 use datatypes::prelude::ConcreteDataType;
115 use datatypes::schema::{ColumnSchema, Schema};
116 use datatypes::vectors::StringVector;
117
118 use super::*;
119 use crate::RecordBatches;
120
121 #[tokio::test]
122 async fn test_cursor() {
123 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
124 "a",
125 ConcreteDataType::string_datatype(),
126 false,
127 )]));
128
129 let rbs = RecordBatches::try_from_columns(
130 schema.clone(),
131 vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
132 )
133 .unwrap();
134
135 let cursor = RecordBatchStreamCursor::new(rbs.as_stream());
136 let result_rb = cursor.take(1).await.expect("take from cursor failed");
137 assert_eq!(result_rb.num_rows(), 1);
138
139 let result_rb = cursor.take(1).await.expect("take from cursor failed");
140 assert_eq!(result_rb.num_rows(), 1);
141
142 let result_rb = cursor.take(1).await.expect("take from cursor failed");
143 assert_eq!(result_rb.num_rows(), 0);
144
145 let rb = RecordBatch::new(
146 schema.clone(),
147 vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
148 )
149 .unwrap();
150 let rbs2 =
151 RecordBatches::try_new(schema.clone(), vec![rb.clone(), rb.clone(), rb]).unwrap();
152 let cursor = RecordBatchStreamCursor::new(rbs2.as_stream());
153 let result_rb = cursor.take(3).await.expect("take from cursor failed");
154 assert_eq!(result_rb.num_rows(), 3);
155 let result_rb = cursor.take(2).await.expect("take from cursor failed");
156 assert_eq!(result_rb.num_rows(), 2);
157 let result_rb = cursor.take(2).await.expect("take from cursor failed");
158 assert_eq!(result_rb.num_rows(), 1);
159 let result_rb = cursor.take(2).await.expect("take from cursor failed");
160 assert_eq!(result_rb.num_rows(), 0);
161
162 let rb = RecordBatch::new(
163 schema.clone(),
164 vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
165 )
166 .unwrap();
167 let rbs3 =
168 RecordBatches::try_new(schema.clone(), vec![rb.clone(), rb.clone(), rb]).unwrap();
169 let cursor = RecordBatchStreamCursor::new(rbs3.as_stream());
170 let result_rb = cursor.take(10).await.expect("take from cursor failed");
171 assert_eq!(result_rb.num_rows(), 6);
172 }
173}