mito2/memtable/bulk/
part_reader.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::VecDeque;
16use std::ops::BitAnd;
17use std::sync::Arc;
18
19use bytes::Bytes;
20use datatypes::arrow::array::{BooleanArray, Scalar, UInt64Array};
21use datatypes::arrow::buffer::BooleanBuffer;
22use datatypes::arrow::record_batch::RecordBatch;
23use parquet::arrow::ProjectionMask;
24use parquet::arrow::arrow_reader::ParquetRecordBatchReader;
25use parquet::file::metadata::ParquetMetaData;
26use snafu::ResultExt;
27use store_api::storage::SequenceNumber;
28
29use crate::error::{self, ComputeArrowSnafu, DecodeArrowRowGroupSnafu};
30use crate::memtable::bulk::context::{BulkIterContext, BulkIterContextRef};
31use crate::memtable::bulk::row_group_reader::MemtableRowGroupReaderBuilder;
32use crate::sst::parquet::flat_format::sequence_column_index;
33use crate::sst::parquet::reader::{MaybeFilter, RowGroupReaderContext};
34
35/// Iterator for reading data inside a bulk part.
36pub struct EncodedBulkPartIter {
37    context: BulkIterContextRef,
38    row_groups_to_read: VecDeque<usize>,
39    current_reader: Option<ParquetRecordBatchReader>,
40    builder: MemtableRowGroupReaderBuilder,
41    /// Sequence number filter.
42    sequence: Option<Scalar<UInt64Array>>,
43}
44
45impl EncodedBulkPartIter {
46    /// Creates a new [BulkPartIter].
47    pub(crate) fn try_new(
48        context: BulkIterContextRef,
49        mut row_groups_to_read: VecDeque<usize>,
50        parquet_meta: Arc<ParquetMetaData>,
51        data: Bytes,
52        sequence: Option<SequenceNumber>,
53    ) -> error::Result<Self> {
54        assert!(context.read_format().as_flat().is_some());
55
56        let sequence = sequence.map(UInt64Array::new_scalar);
57
58        let projection_mask = ProjectionMask::roots(
59            parquet_meta.file_metadata().schema_descr(),
60            context.read_format().projection_indices().iter().copied(),
61        );
62        let builder =
63            MemtableRowGroupReaderBuilder::try_new(&context, projection_mask, parquet_meta, data)?;
64
65        let init_reader = row_groups_to_read
66            .pop_front()
67            .map(|first_row_group| builder.build_row_group_reader(first_row_group, None))
68            .transpose()?;
69        Ok(Self {
70            context,
71            row_groups_to_read,
72            current_reader: init_reader,
73            builder,
74            sequence,
75        })
76    }
77
78    /// Fetches next non-empty record batch.
79    pub(crate) fn next_record_batch(&mut self) -> error::Result<Option<RecordBatch>> {
80        let Some(current) = &mut self.current_reader else {
81            // All row group exhausted.
82            return Ok(None);
83        };
84
85        for batch in current {
86            let batch = batch.context(DecodeArrowRowGroupSnafu)?;
87            if let Some(batch) = apply_combined_filters(&self.context, &self.sequence, batch)? {
88                return Ok(Some(batch));
89            }
90        }
91
92        // Previous row group exhausted, read next row group
93        while let Some(next_row_group) = self.row_groups_to_read.pop_front() {
94            let next_reader = self.builder.build_row_group_reader(next_row_group, None)?;
95            let current = self.current_reader.insert(next_reader);
96
97            for batch in current {
98                let batch = batch.context(DecodeArrowRowGroupSnafu)?;
99                if let Some(batch) = apply_combined_filters(&self.context, &self.sequence, batch)? {
100                    return Ok(Some(batch));
101                }
102            }
103        }
104
105        Ok(None)
106    }
107}
108
109impl Iterator for EncodedBulkPartIter {
110    type Item = error::Result<RecordBatch>;
111
112    fn next(&mut self) -> Option<Self::Item> {
113        self.next_record_batch().transpose()
114    }
115}
116
117/// Iterator for a record batch in a bulk part.
118pub struct BulkPartRecordBatchIter {
119    /// The RecordBatch to read from
120    record_batch: Option<RecordBatch>,
121    /// Iterator context for filtering
122    context: BulkIterContextRef,
123    /// Sequence number filter.
124    sequence: Option<Scalar<UInt64Array>>,
125}
126
127impl BulkPartRecordBatchIter {
128    /// Creates a new [BulkPartRecordBatchIter] from a RecordBatch.
129    pub fn new(
130        record_batch: RecordBatch,
131        context: BulkIterContextRef,
132        sequence: Option<SequenceNumber>,
133    ) -> Self {
134        assert!(context.read_format().as_flat().is_some());
135
136        let sequence = sequence.map(UInt64Array::new_scalar);
137
138        Self {
139            record_batch: Some(record_batch),
140            context,
141            sequence,
142        }
143    }
144
145    /// Applies projection to the RecordBatch if needed.
146    fn apply_projection(&self, record_batch: RecordBatch) -> error::Result<RecordBatch> {
147        let projection_indices = self.context.read_format().projection_indices();
148        if projection_indices.len() == record_batch.num_columns() {
149            return Ok(record_batch);
150        }
151
152        record_batch
153            .project(projection_indices)
154            .context(ComputeArrowSnafu)
155    }
156
157    fn process_batch(&mut self, record_batch: RecordBatch) -> error::Result<Option<RecordBatch>> {
158        // Apply projection first.
159        let projected_batch = self.apply_projection(record_batch)?;
160        // Apply combined filtering (both predicate and sequence filters)
161        let Some(filtered_batch) =
162            apply_combined_filters(&self.context, &self.sequence, projected_batch)?
163        else {
164            return Ok(None);
165        };
166
167        Ok(Some(filtered_batch))
168    }
169}
170
171impl Iterator for BulkPartRecordBatchIter {
172    type Item = error::Result<RecordBatch>;
173
174    fn next(&mut self) -> Option<Self::Item> {
175        let record_batch = self.record_batch.take()?;
176
177        self.process_batch(record_batch).transpose()
178    }
179}
180
181/// Applies both predicate filtering and sequence filtering in a single pass.
182/// Returns None if the filtered batch is empty.
183///
184/// # Panics
185/// Panics if the format is not flat.
186fn apply_combined_filters(
187    context: &BulkIterContext,
188    sequence: &Option<Scalar<UInt64Array>>,
189    record_batch: RecordBatch,
190) -> error::Result<Option<RecordBatch>> {
191    // Converts the format to the flat format first.
192    let format = context.read_format().as_flat().unwrap();
193    let record_batch = format.convert_batch(record_batch, None)?;
194
195    let num_rows = record_batch.num_rows();
196    let mut combined_filter = None;
197
198    // First, apply predicate filters.
199    if !context.base.filters.is_empty() {
200        let num_rows = record_batch.num_rows();
201        let mut mask = BooleanBuffer::new_set(num_rows);
202
203        // Run filter one by one and combine them result, similar to RangeBase::precise_filter
204        for filter_ctx in &context.base.filters {
205            let filter = match filter_ctx.filter() {
206                MaybeFilter::Filter(f) => f,
207                // Column matches.
208                MaybeFilter::Matched => continue,
209                // Column doesn't match, filter the entire batch.
210                MaybeFilter::Pruned => return Ok(None),
211            };
212
213            // Safety: We checked the format type in new().
214            let Some(column_index) = context
215                .read_format()
216                .as_flat()
217                .unwrap()
218                .projected_index_by_id(filter_ctx.column_id())
219            else {
220                continue;
221            };
222            let array = record_batch.column(column_index);
223            let result = filter
224                .evaluate_array(array)
225                .context(crate::error::RecordBatchSnafu)?;
226
227            mask = mask.bitand(&result);
228        }
229        // Convert the mask to BooleanArray
230        combined_filter = Some(BooleanArray::from(mask));
231    }
232
233    // Filters rows by the given `sequence`. Only preserves rows with sequence less than or equal to `sequence`.
234    if let Some(sequence) = sequence {
235        let sequence_column =
236            record_batch.column(sequence_column_index(record_batch.num_columns()));
237        let sequence_filter =
238            datatypes::arrow::compute::kernels::cmp::lt_eq(sequence_column, sequence)
239                .context(ComputeArrowSnafu)?;
240        // Combine with existing filter using AND operation
241        combined_filter = match combined_filter {
242            None => Some(sequence_filter),
243            Some(existing_filter) => {
244                let and_result = datatypes::arrow::compute::and(&existing_filter, &sequence_filter)
245                    .context(ComputeArrowSnafu)?;
246                Some(and_result)
247            }
248        };
249    }
250
251    // Apply the combined filter if any filters were applied
252    let Some(filter_array) = combined_filter else {
253        // No filters applied, return original batch
254        return Ok(Some(record_batch));
255    };
256    let select_count = filter_array.true_count();
257    if select_count == 0 {
258        return Ok(None);
259    }
260    if select_count == num_rows {
261        return Ok(Some(record_batch));
262    }
263    let filtered_batch =
264        datatypes::arrow::compute::filter_record_batch(&record_batch, &filter_array)
265            .context(ComputeArrowSnafu)?;
266
267    Ok(Some(filtered_batch))
268}
269
270#[cfg(test)]
271mod tests {
272    use std::sync::Arc;
273
274    use api::v1::SemanticType;
275    use datafusion_expr::{col, lit};
276    use datatypes::arrow::array::{ArrayRef, Int64Array, StringArray, UInt8Array, UInt64Array};
277    use datatypes::arrow::datatypes::{DataType, Field, Schema};
278    use datatypes::data_type::ConcreteDataType;
279    use datatypes::schema::ColumnSchema;
280    use store_api::metadata::{ColumnMetadata, RegionMetadataBuilder};
281    use store_api::storage::RegionId;
282    use table::predicate::Predicate;
283
284    use super::*;
285    use crate::memtable::bulk::context::BulkIterContext;
286
287    #[test]
288    fn test_bulk_part_record_batch_iter() {
289        // Create a simple schema
290        let schema = Arc::new(Schema::new(vec![
291            Field::new("key1", DataType::Utf8, false),
292            Field::new("field1", DataType::Int64, false),
293            Field::new(
294                "timestamp",
295                DataType::Timestamp(datatypes::arrow::datatypes::TimeUnit::Millisecond, None),
296                false,
297            ),
298            Field::new(
299                "__primary_key",
300                DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Binary)),
301                false,
302            ),
303            Field::new("__sequence", DataType::UInt64, false),
304            Field::new("__op_type", DataType::UInt8, false),
305        ]));
306
307        // Create test data
308        let key1 = Arc::new(StringArray::from_iter_values(["key1", "key2", "key3"]));
309        let field1 = Arc::new(Int64Array::from(vec![11, 12, 13]));
310        let timestamp = Arc::new(datatypes::arrow::array::TimestampMillisecondArray::from(
311            vec![1000, 2000, 3000],
312        ));
313
314        // Create primary key dictionary array
315        use datatypes::arrow::array::{BinaryArray, DictionaryArray, UInt32Array};
316        let values = Arc::new(BinaryArray::from_iter_values([b"key1", b"key2", b"key3"]));
317        let keys = UInt32Array::from(vec![0, 1, 2]);
318        let primary_key = Arc::new(DictionaryArray::new(keys, values));
319
320        let sequence = Arc::new(UInt64Array::from(vec![1, 2, 3]));
321        let op_type = Arc::new(UInt8Array::from(vec![1, 1, 1])); // PUT operations
322
323        let record_batch = RecordBatch::try_new(
324            schema,
325            vec![
326                key1,
327                field1,
328                timestamp,
329                primary_key.clone(),
330                sequence,
331                op_type,
332            ],
333        )
334        .unwrap();
335
336        // Create a minimal region metadata for testing
337        let mut builder = RegionMetadataBuilder::new(RegionId::new(1, 1));
338        builder
339            .push_column_metadata(ColumnMetadata {
340                column_schema: ColumnSchema::new(
341                    "key1",
342                    ConcreteDataType::string_datatype(),
343                    false,
344                ),
345                semantic_type: SemanticType::Tag,
346                column_id: 0,
347            })
348            .push_column_metadata(ColumnMetadata {
349                column_schema: ColumnSchema::new(
350                    "field1",
351                    ConcreteDataType::int64_datatype(),
352                    false,
353                ),
354                semantic_type: SemanticType::Field,
355                column_id: 1,
356            })
357            .push_column_metadata(ColumnMetadata {
358                column_schema: ColumnSchema::new(
359                    "timestamp",
360                    ConcreteDataType::timestamp_millisecond_datatype(),
361                    false,
362                ),
363                semantic_type: SemanticType::Timestamp,
364                column_id: 2,
365            })
366            .primary_key(vec![0]);
367
368        let region_metadata = builder.build().unwrap();
369
370        // Create context
371        let context = Arc::new(
372            BulkIterContext::new(
373                Arc::new(region_metadata.clone()),
374                None, // No projection
375                None, // No predicate
376                false,
377            )
378            .unwrap(),
379        );
380        // Iterates all rows.
381        let iter = BulkPartRecordBatchIter::new(record_batch.clone(), context.clone(), None);
382        let result: Vec<_> = iter.map(|rb| rb.unwrap()).collect();
383        assert_eq!(1, result.len());
384        assert_eq!(3, result[0].num_rows());
385        assert_eq!(6, result[0].num_columns(),);
386
387        // Creates iter with sequence filter (only include sequences <= 2)
388        let iter = BulkPartRecordBatchIter::new(record_batch.clone(), context, Some(2));
389        let result: Vec<_> = iter.map(|rb| rb.unwrap()).collect();
390        assert_eq!(1, result.len());
391        let expect_sequence = Arc::new(UInt64Array::from(vec![1, 2])) as ArrayRef;
392        assert_eq!(
393            &expect_sequence,
394            result[0].column(result[0].num_columns() - 2)
395        );
396        assert_eq!(6, result[0].num_columns());
397
398        let context = Arc::new(
399            BulkIterContext::new(
400                Arc::new(region_metadata),
401                Some(&[0, 2]),
402                Some(Predicate::new(vec![col("key1").eq(lit("key2"))])),
403                false,
404            )
405            .unwrap(),
406        );
407        // Creates iter with projection and predicate.
408        let iter = BulkPartRecordBatchIter::new(record_batch.clone(), context.clone(), None);
409        let result: Vec<_> = iter.map(|rb| rb.unwrap()).collect();
410        assert_eq!(1, result.len());
411        assert_eq!(1, result[0].num_rows());
412        assert_eq!(5, result[0].num_columns());
413        let expect_sequence = Arc::new(UInt64Array::from(vec![2])) as ArrayRef;
414        assert_eq!(
415            &expect_sequence,
416            result[0].column(result[0].num_columns() - 2)
417        );
418    }
419}