1use std::collections::VecDeque;
16use std::sync::Arc;
17
18use bytes::Bytes;
19use datatypes::arrow::array::BooleanArray;
20use datatypes::arrow::record_batch::RecordBatch;
21use parquet::arrow::ProjectionMask;
22use parquet::arrow::arrow_reader::ParquetRecordBatchReader;
23use parquet::file::metadata::ParquetMetaData;
24use snafu::ResultExt;
25use store_api::storage::SequenceRange;
26
27use crate::error::{self, ComputeArrowSnafu, DecodeArrowRowGroupSnafu};
28use crate::memtable::bulk::context::{BulkIterContext, BulkIterContextRef};
29use crate::memtable::bulk::row_group_reader::MemtableRowGroupReaderBuilder;
30use crate::sst::parquet::file_range::PreFilterMode;
31use crate::sst::parquet::flat_format::sequence_column_index;
32use crate::sst::parquet::reader::RowGroupReaderContext;
33
34pub struct EncodedBulkPartIter {
36    context: BulkIterContextRef,
37    row_groups_to_read: VecDeque<usize>,
38    current_reader: Option<ParquetRecordBatchReader>,
39    builder: MemtableRowGroupReaderBuilder,
40    sequence: Option<SequenceRange>,
42    current_skip_fields: bool,
44}
45
46impl EncodedBulkPartIter {
47    pub(crate) fn try_new(
49        context: BulkIterContextRef,
50        mut row_groups_to_read: VecDeque<usize>,
51        parquet_meta: Arc<ParquetMetaData>,
52        data: Bytes,
53        sequence: Option<SequenceRange>,
54    ) -> error::Result<Self> {
55        assert!(context.read_format().as_flat().is_some());
56
57        let projection_mask = ProjectionMask::roots(
58            parquet_meta.file_metadata().schema_descr(),
59            context.read_format().projection_indices().iter().copied(),
60        );
61        let builder =
62            MemtableRowGroupReaderBuilder::try_new(&context, projection_mask, parquet_meta, data)?;
63
64        let (init_reader, current_skip_fields) = match row_groups_to_read.pop_front() {
65            Some(first_row_group) => {
66                let skip_fields = builder.compute_skip_fields(&context, first_row_group);
67                let reader = builder.build_row_group_reader(first_row_group, None)?;
68                (Some(reader), skip_fields)
69            }
70            None => (None, false),
71        };
72
73        Ok(Self {
74            context,
75            row_groups_to_read,
76            current_reader: init_reader,
77            builder,
78            sequence,
79            current_skip_fields,
80        })
81    }
82
83    pub(crate) fn next_record_batch(&mut self) -> error::Result<Option<RecordBatch>> {
85        let Some(current) = &mut self.current_reader else {
86            return Ok(None);
88        };
89
90        for batch in current {
91            let batch = batch.context(DecodeArrowRowGroupSnafu)?;
92            if let Some(batch) = apply_combined_filters(
93                &self.context,
94                &self.sequence,
95                batch,
96                self.current_skip_fields,
97            )? {
98                return Ok(Some(batch));
99            }
100        }
101
102        while let Some(next_row_group) = self.row_groups_to_read.pop_front() {
104            self.current_skip_fields = self
106                .builder
107                .compute_skip_fields(&self.context, next_row_group);
108
109            let next_reader = self.builder.build_row_group_reader(next_row_group, None)?;
110            let current = self.current_reader.insert(next_reader);
111
112            for batch in current {
113                let batch = batch.context(DecodeArrowRowGroupSnafu)?;
114                if let Some(batch) = apply_combined_filters(
115                    &self.context,
116                    &self.sequence,
117                    batch,
118                    self.current_skip_fields,
119                )? {
120                    return Ok(Some(batch));
121                }
122            }
123        }
124
125        Ok(None)
126    }
127}
128
129impl Iterator for EncodedBulkPartIter {
130    type Item = error::Result<RecordBatch>;
131
132    fn next(&mut self) -> Option<Self::Item> {
133        self.next_record_batch().transpose()
134    }
135}
136
137pub struct BulkPartRecordBatchIter {
139    record_batch: Option<RecordBatch>,
141    context: BulkIterContextRef,
143    sequence: Option<SequenceRange>,
145}
146
147impl BulkPartRecordBatchIter {
148    pub fn new(
150        record_batch: RecordBatch,
151        context: BulkIterContextRef,
152        sequence: Option<SequenceRange>,
153    ) -> Self {
154        assert!(context.read_format().as_flat().is_some());
155
156        Self {
157            record_batch: Some(record_batch),
158            context,
159            sequence,
160        }
161    }
162
163    fn apply_projection(&self, record_batch: RecordBatch) -> error::Result<RecordBatch> {
165        let projection_indices = self.context.read_format().projection_indices();
166        if projection_indices.len() == record_batch.num_columns() {
167            return Ok(record_batch);
168        }
169
170        record_batch
171            .project(projection_indices)
172            .context(ComputeArrowSnafu)
173    }
174
175    fn process_batch(&mut self, record_batch: RecordBatch) -> error::Result<Option<RecordBatch>> {
176        let projected_batch = self.apply_projection(record_batch)?;
178        let skip_fields = match self.context.pre_filter_mode() {
181            PreFilterMode::All => false,
182            PreFilterMode::SkipFields => true,
183            PreFilterMode::SkipFieldsOnDelete => true,
184        };
185        let Some(filtered_batch) =
186            apply_combined_filters(&self.context, &self.sequence, projected_batch, skip_fields)?
187        else {
188            return Ok(None);
189        };
190
191        Ok(Some(filtered_batch))
192    }
193}
194
195impl Iterator for BulkPartRecordBatchIter {
196    type Item = error::Result<RecordBatch>;
197
198    fn next(&mut self) -> Option<Self::Item> {
199        let record_batch = self.record_batch.take()?;
200
201        self.process_batch(record_batch).transpose()
202    }
203}
204
205fn apply_combined_filters(
211    context: &BulkIterContext,
212    sequence: &Option<SequenceRange>,
213    record_batch: RecordBatch,
214    skip_fields: bool,
215) -> error::Result<Option<RecordBatch>> {
216    let format = context.read_format().as_flat().unwrap();
218    let record_batch = format.convert_batch(record_batch, None)?;
219
220    let num_rows = record_batch.num_rows();
221    let mut combined_filter = None;
222
223    if !context.base.filters.is_empty() {
225        let predicate_mask = context
226            .base
227            .compute_filter_mask_flat(&record_batch, skip_fields)?;
228        let Some(mask) = predicate_mask else {
230            return Ok(None);
231        };
232        combined_filter = Some(BooleanArray::from(mask));
233    }
234
235    if let Some(sequence) = sequence {
237        let sequence_column =
238            record_batch.column(sequence_column_index(record_batch.num_columns()));
239        let sequence_filter = sequence
240            .filter(&sequence_column)
241            .context(ComputeArrowSnafu)?;
242        combined_filter = match combined_filter {
244            None => Some(sequence_filter),
245            Some(existing_filter) => {
246                let and_result = datatypes::arrow::compute::and(&existing_filter, &sequence_filter)
247                    .context(ComputeArrowSnafu)?;
248                Some(and_result)
249            }
250        };
251    }
252
253    let Some(filter_array) = combined_filter else {
255        return Ok(Some(record_batch));
257    };
258    let select_count = filter_array.true_count();
259    if select_count == 0 {
260        return Ok(None);
261    }
262    if select_count == num_rows {
263        return Ok(Some(record_batch));
264    }
265    let filtered_batch =
266        datatypes::arrow::compute::filter_record_batch(&record_batch, &filter_array)
267            .context(ComputeArrowSnafu)?;
268
269    Ok(Some(filtered_batch))
270}
271
272#[cfg(test)]
273mod tests {
274    use std::sync::Arc;
275
276    use api::v1::SemanticType;
277    use datafusion_expr::{col, lit};
278    use datatypes::arrow::array::{ArrayRef, Int64Array, StringArray, UInt8Array, UInt64Array};
279    use datatypes::arrow::datatypes::{DataType, Field, Schema};
280    use datatypes::data_type::ConcreteDataType;
281    use datatypes::schema::ColumnSchema;
282    use store_api::metadata::{ColumnMetadata, RegionMetadataBuilder};
283    use store_api::storage::RegionId;
284    use table::predicate::Predicate;
285
286    use super::*;
287    use crate::memtable::bulk::context::BulkIterContext;
288
289    #[test]
290    fn test_bulk_part_record_batch_iter() {
291        let schema = Arc::new(Schema::new(vec![
293            Field::new("key1", DataType::Utf8, false),
294            Field::new("field1", DataType::Int64, false),
295            Field::new(
296                "timestamp",
297                DataType::Timestamp(datatypes::arrow::datatypes::TimeUnit::Millisecond, None),
298                false,
299            ),
300            Field::new(
301                "__primary_key",
302                DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Binary)),
303                false,
304            ),
305            Field::new("__sequence", DataType::UInt64, false),
306            Field::new("__op_type", DataType::UInt8, false),
307        ]));
308
309        let key1 = Arc::new(StringArray::from_iter_values(["key1", "key2", "key3"]));
311        let field1 = Arc::new(Int64Array::from(vec![11, 12, 13]));
312        let timestamp = Arc::new(datatypes::arrow::array::TimestampMillisecondArray::from(
313            vec![1000, 2000, 3000],
314        ));
315
316        use datatypes::arrow::array::{BinaryArray, DictionaryArray, UInt32Array};
318        let values = Arc::new(BinaryArray::from_iter_values([b"key1", b"key2", b"key3"]));
319        let keys = UInt32Array::from(vec![0, 1, 2]);
320        let primary_key = Arc::new(DictionaryArray::new(keys, values));
321
322        let sequence = Arc::new(UInt64Array::from(vec![1, 2, 3]));
323        let op_type = Arc::new(UInt8Array::from(vec![1, 1, 1])); let record_batch = RecordBatch::try_new(
326            schema,
327            vec![
328                key1,
329                field1,
330                timestamp,
331                primary_key.clone(),
332                sequence,
333                op_type,
334            ],
335        )
336        .unwrap();
337
338        let mut builder = RegionMetadataBuilder::new(RegionId::new(1, 1));
340        builder
341            .push_column_metadata(ColumnMetadata {
342                column_schema: ColumnSchema::new(
343                    "key1",
344                    ConcreteDataType::string_datatype(),
345                    false,
346                ),
347                semantic_type: SemanticType::Tag,
348                column_id: 0,
349            })
350            .push_column_metadata(ColumnMetadata {
351                column_schema: ColumnSchema::new(
352                    "field1",
353                    ConcreteDataType::int64_datatype(),
354                    false,
355                ),
356                semantic_type: SemanticType::Field,
357                column_id: 1,
358            })
359            .push_column_metadata(ColumnMetadata {
360                column_schema: ColumnSchema::new(
361                    "timestamp",
362                    ConcreteDataType::timestamp_millisecond_datatype(),
363                    false,
364                ),
365                semantic_type: SemanticType::Timestamp,
366                column_id: 2,
367            })
368            .primary_key(vec![0]);
369
370        let region_metadata = builder.build().unwrap();
371
372        let context = Arc::new(
374            BulkIterContext::new(
375                Arc::new(region_metadata.clone()),
376                None, None, false,
379            )
380            .unwrap(),
381        );
382        let iter = BulkPartRecordBatchIter::new(record_batch.clone(), context.clone(), None);
384        let result: Vec<_> = iter.map(|rb| rb.unwrap()).collect();
385        assert_eq!(1, result.len());
386        assert_eq!(3, result[0].num_rows());
387        assert_eq!(6, result[0].num_columns(),);
388
389        let iter = BulkPartRecordBatchIter::new(
391            record_batch.clone(),
392            context,
393            Some(SequenceRange::LtEq { max: 2 }),
394        );
395        let result: Vec<_> = iter.map(|rb| rb.unwrap()).collect();
396        assert_eq!(1, result.len());
397        let expect_sequence = Arc::new(UInt64Array::from(vec![1, 2])) as ArrayRef;
398        assert_eq!(
399            &expect_sequence,
400            result[0].column(result[0].num_columns() - 2)
401        );
402        assert_eq!(6, result[0].num_columns());
403
404        let context = Arc::new(
405            BulkIterContext::new(
406                Arc::new(region_metadata),
407                Some(&[0, 2]),
408                Some(Predicate::new(vec![col("key1").eq(lit("key2"))])),
409                false,
410            )
411            .unwrap(),
412        );
413        let iter = BulkPartRecordBatchIter::new(record_batch.clone(), context.clone(), None);
415        let result: Vec<_> = iter.map(|rb| rb.unwrap()).collect();
416        assert_eq!(1, result.len());
417        assert_eq!(1, result[0].num_rows());
418        assert_eq!(5, result[0].num_columns());
419        let expect_sequence = Arc::new(UInt64Array::from(vec![2])) as ArrayRef;
420        assert_eq!(
421            &expect_sequence,
422            result[0].column(result[0].num_columns() - 2)
423        );
424    }
425}