common_recordbatch/
cursor.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use futures::StreamExt;
16use tokio::sync::Mutex;
17
18use crate::error::Result;
19use crate::recordbatch::merge_record_batches;
20use crate::{RecordBatch, SendableRecordBatchStream};
21
22struct Inner {
23    stream: SendableRecordBatchStream,
24    current_row_index: usize,
25    current_batch: Option<RecordBatch>,
26    total_rows_in_current_batch: usize,
27}
28
29/// A cursor on RecordBatchStream that fetches data batch by batch
30pub struct RecordBatchStreamCursor {
31    inner: Mutex<Inner>,
32}
33
34impl RecordBatchStreamCursor {
35    pub fn new(stream: SendableRecordBatchStream) -> RecordBatchStreamCursor {
36        Self {
37            inner: Mutex::new(Inner {
38                stream,
39                current_row_index: 0,
40                current_batch: None,
41                total_rows_in_current_batch: 0,
42            }),
43        }
44    }
45
46    /// Take `size` of row from the `RecordBatchStream` and create a new
47    /// `RecordBatch` for these rows.
48    pub async fn take(&self, size: usize) -> Result<RecordBatch> {
49        let mut remaining_rows_to_take = size;
50        let mut accumulated_rows = Vec::new();
51
52        let mut inner = self.inner.lock().await;
53
54        while remaining_rows_to_take > 0 {
55            // Ensure we have a current batch or fetch the next one
56            if inner.current_batch.is_none()
57                || inner.current_row_index >= inner.total_rows_in_current_batch
58            {
59                match inner.stream.next().await {
60                    Some(Ok(batch)) => {
61                        inner.total_rows_in_current_batch = batch.num_rows();
62                        inner.current_batch = Some(batch);
63                        inner.current_row_index = 0;
64                    }
65                    Some(Err(e)) => return Err(e),
66                    None => {
67                        // Stream is exhausted
68                        break;
69                    }
70                }
71            }
72
73            // If we still have no batch after attempting to fetch
74            let current_batch = match &inner.current_batch {
75                Some(batch) => batch,
76                None => break,
77            };
78
79            // Calculate how many rows we can take from this batch
80            let rows_to_take_from_batch = remaining_rows_to_take
81                .min(inner.total_rows_in_current_batch - inner.current_row_index);
82
83            // Slice the current batch to get the desired rows
84            let taken_batch =
85                current_batch.slice(inner.current_row_index, rows_to_take_from_batch)?;
86
87            // Add the taken batch to accumulated rows
88            accumulated_rows.push(taken_batch);
89
90            // Update cursor and remaining rows
91            inner.current_row_index += rows_to_take_from_batch;
92            remaining_rows_to_take -= rows_to_take_from_batch;
93        }
94
95        // If no rows were accumulated, return empty
96        if accumulated_rows.is_empty() {
97            return Ok(RecordBatch::new_empty(inner.stream.schema()));
98        }
99
100        // If only one batch was accumulated, return it directly
101        if accumulated_rows.len() == 1 {
102            return Ok(accumulated_rows.remove(0));
103        }
104
105        // Merge multiple batches
106        merge_record_batches(inner.stream.schema(), &accumulated_rows)
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use std::sync::Arc;
113
114    use datatypes::prelude::ConcreteDataType;
115    use datatypes::schema::{ColumnSchema, Schema};
116    use datatypes::vectors::StringVector;
117
118    use super::*;
119    use crate::RecordBatches;
120
121    #[tokio::test]
122    async fn test_cursor() {
123        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
124            "a",
125            ConcreteDataType::string_datatype(),
126            false,
127        )]));
128
129        let rbs = RecordBatches::try_from_columns(
130            schema.clone(),
131            vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
132        )
133        .unwrap();
134
135        let cursor = RecordBatchStreamCursor::new(rbs.as_stream());
136        let result_rb = cursor.take(1).await.expect("take from cursor failed");
137        assert_eq!(result_rb.num_rows(), 1);
138
139        let result_rb = cursor.take(1).await.expect("take from cursor failed");
140        assert_eq!(result_rb.num_rows(), 1);
141
142        let result_rb = cursor.take(1).await.expect("take from cursor failed");
143        assert_eq!(result_rb.num_rows(), 0);
144
145        let rb = RecordBatch::new(
146            schema.clone(),
147            vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
148        )
149        .unwrap();
150        let rbs2 =
151            RecordBatches::try_new(schema.clone(), vec![rb.clone(), rb.clone(), rb]).unwrap();
152        let cursor = RecordBatchStreamCursor::new(rbs2.as_stream());
153        let result_rb = cursor.take(3).await.expect("take from cursor failed");
154        assert_eq!(result_rb.num_rows(), 3);
155        let result_rb = cursor.take(2).await.expect("take from cursor failed");
156        assert_eq!(result_rb.num_rows(), 2);
157        let result_rb = cursor.take(2).await.expect("take from cursor failed");
158        assert_eq!(result_rb.num_rows(), 1);
159        let result_rb = cursor.take(2).await.expect("take from cursor failed");
160        assert_eq!(result_rb.num_rows(), 0);
161
162        let rb = RecordBatch::new(
163            schema.clone(),
164            vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
165        )
166        .unwrap();
167        let rbs3 =
168            RecordBatches::try_new(schema.clone(), vec![rb.clone(), rb.clone(), rb]).unwrap();
169        let cursor = RecordBatchStreamCursor::new(rbs3.as_stream());
170        let result_rb = cursor.take(10).await.expect("take from cursor failed");
171        assert_eq!(result_rb.num_rows(), 6);
172    }
173}