1use std::any::Any;
22use std::pin::Pin;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26use arrow::array::ArrayRef;
27use arrow::compute::{concat, concat_batches, take_record_batch};
28use arrow_schema::SchemaRef;
29use common_recordbatch::{DfRecordBatch, DfSendableRecordBatchStream};
30use datafusion::common::arrow::compute::sort_to_indices;
31use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation};
32use datafusion::execution::{RecordBatchStream, TaskContext};
33use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
34use datafusion::physical_plan::{
35 DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, TopK,
36};
37use datafusion_common::{internal_err, DataFusionError};
38use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr};
39use futures::{Stream, StreamExt};
40use itertools::Itertools;
41use snafu::location;
42use store_api::region_engine::PartitionRange;
43
44use crate::{array_iter_helper, downcast_ts_array};
45
46#[derive(Debug, Clone)]
52pub struct PartSortExec {
53 expression: PhysicalSortExpr,
55 limit: Option<usize>,
56 input: Arc<dyn ExecutionPlan>,
57 metrics: ExecutionPlanMetricsSet,
59 partition_ranges: Vec<Vec<PartitionRange>>,
60 properties: PlanProperties,
61}
62
63impl PartSortExec {
64 pub fn new(
65 expression: PhysicalSortExpr,
66 limit: Option<usize>,
67 partition_ranges: Vec<Vec<PartitionRange>>,
68 input: Arc<dyn ExecutionPlan>,
69 ) -> Self {
70 let metrics = ExecutionPlanMetricsSet::new();
71 let properties = input.properties();
72 let properties = PlanProperties::new(
73 input.equivalence_properties().clone(),
74 input.output_partitioning().clone(),
75 properties.emission_type,
76 properties.boundedness,
77 );
78
79 Self {
80 expression,
81 limit,
82 input,
83 metrics,
84 partition_ranges,
85 properties,
86 }
87 }
88
89 pub fn to_stream(
90 &self,
91 context: Arc<TaskContext>,
92 partition: usize,
93 ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
94 let input_stream: DfSendableRecordBatchStream =
95 self.input.execute(partition, context.clone())?;
96
97 if partition >= self.partition_ranges.len() {
98 internal_err!(
99 "Partition index out of range: {} >= {} at {}",
100 partition,
101 self.partition_ranges.len(),
102 snafu::location!()
103 )?;
104 }
105
106 let df_stream = Box::pin(PartSortStream::new(
107 context,
108 self,
109 self.limit,
110 input_stream,
111 self.partition_ranges[partition].clone(),
112 partition,
113 )?) as _;
114
115 Ok(df_stream)
116 }
117}
118
119impl DisplayAs for PartSortExec {
120 fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 write!(
122 f,
123 "PartSortExec: expr={} num_ranges={}",
124 self.expression,
125 self.partition_ranges.len(),
126 )?;
127 if let Some(limit) = self.limit {
128 write!(f, " limit={}", limit)?;
129 }
130 Ok(())
131 }
132}
133
134impl ExecutionPlan for PartSortExec {
135 fn name(&self) -> &str {
136 "PartSortExec"
137 }
138
139 fn as_any(&self) -> &dyn Any {
140 self
141 }
142
143 fn schema(&self) -> SchemaRef {
144 self.input.schema()
145 }
146
147 fn properties(&self) -> &PlanProperties {
148 &self.properties
149 }
150
151 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
152 vec![&self.input]
153 }
154
155 fn with_new_children(
156 self: Arc<Self>,
157 children: Vec<Arc<dyn ExecutionPlan>>,
158 ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
159 let new_input = if let Some(first) = children.first() {
160 first
161 } else {
162 internal_err!("No children found")?
163 };
164 Ok(Arc::new(Self::new(
165 self.expression.clone(),
166 self.limit,
167 self.partition_ranges.clone(),
168 new_input.clone(),
169 )))
170 }
171
172 fn execute(
173 &self,
174 partition: usize,
175 context: Arc<TaskContext>,
176 ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
177 self.to_stream(context, partition)
178 }
179
180 fn metrics(&self) -> Option<MetricsSet> {
181 Some(self.metrics.clone_inner())
182 }
183
184 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
190 vec![false]
191 }
192}
193
194enum PartSortBuffer {
195 All(Vec<DfRecordBatch>),
196 Top(TopK, usize),
201}
202
203impl PartSortBuffer {
204 pub fn is_empty(&self) -> bool {
205 match self {
206 PartSortBuffer::All(v) => v.is_empty(),
207 PartSortBuffer::Top(_, cnt) => *cnt == 0,
208 }
209 }
210}
211
212struct PartSortStream {
213 reservation: MemoryReservation,
215 buffer: PartSortBuffer,
216 expression: PhysicalSortExpr,
217 limit: Option<usize>,
218 produced: usize,
219 input: DfSendableRecordBatchStream,
220 input_complete: bool,
221 schema: SchemaRef,
222 partition_ranges: Vec<PartitionRange>,
223 #[allow(dead_code)] partition: usize,
225 cur_part_idx: usize,
226 evaluating_batch: Option<DfRecordBatch>,
227 metrics: BaselineMetrics,
228 context: Arc<TaskContext>,
229 root_metrics: ExecutionPlanMetricsSet,
230}
231
232impl PartSortStream {
233 fn new(
234 context: Arc<TaskContext>,
235 sort: &PartSortExec,
236 limit: Option<usize>,
237 input: DfSendableRecordBatchStream,
238 partition_ranges: Vec<PartitionRange>,
239 partition: usize,
240 ) -> datafusion_common::Result<Self> {
241 let buffer = if let Some(limit) = limit {
242 PartSortBuffer::Top(
243 TopK::try_new(
244 partition,
245 sort.schema().clone(),
246 LexOrdering::new(vec![sort.expression.clone()]),
247 limit,
248 context.session_config().batch_size(),
249 context.runtime_env(),
250 &sort.metrics,
251 )?,
252 0,
253 )
254 } else {
255 PartSortBuffer::All(Vec::new())
256 };
257
258 Ok(Self {
259 reservation: MemoryConsumer::new("PartSortStream".to_string())
260 .register(&context.runtime_env().memory_pool),
261 buffer,
262 expression: sort.expression.clone(),
263 limit,
264 produced: 0,
265 input,
266 input_complete: false,
267 schema: sort.input.schema(),
268 partition_ranges,
269 partition,
270 cur_part_idx: 0,
271 evaluating_batch: None,
272 metrics: BaselineMetrics::new(&sort.metrics, partition),
273 context,
274 root_metrics: sort.metrics.clone(),
275 })
276 }
277}
278
279macro_rules! array_check_helper {
280 ($t:ty, $unit:expr, $arr:expr, $cur_range:expr, $min_max_idx:expr) => {{
281 if $cur_range.start.unit().as_arrow_time_unit() != $unit
282 || $cur_range.end.unit().as_arrow_time_unit() != $unit
283 {
284 internal_err!(
285 "PartitionRange unit mismatch, expect {:?}, found {:?}",
286 $cur_range.start.unit(),
287 $unit
288 )?;
289 }
290 let arr = $arr
291 .as_any()
292 .downcast_ref::<arrow::array::PrimitiveArray<$t>>()
293 .unwrap();
294
295 let min = arr.value($min_max_idx.0);
296 let max = arr.value($min_max_idx.1);
297 let (min, max) = if min < max{
298 (min, max)
299 } else {
300 (max, min)
301 };
302 let cur_min = $cur_range.start.value();
303 let cur_max = $cur_range.end.value();
304 if !(min >= cur_min && max < cur_max) {
306 internal_err!(
307 "Sort column min/max value out of partition range: sort_column.min_max=[{:?}, {:?}] not in PartitionRange=[{:?}, {:?}]",
308 min,
309 max,
310 cur_min,
311 cur_max
312 )?;
313 }
314 }};
315}
316
317impl PartSortStream {
318 fn check_in_range(
320 &self,
321 sort_column: &ArrayRef,
322 min_max_idx: (usize, usize),
323 ) -> datafusion_common::Result<()> {
324 if self.cur_part_idx >= self.partition_ranges.len() {
325 internal_err!(
326 "Partition index out of range: {} >= {} at {}",
327 self.cur_part_idx,
328 self.partition_ranges.len(),
329 snafu::location!()
330 )?;
331 }
332 let cur_range = self.partition_ranges[self.cur_part_idx];
333
334 downcast_ts_array!(
335 sort_column.data_type() => (array_check_helper, sort_column, cur_range, min_max_idx),
336 _ => internal_err!(
337 "Unsupported data type for sort column: {:?}",
338 sort_column.data_type()
339 )?,
340 );
341
342 Ok(())
343 }
344
345 fn try_find_next_range(
350 &self,
351 sort_column: &ArrayRef,
352 ) -> datafusion_common::Result<Option<usize>> {
353 if sort_column.is_empty() {
354 return Ok(Some(0));
355 }
356
357 if self.cur_part_idx >= self.partition_ranges.len() {
359 internal_err!(
360 "Partition index out of range: {} >= {} at {}",
361 self.cur_part_idx,
362 self.partition_ranges.len(),
363 snafu::location!()
364 )?;
365 }
366 let cur_range = self.partition_ranges[self.cur_part_idx];
367
368 let sort_column_iter = downcast_ts_array!(
369 sort_column.data_type() => (array_iter_helper, sort_column),
370 _ => internal_err!(
371 "Unsupported data type for sort column: {:?}",
372 sort_column.data_type()
373 )?,
374 );
375
376 for (idx, val) in sort_column_iter {
377 if let Some(val) = val {
379 if val >= cur_range.end.value() || val < cur_range.start.value() {
380 return Ok(Some(idx));
381 }
382 }
383 }
384
385 Ok(None)
386 }
387
388 fn push_buffer(&mut self, batch: DfRecordBatch) -> datafusion_common::Result<()> {
389 match &mut self.buffer {
390 PartSortBuffer::All(v) => v.push(batch),
391 PartSortBuffer::Top(top, cnt) => {
392 *cnt += batch.num_rows();
393 top.insert_batch(batch)?;
394 }
395 }
396
397 Ok(())
398 }
399
400 fn sort_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
404 match &mut self.buffer {
405 PartSortBuffer::All(_) => self.sort_all_buffer(),
406 PartSortBuffer::Top(_, _) => self.sort_top_buffer(),
407 }
408 }
409
410 fn sort_all_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
412 let PartSortBuffer::All(buffer) =
413 std::mem::replace(&mut self.buffer, PartSortBuffer::All(Vec::new()))
414 else {
415 unreachable!("buffer type is checked before and should be All variant")
416 };
417
418 if buffer.is_empty() {
419 return Ok(DfRecordBatch::new_empty(self.schema.clone()));
420 }
421 let mut sort_columns = Vec::with_capacity(buffer.len());
422 let mut opt = None;
423 for batch in buffer.iter() {
424 let sort_column = self.expression.evaluate_to_sort_column(batch)?;
425 opt = opt.or(sort_column.options);
426 sort_columns.push(sort_column.values);
427 }
428
429 let sort_column =
430 concat(&sort_columns.iter().map(|a| a.as_ref()).collect_vec()).map_err(|e| {
431 DataFusionError::ArrowError(
432 e,
433 Some(format!("Fail to concat sort columns at {}", location!())),
434 )
435 })?;
436
437 let indices = sort_to_indices(&sort_column, opt, self.limit).map_err(|e| {
438 DataFusionError::ArrowError(
439 e,
440 Some(format!("Fail to sort to indices at {}", location!())),
441 )
442 })?;
443 if indices.is_empty() {
444 return Ok(DfRecordBatch::new_empty(self.schema.clone()));
445 }
446
447 self.check_in_range(
448 &sort_column,
449 (
450 indices.value(0) as usize,
451 indices.value(indices.len() - 1) as usize,
452 ),
453 )
454 .inspect_err(|_e| {
455 #[cfg(debug_assertions)]
456 common_telemetry::error!(
457 "Fail to check sort column in range at {}, current_idx: {}, num_rows: {}, err: {}",
458 self.partition,
459 self.cur_part_idx,
460 sort_column.len(),
461 _e
462 );
463 })?;
464
465 let total_mem: usize = buffer.iter().map(|r| r.get_array_memory_size()).sum();
467 self.reservation.try_grow(total_mem * 2)?;
468
469 let full_input = concat_batches(&self.schema, &buffer).map_err(|e| {
470 DataFusionError::ArrowError(
471 e,
472 Some(format!(
473 "Fail to concat input batches when sorting at {}",
474 location!()
475 )),
476 )
477 })?;
478
479 let sorted = take_record_batch(&full_input, &indices).map_err(|e| {
480 DataFusionError::ArrowError(
481 e,
482 Some(format!(
483 "Fail to take result record batch when sorting at {}",
484 location!()
485 )),
486 )
487 })?;
488
489 self.produced += sorted.num_rows();
490 drop(full_input);
491 self.reservation.shrink(2 * total_mem);
493 Ok(sorted)
494 }
495
496 fn sort_top_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
498 let new_top_buffer = TopK::try_new(
499 self.partition,
500 self.schema().clone(),
501 LexOrdering::new(vec![self.expression.clone()]),
502 self.limit.unwrap(),
503 self.context.session_config().batch_size(),
504 self.context.runtime_env(),
505 &self.root_metrics,
506 )?;
507 let PartSortBuffer::Top(top_k, _) =
508 std::mem::replace(&mut self.buffer, PartSortBuffer::Top(new_top_buffer, 0))
509 else {
510 unreachable!("buffer type is checked before and should be Top variant")
511 };
512
513 let mut result_stream = top_k.emit()?;
514 let mut placeholder_ctx = std::task::Context::from_waker(futures::task::noop_waker_ref());
515 let mut results = vec![];
516 loop {
518 match result_stream.poll_next_unpin(&mut placeholder_ctx) {
519 Poll::Ready(Some(batch)) => {
520 let batch = batch?;
521 results.push(batch);
522 }
523 Poll::Pending => {
524 #[cfg(debug_assertions)]
525 unreachable!("TopK result stream should always be ready")
526 }
527 Poll::Ready(None) => {
528 break;
529 }
530 }
531 }
532
533 let concat_batch = concat_batches(&self.schema, &results).map_err(|e| {
534 DataFusionError::ArrowError(
535 e,
536 Some(format!(
537 "Fail to concat top k result record batch when sorting at {}",
538 location!()
539 )),
540 )
541 })?;
542
543 Ok(concat_batch)
544 }
545
546 fn split_batch(
556 &mut self,
557 batch: DfRecordBatch,
558 ) -> datafusion_common::Result<Option<DfRecordBatch>> {
559 if batch.num_rows() == 0 {
560 return Ok(None);
561 }
562
563 let sort_column = self
564 .expression
565 .expr
566 .evaluate(&batch)?
567 .into_array(batch.num_rows())?;
568
569 let next_range_idx = self.try_find_next_range(&sort_column)?;
570 let Some(idx) = next_range_idx else {
571 self.push_buffer(batch)?;
572 return Ok(None);
574 };
575
576 let this_range = batch.slice(0, idx);
577 let remaining_range = batch.slice(idx, batch.num_rows() - idx);
578 if this_range.num_rows() != 0 {
579 self.push_buffer(this_range)?;
580 }
581 let sorted_batch = self.sort_buffer();
583 self.cur_part_idx += 1;
585 let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
586 if self.try_find_next_range(&next_sort_column)?.is_some() {
587 self.evaluating_batch = Some(remaining_range);
590 } else {
591 if remaining_range.num_rows() != 0 {
594 self.push_buffer(remaining_range)?;
595 }
596 }
597
598 sorted_batch.map(|x| if x.num_rows() == 0 { None } else { Some(x) })
599 }
600
601 pub fn poll_next_inner(
602 mut self: Pin<&mut Self>,
603 cx: &mut Context<'_>,
604 ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
605 loop {
606 if self.input_complete {
608 if self.buffer.is_empty() {
609 return Poll::Ready(None);
610 } else {
611 return Poll::Ready(Some(self.sort_buffer()));
612 }
613 }
614
615 if let Some(evaluating_batch) = self.evaluating_batch.take()
618 && evaluating_batch.num_rows() != 0
619 {
620 if let Some(sorted_batch) = self.split_batch(evaluating_batch)? {
621 return Poll::Ready(Some(Ok(sorted_batch)));
622 } else {
623 continue;
624 }
625 }
626
627 let res = self.input.as_mut().poll_next(cx);
629 match res {
630 Poll::Ready(Some(Ok(batch))) => {
631 if let Some(sorted_batch) = self.split_batch(batch)? {
632 return Poll::Ready(Some(Ok(sorted_batch)));
633 } else {
634 continue;
635 }
636 }
637 Poll::Ready(None) => {
639 self.input_complete = true;
640 continue;
641 }
642 Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
643 Poll::Pending => return Poll::Pending,
644 }
645 }
646 }
647}
648
649impl Stream for PartSortStream {
650 type Item = datafusion_common::Result<DfRecordBatch>;
651
652 fn poll_next(
653 mut self: Pin<&mut Self>,
654 cx: &mut Context<'_>,
655 ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
656 let result = self.as_mut().poll_next_inner(cx);
657 self.metrics.record_poll(result)
658 }
659}
660
661impl RecordBatchStream for PartSortStream {
662 fn schema(&self) -> SchemaRef {
663 self.schema.clone()
664 }
665}
666
667#[cfg(test)]
668mod test {
669 use std::sync::Arc;
670
671 use arrow::json::ArrayWriter;
672 use arrow_schema::{DataType, Field, Schema, SortOptions, TimeUnit};
673 use common_time::Timestamp;
674 use datafusion_physical_expr::expressions::Column;
675 use futures::StreamExt;
676 use store_api::region_engine::PartitionRange;
677
678 use super::*;
679 use crate::test_util::{new_ts_array, MockInputExec};
680
681 #[tokio::test]
682 async fn fuzzy_test() {
683 let test_cnt = 100;
684 let part_cnt_bound = 100;
686 let range_size_bound = 100;
688 let range_offset_bound = 100;
689 let batch_cnt_bound = 20;
691 let batch_size_bound = 100;
692
693 let mut rng = fastrand::Rng::new();
694 rng.seed(1337);
695
696 let mut test_cases = Vec::new();
697
698 for case_id in 0..test_cnt {
699 let mut bound_val: Option<i64> = None;
700 let descending = rng.bool();
701 let nulls_first = rng.bool();
702 let opt = SortOptions {
703 descending,
704 nulls_first,
705 };
706 let limit = if rng.bool() {
707 Some(rng.usize(0..batch_cnt_bound * batch_size_bound))
708 } else {
709 None
710 };
711 let unit = match rng.u8(0..3) {
712 0 => TimeUnit::Second,
713 1 => TimeUnit::Millisecond,
714 2 => TimeUnit::Microsecond,
715 _ => TimeUnit::Nanosecond,
716 };
717
718 let schema = Schema::new(vec![Field::new(
719 "ts",
720 DataType::Timestamp(unit, None),
721 false,
722 )]);
723 let schema = Arc::new(schema);
724
725 let mut input_ranged_data = vec![];
726 let mut output_ranges = vec![];
727 let mut output_data = vec![];
728 for part_id in 0..rng.usize(0..part_cnt_bound) {
730 let (start, end) = if descending {
732 let end = bound_val
733 .map(
734 |i| i
735 .checked_sub(rng.i64(0..range_offset_bound))
736 .expect("Bad luck, fuzzy test generate data that will overflow, change seed and try again")
737 )
738 .unwrap_or_else(|| rng.i64(-100000000..100000000));
739 bound_val = Some(end);
740 let start = end - rng.i64(1..range_size_bound);
741 let start = Timestamp::new(start, unit.into());
742 let end = Timestamp::new(end, unit.into());
743 (start, end)
744 } else {
745 let start = bound_val
746 .map(|i| i + rng.i64(0..range_offset_bound))
747 .unwrap_or_else(|| rng.i64(..));
748 bound_val = Some(start);
749 let end = start + rng.i64(1..range_size_bound);
750 let start = Timestamp::new(start, unit.into());
751 let end = Timestamp::new(end, unit.into());
752 (start, end)
753 };
754 assert!(start < end);
755
756 let mut per_part_sort_data = vec![];
757 let mut batches = vec![];
758 for _batch_idx in 0..rng.usize(1..batch_cnt_bound) {
759 let cnt = rng.usize(0..batch_size_bound) + 1;
760 let iter = 0..rng.usize(0..cnt);
761 let mut data_gen = iter
762 .map(|_| rng.i64(start.value()..end.value()))
763 .collect_vec();
764 if data_gen.is_empty() {
765 continue;
767 }
768 data_gen.sort();
770 per_part_sort_data.extend(data_gen.clone());
771 let arr = new_ts_array(unit, data_gen.clone());
772 let batch = DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap();
773 batches.push(batch);
774 }
775
776 let range = PartitionRange {
777 start,
778 end,
779 num_rows: batches.iter().map(|b| b.num_rows()).sum(),
780 identifier: part_id,
781 };
782 input_ranged_data.push((range, batches));
783
784 output_ranges.push(range);
785 if per_part_sort_data.is_empty() {
786 continue;
787 }
788 output_data.extend_from_slice(&per_part_sort_data);
789 }
790
791 let mut output_data_iter = output_data.iter().peekable();
793 let mut output_data = vec![];
794 for range in output_ranges.clone() {
795 let mut cur_data = vec![];
796 while let Some(val) = output_data_iter.peek() {
797 if **val < range.start.value() || **val >= range.end.value() {
798 break;
799 }
800 cur_data.push(*output_data_iter.next().unwrap());
801 }
802
803 if cur_data.is_empty() {
804 continue;
805 }
806
807 if descending {
808 cur_data.sort_by(|a, b| b.cmp(a));
809 } else {
810 cur_data.sort();
811 }
812 output_data.push(cur_data);
813 }
814
815 let expected_output = output_data
816 .into_iter()
817 .map(|a| {
818 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
819 })
820 .map(|rb| {
821 if let Some(limit) = limit
823 && rb.num_rows() > limit
824 {
825 rb.slice(0, limit)
826 } else {
827 rb
828 }
829 })
830 .collect_vec();
831
832 test_cases.push((
833 case_id,
834 unit,
835 input_ranged_data,
836 schema,
837 opt,
838 limit,
839 expected_output,
840 ));
841 }
842
843 for (case_id, _unit, input_ranged_data, schema, opt, limit, expected_output) in test_cases {
844 run_test(
845 case_id,
846 input_ranged_data,
847 schema,
848 opt,
849 limit,
850 expected_output,
851 )
852 .await;
853 }
854 }
855
856 #[tokio::test]
857 async fn simple_case() {
858 let testcases = vec![
859 (
860 TimeUnit::Millisecond,
861 vec![
862 ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]),
863 ((5, 10), vec![vec![5, 6], vec![7, 8]]),
864 ],
865 false,
866 None,
867 vec![vec![1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9]],
868 ),
869 (
870 TimeUnit::Millisecond,
871 vec![
872 ((5, 10), vec![vec![5, 6], vec![7, 8, 9]]),
873 ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
874 ],
875 true,
876 None,
877 vec![vec![9, 8, 7, 6, 5], vec![8, 7, 6, 5, 4, 3, 2, 1]],
878 ),
879 (
880 TimeUnit::Millisecond,
881 vec![
882 ((5, 10), vec![]),
883 ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
884 ],
885 true,
886 None,
887 vec![vec![8, 7, 6, 5, 4, 3, 2, 1]],
888 ),
889 (
890 TimeUnit::Millisecond,
891 vec![
892 ((15, 20), vec![vec![17, 18, 19]]),
893 ((10, 15), vec![]),
894 ((5, 10), vec![]),
895 ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
896 ],
897 true,
898 None,
899 vec![vec![19, 18, 17], vec![8, 7, 6, 5, 4, 3, 2, 1]],
900 ),
901 (
902 TimeUnit::Millisecond,
903 vec![
904 ((15, 20), vec![]),
905 ((10, 15), vec![]),
906 ((5, 10), vec![]),
907 ((0, 10), vec![]),
908 ],
909 true,
910 None,
911 vec![],
912 ),
913 (
914 TimeUnit::Millisecond,
915 vec![
916 (
917 (15, 20),
918 vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
919 ),
920 ((10, 15), vec![]),
921 ((5, 10), vec![]),
922 ((0, 10), vec![]),
923 ],
924 true,
925 None,
926 vec![
927 vec![19, 17, 15],
928 vec![12, 11, 10],
929 vec![9, 8, 7, 6, 5],
930 vec![4, 3, 2, 1],
931 ],
932 ),
933 (
934 TimeUnit::Millisecond,
935 vec![
936 (
937 (15, 20),
938 vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
939 ),
940 ((10, 15), vec![]),
941 ((5, 10), vec![]),
942 ((0, 10), vec![]),
943 ],
944 true,
945 Some(2),
946 vec![vec![19, 17], vec![12, 11], vec![9, 8], vec![4, 3]],
947 ),
948 ];
949
950 for (identifier, (unit, input_ranged_data, descending, limit, expected_output)) in
951 testcases.into_iter().enumerate()
952 {
953 let schema = Schema::new(vec![Field::new(
954 "ts",
955 DataType::Timestamp(unit, None),
956 false,
957 )]);
958 let schema = Arc::new(schema);
959 let opt = SortOptions {
960 descending,
961 ..Default::default()
962 };
963
964 let input_ranged_data = input_ranged_data
965 .into_iter()
966 .map(|(range, data)| {
967 let part = PartitionRange {
968 start: Timestamp::new(range.0, unit.into()),
969 end: Timestamp::new(range.1, unit.into()),
970 num_rows: data.iter().map(|b| b.len()).sum(),
971 identifier,
972 };
973
974 let batches = data
975 .into_iter()
976 .map(|b| {
977 let arr = new_ts_array(unit, b);
978 DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap()
979 })
980 .collect_vec();
981 (part, batches)
982 })
983 .collect_vec();
984
985 let expected_output = expected_output
986 .into_iter()
987 .map(|a| {
988 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
989 })
990 .collect_vec();
991
992 run_test(
993 identifier,
994 input_ranged_data,
995 schema.clone(),
996 opt,
997 limit,
998 expected_output,
999 )
1000 .await;
1001 }
1002 }
1003
1004 #[allow(clippy::print_stdout)]
1005 async fn run_test(
1006 case_id: usize,
1007 input_ranged_data: Vec<(PartitionRange, Vec<DfRecordBatch>)>,
1008 schema: SchemaRef,
1009 opt: SortOptions,
1010 limit: Option<usize>,
1011 expected_output: Vec<DfRecordBatch>,
1012 ) {
1013 for rb in &expected_output {
1014 if let Some(limit) = limit {
1015 assert!(
1016 rb.num_rows() <= limit,
1017 "Expect row count in expected output's batch({}) <= limit({})",
1018 rb.num_rows(),
1019 limit
1020 );
1021 }
1022 }
1023 let (ranges, batches): (Vec<_>, Vec<_>) = input_ranged_data.clone().into_iter().unzip();
1024
1025 let batches = batches
1026 .into_iter()
1027 .flat_map(|mut cols| {
1028 cols.push(DfRecordBatch::new_empty(schema.clone()));
1029 cols
1030 })
1031 .collect_vec();
1032 let mock_input = MockInputExec::new(batches, schema.clone());
1033
1034 let exec = PartSortExec::new(
1035 PhysicalSortExpr {
1036 expr: Arc::new(Column::new("ts", 0)),
1037 options: opt,
1038 },
1039 limit,
1040 vec![ranges.clone()],
1041 Arc::new(mock_input),
1042 );
1043
1044 let exec_stream = exec.execute(0, Arc::new(TaskContext::default())).unwrap();
1045
1046 let real_output = exec_stream.map(|r| r.unwrap()).collect::<Vec<_>>().await;
1047 if real_output != expected_output {
1049 let mut first_diff = 0;
1050 for (idx, (lhs, rhs)) in real_output.iter().zip(expected_output.iter()).enumerate() {
1051 if lhs != rhs {
1052 first_diff = idx;
1053 break;
1054 }
1055 }
1056 println!("first diff batch at {}", first_diff);
1057 println!(
1058 "ranges: {:?}",
1059 ranges
1060 .into_iter()
1061 .map(|r| (r.start.to_chrono_datetime(), r.end.to_chrono_datetime()))
1062 .enumerate()
1063 .collect::<Vec<_>>()
1064 );
1065
1066 let mut full_msg = String::new();
1067 {
1068 let mut buf = Vec::with_capacity(10 * real_output.len());
1069 for batch in real_output.iter().skip(first_diff) {
1070 let mut rb_json: Vec<u8> = Vec::new();
1071 let mut writer = ArrayWriter::new(&mut rb_json);
1072 writer.write(batch).unwrap();
1073 writer.finish().unwrap();
1074 buf.append(&mut rb_json);
1075 buf.push(b',');
1076 }
1077 let buf = String::from_utf8_lossy(&buf);
1079 full_msg += &format!("\ncase_id:{case_id}, real_output \n{buf}\n");
1080 }
1081 {
1082 let mut buf = Vec::with_capacity(10 * real_output.len());
1083 for batch in expected_output.iter().skip(first_diff) {
1084 let mut rb_json: Vec<u8> = Vec::new();
1085 let mut writer = ArrayWriter::new(&mut rb_json);
1086 writer.write(batch).unwrap();
1087 writer.finish().unwrap();
1088 buf.append(&mut rb_json);
1089 buf.push(b',');
1090 }
1091 let buf = String::from_utf8_lossy(&buf);
1092 full_msg += &format!("case_id:{case_id}, expected_output \n{buf}");
1093 }
1094 panic!(
1095 "case_{} failed, opt: {:?},\n real output has {} batches, {} rows, expected has {} batches with {} rows\nfull msg: {}",
1096 case_id, opt,
1097 real_output.len(),
1098 real_output.iter().map(|x|x.num_rows()).sum::<usize>(),
1099 expected_output.len(),
1100 expected_output.iter().map(|x|x.num_rows()).sum::<usize>(), full_msg
1101 );
1102 }
1103 }
1104}