1use std::any::Any;
22use std::pin::Pin;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26use arrow::array::{
27 ArrayRef, AsArray, TimestampMicrosecondArray, TimestampMillisecondArray,
28 TimestampNanosecondArray, TimestampSecondArray,
29};
30use arrow::compute::{concat, concat_batches, take_record_batch};
31use arrow_schema::{Schema, SchemaRef};
32use common_recordbatch::{DfRecordBatch, DfSendableRecordBatchStream};
33use common_telemetry::warn;
34use common_time::Timestamp;
35use common_time::timestamp::TimeUnit;
36use datafusion::common::arrow::compute::sort_to_indices;
37use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation};
38use datafusion::execution::{RecordBatchStream, TaskContext};
39use datafusion::physical_plan::execution_plan::CardinalityEffect;
40use datafusion::physical_plan::filter_pushdown::{
41 ChildFilterDescription, FilterDescription, FilterPushdownPhase,
42};
43use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
44use datafusion::physical_plan::{
45 DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, TopK,
46 TopKDynamicFilters,
47};
48use datafusion_common::tree_node::{Transformed, TreeNode};
49use datafusion_common::{DataFusionError, internal_err};
50use datafusion_physical_expr::expressions::{Column, DynamicFilterPhysicalExpr, lit};
51use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
52use futures::{Stream, StreamExt};
53use itertools::Itertools;
54use parking_lot::RwLock;
55use snafu::location;
56use store_api::region_engine::PartitionRange;
57
58use crate::error::Result;
59use crate::window_sort::check_partition_range_monotonicity;
60use crate::{array_iter_helper, downcast_ts_array};
61
62fn get_primary_end(range: &PartitionRange, descending: bool) -> Timestamp {
67 if descending { range.end } else { range.start }
68}
69
70fn group_ranges_by_primary_end(
76 ranges: &[PartitionRange],
77 descending: bool,
78) -> Vec<(Timestamp, usize, usize)> {
79 if ranges.is_empty() {
80 return vec![];
81 }
82
83 let mut groups = Vec::new();
84 let mut group_start = 0;
85 let mut current_primary_end = get_primary_end(&ranges[0], descending);
86
87 for (idx, range) in ranges.iter().enumerate().skip(1) {
88 let primary_end = get_primary_end(range, descending);
89 if primary_end != current_primary_end {
90 groups.push((current_primary_end, group_start, idx));
92 group_start = idx;
94 current_primary_end = primary_end;
95 }
96 }
97 groups.push((current_primary_end, group_start, ranges.len()));
99
100 groups
101}
102
103#[derive(Debug, Clone)]
109pub struct PartSortExec {
110 expression: PhysicalSortExpr,
112 limit: Option<usize>,
113 input: Arc<dyn ExecutionPlan>,
114 metrics: ExecutionPlanMetricsSet,
116 partition_ranges: Vec<Vec<PartitionRange>>,
117 properties: Arc<PlanProperties>,
118 filter: Option<Arc<RwLock<TopKDynamicFilters>>>,
122}
123
124impl PartSortExec {
125 pub fn try_new(
126 expression: PhysicalSortExpr,
127 limit: Option<usize>,
128 partition_ranges: Vec<Vec<PartitionRange>>,
129 input: Arc<dyn ExecutionPlan>,
130 ) -> Result<Self> {
131 check_partition_range_monotonicity(&partition_ranges, expression.options.descending)?;
132
133 let metrics = ExecutionPlanMetricsSet::new();
134 let properties = input.properties();
135 let properties = Arc::new(PlanProperties::new(
136 input.equivalence_properties().clone(),
137 input.output_partitioning().clone(),
138 properties.emission_type,
139 properties.boundedness,
140 ));
141
142 let filter = limit
143 .is_some()
144 .then(|| Self::create_filter(expression.expr.clone()));
145
146 Ok(Self {
147 expression,
148 limit,
149 input,
150 metrics,
151 partition_ranges,
152 properties,
153 filter,
154 })
155 }
156
157 fn create_filter(expr: Arc<dyn PhysicalExpr>) -> Arc<RwLock<TopKDynamicFilters>> {
159 Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new(
160 DynamicFilterPhysicalExpr::new(vec![expr], lit(true)),
161 ))))
162 }
163
164 pub fn to_stream(
165 &self,
166 context: Arc<TaskContext>,
167 partition: usize,
168 ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
169 let input_stream: DfSendableRecordBatchStream =
170 self.input.execute(partition, context.clone())?;
171
172 if partition >= self.partition_ranges.len() {
173 internal_err!(
174 "Partition index out of range: {} >= {} at {}",
175 partition,
176 self.partition_ranges.len(),
177 snafu::location!()
178 )?;
179 }
180
181 let df_stream = Box::pin(PartSortStream::new(
182 context,
183 self,
184 self.limit,
185 input_stream,
186 self.partition_ranges[partition].clone(),
187 partition,
188 self.filter.clone(),
189 )?) as _;
190
191 Ok(df_stream)
192 }
193}
194
195impl DisplayAs for PartSortExec {
196 fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197 write!(
198 f,
199 "PartSortExec: expr={} num_ranges={}",
200 self.expression,
201 self.partition_ranges.len(),
202 )?;
203 if let Some(limit) = self.limit {
204 write!(f, " limit={}", limit)?;
205 }
206 Ok(())
207 }
208}
209
210impl ExecutionPlan for PartSortExec {
211 fn name(&self) -> &str {
212 "PartSortExec"
213 }
214
215 fn as_any(&self) -> &dyn Any {
216 self
217 }
218
219 fn schema(&self) -> SchemaRef {
220 self.input.schema()
221 }
222
223 fn properties(&self) -> &Arc<PlanProperties> {
224 &self.properties
225 }
226
227 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
228 vec![&self.input]
229 }
230
231 fn with_new_children(
232 self: Arc<Self>,
233 children: Vec<Arc<dyn ExecutionPlan>>,
234 ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
235 let new_input = if let Some(first) = children.first() {
236 first
237 } else {
238 internal_err!("No children found")?
239 };
240 let new = Self::try_new(
242 self.expression.clone(),
243 self.limit,
244 self.partition_ranges.clone(),
245 new_input.clone(),
246 )?;
247 Ok(Arc::new(new))
248 }
249
250 fn execute(
251 &self,
252 partition: usize,
253 context: Arc<TaskContext>,
254 ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
255 self.to_stream(context, partition)
256 }
257
258 fn metrics(&self) -> Option<MetricsSet> {
259 Some(self.metrics.clone_inner())
260 }
261
262 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
268 vec![false]
269 }
270
271 fn cardinality_effect(&self) -> CardinalityEffect {
272 if self.limit.is_none() {
273 CardinalityEffect::Equal
274 } else {
275 CardinalityEffect::LowerEqual
276 }
277 }
278
279 fn gather_filters_for_pushdown(
280 &self,
281 phase: FilterPushdownPhase,
282 parent_filters: Vec<Arc<dyn PhysicalExpr>>,
283 _config: &datafusion::config::ConfigOptions,
284 ) -> datafusion_common::Result<FilterDescription> {
285 if !matches!(phase, FilterPushdownPhase::Post) {
286 return FilterDescription::from_children(parent_filters, &self.children());
287 }
288
289 let mut child = ChildFilterDescription::from_child(&parent_filters, &self.input)?;
290
291 if let Some(filter) = &self.filter {
292 child = child.with_self_filter(filter.read().expr());
293 }
294
295 Ok(FilterDescription::new().with_child(child))
296 }
297
298 fn reset_state(self: Arc<Self>) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
299 let new_filter = self
301 .limit
302 .is_some()
303 .then(|| Self::create_filter(self.expression.expr.clone()));
304
305 Ok(Arc::new(Self {
306 expression: self.expression.clone(),
307 limit: self.limit,
308 input: self.input.clone(),
309 metrics: self.metrics.clone(),
310 partition_ranges: self.partition_ranges.clone(),
311 properties: self.properties.clone(),
312 filter: new_filter,
313 }))
314 }
315}
316
317enum PartSortBuffer {
318 All(Vec<DfRecordBatch>),
319 Top(TopK, usize),
324}
325
326impl PartSortBuffer {
327 pub fn is_empty(&self) -> bool {
328 match self {
329 PartSortBuffer::All(v) => v.is_empty(),
330 PartSortBuffer::Top(_, cnt) => *cnt == 0,
331 }
332 }
333}
334
335struct PartSortStream {
336 reservation: MemoryReservation,
338 buffer: PartSortBuffer,
339 expression: PhysicalSortExpr,
340 limit: Option<usize>,
341 input: DfSendableRecordBatchStream,
342 input_complete: bool,
343 schema: SchemaRef,
344 partition_ranges: Vec<PartitionRange>,
345 #[allow(dead_code)] partition: usize,
347 cur_part_idx: usize,
348 evaluating_batch: Option<DfRecordBatch>,
349 metrics: BaselineMetrics,
350 context: Arc<TaskContext>,
351 root_metrics: ExecutionPlanMetricsSet,
352 range_groups: Vec<(Timestamp, usize, usize)>,
355 cur_group_idx: usize,
357 filter: Option<Arc<RwLock<TopKDynamicFilters>>>,
360}
361
362impl PartSortStream {
363 fn new(
364 context: Arc<TaskContext>,
365 sort: &PartSortExec,
366 limit: Option<usize>,
367 input: DfSendableRecordBatchStream,
368 partition_ranges: Vec<PartitionRange>,
369 partition: usize,
370 filter: Option<Arc<RwLock<TopKDynamicFilters>>>,
371 ) -> datafusion_common::Result<Self> {
372 let buffer = if let Some(limit) = limit {
373 let Some(filter) = filter.clone() else {
374 return internal_err!(
375 "TopKDynamicFilters must be provided when limit is set at {}",
376 snafu::location!()
377 );
378 };
379
380 PartSortBuffer::Top(
381 TopK::try_new(
382 partition,
383 sort.schema().clone(),
384 vec![],
385 [sort.expression.clone()].into(),
386 limit,
387 context.session_config().batch_size(),
388 context.runtime_env(),
389 &sort.metrics,
390 filter.clone(),
391 )?,
392 0,
393 )
394 } else {
395 PartSortBuffer::All(Vec::new())
396 };
397
398 let descending = sort.expression.options.descending;
400 let range_groups = group_ranges_by_primary_end(&partition_ranges, descending);
401
402 Ok(Self {
403 reservation: MemoryConsumer::new("PartSortStream".to_string())
404 .register(&context.runtime_env().memory_pool),
405 buffer,
406 expression: sort.expression.clone(),
407 limit,
408 input,
409 input_complete: false,
410 schema: sort.input.schema(),
411 partition_ranges,
412 partition,
413 cur_part_idx: 0,
414 evaluating_batch: None,
415 metrics: BaselineMetrics::new(&sort.metrics, partition),
416 context,
417 root_metrics: sort.metrics.clone(),
418 range_groups,
419 cur_group_idx: 0,
420 filter,
421 })
422 }
423}
424
425macro_rules! array_check_helper {
426 ($t:ty, $unit:expr, $arr:expr, $cur_range:expr, $min_max_idx:expr) => {{
427 if $cur_range.start.unit().as_arrow_time_unit() != $unit
428 || $cur_range.end.unit().as_arrow_time_unit() != $unit
429 {
430 internal_err!(
431 "PartitionRange unit mismatch, expect {:?}, found {:?}",
432 $cur_range.start.unit(),
433 $unit
434 )?;
435 }
436 let arr = $arr
437 .as_any()
438 .downcast_ref::<arrow::array::PrimitiveArray<$t>>()
439 .unwrap();
440
441 let min = arr.value($min_max_idx.0);
442 let max = arr.value($min_max_idx.1);
443 let (min, max) = if min < max{
444 (min, max)
445 } else {
446 (max, min)
447 };
448 let cur_min = $cur_range.start.value();
449 let cur_max = $cur_range.end.value();
450 if !(min >= cur_min && max < cur_max) {
452 internal_err!(
453 "Sort column min/max value out of partition range: sort_column.min_max=[{:?}, {:?}] not in PartitionRange=[{:?}, {:?}]",
454 min,
455 max,
456 cur_min,
457 cur_max
458 )?;
459 }
460 }};
461}
462
463impl PartSortStream {
464 fn check_in_range(
468 &self,
469 sort_column: &ArrayRef,
470 min_max_idx: (usize, usize),
471 ) -> datafusion_common::Result<()> {
472 let Some(cur_range) = self.get_current_group_effective_range() else {
474 internal_err!(
475 "No effective range for current group {} at {}",
476 self.cur_group_idx,
477 snafu::location!()
478 )?
479 };
480
481 downcast_ts_array!(
482 sort_column.data_type() => (array_check_helper, sort_column, cur_range, min_max_idx),
483 _ => internal_err!(
484 "Unsupported data type for sort column: {:?}",
485 sort_column.data_type()
486 )?,
487 );
488
489 Ok(())
490 }
491
492 fn try_find_next_range(
497 &self,
498 sort_column: &ArrayRef,
499 ) -> datafusion_common::Result<Option<usize>> {
500 if sort_column.is_empty() {
501 return Ok(None);
502 }
503
504 if self.cur_part_idx >= self.partition_ranges.len() {
506 internal_err!(
507 "Partition index out of range: {} >= {} at {}",
508 self.cur_part_idx,
509 self.partition_ranges.len(),
510 snafu::location!()
511 )?;
512 }
513 let cur_range = self.partition_ranges[self.cur_part_idx];
514
515 let sort_column_iter = downcast_ts_array!(
516 sort_column.data_type() => (array_iter_helper, sort_column),
517 _ => internal_err!(
518 "Unsupported data type for sort column: {:?}",
519 sort_column.data_type()
520 )?,
521 );
522
523 for (idx, val) in sort_column_iter {
524 if let Some(val) = val
526 && (val >= cur_range.end.value() || val < cur_range.start.value())
527 {
528 return Ok(Some(idx));
529 }
530 }
531
532 Ok(None)
533 }
534
535 fn push_buffer(&mut self, batch: DfRecordBatch) -> datafusion_common::Result<()> {
536 match &mut self.buffer {
537 PartSortBuffer::All(v) => v.push(batch),
538 PartSortBuffer::Top(top, cnt) => {
539 *cnt += batch.num_rows();
540 top.insert_batch(batch)?;
541 }
542 }
543
544 Ok(())
545 }
546
547 fn can_stop_early(&mut self, schema: &Arc<Schema>) -> datafusion_common::Result<bool> {
551 let topk_cnt = match &self.buffer {
552 PartSortBuffer::Top(_, cnt) => *cnt,
553 _ => return Ok(false),
554 };
555 if Some(topk_cnt) < self.limit {
557 return Ok(false);
558 }
559 let next_group_primary_end = if self.cur_group_idx + 1 < self.range_groups.len() {
560 self.range_groups[self.cur_group_idx + 1].0
561 } else {
562 return Ok(false);
564 };
565
566 let filter = self
570 .filter
571 .as_ref()
572 .expect("TopKDynamicFilters must be provided when limit is set");
573 let filter = filter.read().expr().current()?;
574 let mut ts_index = None;
575 let filter = filter
577 .transform_down(|c| {
578 if let Some(column) = c.as_any().downcast_ref::<Column>() {
580 ts_index = Some(column.index());
581 Ok(Transformed::yes(
582 Arc::new(Column::new(column.name(), 0)) as Arc<dyn PhysicalExpr>
583 ))
584 } else {
585 Ok(Transformed::no(c))
586 }
587 })?
588 .data;
589 let Some(ts_index) = ts_index else {
590 return Ok(false); };
592 let field = if schema.fields().len() <= ts_index {
593 warn!(
594 "Schema mismatch when evaluating dynamic filter for PartSortExec at {}, schema: {:?}, ts_index: {}",
595 self.partition, schema, ts_index
596 );
597 return Ok(false); } else {
599 schema.field(ts_index)
600 };
601 let schema = Arc::new(Schema::new(vec![field.clone()]));
602 let primary_end_array = match next_group_primary_end.unit() {
604 TimeUnit::Second => Arc::new(TimestampSecondArray::from(vec![
605 next_group_primary_end.value(),
606 ])) as ArrayRef,
607 TimeUnit::Millisecond => Arc::new(TimestampMillisecondArray::from(vec![
608 next_group_primary_end.value(),
609 ])) as ArrayRef,
610 TimeUnit::Microsecond => Arc::new(TimestampMicrosecondArray::from(vec![
611 next_group_primary_end.value(),
612 ])) as ArrayRef,
613 TimeUnit::Nanosecond => Arc::new(TimestampNanosecondArray::from(vec![
614 next_group_primary_end.value(),
615 ])) as ArrayRef,
616 };
617 let primary_end_batch = DfRecordBatch::try_new(schema, vec![primary_end_array])?;
618 let res = filter.evaluate(&primary_end_batch)?;
619 let array = res.into_array(primary_end_batch.num_rows())?;
620 let filter = array.as_boolean().clone();
621 let overlap = filter.iter().next().flatten();
622 if let Some(false) = overlap {
623 Ok(true)
624 } else {
625 Ok(false)
626 }
627 }
628
629 fn is_in_current_group(&self, part_idx: usize) -> bool {
631 if self.cur_group_idx >= self.range_groups.len() {
632 return false;
633 }
634 let (_, start, end) = self.range_groups[self.cur_group_idx];
635 part_idx >= start && part_idx < end
636 }
637
638 fn advance_to_next_group(&mut self) -> bool {
640 self.cur_group_idx += 1;
641 self.cur_group_idx < self.range_groups.len()
642 }
643
644 fn get_current_group_effective_range(&self) -> Option<PartitionRange> {
648 if self.cur_group_idx >= self.range_groups.len() {
649 return None;
650 }
651 let (_, start_idx, end_idx) = self.range_groups[self.cur_group_idx];
652 if start_idx >= end_idx || start_idx >= self.partition_ranges.len() {
653 return None;
654 }
655
656 let ranges_in_group =
657 &self.partition_ranges[start_idx..end_idx.min(self.partition_ranges.len())];
658 if ranges_in_group.is_empty() {
659 return None;
660 }
661
662 let mut min_start = ranges_in_group[0].start;
664 let mut max_end = ranges_in_group[0].end;
665 for range in ranges_in_group.iter().skip(1) {
666 if range.start < min_start {
667 min_start = range.start;
668 }
669 if range.end > max_end {
670 max_end = range.end;
671 }
672 }
673
674 Some(PartitionRange {
675 start: min_start,
676 end: max_end,
677 num_rows: 0, identifier: 0, })
680 }
681
682 fn sort_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
686 match &mut self.buffer {
687 PartSortBuffer::All(_) => self.sort_all_buffer(),
688 PartSortBuffer::Top(_, _) => self.sort_top_buffer(),
689 }
690 }
691
692 fn sort_all_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
694 let PartSortBuffer::All(buffer) =
695 std::mem::replace(&mut self.buffer, PartSortBuffer::All(Vec::new()))
696 else {
697 unreachable!("buffer type is checked before and should be All variant")
698 };
699
700 if buffer.is_empty() {
701 return Ok(DfRecordBatch::new_empty(self.schema.clone()));
702 }
703 let mut sort_columns = Vec::with_capacity(buffer.len());
704 let mut opt = None;
705 for batch in buffer.iter() {
706 let sort_column = self.expression.evaluate_to_sort_column(batch)?;
707 opt = opt.or(sort_column.options);
708 sort_columns.push(sort_column.values);
709 }
710
711 let sort_column =
712 concat(&sort_columns.iter().map(|a| a.as_ref()).collect_vec()).map_err(|e| {
713 DataFusionError::ArrowError(
714 Box::new(e),
715 Some(format!("Fail to concat sort columns at {}", location!())),
716 )
717 })?;
718
719 let indices = sort_to_indices(&sort_column, opt, self.limit).map_err(|e| {
720 DataFusionError::ArrowError(
721 Box::new(e),
722 Some(format!("Fail to sort to indices at {}", location!())),
723 )
724 })?;
725 if indices.is_empty() {
726 return Ok(DfRecordBatch::new_empty(self.schema.clone()));
727 }
728
729 self.check_in_range(
730 &sort_column,
731 (
732 indices.value(0) as usize,
733 indices.value(indices.len() - 1) as usize,
734 ),
735 )
736 .inspect_err(|_e| {
737 #[cfg(debug_assertions)]
738 common_telemetry::error!(
739 "Fail to check sort column in range at {}, current_idx: {}, num_rows: {}, err: {}",
740 self.partition,
741 self.cur_part_idx,
742 sort_column.len(),
743 _e
744 );
745 })?;
746
747 let total_mem: usize = buffer.iter().map(|r| r.get_array_memory_size()).sum();
749 self.reservation.try_grow(total_mem * 2)?;
750
751 let full_input = concat_batches(&self.schema, &buffer).map_err(|e| {
752 DataFusionError::ArrowError(
753 Box::new(e),
754 Some(format!(
755 "Fail to concat input batches when sorting at {}",
756 location!()
757 )),
758 )
759 })?;
760
761 let sorted = take_record_batch(&full_input, &indices).map_err(|e| {
762 DataFusionError::ArrowError(
763 Box::new(e),
764 Some(format!(
765 "Fail to take result record batch when sorting at {}",
766 location!()
767 )),
768 )
769 })?;
770
771 drop(full_input);
772 self.reservation.shrink(2 * total_mem);
774 Ok(sorted)
775 }
776
777 fn sort_top_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
779 let Some(filter) = self.filter.clone() else {
780 return internal_err!(
781 "TopKDynamicFilters must be provided when sorting with limit at {}",
782 snafu::location!()
783 );
784 };
785
786 let new_top_buffer = TopK::try_new(
787 self.partition,
788 self.schema().clone(),
789 vec![],
790 [self.expression.clone()].into(),
791 self.limit.unwrap(),
792 self.context.session_config().batch_size(),
793 self.context.runtime_env(),
794 &self.root_metrics,
795 filter,
796 )?;
797 let PartSortBuffer::Top(top_k, _) =
798 std::mem::replace(&mut self.buffer, PartSortBuffer::Top(new_top_buffer, 0))
799 else {
800 unreachable!("buffer type is checked before and should be Top variant")
801 };
802
803 let mut result_stream = top_k.emit()?;
804 let mut placeholder_ctx = std::task::Context::from_waker(futures::task::noop_waker_ref());
805 let mut results = vec![];
806 loop {
808 match result_stream.poll_next_unpin(&mut placeholder_ctx) {
809 Poll::Ready(Some(batch)) => {
810 let batch = batch?;
811 results.push(batch);
812 }
813 Poll::Pending => {
814 #[cfg(debug_assertions)]
815 unreachable!("TopK result stream should always be ready")
816 }
817 Poll::Ready(None) => {
818 break;
819 }
820 }
821 }
822
823 let concat_batch = concat_batches(&self.schema, &results).map_err(|e| {
824 DataFusionError::ArrowError(
825 Box::new(e),
826 Some(format!(
827 "Fail to concat top k result record batch when sorting at {}",
828 location!()
829 )),
830 )
831 })?;
832
833 Ok(concat_batch)
834 }
835
836 fn sorted_buffer_if_non_empty(&mut self) -> datafusion_common::Result<Option<DfRecordBatch>> {
838 if self.buffer.is_empty() {
839 return Ok(None);
840 }
841
842 let sorted = self.sort_buffer()?;
843 if sorted.num_rows() == 0 {
844 Ok(None)
845 } else {
846 Ok(Some(sorted))
847 }
848 }
849
850 fn split_batch(
867 &mut self,
868 batch: DfRecordBatch,
869 ) -> datafusion_common::Result<Option<DfRecordBatch>> {
870 if matches!(self.buffer, PartSortBuffer::Top(_, _)) {
871 self.split_batch_topk(batch)?;
872 return Ok(None);
873 }
874
875 self.split_batch_all(batch)
876 }
877
878 fn split_batch_topk(&mut self, batch: DfRecordBatch) -> datafusion_common::Result<()> {
884 if batch.num_rows() == 0 {
885 return Ok(());
886 }
887
888 let sort_column = self
889 .expression
890 .expr
891 .evaluate(&batch)?
892 .into_array(batch.num_rows())?;
893
894 let next_range_idx = self.try_find_next_range(&sort_column)?;
895 let Some(idx) = next_range_idx else {
896 self.push_buffer(batch)?;
897 return Ok(());
899 };
900
901 let this_range = batch.slice(0, idx);
902 let remaining_range = batch.slice(idx, batch.num_rows() - idx);
903 if this_range.num_rows() != 0 {
904 self.push_buffer(this_range)?;
905 }
906
907 self.cur_part_idx += 1;
909
910 if self.cur_part_idx >= self.partition_ranges.len() {
912 debug_assert!(remaining_range.num_rows() == 0);
913 self.input_complete = true;
914 return Ok(());
915 }
916
917 let in_same_group = self.is_in_current_group(self.cur_part_idx);
919
920 if !in_same_group && self.can_stop_early(&batch.schema())? {
923 self.input_complete = true;
924 self.evaluating_batch = None;
925 return Ok(());
926 }
927
928 if !in_same_group {
930 self.advance_to_next_group();
931 }
932
933 let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
934 if self.try_find_next_range(&next_sort_column)?.is_some() {
935 self.evaluating_batch = Some(remaining_range);
938 } else if remaining_range.num_rows() != 0 {
939 self.push_buffer(remaining_range)?;
942 }
943
944 Ok(())
945 }
946
947 fn split_batch_all(
948 &mut self,
949 batch: DfRecordBatch,
950 ) -> datafusion_common::Result<Option<DfRecordBatch>> {
951 if batch.num_rows() == 0 {
952 return Ok(None);
953 }
954
955 let sort_column = self
956 .expression
957 .expr
958 .evaluate(&batch)?
959 .into_array(batch.num_rows())?;
960
961 let next_range_idx = self.try_find_next_range(&sort_column)?;
962 let Some(idx) = next_range_idx else {
963 self.push_buffer(batch)?;
964 return Ok(None);
966 };
967
968 let this_range = batch.slice(0, idx);
969 let remaining_range = batch.slice(idx, batch.num_rows() - idx);
970 if this_range.num_rows() != 0 {
971 self.push_buffer(this_range)?;
972 }
973
974 self.cur_part_idx += 1;
976
977 if self.cur_part_idx >= self.partition_ranges.len() {
979 debug_assert!(remaining_range.num_rows() == 0);
981
982 return self.sorted_buffer_if_non_empty();
984 }
985
986 if self.is_in_current_group(self.cur_part_idx) {
988 let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
990 if self.try_find_next_range(&next_sort_column)?.is_some() {
991 self.evaluating_batch = Some(remaining_range);
993 } else {
994 if remaining_range.num_rows() != 0 {
996 self.push_buffer(remaining_range)?;
997 }
998 }
999 return Ok(None);
1001 }
1002
1003 let sorted_batch = self.sorted_buffer_if_non_empty()?;
1005 self.advance_to_next_group();
1006
1007 let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
1008 if self.try_find_next_range(&next_sort_column)?.is_some() {
1009 self.evaluating_batch = Some(remaining_range);
1012 } else {
1013 if remaining_range.num_rows() != 0 {
1016 self.push_buffer(remaining_range)?;
1017 }
1018 }
1019
1020 Ok(sorted_batch)
1021 }
1022
1023 pub fn poll_next_inner(
1024 mut self: Pin<&mut Self>,
1025 cx: &mut Context<'_>,
1026 ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
1027 loop {
1028 if self.input_complete {
1029 if let Some(sorted_batch) = self.sorted_buffer_if_non_empty()? {
1030 return Poll::Ready(Some(Ok(sorted_batch)));
1031 }
1032 return Poll::Ready(None);
1033 }
1034
1035 if let Some(evaluating_batch) = self.evaluating_batch.take()
1038 && evaluating_batch.num_rows() != 0
1039 {
1040 if self.cur_part_idx >= self.partition_ranges.len() {
1042 if let Some(sorted_batch) = self.sorted_buffer_if_non_empty()? {
1044 return Poll::Ready(Some(Ok(sorted_batch)));
1045 }
1046 return Poll::Ready(None);
1047 }
1048
1049 if let Some(sorted_batch) = self.split_batch(evaluating_batch)? {
1050 return Poll::Ready(Some(Ok(sorted_batch)));
1051 }
1052 continue;
1053 }
1054
1055 let res = self.input.as_mut().poll_next(cx);
1057 match res {
1058 Poll::Ready(Some(Ok(batch))) => {
1059 if let Some(sorted_batch) = self.split_batch(batch)? {
1060 return Poll::Ready(Some(Ok(sorted_batch)));
1061 }
1062 }
1063 Poll::Ready(None) => {
1065 self.input_complete = true;
1066 }
1067 Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
1068 Poll::Pending => return Poll::Pending,
1069 }
1070 }
1071 }
1072}
1073
1074impl Stream for PartSortStream {
1075 type Item = datafusion_common::Result<DfRecordBatch>;
1076
1077 fn poll_next(
1078 mut self: Pin<&mut Self>,
1079 cx: &mut Context<'_>,
1080 ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
1081 let result = self.as_mut().poll_next_inner(cx);
1082 self.metrics.record_poll(result)
1083 }
1084}
1085
1086impl RecordBatchStream for PartSortStream {
1087 fn schema(&self) -> SchemaRef {
1088 self.schema.clone()
1089 }
1090}
1091
1092#[cfg(test)]
1093mod test {
1094 use std::sync::Arc;
1095
1096 use arrow::array::{
1097 TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
1098 TimestampSecondArray,
1099 };
1100 use arrow::json::ArrayWriter;
1101 use arrow_schema::{DataType, Field, Schema, SortOptions, TimeUnit};
1102 use common_time::Timestamp;
1103 use datafusion_physical_expr::expressions::Column;
1104 use futures::StreamExt;
1105 use store_api::region_engine::PartitionRange;
1106
1107 use super::*;
1108 use crate::test_util::{MockInputExec, new_ts_array};
1109
1110 #[tokio::test]
1111 async fn test_can_stop_early_with_empty_topk_buffer() {
1112 let unit = TimeUnit::Millisecond;
1113 let schema = Arc::new(Schema::new(vec![Field::new(
1114 "ts",
1115 DataType::Timestamp(unit, None),
1116 false,
1117 )]));
1118
1119 let mock_input = Arc::new(MockInputExec::new(vec![vec![]], schema.clone()));
1122 let exec = PartSortExec::try_new(
1123 PhysicalSortExpr {
1124 expr: Arc::new(Column::new("ts", 0)),
1125 options: SortOptions {
1126 descending: true,
1127 ..Default::default()
1128 },
1129 },
1130 Some(3),
1131 vec![vec![]],
1132 mock_input.clone(),
1133 )
1134 .unwrap();
1135
1136 let filter = Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new(
1137 DynamicFilterPhysicalExpr::new(vec![], lit(false)),
1138 ))));
1139
1140 let input_stream = mock_input
1141 .execute(0, Arc::new(TaskContext::default()))
1142 .unwrap();
1143 let mut stream = PartSortStream::new(
1144 Arc::new(TaskContext::default()),
1145 &exec,
1146 Some(3),
1147 input_stream,
1148 vec![],
1149 0,
1150 Some(filter),
1151 )
1152 .unwrap();
1153
1154 let batch = DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])])
1156 .unwrap();
1157 stream.push_buffer(batch).unwrap();
1158
1159 assert!(!stream.can_stop_early(&schema).unwrap());
1162 }
1163
1164 #[ignore = "hard to gen expected data correctly here, TODO(discord9): fix it later"]
1165 #[tokio::test]
1166 async fn fuzzy_test() {
1167 let test_cnt = 100;
1168 let part_cnt_bound = 100;
1170 let range_size_bound = 100;
1172 let range_offset_bound = 100;
1173 let batch_cnt_bound = 20;
1175 let batch_size_bound = 100;
1176
1177 let mut rng = fastrand::Rng::new();
1178 rng.seed(1337);
1179
1180 let mut test_cases = Vec::new();
1181
1182 for case_id in 0..test_cnt {
1183 let mut bound_val: Option<i64> = None;
1184 let descending = rng.bool();
1185 let nulls_first = rng.bool();
1186 let opt = SortOptions {
1187 descending,
1188 nulls_first,
1189 };
1190 let limit = if rng.bool() {
1191 Some(rng.usize(1..batch_cnt_bound * batch_size_bound))
1192 } else {
1193 None
1194 };
1195 let unit = match rng.u8(0..3) {
1196 0 => TimeUnit::Second,
1197 1 => TimeUnit::Millisecond,
1198 2 => TimeUnit::Microsecond,
1199 _ => TimeUnit::Nanosecond,
1200 };
1201
1202 let schema = Schema::new(vec![Field::new(
1203 "ts",
1204 DataType::Timestamp(unit, None),
1205 false,
1206 )]);
1207 let schema = Arc::new(schema);
1208
1209 let mut input_ranged_data = vec![];
1210 let mut output_ranges = vec![];
1211 let mut output_data = vec![];
1212 for part_id in 0..rng.usize(0..part_cnt_bound) {
1214 let (start, end) = if descending {
1216 let end = bound_val
1218 .map(
1219 |i| i
1220 .checked_sub(rng.i64(1..=range_offset_bound))
1221 .expect("Bad luck, fuzzy test generate data that will overflow, change seed and try again")
1222 )
1223 .unwrap_or_else(|| rng.i64(-100000000..100000000));
1224 bound_val = Some(end);
1225 let start = end - rng.i64(1..range_size_bound);
1226 let start = Timestamp::new(start, unit.into());
1227 let end = Timestamp::new(end, unit.into());
1228 (start, end)
1229 } else {
1230 let start = bound_val
1232 .map(|i| i + rng.i64(1..=range_offset_bound))
1233 .unwrap_or_else(|| rng.i64(..));
1234 bound_val = Some(start);
1235 let end = start + rng.i64(1..range_size_bound);
1236 let start = Timestamp::new(start, unit.into());
1237 let end = Timestamp::new(end, unit.into());
1238 (start, end)
1239 };
1240 assert!(start < end);
1241
1242 let mut per_part_sort_data = vec![];
1243 let mut batches = vec![];
1244 for _batch_idx in 0..rng.usize(1..batch_cnt_bound) {
1245 let cnt = rng.usize(0..batch_size_bound) + 1;
1246 let iter = 0..rng.usize(0..cnt);
1247 let mut data_gen = iter
1248 .map(|_| rng.i64(start.value()..end.value()))
1249 .collect_vec();
1250 if data_gen.is_empty() {
1251 continue;
1253 }
1254 data_gen.sort();
1256 per_part_sort_data.extend(data_gen.clone());
1257 let arr = new_ts_array(unit, data_gen.clone());
1258 let batch = DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap();
1259 batches.push(batch);
1260 }
1261
1262 let range = PartitionRange {
1263 start,
1264 end,
1265 num_rows: batches.iter().map(|b| b.num_rows()).sum(),
1266 identifier: part_id,
1267 };
1268 input_ranged_data.push((range, batches));
1269
1270 output_ranges.push(range);
1271 if per_part_sort_data.is_empty() {
1272 continue;
1273 }
1274 output_data.extend_from_slice(&per_part_sort_data);
1275 }
1276
1277 let mut output_data_iter = output_data.iter().peekable();
1279 let mut output_data = vec![];
1280 for range in output_ranges.clone() {
1281 let mut cur_data = vec![];
1282 while let Some(val) = output_data_iter.peek() {
1283 if **val < range.start.value() || **val >= range.end.value() {
1284 break;
1285 }
1286 cur_data.push(*output_data_iter.next().unwrap());
1287 }
1288
1289 if cur_data.is_empty() {
1290 continue;
1291 }
1292
1293 if descending {
1294 cur_data.sort_by(|a, b| b.cmp(a));
1295 } else {
1296 cur_data.sort();
1297 }
1298 output_data.push(cur_data);
1299 }
1300
1301 let expected_output = if let Some(limit) = limit {
1302 let mut accumulated = Vec::new();
1303 let mut seen = 0usize;
1304 for mut range_values in output_data {
1305 seen += range_values.len();
1306 accumulated.append(&mut range_values);
1307 if seen >= limit {
1308 break;
1309 }
1310 }
1311
1312 if accumulated.is_empty() {
1313 None
1314 } else {
1315 if descending {
1316 accumulated.sort_by(|a, b| b.cmp(a));
1317 } else {
1318 accumulated.sort();
1319 }
1320 accumulated.truncate(limit.min(accumulated.len()));
1321
1322 Some(
1323 DfRecordBatch::try_new(
1324 schema.clone(),
1325 vec![new_ts_array(unit, accumulated)],
1326 )
1327 .unwrap(),
1328 )
1329 }
1330 } else {
1331 let batches = output_data
1332 .into_iter()
1333 .map(|a| {
1334 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
1335 })
1336 .collect_vec();
1337 if batches.is_empty() {
1338 None
1339 } else {
1340 Some(concat_batches(&schema, &batches).unwrap())
1341 }
1342 };
1343
1344 test_cases.push((
1345 case_id,
1346 unit,
1347 input_ranged_data,
1348 schema,
1349 opt,
1350 limit,
1351 expected_output,
1352 ));
1353 }
1354
1355 for (case_id, _unit, input_ranged_data, schema, opt, limit, expected_output) in test_cases {
1356 run_test(
1357 case_id,
1358 input_ranged_data,
1359 schema,
1360 opt,
1361 limit,
1362 expected_output,
1363 None,
1364 )
1365 .await;
1366 }
1367 }
1368
1369 #[tokio::test]
1370 async fn simple_cases() {
1371 let testcases = vec![
1372 (
1373 TimeUnit::Millisecond,
1374 vec![
1375 ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]),
1376 ((5, 10), vec![vec![5, 6], vec![7, 8]]),
1377 ],
1378 false,
1379 None,
1380 vec![vec![1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9]],
1381 ),
1382 (
1385 TimeUnit::Millisecond,
1386 vec![
1387 ((5, 10), vec![vec![5, 6], vec![7, 8, 9]]),
1388 ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
1389 ],
1390 true,
1391 None,
1392 vec![vec![9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 3, 2, 1]],
1393 ),
1394 (
1395 TimeUnit::Millisecond,
1396 vec![
1397 ((5, 10), vec![]),
1398 ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
1399 ],
1400 true,
1401 None,
1402 vec![vec![8, 7, 6, 5, 4, 3, 2, 1]],
1403 ),
1404 (
1405 TimeUnit::Millisecond,
1406 vec![
1407 ((15, 20), vec![vec![17, 18, 19]]),
1408 ((10, 15), vec![]),
1409 ((5, 10), vec![]),
1410 ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
1411 ],
1412 true,
1413 None,
1414 vec![vec![19, 18, 17], vec![8, 7, 6, 5, 4, 3, 2, 1]],
1415 ),
1416 (
1417 TimeUnit::Millisecond,
1418 vec![
1419 ((15, 20), vec![]),
1420 ((10, 15), vec![]),
1421 ((5, 10), vec![]),
1422 ((0, 10), vec![]),
1423 ],
1424 true,
1425 None,
1426 vec![],
1427 ),
1428 (
1433 TimeUnit::Millisecond,
1434 vec![
1435 (
1436 (15, 20),
1437 vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
1438 ),
1439 ((10, 15), vec![]),
1440 ((5, 10), vec![]),
1441 ((0, 10), vec![]),
1442 ],
1443 true,
1444 None,
1445 vec![
1446 vec![19, 17, 15],
1447 vec![12, 11, 10],
1448 vec![9, 8, 7, 6, 5, 4, 3, 2, 1],
1449 ],
1450 ),
1451 (
1452 TimeUnit::Millisecond,
1453 vec![
1454 (
1455 (15, 20),
1456 vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
1457 ),
1458 ((10, 15), vec![]),
1459 ((5, 10), vec![]),
1460 ((0, 10), vec![]),
1461 ],
1462 true,
1463 Some(2),
1464 vec![vec![19, 17]],
1465 ),
1466 ];
1467
1468 for (identifier, (unit, input_ranged_data, descending, limit, expected_output)) in
1469 testcases.into_iter().enumerate()
1470 {
1471 let schema = Schema::new(vec![Field::new(
1472 "ts",
1473 DataType::Timestamp(unit, None),
1474 false,
1475 )]);
1476 let schema = Arc::new(schema);
1477 let opt = SortOptions {
1478 descending,
1479 ..Default::default()
1480 };
1481
1482 let input_ranged_data = input_ranged_data
1483 .into_iter()
1484 .map(|(range, data)| {
1485 let part = PartitionRange {
1486 start: Timestamp::new(range.0, unit.into()),
1487 end: Timestamp::new(range.1, unit.into()),
1488 num_rows: data.iter().map(|b| b.len()).sum(),
1489 identifier,
1490 };
1491
1492 let batches = data
1493 .into_iter()
1494 .map(|b| {
1495 let arr = new_ts_array(unit, b);
1496 DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap()
1497 })
1498 .collect_vec();
1499 (part, batches)
1500 })
1501 .collect_vec();
1502
1503 let expected_output = expected_output
1504 .into_iter()
1505 .map(|a| {
1506 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
1507 })
1508 .collect_vec();
1509 let expected_output = if expected_output.is_empty() {
1510 None
1511 } else {
1512 Some(concat_batches(&schema, &expected_output).unwrap())
1513 };
1514
1515 run_test(
1516 identifier,
1517 input_ranged_data,
1518 schema.clone(),
1519 opt,
1520 limit,
1521 expected_output,
1522 None,
1523 )
1524 .await;
1525 }
1526 }
1527
1528 #[allow(clippy::print_stdout)]
1529 async fn run_test(
1530 case_id: usize,
1531 input_ranged_data: Vec<(PartitionRange, Vec<DfRecordBatch>)>,
1532 schema: SchemaRef,
1533 opt: SortOptions,
1534 limit: Option<usize>,
1535 expected_output: Option<DfRecordBatch>,
1536 expected_polled_rows: Option<usize>,
1537 ) {
1538 if let (Some(limit), Some(rb)) = (limit, &expected_output) {
1539 assert!(
1540 rb.num_rows() <= limit,
1541 "Expect row count in expected output({}) <= limit({})",
1542 rb.num_rows(),
1543 limit
1544 );
1545 }
1546
1547 let mut data_partition = Vec::with_capacity(input_ranged_data.len());
1548 let mut ranges = Vec::with_capacity(input_ranged_data.len());
1549 for (part_range, batches) in input_ranged_data {
1550 data_partition.push(batches);
1551 ranges.push(part_range);
1552 }
1553
1554 let mock_input = Arc::new(MockInputExec::new(data_partition, schema.clone()));
1555
1556 let exec = PartSortExec::try_new(
1557 PhysicalSortExpr {
1558 expr: Arc::new(Column::new("ts", 0)),
1559 options: opt,
1560 },
1561 limit,
1562 vec![ranges.clone()],
1563 mock_input.clone(),
1564 )
1565 .unwrap();
1566
1567 let exec_stream = exec.execute(0, Arc::new(TaskContext::default())).unwrap();
1568
1569 let real_output = exec_stream.map(|r| r.unwrap()).collect::<Vec<_>>().await;
1570 if limit.is_some() {
1571 assert!(
1572 real_output.len() <= 1,
1573 "case_{case_id} expects a single output batch when limit is set, got {}",
1574 real_output.len()
1575 );
1576 }
1577
1578 let actual_output = if real_output.is_empty() {
1579 None
1580 } else {
1581 Some(concat_batches(&schema, &real_output).unwrap())
1582 };
1583
1584 if let Some(expected_polled_rows) = expected_polled_rows {
1585 let input_pulled_rows = mock_input.metrics().unwrap().output_rows().unwrap();
1586 assert_eq!(input_pulled_rows, expected_polled_rows);
1587 }
1588
1589 match (actual_output, expected_output) {
1590 (None, None) => {}
1591 (Some(actual), Some(expected)) => {
1592 if actual != expected {
1593 let mut actual_json: Vec<u8> = Vec::new();
1594 let mut writer = ArrayWriter::new(&mut actual_json);
1595 writer.write(&actual).unwrap();
1596 writer.finish().unwrap();
1597
1598 let mut expected_json: Vec<u8> = Vec::new();
1599 let mut writer = ArrayWriter::new(&mut expected_json);
1600 writer.write(&expected).unwrap();
1601 writer.finish().unwrap();
1602
1603 panic!(
1604 "case_{} failed (limit {limit:?}), opt: {:?},\nreal_output: {}\nexpected: {}",
1605 case_id,
1606 opt,
1607 String::from_utf8_lossy(&actual_json),
1608 String::from_utf8_lossy(&expected_json),
1609 );
1610 }
1611 }
1612 (None, Some(expected)) => panic!(
1613 "case_{} failed (limit {limit:?}), opt: {:?},\nreal output is empty, expected {} rows",
1614 case_id,
1615 opt,
1616 expected.num_rows()
1617 ),
1618 (Some(actual), None) => panic!(
1619 "case_{} failed (limit {limit:?}), opt: {:?},\nreal output has {} rows, expected empty",
1620 case_id,
1621 opt,
1622 actual.num_rows()
1623 ),
1624 }
1625 }
1626
1627 #[tokio::test]
1630 async fn test_limit_with_multiple_batches_per_partition() {
1631 let unit = TimeUnit::Millisecond;
1632 let schema = Arc::new(Schema::new(vec![Field::new(
1633 "ts",
1634 DataType::Timestamp(unit, None),
1635 false,
1636 )]));
1637
1638 let input_ranged_data = vec![(
1642 PartitionRange {
1643 start: Timestamp::new(0, unit.into()),
1644 end: Timestamp::new(10, unit.into()),
1645 num_rows: 9,
1646 identifier: 0,
1647 },
1648 vec![
1649 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])])
1650 .unwrap(),
1651 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![4, 5, 6])])
1652 .unwrap(),
1653 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![7, 8, 9])])
1654 .unwrap(),
1655 ],
1656 )];
1657
1658 let expected_output = Some(
1659 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![9, 8, 7])])
1660 .unwrap(),
1661 );
1662
1663 run_test(
1664 1000,
1665 input_ranged_data,
1666 schema.clone(),
1667 SortOptions {
1668 descending: true,
1669 ..Default::default()
1670 },
1671 Some(3),
1672 expected_output,
1673 None,
1674 )
1675 .await;
1676
1677 let input_ranged_data = vec![
1681 (
1682 PartitionRange {
1683 start: Timestamp::new(10, unit.into()),
1684 end: Timestamp::new(20, unit.into()),
1685 num_rows: 6,
1686 identifier: 0,
1687 },
1688 vec![
1689 DfRecordBatch::try_new(
1690 schema.clone(),
1691 vec![new_ts_array(unit, vec![10, 11, 12])],
1692 )
1693 .unwrap(),
1694 DfRecordBatch::try_new(
1695 schema.clone(),
1696 vec![new_ts_array(unit, vec![13, 14, 15])],
1697 )
1698 .unwrap(),
1699 ],
1700 ),
1701 (
1702 PartitionRange {
1703 start: Timestamp::new(0, unit.into()),
1704 end: Timestamp::new(10, unit.into()),
1705 num_rows: 5,
1706 identifier: 1,
1707 },
1708 vec![
1709 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])])
1710 .unwrap(),
1711 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![4, 5])])
1712 .unwrap(),
1713 ],
1714 ),
1715 ];
1716
1717 let expected_output = Some(
1718 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![15, 14])]).unwrap(),
1719 );
1720
1721 run_test(
1722 1001,
1723 input_ranged_data,
1724 schema.clone(),
1725 SortOptions {
1726 descending: true,
1727 ..Default::default()
1728 },
1729 Some(2),
1730 expected_output,
1731 None,
1732 )
1733 .await;
1734
1735 let input_ranged_data = vec![(
1738 PartitionRange {
1739 start: Timestamp::new(0, unit.into()),
1740 end: Timestamp::new(10, unit.into()),
1741 num_rows: 9,
1742 identifier: 0,
1743 },
1744 vec![
1745 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![7, 8, 9])])
1746 .unwrap(),
1747 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![4, 5, 6])])
1748 .unwrap(),
1749 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])])
1750 .unwrap(),
1751 ],
1752 )];
1753
1754 let expected_output = Some(
1755 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2])]).unwrap(),
1756 );
1757
1758 run_test(
1759 1002,
1760 input_ranged_data,
1761 schema.clone(),
1762 SortOptions {
1763 descending: false,
1764 ..Default::default()
1765 },
1766 Some(2),
1767 expected_output,
1768 None,
1769 )
1770 .await;
1771 }
1772
1773 #[tokio::test]
1777 async fn test_early_termination() {
1778 let unit = TimeUnit::Millisecond;
1779 let schema = Arc::new(Schema::new(vec![Field::new(
1780 "ts",
1781 DataType::Timestamp(unit, None),
1782 false,
1783 )]));
1784
1785 let input_ranged_data = vec![
1790 (
1791 PartitionRange {
1792 start: Timestamp::new(20, unit.into()),
1793 end: Timestamp::new(30, unit.into()),
1794 num_rows: 10,
1795 identifier: 2,
1796 },
1797 vec![
1798 DfRecordBatch::try_new(
1799 schema.clone(),
1800 vec![new_ts_array(unit, vec![21, 22, 23, 24, 25])],
1801 )
1802 .unwrap(),
1803 DfRecordBatch::try_new(
1804 schema.clone(),
1805 vec![new_ts_array(unit, vec![26, 27, 28, 29, 30])],
1806 )
1807 .unwrap(),
1808 ],
1809 ),
1810 (
1811 PartitionRange {
1812 start: Timestamp::new(10, unit.into()),
1813 end: Timestamp::new(20, unit.into()),
1814 num_rows: 10,
1815 identifier: 1,
1816 },
1817 vec![
1818 DfRecordBatch::try_new(
1819 schema.clone(),
1820 vec![new_ts_array(unit, vec![11, 12, 13, 14, 15])],
1821 )
1822 .unwrap(),
1823 DfRecordBatch::try_new(
1824 schema.clone(),
1825 vec![new_ts_array(unit, vec![16, 17, 18, 19, 20])],
1826 )
1827 .unwrap(),
1828 ],
1829 ),
1830 (
1831 PartitionRange {
1832 start: Timestamp::new(0, unit.into()),
1833 end: Timestamp::new(10, unit.into()),
1834 num_rows: 10,
1835 identifier: 0,
1836 },
1837 vec![
1838 DfRecordBatch::try_new(
1839 schema.clone(),
1840 vec![new_ts_array(unit, vec![1, 2, 3, 4, 5])],
1841 )
1842 .unwrap(),
1843 DfRecordBatch::try_new(
1844 schema.clone(),
1845 vec![new_ts_array(unit, vec![6, 7, 8, 9, 10])],
1846 )
1847 .unwrap(),
1848 ],
1849 ),
1850 ];
1851
1852 let expected_output = Some(
1856 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![29, 28])]).unwrap(),
1857 );
1858
1859 run_test(
1860 1003,
1861 input_ranged_data,
1862 schema.clone(),
1863 SortOptions {
1864 descending: true,
1865 ..Default::default()
1866 },
1867 Some(2),
1868 expected_output,
1869 Some(10),
1870 )
1871 .await;
1872 }
1873
1874 #[tokio::test]
1878 async fn test_primary_end_grouping_with_limit() {
1879 let unit = TimeUnit::Millisecond;
1880 let schema = Arc::new(Schema::new(vec![Field::new(
1881 "ts",
1882 DataType::Timestamp(unit, None),
1883 false,
1884 )]));
1885
1886 let input_ranged_data = vec![
1890 (
1891 PartitionRange {
1892 start: Timestamp::new(70, unit.into()),
1893 end: Timestamp::new(100, unit.into()),
1894 num_rows: 3,
1895 identifier: 0,
1896 },
1897 vec![
1898 DfRecordBatch::try_new(
1899 schema.clone(),
1900 vec![new_ts_array(unit, vec![80, 90, 95])],
1901 )
1902 .unwrap(),
1903 ],
1904 ),
1905 (
1906 PartitionRange {
1907 start: Timestamp::new(50, unit.into()),
1908 end: Timestamp::new(100, unit.into()),
1909 num_rows: 5,
1910 identifier: 1,
1911 },
1912 vec![
1913 DfRecordBatch::try_new(
1914 schema.clone(),
1915 vec![new_ts_array(unit, vec![55, 65, 75, 85, 95])],
1916 )
1917 .unwrap(),
1918 ],
1919 ),
1920 ];
1921
1922 let expected_output = Some(
1926 DfRecordBatch::try_new(
1927 schema.clone(),
1928 vec![new_ts_array(unit, vec![95, 95, 90, 85])],
1929 )
1930 .unwrap(),
1931 );
1932
1933 run_test(
1934 2000,
1935 input_ranged_data,
1936 schema.clone(),
1937 SortOptions {
1938 descending: true,
1939 ..Default::default()
1940 },
1941 Some(4),
1942 expected_output,
1943 None,
1944 )
1945 .await;
1946 }
1947
1948 #[tokio::test]
1959 async fn test_three_ranges_keep_pulling() {
1960 let unit = TimeUnit::Millisecond;
1961 let schema = Arc::new(Schema::new(vec![Field::new(
1962 "ts",
1963 DataType::Timestamp(unit, None),
1964 false,
1965 )]));
1966
1967 let input_ranged_data = vec![
1969 (
1970 PartitionRange {
1971 start: Timestamp::new(70, unit.into()),
1972 end: Timestamp::new(100, unit.into()),
1973 num_rows: 3,
1974 identifier: 0,
1975 },
1976 vec![
1977 DfRecordBatch::try_new(
1978 schema.clone(),
1979 vec![new_ts_array(unit, vec![80, 90, 95])],
1980 )
1981 .unwrap(),
1982 ],
1983 ),
1984 (
1985 PartitionRange {
1986 start: Timestamp::new(50, unit.into()),
1987 end: Timestamp::new(100, unit.into()),
1988 num_rows: 3,
1989 identifier: 1,
1990 },
1991 vec![
1992 DfRecordBatch::try_new(
1993 schema.clone(),
1994 vec![new_ts_array(unit, vec![55, 75, 85])],
1995 )
1996 .unwrap(),
1997 ],
1998 ),
1999 (
2000 PartitionRange {
2001 start: Timestamp::new(40, unit.into()),
2002 end: Timestamp::new(95, unit.into()),
2003 num_rows: 3,
2004 identifier: 2,
2005 },
2006 vec![
2007 DfRecordBatch::try_new(
2008 schema.clone(),
2009 vec![new_ts_array(unit, vec![45, 65, 94])],
2010 )
2011 .unwrap(),
2012 ],
2013 ),
2014 ];
2015
2016 let expected_output = Some(
2020 DfRecordBatch::try_new(
2021 schema.clone(),
2022 vec![new_ts_array(unit, vec![95, 94, 90, 85])],
2023 )
2024 .unwrap(),
2025 );
2026
2027 run_test(
2028 2001,
2029 input_ranged_data,
2030 schema.clone(),
2031 SortOptions {
2032 descending: true,
2033 ..Default::default()
2034 },
2035 Some(4),
2036 expected_output,
2037 None,
2038 )
2039 .await;
2040 }
2041
2042 #[tokio::test]
2046 async fn test_threshold_based_early_termination() {
2047 let unit = TimeUnit::Millisecond;
2048 let schema = Arc::new(Schema::new(vec![Field::new(
2049 "ts",
2050 DataType::Timestamp(unit, None),
2051 false,
2052 )]));
2053
2054 let input_ranged_data = vec![
2058 (
2059 PartitionRange {
2060 start: Timestamp::new(70, unit.into()),
2061 end: Timestamp::new(100, unit.into()),
2062 num_rows: 6,
2063 identifier: 0,
2064 },
2065 vec![
2066 DfRecordBatch::try_new(
2067 schema.clone(),
2068 vec![new_ts_array(unit, vec![94, 95, 96, 97, 98, 99])],
2069 )
2070 .unwrap(),
2071 ],
2072 ),
2073 (
2074 PartitionRange {
2075 start: Timestamp::new(50, unit.into()),
2076 end: Timestamp::new(90, unit.into()),
2077 num_rows: 3,
2078 identifier: 1,
2079 },
2080 vec![
2081 DfRecordBatch::try_new(
2082 schema.clone(),
2083 vec![new_ts_array(unit, vec![85, 86, 87])],
2084 )
2085 .unwrap(),
2086 ],
2087 ),
2088 ];
2089
2090 let expected_output = Some(
2094 DfRecordBatch::try_new(
2095 schema.clone(),
2096 vec![new_ts_array(unit, vec![99, 98, 97, 96])],
2097 )
2098 .unwrap(),
2099 );
2100
2101 run_test(
2102 2002,
2103 input_ranged_data,
2104 schema.clone(),
2105 SortOptions {
2106 descending: true,
2107 ..Default::default()
2108 },
2109 Some(4),
2110 expected_output,
2111 Some(9), )
2113 .await;
2114 }
2115
2116 #[tokio::test]
2120 async fn test_continue_when_threshold_in_next_group_range() {
2121 let unit = TimeUnit::Millisecond;
2122 let schema = Arc::new(Schema::new(vec![Field::new(
2123 "ts",
2124 DataType::Timestamp(unit, None),
2125 false,
2126 )]));
2127
2128 let input_ranged_data = vec![
2132 (
2133 PartitionRange {
2134 start: Timestamp::new(90, unit.into()),
2135 end: Timestamp::new(100, unit.into()),
2136 num_rows: 6,
2137 identifier: 0,
2138 },
2139 vec![
2140 DfRecordBatch::try_new(
2141 schema.clone(),
2142 vec![new_ts_array(unit, vec![94, 95, 96, 97, 98, 99])],
2143 )
2144 .unwrap(),
2145 ],
2146 ),
2147 (
2148 PartitionRange {
2149 start: Timestamp::new(50, unit.into()),
2150 end: Timestamp::new(98, unit.into()),
2151 num_rows: 3,
2152 identifier: 1,
2153 },
2154 vec![
2155 DfRecordBatch::try_new(
2157 schema.clone(),
2158 vec![new_ts_array(unit, vec![55, 60, 65])],
2159 )
2160 .unwrap(),
2161 ],
2162 ),
2163 ];
2164
2165 let expected_output = Some(
2170 DfRecordBatch::try_new(
2171 schema.clone(),
2172 vec![new_ts_array(unit, vec![99, 98, 97, 96])],
2173 )
2174 .unwrap(),
2175 );
2176
2177 run_test(
2180 2003,
2181 input_ranged_data,
2182 schema.clone(),
2183 SortOptions {
2184 descending: true,
2185 ..Default::default()
2186 },
2187 Some(4),
2188 expected_output,
2189 Some(9), )
2191 .await;
2192 }
2193
2194 #[tokio::test]
2196 async fn test_ascending_threshold_early_termination() {
2197 let unit = TimeUnit::Millisecond;
2198 let schema = Arc::new(Schema::new(vec![Field::new(
2199 "ts",
2200 DataType::Timestamp(unit, None),
2201 false,
2202 )]));
2203
2204 let input_ranged_data = vec![
2209 (
2210 PartitionRange {
2211 start: Timestamp::new(10, unit.into()),
2212 end: Timestamp::new(50, unit.into()),
2213 num_rows: 6,
2214 identifier: 0,
2215 },
2216 vec![
2217 DfRecordBatch::try_new(
2218 schema.clone(),
2219 vec![new_ts_array(unit, vec![10, 11, 12, 13, 14, 15])],
2220 )
2221 .unwrap(),
2222 ],
2223 ),
2224 (
2225 PartitionRange {
2226 start: Timestamp::new(20, unit.into()),
2227 end: Timestamp::new(60, unit.into()),
2228 num_rows: 3,
2229 identifier: 1,
2230 },
2231 vec![
2232 DfRecordBatch::try_new(
2233 schema.clone(),
2234 vec![new_ts_array(unit, vec![25, 30, 35])],
2235 )
2236 .unwrap(),
2237 ],
2238 ),
2239 (
2241 PartitionRange {
2242 start: Timestamp::new(60, unit.into()),
2243 end: Timestamp::new(70, unit.into()),
2244 num_rows: 2,
2245 identifier: 1,
2246 },
2247 vec![
2248 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![60, 61])])
2249 .unwrap(),
2250 ],
2251 ),
2252 (
2254 PartitionRange {
2255 start: Timestamp::new(61, unit.into()),
2256 end: Timestamp::new(70, unit.into()),
2257 num_rows: 2,
2258 identifier: 1,
2259 },
2260 vec![
2261 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![71, 72])])
2262 .unwrap(),
2263 ],
2264 ),
2265 ];
2266
2267 let expected_output = Some(
2271 DfRecordBatch::try_new(
2272 schema.clone(),
2273 vec![new_ts_array(unit, vec![10, 11, 12, 13])],
2274 )
2275 .unwrap(),
2276 );
2277
2278 run_test(
2279 2004,
2280 input_ranged_data,
2281 schema.clone(),
2282 SortOptions {
2283 descending: false,
2284 ..Default::default()
2285 },
2286 Some(4),
2287 expected_output,
2288 Some(11), )
2290 .await;
2291 }
2292
2293 #[tokio::test]
2294 async fn test_ascending_threshold_early_termination_case_two() {
2295 let unit = TimeUnit::Millisecond;
2296 let schema = Arc::new(Schema::new(vec![Field::new(
2297 "ts",
2298 DataType::Timestamp(unit, None),
2299 false,
2300 )]));
2301
2302 let input_ranged_data = vec![
2309 (
2310 PartitionRange {
2311 start: Timestamp::new(0, unit.into()),
2312 end: Timestamp::new(20, unit.into()),
2313 num_rows: 4,
2314 identifier: 0,
2315 },
2316 vec![
2317 DfRecordBatch::try_new(
2318 schema.clone(),
2319 vec![new_ts_array(unit, vec![9, 10, 11, 12])],
2320 )
2321 .unwrap(),
2322 ],
2323 ),
2324 (
2325 PartitionRange {
2326 start: Timestamp::new(4, unit.into()),
2327 end: Timestamp::new(25, unit.into()),
2328 num_rows: 1,
2329 identifier: 1,
2330 },
2331 vec![
2332 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![21])])
2333 .unwrap(),
2334 ],
2335 ),
2336 (
2337 PartitionRange {
2338 start: Timestamp::new(5, unit.into()),
2339 end: Timestamp::new(25, unit.into()),
2340 num_rows: 4,
2341 identifier: 1,
2342 },
2343 vec![
2344 DfRecordBatch::try_new(
2345 schema.clone(),
2346 vec![new_ts_array(unit, vec![5, 6, 7, 8])],
2347 )
2348 .unwrap(),
2349 ],
2350 ),
2351 (
2353 PartitionRange {
2354 start: Timestamp::new(42, unit.into()),
2355 end: Timestamp::new(52, unit.into()),
2356 num_rows: 2,
2357 identifier: 1,
2358 },
2359 vec![
2360 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![42, 51])])
2361 .unwrap(),
2362 ],
2363 ),
2364 (
2366 PartitionRange {
2367 start: Timestamp::new(48, unit.into()),
2368 end: Timestamp::new(53, unit.into()),
2369 num_rows: 2,
2370 identifier: 1,
2371 },
2372 vec![
2373 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![48, 51])])
2374 .unwrap(),
2375 ],
2376 ),
2377 ];
2378
2379 let expected_output = Some(
2382 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![5, 6, 7, 8])])
2383 .unwrap(),
2384 );
2385
2386 run_test(
2387 2005,
2388 input_ranged_data,
2389 schema.clone(),
2390 SortOptions {
2391 descending: false,
2392 ..Default::default()
2393 },
2394 Some(4),
2395 expected_output,
2396 Some(11), )
2398 .await;
2399 }
2400
2401 #[tokio::test]
2404 async fn test_early_stop_with_nulls() {
2405 let unit = TimeUnit::Millisecond;
2406 let schema = Arc::new(Schema::new(vec![Field::new(
2407 "ts",
2408 DataType::Timestamp(unit, None),
2409 true, )]));
2411
2412 let new_nullable_ts_array = |unit: TimeUnit, arr: Vec<Option<i64>>| -> ArrayRef {
2414 match unit {
2415 TimeUnit::Second => Arc::new(TimestampSecondArray::from(arr)) as ArrayRef,
2416 TimeUnit::Millisecond => Arc::new(TimestampMillisecondArray::from(arr)) as ArrayRef,
2417 TimeUnit::Microsecond => Arc::new(TimestampMicrosecondArray::from(arr)) as ArrayRef,
2418 TimeUnit::Nanosecond => Arc::new(TimestampNanosecondArray::from(arr)) as ArrayRef,
2419 }
2420 };
2421
2422 let input_ranged_data = vec![
2426 (
2427 PartitionRange {
2428 start: Timestamp::new(70, unit.into()),
2429 end: Timestamp::new(100, unit.into()),
2430 num_rows: 5,
2431 identifier: 0,
2432 },
2433 vec![
2434 DfRecordBatch::try_new(
2435 schema.clone(),
2436 vec![new_nullable_ts_array(
2437 unit,
2438 vec![Some(99), Some(98), None, Some(97), None],
2439 )],
2440 )
2441 .unwrap(),
2442 ],
2443 ),
2444 (
2445 PartitionRange {
2446 start: Timestamp::new(50, unit.into()),
2447 end: Timestamp::new(90, unit.into()),
2448 num_rows: 3,
2449 identifier: 1,
2450 },
2451 vec![
2452 DfRecordBatch::try_new(
2453 schema.clone(),
2454 vec![new_nullable_ts_array(
2455 unit,
2456 vec![Some(89), Some(88), Some(87)],
2457 )],
2458 )
2459 .unwrap(),
2460 ],
2461 ),
2462 ];
2463
2464 let expected_output = Some(
2468 DfRecordBatch::try_new(
2469 schema.clone(),
2470 vec![new_nullable_ts_array(unit, vec![None, None, Some(99)])],
2471 )
2472 .unwrap(),
2473 );
2474
2475 run_test(
2476 3000,
2477 input_ranged_data,
2478 schema.clone(),
2479 SortOptions {
2480 descending: true,
2481 nulls_first: true,
2482 },
2483 Some(3),
2484 expected_output,
2485 Some(8), )
2487 .await;
2488
2489 let input_ranged_data = vec![
2493 (
2494 PartitionRange {
2495 start: Timestamp::new(70, unit.into()),
2496 end: Timestamp::new(100, unit.into()),
2497 num_rows: 5,
2498 identifier: 0,
2499 },
2500 vec![
2501 DfRecordBatch::try_new(
2502 schema.clone(),
2503 vec![new_nullable_ts_array(
2504 unit,
2505 vec![Some(99), Some(98), Some(97), None, None],
2506 )],
2507 )
2508 .unwrap(),
2509 ],
2510 ),
2511 (
2512 PartitionRange {
2513 start: Timestamp::new(50, unit.into()),
2514 end: Timestamp::new(90, unit.into()),
2515 num_rows: 3,
2516 identifier: 1,
2517 },
2518 vec![
2519 DfRecordBatch::try_new(
2520 schema.clone(),
2521 vec![new_nullable_ts_array(
2522 unit,
2523 vec![Some(89), Some(88), Some(87)],
2524 )],
2525 )
2526 .unwrap(),
2527 ],
2528 ),
2529 ];
2530
2531 let expected_output = Some(
2535 DfRecordBatch::try_new(
2536 schema.clone(),
2537 vec![new_nullable_ts_array(
2538 unit,
2539 vec![Some(99), Some(98), Some(97)],
2540 )],
2541 )
2542 .unwrap(),
2543 );
2544
2545 run_test(
2546 3001,
2547 input_ranged_data,
2548 schema.clone(),
2549 SortOptions {
2550 descending: true,
2551 nulls_first: false,
2552 },
2553 Some(3),
2554 expected_output,
2555 Some(8), )
2557 .await;
2558 }
2559
2560 #[tokio::test]
2563 async fn test_early_stop_single_group() {
2564 let unit = TimeUnit::Millisecond;
2565 let schema = Arc::new(Schema::new(vec![Field::new(
2566 "ts",
2567 DataType::Timestamp(unit, None),
2568 false,
2569 )]));
2570
2571 let input_ranged_data = vec![
2573 (
2574 PartitionRange {
2575 start: Timestamp::new(70, unit.into()),
2576 end: Timestamp::new(100, unit.into()),
2577 num_rows: 6,
2578 identifier: 0,
2579 },
2580 vec![
2581 DfRecordBatch::try_new(
2582 schema.clone(),
2583 vec![new_ts_array(unit, vec![94, 95, 96, 97, 98, 99])],
2584 )
2585 .unwrap(),
2586 ],
2587 ),
2588 (
2589 PartitionRange {
2590 start: Timestamp::new(50, unit.into()),
2591 end: Timestamp::new(100, unit.into()),
2592 num_rows: 3,
2593 identifier: 1,
2594 },
2595 vec![
2596 DfRecordBatch::try_new(
2597 schema.clone(),
2598 vec![new_ts_array(unit, vec![85, 86, 87])],
2599 )
2600 .unwrap(),
2601 ],
2602 ),
2603 ];
2604
2605 let expected_output = Some(
2608 DfRecordBatch::try_new(
2609 schema.clone(),
2610 vec![new_ts_array(unit, vec![99, 98, 97, 96])],
2611 )
2612 .unwrap(),
2613 );
2614
2615 run_test(
2616 3002,
2617 input_ranged_data,
2618 schema.clone(),
2619 SortOptions {
2620 descending: true,
2621 ..Default::default()
2622 },
2623 Some(4),
2624 expected_output,
2625 Some(9), )
2627 .await;
2628 }
2629
2630 #[tokio::test]
2632 async fn test_early_stop_exact_boundary_equality() {
2633 let unit = TimeUnit::Millisecond;
2634 let schema = Arc::new(Schema::new(vec![Field::new(
2635 "ts",
2636 DataType::Timestamp(unit, None),
2637 false,
2638 )]));
2639
2640 let input_ranged_data = vec![
2644 (
2645 PartitionRange {
2646 start: Timestamp::new(70, unit.into()),
2647 end: Timestamp::new(100, unit.into()),
2648 num_rows: 4,
2649 identifier: 0,
2650 },
2651 vec![
2652 DfRecordBatch::try_new(
2653 schema.clone(),
2654 vec![new_ts_array(unit, vec![92, 91, 90, 89])],
2655 )
2656 .unwrap(),
2657 ],
2658 ),
2659 (
2660 PartitionRange {
2661 start: Timestamp::new(50, unit.into()),
2662 end: Timestamp::new(90, unit.into()),
2663 num_rows: 3,
2664 identifier: 1,
2665 },
2666 vec![
2667 DfRecordBatch::try_new(
2668 schema.clone(),
2669 vec![new_ts_array(unit, vec![88, 87, 86])],
2670 )
2671 .unwrap(),
2672 ],
2673 ),
2674 ];
2675
2676 let expected_output = Some(
2677 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![92, 91, 90])])
2678 .unwrap(),
2679 );
2680
2681 run_test(
2682 3003,
2683 input_ranged_data,
2684 schema.clone(),
2685 SortOptions {
2686 descending: true,
2687 ..Default::default()
2688 },
2689 Some(3),
2690 expected_output,
2691 Some(7), )
2693 .await;
2694
2695 let input_ranged_data = vec![
2699 (
2700 PartitionRange {
2701 start: Timestamp::new(10, unit.into()),
2702 end: Timestamp::new(50, unit.into()),
2703 num_rows: 4,
2704 identifier: 0,
2705 },
2706 vec![
2707 DfRecordBatch::try_new(
2708 schema.clone(),
2709 vec![new_ts_array(unit, vec![10, 15, 20, 25])],
2710 )
2711 .unwrap(),
2712 ],
2713 ),
2714 (
2715 PartitionRange {
2716 start: Timestamp::new(20, unit.into()),
2717 end: Timestamp::new(60, unit.into()),
2718 num_rows: 3,
2719 identifier: 1,
2720 },
2721 vec![
2722 DfRecordBatch::try_new(
2723 schema.clone(),
2724 vec![new_ts_array(unit, vec![21, 22, 23])],
2725 )
2726 .unwrap(),
2727 ],
2728 ),
2729 ];
2730
2731 let expected_output = Some(
2732 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![10, 15, 20])])
2733 .unwrap(),
2734 );
2735
2736 run_test(
2737 3004,
2738 input_ranged_data,
2739 schema.clone(),
2740 SortOptions {
2741 descending: false,
2742 ..Default::default()
2743 },
2744 Some(3),
2745 expected_output,
2746 Some(7), )
2748 .await;
2749 }
2750
2751 #[tokio::test]
2753 async fn test_early_stop_with_empty_partitions() {
2754 let unit = TimeUnit::Millisecond;
2755 let schema = Arc::new(Schema::new(vec![Field::new(
2756 "ts",
2757 DataType::Timestamp(unit, None),
2758 false,
2759 )]));
2760
2761 let input_ranged_data = vec![
2763 (
2764 PartitionRange {
2765 start: Timestamp::new(70, unit.into()),
2766 end: Timestamp::new(100, unit.into()),
2767 num_rows: 0,
2768 identifier: 0,
2769 },
2770 vec![
2771 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![])])
2773 .unwrap(),
2774 ],
2775 ),
2776 (
2777 PartitionRange {
2778 start: Timestamp::new(50, unit.into()),
2779 end: Timestamp::new(100, unit.into()),
2780 num_rows: 0,
2781 identifier: 1,
2782 },
2783 vec![
2784 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![])])
2786 .unwrap(),
2787 ],
2788 ),
2789 (
2790 PartitionRange {
2791 start: Timestamp::new(30, unit.into()),
2792 end: Timestamp::new(80, unit.into()),
2793 num_rows: 4,
2794 identifier: 2,
2795 },
2796 vec![
2797 DfRecordBatch::try_new(
2798 schema.clone(),
2799 vec![new_ts_array(unit, vec![74, 75, 76, 77])],
2800 )
2801 .unwrap(),
2802 ],
2803 ),
2804 (
2805 PartitionRange {
2806 start: Timestamp::new(10, unit.into()),
2807 end: Timestamp::new(60, unit.into()),
2808 num_rows: 3,
2809 identifier: 3,
2810 },
2811 vec![
2812 DfRecordBatch::try_new(
2813 schema.clone(),
2814 vec![new_ts_array(unit, vec![58, 59, 60])],
2815 )
2816 .unwrap(),
2817 ],
2818 ),
2819 ];
2820
2821 let expected_output = Some(
2824 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![77, 76])]).unwrap(),
2825 );
2826
2827 run_test(
2828 3005,
2829 input_ranged_data,
2830 schema.clone(),
2831 SortOptions {
2832 descending: true,
2833 ..Default::default()
2834 },
2835 Some(2),
2836 expected_output,
2837 Some(7), )
2839 .await;
2840
2841 let input_ranged_data = vec![
2843 (
2844 PartitionRange {
2845 start: Timestamp::new(70, unit.into()),
2846 end: Timestamp::new(100, unit.into()),
2847 num_rows: 4,
2848 identifier: 0,
2849 },
2850 vec![
2851 DfRecordBatch::try_new(
2852 schema.clone(),
2853 vec![new_ts_array(unit, vec![96, 97, 98, 99])],
2854 )
2855 .unwrap(),
2856 ],
2857 ),
2858 (
2859 PartitionRange {
2860 start: Timestamp::new(50, unit.into()),
2861 end: Timestamp::new(90, unit.into()),
2862 num_rows: 0,
2863 identifier: 1,
2864 },
2865 vec![
2866 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![])])
2868 .unwrap(),
2869 ],
2870 ),
2871 (
2872 PartitionRange {
2873 start: Timestamp::new(30, unit.into()),
2874 end: Timestamp::new(70, unit.into()),
2875 num_rows: 0,
2876 identifier: 2,
2877 },
2878 vec![
2879 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![])])
2881 .unwrap(),
2882 ],
2883 ),
2884 (
2885 PartitionRange {
2886 start: Timestamp::new(10, unit.into()),
2887 end: Timestamp::new(50, unit.into()),
2888 num_rows: 3,
2889 identifier: 3,
2890 },
2891 vec![
2892 DfRecordBatch::try_new(
2893 schema.clone(),
2894 vec![new_ts_array(unit, vec![48, 49, 50])],
2895 )
2896 .unwrap(),
2897 ],
2898 ),
2899 ];
2900
2901 let expected_output = Some(
2904 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![99, 98])]).unwrap(),
2905 );
2906
2907 run_test(
2908 3006,
2909 input_ranged_data,
2910 schema.clone(),
2911 SortOptions {
2912 descending: true,
2913 ..Default::default()
2914 },
2915 Some(2),
2916 expected_output,
2917 Some(7), )
2919 .await;
2920 }
2921
2922 #[tokio::test]
2926 async fn test_early_stop_check_update_dyn_filter() {
2927 let unit = TimeUnit::Millisecond;
2928 let schema = Arc::new(Schema::new(vec![Field::new(
2929 "ts",
2930 DataType::Timestamp(unit, None),
2931 false,
2932 )]));
2933
2934 let mock_input = Arc::new(MockInputExec::new(vec![vec![]], schema.clone()));
2935 let exec = PartSortExec::try_new(
2936 PhysicalSortExpr {
2937 expr: Arc::new(Column::new("ts", 0)),
2938 options: SortOptions {
2939 descending: false,
2940 ..Default::default()
2941 },
2942 },
2943 Some(3),
2944 vec![vec![
2945 PartitionRange {
2946 start: Timestamp::new(0, unit.into()),
2947 end: Timestamp::new(20, unit.into()),
2948 num_rows: 3,
2949 identifier: 1,
2950 },
2951 PartitionRange {
2952 start: Timestamp::new(10, unit.into()),
2953 end: Timestamp::new(30, unit.into()),
2954 num_rows: 3,
2955 identifier: 1,
2956 },
2957 ]],
2958 mock_input.clone(),
2959 )
2960 .unwrap();
2961
2962 let filter = exec.filter.clone().unwrap();
2963 let input_stream = mock_input
2964 .execute(0, Arc::new(TaskContext::default()))
2965 .unwrap();
2966 let mut stream = PartSortStream::new(
2967 Arc::new(TaskContext::default()),
2968 &exec,
2969 Some(3),
2970 input_stream,
2971 vec![],
2972 0,
2973 Some(filter.clone()),
2974 )
2975 .unwrap();
2976
2977 assert_eq!(filter.read().expr().snapshot_generation(), 1);
2979 let batch =
2980 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![0, 5, 15])])
2981 .unwrap();
2982 stream.push_buffer(batch).unwrap();
2983
2984 assert_eq!(filter.read().expr().snapshot_generation(), 2);
2986 assert!(!stream.can_stop_early(&schema).unwrap());
2987 assert_eq!(filter.read().expr().snapshot_generation(), 2);
2989
2990 let _ = stream.sort_top_buffer().unwrap();
2991
2992 let batch =
2993 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![21, 25, 29])])
2994 .unwrap();
2995 stream.push_buffer(batch).unwrap();
2996 assert_eq!(filter.read().expr().snapshot_generation(), 2);
2998 let new = stream.sort_top_buffer().unwrap();
2999 assert_eq!(filter.read().expr().snapshot_generation(), 2);
3001
3002 assert_eq!(new.num_rows(), 0)
3004 }
3005}