1use std::collections::HashMap;
18use std::sync::Arc;
19
20use api::v1::OpType;
21use datatypes::arrow::array::{ArrayRef, BooleanArray, UInt64Array, UInt8Array};
22use datatypes::arrow::compute::filter_record_batch;
23use datatypes::arrow::datatypes::SchemaRef;
24use datatypes::arrow::record_batch::RecordBatch;
25use snafu::{OptionExt, ResultExt};
26use store_api::metadata::{ColumnMetadata, RegionMetadata};
27use store_api::storage::{RegionId, SequenceNumber};
28
29use crate::error::{
30 ComputeArrowSnafu, CreateDefaultSnafu, InvalidRequestSnafu, NewRecordBatchSnafu, Result,
31 UnexpectedSnafu,
32};
33
34pub(crate) const PLAIN_FIXED_POS_COLUMN_NUM: usize = 2;
38
39#[derive(Debug)]
46pub struct PlainBatch {
47 record_batch: RecordBatch,
49}
50
51impl PlainBatch {
52 pub fn new(record_batch: RecordBatch) -> Self {
54 assert!(
55 record_batch.num_columns() >= 2,
56 "record batch missing internal columns, num_columns: {}",
57 record_batch.num_columns()
58 );
59
60 Self { record_batch }
61 }
62
63 pub fn with_new_columns(&self, columns: Vec<ArrayRef>) -> Result<Self> {
65 let record_batch = RecordBatch::try_new(self.record_batch.schema(), columns)
66 .context(NewRecordBatchSnafu)?;
67 Ok(Self::new(record_batch))
68 }
69
70 pub fn num_columns(&self) -> usize {
72 self.record_batch.num_columns()
73 }
74
75 pub fn num_rows(&self) -> usize {
77 self.record_batch.num_rows()
78 }
79
80 pub fn is_empty(&self) -> bool {
82 self.num_rows() == 0
83 }
84
85 pub fn columns(&self) -> &[ArrayRef] {
87 self.record_batch.columns()
88 }
89
90 pub fn column(&self, idx: usize) -> &ArrayRef {
92 self.record_batch.column(idx)
93 }
94
95 pub fn internal_columns(&self) -> &[ArrayRef] {
97 &self.record_batch.columns()[self.record_batch.num_columns() - PLAIN_FIXED_POS_COLUMN_NUM..]
98 }
99
100 pub fn as_record_batch(&self) -> &RecordBatch {
102 &self.record_batch
103 }
104
105 pub fn into_record_batch(self) -> RecordBatch {
107 self.record_batch
108 }
109
110 pub fn filter(&self, predicate: &BooleanArray) -> Result<Self> {
112 let record_batch =
113 filter_record_batch(&self.record_batch, predicate).context(ComputeArrowSnafu)?;
114 Ok(Self::new(record_batch))
115 }
116
117 #[allow(dead_code)]
119 pub(crate) fn sequence_column_index(&self) -> usize {
120 self.record_batch.num_columns() - PLAIN_FIXED_POS_COLUMN_NUM
121 }
122}
123
124pub struct ColumnFiller<'a> {
126 metadata: &'a RegionMetadata,
128 schema: SchemaRef,
130 name_to_index: HashMap<String, usize>,
132}
133
134impl<'a> ColumnFiller<'a> {
135 pub fn new(
138 metadata: &'a RegionMetadata,
139 schema: SchemaRef,
140 record_batch: &RecordBatch,
141 ) -> Self {
142 debug_assert_eq!(metadata.column_metadatas.len() + 2, schema.fields().len());
143
144 let name_to_index: HashMap<_, _> = record_batch
146 .schema()
147 .fields()
148 .iter()
149 .enumerate()
150 .map(|(i, field)| (field.name().clone(), i))
151 .collect();
152
153 Self {
154 metadata,
155 schema,
156 name_to_index,
157 }
158 }
159
160 pub fn fill_missing_columns(
162 &self,
163 record_batch: &RecordBatch,
164 sequence: SequenceNumber,
165 op_type: OpType,
166 ) -> Result<RecordBatch> {
167 let num_rows = record_batch.num_rows();
168 let mut new_columns =
169 Vec::with_capacity(record_batch.num_columns() + PLAIN_FIXED_POS_COLUMN_NUM);
170
171 for column in &self.metadata.column_metadatas {
174 let array = match self.name_to_index.get(&column.column_schema.name) {
175 Some(index) => record_batch.column(*index).clone(),
176 None => match op_type {
177 OpType::Put => {
178 fill_column_put_default(self.metadata.region_id, column, num_rows)?
180 }
181 OpType::Delete => {
182 fill_column_delete_default(column, num_rows)?
184 }
185 },
186 };
187
188 new_columns.push(array);
189 }
190
191 let sequence_array = Arc::new(UInt64Array::from(vec![sequence; num_rows]));
194 let op_type_array = Arc::new(UInt8Array::from(vec![op_type as u8; num_rows]));
196 new_columns.push(sequence_array);
197 new_columns.push(op_type_array);
198
199 RecordBatch::try_new(self.schema.clone(), new_columns).context(NewRecordBatchSnafu)
200 }
201}
202
203fn fill_column_put_default(
204 region_id: RegionId,
205 column: &ColumnMetadata,
206 num_rows: usize,
207) -> Result<ArrayRef> {
208 if column.column_schema.is_default_impure() {
209 return UnexpectedSnafu {
210 reason: format!(
211 "unexpected impure default value with region_id: {}, column: {}, default_value: {:?}",
212 region_id,
213 column.column_schema.name,
214 column.column_schema.default_constraint(),
215 ),
216 }
217 .fail();
218 }
219 let vector = column
220 .column_schema
221 .create_default_vector(num_rows)
222 .context(CreateDefaultSnafu {
223 region_id,
224 column: &column.column_schema.name,
225 })?
226 .with_context(|| InvalidRequestSnafu {
228 region_id,
229 reason: format!(
230 "column {} does not have default value",
231 column.column_schema.name
232 ),
233 })?;
234 Ok(vector.to_arrow_array())
235}
236
237fn fill_column_delete_default(column: &ColumnMetadata, num_rows: usize) -> Result<ArrayRef> {
238 let vector = column
240 .column_schema
241 .create_default_vector_for_padding(num_rows);
242 Ok(vector.to_arrow_array())
243}
244
245#[cfg(test)]
246mod tests {
247 use api::v1::SemanticType;
248 use datatypes::arrow::array::{
249 Float64Array, Int32Array, StringArray, TimestampMillisecondArray,
250 };
251 use datatypes::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
252 use datatypes::schema::constraint::ColumnDefaultConstraint;
253 use datatypes::schema::ColumnSchema;
254 use datatypes::value::Value;
255 use store_api::metadata::{ColumnMetadata, RegionMetadataBuilder};
256 use store_api::storage::consts::{OP_TYPE_COLUMN_NAME, SEQUENCE_COLUMN_NAME};
257 use store_api::storage::{ConcreteDataType, RegionId};
258
259 use super::*;
260 use crate::sst::to_plain_sst_arrow_schema;
261
262 fn create_test_region_metadata() -> RegionMetadata {
264 let mut builder = RegionMetadataBuilder::new(RegionId::new(100, 200));
265 builder
266 .push_column_metadata(ColumnMetadata {
268 column_schema: ColumnSchema::new("k0", ConcreteDataType::string_datatype(), false)
269 .with_default_constraint(None)
270 .unwrap(),
271 semantic_type: SemanticType::Tag,
272 column_id: 0,
273 })
274 .push_column_metadata(ColumnMetadata {
276 column_schema: ColumnSchema::new(
277 "ts",
278 ConcreteDataType::timestamp_millisecond_datatype(),
279 false,
280 )
281 .with_time_index(true)
282 .with_default_constraint(None)
283 .unwrap(),
284 semantic_type: SemanticType::Timestamp,
285 column_id: 1,
286 })
287 .push_column_metadata(ColumnMetadata {
289 column_schema: ColumnSchema::new("v1", ConcreteDataType::float64_datatype(), true)
290 .with_default_constraint(Some(ColumnDefaultConstraint::Value(Value::Float64(
291 datatypes::value::OrderedFloat::from(42.0),
292 ))))
293 .unwrap(),
294 semantic_type: SemanticType::Field,
295 column_id: 2,
296 })
297 .primary_key(vec![0]);
298
299 builder.build().unwrap()
300 }
301
302 #[test]
303 fn test_column_filler_put() {
304 let region_metadata = create_test_region_metadata();
305 let output_schema = to_plain_sst_arrow_schema(®ion_metadata);
306
307 let input_schema = Arc::new(Schema::new(vec![
309 Field::new("k0", DataType::Utf8, false),
310 Field::new(
311 "ts",
312 DataType::Timestamp(TimeUnit::Millisecond, None),
313 false,
314 ),
315 ]));
316
317 let k0_values: ArrayRef = Arc::new(StringArray::from(vec!["key1", "key2"]));
318 let ts_values: ArrayRef = Arc::new(TimestampMillisecondArray::from(vec![1000, 2000]));
319
320 let input_batch =
321 RecordBatch::try_new(input_schema, vec![k0_values.clone(), ts_values.clone()]).unwrap();
322
323 let filler = ColumnFiller::new(®ion_metadata, output_schema.clone(), &input_batch);
325
326 let result = filler
328 .fill_missing_columns(&input_batch, 100, OpType::Put)
329 .unwrap();
330
331 let expected_columns = vec![
334 k0_values.clone(),
335 ts_values.clone(),
336 Arc::new(Float64Array::from(vec![42.0, 42.0])),
337 Arc::new(UInt64Array::from(vec![100, 100])),
338 Arc::new(UInt8Array::from(vec![OpType::Put as u8, OpType::Put as u8])),
339 ];
340 let expected_batch = RecordBatch::try_new(output_schema.clone(), expected_columns).unwrap();
341 assert_eq!(expected_batch, result);
342 }
343
344 #[test]
345 fn test_column_filler_delete() {
346 let region_metadata = create_test_region_metadata();
347 let output_schema = to_plain_sst_arrow_schema(®ion_metadata);
348
349 let input_schema = Arc::new(Schema::new(vec![
351 Field::new("k0", DataType::Utf8, false),
352 Field::new(
353 "ts",
354 DataType::Timestamp(TimeUnit::Millisecond, None),
355 false,
356 ),
357 ]));
358
359 let k0_values: ArrayRef = Arc::new(StringArray::from(vec!["key1", "key2"]));
360 let ts_values: ArrayRef = Arc::new(TimestampMillisecondArray::from(vec![1000, 2000]));
361
362 let input_batch =
363 RecordBatch::try_new(input_schema, vec![k0_values.clone(), ts_values.clone()]).unwrap();
364
365 let filler = ColumnFiller::new(®ion_metadata, output_schema.clone(), &input_batch);
367
368 let result = filler
370 .fill_missing_columns(&input_batch, 200, OpType::Delete)
371 .unwrap();
372
373 let v1_default = Arc::new(Float64Array::from(vec![None, None]));
375 let expected_columns = vec![
376 k0_values.clone(),
377 ts_values.clone(),
378 v1_default,
379 Arc::new(UInt64Array::from(vec![200, 200])),
380 Arc::new(UInt8Array::from(vec![
381 OpType::Delete as u8,
382 OpType::Delete as u8,
383 ])),
384 ];
385 let expected_batch = RecordBatch::try_new(output_schema.clone(), expected_columns).unwrap();
386 assert_eq!(expected_batch, result);
387 }
388
389 fn create_test_record_batch() -> RecordBatch {
390 let schema = Arc::new(Schema::new(vec![
391 Field::new("col1", DataType::Int32, false),
392 Field::new("col2", DataType::Utf8, false),
393 Field::new(SEQUENCE_COLUMN_NAME, DataType::UInt64, false),
394 Field::new(OP_TYPE_COLUMN_NAME, DataType::UInt8, false),
395 ]));
396
397 let col1 = Arc::new(Int32Array::from(vec![1, 2, 3]));
398 let col2 = Arc::new(StringArray::from(vec!["a", "b", "c"]));
399 let sequence = Arc::new(UInt64Array::from(vec![100, 101, 102]));
400 let op_type = Arc::new(UInt8Array::from(vec![1, 1, 1]));
401
402 RecordBatch::try_new(schema, vec![col1, col2, sequence, op_type]).unwrap()
403 }
404
405 #[test]
406 fn test_plain_batch_basic_methods() {
407 let record_batch = create_test_record_batch();
408 let plain_batch = PlainBatch::new(record_batch.clone());
409
410 assert_eq!(plain_batch.num_columns(), 4);
412 assert_eq!(plain_batch.num_rows(), 3);
413 assert!(!plain_batch.is_empty());
414 assert_eq!(plain_batch.columns().len(), 4);
415
416 let internal_columns = plain_batch.internal_columns();
418 assert_eq!(internal_columns.len(), PLAIN_FIXED_POS_COLUMN_NUM);
419 assert_eq!(internal_columns[0].len(), 3);
420 assert_eq!(internal_columns[1].len(), 3);
421
422 let col1 = plain_batch.column(0);
424 assert_eq!(col1.len(), 3);
425 assert_eq!(
426 col1.as_any().downcast_ref::<Int32Array>().unwrap().value(0),
427 1
428 );
429
430 assert_eq!(plain_batch.sequence_column_index(), 2);
432
433 assert_eq!(record_batch, *plain_batch.as_record_batch());
435 assert_eq!(record_batch, plain_batch.into_record_batch());
436 }
437
438 #[test]
439 fn test_with_new_columns() {
440 let record_batch = create_test_record_batch();
441 let plain_batch = PlainBatch::new(record_batch);
442
443 let col1 = Arc::new(Int32Array::from(vec![10, 20, 30]));
445 let col2 = Arc::new(StringArray::from(vec!["x", "y", "z"]));
446 let sequence = Arc::new(UInt64Array::from(vec![200, 201, 202]));
447 let op_type = Arc::new(UInt8Array::from(vec![0, 0, 0]));
448
449 let new_batch = plain_batch
450 .with_new_columns(vec![col1, col2, sequence, op_type])
451 .unwrap();
452
453 assert_eq!(new_batch.num_columns(), 4);
454 assert_eq!(new_batch.num_rows(), 3);
455 assert_eq!(
456 new_batch
457 .column(0)
458 .as_any()
459 .downcast_ref::<Int32Array>()
460 .unwrap()
461 .value(0),
462 10
463 );
464 assert_eq!(
465 new_batch
466 .column(1)
467 .as_any()
468 .downcast_ref::<StringArray>()
469 .unwrap()
470 .value(0),
471 "x"
472 );
473 }
474
475 #[test]
476 fn test_filter() {
477 let record_batch = create_test_record_batch();
478 let plain_batch = PlainBatch::new(record_batch);
479
480 let predicate = BooleanArray::from(vec![true, false, true]);
482
483 let filtered_batch = plain_batch.filter(&predicate).unwrap();
484
485 assert_eq!(filtered_batch.num_rows(), 2);
486 assert_eq!(
487 filtered_batch
488 .column(0)
489 .as_any()
490 .downcast_ref::<Int32Array>()
491 .unwrap()
492 .value(0),
493 1
494 );
495 assert_eq!(
496 filtered_batch
497 .column(0)
498 .as_any()
499 .downcast_ref::<Int32Array>()
500 .unwrap()
501 .value(1),
502 3
503 );
504 }
505}