1use std::any::Any;
22use std::pin::Pin;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26use arrow::array::ArrayRef;
27use arrow::compute::{concat, concat_batches, take_record_batch};
28use arrow_schema::SchemaRef;
29use common_recordbatch::{DfRecordBatch, DfSendableRecordBatchStream};
30use datafusion::common::arrow::compute::sort_to_indices;
31use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation};
32use datafusion::execution::{RecordBatchStream, TaskContext};
33use datafusion::physical_plan::execution_plan::CardinalityEffect;
34use datafusion::physical_plan::filter_pushdown::{
35 ChildFilterDescription, FilterDescription, FilterPushdownPhase,
36};
37use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
38use datafusion::physical_plan::{
39 DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, TopK,
40 TopKDynamicFilters,
41};
42use datafusion_common::{DataFusionError, internal_err};
43use datafusion_physical_expr::expressions::{DynamicFilterPhysicalExpr, lit};
44use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
45use futures::{Stream, StreamExt};
46use itertools::Itertools;
47use parking_lot::RwLock;
48use snafu::location;
49use store_api::region_engine::PartitionRange;
50
51use crate::{array_iter_helper, downcast_ts_array};
52
53#[derive(Debug, Clone)]
59pub struct PartSortExec {
60 expression: PhysicalSortExpr,
62 limit: Option<usize>,
63 input: Arc<dyn ExecutionPlan>,
64 metrics: ExecutionPlanMetricsSet,
66 partition_ranges: Vec<Vec<PartitionRange>>,
67 properties: PlanProperties,
68 filter: Option<Arc<RwLock<TopKDynamicFilters>>>,
72}
73
74impl PartSortExec {
75 pub fn new(
76 expression: PhysicalSortExpr,
77 limit: Option<usize>,
78 partition_ranges: Vec<Vec<PartitionRange>>,
79 input: Arc<dyn ExecutionPlan>,
80 ) -> Self {
81 let metrics = ExecutionPlanMetricsSet::new();
82 let properties = input.properties();
83 let properties = PlanProperties::new(
84 input.equivalence_properties().clone(),
85 input.output_partitioning().clone(),
86 properties.emission_type,
87 properties.boundedness,
88 );
89
90 let filter = limit
91 .is_some()
92 .then(|| Self::create_filter(expression.expr.clone()));
93
94 Self {
95 expression,
96 limit,
97 input,
98 metrics,
99 partition_ranges,
100 properties,
101 filter,
102 }
103 }
104
105 fn create_filter(expr: Arc<dyn PhysicalExpr>) -> Arc<RwLock<TopKDynamicFilters>> {
107 Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new(
108 DynamicFilterPhysicalExpr::new(vec![expr], lit(true)),
109 ))))
110 }
111
112 pub fn to_stream(
113 &self,
114 context: Arc<TaskContext>,
115 partition: usize,
116 ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
117 let input_stream: DfSendableRecordBatchStream =
118 self.input.execute(partition, context.clone())?;
119
120 if partition >= self.partition_ranges.len() {
121 internal_err!(
122 "Partition index out of range: {} >= {} at {}",
123 partition,
124 self.partition_ranges.len(),
125 snafu::location!()
126 )?;
127 }
128
129 let df_stream = Box::pin(PartSortStream::new(
130 context,
131 self,
132 self.limit,
133 input_stream,
134 self.partition_ranges[partition].clone(),
135 partition,
136 self.filter.clone(),
137 )?) as _;
138
139 Ok(df_stream)
140 }
141}
142
143impl DisplayAs for PartSortExec {
144 fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 write!(
146 f,
147 "PartSortExec: expr={} num_ranges={}",
148 self.expression,
149 self.partition_ranges.len(),
150 )?;
151 if let Some(limit) = self.limit {
152 write!(f, " limit={}", limit)?;
153 }
154 Ok(())
155 }
156}
157
158impl ExecutionPlan for PartSortExec {
159 fn name(&self) -> &str {
160 "PartSortExec"
161 }
162
163 fn as_any(&self) -> &dyn Any {
164 self
165 }
166
167 fn schema(&self) -> SchemaRef {
168 self.input.schema()
169 }
170
171 fn properties(&self) -> &PlanProperties {
172 &self.properties
173 }
174
175 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
176 vec![&self.input]
177 }
178
179 fn with_new_children(
180 self: Arc<Self>,
181 children: Vec<Arc<dyn ExecutionPlan>>,
182 ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
183 let new_input = if let Some(first) = children.first() {
184 first
185 } else {
186 internal_err!("No children found")?
187 };
188 Ok(Arc::new(Self::new(
189 self.expression.clone(),
190 self.limit,
191 self.partition_ranges.clone(),
192 new_input.clone(),
193 )))
194 }
195
196 fn execute(
197 &self,
198 partition: usize,
199 context: Arc<TaskContext>,
200 ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
201 self.to_stream(context, partition)
202 }
203
204 fn metrics(&self) -> Option<MetricsSet> {
205 Some(self.metrics.clone_inner())
206 }
207
208 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
214 vec![false]
215 }
216
217 fn cardinality_effect(&self) -> CardinalityEffect {
218 if self.limit.is_none() {
219 CardinalityEffect::Equal
220 } else {
221 CardinalityEffect::LowerEqual
222 }
223 }
224
225 fn gather_filters_for_pushdown(
226 &self,
227 phase: FilterPushdownPhase,
228 parent_filters: Vec<Arc<dyn PhysicalExpr>>,
229 _config: &datafusion::config::ConfigOptions,
230 ) -> datafusion_common::Result<FilterDescription> {
231 if !matches!(phase, FilterPushdownPhase::Post) {
232 return FilterDescription::from_children(parent_filters, &self.children());
233 }
234
235 let mut child = ChildFilterDescription::from_child(&parent_filters, &self.input)?;
236
237 if let Some(filter) = &self.filter {
238 child = child.with_self_filter(filter.read().expr());
239 }
240
241 Ok(FilterDescription::new().with_child(child))
242 }
243
244 fn reset_state(self: Arc<Self>) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
245 let new_filter = self
247 .limit
248 .is_some()
249 .then(|| Self::create_filter(self.expression.expr.clone()));
250
251 Ok(Arc::new(Self {
252 expression: self.expression.clone(),
253 limit: self.limit,
254 input: self.input.clone(),
255 metrics: self.metrics.clone(),
256 partition_ranges: self.partition_ranges.clone(),
257 properties: self.properties.clone(),
258 filter: new_filter,
259 }))
260 }
261}
262
263enum PartSortBuffer {
264 All(Vec<DfRecordBatch>),
265 Top(TopK, usize),
270}
271
272impl PartSortBuffer {
273 pub fn is_empty(&self) -> bool {
274 match self {
275 PartSortBuffer::All(v) => v.is_empty(),
276 PartSortBuffer::Top(_, cnt) => *cnt == 0,
277 }
278 }
279}
280
281struct PartSortStream {
282 reservation: MemoryReservation,
284 buffer: PartSortBuffer,
285 expression: PhysicalSortExpr,
286 limit: Option<usize>,
287 produced: usize,
288 input: DfSendableRecordBatchStream,
289 input_complete: bool,
290 schema: SchemaRef,
291 partition_ranges: Vec<PartitionRange>,
292 #[allow(dead_code)] partition: usize,
294 cur_part_idx: usize,
295 evaluating_batch: Option<DfRecordBatch>,
296 metrics: BaselineMetrics,
297 context: Arc<TaskContext>,
298 root_metrics: ExecutionPlanMetricsSet,
299}
300
301impl PartSortStream {
302 fn new(
303 context: Arc<TaskContext>,
304 sort: &PartSortExec,
305 limit: Option<usize>,
306 input: DfSendableRecordBatchStream,
307 partition_ranges: Vec<PartitionRange>,
308 partition: usize,
309 filter: Option<Arc<RwLock<TopKDynamicFilters>>>,
310 ) -> datafusion_common::Result<Self> {
311 let buffer = if let Some(limit) = limit {
312 let Some(filter) = filter else {
313 return internal_err!(
314 "TopKDynamicFilters must be provided when limit is set at {}",
315 snafu::location!()
316 );
317 };
318
319 PartSortBuffer::Top(
320 TopK::try_new(
321 partition,
322 sort.schema().clone(),
323 vec![],
324 [sort.expression.clone()].into(),
325 limit,
326 context.session_config().batch_size(),
327 context.runtime_env(),
328 &sort.metrics,
329 filter,
330 )?,
331 0,
332 )
333 } else {
334 PartSortBuffer::All(Vec::new())
335 };
336
337 Ok(Self {
338 reservation: MemoryConsumer::new("PartSortStream".to_string())
339 .register(&context.runtime_env().memory_pool),
340 buffer,
341 expression: sort.expression.clone(),
342 limit,
343 produced: 0,
344 input,
345 input_complete: false,
346 schema: sort.input.schema(),
347 partition_ranges,
348 partition,
349 cur_part_idx: 0,
350 evaluating_batch: None,
351 metrics: BaselineMetrics::new(&sort.metrics, partition),
352 context,
353 root_metrics: sort.metrics.clone(),
354 })
355 }
356}
357
358macro_rules! array_check_helper {
359 ($t:ty, $unit:expr, $arr:expr, $cur_range:expr, $min_max_idx:expr) => {{
360 if $cur_range.start.unit().as_arrow_time_unit() != $unit
361 || $cur_range.end.unit().as_arrow_time_unit() != $unit
362 {
363 internal_err!(
364 "PartitionRange unit mismatch, expect {:?}, found {:?}",
365 $cur_range.start.unit(),
366 $unit
367 )?;
368 }
369 let arr = $arr
370 .as_any()
371 .downcast_ref::<arrow::array::PrimitiveArray<$t>>()
372 .unwrap();
373
374 let min = arr.value($min_max_idx.0);
375 let max = arr.value($min_max_idx.1);
376 let (min, max) = if min < max{
377 (min, max)
378 } else {
379 (max, min)
380 };
381 let cur_min = $cur_range.start.value();
382 let cur_max = $cur_range.end.value();
383 if !(min >= cur_min && max < cur_max) {
385 internal_err!(
386 "Sort column min/max value out of partition range: sort_column.min_max=[{:?}, {:?}] not in PartitionRange=[{:?}, {:?}]",
387 min,
388 max,
389 cur_min,
390 cur_max
391 )?;
392 }
393 }};
394}
395
396impl PartSortStream {
397 fn check_in_range(
399 &self,
400 sort_column: &ArrayRef,
401 min_max_idx: (usize, usize),
402 ) -> datafusion_common::Result<()> {
403 if self.cur_part_idx >= self.partition_ranges.len() {
404 internal_err!(
405 "Partition index out of range: {} >= {} at {}",
406 self.cur_part_idx,
407 self.partition_ranges.len(),
408 snafu::location!()
409 )?;
410 }
411 let cur_range = self.partition_ranges[self.cur_part_idx];
412
413 downcast_ts_array!(
414 sort_column.data_type() => (array_check_helper, sort_column, cur_range, min_max_idx),
415 _ => internal_err!(
416 "Unsupported data type for sort column: {:?}",
417 sort_column.data_type()
418 )?,
419 );
420
421 Ok(())
422 }
423
424 fn try_find_next_range(
429 &self,
430 sort_column: &ArrayRef,
431 ) -> datafusion_common::Result<Option<usize>> {
432 if sort_column.is_empty() {
433 return Ok(Some(0));
434 }
435
436 if self.cur_part_idx >= self.partition_ranges.len() {
438 internal_err!(
439 "Partition index out of range: {} >= {} at {}",
440 self.cur_part_idx,
441 self.partition_ranges.len(),
442 snafu::location!()
443 )?;
444 }
445 let cur_range = self.partition_ranges[self.cur_part_idx];
446
447 let sort_column_iter = downcast_ts_array!(
448 sort_column.data_type() => (array_iter_helper, sort_column),
449 _ => internal_err!(
450 "Unsupported data type for sort column: {:?}",
451 sort_column.data_type()
452 )?,
453 );
454
455 for (idx, val) in sort_column_iter {
456 if let Some(val) = val
458 && (val >= cur_range.end.value() || val < cur_range.start.value())
459 {
460 return Ok(Some(idx));
461 }
462 }
463
464 Ok(None)
465 }
466
467 fn push_buffer(&mut self, batch: DfRecordBatch) -> datafusion_common::Result<()> {
468 match &mut self.buffer {
469 PartSortBuffer::All(v) => v.push(batch),
470 PartSortBuffer::Top(top, cnt) => {
471 *cnt += batch.num_rows();
472 top.insert_batch(batch)?;
473 }
474 }
475
476 Ok(())
477 }
478
479 fn sort_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
483 match &mut self.buffer {
484 PartSortBuffer::All(_) => self.sort_all_buffer(),
485 PartSortBuffer::Top(_, _) => self.sort_top_buffer(),
486 }
487 }
488
489 fn sort_all_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
491 let PartSortBuffer::All(buffer) =
492 std::mem::replace(&mut self.buffer, PartSortBuffer::All(Vec::new()))
493 else {
494 unreachable!("buffer type is checked before and should be All variant")
495 };
496
497 if buffer.is_empty() {
498 return Ok(DfRecordBatch::new_empty(self.schema.clone()));
499 }
500 let mut sort_columns = Vec::with_capacity(buffer.len());
501 let mut opt = None;
502 for batch in buffer.iter() {
503 let sort_column = self.expression.evaluate_to_sort_column(batch)?;
504 opt = opt.or(sort_column.options);
505 sort_columns.push(sort_column.values);
506 }
507
508 let sort_column =
509 concat(&sort_columns.iter().map(|a| a.as_ref()).collect_vec()).map_err(|e| {
510 DataFusionError::ArrowError(
511 Box::new(e),
512 Some(format!("Fail to concat sort columns at {}", location!())),
513 )
514 })?;
515
516 let indices = sort_to_indices(&sort_column, opt, self.limit).map_err(|e| {
517 DataFusionError::ArrowError(
518 Box::new(e),
519 Some(format!("Fail to sort to indices at {}", location!())),
520 )
521 })?;
522 if indices.is_empty() {
523 return Ok(DfRecordBatch::new_empty(self.schema.clone()));
524 }
525
526 self.check_in_range(
527 &sort_column,
528 (
529 indices.value(0) as usize,
530 indices.value(indices.len() - 1) as usize,
531 ),
532 )
533 .inspect_err(|_e| {
534 #[cfg(debug_assertions)]
535 common_telemetry::error!(
536 "Fail to check sort column in range at {}, current_idx: {}, num_rows: {}, err: {}",
537 self.partition,
538 self.cur_part_idx,
539 sort_column.len(),
540 _e
541 );
542 })?;
543
544 let total_mem: usize = buffer.iter().map(|r| r.get_array_memory_size()).sum();
546 self.reservation.try_grow(total_mem * 2)?;
547
548 let full_input = concat_batches(&self.schema, &buffer).map_err(|e| {
549 DataFusionError::ArrowError(
550 Box::new(e),
551 Some(format!(
552 "Fail to concat input batches when sorting at {}",
553 location!()
554 )),
555 )
556 })?;
557
558 let sorted = take_record_batch(&full_input, &indices).map_err(|e| {
559 DataFusionError::ArrowError(
560 Box::new(e),
561 Some(format!(
562 "Fail to take result record batch when sorting at {}",
563 location!()
564 )),
565 )
566 })?;
567
568 self.produced += sorted.num_rows();
569 drop(full_input);
570 self.reservation.shrink(2 * total_mem);
572 Ok(sorted)
573 }
574
575 fn sort_top_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
577 let filter = Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new(
578 DynamicFilterPhysicalExpr::new(vec![], lit(true)),
579 ))));
580 let new_top_buffer = TopK::try_new(
581 self.partition,
582 self.schema().clone(),
583 vec![],
584 [self.expression.clone()].into(),
585 self.limit.unwrap(),
586 self.context.session_config().batch_size(),
587 self.context.runtime_env(),
588 &self.root_metrics,
589 filter,
590 )?;
591 let PartSortBuffer::Top(top_k, _) =
592 std::mem::replace(&mut self.buffer, PartSortBuffer::Top(new_top_buffer, 0))
593 else {
594 unreachable!("buffer type is checked before and should be Top variant")
595 };
596
597 let mut result_stream = top_k.emit()?;
598 let mut placeholder_ctx = std::task::Context::from_waker(futures::task::noop_waker_ref());
599 let mut results = vec![];
600 loop {
602 match result_stream.poll_next_unpin(&mut placeholder_ctx) {
603 Poll::Ready(Some(batch)) => {
604 let batch = batch?;
605 results.push(batch);
606 }
607 Poll::Pending => {
608 #[cfg(debug_assertions)]
609 unreachable!("TopK result stream should always be ready")
610 }
611 Poll::Ready(None) => {
612 break;
613 }
614 }
615 }
616
617 let concat_batch = concat_batches(&self.schema, &results).map_err(|e| {
618 DataFusionError::ArrowError(
619 Box::new(e),
620 Some(format!(
621 "Fail to concat top k result record batch when sorting at {}",
622 location!()
623 )),
624 )
625 })?;
626
627 Ok(concat_batch)
628 }
629
630 fn split_batch(
640 &mut self,
641 batch: DfRecordBatch,
642 ) -> datafusion_common::Result<Option<DfRecordBatch>> {
643 if batch.num_rows() == 0 {
644 return Ok(None);
645 }
646
647 let sort_column = self
648 .expression
649 .expr
650 .evaluate(&batch)?
651 .into_array(batch.num_rows())?;
652
653 let next_range_idx = self.try_find_next_range(&sort_column)?;
654 let Some(idx) = next_range_idx else {
655 self.push_buffer(batch)?;
656 return Ok(None);
658 };
659
660 let this_range = batch.slice(0, idx);
661 let remaining_range = batch.slice(idx, batch.num_rows() - idx);
662 if this_range.num_rows() != 0 {
663 self.push_buffer(this_range)?;
664 }
665 let sorted_batch = self.sort_buffer();
667 self.cur_part_idx += 1;
669 let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
670 if self.try_find_next_range(&next_sort_column)?.is_some() {
671 self.evaluating_batch = Some(remaining_range);
674 } else {
675 if remaining_range.num_rows() != 0 {
678 self.push_buffer(remaining_range)?;
679 }
680 }
681
682 sorted_batch.map(|x| if x.num_rows() == 0 { None } else { Some(x) })
683 }
684
685 pub fn poll_next_inner(
686 mut self: Pin<&mut Self>,
687 cx: &mut Context<'_>,
688 ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
689 loop {
690 if self.input_complete {
692 if self.buffer.is_empty() {
693 return Poll::Ready(None);
694 } else {
695 return Poll::Ready(Some(self.sort_buffer()));
696 }
697 }
698
699 if let Some(evaluating_batch) = self.evaluating_batch.take()
702 && evaluating_batch.num_rows() != 0
703 {
704 if let Some(sorted_batch) = self.split_batch(evaluating_batch)? {
705 return Poll::Ready(Some(Ok(sorted_batch)));
706 } else {
707 continue;
708 }
709 }
710
711 let res = self.input.as_mut().poll_next(cx);
713 match res {
714 Poll::Ready(Some(Ok(batch))) => {
715 if let Some(sorted_batch) = self.split_batch(batch)? {
716 return Poll::Ready(Some(Ok(sorted_batch)));
717 } else {
718 continue;
719 }
720 }
721 Poll::Ready(None) => {
723 self.input_complete = true;
724 continue;
725 }
726 Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
727 Poll::Pending => return Poll::Pending,
728 }
729 }
730 }
731}
732
733impl Stream for PartSortStream {
734 type Item = datafusion_common::Result<DfRecordBatch>;
735
736 fn poll_next(
737 mut self: Pin<&mut Self>,
738 cx: &mut Context<'_>,
739 ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
740 let result = self.as_mut().poll_next_inner(cx);
741 self.metrics.record_poll(result)
742 }
743}
744
745impl RecordBatchStream for PartSortStream {
746 fn schema(&self) -> SchemaRef {
747 self.schema.clone()
748 }
749}
750
751#[cfg(test)]
752mod test {
753 use std::sync::Arc;
754
755 use arrow::json::ArrayWriter;
756 use arrow_schema::{DataType, Field, Schema, SortOptions, TimeUnit};
757 use common_time::Timestamp;
758 use datafusion_physical_expr::expressions::Column;
759 use futures::StreamExt;
760 use store_api::region_engine::PartitionRange;
761
762 use super::*;
763 use crate::test_util::{MockInputExec, new_ts_array};
764
765 #[tokio::test]
766 async fn fuzzy_test() {
767 let test_cnt = 100;
768 let part_cnt_bound = 100;
770 let range_size_bound = 100;
772 let range_offset_bound = 100;
773 let batch_cnt_bound = 20;
775 let batch_size_bound = 100;
776
777 let mut rng = fastrand::Rng::new();
778 rng.seed(1337);
779
780 let mut test_cases = Vec::new();
781
782 for case_id in 0..test_cnt {
783 let mut bound_val: Option<i64> = None;
784 let descending = rng.bool();
785 let nulls_first = rng.bool();
786 let opt = SortOptions {
787 descending,
788 nulls_first,
789 };
790 let limit = if rng.bool() {
791 Some(rng.usize(0..batch_cnt_bound * batch_size_bound))
792 } else {
793 None
794 };
795 let unit = match rng.u8(0..3) {
796 0 => TimeUnit::Second,
797 1 => TimeUnit::Millisecond,
798 2 => TimeUnit::Microsecond,
799 _ => TimeUnit::Nanosecond,
800 };
801
802 let schema = Schema::new(vec![Field::new(
803 "ts",
804 DataType::Timestamp(unit, None),
805 false,
806 )]);
807 let schema = Arc::new(schema);
808
809 let mut input_ranged_data = vec![];
810 let mut output_ranges = vec![];
811 let mut output_data = vec![];
812 for part_id in 0..rng.usize(0..part_cnt_bound) {
814 let (start, end) = if descending {
816 let end = bound_val
817 .map(
818 |i| i
819 .checked_sub(rng.i64(0..range_offset_bound))
820 .expect("Bad luck, fuzzy test generate data that will overflow, change seed and try again")
821 )
822 .unwrap_or_else(|| rng.i64(-100000000..100000000));
823 bound_val = Some(end);
824 let start = end - rng.i64(1..range_size_bound);
825 let start = Timestamp::new(start, unit.into());
826 let end = Timestamp::new(end, unit.into());
827 (start, end)
828 } else {
829 let start = bound_val
830 .map(|i| i + rng.i64(0..range_offset_bound))
831 .unwrap_or_else(|| rng.i64(..));
832 bound_val = Some(start);
833 let end = start + rng.i64(1..range_size_bound);
834 let start = Timestamp::new(start, unit.into());
835 let end = Timestamp::new(end, unit.into());
836 (start, end)
837 };
838 assert!(start < end);
839
840 let mut per_part_sort_data = vec![];
841 let mut batches = vec![];
842 for _batch_idx in 0..rng.usize(1..batch_cnt_bound) {
843 let cnt = rng.usize(0..batch_size_bound) + 1;
844 let iter = 0..rng.usize(0..cnt);
845 let mut data_gen = iter
846 .map(|_| rng.i64(start.value()..end.value()))
847 .collect_vec();
848 if data_gen.is_empty() {
849 continue;
851 }
852 data_gen.sort();
854 per_part_sort_data.extend(data_gen.clone());
855 let arr = new_ts_array(unit, data_gen.clone());
856 let batch = DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap();
857 batches.push(batch);
858 }
859
860 let range = PartitionRange {
861 start,
862 end,
863 num_rows: batches.iter().map(|b| b.num_rows()).sum(),
864 identifier: part_id,
865 };
866 input_ranged_data.push((range, batches));
867
868 output_ranges.push(range);
869 if per_part_sort_data.is_empty() {
870 continue;
871 }
872 output_data.extend_from_slice(&per_part_sort_data);
873 }
874
875 let mut output_data_iter = output_data.iter().peekable();
877 let mut output_data = vec![];
878 for range in output_ranges.clone() {
879 let mut cur_data = vec![];
880 while let Some(val) = output_data_iter.peek() {
881 if **val < range.start.value() || **val >= range.end.value() {
882 break;
883 }
884 cur_data.push(*output_data_iter.next().unwrap());
885 }
886
887 if cur_data.is_empty() {
888 continue;
889 }
890
891 if descending {
892 cur_data.sort_by(|a, b| b.cmp(a));
893 } else {
894 cur_data.sort();
895 }
896 output_data.push(cur_data);
897 }
898
899 let expected_output = output_data
900 .into_iter()
901 .map(|a| {
902 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
903 })
904 .map(|rb| {
905 if let Some(limit) = limit
907 && rb.num_rows() > limit
908 {
909 rb.slice(0, limit)
910 } else {
911 rb
912 }
913 })
914 .collect_vec();
915
916 test_cases.push((
917 case_id,
918 unit,
919 input_ranged_data,
920 schema,
921 opt,
922 limit,
923 expected_output,
924 ));
925 }
926
927 for (case_id, _unit, input_ranged_data, schema, opt, limit, expected_output) in test_cases {
928 run_test(
929 case_id,
930 input_ranged_data,
931 schema,
932 opt,
933 limit,
934 expected_output,
935 )
936 .await;
937 }
938 }
939
940 #[tokio::test]
941 async fn simple_case() {
942 let testcases = vec![
943 (
944 TimeUnit::Millisecond,
945 vec![
946 ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]),
947 ((5, 10), vec![vec![5, 6], vec![7, 8]]),
948 ],
949 false,
950 None,
951 vec![vec![1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9]],
952 ),
953 (
954 TimeUnit::Millisecond,
955 vec![
956 ((5, 10), vec![vec![5, 6], vec![7, 8, 9]]),
957 ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
958 ],
959 true,
960 None,
961 vec![vec![9, 8, 7, 6, 5], vec![8, 7, 6, 5, 4, 3, 2, 1]],
962 ),
963 (
964 TimeUnit::Millisecond,
965 vec![
966 ((5, 10), vec![]),
967 ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
968 ],
969 true,
970 None,
971 vec![vec![8, 7, 6, 5, 4, 3, 2, 1]],
972 ),
973 (
974 TimeUnit::Millisecond,
975 vec![
976 ((15, 20), vec![vec![17, 18, 19]]),
977 ((10, 15), vec![]),
978 ((5, 10), vec![]),
979 ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
980 ],
981 true,
982 None,
983 vec![vec![19, 18, 17], vec![8, 7, 6, 5, 4, 3, 2, 1]],
984 ),
985 (
986 TimeUnit::Millisecond,
987 vec![
988 ((15, 20), vec![]),
989 ((10, 15), vec![]),
990 ((5, 10), vec![]),
991 ((0, 10), vec![]),
992 ],
993 true,
994 None,
995 vec![],
996 ),
997 (
998 TimeUnit::Millisecond,
999 vec![
1000 (
1001 (15, 20),
1002 vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
1003 ),
1004 ((10, 15), vec![]),
1005 ((5, 10), vec![]),
1006 ((0, 10), vec![]),
1007 ],
1008 true,
1009 None,
1010 vec![
1011 vec![19, 17, 15],
1012 vec![12, 11, 10],
1013 vec![9, 8, 7, 6, 5],
1014 vec![4, 3, 2, 1],
1015 ],
1016 ),
1017 (
1018 TimeUnit::Millisecond,
1019 vec![
1020 (
1021 (15, 20),
1022 vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
1023 ),
1024 ((10, 15), vec![]),
1025 ((5, 10), vec![]),
1026 ((0, 10), vec![]),
1027 ],
1028 true,
1029 Some(2),
1030 vec![vec![19, 17], vec![12, 11], vec![9, 8], vec![4, 3]],
1031 ),
1032 ];
1033
1034 for (identifier, (unit, input_ranged_data, descending, limit, expected_output)) in
1035 testcases.into_iter().enumerate()
1036 {
1037 let schema = Schema::new(vec![Field::new(
1038 "ts",
1039 DataType::Timestamp(unit, None),
1040 false,
1041 )]);
1042 let schema = Arc::new(schema);
1043 let opt = SortOptions {
1044 descending,
1045 ..Default::default()
1046 };
1047
1048 let input_ranged_data = input_ranged_data
1049 .into_iter()
1050 .map(|(range, data)| {
1051 let part = PartitionRange {
1052 start: Timestamp::new(range.0, unit.into()),
1053 end: Timestamp::new(range.1, unit.into()),
1054 num_rows: data.iter().map(|b| b.len()).sum(),
1055 identifier,
1056 };
1057
1058 let batches = data
1059 .into_iter()
1060 .map(|b| {
1061 let arr = new_ts_array(unit, b);
1062 DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap()
1063 })
1064 .collect_vec();
1065 (part, batches)
1066 })
1067 .collect_vec();
1068
1069 let expected_output = expected_output
1070 .into_iter()
1071 .map(|a| {
1072 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
1073 })
1074 .collect_vec();
1075
1076 run_test(
1077 identifier,
1078 input_ranged_data,
1079 schema.clone(),
1080 opt,
1081 limit,
1082 expected_output,
1083 )
1084 .await;
1085 }
1086 }
1087
1088 #[allow(clippy::print_stdout)]
1089 async fn run_test(
1090 case_id: usize,
1091 input_ranged_data: Vec<(PartitionRange, Vec<DfRecordBatch>)>,
1092 schema: SchemaRef,
1093 opt: SortOptions,
1094 limit: Option<usize>,
1095 expected_output: Vec<DfRecordBatch>,
1096 ) {
1097 for rb in &expected_output {
1098 if let Some(limit) = limit {
1099 assert!(
1100 rb.num_rows() <= limit,
1101 "Expect row count in expected output's batch({}) <= limit({})",
1102 rb.num_rows(),
1103 limit
1104 );
1105 }
1106 }
1107 let (ranges, batches): (Vec<_>, Vec<_>) = input_ranged_data.clone().into_iter().unzip();
1108
1109 let batches = batches
1110 .into_iter()
1111 .flat_map(|mut cols| {
1112 cols.push(DfRecordBatch::new_empty(schema.clone()));
1113 cols
1114 })
1115 .collect_vec();
1116 let mock_input = MockInputExec::new(batches, schema.clone());
1117
1118 let exec = PartSortExec::new(
1119 PhysicalSortExpr {
1120 expr: Arc::new(Column::new("ts", 0)),
1121 options: opt,
1122 },
1123 limit,
1124 vec![ranges.clone()],
1125 Arc::new(mock_input),
1126 );
1127
1128 let exec_stream = exec.execute(0, Arc::new(TaskContext::default())).unwrap();
1129
1130 let real_output = exec_stream.map(|r| r.unwrap()).collect::<Vec<_>>().await;
1131 if real_output != expected_output {
1133 let mut first_diff = 0;
1134 for (idx, (lhs, rhs)) in real_output.iter().zip(expected_output.iter()).enumerate() {
1135 if lhs != rhs {
1136 first_diff = idx;
1137 break;
1138 }
1139 }
1140 println!("first diff batch at {}", first_diff);
1141 println!(
1142 "ranges: {:?}",
1143 ranges
1144 .into_iter()
1145 .map(|r| (r.start.to_chrono_datetime(), r.end.to_chrono_datetime()))
1146 .enumerate()
1147 .collect::<Vec<_>>()
1148 );
1149
1150 let mut full_msg = String::new();
1151 {
1152 let mut buf = Vec::with_capacity(10 * real_output.len());
1153 for batch in real_output.iter().skip(first_diff) {
1154 let mut rb_json: Vec<u8> = Vec::new();
1155 let mut writer = ArrayWriter::new(&mut rb_json);
1156 writer.write(batch).unwrap();
1157 writer.finish().unwrap();
1158 buf.append(&mut rb_json);
1159 buf.push(b',');
1160 }
1161 let buf = String::from_utf8_lossy(&buf);
1163 full_msg += &format!("\ncase_id:{case_id}, real_output \n{buf}\n");
1164 }
1165 {
1166 let mut buf = Vec::with_capacity(10 * real_output.len());
1167 for batch in expected_output.iter().skip(first_diff) {
1168 let mut rb_json: Vec<u8> = Vec::new();
1169 let mut writer = ArrayWriter::new(&mut rb_json);
1170 writer.write(batch).unwrap();
1171 writer.finish().unwrap();
1172 buf.append(&mut rb_json);
1173 buf.push(b',');
1174 }
1175 let buf = String::from_utf8_lossy(&buf);
1176 full_msg += &format!("case_id:{case_id}, expected_output \n{buf}");
1177 }
1178 panic!(
1179 "case_{} failed, opt: {:?},\n real output has {} batches, {} rows, expected has {} batches with {} rows\nfull msg: {}",
1180 case_id,
1181 opt,
1182 real_output.len(),
1183 real_output.iter().map(|x| x.num_rows()).sum::<usize>(),
1184 expected_output.len(),
1185 expected_output.iter().map(|x| x.num_rows()).sum::<usize>(),
1186 full_msg
1187 );
1188 }
1189 }
1190}