mito2/memtable/bulk/
part_reader.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::VecDeque;
16use std::sync::Arc;
17
18use bytes::Bytes;
19use datatypes::arrow::array::BooleanArray;
20use datatypes::arrow::record_batch::RecordBatch;
21use parquet::arrow::ProjectionMask;
22use parquet::arrow::arrow_reader::ParquetRecordBatchReader;
23use parquet::file::metadata::ParquetMetaData;
24use snafu::ResultExt;
25use store_api::storage::SequenceRange;
26
27use crate::error::{self, ComputeArrowSnafu, DecodeArrowRowGroupSnafu};
28use crate::memtable::bulk::context::{BulkIterContext, BulkIterContextRef};
29use crate::memtable::bulk::row_group_reader::MemtableRowGroupReaderBuilder;
30use crate::sst::parquet::file_range::PreFilterMode;
31use crate::sst::parquet::flat_format::sequence_column_index;
32use crate::sst::parquet::reader::RowGroupReaderContext;
33
34/// Iterator for reading data inside a bulk part.
35pub struct EncodedBulkPartIter {
36    context: BulkIterContextRef,
37    row_groups_to_read: VecDeque<usize>,
38    current_reader: Option<ParquetRecordBatchReader>,
39    builder: MemtableRowGroupReaderBuilder,
40    /// Sequence number filter.
41    sequence: Option<SequenceRange>,
42    /// Cached skip_fields for current row group.
43    current_skip_fields: bool,
44}
45
46impl EncodedBulkPartIter {
47    /// Creates a new [BulkPartIter].
48    pub(crate) fn try_new(
49        context: BulkIterContextRef,
50        mut row_groups_to_read: VecDeque<usize>,
51        parquet_meta: Arc<ParquetMetaData>,
52        data: Bytes,
53        sequence: Option<SequenceRange>,
54    ) -> error::Result<Self> {
55        assert!(context.read_format().as_flat().is_some());
56
57        let projection_mask = ProjectionMask::roots(
58            parquet_meta.file_metadata().schema_descr(),
59            context.read_format().projection_indices().iter().copied(),
60        );
61        let builder =
62            MemtableRowGroupReaderBuilder::try_new(&context, projection_mask, parquet_meta, data)?;
63
64        let (init_reader, current_skip_fields) = match row_groups_to_read.pop_front() {
65            Some(first_row_group) => {
66                let skip_fields = builder.compute_skip_fields(&context, first_row_group);
67                let reader = builder.build_row_group_reader(first_row_group, None)?;
68                (Some(reader), skip_fields)
69            }
70            None => (None, false),
71        };
72
73        Ok(Self {
74            context,
75            row_groups_to_read,
76            current_reader: init_reader,
77            builder,
78            sequence,
79            current_skip_fields,
80        })
81    }
82
83    /// Fetches next non-empty record batch.
84    pub(crate) fn next_record_batch(&mut self) -> error::Result<Option<RecordBatch>> {
85        let Some(current) = &mut self.current_reader else {
86            // All row group exhausted.
87            return Ok(None);
88        };
89
90        for batch in current {
91            let batch = batch.context(DecodeArrowRowGroupSnafu)?;
92            if let Some(batch) = apply_combined_filters(
93                &self.context,
94                &self.sequence,
95                batch,
96                self.current_skip_fields,
97            )? {
98                return Ok(Some(batch));
99            }
100        }
101
102        // Previous row group exhausted, read next row group
103        while let Some(next_row_group) = self.row_groups_to_read.pop_front() {
104            // Compute skip_fields for this row group
105            self.current_skip_fields = self
106                .builder
107                .compute_skip_fields(&self.context, next_row_group);
108
109            let next_reader = self.builder.build_row_group_reader(next_row_group, None)?;
110            let current = self.current_reader.insert(next_reader);
111
112            for batch in current {
113                let batch = batch.context(DecodeArrowRowGroupSnafu)?;
114                if let Some(batch) = apply_combined_filters(
115                    &self.context,
116                    &self.sequence,
117                    batch,
118                    self.current_skip_fields,
119                )? {
120                    return Ok(Some(batch));
121                }
122            }
123        }
124
125        Ok(None)
126    }
127}
128
129impl Iterator for EncodedBulkPartIter {
130    type Item = error::Result<RecordBatch>;
131
132    fn next(&mut self) -> Option<Self::Item> {
133        self.next_record_batch().transpose()
134    }
135}
136
137/// Iterator for a record batch in a bulk part.
138pub struct BulkPartRecordBatchIter {
139    /// The RecordBatch to read from
140    record_batch: Option<RecordBatch>,
141    /// Iterator context for filtering
142    context: BulkIterContextRef,
143    /// Sequence number filter.
144    sequence: Option<SequenceRange>,
145}
146
147impl BulkPartRecordBatchIter {
148    /// Creates a new [BulkPartRecordBatchIter] from a RecordBatch.
149    pub fn new(
150        record_batch: RecordBatch,
151        context: BulkIterContextRef,
152        sequence: Option<SequenceRange>,
153    ) -> Self {
154        assert!(context.read_format().as_flat().is_some());
155
156        Self {
157            record_batch: Some(record_batch),
158            context,
159            sequence,
160        }
161    }
162
163    /// Applies projection to the RecordBatch if needed.
164    fn apply_projection(&self, record_batch: RecordBatch) -> error::Result<RecordBatch> {
165        let projection_indices = self.context.read_format().projection_indices();
166        if projection_indices.len() == record_batch.num_columns() {
167            return Ok(record_batch);
168        }
169
170        record_batch
171            .project(projection_indices)
172            .context(ComputeArrowSnafu)
173    }
174
175    fn process_batch(&mut self, record_batch: RecordBatch) -> error::Result<Option<RecordBatch>> {
176        // Apply projection first.
177        let projected_batch = self.apply_projection(record_batch)?;
178        // Apply combined filtering (both predicate and sequence filters)
179        // For BulkPartRecordBatchIter, we don't have row group information.
180        let skip_fields = match self.context.pre_filter_mode() {
181            PreFilterMode::All => false,
182            PreFilterMode::SkipFields => true,
183            PreFilterMode::SkipFieldsOnDelete => true,
184        };
185        let Some(filtered_batch) =
186            apply_combined_filters(&self.context, &self.sequence, projected_batch, skip_fields)?
187        else {
188            return Ok(None);
189        };
190
191        Ok(Some(filtered_batch))
192    }
193}
194
195impl Iterator for BulkPartRecordBatchIter {
196    type Item = error::Result<RecordBatch>;
197
198    fn next(&mut self) -> Option<Self::Item> {
199        let record_batch = self.record_batch.take()?;
200
201        self.process_batch(record_batch).transpose()
202    }
203}
204
205/// Applies both predicate filtering and sequence filtering in a single pass.
206/// Returns None if the filtered batch is empty.
207///
208/// # Panics
209/// Panics if the format is not flat.
210fn apply_combined_filters(
211    context: &BulkIterContext,
212    sequence: &Option<SequenceRange>,
213    record_batch: RecordBatch,
214    skip_fields: bool,
215) -> error::Result<Option<RecordBatch>> {
216    // Converts the format to the flat format first.
217    let format = context.read_format().as_flat().unwrap();
218    let record_batch = format.convert_batch(record_batch, None)?;
219
220    let num_rows = record_batch.num_rows();
221    let mut combined_filter = None;
222
223    // First, apply predicate filters using the shared method.
224    if !context.base.filters.is_empty() {
225        let predicate_mask = context
226            .base
227            .compute_filter_mask_flat(&record_batch, skip_fields)?;
228        // If predicate filters out the entire batch, return None early
229        let Some(mask) = predicate_mask else {
230            return Ok(None);
231        };
232        combined_filter = Some(BooleanArray::from(mask));
233    }
234
235    // Filters rows by the given `sequence`. Only preserves rows with sequence less than or equal to `sequence`.
236    if let Some(sequence) = sequence {
237        let sequence_column =
238            record_batch.column(sequence_column_index(record_batch.num_columns()));
239        let sequence_filter = sequence
240            .filter(&sequence_column)
241            .context(ComputeArrowSnafu)?;
242        // Combine with existing filter using AND operation
243        combined_filter = match combined_filter {
244            None => Some(sequence_filter),
245            Some(existing_filter) => {
246                let and_result = datatypes::arrow::compute::and(&existing_filter, &sequence_filter)
247                    .context(ComputeArrowSnafu)?;
248                Some(and_result)
249            }
250        };
251    }
252
253    // Apply the combined filter if any filters were applied
254    let Some(filter_array) = combined_filter else {
255        // No filters applied, return original batch
256        return Ok(Some(record_batch));
257    };
258    let select_count = filter_array.true_count();
259    if select_count == 0 {
260        return Ok(None);
261    }
262    if select_count == num_rows {
263        return Ok(Some(record_batch));
264    }
265    let filtered_batch =
266        datatypes::arrow::compute::filter_record_batch(&record_batch, &filter_array)
267            .context(ComputeArrowSnafu)?;
268
269    Ok(Some(filtered_batch))
270}
271
272#[cfg(test)]
273mod tests {
274    use std::sync::Arc;
275
276    use api::v1::SemanticType;
277    use datafusion_expr::{col, lit};
278    use datatypes::arrow::array::{ArrayRef, Int64Array, StringArray, UInt8Array, UInt64Array};
279    use datatypes::arrow::datatypes::{DataType, Field, Schema};
280    use datatypes::data_type::ConcreteDataType;
281    use datatypes::schema::ColumnSchema;
282    use store_api::metadata::{ColumnMetadata, RegionMetadataBuilder};
283    use store_api::storage::RegionId;
284    use table::predicate::Predicate;
285
286    use super::*;
287    use crate::memtable::bulk::context::BulkIterContext;
288
289    #[test]
290    fn test_bulk_part_record_batch_iter() {
291        // Create a simple schema
292        let schema = Arc::new(Schema::new(vec![
293            Field::new("key1", DataType::Utf8, false),
294            Field::new("field1", DataType::Int64, false),
295            Field::new(
296                "timestamp",
297                DataType::Timestamp(datatypes::arrow::datatypes::TimeUnit::Millisecond, None),
298                false,
299            ),
300            Field::new(
301                "__primary_key",
302                DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Binary)),
303                false,
304            ),
305            Field::new("__sequence", DataType::UInt64, false),
306            Field::new("__op_type", DataType::UInt8, false),
307        ]));
308
309        // Create test data
310        let key1 = Arc::new(StringArray::from_iter_values(["key1", "key2", "key3"]));
311        let field1 = Arc::new(Int64Array::from(vec![11, 12, 13]));
312        let timestamp = Arc::new(datatypes::arrow::array::TimestampMillisecondArray::from(
313            vec![1000, 2000, 3000],
314        ));
315
316        // Create primary key dictionary array
317        use datatypes::arrow::array::{BinaryArray, DictionaryArray, UInt32Array};
318        let values = Arc::new(BinaryArray::from_iter_values([b"key1", b"key2", b"key3"]));
319        let keys = UInt32Array::from(vec![0, 1, 2]);
320        let primary_key = Arc::new(DictionaryArray::new(keys, values));
321
322        let sequence = Arc::new(UInt64Array::from(vec![1, 2, 3]));
323        let op_type = Arc::new(UInt8Array::from(vec![1, 1, 1])); // PUT operations
324
325        let record_batch = RecordBatch::try_new(
326            schema,
327            vec![
328                key1,
329                field1,
330                timestamp,
331                primary_key.clone(),
332                sequence,
333                op_type,
334            ],
335        )
336        .unwrap();
337
338        // Create a minimal region metadata for testing
339        let mut builder = RegionMetadataBuilder::new(RegionId::new(1, 1));
340        builder
341            .push_column_metadata(ColumnMetadata {
342                column_schema: ColumnSchema::new(
343                    "key1",
344                    ConcreteDataType::string_datatype(),
345                    false,
346                ),
347                semantic_type: SemanticType::Tag,
348                column_id: 0,
349            })
350            .push_column_metadata(ColumnMetadata {
351                column_schema: ColumnSchema::new(
352                    "field1",
353                    ConcreteDataType::int64_datatype(),
354                    false,
355                ),
356                semantic_type: SemanticType::Field,
357                column_id: 1,
358            })
359            .push_column_metadata(ColumnMetadata {
360                column_schema: ColumnSchema::new(
361                    "timestamp",
362                    ConcreteDataType::timestamp_millisecond_datatype(),
363                    false,
364                ),
365                semantic_type: SemanticType::Timestamp,
366                column_id: 2,
367            })
368            .primary_key(vec![0]);
369
370        let region_metadata = builder.build().unwrap();
371
372        // Create context
373        let context = Arc::new(
374            BulkIterContext::new(
375                Arc::new(region_metadata.clone()),
376                None, // No projection
377                None, // No predicate
378                false,
379            )
380            .unwrap(),
381        );
382        // Iterates all rows.
383        let iter = BulkPartRecordBatchIter::new(record_batch.clone(), context.clone(), None);
384        let result: Vec<_> = iter.map(|rb| rb.unwrap()).collect();
385        assert_eq!(1, result.len());
386        assert_eq!(3, result[0].num_rows());
387        assert_eq!(6, result[0].num_columns(),);
388
389        // Creates iter with sequence filter (only include sequences <= 2)
390        let iter = BulkPartRecordBatchIter::new(
391            record_batch.clone(),
392            context,
393            Some(SequenceRange::LtEq { max: 2 }),
394        );
395        let result: Vec<_> = iter.map(|rb| rb.unwrap()).collect();
396        assert_eq!(1, result.len());
397        let expect_sequence = Arc::new(UInt64Array::from(vec![1, 2])) as ArrayRef;
398        assert_eq!(
399            &expect_sequence,
400            result[0].column(result[0].num_columns() - 2)
401        );
402        assert_eq!(6, result[0].num_columns());
403
404        let context = Arc::new(
405            BulkIterContext::new(
406                Arc::new(region_metadata),
407                Some(&[0, 2]),
408                Some(Predicate::new(vec![col("key1").eq(lit("key2"))])),
409                false,
410            )
411            .unwrap(),
412        );
413        // Creates iter with projection and predicate.
414        let iter = BulkPartRecordBatchIter::new(record_batch.clone(), context.clone(), None);
415        let result: Vec<_> = iter.map(|rb| rb.unwrap()).collect();
416        assert_eq!(1, result.len());
417        assert_eq!(1, result[0].num_rows());
418        assert_eq!(5, result[0].num_columns());
419        let expect_sequence = Arc::new(UInt64Array::from(vec![2])) as ArrayRef;
420        assert_eq!(
421            &expect_sequence,
422            result[0].column(result[0].num_columns() - 2)
423        );
424    }
425}