common_recordbatch/
recordbatch.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::slice;
17use std::sync::Arc;
18
19use datafusion::arrow::util::pretty::pretty_format_batches;
20use datatypes::arrow::array::RecordBatchOptions;
21use datatypes::prelude::DataType;
22use datatypes::schema::SchemaRef;
23use datatypes::value::Value;
24use datatypes::vectors::{Helper, VectorRef};
25use serde::ser::{Error, SerializeStruct};
26use serde::{Serialize, Serializer};
27use snafu::{ensure, OptionExt, ResultExt};
28
29use crate::error::{
30    self, CastVectorSnafu, ColumnNotExistsSnafu, DataTypesSnafu, ProjectArrowRecordBatchSnafu,
31    Result,
32};
33use crate::DfRecordBatch;
34
35/// A two-dimensional batch of column-oriented data with a defined schema.
36#[derive(Clone, Debug, PartialEq)]
37pub struct RecordBatch {
38    pub schema: SchemaRef,
39    pub columns: Vec<VectorRef>,
40    df_record_batch: DfRecordBatch,
41}
42
43impl RecordBatch {
44    /// Create a new [`RecordBatch`] from `schema` and `columns`.
45    pub fn new<I: IntoIterator<Item = VectorRef>>(
46        schema: SchemaRef,
47        columns: I,
48    ) -> Result<RecordBatch> {
49        let columns: Vec<_> = columns.into_iter().collect();
50        let arrow_arrays = columns.iter().map(|v| v.to_arrow_array()).collect();
51
52        let df_record_batch = DfRecordBatch::try_new(schema.arrow_schema().clone(), arrow_arrays)
53            .context(error::NewDfRecordBatchSnafu)?;
54
55        Ok(RecordBatch {
56            schema,
57            columns,
58            df_record_batch,
59        })
60    }
61
62    /// Create an empty [`RecordBatch`] from `schema`.
63    pub fn new_empty(schema: SchemaRef) -> RecordBatch {
64        let df_record_batch = DfRecordBatch::new_empty(schema.arrow_schema().clone());
65        let columns = schema
66            .column_schemas()
67            .iter()
68            .map(|col| col.data_type.create_mutable_vector(0).to_vector())
69            .collect();
70        RecordBatch {
71            schema,
72            columns,
73            df_record_batch,
74        }
75    }
76
77    /// Create an empty [`RecordBatch`] from `schema` with `num_rows`.
78    pub fn new_with_count(schema: SchemaRef, num_rows: usize) -> Result<Self> {
79        let df_record_batch = DfRecordBatch::try_new_with_options(
80            schema.arrow_schema().clone(),
81            vec![],
82            &RecordBatchOptions::new().with_row_count(Some(num_rows)),
83        )
84        .context(error::NewDfRecordBatchSnafu)?;
85        Ok(RecordBatch {
86            schema,
87            columns: vec![],
88            df_record_batch,
89        })
90    }
91
92    pub fn try_project(&self, indices: &[usize]) -> Result<Self> {
93        let schema = Arc::new(self.schema.try_project(indices).context(DataTypesSnafu)?);
94        let mut columns = Vec::with_capacity(indices.len());
95        for index in indices {
96            columns.push(self.columns[*index].clone());
97        }
98        let df_record_batch = self.df_record_batch.project(indices).with_context(|_| {
99            ProjectArrowRecordBatchSnafu {
100                schema: self.schema.clone(),
101                projection: indices.to_vec(),
102            }
103        })?;
104
105        Ok(Self {
106            schema,
107            columns,
108            df_record_batch,
109        })
110    }
111
112    /// Create a new [`RecordBatch`] from `schema` and `df_record_batch`.
113    ///
114    /// This method doesn't check the schema.
115    pub fn try_from_df_record_batch(
116        schema: SchemaRef,
117        df_record_batch: DfRecordBatch,
118    ) -> Result<RecordBatch> {
119        let columns = df_record_batch
120            .columns()
121            .iter()
122            .map(|c| Helper::try_into_vector(c.clone()).context(error::DataTypesSnafu))
123            .collect::<Result<Vec<_>>>()?;
124
125        Ok(RecordBatch {
126            schema,
127            columns,
128            df_record_batch,
129        })
130    }
131
132    #[inline]
133    pub fn df_record_batch(&self) -> &DfRecordBatch {
134        &self.df_record_batch
135    }
136
137    #[inline]
138    pub fn into_df_record_batch(self) -> DfRecordBatch {
139        self.df_record_batch
140    }
141
142    #[inline]
143    pub fn columns(&self) -> &[VectorRef] {
144        &self.columns
145    }
146
147    #[inline]
148    pub fn column(&self, idx: usize) -> &VectorRef {
149        &self.columns[idx]
150    }
151
152    pub fn column_by_name(&self, name: &str) -> Option<&VectorRef> {
153        let idx = self.schema.column_index_by_name(name)?;
154        Some(&self.columns[idx])
155    }
156
157    #[inline]
158    pub fn num_columns(&self) -> usize {
159        self.columns.len()
160    }
161
162    #[inline]
163    pub fn num_rows(&self) -> usize {
164        self.df_record_batch.num_rows()
165    }
166
167    /// Create an iterator to traverse the data by row
168    pub fn rows(&self) -> RecordBatchRowIterator<'_> {
169        RecordBatchRowIterator::new(self)
170    }
171
172    pub fn column_vectors(
173        &self,
174        table_name: &str,
175        table_schema: SchemaRef,
176    ) -> Result<HashMap<String, VectorRef>> {
177        let mut vectors = HashMap::with_capacity(self.num_columns());
178
179        // column schemas in recordbatch must match its vectors, otherwise it's corrupted
180        for (vector_schema, vector) in self.schema.column_schemas().iter().zip(self.columns.iter())
181        {
182            let column_name = &vector_schema.name;
183            let column_schema =
184                table_schema
185                    .column_schema_by_name(column_name)
186                    .context(ColumnNotExistsSnafu {
187                        table_name,
188                        column_name,
189                    })?;
190            let vector = if vector_schema.data_type != column_schema.data_type {
191                vector
192                    .cast(&column_schema.data_type)
193                    .with_context(|_| CastVectorSnafu {
194                        from_type: vector.data_type(),
195                        to_type: column_schema.data_type.clone(),
196                    })?
197            } else {
198                vector.clone()
199            };
200
201            let _ = vectors.insert(column_name.clone(), vector);
202        }
203
204        Ok(vectors)
205    }
206
207    /// Pretty display this record batch like a table
208    pub fn pretty_print(&self) -> String {
209        pretty_format_batches(slice::from_ref(&self.df_record_batch))
210            .map(|t| t.to_string())
211            .unwrap_or("failed to pretty display a record batch".to_string())
212    }
213
214    /// Return a slice record batch starts from offset, with len rows
215    pub fn slice(&self, offset: usize, len: usize) -> Result<RecordBatch> {
216        ensure!(
217            offset + len <= self.num_rows(),
218            error::RecordBatchSliceIndexOverflowSnafu {
219                size: self.num_rows(),
220                visit_index: offset + len
221            }
222        );
223        let columns = self.columns.iter().map(|vector| vector.slice(offset, len));
224        RecordBatch::new(self.schema.clone(), columns)
225    }
226}
227
228impl Serialize for RecordBatch {
229    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
230    where
231        S: Serializer,
232    {
233        // TODO(yingwen): arrow and arrow2's schemas have different fields, so
234        // it might be better to use our `RawSchema` as serialized field.
235        let mut s = serializer.serialize_struct("record", 2)?;
236        s.serialize_field("schema", &**self.schema.arrow_schema())?;
237
238        let vec = self
239            .columns
240            .iter()
241            .map(|c| c.serialize_to_json())
242            .collect::<std::result::Result<Vec<_>, _>>()
243            .map_err(S::Error::custom)?;
244
245        s.serialize_field("columns", &vec)?;
246        s.end()
247    }
248}
249
250pub struct RecordBatchRowIterator<'a> {
251    record_batch: &'a RecordBatch,
252    rows: usize,
253    columns: usize,
254    row_cursor: usize,
255}
256
257impl<'a> RecordBatchRowIterator<'a> {
258    fn new(record_batch: &'a RecordBatch) -> RecordBatchRowIterator<'a> {
259        RecordBatchRowIterator {
260            record_batch,
261            rows: record_batch.df_record_batch.num_rows(),
262            columns: record_batch.df_record_batch.num_columns(),
263            row_cursor: 0,
264        }
265    }
266}
267
268impl Iterator for RecordBatchRowIterator<'_> {
269    type Item = Vec<Value>;
270
271    fn next(&mut self) -> Option<Self::Item> {
272        if self.row_cursor == self.rows {
273            None
274        } else {
275            let mut row = Vec::with_capacity(self.columns);
276
277            for col in 0..self.columns {
278                let column = self.record_batch.column(col);
279                row.push(column.get(self.row_cursor));
280            }
281
282            self.row_cursor += 1;
283            Some(row)
284        }
285    }
286}
287
288/// merge multiple recordbatch into a single
289pub fn merge_record_batches(schema: SchemaRef, batches: &[RecordBatch]) -> Result<RecordBatch> {
290    let batches_len = batches.len();
291    if batches_len == 0 {
292        return Ok(RecordBatch::new_empty(schema));
293    }
294
295    let n_rows = batches.iter().map(|b| b.num_rows()).sum();
296    let n_columns = schema.num_columns();
297    // Collect arrays from each batch
298    let mut merged_columns = Vec::with_capacity(n_columns);
299
300    for col_idx in 0..n_columns {
301        let mut acc = schema.column_schemas()[col_idx]
302            .data_type
303            .create_mutable_vector(n_rows);
304
305        for batch in batches {
306            let column = batch.column(col_idx);
307            acc.extend_slice_of(column.as_ref(), 0, column.len())
308                .context(error::DataTypesSnafu)?;
309        }
310
311        merged_columns.push(acc.to_vector());
312    }
313
314    // Create a new RecordBatch with merged columns
315    RecordBatch::new(schema, merged_columns)
316}
317
318#[cfg(test)]
319mod tests {
320    use std::sync::Arc;
321
322    use datatypes::arrow::datatypes::{DataType, Field, Schema as ArrowSchema};
323    use datatypes::data_type::ConcreteDataType;
324    use datatypes::schema::{ColumnSchema, Schema};
325    use datatypes::vectors::{StringVector, UInt32Vector};
326
327    use super::*;
328
329    #[test]
330    fn test_record_batch() {
331        let arrow_schema = Arc::new(ArrowSchema::new(vec![
332            Field::new("c1", DataType::UInt32, false),
333            Field::new("c2", DataType::UInt32, false),
334        ]));
335        let schema = Arc::new(Schema::try_from(arrow_schema).unwrap());
336
337        let c1 = Arc::new(UInt32Vector::from_slice([1, 2, 3]));
338        let c2 = Arc::new(UInt32Vector::from_slice([4, 5, 6]));
339        let columns: Vec<VectorRef> = vec![c1, c2];
340
341        let batch = RecordBatch::new(schema.clone(), columns.clone()).unwrap();
342        assert_eq!(3, batch.num_rows());
343        assert_eq!(&columns, batch.columns());
344        for (i, expect) in columns.iter().enumerate().take(batch.num_columns()) {
345            let column = batch.column(i);
346            assert_eq!(expect, column);
347        }
348        assert_eq!(schema, batch.schema);
349
350        assert_eq!(columns[0], *batch.column_by_name("c1").unwrap());
351        assert_eq!(columns[1], *batch.column_by_name("c2").unwrap());
352        assert!(batch.column_by_name("c3").is_none());
353
354        let converted =
355            RecordBatch::try_from_df_record_batch(schema, batch.df_record_batch().clone()).unwrap();
356        assert_eq!(batch, converted);
357        assert_eq!(*batch.df_record_batch(), converted.into_df_record_batch());
358    }
359
360    #[test]
361    pub fn test_serialize_recordbatch() {
362        let column_schemas = vec![ColumnSchema::new(
363            "number",
364            ConcreteDataType::uint32_datatype(),
365            false,
366        )];
367        let schema = Arc::new(Schema::try_new(column_schemas).unwrap());
368
369        let numbers: Vec<u32> = (0..10).collect();
370        let columns = vec![Arc::new(UInt32Vector::from_slice(numbers)) as VectorRef];
371        let batch = RecordBatch::new(schema, columns).unwrap();
372
373        let output = serde_json::to_string(&batch).unwrap();
374        assert_eq!(
375            r#"{"schema":{"fields":[{"name":"number","data_type":"UInt32","nullable":false,"dict_id":0,"dict_is_ordered":false,"metadata":{}}],"metadata":{"greptime:version":"0"}},"columns":[[0,1,2,3,4,5,6,7,8,9]]}"#,
376            output
377        );
378    }
379
380    #[test]
381    fn test_record_batch_visitor() {
382        let column_schemas = vec![
383            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
384            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
385        ];
386        let schema = Arc::new(Schema::new(column_schemas));
387        let columns: Vec<VectorRef> = vec![
388            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
389            Arc::new(StringVector::from(vec![
390                None,
391                Some("hello"),
392                Some("greptime"),
393                None,
394            ])),
395        ];
396        let recordbatch = RecordBatch::new(schema, columns).unwrap();
397
398        let mut record_batch_iter = recordbatch.rows();
399        assert_eq!(
400            vec![Value::UInt32(1), Value::Null],
401            record_batch_iter
402                .next()
403                .unwrap()
404                .into_iter()
405                .collect::<Vec<Value>>()
406        );
407
408        assert_eq!(
409            vec![Value::UInt32(2), Value::String("hello".into())],
410            record_batch_iter
411                .next()
412                .unwrap()
413                .into_iter()
414                .collect::<Vec<Value>>()
415        );
416
417        assert_eq!(
418            vec![Value::UInt32(3), Value::String("greptime".into())],
419            record_batch_iter
420                .next()
421                .unwrap()
422                .into_iter()
423                .collect::<Vec<Value>>()
424        );
425
426        assert_eq!(
427            vec![Value::UInt32(4), Value::Null],
428            record_batch_iter
429                .next()
430                .unwrap()
431                .into_iter()
432                .collect::<Vec<Value>>()
433        );
434
435        assert!(record_batch_iter.next().is_none());
436    }
437
438    #[test]
439    fn test_record_batch_slice() {
440        let column_schemas = vec![
441            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
442            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
443        ];
444        let schema = Arc::new(Schema::new(column_schemas));
445        let columns: Vec<VectorRef> = vec![
446            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
447            Arc::new(StringVector::from(vec![
448                None,
449                Some("hello"),
450                Some("greptime"),
451                None,
452            ])),
453        ];
454        let recordbatch = RecordBatch::new(schema, columns).unwrap();
455        let recordbatch = recordbatch.slice(1, 2).expect("recordbatch slice");
456        let mut record_batch_iter = recordbatch.rows();
457        assert_eq!(
458            vec![Value::UInt32(2), Value::String("hello".into())],
459            record_batch_iter
460                .next()
461                .unwrap()
462                .into_iter()
463                .collect::<Vec<Value>>()
464        );
465
466        assert_eq!(
467            vec![Value::UInt32(3), Value::String("greptime".into())],
468            record_batch_iter
469                .next()
470                .unwrap()
471                .into_iter()
472                .collect::<Vec<Value>>()
473        );
474
475        assert!(record_batch_iter.next().is_none());
476
477        assert!(recordbatch.slice(1, 5).is_err());
478    }
479
480    #[test]
481    fn test_merge_record_batch() {
482        let column_schemas = vec![
483            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
484            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
485        ];
486        let schema = Arc::new(Schema::new(column_schemas));
487        let columns: Vec<VectorRef> = vec![
488            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
489            Arc::new(StringVector::from(vec![
490                None,
491                Some("hello"),
492                Some("greptime"),
493                None,
494            ])),
495        ];
496        let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap();
497
498        let columns: Vec<VectorRef> = vec![
499            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
500            Arc::new(StringVector::from(vec![
501                None,
502                Some("hello"),
503                Some("greptime"),
504                None,
505            ])),
506        ];
507        let recordbatch2 = RecordBatch::new(schema.clone(), columns).unwrap();
508
509        let merged = merge_record_batches(schema.clone(), &[recordbatch, recordbatch2])
510            .expect("merge recordbatch");
511        assert_eq!(merged.num_rows(), 8);
512    }
513}