1use std::collections::HashMap;
16use std::slice;
17use std::sync::Arc;
18
19use datafusion::arrow::util::pretty::pretty_format_batches;
20use datatypes::arrow::array::RecordBatchOptions;
21use datatypes::prelude::DataType;
22use datatypes::schema::SchemaRef;
23use datatypes::value::Value;
24use datatypes::vectors::{Helper, VectorRef};
25use serde::ser::{Error, SerializeStruct};
26use serde::{Serialize, Serializer};
27use snafu::{ensure, OptionExt, ResultExt};
28
29use crate::error::{
30 self, CastVectorSnafu, ColumnNotExistsSnafu, DataTypesSnafu, ProjectArrowRecordBatchSnafu,
31 Result,
32};
33use crate::DfRecordBatch;
34
35#[derive(Clone, Debug, PartialEq)]
37pub struct RecordBatch {
38 pub schema: SchemaRef,
39 pub columns: Vec<VectorRef>,
40 df_record_batch: DfRecordBatch,
41}
42
43impl RecordBatch {
44 pub fn new<I: IntoIterator<Item = VectorRef>>(
46 schema: SchemaRef,
47 columns: I,
48 ) -> Result<RecordBatch> {
49 let columns: Vec<_> = columns.into_iter().collect();
50 let arrow_arrays = columns.iter().map(|v| v.to_arrow_array()).collect();
51
52 let df_record_batch = DfRecordBatch::try_new(schema.arrow_schema().clone(), arrow_arrays)
53 .context(error::NewDfRecordBatchSnafu)?;
54
55 Ok(RecordBatch {
56 schema,
57 columns,
58 df_record_batch,
59 })
60 }
61
62 pub fn new_empty(schema: SchemaRef) -> RecordBatch {
64 let df_record_batch = DfRecordBatch::new_empty(schema.arrow_schema().clone());
65 let columns = schema
66 .column_schemas()
67 .iter()
68 .map(|col| col.data_type.create_mutable_vector(0).to_vector())
69 .collect();
70 RecordBatch {
71 schema,
72 columns,
73 df_record_batch,
74 }
75 }
76
77 pub fn new_with_count(schema: SchemaRef, num_rows: usize) -> Result<Self> {
79 let df_record_batch = DfRecordBatch::try_new_with_options(
80 schema.arrow_schema().clone(),
81 vec![],
82 &RecordBatchOptions::new().with_row_count(Some(num_rows)),
83 )
84 .context(error::NewDfRecordBatchSnafu)?;
85 Ok(RecordBatch {
86 schema,
87 columns: vec![],
88 df_record_batch,
89 })
90 }
91
92 pub fn try_project(&self, indices: &[usize]) -> Result<Self> {
93 let schema = Arc::new(self.schema.try_project(indices).context(DataTypesSnafu)?);
94 let mut columns = Vec::with_capacity(indices.len());
95 for index in indices {
96 columns.push(self.columns[*index].clone());
97 }
98 let df_record_batch = self.df_record_batch.project(indices).with_context(|_| {
99 ProjectArrowRecordBatchSnafu {
100 schema: self.schema.clone(),
101 projection: indices.to_vec(),
102 }
103 })?;
104
105 Ok(Self {
106 schema,
107 columns,
108 df_record_batch,
109 })
110 }
111
112 pub fn try_from_df_record_batch(
116 schema: SchemaRef,
117 df_record_batch: DfRecordBatch,
118 ) -> Result<RecordBatch> {
119 let columns = df_record_batch
120 .columns()
121 .iter()
122 .map(|c| Helper::try_into_vector(c.clone()).context(error::DataTypesSnafu))
123 .collect::<Result<Vec<_>>>()?;
124
125 Ok(RecordBatch {
126 schema,
127 columns,
128 df_record_batch,
129 })
130 }
131
132 #[inline]
133 pub fn df_record_batch(&self) -> &DfRecordBatch {
134 &self.df_record_batch
135 }
136
137 #[inline]
138 pub fn into_df_record_batch(self) -> DfRecordBatch {
139 self.df_record_batch
140 }
141
142 #[inline]
143 pub fn columns(&self) -> &[VectorRef] {
144 &self.columns
145 }
146
147 #[inline]
148 pub fn column(&self, idx: usize) -> &VectorRef {
149 &self.columns[idx]
150 }
151
152 pub fn column_by_name(&self, name: &str) -> Option<&VectorRef> {
153 let idx = self.schema.column_index_by_name(name)?;
154 Some(&self.columns[idx])
155 }
156
157 #[inline]
158 pub fn num_columns(&self) -> usize {
159 self.columns.len()
160 }
161
162 #[inline]
163 pub fn num_rows(&self) -> usize {
164 self.df_record_batch.num_rows()
165 }
166
167 pub fn rows(&self) -> RecordBatchRowIterator<'_> {
169 RecordBatchRowIterator::new(self)
170 }
171
172 pub fn column_vectors(
173 &self,
174 table_name: &str,
175 table_schema: SchemaRef,
176 ) -> Result<HashMap<String, VectorRef>> {
177 let mut vectors = HashMap::with_capacity(self.num_columns());
178
179 for (vector_schema, vector) in self.schema.column_schemas().iter().zip(self.columns.iter())
181 {
182 let column_name = &vector_schema.name;
183 let column_schema =
184 table_schema
185 .column_schema_by_name(column_name)
186 .context(ColumnNotExistsSnafu {
187 table_name,
188 column_name,
189 })?;
190 let vector = if vector_schema.data_type != column_schema.data_type {
191 vector
192 .cast(&column_schema.data_type)
193 .with_context(|_| CastVectorSnafu {
194 from_type: vector.data_type(),
195 to_type: column_schema.data_type.clone(),
196 })?
197 } else {
198 vector.clone()
199 };
200
201 let _ = vectors.insert(column_name.clone(), vector);
202 }
203
204 Ok(vectors)
205 }
206
207 pub fn pretty_print(&self) -> String {
209 pretty_format_batches(slice::from_ref(&self.df_record_batch))
210 .map(|t| t.to_string())
211 .unwrap_or("failed to pretty display a record batch".to_string())
212 }
213
214 pub fn slice(&self, offset: usize, len: usize) -> Result<RecordBatch> {
216 ensure!(
217 offset + len <= self.num_rows(),
218 error::RecordBatchSliceIndexOverflowSnafu {
219 size: self.num_rows(),
220 visit_index: offset + len
221 }
222 );
223 let columns = self.columns.iter().map(|vector| vector.slice(offset, len));
224 RecordBatch::new(self.schema.clone(), columns)
225 }
226}
227
228impl Serialize for RecordBatch {
229 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
230 where
231 S: Serializer,
232 {
233 let mut s = serializer.serialize_struct("record", 2)?;
236 s.serialize_field("schema", &**self.schema.arrow_schema())?;
237
238 let vec = self
239 .columns
240 .iter()
241 .map(|c| c.serialize_to_json())
242 .collect::<std::result::Result<Vec<_>, _>>()
243 .map_err(S::Error::custom)?;
244
245 s.serialize_field("columns", &vec)?;
246 s.end()
247 }
248}
249
250pub struct RecordBatchRowIterator<'a> {
251 record_batch: &'a RecordBatch,
252 rows: usize,
253 columns: usize,
254 row_cursor: usize,
255}
256
257impl<'a> RecordBatchRowIterator<'a> {
258 fn new(record_batch: &'a RecordBatch) -> RecordBatchRowIterator<'a> {
259 RecordBatchRowIterator {
260 record_batch,
261 rows: record_batch.df_record_batch.num_rows(),
262 columns: record_batch.df_record_batch.num_columns(),
263 row_cursor: 0,
264 }
265 }
266}
267
268impl Iterator for RecordBatchRowIterator<'_> {
269 type Item = Vec<Value>;
270
271 fn next(&mut self) -> Option<Self::Item> {
272 if self.row_cursor == self.rows {
273 None
274 } else {
275 let mut row = Vec::with_capacity(self.columns);
276
277 for col in 0..self.columns {
278 let column = self.record_batch.column(col);
279 row.push(column.get(self.row_cursor));
280 }
281
282 self.row_cursor += 1;
283 Some(row)
284 }
285 }
286}
287
288pub fn merge_record_batches(schema: SchemaRef, batches: &[RecordBatch]) -> Result<RecordBatch> {
290 let batches_len = batches.len();
291 if batches_len == 0 {
292 return Ok(RecordBatch::new_empty(schema));
293 }
294
295 let n_rows = batches.iter().map(|b| b.num_rows()).sum();
296 let n_columns = schema.num_columns();
297 let mut merged_columns = Vec::with_capacity(n_columns);
299
300 for col_idx in 0..n_columns {
301 let mut acc = schema.column_schemas()[col_idx]
302 .data_type
303 .create_mutable_vector(n_rows);
304
305 for batch in batches {
306 let column = batch.column(col_idx);
307 acc.extend_slice_of(column.as_ref(), 0, column.len())
308 .context(error::DataTypesSnafu)?;
309 }
310
311 merged_columns.push(acc.to_vector());
312 }
313
314 RecordBatch::new(schema, merged_columns)
316}
317
318#[cfg(test)]
319mod tests {
320 use std::sync::Arc;
321
322 use datatypes::arrow::datatypes::{DataType, Field, Schema as ArrowSchema};
323 use datatypes::data_type::ConcreteDataType;
324 use datatypes::schema::{ColumnSchema, Schema};
325 use datatypes::vectors::{StringVector, UInt32Vector};
326
327 use super::*;
328
329 #[test]
330 fn test_record_batch() {
331 let arrow_schema = Arc::new(ArrowSchema::new(vec![
332 Field::new("c1", DataType::UInt32, false),
333 Field::new("c2", DataType::UInt32, false),
334 ]));
335 let schema = Arc::new(Schema::try_from(arrow_schema).unwrap());
336
337 let c1 = Arc::new(UInt32Vector::from_slice([1, 2, 3]));
338 let c2 = Arc::new(UInt32Vector::from_slice([4, 5, 6]));
339 let columns: Vec<VectorRef> = vec![c1, c2];
340
341 let batch = RecordBatch::new(schema.clone(), columns.clone()).unwrap();
342 assert_eq!(3, batch.num_rows());
343 assert_eq!(&columns, batch.columns());
344 for (i, expect) in columns.iter().enumerate().take(batch.num_columns()) {
345 let column = batch.column(i);
346 assert_eq!(expect, column);
347 }
348 assert_eq!(schema, batch.schema);
349
350 assert_eq!(columns[0], *batch.column_by_name("c1").unwrap());
351 assert_eq!(columns[1], *batch.column_by_name("c2").unwrap());
352 assert!(batch.column_by_name("c3").is_none());
353
354 let converted =
355 RecordBatch::try_from_df_record_batch(schema, batch.df_record_batch().clone()).unwrap();
356 assert_eq!(batch, converted);
357 assert_eq!(*batch.df_record_batch(), converted.into_df_record_batch());
358 }
359
360 #[test]
361 pub fn test_serialize_recordbatch() {
362 let column_schemas = vec![ColumnSchema::new(
363 "number",
364 ConcreteDataType::uint32_datatype(),
365 false,
366 )];
367 let schema = Arc::new(Schema::try_new(column_schemas).unwrap());
368
369 let numbers: Vec<u32> = (0..10).collect();
370 let columns = vec![Arc::new(UInt32Vector::from_slice(numbers)) as VectorRef];
371 let batch = RecordBatch::new(schema, columns).unwrap();
372
373 let output = serde_json::to_string(&batch).unwrap();
374 assert_eq!(
375 r#"{"schema":{"fields":[{"name":"number","data_type":"UInt32","nullable":false,"dict_id":0,"dict_is_ordered":false,"metadata":{}}],"metadata":{"greptime:version":"0"}},"columns":[[0,1,2,3,4,5,6,7,8,9]]}"#,
376 output
377 );
378 }
379
380 #[test]
381 fn test_record_batch_visitor() {
382 let column_schemas = vec![
383 ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
384 ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
385 ];
386 let schema = Arc::new(Schema::new(column_schemas));
387 let columns: Vec<VectorRef> = vec![
388 Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
389 Arc::new(StringVector::from(vec![
390 None,
391 Some("hello"),
392 Some("greptime"),
393 None,
394 ])),
395 ];
396 let recordbatch = RecordBatch::new(schema, columns).unwrap();
397
398 let mut record_batch_iter = recordbatch.rows();
399 assert_eq!(
400 vec![Value::UInt32(1), Value::Null],
401 record_batch_iter
402 .next()
403 .unwrap()
404 .into_iter()
405 .collect::<Vec<Value>>()
406 );
407
408 assert_eq!(
409 vec![Value::UInt32(2), Value::String("hello".into())],
410 record_batch_iter
411 .next()
412 .unwrap()
413 .into_iter()
414 .collect::<Vec<Value>>()
415 );
416
417 assert_eq!(
418 vec![Value::UInt32(3), Value::String("greptime".into())],
419 record_batch_iter
420 .next()
421 .unwrap()
422 .into_iter()
423 .collect::<Vec<Value>>()
424 );
425
426 assert_eq!(
427 vec![Value::UInt32(4), Value::Null],
428 record_batch_iter
429 .next()
430 .unwrap()
431 .into_iter()
432 .collect::<Vec<Value>>()
433 );
434
435 assert!(record_batch_iter.next().is_none());
436 }
437
438 #[test]
439 fn test_record_batch_slice() {
440 let column_schemas = vec![
441 ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
442 ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
443 ];
444 let schema = Arc::new(Schema::new(column_schemas));
445 let columns: Vec<VectorRef> = vec![
446 Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
447 Arc::new(StringVector::from(vec![
448 None,
449 Some("hello"),
450 Some("greptime"),
451 None,
452 ])),
453 ];
454 let recordbatch = RecordBatch::new(schema, columns).unwrap();
455 let recordbatch = recordbatch.slice(1, 2).expect("recordbatch slice");
456 let mut record_batch_iter = recordbatch.rows();
457 assert_eq!(
458 vec![Value::UInt32(2), Value::String("hello".into())],
459 record_batch_iter
460 .next()
461 .unwrap()
462 .into_iter()
463 .collect::<Vec<Value>>()
464 );
465
466 assert_eq!(
467 vec![Value::UInt32(3), Value::String("greptime".into())],
468 record_batch_iter
469 .next()
470 .unwrap()
471 .into_iter()
472 .collect::<Vec<Value>>()
473 );
474
475 assert!(record_batch_iter.next().is_none());
476
477 assert!(recordbatch.slice(1, 5).is_err());
478 }
479
480 #[test]
481 fn test_merge_record_batch() {
482 let column_schemas = vec![
483 ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
484 ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
485 ];
486 let schema = Arc::new(Schema::new(column_schemas));
487 let columns: Vec<VectorRef> = vec![
488 Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
489 Arc::new(StringVector::from(vec![
490 None,
491 Some("hello"),
492 Some("greptime"),
493 None,
494 ])),
495 ];
496 let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap();
497
498 let columns: Vec<VectorRef> = vec![
499 Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
500 Arc::new(StringVector::from(vec![
501 None,
502 Some("hello"),
503 Some("greptime"),
504 None,
505 ])),
506 ];
507 let recordbatch2 = RecordBatch::new(schema.clone(), columns).unwrap();
508
509 let merged = merge_record_batches(schema.clone(), &[recordbatch, recordbatch2])
510 .expect("merge recordbatch");
511 assert_eq!(merged.num_rows(), 8);
512 }
513}