1use std::any::Any;
22use std::pin::Pin;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26use arrow::array::ArrayRef;
27use arrow::compute::{concat, concat_batches, take_record_batch};
28use arrow_schema::SchemaRef;
29use common_recordbatch::{DfRecordBatch, DfSendableRecordBatchStream};
30use datafusion::common::arrow::compute::sort_to_indices;
31use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation};
32use datafusion::execution::{RecordBatchStream, TaskContext};
33use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
34use datafusion::physical_plan::{
35 DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, TopK,
36};
37use datafusion_common::{internal_err, DataFusionError};
38use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr};
39use futures::{Stream, StreamExt};
40use itertools::Itertools;
41use snafu::location;
42use store_api::region_engine::PartitionRange;
43
44use crate::{array_iter_helper, downcast_ts_array};
45
46#[derive(Debug, Clone)]
52pub struct PartSortExec {
53 expression: PhysicalSortExpr,
55 limit: Option<usize>,
56 input: Arc<dyn ExecutionPlan>,
57 metrics: ExecutionPlanMetricsSet,
59 partition_ranges: Vec<Vec<PartitionRange>>,
60 properties: PlanProperties,
61}
62
63impl PartSortExec {
64 pub fn new(
65 expression: PhysicalSortExpr,
66 limit: Option<usize>,
67 partition_ranges: Vec<Vec<PartitionRange>>,
68 input: Arc<dyn ExecutionPlan>,
69 ) -> Self {
70 let metrics = ExecutionPlanMetricsSet::new();
71 let properties = input.properties();
72 let properties = PlanProperties::new(
73 input.equivalence_properties().clone(),
74 input.output_partitioning().clone(),
75 properties.emission_type,
76 properties.boundedness,
77 );
78
79 Self {
80 expression,
81 limit,
82 input,
83 metrics,
84 partition_ranges,
85 properties,
86 }
87 }
88
89 pub fn to_stream(
90 &self,
91 context: Arc<TaskContext>,
92 partition: usize,
93 ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
94 let input_stream: DfSendableRecordBatchStream =
95 self.input.execute(partition, context.clone())?;
96
97 if partition >= self.partition_ranges.len() {
98 internal_err!(
99 "Partition index out of range: {} >= {}",
100 partition,
101 self.partition_ranges.len()
102 )?;
103 }
104
105 let df_stream = Box::pin(PartSortStream::new(
106 context,
107 self,
108 self.limit,
109 input_stream,
110 self.partition_ranges[partition].clone(),
111 partition,
112 )?) as _;
113
114 Ok(df_stream)
115 }
116}
117
118impl DisplayAs for PartSortExec {
119 fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 write!(
121 f,
122 "PartSortExec: expr={} num_ranges={}",
123 self.expression,
124 self.partition_ranges.len(),
125 )?;
126 if let Some(limit) = self.limit {
127 write!(f, " limit={}", limit)?;
128 }
129 Ok(())
130 }
131}
132
133impl ExecutionPlan for PartSortExec {
134 fn name(&self) -> &str {
135 "PartSortExec"
136 }
137
138 fn as_any(&self) -> &dyn Any {
139 self
140 }
141
142 fn schema(&self) -> SchemaRef {
143 self.input.schema()
144 }
145
146 fn properties(&self) -> &PlanProperties {
147 &self.properties
148 }
149
150 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
151 vec![&self.input]
152 }
153
154 fn with_new_children(
155 self: Arc<Self>,
156 children: Vec<Arc<dyn ExecutionPlan>>,
157 ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
158 let new_input = if let Some(first) = children.first() {
159 first
160 } else {
161 internal_err!("No children found")?
162 };
163 Ok(Arc::new(Self::new(
164 self.expression.clone(),
165 self.limit,
166 self.partition_ranges.clone(),
167 new_input.clone(),
168 )))
169 }
170
171 fn execute(
172 &self,
173 partition: usize,
174 context: Arc<TaskContext>,
175 ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
176 self.to_stream(context, partition)
177 }
178
179 fn metrics(&self) -> Option<MetricsSet> {
180 Some(self.metrics.clone_inner())
181 }
182
183 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
189 vec![false]
190 }
191}
192
193enum PartSortBuffer {
194 All(Vec<DfRecordBatch>),
195 Top(TopK, usize),
200}
201
202impl PartSortBuffer {
203 pub fn is_empty(&self) -> bool {
204 match self {
205 PartSortBuffer::All(v) => v.is_empty(),
206 PartSortBuffer::Top(_, cnt) => *cnt == 0,
207 }
208 }
209}
210
211struct PartSortStream {
212 reservation: MemoryReservation,
214 buffer: PartSortBuffer,
215 expression: PhysicalSortExpr,
216 limit: Option<usize>,
217 produced: usize,
218 input: DfSendableRecordBatchStream,
219 input_complete: bool,
220 schema: SchemaRef,
221 partition_ranges: Vec<PartitionRange>,
222 #[allow(dead_code)] partition: usize,
224 cur_part_idx: usize,
225 evaluating_batch: Option<DfRecordBatch>,
226 metrics: BaselineMetrics,
227 context: Arc<TaskContext>,
228 root_metrics: ExecutionPlanMetricsSet,
229}
230
231impl PartSortStream {
232 fn new(
233 context: Arc<TaskContext>,
234 sort: &PartSortExec,
235 limit: Option<usize>,
236 input: DfSendableRecordBatchStream,
237 partition_ranges: Vec<PartitionRange>,
238 partition: usize,
239 ) -> datafusion_common::Result<Self> {
240 let buffer = if let Some(limit) = limit {
241 PartSortBuffer::Top(
242 TopK::try_new(
243 partition,
244 sort.schema().clone(),
245 LexOrdering::new(vec![sort.expression.clone()]),
246 limit,
247 context.session_config().batch_size(),
248 context.runtime_env(),
249 &sort.metrics,
250 )?,
251 0,
252 )
253 } else {
254 PartSortBuffer::All(Vec::new())
255 };
256
257 Ok(Self {
258 reservation: MemoryConsumer::new("PartSortStream".to_string())
259 .register(&context.runtime_env().memory_pool),
260 buffer,
261 expression: sort.expression.clone(),
262 limit,
263 produced: 0,
264 input,
265 input_complete: false,
266 schema: sort.input.schema(),
267 partition_ranges,
268 partition,
269 cur_part_idx: 0,
270 evaluating_batch: None,
271 metrics: BaselineMetrics::new(&sort.metrics, partition),
272 context,
273 root_metrics: sort.metrics.clone(),
274 })
275 }
276}
277
278macro_rules! array_check_helper {
279 ($t:ty, $unit:expr, $arr:expr, $cur_range:expr, $min_max_idx:expr) => {{
280 if $cur_range.start.unit().as_arrow_time_unit() != $unit
281 || $cur_range.end.unit().as_arrow_time_unit() != $unit
282 {
283 internal_err!(
284 "PartitionRange unit mismatch, expect {:?}, found {:?}",
285 $cur_range.start.unit(),
286 $unit
287 )?;
288 }
289 let arr = $arr
290 .as_any()
291 .downcast_ref::<arrow::array::PrimitiveArray<$t>>()
292 .unwrap();
293
294 let min = arr.value($min_max_idx.0);
295 let max = arr.value($min_max_idx.1);
296 let (min, max) = if min < max{
297 (min, max)
298 } else {
299 (max, min)
300 };
301 let cur_min = $cur_range.start.value();
302 let cur_max = $cur_range.end.value();
303 if !(min >= cur_min && max < cur_max) {
305 internal_err!(
306 "Sort column min/max value out of partition range: sort_column.min_max=[{:?}, {:?}] not in PartitionRange=[{:?}, {:?}]",
307 min,
308 max,
309 cur_min,
310 cur_max
311 )?;
312 }
313 }};
314}
315
316impl PartSortStream {
317 fn check_in_range(
319 &self,
320 sort_column: &ArrayRef,
321 min_max_idx: (usize, usize),
322 ) -> datafusion_common::Result<()> {
323 if self.cur_part_idx >= self.partition_ranges.len() {
324 internal_err!(
325 "Partition index out of range: {} >= {}",
326 self.cur_part_idx,
327 self.partition_ranges.len()
328 )?;
329 }
330 let cur_range = self.partition_ranges[self.cur_part_idx];
331
332 downcast_ts_array!(
333 sort_column.data_type() => (array_check_helper, sort_column, cur_range, min_max_idx),
334 _ => internal_err!(
335 "Unsupported data type for sort column: {:?}",
336 sort_column.data_type()
337 )?,
338 );
339
340 Ok(())
341 }
342
343 fn try_find_next_range(
348 &self,
349 sort_column: &ArrayRef,
350 ) -> datafusion_common::Result<Option<usize>> {
351 if sort_column.is_empty() {
352 return Ok(Some(0));
353 }
354
355 if self.cur_part_idx >= self.partition_ranges.len() {
357 internal_err!(
358 "Partition index out of range: {} >= {}",
359 self.cur_part_idx,
360 self.partition_ranges.len()
361 )?;
362 }
363 let cur_range = self.partition_ranges[self.cur_part_idx];
364
365 let sort_column_iter = downcast_ts_array!(
366 sort_column.data_type() => (array_iter_helper, sort_column),
367 _ => internal_err!(
368 "Unsupported data type for sort column: {:?}",
369 sort_column.data_type()
370 )?,
371 );
372
373 for (idx, val) in sort_column_iter {
374 if let Some(val) = val {
376 if val >= cur_range.end.value() || val < cur_range.start.value() {
377 return Ok(Some(idx));
378 }
379 }
380 }
381
382 Ok(None)
383 }
384
385 fn push_buffer(&mut self, batch: DfRecordBatch) -> datafusion_common::Result<()> {
386 match &mut self.buffer {
387 PartSortBuffer::All(v) => v.push(batch),
388 PartSortBuffer::Top(top, cnt) => {
389 *cnt += batch.num_rows();
390 top.insert_batch(batch)?;
391 }
392 }
393
394 Ok(())
395 }
396
397 fn sort_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
401 match &mut self.buffer {
402 PartSortBuffer::All(_) => self.sort_all_buffer(),
403 PartSortBuffer::Top(_, _) => self.sort_top_buffer(),
404 }
405 }
406
407 fn sort_all_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
409 let PartSortBuffer::All(buffer) =
410 std::mem::replace(&mut self.buffer, PartSortBuffer::All(Vec::new()))
411 else {
412 unreachable!("buffer type is checked before and should be All variant")
413 };
414
415 if buffer.is_empty() {
416 return Ok(DfRecordBatch::new_empty(self.schema.clone()));
417 }
418 let mut sort_columns = Vec::with_capacity(buffer.len());
419 let mut opt = None;
420 for batch in buffer.iter() {
421 let sort_column = self.expression.evaluate_to_sort_column(batch)?;
422 opt = opt.or(sort_column.options);
423 sort_columns.push(sort_column.values);
424 }
425
426 let sort_column =
427 concat(&sort_columns.iter().map(|a| a.as_ref()).collect_vec()).map_err(|e| {
428 DataFusionError::ArrowError(
429 e,
430 Some(format!("Fail to concat sort columns at {}", location!())),
431 )
432 })?;
433
434 let indices = sort_to_indices(&sort_column, opt, self.limit).map_err(|e| {
435 DataFusionError::ArrowError(
436 e,
437 Some(format!("Fail to sort to indices at {}", location!())),
438 )
439 })?;
440 if indices.is_empty() {
441 return Ok(DfRecordBatch::new_empty(self.schema.clone()));
442 }
443
444 self.check_in_range(
445 &sort_column,
446 (
447 indices.value(0) as usize,
448 indices.value(indices.len() - 1) as usize,
449 ),
450 )
451 .inspect_err(|_e| {
452 #[cfg(debug_assertions)]
453 common_telemetry::error!(
454 "Fail to check sort column in range at {}, current_idx: {}, num_rows: {}, err: {}",
455 self.partition,
456 self.cur_part_idx,
457 sort_column.len(),
458 _e
459 );
460 })?;
461
462 let total_mem: usize = buffer.iter().map(|r| r.get_array_memory_size()).sum();
464 self.reservation.try_grow(total_mem * 2)?;
465
466 let full_input = concat_batches(&self.schema, &buffer).map_err(|e| {
467 DataFusionError::ArrowError(
468 e,
469 Some(format!(
470 "Fail to concat input batches when sorting at {}",
471 location!()
472 )),
473 )
474 })?;
475
476 let sorted = take_record_batch(&full_input, &indices).map_err(|e| {
477 DataFusionError::ArrowError(
478 e,
479 Some(format!(
480 "Fail to take result record batch when sorting at {}",
481 location!()
482 )),
483 )
484 })?;
485
486 self.produced += sorted.num_rows();
487 drop(full_input);
488 self.reservation.shrink(2 * total_mem);
490 Ok(sorted)
491 }
492
493 fn sort_top_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
495 let new_top_buffer = TopK::try_new(
496 self.partition,
497 self.schema().clone(),
498 LexOrdering::new(vec![self.expression.clone()]),
499 self.limit.unwrap(),
500 self.context.session_config().batch_size(),
501 self.context.runtime_env(),
502 &self.root_metrics,
503 )?;
504 let PartSortBuffer::Top(top_k, _) =
505 std::mem::replace(&mut self.buffer, PartSortBuffer::Top(new_top_buffer, 0))
506 else {
507 unreachable!("buffer type is checked before and should be Top variant")
508 };
509
510 let mut result_stream = top_k.emit()?;
511 let mut placeholder_ctx = std::task::Context::from_waker(futures::task::noop_waker_ref());
512 let mut results = vec![];
513 loop {
515 match result_stream.poll_next_unpin(&mut placeholder_ctx) {
516 Poll::Ready(Some(batch)) => {
517 let batch = batch?;
518 results.push(batch);
519 }
520 Poll::Pending => {
521 #[cfg(debug_assertions)]
522 unreachable!("TopK result stream should always be ready")
523 }
524 Poll::Ready(None) => {
525 break;
526 }
527 }
528 }
529
530 let concat_batch = concat_batches(&self.schema, &results).map_err(|e| {
531 DataFusionError::ArrowError(
532 e,
533 Some(format!(
534 "Fail to concat top k result record batch when sorting at {}",
535 location!()
536 )),
537 )
538 })?;
539
540 Ok(concat_batch)
541 }
542
543 fn split_batch(
553 &mut self,
554 batch: DfRecordBatch,
555 ) -> datafusion_common::Result<Option<DfRecordBatch>> {
556 if batch.num_rows() == 0 {
557 return Ok(None);
558 }
559
560 let sort_column = self
561 .expression
562 .expr
563 .evaluate(&batch)?
564 .into_array(batch.num_rows())?;
565
566 let next_range_idx = self.try_find_next_range(&sort_column)?;
567 let Some(idx) = next_range_idx else {
568 self.push_buffer(batch)?;
569 return Ok(None);
571 };
572
573 let this_range = batch.slice(0, idx);
574 let remaining_range = batch.slice(idx, batch.num_rows() - idx);
575 if this_range.num_rows() != 0 {
576 self.push_buffer(this_range)?;
577 }
578 let sorted_batch = self.sort_buffer();
580 self.cur_part_idx += 1;
582 let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
583 if self.try_find_next_range(&next_sort_column)?.is_some() {
584 self.evaluating_batch = Some(remaining_range);
587 } else {
588 if remaining_range.num_rows() != 0 {
591 self.push_buffer(remaining_range)?;
592 }
593 }
594
595 sorted_batch.map(|x| if x.num_rows() == 0 { None } else { Some(x) })
596 }
597
598 pub fn poll_next_inner(
599 mut self: Pin<&mut Self>,
600 cx: &mut Context<'_>,
601 ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
602 loop {
603 if self.input_complete {
605 if self.buffer.is_empty() {
606 return Poll::Ready(None);
607 } else {
608 return Poll::Ready(Some(self.sort_buffer()));
609 }
610 }
611
612 if let Some(evaluating_batch) = self.evaluating_batch.take()
615 && evaluating_batch.num_rows() != 0
616 {
617 if let Some(sorted_batch) = self.split_batch(evaluating_batch)? {
618 return Poll::Ready(Some(Ok(sorted_batch)));
619 } else {
620 continue;
621 }
622 }
623
624 let res = self.input.as_mut().poll_next(cx);
626 match res {
627 Poll::Ready(Some(Ok(batch))) => {
628 if let Some(sorted_batch) = self.split_batch(batch)? {
629 return Poll::Ready(Some(Ok(sorted_batch)));
630 } else {
631 continue;
632 }
633 }
634 Poll::Ready(None) => {
636 self.input_complete = true;
637 continue;
638 }
639 Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
640 Poll::Pending => return Poll::Pending,
641 }
642 }
643 }
644}
645
646impl Stream for PartSortStream {
647 type Item = datafusion_common::Result<DfRecordBatch>;
648
649 fn poll_next(
650 mut self: Pin<&mut Self>,
651 cx: &mut Context<'_>,
652 ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
653 let result = self.as_mut().poll_next_inner(cx);
654 self.metrics.record_poll(result)
655 }
656}
657
658impl RecordBatchStream for PartSortStream {
659 fn schema(&self) -> SchemaRef {
660 self.schema.clone()
661 }
662}
663
664#[cfg(test)]
665mod test {
666 use std::sync::Arc;
667
668 use arrow::json::ArrayWriter;
669 use arrow_schema::{DataType, Field, Schema, SortOptions, TimeUnit};
670 use common_time::Timestamp;
671 use datafusion_physical_expr::expressions::Column;
672 use futures::StreamExt;
673 use store_api::region_engine::PartitionRange;
674
675 use super::*;
676 use crate::test_util::{new_ts_array, MockInputExec};
677
678 #[tokio::test]
679 async fn fuzzy_test() {
680 let test_cnt = 100;
681 let part_cnt_bound = 100;
683 let range_size_bound = 100;
685 let range_offset_bound = 100;
686 let batch_cnt_bound = 20;
688 let batch_size_bound = 100;
689
690 let mut rng = fastrand::Rng::new();
691 rng.seed(1337);
692
693 let mut test_cases = Vec::new();
694
695 for case_id in 0..test_cnt {
696 let mut bound_val: Option<i64> = None;
697 let descending = rng.bool();
698 let nulls_first = rng.bool();
699 let opt = SortOptions {
700 descending,
701 nulls_first,
702 };
703 let limit = if rng.bool() {
704 Some(rng.usize(0..batch_cnt_bound * batch_size_bound))
705 } else {
706 None
707 };
708 let unit = match rng.u8(0..3) {
709 0 => TimeUnit::Second,
710 1 => TimeUnit::Millisecond,
711 2 => TimeUnit::Microsecond,
712 _ => TimeUnit::Nanosecond,
713 };
714
715 let schema = Schema::new(vec![Field::new(
716 "ts",
717 DataType::Timestamp(unit, None),
718 false,
719 )]);
720 let schema = Arc::new(schema);
721
722 let mut input_ranged_data = vec![];
723 let mut output_ranges = vec![];
724 let mut output_data = vec![];
725 for part_id in 0..rng.usize(0..part_cnt_bound) {
727 let (start, end) = if descending {
729 let end = bound_val
730 .map(
731 |i| i
732 .checked_sub(rng.i64(0..range_offset_bound))
733 .expect("Bad luck, fuzzy test generate data that will overflow, change seed and try again")
734 )
735 .unwrap_or_else(|| rng.i64(-100000000..100000000));
736 bound_val = Some(end);
737 let start = end - rng.i64(1..range_size_bound);
738 let start = Timestamp::new(start, unit.into());
739 let end = Timestamp::new(end, unit.into());
740 (start, end)
741 } else {
742 let start = bound_val
743 .map(|i| i + rng.i64(0..range_offset_bound))
744 .unwrap_or_else(|| rng.i64(..));
745 bound_val = Some(start);
746 let end = start + rng.i64(1..range_size_bound);
747 let start = Timestamp::new(start, unit.into());
748 let end = Timestamp::new(end, unit.into());
749 (start, end)
750 };
751 assert!(start < end);
752
753 let mut per_part_sort_data = vec![];
754 let mut batches = vec![];
755 for _batch_idx in 0..rng.usize(1..batch_cnt_bound) {
756 let cnt = rng.usize(0..batch_size_bound) + 1;
757 let iter = 0..rng.usize(0..cnt);
758 let mut data_gen = iter
759 .map(|_| rng.i64(start.value()..end.value()))
760 .collect_vec();
761 if data_gen.is_empty() {
762 continue;
764 }
765 data_gen.sort();
767 per_part_sort_data.extend(data_gen.clone());
768 let arr = new_ts_array(unit, data_gen.clone());
769 let batch = DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap();
770 batches.push(batch);
771 }
772
773 let range = PartitionRange {
774 start,
775 end,
776 num_rows: batches.iter().map(|b| b.num_rows()).sum(),
777 identifier: part_id,
778 };
779 input_ranged_data.push((range, batches));
780
781 output_ranges.push(range);
782 if per_part_sort_data.is_empty() {
783 continue;
784 }
785 output_data.extend_from_slice(&per_part_sort_data);
786 }
787
788 let mut output_data_iter = output_data.iter().peekable();
790 let mut output_data = vec![];
791 for range in output_ranges.clone() {
792 let mut cur_data = vec![];
793 while let Some(val) = output_data_iter.peek() {
794 if **val < range.start.value() || **val >= range.end.value() {
795 break;
796 }
797 cur_data.push(*output_data_iter.next().unwrap());
798 }
799
800 if cur_data.is_empty() {
801 continue;
802 }
803
804 if descending {
805 cur_data.sort_by(|a, b| b.cmp(a));
806 } else {
807 cur_data.sort();
808 }
809 output_data.push(cur_data);
810 }
811
812 let expected_output = output_data
813 .into_iter()
814 .map(|a| {
815 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
816 })
817 .map(|rb| {
818 if let Some(limit) = limit
820 && rb.num_rows() > limit
821 {
822 rb.slice(0, limit)
823 } else {
824 rb
825 }
826 })
827 .collect_vec();
828
829 test_cases.push((
830 case_id,
831 unit,
832 input_ranged_data,
833 schema,
834 opt,
835 limit,
836 expected_output,
837 ));
838 }
839
840 for (case_id, _unit, input_ranged_data, schema, opt, limit, expected_output) in test_cases {
841 run_test(
842 case_id,
843 input_ranged_data,
844 schema,
845 opt,
846 limit,
847 expected_output,
848 )
849 .await;
850 }
851 }
852
853 #[tokio::test]
854 async fn simple_case() {
855 let testcases = vec![
856 (
857 TimeUnit::Millisecond,
858 vec![
859 ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]),
860 ((5, 10), vec![vec![5, 6], vec![7, 8]]),
861 ],
862 false,
863 None,
864 vec![vec![1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9]],
865 ),
866 (
867 TimeUnit::Millisecond,
868 vec![
869 ((5, 10), vec![vec![5, 6], vec![7, 8, 9]]),
870 ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
871 ],
872 true,
873 None,
874 vec![vec![9, 8, 7, 6, 5], vec![8, 7, 6, 5, 4, 3, 2, 1]],
875 ),
876 (
877 TimeUnit::Millisecond,
878 vec![
879 ((5, 10), vec![]),
880 ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
881 ],
882 true,
883 None,
884 vec![vec![8, 7, 6, 5, 4, 3, 2, 1]],
885 ),
886 (
887 TimeUnit::Millisecond,
888 vec![
889 ((15, 20), vec![vec![17, 18, 19]]),
890 ((10, 15), vec![]),
891 ((5, 10), vec![]),
892 ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
893 ],
894 true,
895 None,
896 vec![vec![19, 18, 17], vec![8, 7, 6, 5, 4, 3, 2, 1]],
897 ),
898 (
899 TimeUnit::Millisecond,
900 vec![
901 ((15, 20), vec![]),
902 ((10, 15), vec![]),
903 ((5, 10), vec![]),
904 ((0, 10), vec![]),
905 ],
906 true,
907 None,
908 vec![],
909 ),
910 (
911 TimeUnit::Millisecond,
912 vec![
913 (
914 (15, 20),
915 vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
916 ),
917 ((10, 15), vec![]),
918 ((5, 10), vec![]),
919 ((0, 10), vec![]),
920 ],
921 true,
922 None,
923 vec![
924 vec![19, 17, 15],
925 vec![12, 11, 10],
926 vec![9, 8, 7, 6, 5],
927 vec![4, 3, 2, 1],
928 ],
929 ),
930 (
931 TimeUnit::Millisecond,
932 vec![
933 (
934 (15, 20),
935 vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
936 ),
937 ((10, 15), vec![]),
938 ((5, 10), vec![]),
939 ((0, 10), vec![]),
940 ],
941 true,
942 Some(2),
943 vec![vec![19, 17], vec![12, 11], vec![9, 8], vec![4, 3]],
944 ),
945 ];
946
947 for (identifier, (unit, input_ranged_data, descending, limit, expected_output)) in
948 testcases.into_iter().enumerate()
949 {
950 let schema = Schema::new(vec![Field::new(
951 "ts",
952 DataType::Timestamp(unit, None),
953 false,
954 )]);
955 let schema = Arc::new(schema);
956 let opt = SortOptions {
957 descending,
958 ..Default::default()
959 };
960
961 let input_ranged_data = input_ranged_data
962 .into_iter()
963 .map(|(range, data)| {
964 let part = PartitionRange {
965 start: Timestamp::new(range.0, unit.into()),
966 end: Timestamp::new(range.1, unit.into()),
967 num_rows: data.iter().map(|b| b.len()).sum(),
968 identifier,
969 };
970
971 let batches = data
972 .into_iter()
973 .map(|b| {
974 let arr = new_ts_array(unit, b);
975 DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap()
976 })
977 .collect_vec();
978 (part, batches)
979 })
980 .collect_vec();
981
982 let expected_output = expected_output
983 .into_iter()
984 .map(|a| {
985 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
986 })
987 .collect_vec();
988
989 run_test(
990 identifier,
991 input_ranged_data,
992 schema.clone(),
993 opt,
994 limit,
995 expected_output,
996 )
997 .await;
998 }
999 }
1000
1001 #[allow(clippy::print_stdout)]
1002 async fn run_test(
1003 case_id: usize,
1004 input_ranged_data: Vec<(PartitionRange, Vec<DfRecordBatch>)>,
1005 schema: SchemaRef,
1006 opt: SortOptions,
1007 limit: Option<usize>,
1008 expected_output: Vec<DfRecordBatch>,
1009 ) {
1010 for rb in &expected_output {
1011 if let Some(limit) = limit {
1012 assert!(
1013 rb.num_rows() <= limit,
1014 "Expect row count in expected output's batch({}) <= limit({})",
1015 rb.num_rows(),
1016 limit
1017 );
1018 }
1019 }
1020 let (ranges, batches): (Vec<_>, Vec<_>) = input_ranged_data.clone().into_iter().unzip();
1021
1022 let batches = batches
1023 .into_iter()
1024 .flat_map(|mut cols| {
1025 cols.push(DfRecordBatch::new_empty(schema.clone()));
1026 cols
1027 })
1028 .collect_vec();
1029 let mock_input = MockInputExec::new(batches, schema.clone());
1030
1031 let exec = PartSortExec::new(
1032 PhysicalSortExpr {
1033 expr: Arc::new(Column::new("ts", 0)),
1034 options: opt,
1035 },
1036 limit,
1037 vec![ranges.clone()],
1038 Arc::new(mock_input),
1039 );
1040
1041 let exec_stream = exec.execute(0, Arc::new(TaskContext::default())).unwrap();
1042
1043 let real_output = exec_stream.map(|r| r.unwrap()).collect::<Vec<_>>().await;
1044 if real_output != expected_output {
1046 let mut first_diff = 0;
1047 for (idx, (lhs, rhs)) in real_output.iter().zip(expected_output.iter()).enumerate() {
1048 if lhs != rhs {
1049 first_diff = idx;
1050 break;
1051 }
1052 }
1053 println!("first diff batch at {}", first_diff);
1054 println!(
1055 "ranges: {:?}",
1056 ranges
1057 .into_iter()
1058 .map(|r| (r.start.to_chrono_datetime(), r.end.to_chrono_datetime()))
1059 .enumerate()
1060 .collect::<Vec<_>>()
1061 );
1062
1063 let mut full_msg = String::new();
1064 {
1065 let mut buf = Vec::with_capacity(10 * real_output.len());
1066 for batch in real_output.iter().skip(first_diff) {
1067 let mut rb_json: Vec<u8> = Vec::new();
1068 let mut writer = ArrayWriter::new(&mut rb_json);
1069 writer.write(batch).unwrap();
1070 writer.finish().unwrap();
1071 buf.append(&mut rb_json);
1072 buf.push(b',');
1073 }
1074 let buf = String::from_utf8_lossy(&buf);
1076 full_msg += &format!("\ncase_id:{case_id}, real_output \n{buf}\n");
1077 }
1078 {
1079 let mut buf = Vec::with_capacity(10 * real_output.len());
1080 for batch in expected_output.iter().skip(first_diff) {
1081 let mut rb_json: Vec<u8> = Vec::new();
1082 let mut writer = ArrayWriter::new(&mut rb_json);
1083 writer.write(batch).unwrap();
1084 writer.finish().unwrap();
1085 buf.append(&mut rb_json);
1086 buf.push(b',');
1087 }
1088 let buf = String::from_utf8_lossy(&buf);
1089 full_msg += &format!("case_id:{case_id}, expected_output \n{buf}");
1090 }
1091 panic!(
1092 "case_{} failed, opt: {:?},\n real output has {} batches, {} rows, expected has {} batches with {} rows\nfull msg: {}",
1093 case_id, opt,
1094 real_output.len(),
1095 real_output.iter().map(|x|x.num_rows()).sum::<usize>(),
1096 expected_output.len(),
1097 expected_output.iter().map(|x|x.num_rows()).sum::<usize>(), full_msg
1098 );
1099 }
1100 }
1101}