common_recordbatch/
recordbatch.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::slice;
17use std::sync::Arc;
18
19use datafusion::arrow::util::pretty::pretty_format_batches;
20use datafusion_common::arrow::array::ArrayRef;
21use datafusion_common::arrow::compute;
22use datafusion_common::arrow::datatypes::{DataType as ArrowDataType, SchemaRef as ArrowSchemaRef};
23use datatypes::arrow::array::RecordBatchOptions;
24use datatypes::prelude::DataType;
25use datatypes::schema::SchemaRef;
26use datatypes::value::Value;
27use datatypes::vectors::{Helper, VectorRef};
28use serde::ser::{Error, SerializeStruct};
29use serde::{Serialize, Serializer};
30use snafu::{OptionExt, ResultExt, ensure};
31
32use crate::DfRecordBatch;
33use crate::error::{
34    self, ArrowComputeSnafu, CastVectorSnafu, ColumnNotExistsSnafu, DataTypesSnafu,
35    ProjectArrowRecordBatchSnafu, Result,
36};
37
38/// A two-dimensional batch of column-oriented data with a defined schema.
39#[derive(Clone, Debug, PartialEq)]
40pub struct RecordBatch {
41    pub schema: SchemaRef,
42    pub columns: Vec<VectorRef>,
43    df_record_batch: DfRecordBatch,
44}
45
46impl RecordBatch {
47    /// Create a new [`RecordBatch`] from `schema` and `columns`.
48    pub fn new<I: IntoIterator<Item = VectorRef>>(
49        schema: SchemaRef,
50        columns: I,
51    ) -> Result<RecordBatch> {
52        let columns: Vec<_> = columns.into_iter().collect();
53        let arrow_arrays = columns.iter().map(|v| v.to_arrow_array()).collect();
54
55        // Casting the arrays here to match the schema, is a temporary solution to support Arrow's
56        // view array types (`StringViewArray` and `BinaryViewArray`).
57        // As to "support": the arrays here are created from vectors, which do not have types
58        // corresponding to view arrays. What we can do is to only cast them.
59        // As to "temporary": we are planing to use Arrow's RecordBatch directly in the read path.
60        // the casting here will be removed in the end.
61        // TODO(LFC): Remove the casting here once `Batch` is no longer used.
62        let arrow_arrays = Self::cast_view_arrays(schema.arrow_schema(), arrow_arrays)?;
63
64        let df_record_batch = DfRecordBatch::try_new(schema.arrow_schema().clone(), arrow_arrays)
65            .context(error::NewDfRecordBatchSnafu)?;
66
67        Ok(RecordBatch {
68            schema,
69            columns,
70            df_record_batch,
71        })
72    }
73
74    fn cast_view_arrays(
75        schema: &ArrowSchemaRef,
76        mut arrays: Vec<ArrayRef>,
77    ) -> Result<Vec<ArrayRef>> {
78        for (f, a) in schema.fields().iter().zip(arrays.iter_mut()) {
79            let expected = f.data_type();
80            let actual = a.data_type();
81            if matches!(
82                (expected, actual),
83                (ArrowDataType::Utf8View, ArrowDataType::Utf8)
84                    | (ArrowDataType::BinaryView, ArrowDataType::Binary)
85            ) {
86                *a = compute::cast(a, expected).context(ArrowComputeSnafu)?;
87            }
88        }
89        Ok(arrays)
90    }
91
92    /// Create an empty [`RecordBatch`] from `schema`.
93    pub fn new_empty(schema: SchemaRef) -> RecordBatch {
94        let df_record_batch = DfRecordBatch::new_empty(schema.arrow_schema().clone());
95        let columns = schema
96            .column_schemas()
97            .iter()
98            .map(|col| col.data_type.create_mutable_vector(0).to_vector())
99            .collect();
100        RecordBatch {
101            schema,
102            columns,
103            df_record_batch,
104        }
105    }
106
107    /// Create an empty [`RecordBatch`] from `schema` with `num_rows`.
108    pub fn new_with_count(schema: SchemaRef, num_rows: usize) -> Result<Self> {
109        let df_record_batch = DfRecordBatch::try_new_with_options(
110            schema.arrow_schema().clone(),
111            vec![],
112            &RecordBatchOptions::new().with_row_count(Some(num_rows)),
113        )
114        .context(error::NewDfRecordBatchSnafu)?;
115        Ok(RecordBatch {
116            schema,
117            columns: vec![],
118            df_record_batch,
119        })
120    }
121
122    pub fn try_project(&self, indices: &[usize]) -> Result<Self> {
123        let schema = Arc::new(self.schema.try_project(indices).context(DataTypesSnafu)?);
124        let mut columns = Vec::with_capacity(indices.len());
125        for index in indices {
126            columns.push(self.columns[*index].clone());
127        }
128        let df_record_batch = self.df_record_batch.project(indices).with_context(|_| {
129            ProjectArrowRecordBatchSnafu {
130                schema: self.schema.clone(),
131                projection: indices.to_vec(),
132            }
133        })?;
134
135        Ok(Self {
136            schema,
137            columns,
138            df_record_batch,
139        })
140    }
141
142    /// Create a new [`RecordBatch`] from `schema` and `df_record_batch`.
143    ///
144    /// This method doesn't check the schema.
145    pub fn try_from_df_record_batch(
146        schema: SchemaRef,
147        df_record_batch: DfRecordBatch,
148    ) -> Result<RecordBatch> {
149        let columns = df_record_batch
150            .columns()
151            .iter()
152            .map(|c| Helper::try_into_vector(c.clone()).context(error::DataTypesSnafu))
153            .collect::<Result<Vec<_>>>()?;
154
155        Ok(RecordBatch {
156            schema,
157            columns,
158            df_record_batch,
159        })
160    }
161
162    #[inline]
163    pub fn df_record_batch(&self) -> &DfRecordBatch {
164        &self.df_record_batch
165    }
166
167    #[inline]
168    pub fn into_df_record_batch(self) -> DfRecordBatch {
169        self.df_record_batch
170    }
171
172    #[inline]
173    pub fn columns(&self) -> &[VectorRef] {
174        &self.columns
175    }
176
177    #[inline]
178    pub fn column(&self, idx: usize) -> &VectorRef {
179        &self.columns[idx]
180    }
181
182    pub fn column_by_name(&self, name: &str) -> Option<&VectorRef> {
183        let idx = self.schema.column_index_by_name(name)?;
184        Some(&self.columns[idx])
185    }
186
187    #[inline]
188    pub fn num_columns(&self) -> usize {
189        self.columns.len()
190    }
191
192    #[inline]
193    pub fn num_rows(&self) -> usize {
194        self.df_record_batch.num_rows()
195    }
196
197    /// Create an iterator to traverse the data by row
198    pub fn rows(&self) -> RecordBatchRowIterator<'_> {
199        RecordBatchRowIterator::new(self)
200    }
201
202    pub fn column_vectors(
203        &self,
204        table_name: &str,
205        table_schema: SchemaRef,
206    ) -> Result<HashMap<String, VectorRef>> {
207        let mut vectors = HashMap::with_capacity(self.num_columns());
208
209        // column schemas in recordbatch must match its vectors, otherwise it's corrupted
210        for (vector_schema, vector) in self.schema.column_schemas().iter().zip(self.columns.iter())
211        {
212            let column_name = &vector_schema.name;
213            let column_schema =
214                table_schema
215                    .column_schema_by_name(column_name)
216                    .context(ColumnNotExistsSnafu {
217                        table_name,
218                        column_name,
219                    })?;
220            let vector = if vector_schema.data_type != column_schema.data_type {
221                vector
222                    .cast(&column_schema.data_type)
223                    .with_context(|_| CastVectorSnafu {
224                        from_type: vector.data_type(),
225                        to_type: column_schema.data_type.clone(),
226                    })?
227            } else {
228                vector.clone()
229            };
230
231            let _ = vectors.insert(column_name.clone(), vector);
232        }
233
234        Ok(vectors)
235    }
236
237    /// Pretty display this record batch like a table
238    pub fn pretty_print(&self) -> String {
239        pretty_format_batches(slice::from_ref(&self.df_record_batch))
240            .map(|t| t.to_string())
241            .unwrap_or("failed to pretty display a record batch".to_string())
242    }
243
244    /// Return a slice record batch starts from offset, with len rows
245    pub fn slice(&self, offset: usize, len: usize) -> Result<RecordBatch> {
246        ensure!(
247            offset + len <= self.num_rows(),
248            error::RecordBatchSliceIndexOverflowSnafu {
249                size: self.num_rows(),
250                visit_index: offset + len
251            }
252        );
253        let columns = self.columns.iter().map(|vector| vector.slice(offset, len));
254        RecordBatch::new(self.schema.clone(), columns)
255    }
256}
257
258impl Serialize for RecordBatch {
259    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
260    where
261        S: Serializer,
262    {
263        // TODO(yingwen): arrow and arrow2's schemas have different fields, so
264        // it might be better to use our `RawSchema` as serialized field.
265        let mut s = serializer.serialize_struct("record", 2)?;
266        s.serialize_field("schema", &**self.schema.arrow_schema())?;
267
268        let vec = self
269            .columns
270            .iter()
271            .map(|c| c.serialize_to_json())
272            .collect::<std::result::Result<Vec<_>, _>>()
273            .map_err(S::Error::custom)?;
274
275        s.serialize_field("columns", &vec)?;
276        s.end()
277    }
278}
279
280pub struct RecordBatchRowIterator<'a> {
281    record_batch: &'a RecordBatch,
282    rows: usize,
283    columns: usize,
284    row_cursor: usize,
285}
286
287impl<'a> RecordBatchRowIterator<'a> {
288    fn new(record_batch: &'a RecordBatch) -> RecordBatchRowIterator<'a> {
289        RecordBatchRowIterator {
290            record_batch,
291            rows: record_batch.df_record_batch.num_rows(),
292            columns: record_batch.df_record_batch.num_columns(),
293            row_cursor: 0,
294        }
295    }
296}
297
298impl Iterator for RecordBatchRowIterator<'_> {
299    type Item = Vec<Value>;
300
301    fn next(&mut self) -> Option<Self::Item> {
302        if self.row_cursor == self.rows {
303            None
304        } else {
305            let mut row = Vec::with_capacity(self.columns);
306
307            for col in 0..self.columns {
308                let column = self.record_batch.column(col);
309                row.push(column.get(self.row_cursor));
310            }
311
312            self.row_cursor += 1;
313            Some(row)
314        }
315    }
316}
317
318/// merge multiple recordbatch into a single
319pub fn merge_record_batches(schema: SchemaRef, batches: &[RecordBatch]) -> Result<RecordBatch> {
320    let batches_len = batches.len();
321    if batches_len == 0 {
322        return Ok(RecordBatch::new_empty(schema));
323    }
324
325    let n_rows = batches.iter().map(|b| b.num_rows()).sum();
326    let n_columns = schema.num_columns();
327    // Collect arrays from each batch
328    let mut merged_columns = Vec::with_capacity(n_columns);
329
330    for col_idx in 0..n_columns {
331        let mut acc = schema.column_schemas()[col_idx]
332            .data_type
333            .create_mutable_vector(n_rows);
334
335        for batch in batches {
336            let column = batch.column(col_idx);
337            acc.extend_slice_of(column.as_ref(), 0, column.len())
338                .context(error::DataTypesSnafu)?;
339        }
340
341        merged_columns.push(acc.to_vector());
342    }
343
344    // Create a new RecordBatch with merged columns
345    RecordBatch::new(schema, merged_columns)
346}
347
348#[cfg(test)]
349mod tests {
350    use std::sync::Arc;
351
352    use datatypes::arrow::datatypes::{DataType, Field, Schema as ArrowSchema};
353    use datatypes::data_type::ConcreteDataType;
354    use datatypes::schema::{ColumnSchema, Schema};
355    use datatypes::vectors::{StringVector, UInt32Vector};
356
357    use super::*;
358
359    #[test]
360    fn test_record_batch() {
361        let arrow_schema = Arc::new(ArrowSchema::new(vec![
362            Field::new("c1", DataType::UInt32, false),
363            Field::new("c2", DataType::UInt32, false),
364        ]));
365        let schema = Arc::new(Schema::try_from(arrow_schema).unwrap());
366
367        let c1 = Arc::new(UInt32Vector::from_slice([1, 2, 3]));
368        let c2 = Arc::new(UInt32Vector::from_slice([4, 5, 6]));
369        let columns: Vec<VectorRef> = vec![c1, c2];
370
371        let batch = RecordBatch::new(schema.clone(), columns.clone()).unwrap();
372        assert_eq!(3, batch.num_rows());
373        assert_eq!(&columns, batch.columns());
374        for (i, expect) in columns.iter().enumerate().take(batch.num_columns()) {
375            let column = batch.column(i);
376            assert_eq!(expect, column);
377        }
378        assert_eq!(schema, batch.schema);
379
380        assert_eq!(columns[0], *batch.column_by_name("c1").unwrap());
381        assert_eq!(columns[1], *batch.column_by_name("c2").unwrap());
382        assert!(batch.column_by_name("c3").is_none());
383
384        let converted =
385            RecordBatch::try_from_df_record_batch(schema, batch.df_record_batch().clone()).unwrap();
386        assert_eq!(batch, converted);
387        assert_eq!(*batch.df_record_batch(), converted.into_df_record_batch());
388    }
389
390    #[test]
391    pub fn test_serialize_recordbatch() {
392        let column_schemas = vec![ColumnSchema::new(
393            "number",
394            ConcreteDataType::uint32_datatype(),
395            false,
396        )];
397        let schema = Arc::new(Schema::try_new(column_schemas).unwrap());
398
399        let numbers: Vec<u32> = (0..10).collect();
400        let columns = vec![Arc::new(UInt32Vector::from_slice(numbers)) as VectorRef];
401        let batch = RecordBatch::new(schema, columns).unwrap();
402
403        let output = serde_json::to_string(&batch).unwrap();
404        assert_eq!(
405            r#"{"schema":{"fields":[{"name":"number","data_type":"UInt32","nullable":false,"dict_id":0,"dict_is_ordered":false,"metadata":{}}],"metadata":{"greptime:version":"0"}},"columns":[[0,1,2,3,4,5,6,7,8,9]]}"#,
406            output
407        );
408    }
409
410    #[test]
411    fn test_record_batch_visitor() {
412        let column_schemas = vec![
413            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
414            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
415        ];
416        let schema = Arc::new(Schema::new(column_schemas));
417        let columns: Vec<VectorRef> = vec![
418            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
419            Arc::new(StringVector::from(vec![
420                None,
421                Some("hello"),
422                Some("greptime"),
423                None,
424            ])),
425        ];
426        let recordbatch = RecordBatch::new(schema, columns).unwrap();
427
428        let mut record_batch_iter = recordbatch.rows();
429        assert_eq!(
430            vec![Value::UInt32(1), Value::Null],
431            record_batch_iter
432                .next()
433                .unwrap()
434                .into_iter()
435                .collect::<Vec<Value>>()
436        );
437
438        assert_eq!(
439            vec![Value::UInt32(2), Value::String("hello".into())],
440            record_batch_iter
441                .next()
442                .unwrap()
443                .into_iter()
444                .collect::<Vec<Value>>()
445        );
446
447        assert_eq!(
448            vec![Value::UInt32(3), Value::String("greptime".into())],
449            record_batch_iter
450                .next()
451                .unwrap()
452                .into_iter()
453                .collect::<Vec<Value>>()
454        );
455
456        assert_eq!(
457            vec![Value::UInt32(4), Value::Null],
458            record_batch_iter
459                .next()
460                .unwrap()
461                .into_iter()
462                .collect::<Vec<Value>>()
463        );
464
465        assert!(record_batch_iter.next().is_none());
466    }
467
468    #[test]
469    fn test_record_batch_slice() {
470        let column_schemas = vec![
471            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
472            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
473        ];
474        let schema = Arc::new(Schema::new(column_schemas));
475        let columns: Vec<VectorRef> = vec![
476            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
477            Arc::new(StringVector::from(vec![
478                None,
479                Some("hello"),
480                Some("greptime"),
481                None,
482            ])),
483        ];
484        let recordbatch = RecordBatch::new(schema, columns).unwrap();
485        let recordbatch = recordbatch.slice(1, 2).expect("recordbatch slice");
486        let mut record_batch_iter = recordbatch.rows();
487        assert_eq!(
488            vec![Value::UInt32(2), Value::String("hello".into())],
489            record_batch_iter
490                .next()
491                .unwrap()
492                .into_iter()
493                .collect::<Vec<Value>>()
494        );
495
496        assert_eq!(
497            vec![Value::UInt32(3), Value::String("greptime".into())],
498            record_batch_iter
499                .next()
500                .unwrap()
501                .into_iter()
502                .collect::<Vec<Value>>()
503        );
504
505        assert!(record_batch_iter.next().is_none());
506
507        assert!(recordbatch.slice(1, 5).is_err());
508    }
509
510    #[test]
511    fn test_merge_record_batch() {
512        let column_schemas = vec![
513            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
514            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
515        ];
516        let schema = Arc::new(Schema::new(column_schemas));
517        let columns: Vec<VectorRef> = vec![
518            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
519            Arc::new(StringVector::from(vec![
520                None,
521                Some("hello"),
522                Some("greptime"),
523                None,
524            ])),
525        ];
526        let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap();
527
528        let columns: Vec<VectorRef> = vec![
529            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
530            Arc::new(StringVector::from(vec![
531                None,
532                Some("hello"),
533                Some("greptime"),
534                None,
535            ])),
536        ];
537        let recordbatch2 = RecordBatch::new(schema.clone(), columns).unwrap();
538
539        let merged = merge_record_batches(schema.clone(), &[recordbatch, recordbatch2])
540            .expect("merge recordbatch");
541        assert_eq!(merged.num_rows(), 8);
542    }
543}