common_recordbatch/
cursor.rsuse futures::StreamExt;
use tokio::sync::Mutex;
use crate::error::Result;
use crate::recordbatch::merge_record_batches;
use crate::{RecordBatch, SendableRecordBatchStream};
struct Inner {
stream: SendableRecordBatchStream,
current_row_index: usize,
current_batch: Option<RecordBatch>,
total_rows_in_current_batch: usize,
}
pub struct RecordBatchStreamCursor {
inner: Mutex<Inner>,
}
impl RecordBatchStreamCursor {
pub fn new(stream: SendableRecordBatchStream) -> RecordBatchStreamCursor {
Self {
inner: Mutex::new(Inner {
stream,
current_row_index: 0,
current_batch: None,
total_rows_in_current_batch: 0,
}),
}
}
pub async fn take(&self, size: usize) -> Result<RecordBatch> {
let mut remaining_rows_to_take = size;
let mut accumulated_rows = Vec::new();
let mut inner = self.inner.lock().await;
while remaining_rows_to_take > 0 {
if inner.current_batch.is_none()
|| inner.current_row_index >= inner.total_rows_in_current_batch
{
match inner.stream.next().await {
Some(Ok(batch)) => {
inner.total_rows_in_current_batch = batch.num_rows();
inner.current_batch = Some(batch);
inner.current_row_index = 0;
}
Some(Err(e)) => return Err(e),
None => {
break;
}
}
}
let current_batch = match &inner.current_batch {
Some(batch) => batch,
None => break,
};
let rows_to_take_from_batch = remaining_rows_to_take
.min(inner.total_rows_in_current_batch - inner.current_row_index);
let taken_batch =
current_batch.slice(inner.current_row_index, rows_to_take_from_batch)?;
accumulated_rows.push(taken_batch);
inner.current_row_index += rows_to_take_from_batch;
remaining_rows_to_take -= rows_to_take_from_batch;
}
if accumulated_rows.is_empty() {
return Ok(RecordBatch::new_empty(inner.stream.schema()));
}
if accumulated_rows.len() == 1 {
return Ok(accumulated_rows.remove(0));
}
merge_record_batches(inner.stream.schema(), &accumulated_rows)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use datatypes::prelude::ConcreteDataType;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::vectors::StringVector;
use super::*;
use crate::RecordBatches;
#[tokio::test]
async fn test_cursor() {
let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
"a",
ConcreteDataType::string_datatype(),
false,
)]));
let rbs = RecordBatches::try_from_columns(
schema.clone(),
vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
)
.unwrap();
let cursor = RecordBatchStreamCursor::new(rbs.as_stream());
let result_rb = cursor.take(1).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 1);
let result_rb = cursor.take(1).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 1);
let result_rb = cursor.take(1).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 0);
let rb = RecordBatch::new(
schema.clone(),
vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
)
.unwrap();
let rbs2 =
RecordBatches::try_new(schema.clone(), vec![rb.clone(), rb.clone(), rb]).unwrap();
let cursor = RecordBatchStreamCursor::new(rbs2.as_stream());
let result_rb = cursor.take(3).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 3);
let result_rb = cursor.take(2).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 2);
let result_rb = cursor.take(2).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 1);
let result_rb = cursor.take(2).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 0);
let rb = RecordBatch::new(
schema.clone(),
vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
)
.unwrap();
let rbs3 =
RecordBatches::try_new(schema.clone(), vec![rb.clone(), rb.clone(), rb]).unwrap();
let cursor = RecordBatchStreamCursor::new(rbs3.as_stream());
let result_rb = cursor.take(10).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 6);
}
}