mito2/read/
flat_merge.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::Ordering;
16use std::collections::BinaryHeap;
17use std::sync::Arc;
18
19use async_stream::try_stream;
20use datatypes::arrow::array::{Int64Array, UInt64Array};
21use datatypes::arrow::compute::interleave;
22use datatypes::arrow::datatypes::SchemaRef;
23use datatypes::arrow::record_batch::RecordBatch;
24use datatypes::arrow_array::BinaryArray;
25use datatypes::timestamp::timestamp_array_to_primitive;
26use futures::{Stream, TryStreamExt};
27use snafu::ResultExt;
28use store_api::storage::SequenceNumber;
29
30use crate::error::{ComputeArrowSnafu, Result};
31use crate::memtable::BoxedRecordBatchIterator;
32use crate::read::BoxedRecordBatchStream;
33use crate::sst::parquet::flat_format::{
34    primary_key_column_index, sequence_column_index, time_index_column_index,
35};
36use crate::sst::parquet::format::PrimaryKeyArray;
37
38/// Keeps track of the current position in a batch
39#[derive(Debug, Copy, Clone, Default)]
40struct BatchCursor {
41    /// The index into BatchBuilder::batches
42    batch_idx: usize,
43    /// The row index within the given batch
44    row_idx: usize,
45}
46
47/// Provides an API to incrementally build a [`RecordBatch`] from partitioned [`RecordBatch`]
48// Ports from https://github.com/apache/datafusion/blob/49.0.0/datafusion/physical-plan/src/sorts/builder.rs
49// Adds the `take_remaining_rows()` method.
50#[derive(Debug)]
51pub struct BatchBuilder {
52    /// The schema of the RecordBatches yielded by this stream
53    schema: SchemaRef,
54
55    /// Maintain a list of [`RecordBatch`] and their corresponding stream
56    batches: Vec<(usize, RecordBatch)>,
57
58    /// The current [`BatchCursor`] for each stream
59    cursors: Vec<BatchCursor>,
60
61    /// The accumulated stream indexes from which to pull rows
62    /// Consists of a tuple of `(batch_idx, row_idx)`
63    indices: Vec<(usize, usize)>,
64}
65
66impl BatchBuilder {
67    /// Create a new [`BatchBuilder`] with the provided `stream_count` and `batch_size`
68    pub fn new(schema: SchemaRef, stream_count: usize, batch_size: usize) -> Self {
69        Self {
70            schema,
71            batches: Vec::with_capacity(stream_count * 2),
72            cursors: vec![BatchCursor::default(); stream_count],
73            indices: Vec::with_capacity(batch_size),
74        }
75    }
76
77    /// Append a new batch in `stream_idx`
78    pub fn push_batch(&mut self, stream_idx: usize, batch: RecordBatch) {
79        let batch_idx = self.batches.len();
80        self.batches.push((stream_idx, batch));
81        self.cursors[stream_idx] = BatchCursor {
82            batch_idx,
83            row_idx: 0,
84        };
85    }
86
87    /// Append the next row from `stream_idx`
88    pub fn push_row(&mut self, stream_idx: usize) {
89        let cursor = &mut self.cursors[stream_idx];
90        let row_idx = cursor.row_idx;
91        cursor.row_idx += 1;
92        self.indices.push((cursor.batch_idx, row_idx));
93    }
94
95    /// Returns the number of in-progress rows in this [`BatchBuilder`]
96    pub fn len(&self) -> usize {
97        self.indices.len()
98    }
99
100    /// Returns `true` if this [`BatchBuilder`] contains no in-progress rows
101    pub fn is_empty(&self) -> bool {
102        self.indices.is_empty()
103    }
104
105    /// Returns the schema of this [`BatchBuilder`]
106    pub fn schema(&self) -> &SchemaRef {
107        &self.schema
108    }
109
110    /// Drains the in_progress row indexes, and builds a new RecordBatch from them
111    ///
112    /// Will then drop any batches for which all rows have been yielded to the output
113    ///
114    /// Returns `None` if no pending rows
115    pub fn build_record_batch(&mut self) -> Result<Option<RecordBatch>> {
116        if self.is_empty() {
117            return Ok(None);
118        }
119
120        let columns = (0..self.schema.fields.len())
121            .map(|column_idx| {
122                let arrays: Vec<_> = self
123                    .batches
124                    .iter()
125                    .map(|(_, batch)| batch.column(column_idx).as_ref())
126                    .collect();
127                interleave(&arrays, &self.indices).context(ComputeArrowSnafu)
128            })
129            .collect::<Result<Vec<_>>>()?;
130
131        self.indices.clear();
132
133        // New cursors are only created once the previous cursor for the stream
134        // is finished. This means all remaining rows from all but the last batch
135        // for each stream have been yielded to the newly created record batch
136        //
137        // We can therefore drop all but the last batch for each stream
138        self.retain_batches();
139
140        RecordBatch::try_new(Arc::clone(&self.schema), columns)
141            .context(ComputeArrowSnafu)
142            .map(Some)
143    }
144
145    /// Slice and take remaining rows from the last batch of `stream_idx` and push
146    /// the next batch if available.
147    pub fn take_remaining_rows(
148        &mut self,
149        stream_idx: usize,
150        next: Option<RecordBatch>,
151    ) -> RecordBatch {
152        let cursor = &mut self.cursors[stream_idx];
153        let batch = &self.batches[cursor.batch_idx];
154        let output = batch
155            .1
156            .slice(cursor.row_idx, batch.1.num_rows() - cursor.row_idx);
157        cursor.row_idx = batch.1.num_rows();
158
159        if let Some(b) = next {
160            self.push_batch(stream_idx, b);
161            self.retain_batches();
162        }
163
164        output
165    }
166
167    fn retain_batches(&mut self) {
168        let mut batch_idx = 0;
169        let mut retained = 0;
170        self.batches.retain(|(stream_idx, _)| {
171            let stream_cursor = &mut self.cursors[*stream_idx];
172            let retain = stream_cursor.batch_idx == batch_idx;
173            batch_idx += 1;
174
175            if retain {
176                stream_cursor.batch_idx = retained;
177                retained += 1;
178            }
179            retain
180        });
181    }
182}
183
184/// A comparable node of the heap.
185trait NodeCmp: Eq + Ord {
186    /// Returns whether the node still has batch to read.
187    fn is_eof(&self) -> bool;
188
189    /// Returns true if the key range of current batch in `self` is behind (exclusive) current
190    /// batch in `other`.
191    ///
192    /// # Panics
193    /// Panics if either `self` or `other` is EOF.
194    fn is_behind(&self, other: &Self) -> bool;
195}
196
197/// Common algorithm of merging sorted batches from multiple nodes.
198struct MergeAlgo<T> {
199    /// Holds nodes whose key range of current batch **is** overlapped with the merge window.
200    /// Each node yields batches from a `source`.
201    ///
202    /// Node in this heap **MUST** not be empty. A `merge window` is the (primary key, timestamp)
203    /// range of the **root node** in the `hot` heap.
204    hot: BinaryHeap<T>,
205    /// Holds nodes whose key range of current batch **isn't** overlapped with the merge window.
206    ///
207    /// Nodes in this heap **MUST** not be empty.
208    cold: BinaryHeap<T>,
209}
210
211impl<T: NodeCmp> MergeAlgo<T> {
212    /// Creates a new merge algorithm from `nodes`.
213    ///
214    /// All nodes must be initialized.
215    fn new(mut nodes: Vec<T>) -> Self {
216        // Skips EOF nodes.
217        nodes.retain(|node| !node.is_eof());
218        let hot = BinaryHeap::with_capacity(nodes.len());
219        let cold = BinaryHeap::from(nodes);
220
221        let mut algo = MergeAlgo { hot, cold };
222        // Initializes the algorithm.
223        algo.refill_hot();
224
225        algo
226    }
227
228    /// Moves nodes in `cold` heap, whose key range is overlapped with current merge
229    /// window to `hot` heap.
230    fn refill_hot(&mut self) {
231        while !self.cold.is_empty() {
232            if let Some(merge_window) = self.hot.peek() {
233                let warmest = self.cold.peek().unwrap();
234                if warmest.is_behind(merge_window) {
235                    // if the warmest node in the `cold` heap is totally after the
236                    // `merge_window`, then no need to add more nodes into the `hot`
237                    // heap for merge sorting.
238                    break;
239                }
240            }
241
242            let warmest = self.cold.pop().unwrap();
243            self.hot.push(warmest);
244        }
245    }
246
247    /// Push the node popped from `hot` back to a proper heap.
248    fn reheap(&mut self, node: T) {
249        if node.is_eof() {
250            // If the node is EOF, don't put it into the heap again.
251            // The merge window would be updated, need to refill the hot heap.
252            self.refill_hot();
253        } else {
254            // Find a proper heap for this node.
255            let node_is_cold = if let Some(hottest) = self.hot.peek() {
256                // If key range of this node is behind the hottest node's then we can
257                // push it to the cold heap. Otherwise we should push it to the hot heap.
258                node.is_behind(hottest)
259            } else {
260                // The hot heap is empty, but we don't known whether the current
261                // batch of this node is still the hottest.
262                true
263            };
264
265            if node_is_cold {
266                self.cold.push(node);
267            } else {
268                self.hot.push(node);
269            }
270            // Anyway, the merge window has been changed, we need to refill the hot heap.
271            self.refill_hot();
272        }
273    }
274
275    /// Pops the hottest node.
276    fn pop_hot(&mut self) -> Option<T> {
277        self.hot.pop()
278    }
279
280    /// Returns true if there are rows in the hot heap.
281    fn has_rows(&self) -> bool {
282        !self.hot.is_empty()
283    }
284
285    /// Returns true if we can fetch a batch directly instead of a row.
286    fn can_fetch_batch(&self) -> bool {
287        self.hot.len() == 1
288    }
289}
290
291// TODO(yingwen): Further downcast and store arrays in this struct.
292/// Columns to compare for a [RecordBatch].
293struct SortColumns {
294    primary_key: PrimaryKeyArray,
295    timestamp: Int64Array,
296    sequence: UInt64Array,
297}
298
299impl SortColumns {
300    /// Creates a new [SortColumns] from a [RecordBatch] and the position of the time index column.
301    ///
302    /// # Panics
303    /// Panics if the input batch doesn't have correct internal columns.
304    fn new(batch: &RecordBatch) -> Self {
305        let num_columns = batch.num_columns();
306        let primary_key = batch
307            .column(primary_key_column_index(num_columns))
308            .as_any()
309            .downcast_ref::<PrimaryKeyArray>()
310            .unwrap()
311            .clone();
312        let timestamp = batch.column(time_index_column_index(num_columns));
313        let (timestamp, _unit) = timestamp_array_to_primitive(timestamp).unwrap();
314        let sequence = batch
315            .column(sequence_column_index(num_columns))
316            .as_any()
317            .downcast_ref::<UInt64Array>()
318            .unwrap()
319            .clone();
320
321        Self {
322            primary_key,
323            timestamp,
324            sequence,
325        }
326    }
327
328    fn primary_key_at(&self, index: usize) -> &[u8] {
329        let key = self.primary_key.keys().value(index);
330        let binary_values = self
331            .primary_key
332            .values()
333            .as_any()
334            .downcast_ref::<BinaryArray>()
335            .unwrap();
336        binary_values.value(key as usize)
337    }
338
339    fn timestamp_at(&self, index: usize) -> i64 {
340        self.timestamp.value(index)
341    }
342
343    fn sequence_at(&self, index: usize) -> SequenceNumber {
344        self.sequence.value(index)
345    }
346
347    fn num_rows(&self) -> usize {
348        self.timestamp.len()
349    }
350}
351
352/// Cursor to a row in the [RecordBatch].
353///
354/// It compares batches by rows. During comparison, it ignores op type as sequence is enough to
355/// distinguish different rows.
356struct RowCursor {
357    /// Current row offset.
358    offset: usize,
359    /// Keys of the batch.
360    columns: SortColumns,
361}
362
363impl RowCursor {
364    fn new(columns: SortColumns) -> Self {
365        debug_assert!(columns.num_rows() > 0);
366
367        Self { offset: 0, columns }
368    }
369
370    fn is_finished(&self) -> bool {
371        self.offset >= self.columns.num_rows()
372    }
373
374    fn advance(&mut self) {
375        self.offset += 1;
376    }
377
378    fn first_primary_key(&self) -> &[u8] {
379        self.columns.primary_key_at(self.offset)
380    }
381
382    fn first_timestamp(&self) -> i64 {
383        self.columns.timestamp_at(self.offset)
384    }
385
386    fn first_sequence(&self) -> SequenceNumber {
387        self.columns.sequence_at(self.offset)
388    }
389
390    fn last_primary_key(&self) -> &[u8] {
391        self.columns.primary_key_at(self.columns.num_rows() - 1)
392    }
393
394    fn last_timestamp(&self) -> i64 {
395        self.columns.timestamp_at(self.columns.num_rows() - 1)
396    }
397}
398
399impl PartialEq for RowCursor {
400    fn eq(&self, other: &Self) -> bool {
401        self.first_primary_key() == other.first_primary_key()
402            && self.first_timestamp() == other.first_timestamp()
403            && self.first_sequence() == other.first_sequence()
404    }
405}
406
407impl Eq for RowCursor {}
408
409impl PartialOrd for RowCursor {
410    fn partial_cmp(&self, other: &RowCursor) -> Option<Ordering> {
411        Some(self.cmp(other))
412    }
413}
414
415impl Ord for RowCursor {
416    /// Compares by primary key, time index, sequence desc.
417    fn cmp(&self, other: &RowCursor) -> Ordering {
418        self.first_primary_key()
419            .cmp(other.first_primary_key())
420            .then_with(|| self.first_timestamp().cmp(&other.first_timestamp()))
421            .then_with(|| other.first_sequence().cmp(&self.first_sequence()))
422    }
423}
424
425/// Iterator to merge multiple sorted iterators into a single sorted iterator.
426///
427/// All iterators must be sorted by primary key, time index, sequence desc.
428pub struct FlatMergeIterator {
429    /// The merge algorithm to maintain heaps.
430    algo: MergeAlgo<IterNode>,
431    /// Current buffered rows to output.
432    in_progress: BatchBuilder,
433    /// Non-empty batch to output.
434    output_batch: Option<RecordBatch>,
435    /// Batch size to merge rows.
436    /// This is not a hard limit, the iterator may return smaller batches to avoid concatenating
437    /// rows.
438    batch_size: usize,
439}
440
441impl FlatMergeIterator {
442    /// Creates a new iterator to merge sorted `iters`.
443    pub fn new(
444        schema: SchemaRef,
445        iters: Vec<BoxedRecordBatchIterator>,
446        batch_size: usize,
447    ) -> Result<Self> {
448        let mut in_progress = BatchBuilder::new(schema, iters.len(), batch_size);
449        let mut nodes = Vec::with_capacity(iters.len());
450        // Initialize nodes and the buffer.
451        for (node_index, iter) in iters.into_iter().enumerate() {
452            let mut node = IterNode {
453                node_index,
454                iter,
455                cursor: None,
456            };
457            if let Some(batch) = node.advance_batch()? {
458                in_progress.push_batch(node_index, batch);
459                nodes.push(node);
460            }
461        }
462
463        let algo = MergeAlgo::new(nodes);
464
465        Ok(Self {
466            algo,
467            in_progress,
468            output_batch: None,
469            batch_size,
470        })
471    }
472
473    /// Fetches next sorted batch.
474    pub fn next_batch(&mut self) -> Result<Option<RecordBatch>> {
475        while self.algo.has_rows() && self.output_batch.is_none() {
476            if self.algo.can_fetch_batch() && !self.in_progress.is_empty() {
477                // Only one batch in the hot heap, but we have pending rows, output the pending rows first.
478                self.output_batch = self.in_progress.build_record_batch()?;
479                debug_assert!(self.output_batch.is_some());
480            } else if self.algo.can_fetch_batch() {
481                self.fetch_batch_from_hottest()?;
482            } else {
483                self.fetch_row_from_hottest()?;
484            }
485        }
486
487        if let Some(batch) = self.output_batch.take() {
488            Ok(Some(batch))
489        } else {
490            // No more batches.
491            Ok(None)
492        }
493    }
494
495    /// Fetches a batch from the hottest node.
496    fn fetch_batch_from_hottest(&mut self) -> Result<()> {
497        debug_assert!(self.in_progress.is_empty());
498
499        // Safety: next_batch() ensures the heap is not empty.
500        let mut hottest = self.algo.pop_hot().unwrap();
501        debug_assert!(!hottest.current_cursor().is_finished());
502        let next = hottest.advance_batch()?;
503        // The node is the heap is not empty, so it must have existing rows in the builder.
504        let batch = self
505            .in_progress
506            .take_remaining_rows(hottest.node_index, next);
507        Self::maybe_output_batch(batch, &mut self.output_batch);
508        self.algo.reheap(hottest);
509
510        Ok(())
511    }
512
513    /// Fetches a row from the hottest node.
514    fn fetch_row_from_hottest(&mut self) -> Result<()> {
515        // Safety: next_batch() ensures the heap has more than 1 element.
516        let mut hottest = self.algo.pop_hot().unwrap();
517        debug_assert!(!hottest.current_cursor().is_finished());
518        self.in_progress.push_row(hottest.node_index);
519        if self.in_progress.len() >= self.batch_size {
520            // We buffered enough rows.
521            if let Some(output) = self.in_progress.build_record_batch()? {
522                Self::maybe_output_batch(output, &mut self.output_batch);
523            }
524        }
525
526        if let Some(next) = hottest.advance_row()? {
527            self.in_progress.push_batch(hottest.node_index, next);
528        }
529
530        self.algo.reheap(hottest);
531        Ok(())
532    }
533
534    /// Adds the batch to the output batch if it is not empty.
535    fn maybe_output_batch(batch: RecordBatch, output_batch: &mut Option<RecordBatch>) {
536        debug_assert!(output_batch.is_none());
537        if batch.num_rows() > 0 {
538            *output_batch = Some(batch);
539        }
540    }
541}
542
543impl Iterator for FlatMergeIterator {
544    type Item = Result<RecordBatch>;
545
546    fn next(&mut self) -> Option<Self::Item> {
547        self.next_batch().transpose()
548    }
549}
550
551/// Iterator to merge multiple sorted iterators into a single sorted iterator.
552///
553/// All iterators must be sorted by primary key, time index, sequence desc.
554pub struct FlatMergeReader {
555    /// The merge algorithm to maintain heaps.
556    algo: MergeAlgo<StreamNode>,
557    /// Current buffered rows to output.
558    in_progress: BatchBuilder,
559    /// Non-empty batch to output.
560    output_batch: Option<RecordBatch>,
561    /// Batch size to merge rows.
562    /// This is not a hard limit, the iterator may return smaller batches to avoid concatenating
563    /// rows.
564    batch_size: usize,
565}
566
567impl FlatMergeReader {
568    /// Creates a new iterator to merge sorted `iters`.
569    pub async fn new(
570        schema: SchemaRef,
571        iters: Vec<BoxedRecordBatchStream>,
572        batch_size: usize,
573    ) -> Result<Self> {
574        let mut in_progress = BatchBuilder::new(schema, iters.len(), batch_size);
575        let mut nodes = Vec::with_capacity(iters.len());
576        // Initialize nodes and the buffer.
577        for (node_index, iter) in iters.into_iter().enumerate() {
578            let mut node = StreamNode {
579                node_index,
580                iter,
581                cursor: None,
582            };
583            if let Some(batch) = node.advance_batch().await? {
584                in_progress.push_batch(node_index, batch);
585                nodes.push(node);
586            }
587        }
588
589        let algo = MergeAlgo::new(nodes);
590
591        Ok(Self {
592            algo,
593            in_progress,
594            output_batch: None,
595            batch_size,
596        })
597    }
598
599    /// Fetches next sorted batch.
600    pub async fn next_batch(&mut self) -> Result<Option<RecordBatch>> {
601        while self.algo.has_rows() && self.output_batch.is_none() {
602            if self.algo.can_fetch_batch() && !self.in_progress.is_empty() {
603                // Only one batch in the hot heap, but we have pending rows, output the pending rows first.
604                self.output_batch = self.in_progress.build_record_batch()?;
605                debug_assert!(self.output_batch.is_some());
606            } else if self.algo.can_fetch_batch() {
607                self.fetch_batch_from_hottest().await?;
608            } else {
609                self.fetch_row_from_hottest().await?;
610            }
611        }
612
613        if let Some(batch) = self.output_batch.take() {
614            Ok(Some(batch))
615        } else {
616            // No more batches.
617            Ok(None)
618        }
619    }
620
621    /// Converts the reader into a stream.
622    pub fn into_stream(mut self) -> impl Stream<Item = Result<RecordBatch>> {
623        try_stream! {
624            while let Some(batch) = self.next_batch().await? {
625                yield batch;
626            }
627        }
628    }
629
630    /// Fetches a batch from the hottest node.
631    async fn fetch_batch_from_hottest(&mut self) -> Result<()> {
632        debug_assert!(self.in_progress.is_empty());
633
634        // Safety: next_batch() ensures the heap is not empty.
635        let mut hottest = self.algo.pop_hot().unwrap();
636        debug_assert!(!hottest.current_cursor().is_finished());
637        let next = hottest.advance_batch().await?;
638        // The node is the heap is not empty, so it must have existing rows in the builder.
639        let batch = self
640            .in_progress
641            .take_remaining_rows(hottest.node_index, next);
642        Self::maybe_output_batch(batch, &mut self.output_batch);
643        self.algo.reheap(hottest);
644
645        Ok(())
646    }
647
648    /// Fetches a row from the hottest node.
649    async fn fetch_row_from_hottest(&mut self) -> Result<()> {
650        // Safety: next_batch() ensures the heap has more than 1 element.
651        let mut hottest = self.algo.pop_hot().unwrap();
652        debug_assert!(!hottest.current_cursor().is_finished());
653        self.in_progress.push_row(hottest.node_index);
654        if self.in_progress.len() >= self.batch_size {
655            // We buffered enough rows.
656            if let Some(output) = self.in_progress.build_record_batch()? {
657                Self::maybe_output_batch(output, &mut self.output_batch);
658            }
659        }
660
661        if let Some(next) = hottest.advance_row().await? {
662            self.in_progress.push_batch(hottest.node_index, next);
663        }
664
665        self.algo.reheap(hottest);
666        Ok(())
667    }
668
669    /// Adds the batch to the output batch if it is not empty.
670    fn maybe_output_batch(batch: RecordBatch, output_batch: &mut Option<RecordBatch>) {
671        debug_assert!(output_batch.is_none());
672        if batch.num_rows() > 0 {
673            *output_batch = Some(batch);
674        }
675    }
676}
677
678/// A sync node in the merge iterator.
679struct GenericNode<T> {
680    /// Index of the node.
681    node_index: usize,
682    /// Iterator of this `Node`.
683    iter: T,
684    /// Current batch to be read. The node should ensure the batch is not empty (The
685    /// cursor is not finished).
686    ///
687    /// `None` means the `iter` has reached EOF.
688    cursor: Option<RowCursor>,
689}
690
691impl<T> NodeCmp for GenericNode<T> {
692    fn is_eof(&self) -> bool {
693        self.cursor.is_none()
694    }
695
696    fn is_behind(&self, other: &Self) -> bool {
697        debug_assert!(!self.current_cursor().is_finished());
698        debug_assert!(!other.current_cursor().is_finished());
699
700        // We only compare pk and timestamp so nodes in the cold
701        // heap don't have overlapping timestamps with the hottest node
702        // in the hot heap.
703        self.current_cursor()
704            .first_primary_key()
705            .cmp(other.current_cursor().last_primary_key())
706            .then_with(|| {
707                self.current_cursor()
708                    .first_timestamp()
709                    .cmp(&other.current_cursor().last_timestamp())
710            })
711            == Ordering::Greater
712    }
713}
714
715impl<T> PartialEq for GenericNode<T> {
716    fn eq(&self, other: &GenericNode<T>) -> bool {
717        self.cursor == other.cursor
718    }
719}
720
721impl<T> Eq for GenericNode<T> {}
722
723impl<T> PartialOrd for GenericNode<T> {
724    fn partial_cmp(&self, other: &GenericNode<T>) -> Option<Ordering> {
725        Some(self.cmp(other))
726    }
727}
728
729impl<T> Ord for GenericNode<T> {
730    fn cmp(&self, other: &GenericNode<T>) -> Ordering {
731        // The std binary heap is a max heap, but we want the nodes are ordered in
732        // ascend order, so we compare the nodes in reverse order.
733        other.cursor.cmp(&self.cursor)
734    }
735}
736
737impl<T> GenericNode<T> {
738    /// Returns current cursor.
739    ///
740    /// # Panics
741    /// Panics if the node has reached EOF.
742    fn current_cursor(&self) -> &RowCursor {
743        self.cursor.as_ref().unwrap()
744    }
745}
746
747impl GenericNode<BoxedRecordBatchIterator> {
748    /// Fetches a new batch from the iter and updates the cursor.
749    /// It advances the current batch.
750    /// Returns the fetched new batch.
751    fn advance_batch(&mut self) -> Result<Option<RecordBatch>> {
752        let batch = self.advance_inner_iter()?;
753        let columns = batch.as_ref().map(SortColumns::new);
754        self.cursor = columns.map(RowCursor::new);
755
756        Ok(batch)
757    }
758
759    /// Skips one row.
760    /// Returns the next batch if the current batch is finished.
761    fn advance_row(&mut self) -> Result<Option<RecordBatch>> {
762        let cursor = self.cursor.as_mut().unwrap();
763        cursor.advance();
764        if !cursor.is_finished() {
765            return Ok(None);
766        }
767
768        // Finished current batch, need to fetch a new batch.
769        self.advance_batch()
770    }
771
772    /// Fetches a non-empty batch from the iter.
773    fn advance_inner_iter(&mut self) -> Result<Option<RecordBatch>> {
774        while let Some(batch) = self.iter.next().transpose()? {
775            if batch.num_rows() > 0 {
776                return Ok(Some(batch));
777            }
778        }
779        Ok(None)
780    }
781}
782
783type StreamNode = GenericNode<BoxedRecordBatchStream>;
784type IterNode = GenericNode<BoxedRecordBatchIterator>;
785
786impl GenericNode<BoxedRecordBatchStream> {
787    /// Fetches a new batch from the iter and updates the cursor.
788    /// It advances the current batch.
789    /// Returns the fetched new batch.
790    async fn advance_batch(&mut self) -> Result<Option<RecordBatch>> {
791        let batch = self.advance_inner_iter().await?;
792        let columns = batch.as_ref().map(SortColumns::new);
793        self.cursor = columns.map(RowCursor::new);
794
795        Ok(batch)
796    }
797
798    /// Skips one row.
799    /// Returns the next batch if the current batch is finished.
800    async fn advance_row(&mut self) -> Result<Option<RecordBatch>> {
801        let cursor = self.cursor.as_mut().unwrap();
802        cursor.advance();
803        if !cursor.is_finished() {
804            return Ok(None);
805        }
806
807        // Finished current batch, need to fetch a new batch.
808        self.advance_batch().await
809    }
810
811    /// Fetches a non-empty batch from the iter.
812    async fn advance_inner_iter(&mut self) -> Result<Option<RecordBatch>> {
813        while let Some(batch) = self.iter.try_next().await? {
814            if batch.num_rows() > 0 {
815                return Ok(Some(batch));
816            }
817        }
818        Ok(None)
819    }
820}
821
822#[cfg(test)]
823mod tests {
824    use std::sync::Arc;
825
826    use api::v1::OpType;
827    use datatypes::arrow::array::builder::BinaryDictionaryBuilder;
828    use datatypes::arrow::array::{Int64Array, TimestampMillisecondArray, UInt64Array, UInt8Array};
829    use datatypes::arrow::datatypes::{DataType, Field, Schema, TimeUnit, UInt32Type};
830    use datatypes::arrow::record_batch::RecordBatch;
831
832    use super::*;
833
834    /// Creates a test RecordBatch with the specified data.
835    fn create_test_record_batch(
836        primary_keys: &[&[u8]],
837        timestamps: &[i64],
838        sequences: &[u64],
839        op_types: &[OpType],
840        field_values: &[i64],
841    ) -> RecordBatch {
842        let schema = Arc::new(Schema::new(vec![
843            Field::new("field1", DataType::Int64, false),
844            Field::new(
845                "timestamp",
846                DataType::Timestamp(TimeUnit::Millisecond, None),
847                false,
848            ),
849            Field::new(
850                "__primary_key",
851                DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Binary)),
852                false,
853            ),
854            Field::new("__sequence", DataType::UInt64, false),
855            Field::new("__op_type", DataType::UInt8, false),
856        ]));
857
858        let field1 = Arc::new(Int64Array::from_iter_values(field_values.iter().copied()));
859        let timestamp = Arc::new(TimestampMillisecondArray::from_iter_values(
860            timestamps.iter().copied(),
861        ));
862
863        // Create primary key dictionary array using BinaryDictionaryBuilder
864        let mut builder = BinaryDictionaryBuilder::<UInt32Type>::new();
865        for &key in primary_keys {
866            builder.append(key).unwrap();
867        }
868        let primary_key = Arc::new(builder.finish());
869
870        let sequence = Arc::new(UInt64Array::from_iter_values(sequences.iter().copied()));
871        let op_type = Arc::new(UInt8Array::from_iter_values(
872            op_types.iter().map(|&v| v as u8),
873        ));
874
875        RecordBatch::try_new(
876            schema,
877            vec![field1, timestamp, primary_key, sequence, op_type],
878        )
879        .unwrap()
880    }
881
882    fn new_test_iter(batches: Vec<RecordBatch>) -> BoxedRecordBatchIterator {
883        Box::new(batches.into_iter().map(Ok))
884    }
885
886    /// Helper function to check if two record batches are equivalent.
887    fn assert_record_batches_eq(expected: &[RecordBatch], actual: &[RecordBatch]) {
888        for (exp, act) in expected.iter().zip(actual.iter()) {
889            assert_eq!(exp, act,);
890        }
891    }
892
893    /// Helper function to collect all batches from a FlatMergeIterator.
894    fn collect_merge_iterator_batches(iter: FlatMergeIterator) -> Vec<RecordBatch> {
895        iter.map(|result| result.unwrap()).collect()
896    }
897
898    #[test]
899    fn test_merge_iterator_empty() {
900        let schema = Arc::new(Schema::new(vec![
901            Field::new("field1", DataType::Int64, false),
902            Field::new(
903                "timestamp",
904                DataType::Timestamp(TimeUnit::Millisecond, None),
905                false,
906            ),
907            Field::new(
908                "__primary_key",
909                DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Binary)),
910                false,
911            ),
912            Field::new("__sequence", DataType::UInt64, false),
913            Field::new("__op_type", DataType::UInt8, false),
914        ]));
915
916        let mut merge_iter = FlatMergeIterator::new(schema, vec![], 1024).unwrap();
917        assert!(merge_iter.next_batch().unwrap().is_none());
918    }
919
920    #[test]
921    fn test_merge_iterator_single_batch() {
922        let batch = create_test_record_batch(
923            &[b"k1", b"k1"],
924            &[1000, 2000],
925            &[21, 22],
926            &[OpType::Put, OpType::Put],
927            &[11, 12],
928        );
929
930        let schema = batch.schema();
931        let iter = Box::new(new_test_iter(vec![batch.clone()]));
932
933        let merge_iter = FlatMergeIterator::new(schema, vec![iter], 1024).unwrap();
934        let result = collect_merge_iterator_batches(merge_iter);
935
936        assert_eq!(result.len(), 1);
937        assert_record_batches_eq(&[batch], &result);
938    }
939
940    #[test]
941    fn test_merge_iterator_non_overlapping() {
942        let batch1 = create_test_record_batch(
943            &[b"k1", b"k1"],
944            &[1000, 2000],
945            &[21, 22],
946            &[OpType::Put, OpType::Put],
947            &[11, 12],
948        );
949        let batch2 = create_test_record_batch(
950            &[b"k1", b"k1"],
951            &[4000, 5000],
952            &[24, 25],
953            &[OpType::Put, OpType::Put],
954            &[14, 15],
955        );
956        let batch3 = create_test_record_batch(
957            &[b"k2", b"k2"],
958            &[2000, 3000],
959            &[22, 23],
960            &[OpType::Delete, OpType::Put],
961            &[12, 13],
962        );
963
964        let schema = batch1.schema();
965        let iter1 = Box::new(new_test_iter(vec![batch1.clone(), batch3.clone()]));
966        let iter2 = Box::new(new_test_iter(vec![batch2.clone()]));
967
968        let merge_iter = FlatMergeIterator::new(schema, vec![iter1, iter2], 1024).unwrap();
969        let result = collect_merge_iterator_batches(merge_iter);
970
971        // Results should be sorted by primary key, timestamp, sequence desc
972        let expected = vec![batch1, batch2, batch3];
973        assert_record_batches_eq(&expected, &result);
974    }
975
976    #[test]
977    fn test_merge_iterator_overlapping_timestamps() {
978        // Create batches with overlapping timestamps but different sequences
979        let batch1 = create_test_record_batch(
980            &[b"k1", b"k1"],
981            &[1000, 2000],
982            &[21, 22],
983            &[OpType::Put, OpType::Put],
984            &[11, 12],
985        );
986        let batch2 = create_test_record_batch(
987            &[b"k1", b"k1"],
988            &[1500, 2500],
989            &[31, 32],
990            &[OpType::Put, OpType::Put],
991            &[15, 25],
992        );
993
994        let schema = batch1.schema();
995        let iter1 = Box::new(new_test_iter(vec![batch1]));
996        let iter2 = Box::new(new_test_iter(vec![batch2]));
997
998        let merge_iter = FlatMergeIterator::new(schema, vec![iter1, iter2], 1024).unwrap();
999        let result = collect_merge_iterator_batches(merge_iter);
1000
1001        let expected = vec![
1002            create_test_record_batch(
1003                &[b"k1", b"k1"],
1004                &[1000, 1500],
1005                &[21, 31],
1006                &[OpType::Put, OpType::Put],
1007                &[11, 15],
1008            ),
1009            create_test_record_batch(&[b"k1"], &[2000], &[22], &[OpType::Put], &[12]),
1010            create_test_record_batch(&[b"k1"], &[2500], &[32], &[OpType::Put], &[25]),
1011        ];
1012        assert_record_batches_eq(&expected, &result);
1013    }
1014
1015    #[test]
1016    fn test_merge_iterator_duplicate_keys_sequences() {
1017        // Test with same primary key and timestamp but different sequences
1018        let batch1 = create_test_record_batch(
1019            &[b"k1", b"k1"],
1020            &[1000, 1000],
1021            &[20, 10],
1022            &[OpType::Put, OpType::Put],
1023            &[1, 2],
1024        );
1025        let batch2 = create_test_record_batch(
1026            &[b"k1"],
1027            &[1000],
1028            &[15], // Middle sequence
1029            &[OpType::Put],
1030            &[3],
1031        );
1032
1033        let schema = batch1.schema();
1034        let iter1 = Box::new(new_test_iter(vec![batch1]));
1035        let iter2 = Box::new(new_test_iter(vec![batch2]));
1036
1037        let merge_iter = FlatMergeIterator::new(schema, vec![iter1, iter2], 1024).unwrap();
1038        let result = collect_merge_iterator_batches(merge_iter);
1039
1040        // Should be sorted by sequence descending for same key/timestamp
1041        let expected = vec![
1042            create_test_record_batch(
1043                &[b"k1", b"k1"],
1044                &[1000, 1000],
1045                &[20, 15],
1046                &[OpType::Put, OpType::Put],
1047                &[1, 3],
1048            ),
1049            create_test_record_batch(&[b"k1"], &[1000], &[10], &[OpType::Put], &[2]),
1050        ];
1051        assert_record_batches_eq(&expected, &result);
1052    }
1053
1054    #[test]
1055    fn test_batch_builder_basic() {
1056        let schema = Arc::new(Schema::new(vec![
1057            Field::new("field1", DataType::Int64, false),
1058            Field::new(
1059                "timestamp",
1060                DataType::Timestamp(TimeUnit::Millisecond, None),
1061                false,
1062            ),
1063        ]));
1064
1065        let mut builder = BatchBuilder::new(schema.clone(), 2, 1024);
1066        assert!(builder.is_empty());
1067
1068        let batch = RecordBatch::try_new(
1069            schema,
1070            vec![
1071                Arc::new(Int64Array::from(vec![1, 2])),
1072                Arc::new(TimestampMillisecondArray::from(vec![1000, 2000])),
1073            ],
1074        )
1075        .unwrap();
1076
1077        builder.push_batch(0, batch);
1078        builder.push_row(0);
1079        builder.push_row(0);
1080
1081        assert!(!builder.is_empty());
1082        assert_eq!(builder.len(), 2);
1083
1084        let result_batch = builder.build_record_batch().unwrap().unwrap();
1085        assert_eq!(result_batch.num_rows(), 2);
1086    }
1087
1088    #[test]
1089    fn test_row_cursor_comparison() {
1090        // Create test batches for cursor comparison
1091        let batch1 = create_test_record_batch(
1092            &[b"k1", b"k1"],
1093            &[1000, 2000],
1094            &[22, 21],
1095            &[OpType::Put, OpType::Put],
1096            &[11, 12],
1097        );
1098        let batch2 = create_test_record_batch(
1099            &[b"k1", b"k1"],
1100            &[1000, 2000],
1101            &[23, 20], // Different sequences
1102            &[OpType::Put, OpType::Put],
1103            &[11, 12],
1104        );
1105
1106        let columns1 = SortColumns::new(&batch1);
1107        let columns2 = SortColumns::new(&batch2);
1108
1109        let cursor1 = RowCursor::new(columns1);
1110        let cursor2 = RowCursor::new(columns2);
1111
1112        // cursors with same pk and timestamp should be ordered by sequence desc
1113        // cursor1 has sequence 22, cursor2 has sequence 23, so cursor2 < cursor1 (higher sequence comes first)
1114        assert!(cursor2 < cursor1);
1115    }
1116}