1use std::collections::HashMap;
18use std::sync::Arc;
19
20use api::v1::OpType;
21use datatypes::arrow::array::{ArrayRef, BooleanArray, UInt64Array, UInt8Array};
22use datatypes::arrow::compute::filter_record_batch;
23use datatypes::arrow::datatypes::SchemaRef;
24use datatypes::arrow::record_batch::RecordBatch;
25use snafu::{OptionExt, ResultExt};
26use store_api::metadata::{ColumnMetadata, RegionMetadata};
27use store_api::storage::{RegionId, SequenceNumber};
28
29use crate::error::{
30 ComputeArrowSnafu, CreateDefaultSnafu, InvalidRequestSnafu, NewRecordBatchSnafu, Result,
31 UnexpectedSnafu,
32};
33use crate::sst::parquet::plain_format::PLAIN_FIXED_POS_COLUMN_NUM;
34
35#[derive(Debug)]
42pub struct PlainBatch {
43 record_batch: RecordBatch,
45}
46
47impl PlainBatch {
48 pub fn new(record_batch: RecordBatch) -> Self {
50 assert!(
51 record_batch.num_columns() >= 2,
52 "record batch missing internal columns, num_columns: {}",
53 record_batch.num_columns()
54 );
55
56 Self { record_batch }
57 }
58
59 pub fn with_new_columns(&self, columns: Vec<ArrayRef>) -> Result<Self> {
61 let record_batch = RecordBatch::try_new(self.record_batch.schema(), columns)
62 .context(NewRecordBatchSnafu)?;
63 Ok(Self::new(record_batch))
64 }
65
66 pub fn num_columns(&self) -> usize {
68 self.record_batch.num_columns()
69 }
70
71 pub fn num_rows(&self) -> usize {
73 self.record_batch.num_rows()
74 }
75
76 pub fn is_empty(&self) -> bool {
78 self.num_rows() == 0
79 }
80
81 pub fn columns(&self) -> &[ArrayRef] {
83 self.record_batch.columns()
84 }
85
86 pub fn column(&self, idx: usize) -> &ArrayRef {
88 self.record_batch.column(idx)
89 }
90
91 pub fn internal_columns(&self) -> &[ArrayRef] {
93 &self.record_batch.columns()[self.record_batch.num_columns() - PLAIN_FIXED_POS_COLUMN_NUM..]
94 }
95
96 pub fn as_record_batch(&self) -> &RecordBatch {
98 &self.record_batch
99 }
100
101 pub fn into_record_batch(self) -> RecordBatch {
103 self.record_batch
104 }
105
106 pub fn filter(&self, predicate: &BooleanArray) -> Result<Self> {
108 let record_batch =
109 filter_record_batch(&self.record_batch, predicate).context(ComputeArrowSnafu)?;
110 Ok(Self::new(record_batch))
111 }
112
113 #[allow(dead_code)]
115 pub(crate) fn sequence_column_index(&self) -> usize {
116 self.record_batch.num_columns() - PLAIN_FIXED_POS_COLUMN_NUM
117 }
118}
119
120pub struct ColumnFiller<'a> {
122 metadata: &'a RegionMetadata,
124 schema: SchemaRef,
126 name_to_index: HashMap<String, usize>,
128}
129
130impl<'a> ColumnFiller<'a> {
131 pub fn new(
134 metadata: &'a RegionMetadata,
135 schema: SchemaRef,
136 record_batch: &RecordBatch,
137 ) -> Self {
138 debug_assert_eq!(metadata.column_metadatas.len() + 2, schema.fields().len());
139
140 let name_to_index: HashMap<_, _> = record_batch
142 .schema()
143 .fields()
144 .iter()
145 .enumerate()
146 .map(|(i, field)| (field.name().clone(), i))
147 .collect();
148
149 Self {
150 metadata,
151 schema,
152 name_to_index,
153 }
154 }
155
156 pub fn fill_missing_columns(
158 &self,
159 record_batch: &RecordBatch,
160 sequence: SequenceNumber,
161 op_type: OpType,
162 ) -> Result<RecordBatch> {
163 let num_rows = record_batch.num_rows();
164 let mut new_columns =
165 Vec::with_capacity(record_batch.num_columns() + PLAIN_FIXED_POS_COLUMN_NUM);
166
167 for column in &self.metadata.column_metadatas {
170 let array = match self.name_to_index.get(&column.column_schema.name) {
171 Some(index) => record_batch.column(*index).clone(),
172 None => match op_type {
173 OpType::Put => {
174 fill_column_put_default(self.metadata.region_id, column, num_rows)?
176 }
177 OpType::Delete => {
178 fill_column_delete_default(column, num_rows)?
180 }
181 },
182 };
183
184 new_columns.push(array);
185 }
186
187 let sequence_array = Arc::new(UInt64Array::from(vec![sequence; num_rows]));
190 let op_type_array = Arc::new(UInt8Array::from(vec![op_type as u8; num_rows]));
192 new_columns.push(sequence_array);
193 new_columns.push(op_type_array);
194
195 RecordBatch::try_new(self.schema.clone(), new_columns).context(NewRecordBatchSnafu)
196 }
197}
198
199fn fill_column_put_default(
200 region_id: RegionId,
201 column: &ColumnMetadata,
202 num_rows: usize,
203) -> Result<ArrayRef> {
204 if column.column_schema.is_default_impure() {
205 return UnexpectedSnafu {
206 reason: format!(
207 "unexpected impure default value with region_id: {}, column: {}, default_value: {:?}",
208 region_id,
209 column.column_schema.name,
210 column.column_schema.default_constraint(),
211 ),
212 }
213 .fail();
214 }
215 let vector = column
216 .column_schema
217 .create_default_vector(num_rows)
218 .context(CreateDefaultSnafu {
219 region_id,
220 column: &column.column_schema.name,
221 })?
222 .with_context(|| InvalidRequestSnafu {
224 region_id,
225 reason: format!(
226 "column {} does not have default value",
227 column.column_schema.name
228 ),
229 })?;
230 Ok(vector.to_arrow_array())
231}
232
233fn fill_column_delete_default(column: &ColumnMetadata, num_rows: usize) -> Result<ArrayRef> {
234 let vector = column
236 .column_schema
237 .create_default_vector_for_padding(num_rows);
238 Ok(vector.to_arrow_array())
239}
240
241#[cfg(test)]
242mod tests {
243 use api::v1::SemanticType;
244 use datatypes::arrow::array::{
245 Float64Array, Int32Array, StringArray, TimestampMillisecondArray,
246 };
247 use datatypes::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
248 use datatypes::schema::constraint::ColumnDefaultConstraint;
249 use datatypes::schema::ColumnSchema;
250 use datatypes::value::Value;
251 use store_api::metadata::{ColumnMetadata, RegionMetadataBuilder};
252 use store_api::storage::consts::{OP_TYPE_COLUMN_NAME, SEQUENCE_COLUMN_NAME};
253 use store_api::storage::{ConcreteDataType, RegionId};
254
255 use super::*;
256 use crate::sst::to_plain_sst_arrow_schema;
257
258 fn create_test_region_metadata() -> RegionMetadata {
260 let mut builder = RegionMetadataBuilder::new(RegionId::new(100, 200));
261 builder
262 .push_column_metadata(ColumnMetadata {
264 column_schema: ColumnSchema::new("k0", ConcreteDataType::string_datatype(), false)
265 .with_default_constraint(None)
266 .unwrap(),
267 semantic_type: SemanticType::Tag,
268 column_id: 0,
269 })
270 .push_column_metadata(ColumnMetadata {
272 column_schema: ColumnSchema::new(
273 "ts",
274 ConcreteDataType::timestamp_millisecond_datatype(),
275 false,
276 )
277 .with_time_index(true)
278 .with_default_constraint(None)
279 .unwrap(),
280 semantic_type: SemanticType::Timestamp,
281 column_id: 1,
282 })
283 .push_column_metadata(ColumnMetadata {
285 column_schema: ColumnSchema::new("v1", ConcreteDataType::float64_datatype(), true)
286 .with_default_constraint(Some(ColumnDefaultConstraint::Value(Value::Float64(
287 datatypes::value::OrderedFloat::from(42.0),
288 ))))
289 .unwrap(),
290 semantic_type: SemanticType::Field,
291 column_id: 2,
292 })
293 .primary_key(vec![0]);
294
295 builder.build().unwrap()
296 }
297
298 #[test]
299 fn test_column_filler_put() {
300 let region_metadata = create_test_region_metadata();
301 let output_schema = to_plain_sst_arrow_schema(®ion_metadata);
302
303 let input_schema = Arc::new(Schema::new(vec![
305 Field::new("k0", DataType::Utf8, false),
306 Field::new(
307 "ts",
308 DataType::Timestamp(TimeUnit::Millisecond, None),
309 false,
310 ),
311 ]));
312
313 let k0_values: ArrayRef = Arc::new(StringArray::from(vec!["key1", "key2"]));
314 let ts_values: ArrayRef = Arc::new(TimestampMillisecondArray::from(vec![1000, 2000]));
315
316 let input_batch =
317 RecordBatch::try_new(input_schema, vec![k0_values.clone(), ts_values.clone()]).unwrap();
318
319 let filler = ColumnFiller::new(®ion_metadata, output_schema.clone(), &input_batch);
321
322 let result = filler
324 .fill_missing_columns(&input_batch, 100, OpType::Put)
325 .unwrap();
326
327 let expected_columns = vec![
330 k0_values.clone(),
331 ts_values.clone(),
332 Arc::new(Float64Array::from(vec![42.0, 42.0])),
333 Arc::new(UInt64Array::from(vec![100, 100])),
334 Arc::new(UInt8Array::from(vec![OpType::Put as u8, OpType::Put as u8])),
335 ];
336 let expected_batch = RecordBatch::try_new(output_schema.clone(), expected_columns).unwrap();
337 assert_eq!(expected_batch, result);
338 }
339
340 #[test]
341 fn test_column_filler_delete() {
342 let region_metadata = create_test_region_metadata();
343 let output_schema = to_plain_sst_arrow_schema(®ion_metadata);
344
345 let input_schema = Arc::new(Schema::new(vec![
347 Field::new("k0", DataType::Utf8, false),
348 Field::new(
349 "ts",
350 DataType::Timestamp(TimeUnit::Millisecond, None),
351 false,
352 ),
353 ]));
354
355 let k0_values: ArrayRef = Arc::new(StringArray::from(vec!["key1", "key2"]));
356 let ts_values: ArrayRef = Arc::new(TimestampMillisecondArray::from(vec![1000, 2000]));
357
358 let input_batch =
359 RecordBatch::try_new(input_schema, vec![k0_values.clone(), ts_values.clone()]).unwrap();
360
361 let filler = ColumnFiller::new(®ion_metadata, output_schema.clone(), &input_batch);
363
364 let result = filler
366 .fill_missing_columns(&input_batch, 200, OpType::Delete)
367 .unwrap();
368
369 let v1_default = Arc::new(Float64Array::from(vec![None, None]));
371 let expected_columns = vec![
372 k0_values.clone(),
373 ts_values.clone(),
374 v1_default,
375 Arc::new(UInt64Array::from(vec![200, 200])),
376 Arc::new(UInt8Array::from(vec![
377 OpType::Delete as u8,
378 OpType::Delete as u8,
379 ])),
380 ];
381 let expected_batch = RecordBatch::try_new(output_schema.clone(), expected_columns).unwrap();
382 assert_eq!(expected_batch, result);
383 }
384
385 fn create_test_record_batch() -> RecordBatch {
386 let schema = Arc::new(Schema::new(vec![
387 Field::new("col1", DataType::Int32, false),
388 Field::new("col2", DataType::Utf8, false),
389 Field::new(SEQUENCE_COLUMN_NAME, DataType::UInt64, false),
390 Field::new(OP_TYPE_COLUMN_NAME, DataType::UInt8, false),
391 ]));
392
393 let col1 = Arc::new(Int32Array::from(vec![1, 2, 3]));
394 let col2 = Arc::new(StringArray::from(vec!["a", "b", "c"]));
395 let sequence = Arc::new(UInt64Array::from(vec![100, 101, 102]));
396 let op_type = Arc::new(UInt8Array::from(vec![1, 1, 1]));
397
398 RecordBatch::try_new(schema, vec![col1, col2, sequence, op_type]).unwrap()
399 }
400
401 #[test]
402 fn test_plain_batch_basic_methods() {
403 let record_batch = create_test_record_batch();
404 let plain_batch = PlainBatch::new(record_batch.clone());
405
406 assert_eq!(plain_batch.num_columns(), 4);
408 assert_eq!(plain_batch.num_rows(), 3);
409 assert!(!plain_batch.is_empty());
410 assert_eq!(plain_batch.columns().len(), 4);
411
412 let internal_columns = plain_batch.internal_columns();
414 assert_eq!(internal_columns.len(), PLAIN_FIXED_POS_COLUMN_NUM);
415 assert_eq!(internal_columns[0].len(), 3);
416 assert_eq!(internal_columns[1].len(), 3);
417
418 let col1 = plain_batch.column(0);
420 assert_eq!(col1.len(), 3);
421 assert_eq!(
422 col1.as_any().downcast_ref::<Int32Array>().unwrap().value(0),
423 1
424 );
425
426 assert_eq!(plain_batch.sequence_column_index(), 2);
428
429 assert_eq!(record_batch, *plain_batch.as_record_batch());
431 assert_eq!(record_batch, plain_batch.into_record_batch());
432 }
433
434 #[test]
435 fn test_with_new_columns() {
436 let record_batch = create_test_record_batch();
437 let plain_batch = PlainBatch::new(record_batch);
438
439 let col1 = Arc::new(Int32Array::from(vec![10, 20, 30]));
441 let col2 = Arc::new(StringArray::from(vec!["x", "y", "z"]));
442 let sequence = Arc::new(UInt64Array::from(vec![200, 201, 202]));
443 let op_type = Arc::new(UInt8Array::from(vec![0, 0, 0]));
444
445 let new_batch = plain_batch
446 .with_new_columns(vec![col1, col2, sequence, op_type])
447 .unwrap();
448
449 assert_eq!(new_batch.num_columns(), 4);
450 assert_eq!(new_batch.num_rows(), 3);
451 assert_eq!(
452 new_batch
453 .column(0)
454 .as_any()
455 .downcast_ref::<Int32Array>()
456 .unwrap()
457 .value(0),
458 10
459 );
460 assert_eq!(
461 new_batch
462 .column(1)
463 .as_any()
464 .downcast_ref::<StringArray>()
465 .unwrap()
466 .value(0),
467 "x"
468 );
469 }
470
471 #[test]
472 fn test_filter() {
473 let record_batch = create_test_record_batch();
474 let plain_batch = PlainBatch::new(record_batch);
475
476 let predicate = BooleanArray::from(vec![true, false, true]);
478
479 let filtered_batch = plain_batch.filter(&predicate).unwrap();
480
481 assert_eq!(filtered_batch.num_rows(), 2);
482 assert_eq!(
483 filtered_batch
484 .column(0)
485 .as_any()
486 .downcast_ref::<Int32Array>()
487 .unwrap()
488 .value(0),
489 1
490 );
491 assert_eq!(
492 filtered_batch
493 .column(0)
494 .as_any()
495 .downcast_ref::<Int32Array>()
496 .unwrap()
497 .value(1),
498 3
499 );
500 }
501}