Skip to main content

query/
part_sort.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Module for sorting input data within each [`PartitionRange`].
16//!
17//! This module defines the [`PartSortExec`] execution plan, which sorts each
18//! partition ([`PartitionRange`]) independently based on the provided physical
19//! sort expressions.
20
21use std::any::Any;
22use std::pin::Pin;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26use arrow::array::{
27    ArrayRef, AsArray, TimestampMicrosecondArray, TimestampMillisecondArray,
28    TimestampNanosecondArray, TimestampSecondArray,
29};
30use arrow::compute::{concat, concat_batches, take_record_batch};
31use arrow_schema::{Schema, SchemaRef};
32use common_recordbatch::{DfRecordBatch, DfSendableRecordBatchStream};
33use common_telemetry::warn;
34use common_time::Timestamp;
35use common_time::timestamp::TimeUnit;
36use datafusion::common::arrow::compute::sort_to_indices;
37use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation};
38use datafusion::execution::{RecordBatchStream, TaskContext};
39use datafusion::physical_plan::execution_plan::CardinalityEffect;
40use datafusion::physical_plan::filter_pushdown::{
41    ChildFilterDescription, FilterDescription, FilterPushdownPhase,
42};
43use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
44use datafusion::physical_plan::{
45    DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, TopK,
46    TopKDynamicFilters,
47};
48use datafusion_common::tree_node::{Transformed, TreeNode};
49use datafusion_common::{DataFusionError, internal_err};
50use datafusion_physical_expr::expressions::{Column, DynamicFilterPhysicalExpr, lit};
51use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
52use futures::{Stream, StreamExt};
53use itertools::Itertools;
54use parking_lot::RwLock;
55use snafu::location;
56use store_api::region_engine::PartitionRange;
57
58use crate::error::Result;
59use crate::window_sort::check_partition_range_monotonicity;
60use crate::{array_iter_helper, downcast_ts_array};
61
62/// Get the primary end of a `PartitionRange` based on sort direction.
63///
64/// - Descending: primary end is `end` (we process highest values first)
65/// - Ascending: primary end is `start` (we process lowest values first)
66fn get_primary_end(range: &PartitionRange, descending: bool) -> Timestamp {
67    if descending { range.end } else { range.start }
68}
69
70/// Group consecutive ranges by their primary end value.
71///
72/// Returns a vector of (primary_end, start_idx_inclusive, end_idx_exclusive) tuples.
73/// Ranges with the same primary end MUST be processed together because they may
74/// overlap and contain values that belong to the same "top-k" result.
75fn group_ranges_by_primary_end(
76    ranges: &[PartitionRange],
77    descending: bool,
78) -> Vec<(Timestamp, usize, usize)> {
79    if ranges.is_empty() {
80        return vec![];
81    }
82
83    let mut groups = Vec::new();
84    let mut group_start = 0;
85    let mut current_primary_end = get_primary_end(&ranges[0], descending);
86
87    for (idx, range) in ranges.iter().enumerate().skip(1) {
88        let primary_end = get_primary_end(range, descending);
89        if primary_end != current_primary_end {
90            // End current group
91            groups.push((current_primary_end, group_start, idx));
92            // Start new group
93            group_start = idx;
94            current_primary_end = primary_end;
95        }
96    }
97    // Push the last group
98    groups.push((current_primary_end, group_start, ranges.len()));
99
100    groups
101}
102
103/// Sort input within given PartitionRange
104///
105/// Input is assumed to be segmented by empty RecordBatch, which indicates a new `PartitionRange` is starting
106///
107/// and this operator will sort each partition independently within the partition.
108#[derive(Debug, Clone)]
109pub struct PartSortExec {
110    /// Physical sort expressions(that is, sort by timestamp)
111    expression: PhysicalSortExpr,
112    limit: Option<usize>,
113    input: Arc<dyn ExecutionPlan>,
114    /// Execution metrics
115    metrics: ExecutionPlanMetricsSet,
116    partition_ranges: Vec<Vec<PartitionRange>>,
117    properties: Arc<PlanProperties>,
118    /// Filter matching the state of the sort for dynamic filter pushdown.
119    /// If `limit` is `Some`, this will also be set and a TopK operator may be used.
120    /// If `limit` is `None`, this will be `None`.
121    filter: Option<Arc<RwLock<TopKDynamicFilters>>>,
122}
123
124impl PartSortExec {
125    pub fn try_new(
126        expression: PhysicalSortExpr,
127        limit: Option<usize>,
128        partition_ranges: Vec<Vec<PartitionRange>>,
129        input: Arc<dyn ExecutionPlan>,
130    ) -> Result<Self> {
131        check_partition_range_monotonicity(&partition_ranges, expression.options.descending)?;
132
133        let metrics = ExecutionPlanMetricsSet::new();
134        let properties = input.properties();
135        let properties = Arc::new(PlanProperties::new(
136            input.equivalence_properties().clone(),
137            input.output_partitioning().clone(),
138            properties.emission_type,
139            properties.boundedness,
140        ));
141
142        let filter = limit
143            .is_some()
144            .then(|| Self::create_filter(expression.expr.clone()));
145
146        Ok(Self {
147            expression,
148            limit,
149            input,
150            metrics,
151            partition_ranges,
152            properties,
153            filter,
154        })
155    }
156
157    /// Add or reset `self.filter` to a new `TopKDynamicFilters`.
158    fn create_filter(expr: Arc<dyn PhysicalExpr>) -> Arc<RwLock<TopKDynamicFilters>> {
159        Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new(
160            DynamicFilterPhysicalExpr::new(vec![expr], lit(true)),
161        ))))
162    }
163
164    pub fn to_stream(
165        &self,
166        context: Arc<TaskContext>,
167        partition: usize,
168    ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
169        let input_stream: DfSendableRecordBatchStream =
170            self.input.execute(partition, context.clone())?;
171
172        if partition >= self.partition_ranges.len() {
173            internal_err!(
174                "Partition index out of range: {} >= {} at {}",
175                partition,
176                self.partition_ranges.len(),
177                snafu::location!()
178            )?;
179        }
180
181        let df_stream = Box::pin(PartSortStream::new(
182            context,
183            self,
184            self.limit,
185            input_stream,
186            self.partition_ranges[partition].clone(),
187            partition,
188            self.filter.clone(),
189        )?) as _;
190
191        Ok(df_stream)
192    }
193}
194
195impl DisplayAs for PartSortExec {
196    fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197        write!(
198            f,
199            "PartSortExec: expr={} num_ranges={}",
200            self.expression,
201            self.partition_ranges.len(),
202        )?;
203        if let Some(limit) = self.limit {
204            write!(f, " limit={}", limit)?;
205        }
206        Ok(())
207    }
208}
209
210impl ExecutionPlan for PartSortExec {
211    fn name(&self) -> &str {
212        "PartSortExec"
213    }
214
215    fn as_any(&self) -> &dyn Any {
216        self
217    }
218
219    fn schema(&self) -> SchemaRef {
220        self.input.schema()
221    }
222
223    fn properties(&self) -> &Arc<PlanProperties> {
224        &self.properties
225    }
226
227    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
228        vec![&self.input]
229    }
230
231    fn with_new_children(
232        self: Arc<Self>,
233        children: Vec<Arc<dyn ExecutionPlan>>,
234    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
235        let new_input = if let Some(first) = children.first() {
236            first
237        } else {
238            internal_err!("No children found")?
239        };
240        // create a new dynamic filter when with_new_children, as the old filter is bound to the old input and cannot be reused
241        let new = Self::try_new(
242            self.expression.clone(),
243            self.limit,
244            self.partition_ranges.clone(),
245            new_input.clone(),
246        )?;
247        Ok(Arc::new(new))
248    }
249
250    fn execute(
251        &self,
252        partition: usize,
253        context: Arc<TaskContext>,
254    ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
255        self.to_stream(context, partition)
256    }
257
258    fn metrics(&self) -> Option<MetricsSet> {
259        Some(self.metrics.clone_inner())
260    }
261
262    /// # Explain
263    ///
264    /// This plan needs to be executed on each partition independently,
265    /// and is expected to run directly on storage engine's output
266    /// distribution / partition.
267    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
268        vec![false]
269    }
270
271    fn cardinality_effect(&self) -> CardinalityEffect {
272        if self.limit.is_none() {
273            CardinalityEffect::Equal
274        } else {
275            CardinalityEffect::LowerEqual
276        }
277    }
278
279    fn gather_filters_for_pushdown(
280        &self,
281        phase: FilterPushdownPhase,
282        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
283        _config: &datafusion::config::ConfigOptions,
284    ) -> datafusion_common::Result<FilterDescription> {
285        if !matches!(phase, FilterPushdownPhase::Post) {
286            return FilterDescription::from_children(parent_filters, &self.children());
287        }
288
289        let mut child = ChildFilterDescription::from_child(&parent_filters, &self.input)?;
290
291        if let Some(filter) = &self.filter {
292            child = child.with_self_filter(filter.read().expr());
293        }
294
295        Ok(FilterDescription::new().with_child(child))
296    }
297
298    fn reset_state(self: Arc<Self>) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
299        // shared dynamic filter needs to be reset
300        let new_filter = self
301            .limit
302            .is_some()
303            .then(|| Self::create_filter(self.expression.expr.clone()));
304
305        Ok(Arc::new(Self {
306            expression: self.expression.clone(),
307            limit: self.limit,
308            input: self.input.clone(),
309            metrics: self.metrics.clone(),
310            partition_ranges: self.partition_ranges.clone(),
311            properties: self.properties.clone(),
312            filter: new_filter,
313        }))
314    }
315}
316
317enum PartSortBuffer {
318    All(Vec<DfRecordBatch>),
319    /// TopK buffer with row count.
320    ///
321    /// Given this heap only keeps k element, the capacity of this buffer
322    /// is not accurate, and is only used for empty check.
323    Top(TopK, usize),
324}
325
326impl PartSortBuffer {
327    pub fn is_empty(&self) -> bool {
328        match self {
329            PartSortBuffer::All(v) => v.is_empty(),
330            PartSortBuffer::Top(_, cnt) => *cnt == 0,
331        }
332    }
333}
334
335struct PartSortStream {
336    /// Memory pool for this stream
337    reservation: MemoryReservation,
338    buffer: PartSortBuffer,
339    expression: PhysicalSortExpr,
340    limit: Option<usize>,
341    input: DfSendableRecordBatchStream,
342    input_complete: bool,
343    schema: SchemaRef,
344    partition_ranges: Vec<PartitionRange>,
345    #[allow(dead_code)] // this is used under #[debug_assertions]
346    partition: usize,
347    cur_part_idx: usize,
348    evaluating_batch: Option<DfRecordBatch>,
349    metrics: BaselineMetrics,
350    context: Arc<TaskContext>,
351    root_metrics: ExecutionPlanMetricsSet,
352    /// Groups of ranges by primary end: (primary_end, start_idx_inclusive, end_idx_exclusive).
353    /// Ranges in the same group must be processed together before outputting results.
354    range_groups: Vec<(Timestamp, usize, usize)>,
355    /// Current group being processed (index into range_groups).
356    cur_group_idx: usize,
357    /// Dynamic Filter for all TopK instance, notice the `PartSortExec`/`PartSortStream`/`TopK` must share the same filter
358    /// so that updates from each `TopK` can be seen by others(and by the table scan operator).
359    filter: Option<Arc<RwLock<TopKDynamicFilters>>>,
360}
361
362impl PartSortStream {
363    fn new(
364        context: Arc<TaskContext>,
365        sort: &PartSortExec,
366        limit: Option<usize>,
367        input: DfSendableRecordBatchStream,
368        partition_ranges: Vec<PartitionRange>,
369        partition: usize,
370        filter: Option<Arc<RwLock<TopKDynamicFilters>>>,
371    ) -> datafusion_common::Result<Self> {
372        let buffer = if let Some(limit) = limit {
373            let Some(filter) = filter.clone() else {
374                return internal_err!(
375                    "TopKDynamicFilters must be provided when limit is set at {}",
376                    snafu::location!()
377                );
378            };
379
380            PartSortBuffer::Top(
381                TopK::try_new(
382                    partition,
383                    sort.schema().clone(),
384                    vec![],
385                    [sort.expression.clone()].into(),
386                    limit,
387                    context.session_config().batch_size(),
388                    context.runtime_env(),
389                    &sort.metrics,
390                    filter.clone(),
391                )?,
392                0,
393            )
394        } else {
395            PartSortBuffer::All(Vec::new())
396        };
397
398        // Compute range groups by primary end
399        let descending = sort.expression.options.descending;
400        let range_groups = group_ranges_by_primary_end(&partition_ranges, descending);
401
402        Ok(Self {
403            reservation: MemoryConsumer::new("PartSortStream".to_string())
404                .register(&context.runtime_env().memory_pool),
405            buffer,
406            expression: sort.expression.clone(),
407            limit,
408            input,
409            input_complete: false,
410            schema: sort.input.schema(),
411            partition_ranges,
412            partition,
413            cur_part_idx: 0,
414            evaluating_batch: None,
415            metrics: BaselineMetrics::new(&sort.metrics, partition),
416            context,
417            root_metrics: sort.metrics.clone(),
418            range_groups,
419            cur_group_idx: 0,
420            filter,
421        })
422    }
423}
424
425macro_rules! array_check_helper {
426    ($t:ty, $unit:expr, $arr:expr, $cur_range:expr, $min_max_idx:expr) => {{
427            if $cur_range.start.unit().as_arrow_time_unit() != $unit
428            || $cur_range.end.unit().as_arrow_time_unit() != $unit
429        {
430            internal_err!(
431                "PartitionRange unit mismatch, expect {:?}, found {:?}",
432                $cur_range.start.unit(),
433                $unit
434            )?;
435        }
436        let arr = $arr
437            .as_any()
438            .downcast_ref::<arrow::array::PrimitiveArray<$t>>()
439            .unwrap();
440
441        let min = arr.value($min_max_idx.0);
442        let max = arr.value($min_max_idx.1);
443        let (min, max) = if min < max{
444            (min, max)
445        } else {
446            (max, min)
447        };
448        let cur_min = $cur_range.start.value();
449        let cur_max = $cur_range.end.value();
450        // note that PartitionRange is left inclusive and right exclusive
451        if !(min >= cur_min && max < cur_max) {
452            internal_err!(
453                "Sort column min/max value out of partition range: sort_column.min_max=[{:?}, {:?}] not in PartitionRange=[{:?}, {:?}]",
454                min,
455                max,
456                cur_min,
457                cur_max
458            )?;
459        }
460    }};
461}
462
463impl PartSortStream {
464    /// check whether the sort column's min/max value is within the current group's effective range.
465    /// For group-based processing, data from multiple ranges with the same primary end
466    /// is accumulated together, so we check against the union of all ranges in the group.
467    fn check_in_range(
468        &self,
469        sort_column: &ArrayRef,
470        min_max_idx: (usize, usize),
471    ) -> datafusion_common::Result<()> {
472        // Use the group's effective range instead of the current partition range
473        let Some(cur_range) = self.get_current_group_effective_range() else {
474            internal_err!(
475                "No effective range for current group {} at {}",
476                self.cur_group_idx,
477                snafu::location!()
478            )?
479        };
480
481        downcast_ts_array!(
482            sort_column.data_type() => (array_check_helper, sort_column, cur_range, min_max_idx),
483            _ => internal_err!(
484                "Unsupported data type for sort column: {:?}",
485                sort_column.data_type()
486            )?,
487        );
488
489        Ok(())
490    }
491
492    /// Try find data whose value exceeds the current partition range.
493    ///
494    /// Returns `None` if no such data is found, and `Some(idx)` where idx points to
495    /// the first data that exceeds the current partition range.
496    fn try_find_next_range(
497        &self,
498        sort_column: &ArrayRef,
499    ) -> datafusion_common::Result<Option<usize>> {
500        if sort_column.is_empty() {
501            return Ok(None);
502        }
503
504        // check if the current partition index is out of range
505        if self.cur_part_idx >= self.partition_ranges.len() {
506            internal_err!(
507                "Partition index out of range: {} >= {} at {}",
508                self.cur_part_idx,
509                self.partition_ranges.len(),
510                snafu::location!()
511            )?;
512        }
513        let cur_range = self.partition_ranges[self.cur_part_idx];
514
515        let sort_column_iter = downcast_ts_array!(
516            sort_column.data_type() => (array_iter_helper, sort_column),
517            _ => internal_err!(
518                "Unsupported data type for sort column: {:?}",
519                sort_column.data_type()
520            )?,
521        );
522
523        for (idx, val) in sort_column_iter {
524            // ignore vacant time index data
525            if let Some(val) = val
526                && (val >= cur_range.end.value() || val < cur_range.start.value())
527            {
528                return Ok(Some(idx));
529            }
530        }
531
532        Ok(None)
533    }
534
535    fn push_buffer(&mut self, batch: DfRecordBatch) -> datafusion_common::Result<()> {
536        match &mut self.buffer {
537            PartSortBuffer::All(v) => v.push(batch),
538            PartSortBuffer::Top(top, cnt) => {
539                *cnt += batch.num_rows();
540                top.insert_batch(batch)?;
541            }
542        }
543
544        Ok(())
545    }
546
547    /// Stop read earlier when current group do not overlap with any of those next group
548    /// If not overlap, we can stop read further input as current top k is final
549    /// Use dynamic filter to evaluate the next group's primary end
550    fn can_stop_early(&mut self, schema: &Arc<Schema>) -> datafusion_common::Result<bool> {
551        let topk_cnt = match &self.buffer {
552            PartSortBuffer::Top(_, cnt) => *cnt,
553            _ => return Ok(false),
554        };
555        // not fulfill topk yet
556        if Some(topk_cnt) < self.limit {
557            return Ok(false);
558        }
559        let next_group_primary_end = if self.cur_group_idx + 1 < self.range_groups.len() {
560            self.range_groups[self.cur_group_idx + 1].0
561        } else {
562            // no next group
563            return Ok(false);
564        };
565
566        // dyn filter is updated based on the last value of topk heap("threshold")
567        // it's a max-heap for a ASC TopK operator
568        // so can use dyn filter to prune data range
569        let filter = self
570            .filter
571            .as_ref()
572            .expect("TopKDynamicFilters must be provided when limit is set");
573        let filter = filter.read().expr().current()?;
574        let mut ts_index = None;
575        // invariant: the filter must contain only the same column expr that's time index column
576        let filter = filter
577            .transform_down(|c| {
578                // rewrite all column's index as 0
579                if let Some(column) = c.as_any().downcast_ref::<Column>() {
580                    ts_index = Some(column.index());
581                    Ok(Transformed::yes(
582                        Arc::new(Column::new(column.name(), 0)) as Arc<dyn PhysicalExpr>
583                    ))
584                } else {
585                    Ok(Transformed::no(c))
586                }
587            })?
588            .data;
589        let Some(ts_index) = ts_index else {
590            return Ok(false); // dyn filter is still true, cannot decide, continue read
591        };
592        let field = if schema.fields().len() <= ts_index {
593            warn!(
594                "Schema mismatch when evaluating dynamic filter for PartSortExec at {}, schema: {:?}, ts_index: {}",
595                self.partition, schema, ts_index
596            );
597            return Ok(false); // schema mismatch, cannot decide, continue read
598        } else {
599            schema.field(ts_index)
600        };
601        let schema = Arc::new(Schema::new(vec![field.clone()]));
602        // convert next_group_primary_end to array&filter, if eval to false, means no overlap, can stop early
603        let primary_end_array = match next_group_primary_end.unit() {
604            TimeUnit::Second => Arc::new(TimestampSecondArray::from(vec![
605                next_group_primary_end.value(),
606            ])) as ArrayRef,
607            TimeUnit::Millisecond => Arc::new(TimestampMillisecondArray::from(vec![
608                next_group_primary_end.value(),
609            ])) as ArrayRef,
610            TimeUnit::Microsecond => Arc::new(TimestampMicrosecondArray::from(vec![
611                next_group_primary_end.value(),
612            ])) as ArrayRef,
613            TimeUnit::Nanosecond => Arc::new(TimestampNanosecondArray::from(vec![
614                next_group_primary_end.value(),
615            ])) as ArrayRef,
616        };
617        let primary_end_batch = DfRecordBatch::try_new(schema, vec![primary_end_array])?;
618        let res = filter.evaluate(&primary_end_batch)?;
619        let array = res.into_array(primary_end_batch.num_rows())?;
620        let filter = array.as_boolean().clone();
621        let overlap = filter.iter().next().flatten();
622        if let Some(false) = overlap {
623            Ok(true)
624        } else {
625            Ok(false)
626        }
627    }
628
629    /// Check if the given partition index is within the current group.
630    fn is_in_current_group(&self, part_idx: usize) -> bool {
631        if self.cur_group_idx >= self.range_groups.len() {
632            return false;
633        }
634        let (_, start, end) = self.range_groups[self.cur_group_idx];
635        part_idx >= start && part_idx < end
636    }
637
638    /// Advance to the next group. Returns true if there is a next group.
639    fn advance_to_next_group(&mut self) -> bool {
640        self.cur_group_idx += 1;
641        self.cur_group_idx < self.range_groups.len()
642    }
643
644    /// Get the effective range for the current group.
645    /// For a group of ranges with the same primary end, the effective range is
646    /// the union of all ranges in the group.
647    fn get_current_group_effective_range(&self) -> Option<PartitionRange> {
648        if self.cur_group_idx >= self.range_groups.len() {
649            return None;
650        }
651        let (_, start_idx, end_idx) = self.range_groups[self.cur_group_idx];
652        if start_idx >= end_idx || start_idx >= self.partition_ranges.len() {
653            return None;
654        }
655
656        let ranges_in_group =
657            &self.partition_ranges[start_idx..end_idx.min(self.partition_ranges.len())];
658        if ranges_in_group.is_empty() {
659            return None;
660        }
661
662        // Compute union of all ranges in the group
663        let mut min_start = ranges_in_group[0].start;
664        let mut max_end = ranges_in_group[0].end;
665        for range in ranges_in_group.iter().skip(1) {
666            if range.start < min_start {
667                min_start = range.start;
668            }
669            if range.end > max_end {
670                max_end = range.end;
671            }
672        }
673
674        Some(PartitionRange {
675            start: min_start,
676            end: max_end,
677            num_rows: 0,   // Not used for validation
678            identifier: 0, // Not used for validation
679        })
680    }
681
682    /// Sort and clear the buffer and return the sorted record batch
683    ///
684    /// this function will return a empty record batch if the buffer is empty
685    fn sort_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
686        match &mut self.buffer {
687            PartSortBuffer::All(_) => self.sort_all_buffer(),
688            PartSortBuffer::Top(_, _) => self.sort_top_buffer(),
689        }
690    }
691
692    /// Internal method for sorting `All` buffer (without limit).
693    fn sort_all_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
694        let PartSortBuffer::All(buffer) =
695            std::mem::replace(&mut self.buffer, PartSortBuffer::All(Vec::new()))
696        else {
697            unreachable!("buffer type is checked before and should be All variant")
698        };
699
700        if buffer.is_empty() {
701            return Ok(DfRecordBatch::new_empty(self.schema.clone()));
702        }
703        let mut sort_columns = Vec::with_capacity(buffer.len());
704        let mut opt = None;
705        for batch in buffer.iter() {
706            let sort_column = self.expression.evaluate_to_sort_column(batch)?;
707            opt = opt.or(sort_column.options);
708            sort_columns.push(sort_column.values);
709        }
710
711        let sort_column =
712            concat(&sort_columns.iter().map(|a| a.as_ref()).collect_vec()).map_err(|e| {
713                DataFusionError::ArrowError(
714                    Box::new(e),
715                    Some(format!("Fail to concat sort columns at {}", location!())),
716                )
717            })?;
718
719        let indices = sort_to_indices(&sort_column, opt, self.limit).map_err(|e| {
720            DataFusionError::ArrowError(
721                Box::new(e),
722                Some(format!("Fail to sort to indices at {}", location!())),
723            )
724        })?;
725        if indices.is_empty() {
726            return Ok(DfRecordBatch::new_empty(self.schema.clone()));
727        }
728
729        self.check_in_range(
730            &sort_column,
731            (
732                indices.value(0) as usize,
733                indices.value(indices.len() - 1) as usize,
734            ),
735        )
736        .inspect_err(|_e| {
737            #[cfg(debug_assertions)]
738            common_telemetry::error!(
739                "Fail to check sort column in range at {}, current_idx: {}, num_rows: {}, err: {}",
740                self.partition,
741                self.cur_part_idx,
742                sort_column.len(),
743                _e
744            );
745        })?;
746
747        // reserve memory for the concat input and sorted output
748        let total_mem: usize = buffer.iter().map(|r| r.get_array_memory_size()).sum();
749        self.reservation.try_grow(total_mem * 2)?;
750
751        let full_input = concat_batches(&self.schema, &buffer).map_err(|e| {
752            DataFusionError::ArrowError(
753                Box::new(e),
754                Some(format!(
755                    "Fail to concat input batches when sorting at {}",
756                    location!()
757                )),
758            )
759        })?;
760
761        let sorted = take_record_batch(&full_input, &indices).map_err(|e| {
762            DataFusionError::ArrowError(
763                Box::new(e),
764                Some(format!(
765                    "Fail to take result record batch when sorting at {}",
766                    location!()
767                )),
768            )
769        })?;
770
771        drop(full_input);
772        // here remove both buffer and full_input memory
773        self.reservation.shrink(2 * total_mem);
774        Ok(sorted)
775    }
776
777    /// Internal method for sorting `Top` buffer (with limit).
778    fn sort_top_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
779        let Some(filter) = self.filter.clone() else {
780            return internal_err!(
781                "TopKDynamicFilters must be provided when sorting with limit at {}",
782                snafu::location!()
783            );
784        };
785
786        let new_top_buffer = TopK::try_new(
787            self.partition,
788            self.schema().clone(),
789            vec![],
790            [self.expression.clone()].into(),
791            self.limit.unwrap(),
792            self.context.session_config().batch_size(),
793            self.context.runtime_env(),
794            &self.root_metrics,
795            filter,
796        )?;
797        let PartSortBuffer::Top(top_k, _) =
798            std::mem::replace(&mut self.buffer, PartSortBuffer::Top(new_top_buffer, 0))
799        else {
800            unreachable!("buffer type is checked before and should be Top variant")
801        };
802
803        let mut result_stream = top_k.emit()?;
804        let mut placeholder_ctx = std::task::Context::from_waker(futures::task::noop_waker_ref());
805        let mut results = vec![];
806        // according to the current implementation of `TopK`, the result stream will always be ready
807        loop {
808            match result_stream.poll_next_unpin(&mut placeholder_ctx) {
809                Poll::Ready(Some(batch)) => {
810                    let batch = batch?;
811                    results.push(batch);
812                }
813                Poll::Pending => {
814                    #[cfg(debug_assertions)]
815                    unreachable!("TopK result stream should always be ready")
816                }
817                Poll::Ready(None) => {
818                    break;
819                }
820            }
821        }
822
823        let concat_batch = concat_batches(&self.schema, &results).map_err(|e| {
824            DataFusionError::ArrowError(
825                Box::new(e),
826                Some(format!(
827                    "Fail to concat top k result record batch when sorting at {}",
828                    location!()
829                )),
830            )
831        })?;
832
833        Ok(concat_batch)
834    }
835
836    /// Sorts current buffer and returns `None` when there is nothing to emit.
837    fn sorted_buffer_if_non_empty(&mut self) -> datafusion_common::Result<Option<DfRecordBatch>> {
838        if self.buffer.is_empty() {
839            return Ok(None);
840        }
841
842        let sorted = self.sort_buffer()?;
843        if sorted.num_rows() == 0 {
844            Ok(None)
845        } else {
846            Ok(Some(sorted))
847        }
848    }
849
850    /// Try to split the input batch if it contains data that exceeds the current partition range.
851    ///
852    /// When the input batch contains data that exceeds the current partition range, this function
853    /// will split the input batch into two parts, the first part is within the current partition
854    /// range will be merged and sorted with previous buffer, and the second part will be registered
855    /// to `evaluating_batch` for next polling.
856    ///
857    /// **Group-based processing**: Ranges with the same primary end are grouped together.
858    /// We only sort and output when transitioning to a NEW group, not when moving between
859    /// ranges within the same group.
860    ///
861    /// Returns `None` if the input batch is empty or fully within the current partition range
862    /// (or we're still collecting data within the same group), and `Some(batch)` when we've
863    /// completed a group and have sorted output. When operating in TopK (limit) mode, this
864    /// function will not emit intermediate batches; it only prepares state for a single final
865    /// output.
866    fn split_batch(
867        &mut self,
868        batch: DfRecordBatch,
869    ) -> datafusion_common::Result<Option<DfRecordBatch>> {
870        if matches!(self.buffer, PartSortBuffer::Top(_, _)) {
871            self.split_batch_topk(batch)?;
872            return Ok(None);
873        }
874
875        self.split_batch_all(batch)
876    }
877
878    /// Specialized splitting logic for TopK (limit) mode.
879    ///
880    /// We only emit once when the TopK buffer is fulfilled or when input is fully consumed.
881    /// When the buffer is fulfilled and we are about to enter a new group, we stop consuming
882    /// further ranges.
883    fn split_batch_topk(&mut self, batch: DfRecordBatch) -> datafusion_common::Result<()> {
884        if batch.num_rows() == 0 {
885            return Ok(());
886        }
887
888        let sort_column = self
889            .expression
890            .expr
891            .evaluate(&batch)?
892            .into_array(batch.num_rows())?;
893
894        let next_range_idx = self.try_find_next_range(&sort_column)?;
895        let Some(idx) = next_range_idx else {
896            self.push_buffer(batch)?;
897            // keep polling input for next batch
898            return Ok(());
899        };
900
901        let this_range = batch.slice(0, idx);
902        let remaining_range = batch.slice(idx, batch.num_rows() - idx);
903        if this_range.num_rows() != 0 {
904            self.push_buffer(this_range)?;
905        }
906
907        // Step to next proper PartitionRange
908        self.cur_part_idx += 1;
909
910        // If we've processed all partitions, mark completion.
911        if self.cur_part_idx >= self.partition_ranges.len() {
912            debug_assert!(remaining_range.num_rows() == 0);
913            self.input_complete = true;
914            return Ok(());
915        }
916
917        // Check if we're still in the same group
918        let in_same_group = self.is_in_current_group(self.cur_part_idx);
919
920        // When TopK is fulfilled and we are switching to a new group, stop consuming further ranges if possible.
921        // read from topk heap and determine whether we can stop earlier.
922        if !in_same_group && self.can_stop_early(&batch.schema())? {
923            self.input_complete = true;
924            self.evaluating_batch = None;
925            return Ok(());
926        }
927
928        // Transition to a new group if needed
929        if !in_same_group {
930            self.advance_to_next_group();
931        }
932
933        let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
934        if self.try_find_next_range(&next_sort_column)?.is_some() {
935            // remaining batch still contains data that exceeds the current partition range
936            // register the remaining batch for next polling
937            self.evaluating_batch = Some(remaining_range);
938        } else if remaining_range.num_rows() != 0 {
939            // remaining batch is within the current partition range
940            // push to the buffer and continue polling
941            self.push_buffer(remaining_range)?;
942        }
943
944        Ok(())
945    }
946
947    fn split_batch_all(
948        &mut self,
949        batch: DfRecordBatch,
950    ) -> datafusion_common::Result<Option<DfRecordBatch>> {
951        if batch.num_rows() == 0 {
952            return Ok(None);
953        }
954
955        let sort_column = self
956            .expression
957            .expr
958            .evaluate(&batch)?
959            .into_array(batch.num_rows())?;
960
961        let next_range_idx = self.try_find_next_range(&sort_column)?;
962        let Some(idx) = next_range_idx else {
963            self.push_buffer(batch)?;
964            // keep polling input for next batch
965            return Ok(None);
966        };
967
968        let this_range = batch.slice(0, idx);
969        let remaining_range = batch.slice(idx, batch.num_rows() - idx);
970        if this_range.num_rows() != 0 {
971            self.push_buffer(this_range)?;
972        }
973
974        // Step to next proper PartitionRange
975        self.cur_part_idx += 1;
976
977        // If we've processed all partitions, sort and output
978        if self.cur_part_idx >= self.partition_ranges.len() {
979            // assert there is no data beyond the last partition range (remaining is empty).
980            debug_assert!(remaining_range.num_rows() == 0);
981
982            // Sort and output the final group
983            return self.sorted_buffer_if_non_empty();
984        }
985
986        // Check if we're still in the same group
987        if self.is_in_current_group(self.cur_part_idx) {
988            // Same group - don't sort yet, keep collecting
989            let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
990            if self.try_find_next_range(&next_sort_column)?.is_some() {
991                // remaining batch still contains data that exceeds the current partition range
992                self.evaluating_batch = Some(remaining_range);
993            } else {
994                // remaining batch is within the current partition range
995                if remaining_range.num_rows() != 0 {
996                    self.push_buffer(remaining_range)?;
997                }
998            }
999            // Return None to continue collecting within the same group
1000            return Ok(None);
1001        }
1002
1003        // Transitioning to a new group - sort current group and output
1004        let sorted_batch = self.sorted_buffer_if_non_empty()?;
1005        self.advance_to_next_group();
1006
1007        let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
1008        if self.try_find_next_range(&next_sort_column)?.is_some() {
1009            // remaining batch still contains data that exceeds the current partition range
1010            // register the remaining batch for next polling
1011            self.evaluating_batch = Some(remaining_range);
1012        } else {
1013            // remaining batch is within the current partition range
1014            // push to the buffer and continue polling
1015            if remaining_range.num_rows() != 0 {
1016                self.push_buffer(remaining_range)?;
1017            }
1018        }
1019
1020        Ok(sorted_batch)
1021    }
1022
1023    pub fn poll_next_inner(
1024        mut self: Pin<&mut Self>,
1025        cx: &mut Context<'_>,
1026    ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
1027        loop {
1028            if self.input_complete {
1029                if let Some(sorted_batch) = self.sorted_buffer_if_non_empty()? {
1030                    return Poll::Ready(Some(Ok(sorted_batch)));
1031                }
1032                return Poll::Ready(None);
1033            }
1034
1035            // if there is a remaining batch being evaluated from last run,
1036            // split on it instead of fetching new batch
1037            if let Some(evaluating_batch) = self.evaluating_batch.take()
1038                && evaluating_batch.num_rows() != 0
1039            {
1040                // Check if we've already processed all partitions
1041                if self.cur_part_idx >= self.partition_ranges.len() {
1042                    // All partitions processed, discard remaining data
1043                    if let Some(sorted_batch) = self.sorted_buffer_if_non_empty()? {
1044                        return Poll::Ready(Some(Ok(sorted_batch)));
1045                    }
1046                    return Poll::Ready(None);
1047                }
1048
1049                if let Some(sorted_batch) = self.split_batch(evaluating_batch)? {
1050                    return Poll::Ready(Some(Ok(sorted_batch)));
1051                }
1052                continue;
1053            }
1054
1055            // fetch next batch from input
1056            let res = self.input.as_mut().poll_next(cx);
1057            match res {
1058                Poll::Ready(Some(Ok(batch))) => {
1059                    if let Some(sorted_batch) = self.split_batch(batch)? {
1060                        return Poll::Ready(Some(Ok(sorted_batch)));
1061                    }
1062                }
1063                // input stream end, mark and continue
1064                Poll::Ready(None) => {
1065                    self.input_complete = true;
1066                }
1067                Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
1068                Poll::Pending => return Poll::Pending,
1069            }
1070        }
1071    }
1072}
1073
1074impl Stream for PartSortStream {
1075    type Item = datafusion_common::Result<DfRecordBatch>;
1076
1077    fn poll_next(
1078        mut self: Pin<&mut Self>,
1079        cx: &mut Context<'_>,
1080    ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
1081        let result = self.as_mut().poll_next_inner(cx);
1082        self.metrics.record_poll(result)
1083    }
1084}
1085
1086impl RecordBatchStream for PartSortStream {
1087    fn schema(&self) -> SchemaRef {
1088        self.schema.clone()
1089    }
1090}
1091
1092#[cfg(test)]
1093mod test {
1094    use std::sync::Arc;
1095
1096    use arrow::array::{
1097        TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
1098        TimestampSecondArray,
1099    };
1100    use arrow::json::ArrayWriter;
1101    use arrow_schema::{DataType, Field, Schema, SortOptions, TimeUnit};
1102    use common_time::Timestamp;
1103    use datafusion_physical_expr::expressions::Column;
1104    use futures::StreamExt;
1105    use store_api::region_engine::PartitionRange;
1106
1107    use super::*;
1108    use crate::test_util::{MockInputExec, new_ts_array};
1109
1110    #[tokio::test]
1111    async fn test_can_stop_early_with_empty_topk_buffer() {
1112        let unit = TimeUnit::Millisecond;
1113        let schema = Arc::new(Schema::new(vec![Field::new(
1114            "ts",
1115            DataType::Timestamp(unit, None),
1116            false,
1117        )]));
1118
1119        // Build a minimal PartSortExec and stream, but inject a dynamic filter that
1120        // always evaluates to false so TopK will filter out all rows internally.
1121        let mock_input = Arc::new(MockInputExec::new(vec![vec![]], schema.clone()));
1122        let exec = PartSortExec::try_new(
1123            PhysicalSortExpr {
1124                expr: Arc::new(Column::new("ts", 0)),
1125                options: SortOptions {
1126                    descending: true,
1127                    ..Default::default()
1128                },
1129            },
1130            Some(3),
1131            vec![vec![]],
1132            mock_input.clone(),
1133        )
1134        .unwrap();
1135
1136        let filter = Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new(
1137            DynamicFilterPhysicalExpr::new(vec![], lit(false)),
1138        ))));
1139
1140        let input_stream = mock_input
1141            .execute(0, Arc::new(TaskContext::default()))
1142            .unwrap();
1143        let mut stream = PartSortStream::new(
1144            Arc::new(TaskContext::default()),
1145            &exec,
1146            Some(3),
1147            input_stream,
1148            vec![],
1149            0,
1150            Some(filter),
1151        )
1152        .unwrap();
1153
1154        // Push 3 rows so the external counter reaches `limit`, while TopK keeps no rows.
1155        let batch = DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])])
1156            .unwrap();
1157        stream.push_buffer(batch).unwrap();
1158
1159        // The TopK result buffer is empty, so we cannot determine early-stop.
1160        // Ensure this path returns `Ok(false)` (and, importantly, does not panic).
1161        assert!(!stream.can_stop_early(&schema).unwrap());
1162    }
1163
1164    #[ignore = "hard to gen expected data correctly here, TODO(discord9): fix it later"]
1165    #[tokio::test]
1166    async fn fuzzy_test() {
1167        let test_cnt = 100;
1168        // bound for total count of PartitionRange
1169        let part_cnt_bound = 100;
1170        // bound for timestamp range size and offset for each PartitionRange
1171        let range_size_bound = 100;
1172        let range_offset_bound = 100;
1173        // bound for batch count and size within each PartitionRange
1174        let batch_cnt_bound = 20;
1175        let batch_size_bound = 100;
1176
1177        let mut rng = fastrand::Rng::new();
1178        rng.seed(1337);
1179
1180        let mut test_cases = Vec::new();
1181
1182        for case_id in 0..test_cnt {
1183            let mut bound_val: Option<i64> = None;
1184            let descending = rng.bool();
1185            let nulls_first = rng.bool();
1186            let opt = SortOptions {
1187                descending,
1188                nulls_first,
1189            };
1190            let limit = if rng.bool() {
1191                Some(rng.usize(1..batch_cnt_bound * batch_size_bound))
1192            } else {
1193                None
1194            };
1195            let unit = match rng.u8(0..3) {
1196                0 => TimeUnit::Second,
1197                1 => TimeUnit::Millisecond,
1198                2 => TimeUnit::Microsecond,
1199                _ => TimeUnit::Nanosecond,
1200            };
1201
1202            let schema = Schema::new(vec![Field::new(
1203                "ts",
1204                DataType::Timestamp(unit, None),
1205                false,
1206            )]);
1207            let schema = Arc::new(schema);
1208
1209            let mut input_ranged_data = vec![];
1210            let mut output_ranges = vec![];
1211            let mut output_data = vec![];
1212            // generate each input `PartitionRange`
1213            for part_id in 0..rng.usize(0..part_cnt_bound) {
1214                // generate each `PartitionRange`'s timestamp range
1215                let (start, end) = if descending {
1216                    // Use 1..=range_offset_bound to ensure strictly decreasing end values
1217                    let end = bound_val
1218                        .map(
1219                            |i| i
1220                            .checked_sub(rng.i64(1..=range_offset_bound))
1221                            .expect("Bad luck, fuzzy test generate data that will overflow, change seed and try again")
1222                        )
1223                        .unwrap_or_else(|| rng.i64(-100000000..100000000));
1224                    bound_val = Some(end);
1225                    let start = end - rng.i64(1..range_size_bound);
1226                    let start = Timestamp::new(start, unit.into());
1227                    let end = Timestamp::new(end, unit.into());
1228                    (start, end)
1229                } else {
1230                    // Use 1..=range_offset_bound to ensure strictly increasing start values
1231                    let start = bound_val
1232                        .map(|i| i + rng.i64(1..=range_offset_bound))
1233                        .unwrap_or_else(|| rng.i64(..));
1234                    bound_val = Some(start);
1235                    let end = start + rng.i64(1..range_size_bound);
1236                    let start = Timestamp::new(start, unit.into());
1237                    let end = Timestamp::new(end, unit.into());
1238                    (start, end)
1239                };
1240                assert!(start < end);
1241
1242                let mut per_part_sort_data = vec![];
1243                let mut batches = vec![];
1244                for _batch_idx in 0..rng.usize(1..batch_cnt_bound) {
1245                    let cnt = rng.usize(0..batch_size_bound) + 1;
1246                    let iter = 0..rng.usize(0..cnt);
1247                    let mut data_gen = iter
1248                        .map(|_| rng.i64(start.value()..end.value()))
1249                        .collect_vec();
1250                    if data_gen.is_empty() {
1251                        // current batch is empty, skip
1252                        continue;
1253                    }
1254                    // mito always sort on ASC order
1255                    data_gen.sort();
1256                    per_part_sort_data.extend(data_gen.clone());
1257                    let arr = new_ts_array(unit, data_gen.clone());
1258                    let batch = DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap();
1259                    batches.push(batch);
1260                }
1261
1262                let range = PartitionRange {
1263                    start,
1264                    end,
1265                    num_rows: batches.iter().map(|b| b.num_rows()).sum(),
1266                    identifier: part_id,
1267                };
1268                input_ranged_data.push((range, batches));
1269
1270                output_ranges.push(range);
1271                if per_part_sort_data.is_empty() {
1272                    continue;
1273                }
1274                output_data.extend_from_slice(&per_part_sort_data);
1275            }
1276
1277            // adjust output data with adjacent PartitionRanges
1278            let mut output_data_iter = output_data.iter().peekable();
1279            let mut output_data = vec![];
1280            for range in output_ranges.clone() {
1281                let mut cur_data = vec![];
1282                while let Some(val) = output_data_iter.peek() {
1283                    if **val < range.start.value() || **val >= range.end.value() {
1284                        break;
1285                    }
1286                    cur_data.push(*output_data_iter.next().unwrap());
1287                }
1288
1289                if cur_data.is_empty() {
1290                    continue;
1291                }
1292
1293                if descending {
1294                    cur_data.sort_by(|a, b| b.cmp(a));
1295                } else {
1296                    cur_data.sort();
1297                }
1298                output_data.push(cur_data);
1299            }
1300
1301            let expected_output = if let Some(limit) = limit {
1302                let mut accumulated = Vec::new();
1303                let mut seen = 0usize;
1304                for mut range_values in output_data {
1305                    seen += range_values.len();
1306                    accumulated.append(&mut range_values);
1307                    if seen >= limit {
1308                        break;
1309                    }
1310                }
1311
1312                if accumulated.is_empty() {
1313                    None
1314                } else {
1315                    if descending {
1316                        accumulated.sort_by(|a, b| b.cmp(a));
1317                    } else {
1318                        accumulated.sort();
1319                    }
1320                    accumulated.truncate(limit.min(accumulated.len()));
1321
1322                    Some(
1323                        DfRecordBatch::try_new(
1324                            schema.clone(),
1325                            vec![new_ts_array(unit, accumulated)],
1326                        )
1327                        .unwrap(),
1328                    )
1329                }
1330            } else {
1331                let batches = output_data
1332                    .into_iter()
1333                    .map(|a| {
1334                        DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
1335                    })
1336                    .collect_vec();
1337                if batches.is_empty() {
1338                    None
1339                } else {
1340                    Some(concat_batches(&schema, &batches).unwrap())
1341                }
1342            };
1343
1344            test_cases.push((
1345                case_id,
1346                unit,
1347                input_ranged_data,
1348                schema,
1349                opt,
1350                limit,
1351                expected_output,
1352            ));
1353        }
1354
1355        for (case_id, _unit, input_ranged_data, schema, opt, limit, expected_output) in test_cases {
1356            run_test(
1357                case_id,
1358                input_ranged_data,
1359                schema,
1360                opt,
1361                limit,
1362                expected_output,
1363                None,
1364            )
1365            .await;
1366        }
1367    }
1368
1369    #[tokio::test]
1370    async fn simple_cases() {
1371        let testcases = vec![
1372            (
1373                TimeUnit::Millisecond,
1374                vec![
1375                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]),
1376                    ((5, 10), vec![vec![5, 6], vec![7, 8]]),
1377                ],
1378                false,
1379                None,
1380                vec![vec![1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9]],
1381            ),
1382            // Case 1: Descending sort with overlapping ranges that have the same primary end (end=10).
1383            // Ranges [5,10) and [0,10) are grouped together, so their data is merged before sorting.
1384            (
1385                TimeUnit::Millisecond,
1386                vec![
1387                    ((5, 10), vec![vec![5, 6], vec![7, 8, 9]]),
1388                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
1389                ],
1390                true,
1391                None,
1392                vec![vec![9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 3, 2, 1]],
1393            ),
1394            (
1395                TimeUnit::Millisecond,
1396                vec![
1397                    ((5, 10), vec![]),
1398                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
1399                ],
1400                true,
1401                None,
1402                vec![vec![8, 7, 6, 5, 4, 3, 2, 1]],
1403            ),
1404            (
1405                TimeUnit::Millisecond,
1406                vec![
1407                    ((15, 20), vec![vec![17, 18, 19]]),
1408                    ((10, 15), vec![]),
1409                    ((5, 10), vec![]),
1410                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
1411                ],
1412                true,
1413                None,
1414                vec![vec![19, 18, 17], vec![8, 7, 6, 5, 4, 3, 2, 1]],
1415            ),
1416            (
1417                TimeUnit::Millisecond,
1418                vec![
1419                    ((15, 20), vec![]),
1420                    ((10, 15), vec![]),
1421                    ((5, 10), vec![]),
1422                    ((0, 10), vec![]),
1423                ],
1424                true,
1425                None,
1426                vec![],
1427            ),
1428            // Case 5: Data from one batch spans multiple ranges. Ranges with same end are grouped.
1429            // Ranges: [15,20) end=20, [10,15) end=15, [5,10) end=10, [0,10) end=10
1430            // Groups: {[15,20)}, {[10,15)}, {[5,10), [0,10)}
1431            // The last two ranges are merged because they share end=10.
1432            (
1433                TimeUnit::Millisecond,
1434                vec![
1435                    (
1436                        (15, 20),
1437                        vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
1438                    ),
1439                    ((10, 15), vec![]),
1440                    ((5, 10), vec![]),
1441                    ((0, 10), vec![]),
1442                ],
1443                true,
1444                None,
1445                vec![
1446                    vec![19, 17, 15],
1447                    vec![12, 11, 10],
1448                    vec![9, 8, 7, 6, 5, 4, 3, 2, 1],
1449                ],
1450            ),
1451            (
1452                TimeUnit::Millisecond,
1453                vec![
1454                    (
1455                        (15, 20),
1456                        vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
1457                    ),
1458                    ((10, 15), vec![]),
1459                    ((5, 10), vec![]),
1460                    ((0, 10), vec![]),
1461                ],
1462                true,
1463                Some(2),
1464                vec![vec![19, 17]],
1465            ),
1466        ];
1467
1468        for (identifier, (unit, input_ranged_data, descending, limit, expected_output)) in
1469            testcases.into_iter().enumerate()
1470        {
1471            let schema = Schema::new(vec![Field::new(
1472                "ts",
1473                DataType::Timestamp(unit, None),
1474                false,
1475            )]);
1476            let schema = Arc::new(schema);
1477            let opt = SortOptions {
1478                descending,
1479                ..Default::default()
1480            };
1481
1482            let input_ranged_data = input_ranged_data
1483                .into_iter()
1484                .map(|(range, data)| {
1485                    let part = PartitionRange {
1486                        start: Timestamp::new(range.0, unit.into()),
1487                        end: Timestamp::new(range.1, unit.into()),
1488                        num_rows: data.iter().map(|b| b.len()).sum(),
1489                        identifier,
1490                    };
1491
1492                    let batches = data
1493                        .into_iter()
1494                        .map(|b| {
1495                            let arr = new_ts_array(unit, b);
1496                            DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap()
1497                        })
1498                        .collect_vec();
1499                    (part, batches)
1500                })
1501                .collect_vec();
1502
1503            let expected_output = expected_output
1504                .into_iter()
1505                .map(|a| {
1506                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
1507                })
1508                .collect_vec();
1509            let expected_output = if expected_output.is_empty() {
1510                None
1511            } else {
1512                Some(concat_batches(&schema, &expected_output).unwrap())
1513            };
1514
1515            run_test(
1516                identifier,
1517                input_ranged_data,
1518                schema.clone(),
1519                opt,
1520                limit,
1521                expected_output,
1522                None,
1523            )
1524            .await;
1525        }
1526    }
1527
1528    #[allow(clippy::print_stdout)]
1529    async fn run_test(
1530        case_id: usize,
1531        input_ranged_data: Vec<(PartitionRange, Vec<DfRecordBatch>)>,
1532        schema: SchemaRef,
1533        opt: SortOptions,
1534        limit: Option<usize>,
1535        expected_output: Option<DfRecordBatch>,
1536        expected_polled_rows: Option<usize>,
1537    ) {
1538        if let (Some(limit), Some(rb)) = (limit, &expected_output) {
1539            assert!(
1540                rb.num_rows() <= limit,
1541                "Expect row count in expected output({}) <= limit({})",
1542                rb.num_rows(),
1543                limit
1544            );
1545        }
1546
1547        let mut data_partition = Vec::with_capacity(input_ranged_data.len());
1548        let mut ranges = Vec::with_capacity(input_ranged_data.len());
1549        for (part_range, batches) in input_ranged_data {
1550            data_partition.push(batches);
1551            ranges.push(part_range);
1552        }
1553
1554        let mock_input = Arc::new(MockInputExec::new(data_partition, schema.clone()));
1555
1556        let exec = PartSortExec::try_new(
1557            PhysicalSortExpr {
1558                expr: Arc::new(Column::new("ts", 0)),
1559                options: opt,
1560            },
1561            limit,
1562            vec![ranges.clone()],
1563            mock_input.clone(),
1564        )
1565        .unwrap();
1566
1567        let exec_stream = exec.execute(0, Arc::new(TaskContext::default())).unwrap();
1568
1569        let real_output = exec_stream.map(|r| r.unwrap()).collect::<Vec<_>>().await;
1570        if limit.is_some() {
1571            assert!(
1572                real_output.len() <= 1,
1573                "case_{case_id} expects a single output batch when limit is set, got {}",
1574                real_output.len()
1575            );
1576        }
1577
1578        let actual_output = if real_output.is_empty() {
1579            None
1580        } else {
1581            Some(concat_batches(&schema, &real_output).unwrap())
1582        };
1583
1584        if let Some(expected_polled_rows) = expected_polled_rows {
1585            let input_pulled_rows = mock_input.metrics().unwrap().output_rows().unwrap();
1586            assert_eq!(input_pulled_rows, expected_polled_rows);
1587        }
1588
1589        match (actual_output, expected_output) {
1590            (None, None) => {}
1591            (Some(actual), Some(expected)) => {
1592                if actual != expected {
1593                    let mut actual_json: Vec<u8> = Vec::new();
1594                    let mut writer = ArrayWriter::new(&mut actual_json);
1595                    writer.write(&actual).unwrap();
1596                    writer.finish().unwrap();
1597
1598                    let mut expected_json: Vec<u8> = Vec::new();
1599                    let mut writer = ArrayWriter::new(&mut expected_json);
1600                    writer.write(&expected).unwrap();
1601                    writer.finish().unwrap();
1602
1603                    panic!(
1604                        "case_{} failed (limit {limit:?}), opt: {:?},\nreal_output: {}\nexpected: {}",
1605                        case_id,
1606                        opt,
1607                        String::from_utf8_lossy(&actual_json),
1608                        String::from_utf8_lossy(&expected_json),
1609                    );
1610                }
1611            }
1612            (None, Some(expected)) => panic!(
1613                "case_{} failed (limit {limit:?}), opt: {:?},\nreal output is empty, expected {} rows",
1614                case_id,
1615                opt,
1616                expected.num_rows()
1617            ),
1618            (Some(actual), None) => panic!(
1619                "case_{} failed (limit {limit:?}), opt: {:?},\nreal output has {} rows, expected empty",
1620                case_id,
1621                opt,
1622                actual.num_rows()
1623            ),
1624        }
1625    }
1626
1627    /// Test that verifies the limit is correctly applied per partition when
1628    /// multiple batches are received for the same partition.
1629    #[tokio::test]
1630    async fn test_limit_with_multiple_batches_per_partition() {
1631        let unit = TimeUnit::Millisecond;
1632        let schema = Arc::new(Schema::new(vec![Field::new(
1633            "ts",
1634            DataType::Timestamp(unit, None),
1635            false,
1636        )]));
1637
1638        // Test case: Multiple batches in a single partition with limit=3
1639        // Input: 3 batches with [1,2,3], [4,5,6], [7,8,9] all in partition (0,10)
1640        // Expected: Only top 3 values [9,8,7] for descending sort
1641        let input_ranged_data = vec![(
1642            PartitionRange {
1643                start: Timestamp::new(0, unit.into()),
1644                end: Timestamp::new(10, unit.into()),
1645                num_rows: 9,
1646                identifier: 0,
1647            },
1648            vec![
1649                DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])])
1650                    .unwrap(),
1651                DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![4, 5, 6])])
1652                    .unwrap(),
1653                DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![7, 8, 9])])
1654                    .unwrap(),
1655            ],
1656        )];
1657
1658        let expected_output = Some(
1659            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![9, 8, 7])])
1660                .unwrap(),
1661        );
1662
1663        run_test(
1664            1000,
1665            input_ranged_data,
1666            schema.clone(),
1667            SortOptions {
1668                descending: true,
1669                ..Default::default()
1670            },
1671            Some(3),
1672            expected_output,
1673            None,
1674        )
1675        .await;
1676
1677        // Test case: Multiple batches across multiple partitions with limit=2
1678        // Partition 0: batches [10,11,12], [13,14,15] -> top 2 descending = [15,14]
1679        // Partition 1: batches [1,2,3], [4,5] -> top 2 descending = [5,4]
1680        let input_ranged_data = vec![
1681            (
1682                PartitionRange {
1683                    start: Timestamp::new(10, unit.into()),
1684                    end: Timestamp::new(20, unit.into()),
1685                    num_rows: 6,
1686                    identifier: 0,
1687                },
1688                vec![
1689                    DfRecordBatch::try_new(
1690                        schema.clone(),
1691                        vec![new_ts_array(unit, vec![10, 11, 12])],
1692                    )
1693                    .unwrap(),
1694                    DfRecordBatch::try_new(
1695                        schema.clone(),
1696                        vec![new_ts_array(unit, vec![13, 14, 15])],
1697                    )
1698                    .unwrap(),
1699                ],
1700            ),
1701            (
1702                PartitionRange {
1703                    start: Timestamp::new(0, unit.into()),
1704                    end: Timestamp::new(10, unit.into()),
1705                    num_rows: 5,
1706                    identifier: 1,
1707                },
1708                vec![
1709                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])])
1710                        .unwrap(),
1711                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![4, 5])])
1712                        .unwrap(),
1713                ],
1714            ),
1715        ];
1716
1717        let expected_output = Some(
1718            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![15, 14])]).unwrap(),
1719        );
1720
1721        run_test(
1722            1001,
1723            input_ranged_data,
1724            schema.clone(),
1725            SortOptions {
1726                descending: true,
1727                ..Default::default()
1728            },
1729            Some(2),
1730            expected_output,
1731            None,
1732        )
1733        .await;
1734
1735        // Test case: Ascending sort with limit
1736        // Partition: batches [7,8,9], [4,5,6], [1,2,3] -> top 2 ascending = [1,2]
1737        let input_ranged_data = vec![(
1738            PartitionRange {
1739                start: Timestamp::new(0, unit.into()),
1740                end: Timestamp::new(10, unit.into()),
1741                num_rows: 9,
1742                identifier: 0,
1743            },
1744            vec![
1745                DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![7, 8, 9])])
1746                    .unwrap(),
1747                DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![4, 5, 6])])
1748                    .unwrap(),
1749                DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])])
1750                    .unwrap(),
1751            ],
1752        )];
1753
1754        let expected_output = Some(
1755            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2])]).unwrap(),
1756        );
1757
1758        run_test(
1759            1002,
1760            input_ranged_data,
1761            schema.clone(),
1762            SortOptions {
1763                descending: false,
1764                ..Default::default()
1765            },
1766            Some(2),
1767            expected_output,
1768            None,
1769        )
1770        .await;
1771    }
1772
1773    /// Test that verifies early termination behavior.
1774    /// Once we've produced limit * num_partitions rows, we should stop
1775    /// pulling from input stream.
1776    #[tokio::test]
1777    async fn test_early_termination() {
1778        let unit = TimeUnit::Millisecond;
1779        let schema = Arc::new(Schema::new(vec![Field::new(
1780            "ts",
1781            DataType::Timestamp(unit, None),
1782            false,
1783        )]));
1784
1785        // Create 3 partitions, each with more data than the limit
1786        // limit=2 per partition, so total expected output = 6 rows
1787        // After producing 6 rows, early termination should kick in
1788        // For descending sort, ranges must be ordered by (end DESC, start DESC)
1789        let input_ranged_data = vec![
1790            (
1791                PartitionRange {
1792                    start: Timestamp::new(20, unit.into()),
1793                    end: Timestamp::new(30, unit.into()),
1794                    num_rows: 10,
1795                    identifier: 2,
1796                },
1797                vec![
1798                    DfRecordBatch::try_new(
1799                        schema.clone(),
1800                        vec![new_ts_array(unit, vec![21, 22, 23, 24, 25])],
1801                    )
1802                    .unwrap(),
1803                    DfRecordBatch::try_new(
1804                        schema.clone(),
1805                        vec![new_ts_array(unit, vec![26, 27, 28, 29, 30])],
1806                    )
1807                    .unwrap(),
1808                ],
1809            ),
1810            (
1811                PartitionRange {
1812                    start: Timestamp::new(10, unit.into()),
1813                    end: Timestamp::new(20, unit.into()),
1814                    num_rows: 10,
1815                    identifier: 1,
1816                },
1817                vec![
1818                    DfRecordBatch::try_new(
1819                        schema.clone(),
1820                        vec![new_ts_array(unit, vec![11, 12, 13, 14, 15])],
1821                    )
1822                    .unwrap(),
1823                    DfRecordBatch::try_new(
1824                        schema.clone(),
1825                        vec![new_ts_array(unit, vec![16, 17, 18, 19, 20])],
1826                    )
1827                    .unwrap(),
1828                ],
1829            ),
1830            (
1831                PartitionRange {
1832                    start: Timestamp::new(0, unit.into()),
1833                    end: Timestamp::new(10, unit.into()),
1834                    num_rows: 10,
1835                    identifier: 0,
1836                },
1837                vec![
1838                    DfRecordBatch::try_new(
1839                        schema.clone(),
1840                        vec![new_ts_array(unit, vec![1, 2, 3, 4, 5])],
1841                    )
1842                    .unwrap(),
1843                    DfRecordBatch::try_new(
1844                        schema.clone(),
1845                        vec![new_ts_array(unit, vec![6, 7, 8, 9, 10])],
1846                    )
1847                    .unwrap(),
1848                ],
1849            ),
1850        ];
1851
1852        // PartSort won't reorder `PartitionRange` (it assumes it's already ordered), so it will not read other partitions.
1853        // This case is just to verify that early termination works as expected.
1854        // First partition [20, 30) produces top 2 values: 29, 28
1855        let expected_output = Some(
1856            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![29, 28])]).unwrap(),
1857        );
1858
1859        run_test(
1860            1003,
1861            input_ranged_data,
1862            schema.clone(),
1863            SortOptions {
1864                descending: true,
1865                ..Default::default()
1866            },
1867            Some(2),
1868            expected_output,
1869            Some(10),
1870        )
1871        .await;
1872    }
1873
1874    /// Example:
1875    /// - Range [70, 100) has data [80, 90, 95]
1876    /// - Range [50, 100) has data [55, 65, 75, 85, 95]
1877    #[tokio::test]
1878    async fn test_primary_end_grouping_with_limit() {
1879        let unit = TimeUnit::Millisecond;
1880        let schema = Arc::new(Schema::new(vec![Field::new(
1881            "ts",
1882            DataType::Timestamp(unit, None),
1883            false,
1884        )]));
1885
1886        // Two ranges with the same end (100) - they should be grouped together
1887        // For descending, ranges are ordered by (end DESC, start DESC)
1888        // So [70, 100) comes before [50, 100) (70 > 50)
1889        let input_ranged_data = vec![
1890            (
1891                PartitionRange {
1892                    start: Timestamp::new(70, unit.into()),
1893                    end: Timestamp::new(100, unit.into()),
1894                    num_rows: 3,
1895                    identifier: 0,
1896                },
1897                vec![
1898                    DfRecordBatch::try_new(
1899                        schema.clone(),
1900                        vec![new_ts_array(unit, vec![80, 90, 95])],
1901                    )
1902                    .unwrap(),
1903                ],
1904            ),
1905            (
1906                PartitionRange {
1907                    start: Timestamp::new(50, unit.into()),
1908                    end: Timestamp::new(100, unit.into()),
1909                    num_rows: 5,
1910                    identifier: 1,
1911                },
1912                vec![
1913                    DfRecordBatch::try_new(
1914                        schema.clone(),
1915                        vec![new_ts_array(unit, vec![55, 65, 75, 85, 95])],
1916                    )
1917                    .unwrap(),
1918                ],
1919            ),
1920        ];
1921
1922        // With limit=4, descending: top 4 values from combined data
1923        // Combined: [80, 90, 95, 55, 65, 75, 85, 95] -> sorted desc: [95, 95, 90, 85, 80, 75, 65, 55]
1924        // Top 4: [95, 95, 90, 85]
1925        let expected_output = Some(
1926            DfRecordBatch::try_new(
1927                schema.clone(),
1928                vec![new_ts_array(unit, vec![95, 95, 90, 85])],
1929            )
1930            .unwrap(),
1931        );
1932
1933        run_test(
1934            2000,
1935            input_ranged_data,
1936            schema.clone(),
1937            SortOptions {
1938                descending: true,
1939                ..Default::default()
1940            },
1941            Some(4),
1942            expected_output,
1943            None,
1944        )
1945        .await;
1946    }
1947
1948    /// Test case with three ranges demonstrating the "keep pulling" behavior.
1949    /// After processing ranges with end=100, the smallest value in top-k might still
1950    /// be reachable by the next group.
1951    ///
1952    /// Ranges: [70, 100), [50, 100), [40, 95)
1953    /// With descending sort and limit=4:
1954    /// - Group 1 (end=100): [70, 100) and [50, 100) merged
1955    /// - Group 2 (end=95): [40, 95)
1956    /// After group 1, smallest in top-4 is 85. Range [40, 95) could have values >= 85,
1957    /// so we continue to group 2.
1958    #[tokio::test]
1959    async fn test_three_ranges_keep_pulling() {
1960        let unit = TimeUnit::Millisecond;
1961        let schema = Arc::new(Schema::new(vec![Field::new(
1962            "ts",
1963            DataType::Timestamp(unit, None),
1964            false,
1965        )]));
1966
1967        // Three ranges, two with same end (100), one with different end (95)
1968        let input_ranged_data = vec![
1969            (
1970                PartitionRange {
1971                    start: Timestamp::new(70, unit.into()),
1972                    end: Timestamp::new(100, unit.into()),
1973                    num_rows: 3,
1974                    identifier: 0,
1975                },
1976                vec![
1977                    DfRecordBatch::try_new(
1978                        schema.clone(),
1979                        vec![new_ts_array(unit, vec![80, 90, 95])],
1980                    )
1981                    .unwrap(),
1982                ],
1983            ),
1984            (
1985                PartitionRange {
1986                    start: Timestamp::new(50, unit.into()),
1987                    end: Timestamp::new(100, unit.into()),
1988                    num_rows: 3,
1989                    identifier: 1,
1990                },
1991                vec![
1992                    DfRecordBatch::try_new(
1993                        schema.clone(),
1994                        vec![new_ts_array(unit, vec![55, 75, 85])],
1995                    )
1996                    .unwrap(),
1997                ],
1998            ),
1999            (
2000                PartitionRange {
2001                    start: Timestamp::new(40, unit.into()),
2002                    end: Timestamp::new(95, unit.into()),
2003                    num_rows: 3,
2004                    identifier: 2,
2005                },
2006                vec![
2007                    DfRecordBatch::try_new(
2008                        schema.clone(),
2009                        vec![new_ts_array(unit, vec![45, 65, 94])],
2010                    )
2011                    .unwrap(),
2012                ],
2013            ),
2014        ];
2015
2016        // All data: [80, 90, 95, 55, 75, 85, 45, 65, 94]
2017        // Sorted descending: [95, 94, 90, 85, 80, 75, 65, 55, 45]
2018        // With limit=4: should be top 4 largest values across all ranges: [95, 94, 90, 85]
2019        let expected_output = Some(
2020            DfRecordBatch::try_new(
2021                schema.clone(),
2022                vec![new_ts_array(unit, vec![95, 94, 90, 85])],
2023            )
2024            .unwrap(),
2025        );
2026
2027        run_test(
2028            2001,
2029            input_ranged_data,
2030            schema.clone(),
2031            SortOptions {
2032                descending: true,
2033                ..Default::default()
2034            },
2035            Some(4),
2036            expected_output,
2037            None,
2038        )
2039        .await;
2040    }
2041
2042    /// Test early termination based on threshold comparison with next group.
2043    /// When the threshold (smallest value for descending) is >= next group's primary end,
2044    /// we can stop early because the next group cannot have better values.
2045    #[tokio::test]
2046    async fn test_threshold_based_early_termination() {
2047        let unit = TimeUnit::Millisecond;
2048        let schema = Arc::new(Schema::new(vec![Field::new(
2049            "ts",
2050            DataType::Timestamp(unit, None),
2051            false,
2052        )]));
2053
2054        // Group 1 (end=100) has 6 rows, TopK will keep top 4
2055        // Group 2 (end=90) has 3 rows - should NOT be processed because
2056        // threshold (96) >= next_primary_end (90)
2057        let input_ranged_data = vec![
2058            (
2059                PartitionRange {
2060                    start: Timestamp::new(70, unit.into()),
2061                    end: Timestamp::new(100, unit.into()),
2062                    num_rows: 6,
2063                    identifier: 0,
2064                },
2065                vec![
2066                    DfRecordBatch::try_new(
2067                        schema.clone(),
2068                        vec![new_ts_array(unit, vec![94, 95, 96, 97, 98, 99])],
2069                    )
2070                    .unwrap(),
2071                ],
2072            ),
2073            (
2074                PartitionRange {
2075                    start: Timestamp::new(50, unit.into()),
2076                    end: Timestamp::new(90, unit.into()),
2077                    num_rows: 3,
2078                    identifier: 1,
2079                },
2080                vec![
2081                    DfRecordBatch::try_new(
2082                        schema.clone(),
2083                        vec![new_ts_array(unit, vec![85, 86, 87])],
2084                    )
2085                    .unwrap(),
2086                ],
2087            ),
2088        ];
2089
2090        // With limit=4, descending: top 4 from group 1 are [99, 98, 97, 96]
2091        // Threshold is 96, next group's primary_end is 90
2092        // Since 96 >= 90, we stop after group 1
2093        let expected_output = Some(
2094            DfRecordBatch::try_new(
2095                schema.clone(),
2096                vec![new_ts_array(unit, vec![99, 98, 97, 96])],
2097            )
2098            .unwrap(),
2099        );
2100
2101        run_test(
2102            2002,
2103            input_ranged_data,
2104            schema.clone(),
2105            SortOptions {
2106                descending: true,
2107                ..Default::default()
2108            },
2109            Some(4),
2110            expected_output,
2111            Some(9), // Pull both batches since all rows fall within the first range
2112        )
2113        .await;
2114    }
2115
2116    /// Test that we continue to next group when threshold is within next group's range.
2117    /// Even after fulfilling limit, if threshold < next_primary_end (descending),
2118    /// we would need to continue... but limit exhaustion stops us first.
2119    #[tokio::test]
2120    async fn test_continue_when_threshold_in_next_group_range() {
2121        let unit = TimeUnit::Millisecond;
2122        let schema = Arc::new(Schema::new(vec![Field::new(
2123            "ts",
2124            DataType::Timestamp(unit, None),
2125            false,
2126        )]));
2127
2128        // Group 1 (end=100) has 6 rows, TopK will keep top 4
2129        // Group 2 (end=98) has 3 rows - threshold (96) < 98, so next group
2130        // could theoretically have better values. Continue reading.
2131        let input_ranged_data = vec![
2132            (
2133                PartitionRange {
2134                    start: Timestamp::new(90, unit.into()),
2135                    end: Timestamp::new(100, unit.into()),
2136                    num_rows: 6,
2137                    identifier: 0,
2138                },
2139                vec![
2140                    DfRecordBatch::try_new(
2141                        schema.clone(),
2142                        vec![new_ts_array(unit, vec![94, 95, 96, 97, 98, 99])],
2143                    )
2144                    .unwrap(),
2145                ],
2146            ),
2147            (
2148                PartitionRange {
2149                    start: Timestamp::new(50, unit.into()),
2150                    end: Timestamp::new(98, unit.into()),
2151                    num_rows: 3,
2152                    identifier: 1,
2153                },
2154                vec![
2155                    // Values must be < 70 (outside group 1's range) to avoid ambiguity
2156                    DfRecordBatch::try_new(
2157                        schema.clone(),
2158                        vec![new_ts_array(unit, vec![55, 60, 65])],
2159                    )
2160                    .unwrap(),
2161                ],
2162            ),
2163        ];
2164
2165        // With limit=4, we get [99, 98, 97, 96] from group 1
2166        // Threshold is 96, next group's primary_end is 98
2167        // 96 < 98, so threshold check says "could continue"
2168        // But limit is exhausted (0), so we stop anyway
2169        let expected_output = Some(
2170            DfRecordBatch::try_new(
2171                schema.clone(),
2172                vec![new_ts_array(unit, vec![99, 98, 97, 96])],
2173            )
2174            .unwrap(),
2175        );
2176
2177        // Note: We pull 9 rows (both batches) because we need to read batch 2
2178        // to detect the group boundary, even though we stop after outputting group 1.
2179        run_test(
2180            2003,
2181            input_ranged_data,
2182            schema.clone(),
2183            SortOptions {
2184                descending: true,
2185                ..Default::default()
2186            },
2187            Some(4),
2188            expected_output,
2189            Some(9), // Pull both batches to detect boundary
2190        )
2191        .await;
2192    }
2193
2194    /// Test ascending sort with threshold-based early termination.
2195    #[tokio::test]
2196    async fn test_ascending_threshold_early_termination() {
2197        let unit = TimeUnit::Millisecond;
2198        let schema = Arc::new(Schema::new(vec![Field::new(
2199            "ts",
2200            DataType::Timestamp(unit, None),
2201            false,
2202        )]));
2203
2204        // For ascending: primary_end is start, ranges sorted by (start ASC, end ASC)
2205        // Group 1 (start=10) has 6 rows
2206        // Group 2 (start=20) has 3 rows - should NOT be processed because
2207        // threshold (13) < next_primary_end (20)
2208        let input_ranged_data = vec![
2209            (
2210                PartitionRange {
2211                    start: Timestamp::new(10, unit.into()),
2212                    end: Timestamp::new(50, unit.into()),
2213                    num_rows: 6,
2214                    identifier: 0,
2215                },
2216                vec![
2217                    DfRecordBatch::try_new(
2218                        schema.clone(),
2219                        vec![new_ts_array(unit, vec![10, 11, 12, 13, 14, 15])],
2220                    )
2221                    .unwrap(),
2222                ],
2223            ),
2224            (
2225                PartitionRange {
2226                    start: Timestamp::new(20, unit.into()),
2227                    end: Timestamp::new(60, unit.into()),
2228                    num_rows: 3,
2229                    identifier: 1,
2230                },
2231                vec![
2232                    DfRecordBatch::try_new(
2233                        schema.clone(),
2234                        vec![new_ts_array(unit, vec![25, 30, 35])],
2235                    )
2236                    .unwrap(),
2237                ],
2238            ),
2239            // still read this batch to detect group boundary(?)
2240            (
2241                PartitionRange {
2242                    start: Timestamp::new(60, unit.into()),
2243                    end: Timestamp::new(70, unit.into()),
2244                    num_rows: 2,
2245                    identifier: 1,
2246                },
2247                vec![
2248                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![60, 61])])
2249                        .unwrap(),
2250                ],
2251            ),
2252            // after boundary detected, this following one should not be read
2253            (
2254                PartitionRange {
2255                    start: Timestamp::new(61, unit.into()),
2256                    end: Timestamp::new(70, unit.into()),
2257                    num_rows: 2,
2258                    identifier: 1,
2259                },
2260                vec![
2261                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![71, 72])])
2262                        .unwrap(),
2263                ],
2264            ),
2265        ];
2266
2267        // With limit=4, ascending: top 4 (smallest) from group 1 are [10, 11, 12, 13]
2268        // Threshold is 13 (largest in top-k), next group's primary_end is 20
2269        // Since 13 < 20, we stop after group 1 (no value in group 2 can be < 13)
2270        let expected_output = Some(
2271            DfRecordBatch::try_new(
2272                schema.clone(),
2273                vec![new_ts_array(unit, vec![10, 11, 12, 13])],
2274            )
2275            .unwrap(),
2276        );
2277
2278        run_test(
2279            2004,
2280            input_ranged_data,
2281            schema.clone(),
2282            SortOptions {
2283                descending: false,
2284                ..Default::default()
2285            },
2286            Some(4),
2287            expected_output,
2288            Some(11), // Pull first two batches to detect boundary
2289        )
2290        .await;
2291    }
2292
2293    #[tokio::test]
2294    async fn test_ascending_threshold_early_termination_case_two() {
2295        let unit = TimeUnit::Millisecond;
2296        let schema = Arc::new(Schema::new(vec![Field::new(
2297            "ts",
2298            DataType::Timestamp(unit, None),
2299            false,
2300        )]));
2301
2302        // For ascending: primary_end is start, ranges sorted by (start ASC, end ASC)
2303        // Group 1 (start=0) has 4 rows, Group 2 (start=4) has 1 row, Group 3 (start=5) has 4 rows
2304        // After reading all data: [9,10,11,12, 21, 5,6,7,8]
2305        // Sorted ascending: [5,6,7,8, 9,10,11,12, 21]
2306        // With limit=4, output should be smallest 4: [5,6,7,8]
2307        // Algorithm continues reading until start=42 > threshold=8, confirming no smaller values exist
2308        let input_ranged_data = vec![
2309            (
2310                PartitionRange {
2311                    start: Timestamp::new(0, unit.into()),
2312                    end: Timestamp::new(20, unit.into()),
2313                    num_rows: 4,
2314                    identifier: 0,
2315                },
2316                vec![
2317                    DfRecordBatch::try_new(
2318                        schema.clone(),
2319                        vec![new_ts_array(unit, vec![9, 10, 11, 12])],
2320                    )
2321                    .unwrap(),
2322                ],
2323            ),
2324            (
2325                PartitionRange {
2326                    start: Timestamp::new(4, unit.into()),
2327                    end: Timestamp::new(25, unit.into()),
2328                    num_rows: 1,
2329                    identifier: 1,
2330                },
2331                vec![
2332                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![21])])
2333                        .unwrap(),
2334                ],
2335            ),
2336            (
2337                PartitionRange {
2338                    start: Timestamp::new(5, unit.into()),
2339                    end: Timestamp::new(25, unit.into()),
2340                    num_rows: 4,
2341                    identifier: 1,
2342                },
2343                vec![
2344                    DfRecordBatch::try_new(
2345                        schema.clone(),
2346                        vec![new_ts_array(unit, vec![5, 6, 7, 8])],
2347                    )
2348                    .unwrap(),
2349                ],
2350            ),
2351            // This still will be read to detect boundary, but should not contribute to output
2352            (
2353                PartitionRange {
2354                    start: Timestamp::new(42, unit.into()),
2355                    end: Timestamp::new(52, unit.into()),
2356                    num_rows: 2,
2357                    identifier: 1,
2358                },
2359                vec![
2360                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![42, 51])])
2361                        .unwrap(),
2362                ],
2363            ),
2364            // This following one should not be read after boundary detected
2365            (
2366                PartitionRange {
2367                    start: Timestamp::new(48, unit.into()),
2368                    end: Timestamp::new(53, unit.into()),
2369                    num_rows: 2,
2370                    identifier: 1,
2371                },
2372                vec![
2373                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![48, 51])])
2374                        .unwrap(),
2375                ],
2376            ),
2377        ];
2378
2379        // With limit=4, ascending: after processing all ranges, smallest 4 are [5, 6, 7, 8]
2380        // Threshold is 8 (4th smallest value), algorithm reads until start=42 > threshold=8
2381        let expected_output = Some(
2382            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![5, 6, 7, 8])])
2383                .unwrap(),
2384        );
2385
2386        run_test(
2387            2005,
2388            input_ranged_data,
2389            schema.clone(),
2390            SortOptions {
2391                descending: false,
2392                ..Default::default()
2393            },
2394            Some(4),
2395            expected_output,
2396            Some(11), // Read first 4 ranges to confirm threshold boundary
2397        )
2398        .await;
2399    }
2400
2401    /// Test early stop behavior with null values in sort column.
2402    /// Verifies that nulls are handled correctly based on nulls_first option.
2403    #[tokio::test]
2404    async fn test_early_stop_with_nulls() {
2405        let unit = TimeUnit::Millisecond;
2406        let schema = Arc::new(Schema::new(vec![Field::new(
2407            "ts",
2408            DataType::Timestamp(unit, None),
2409            true, // nullable
2410        )]));
2411
2412        // Helper function to create nullable timestamp array
2413        let new_nullable_ts_array = |unit: TimeUnit, arr: Vec<Option<i64>>| -> ArrayRef {
2414            match unit {
2415                TimeUnit::Second => Arc::new(TimestampSecondArray::from(arr)) as ArrayRef,
2416                TimeUnit::Millisecond => Arc::new(TimestampMillisecondArray::from(arr)) as ArrayRef,
2417                TimeUnit::Microsecond => Arc::new(TimestampMicrosecondArray::from(arr)) as ArrayRef,
2418                TimeUnit::Nanosecond => Arc::new(TimestampNanosecondArray::from(arr)) as ArrayRef,
2419            }
2420        };
2421
2422        // Test case 1: nulls_first=true, null values should appear first
2423        // Group 1 (end=100): [null, null, 99, 98, 97] -> with limit=3, top 3 are [null, null, 99]
2424        // Threshold is 99, next group end=90, since 99 >= 90, we should stop early
2425        let input_ranged_data = vec![
2426            (
2427                PartitionRange {
2428                    start: Timestamp::new(70, unit.into()),
2429                    end: Timestamp::new(100, unit.into()),
2430                    num_rows: 5,
2431                    identifier: 0,
2432                },
2433                vec![
2434                    DfRecordBatch::try_new(
2435                        schema.clone(),
2436                        vec![new_nullable_ts_array(
2437                            unit,
2438                            vec![Some(99), Some(98), None, Some(97), None],
2439                        )],
2440                    )
2441                    .unwrap(),
2442                ],
2443            ),
2444            (
2445                PartitionRange {
2446                    start: Timestamp::new(50, unit.into()),
2447                    end: Timestamp::new(90, unit.into()),
2448                    num_rows: 3,
2449                    identifier: 1,
2450                },
2451                vec![
2452                    DfRecordBatch::try_new(
2453                        schema.clone(),
2454                        vec![new_nullable_ts_array(
2455                            unit,
2456                            vec![Some(89), Some(88), Some(87)],
2457                        )],
2458                    )
2459                    .unwrap(),
2460                ],
2461            ),
2462        ];
2463
2464        // With nulls_first=true, nulls sort before all values
2465        // For descending, order is: null, null, 99, 98, 97
2466        // With limit=3, we get: null, null, 99
2467        let expected_output = Some(
2468            DfRecordBatch::try_new(
2469                schema.clone(),
2470                vec![new_nullable_ts_array(unit, vec![None, None, Some(99)])],
2471            )
2472            .unwrap(),
2473        );
2474
2475        run_test(
2476            3000,
2477            input_ranged_data,
2478            schema.clone(),
2479            SortOptions {
2480                descending: true,
2481                nulls_first: true,
2482            },
2483            Some(3),
2484            expected_output,
2485            Some(8), // Must read both batches to detect group boundary
2486        )
2487        .await;
2488
2489        // Test case 2: nulls_last=true, null values should appear last
2490        // Group 1 (end=100): [99, 98, 97, null, null] -> with limit=3, top 3 are [99, 98, 97]
2491        // Threshold is 97, next group end=90, since 97 >= 90, we should stop early
2492        let input_ranged_data = vec![
2493            (
2494                PartitionRange {
2495                    start: Timestamp::new(70, unit.into()),
2496                    end: Timestamp::new(100, unit.into()),
2497                    num_rows: 5,
2498                    identifier: 0,
2499                },
2500                vec![
2501                    DfRecordBatch::try_new(
2502                        schema.clone(),
2503                        vec![new_nullable_ts_array(
2504                            unit,
2505                            vec![Some(99), Some(98), Some(97), None, None],
2506                        )],
2507                    )
2508                    .unwrap(),
2509                ],
2510            ),
2511            (
2512                PartitionRange {
2513                    start: Timestamp::new(50, unit.into()),
2514                    end: Timestamp::new(90, unit.into()),
2515                    num_rows: 3,
2516                    identifier: 1,
2517                },
2518                vec![
2519                    DfRecordBatch::try_new(
2520                        schema.clone(),
2521                        vec![new_nullable_ts_array(
2522                            unit,
2523                            vec![Some(89), Some(88), Some(87)],
2524                        )],
2525                    )
2526                    .unwrap(),
2527                ],
2528            ),
2529        ];
2530
2531        // With nulls_last=false (equivalent to nulls_first=false), values sort before nulls
2532        // For descending, order is: 99, 98, 97, null, null
2533        // With limit=3, we get: 99, 98, 97
2534        let expected_output = Some(
2535            DfRecordBatch::try_new(
2536                schema.clone(),
2537                vec![new_nullable_ts_array(
2538                    unit,
2539                    vec![Some(99), Some(98), Some(97)],
2540                )],
2541            )
2542            .unwrap(),
2543        );
2544
2545        run_test(
2546            3001,
2547            input_ranged_data,
2548            schema.clone(),
2549            SortOptions {
2550                descending: true,
2551                nulls_first: false,
2552            },
2553            Some(3),
2554            expected_output,
2555            Some(8), // Must read both batches to detect group boundary
2556        )
2557        .await;
2558    }
2559
2560    /// Test early stop behavior when there's only one group (no next group).
2561    /// In this case, can_stop_early should return false and we should process all data.
2562    #[tokio::test]
2563    async fn test_early_stop_single_group() {
2564        let unit = TimeUnit::Millisecond;
2565        let schema = Arc::new(Schema::new(vec![Field::new(
2566            "ts",
2567            DataType::Timestamp(unit, None),
2568            false,
2569        )]));
2570
2571        // Only one group (all ranges have the same end), no next group to compare against
2572        let input_ranged_data = vec![
2573            (
2574                PartitionRange {
2575                    start: Timestamp::new(70, unit.into()),
2576                    end: Timestamp::new(100, unit.into()),
2577                    num_rows: 6,
2578                    identifier: 0,
2579                },
2580                vec![
2581                    DfRecordBatch::try_new(
2582                        schema.clone(),
2583                        vec![new_ts_array(unit, vec![94, 95, 96, 97, 98, 99])],
2584                    )
2585                    .unwrap(),
2586                ],
2587            ),
2588            (
2589                PartitionRange {
2590                    start: Timestamp::new(50, unit.into()),
2591                    end: Timestamp::new(100, unit.into()),
2592                    num_rows: 3,
2593                    identifier: 1,
2594                },
2595                vec![
2596                    DfRecordBatch::try_new(
2597                        schema.clone(),
2598                        vec![new_ts_array(unit, vec![85, 86, 87])],
2599                    )
2600                    .unwrap(),
2601                ],
2602            ),
2603        ];
2604
2605        // Even though we have enough data in first range, we must process all
2606        // because there's no next group to compare threshold against
2607        let expected_output = Some(
2608            DfRecordBatch::try_new(
2609                schema.clone(),
2610                vec![new_ts_array(unit, vec![99, 98, 97, 96])],
2611            )
2612            .unwrap(),
2613        );
2614
2615        run_test(
2616            3002,
2617            input_ranged_data,
2618            schema.clone(),
2619            SortOptions {
2620                descending: true,
2621                ..Default::default()
2622            },
2623            Some(4),
2624            expected_output,
2625            Some(9), // Must read all batches since no early stop is possible
2626        )
2627        .await;
2628    }
2629
2630    /// Test early stop behavior when threshold exactly equals next group's boundary.
2631    #[tokio::test]
2632    async fn test_early_stop_exact_boundary_equality() {
2633        let unit = TimeUnit::Millisecond;
2634        let schema = Arc::new(Schema::new(vec![Field::new(
2635            "ts",
2636            DataType::Timestamp(unit, None),
2637            false,
2638        )]));
2639
2640        // Test case 1: Descending sort, threshold == next_group_end
2641        // Group 1 (end=100): data up to 90, threshold = 90, next_group_end = 90
2642        // Since 90 >= 90, we should stop early
2643        let input_ranged_data = vec![
2644            (
2645                PartitionRange {
2646                    start: Timestamp::new(70, unit.into()),
2647                    end: Timestamp::new(100, unit.into()),
2648                    num_rows: 4,
2649                    identifier: 0,
2650                },
2651                vec![
2652                    DfRecordBatch::try_new(
2653                        schema.clone(),
2654                        vec![new_ts_array(unit, vec![92, 91, 90, 89])],
2655                    )
2656                    .unwrap(),
2657                ],
2658            ),
2659            (
2660                PartitionRange {
2661                    start: Timestamp::new(50, unit.into()),
2662                    end: Timestamp::new(90, unit.into()),
2663                    num_rows: 3,
2664                    identifier: 1,
2665                },
2666                vec![
2667                    DfRecordBatch::try_new(
2668                        schema.clone(),
2669                        vec![new_ts_array(unit, vec![88, 87, 86])],
2670                    )
2671                    .unwrap(),
2672                ],
2673            ),
2674        ];
2675
2676        let expected_output = Some(
2677            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![92, 91, 90])])
2678                .unwrap(),
2679        );
2680
2681        run_test(
2682            3003,
2683            input_ranged_data,
2684            schema.clone(),
2685            SortOptions {
2686                descending: true,
2687                ..Default::default()
2688            },
2689            Some(3),
2690            expected_output,
2691            Some(7), // Must read both batches to detect boundary
2692        )
2693        .await;
2694
2695        // Test case 2: Ascending sort, threshold == next_group_start
2696        // Group 1 (start=10): data from 10, threshold = 20, next_group_start = 20
2697        // Since 20 < 20 is false, we should continue
2698        let input_ranged_data = vec![
2699            (
2700                PartitionRange {
2701                    start: Timestamp::new(10, unit.into()),
2702                    end: Timestamp::new(50, unit.into()),
2703                    num_rows: 4,
2704                    identifier: 0,
2705                },
2706                vec![
2707                    DfRecordBatch::try_new(
2708                        schema.clone(),
2709                        vec![new_ts_array(unit, vec![10, 15, 20, 25])],
2710                    )
2711                    .unwrap(),
2712                ],
2713            ),
2714            (
2715                PartitionRange {
2716                    start: Timestamp::new(20, unit.into()),
2717                    end: Timestamp::new(60, unit.into()),
2718                    num_rows: 3,
2719                    identifier: 1,
2720                },
2721                vec![
2722                    DfRecordBatch::try_new(
2723                        schema.clone(),
2724                        vec![new_ts_array(unit, vec![21, 22, 23])],
2725                    )
2726                    .unwrap(),
2727                ],
2728            ),
2729        ];
2730
2731        let expected_output = Some(
2732            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![10, 15, 20])])
2733                .unwrap(),
2734        );
2735
2736        run_test(
2737            3004,
2738            input_ranged_data,
2739            schema.clone(),
2740            SortOptions {
2741                descending: false,
2742                ..Default::default()
2743            },
2744            Some(3),
2745            expected_output,
2746            Some(7), // Must read both batches since 20 is not < 20
2747        )
2748        .await;
2749    }
2750
2751    /// Test early stop behavior with empty partition groups.
2752    #[tokio::test]
2753    async fn test_early_stop_with_empty_partitions() {
2754        let unit = TimeUnit::Millisecond;
2755        let schema = Arc::new(Schema::new(vec![Field::new(
2756            "ts",
2757            DataType::Timestamp(unit, None),
2758            false,
2759        )]));
2760
2761        // Test case 1: First group is empty, second group has data
2762        let input_ranged_data = vec![
2763            (
2764                PartitionRange {
2765                    start: Timestamp::new(70, unit.into()),
2766                    end: Timestamp::new(100, unit.into()),
2767                    num_rows: 0,
2768                    identifier: 0,
2769                },
2770                vec![
2771                    // Empty batch for first range
2772                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![])])
2773                        .unwrap(),
2774                ],
2775            ),
2776            (
2777                PartitionRange {
2778                    start: Timestamp::new(50, unit.into()),
2779                    end: Timestamp::new(100, unit.into()),
2780                    num_rows: 0,
2781                    identifier: 1,
2782                },
2783                vec![
2784                    // Empty batch for second range
2785                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![])])
2786                        .unwrap(),
2787                ],
2788            ),
2789            (
2790                PartitionRange {
2791                    start: Timestamp::new(30, unit.into()),
2792                    end: Timestamp::new(80, unit.into()),
2793                    num_rows: 4,
2794                    identifier: 2,
2795                },
2796                vec![
2797                    DfRecordBatch::try_new(
2798                        schema.clone(),
2799                        vec![new_ts_array(unit, vec![74, 75, 76, 77])],
2800                    )
2801                    .unwrap(),
2802                ],
2803            ),
2804            (
2805                PartitionRange {
2806                    start: Timestamp::new(10, unit.into()),
2807                    end: Timestamp::new(60, unit.into()),
2808                    num_rows: 3,
2809                    identifier: 3,
2810                },
2811                vec![
2812                    DfRecordBatch::try_new(
2813                        schema.clone(),
2814                        vec![new_ts_array(unit, vec![58, 59, 60])],
2815                    )
2816                    .unwrap(),
2817                ],
2818            ),
2819        ];
2820
2821        // Group 1 (end=100) is empty, Group 2 (end=80) has data
2822        // Should continue to Group 2 since Group 1 has no data
2823        let expected_output = Some(
2824            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![77, 76])]).unwrap(),
2825        );
2826
2827        run_test(
2828            3005,
2829            input_ranged_data,
2830            schema.clone(),
2831            SortOptions {
2832                descending: true,
2833                ..Default::default()
2834            },
2835            Some(2),
2836            expected_output,
2837            Some(7), // Must read until finding actual data
2838        )
2839        .await;
2840
2841        // Test case 2: Empty partitions between data groups
2842        let input_ranged_data = vec![
2843            (
2844                PartitionRange {
2845                    start: Timestamp::new(70, unit.into()),
2846                    end: Timestamp::new(100, unit.into()),
2847                    num_rows: 4,
2848                    identifier: 0,
2849                },
2850                vec![
2851                    DfRecordBatch::try_new(
2852                        schema.clone(),
2853                        vec![new_ts_array(unit, vec![96, 97, 98, 99])],
2854                    )
2855                    .unwrap(),
2856                ],
2857            ),
2858            (
2859                PartitionRange {
2860                    start: Timestamp::new(50, unit.into()),
2861                    end: Timestamp::new(90, unit.into()),
2862                    num_rows: 0,
2863                    identifier: 1,
2864                },
2865                vec![
2866                    // Empty range - should be skipped
2867                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![])])
2868                        .unwrap(),
2869                ],
2870            ),
2871            (
2872                PartitionRange {
2873                    start: Timestamp::new(30, unit.into()),
2874                    end: Timestamp::new(70, unit.into()),
2875                    num_rows: 0,
2876                    identifier: 2,
2877                },
2878                vec![
2879                    // Another empty range
2880                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![])])
2881                        .unwrap(),
2882                ],
2883            ),
2884            (
2885                PartitionRange {
2886                    start: Timestamp::new(10, unit.into()),
2887                    end: Timestamp::new(50, unit.into()),
2888                    num_rows: 3,
2889                    identifier: 3,
2890                },
2891                vec![
2892                    DfRecordBatch::try_new(
2893                        schema.clone(),
2894                        vec![new_ts_array(unit, vec![48, 49, 50])],
2895                    )
2896                    .unwrap(),
2897                ],
2898            ),
2899        ];
2900
2901        // With limit=2 from group 1: [99, 98], threshold=98, next group end=50
2902        // Since 98 >= 50, we should stop early
2903        let expected_output = Some(
2904            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![99, 98])]).unwrap(),
2905        );
2906
2907        run_test(
2908            3006,
2909            input_ranged_data,
2910            schema.clone(),
2911            SortOptions {
2912                descending: true,
2913                ..Default::default()
2914            },
2915            Some(2),
2916            expected_output,
2917            Some(7), // Must read to detect early stop condition
2918        )
2919        .await;
2920    }
2921
2922    /// First group: [0,20), data: [0, 5, 15]
2923    /// Second group: [10, 30), data: [21, 25, 29]
2924    /// after first group, calling early stop manually, and check if filter is updated
2925    #[tokio::test]
2926    async fn test_early_stop_check_update_dyn_filter() {
2927        let unit = TimeUnit::Millisecond;
2928        let schema = Arc::new(Schema::new(vec![Field::new(
2929            "ts",
2930            DataType::Timestamp(unit, None),
2931            false,
2932        )]));
2933
2934        let mock_input = Arc::new(MockInputExec::new(vec![vec![]], schema.clone()));
2935        let exec = PartSortExec::try_new(
2936            PhysicalSortExpr {
2937                expr: Arc::new(Column::new("ts", 0)),
2938                options: SortOptions {
2939                    descending: false,
2940                    ..Default::default()
2941                },
2942            },
2943            Some(3),
2944            vec![vec![
2945                PartitionRange {
2946                    start: Timestamp::new(0, unit.into()),
2947                    end: Timestamp::new(20, unit.into()),
2948                    num_rows: 3,
2949                    identifier: 1,
2950                },
2951                PartitionRange {
2952                    start: Timestamp::new(10, unit.into()),
2953                    end: Timestamp::new(30, unit.into()),
2954                    num_rows: 3,
2955                    identifier: 1,
2956                },
2957            ]],
2958            mock_input.clone(),
2959        )
2960        .unwrap();
2961
2962        let filter = exec.filter.clone().unwrap();
2963        let input_stream = mock_input
2964            .execute(0, Arc::new(TaskContext::default()))
2965            .unwrap();
2966        let mut stream = PartSortStream::new(
2967            Arc::new(TaskContext::default()),
2968            &exec,
2969            Some(3),
2970            input_stream,
2971            vec![],
2972            0,
2973            Some(filter.clone()),
2974        )
2975        .unwrap();
2976
2977        // initially, snapshot_generation is 1
2978        assert_eq!(filter.read().expr().snapshot_generation(), 1);
2979        let batch =
2980            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![0, 5, 15])])
2981                .unwrap();
2982        stream.push_buffer(batch).unwrap();
2983
2984        // after pushing first batch, snapshot_generation is updated to 2
2985        assert_eq!(filter.read().expr().snapshot_generation(), 2);
2986        assert!(!stream.can_stop_early(&schema).unwrap());
2987        // still two as not updated
2988        assert_eq!(filter.read().expr().snapshot_generation(), 2);
2989
2990        let _ = stream.sort_top_buffer().unwrap();
2991
2992        let batch =
2993            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![21, 25, 29])])
2994                .unwrap();
2995        stream.push_buffer(batch).unwrap();
2996        // still two as not updated
2997        assert_eq!(filter.read().expr().snapshot_generation(), 2);
2998        let new = stream.sort_top_buffer().unwrap();
2999        // still two as not updated
3000        assert_eq!(filter.read().expr().snapshot_generation(), 2);
3001
3002        // dyn filter kick in, and filter out all rows >= 15(the filter is rows<15)
3003        assert_eq!(new.num_rows(), 0)
3004    }
3005}