1use std::any::Any;
22use std::pin::Pin;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26use arrow::array::ArrayRef;
27use arrow::compute::{concat, concat_batches, take_record_batch};
28use arrow_schema::SchemaRef;
29use common_recordbatch::{DfRecordBatch, DfSendableRecordBatchStream};
30use datafusion::common::arrow::compute::sort_to_indices;
31use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation};
32use datafusion::execution::{RecordBatchStream, TaskContext};
33use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
34use datafusion::physical_plan::{
35 DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, TopK,
36};
37use datafusion_common::{internal_err, DataFusionError};
38use datafusion_physical_expr::PhysicalSortExpr;
39use futures::{Stream, StreamExt};
40use itertools::Itertools;
41use snafu::location;
42use store_api::region_engine::PartitionRange;
43
44use crate::{array_iter_helper, downcast_ts_array};
45
46#[derive(Debug, Clone)]
52pub struct PartSortExec {
53 expression: PhysicalSortExpr,
55 limit: Option<usize>,
56 input: Arc<dyn ExecutionPlan>,
57 metrics: ExecutionPlanMetricsSet,
59 partition_ranges: Vec<Vec<PartitionRange>>,
60 properties: PlanProperties,
61}
62
63impl PartSortExec {
64 pub fn new(
65 expression: PhysicalSortExpr,
66 limit: Option<usize>,
67 partition_ranges: Vec<Vec<PartitionRange>>,
68 input: Arc<dyn ExecutionPlan>,
69 ) -> Self {
70 let metrics = ExecutionPlanMetricsSet::new();
71 let properties = input.properties();
72 let properties = PlanProperties::new(
73 input.equivalence_properties().clone(),
74 input.output_partitioning().clone(),
75 properties.emission_type,
76 properties.boundedness,
77 );
78
79 Self {
80 expression,
81 limit,
82 input,
83 metrics,
84 partition_ranges,
85 properties,
86 }
87 }
88
89 pub fn to_stream(
90 &self,
91 context: Arc<TaskContext>,
92 partition: usize,
93 ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
94 let input_stream: DfSendableRecordBatchStream =
95 self.input.execute(partition, context.clone())?;
96
97 if partition >= self.partition_ranges.len() {
98 internal_err!(
99 "Partition index out of range: {} >= {} at {}",
100 partition,
101 self.partition_ranges.len(),
102 snafu::location!()
103 )?;
104 }
105
106 let df_stream = Box::pin(PartSortStream::new(
107 context,
108 self,
109 self.limit,
110 input_stream,
111 self.partition_ranges[partition].clone(),
112 partition,
113 )?) as _;
114
115 Ok(df_stream)
116 }
117}
118
119impl DisplayAs for PartSortExec {
120 fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 write!(
122 f,
123 "PartSortExec: expr={} num_ranges={}",
124 self.expression,
125 self.partition_ranges.len(),
126 )?;
127 if let Some(limit) = self.limit {
128 write!(f, " limit={}", limit)?;
129 }
130 Ok(())
131 }
132}
133
134impl ExecutionPlan for PartSortExec {
135 fn name(&self) -> &str {
136 "PartSortExec"
137 }
138
139 fn as_any(&self) -> &dyn Any {
140 self
141 }
142
143 fn schema(&self) -> SchemaRef {
144 self.input.schema()
145 }
146
147 fn properties(&self) -> &PlanProperties {
148 &self.properties
149 }
150
151 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
152 vec![&self.input]
153 }
154
155 fn with_new_children(
156 self: Arc<Self>,
157 children: Vec<Arc<dyn ExecutionPlan>>,
158 ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
159 let new_input = if let Some(first) = children.first() {
160 first
161 } else {
162 internal_err!("No children found")?
163 };
164 Ok(Arc::new(Self::new(
165 self.expression.clone(),
166 self.limit,
167 self.partition_ranges.clone(),
168 new_input.clone(),
169 )))
170 }
171
172 fn execute(
173 &self,
174 partition: usize,
175 context: Arc<TaskContext>,
176 ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
177 self.to_stream(context, partition)
178 }
179
180 fn metrics(&self) -> Option<MetricsSet> {
181 Some(self.metrics.clone_inner())
182 }
183
184 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
190 vec![false]
191 }
192}
193
194enum PartSortBuffer {
195 All(Vec<DfRecordBatch>),
196 Top(TopK, usize),
201}
202
203impl PartSortBuffer {
204 pub fn is_empty(&self) -> bool {
205 match self {
206 PartSortBuffer::All(v) => v.is_empty(),
207 PartSortBuffer::Top(_, cnt) => *cnt == 0,
208 }
209 }
210}
211
212struct PartSortStream {
213 reservation: MemoryReservation,
215 buffer: PartSortBuffer,
216 expression: PhysicalSortExpr,
217 limit: Option<usize>,
218 produced: usize,
219 input: DfSendableRecordBatchStream,
220 input_complete: bool,
221 schema: SchemaRef,
222 partition_ranges: Vec<PartitionRange>,
223 #[allow(dead_code)] partition: usize,
225 cur_part_idx: usize,
226 evaluating_batch: Option<DfRecordBatch>,
227 metrics: BaselineMetrics,
228 context: Arc<TaskContext>,
229 root_metrics: ExecutionPlanMetricsSet,
230}
231
232impl PartSortStream {
233 fn new(
234 context: Arc<TaskContext>,
235 sort: &PartSortExec,
236 limit: Option<usize>,
237 input: DfSendableRecordBatchStream,
238 partition_ranges: Vec<PartitionRange>,
239 partition: usize,
240 ) -> datafusion_common::Result<Self> {
241 let buffer = if let Some(limit) = limit {
242 PartSortBuffer::Top(
243 TopK::try_new(
244 partition,
245 sort.schema().clone(),
246 vec![],
247 [sort.expression.clone()].into(),
248 limit,
249 context.session_config().batch_size(),
250 context.runtime_env(),
251 &sort.metrics,
252 None,
253 )?,
254 0,
255 )
256 } else {
257 PartSortBuffer::All(Vec::new())
258 };
259
260 Ok(Self {
261 reservation: MemoryConsumer::new("PartSortStream".to_string())
262 .register(&context.runtime_env().memory_pool),
263 buffer,
264 expression: sort.expression.clone(),
265 limit,
266 produced: 0,
267 input,
268 input_complete: false,
269 schema: sort.input.schema(),
270 partition_ranges,
271 partition,
272 cur_part_idx: 0,
273 evaluating_batch: None,
274 metrics: BaselineMetrics::new(&sort.metrics, partition),
275 context,
276 root_metrics: sort.metrics.clone(),
277 })
278 }
279}
280
281macro_rules! array_check_helper {
282 ($t:ty, $unit:expr, $arr:expr, $cur_range:expr, $min_max_idx:expr) => {{
283 if $cur_range.start.unit().as_arrow_time_unit() != $unit
284 || $cur_range.end.unit().as_arrow_time_unit() != $unit
285 {
286 internal_err!(
287 "PartitionRange unit mismatch, expect {:?}, found {:?}",
288 $cur_range.start.unit(),
289 $unit
290 )?;
291 }
292 let arr = $arr
293 .as_any()
294 .downcast_ref::<arrow::array::PrimitiveArray<$t>>()
295 .unwrap();
296
297 let min = arr.value($min_max_idx.0);
298 let max = arr.value($min_max_idx.1);
299 let (min, max) = if min < max{
300 (min, max)
301 } else {
302 (max, min)
303 };
304 let cur_min = $cur_range.start.value();
305 let cur_max = $cur_range.end.value();
306 if !(min >= cur_min && max < cur_max) {
308 internal_err!(
309 "Sort column min/max value out of partition range: sort_column.min_max=[{:?}, {:?}] not in PartitionRange=[{:?}, {:?}]",
310 min,
311 max,
312 cur_min,
313 cur_max
314 )?;
315 }
316 }};
317}
318
319impl PartSortStream {
320 fn check_in_range(
322 &self,
323 sort_column: &ArrayRef,
324 min_max_idx: (usize, usize),
325 ) -> datafusion_common::Result<()> {
326 if self.cur_part_idx >= self.partition_ranges.len() {
327 internal_err!(
328 "Partition index out of range: {} >= {} at {}",
329 self.cur_part_idx,
330 self.partition_ranges.len(),
331 snafu::location!()
332 )?;
333 }
334 let cur_range = self.partition_ranges[self.cur_part_idx];
335
336 downcast_ts_array!(
337 sort_column.data_type() => (array_check_helper, sort_column, cur_range, min_max_idx),
338 _ => internal_err!(
339 "Unsupported data type for sort column: {:?}",
340 sort_column.data_type()
341 )?,
342 );
343
344 Ok(())
345 }
346
347 fn try_find_next_range(
352 &self,
353 sort_column: &ArrayRef,
354 ) -> datafusion_common::Result<Option<usize>> {
355 if sort_column.is_empty() {
356 return Ok(Some(0));
357 }
358
359 if self.cur_part_idx >= self.partition_ranges.len() {
361 internal_err!(
362 "Partition index out of range: {} >= {} at {}",
363 self.cur_part_idx,
364 self.partition_ranges.len(),
365 snafu::location!()
366 )?;
367 }
368 let cur_range = self.partition_ranges[self.cur_part_idx];
369
370 let sort_column_iter = downcast_ts_array!(
371 sort_column.data_type() => (array_iter_helper, sort_column),
372 _ => internal_err!(
373 "Unsupported data type for sort column: {:?}",
374 sort_column.data_type()
375 )?,
376 );
377
378 for (idx, val) in sort_column_iter {
379 if let Some(val) = val {
381 if val >= cur_range.end.value() || val < cur_range.start.value() {
382 return Ok(Some(idx));
383 }
384 }
385 }
386
387 Ok(None)
388 }
389
390 fn push_buffer(&mut self, batch: DfRecordBatch) -> datafusion_common::Result<()> {
391 match &mut self.buffer {
392 PartSortBuffer::All(v) => v.push(batch),
393 PartSortBuffer::Top(top, cnt) => {
394 *cnt += batch.num_rows();
395 top.insert_batch(batch)?;
396 }
397 }
398
399 Ok(())
400 }
401
402 fn sort_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
406 match &mut self.buffer {
407 PartSortBuffer::All(_) => self.sort_all_buffer(),
408 PartSortBuffer::Top(_, _) => self.sort_top_buffer(),
409 }
410 }
411
412 fn sort_all_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
414 let PartSortBuffer::All(buffer) =
415 std::mem::replace(&mut self.buffer, PartSortBuffer::All(Vec::new()))
416 else {
417 unreachable!("buffer type is checked before and should be All variant")
418 };
419
420 if buffer.is_empty() {
421 return Ok(DfRecordBatch::new_empty(self.schema.clone()));
422 }
423 let mut sort_columns = Vec::with_capacity(buffer.len());
424 let mut opt = None;
425 for batch in buffer.iter() {
426 let sort_column = self.expression.evaluate_to_sort_column(batch)?;
427 opt = opt.or(sort_column.options);
428 sort_columns.push(sort_column.values);
429 }
430
431 let sort_column =
432 concat(&sort_columns.iter().map(|a| a.as_ref()).collect_vec()).map_err(|e| {
433 DataFusionError::ArrowError(
434 Box::new(e),
435 Some(format!("Fail to concat sort columns at {}", location!())),
436 )
437 })?;
438
439 let indices = sort_to_indices(&sort_column, opt, self.limit).map_err(|e| {
440 DataFusionError::ArrowError(
441 Box::new(e),
442 Some(format!("Fail to sort to indices at {}", location!())),
443 )
444 })?;
445 if indices.is_empty() {
446 return Ok(DfRecordBatch::new_empty(self.schema.clone()));
447 }
448
449 self.check_in_range(
450 &sort_column,
451 (
452 indices.value(0) as usize,
453 indices.value(indices.len() - 1) as usize,
454 ),
455 )
456 .inspect_err(|_e| {
457 #[cfg(debug_assertions)]
458 common_telemetry::error!(
459 "Fail to check sort column in range at {}, current_idx: {}, num_rows: {}, err: {}",
460 self.partition,
461 self.cur_part_idx,
462 sort_column.len(),
463 _e
464 );
465 })?;
466
467 let total_mem: usize = buffer.iter().map(|r| r.get_array_memory_size()).sum();
469 self.reservation.try_grow(total_mem * 2)?;
470
471 let full_input = concat_batches(&self.schema, &buffer).map_err(|e| {
472 DataFusionError::ArrowError(
473 Box::new(e),
474 Some(format!(
475 "Fail to concat input batches when sorting at {}",
476 location!()
477 )),
478 )
479 })?;
480
481 let sorted = take_record_batch(&full_input, &indices).map_err(|e| {
482 DataFusionError::ArrowError(
483 Box::new(e),
484 Some(format!(
485 "Fail to take result record batch when sorting at {}",
486 location!()
487 )),
488 )
489 })?;
490
491 self.produced += sorted.num_rows();
492 drop(full_input);
493 self.reservation.shrink(2 * total_mem);
495 Ok(sorted)
496 }
497
498 fn sort_top_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
500 let new_top_buffer = TopK::try_new(
501 self.partition,
502 self.schema().clone(),
503 vec![],
504 [self.expression.clone()].into(),
505 self.limit.unwrap(),
506 self.context.session_config().batch_size(),
507 self.context.runtime_env(),
508 &self.root_metrics,
509 None,
510 )?;
511 let PartSortBuffer::Top(top_k, _) =
512 std::mem::replace(&mut self.buffer, PartSortBuffer::Top(new_top_buffer, 0))
513 else {
514 unreachable!("buffer type is checked before and should be Top variant")
515 };
516
517 let mut result_stream = top_k.emit()?;
518 let mut placeholder_ctx = std::task::Context::from_waker(futures::task::noop_waker_ref());
519 let mut results = vec![];
520 loop {
522 match result_stream.poll_next_unpin(&mut placeholder_ctx) {
523 Poll::Ready(Some(batch)) => {
524 let batch = batch?;
525 results.push(batch);
526 }
527 Poll::Pending => {
528 #[cfg(debug_assertions)]
529 unreachable!("TopK result stream should always be ready")
530 }
531 Poll::Ready(None) => {
532 break;
533 }
534 }
535 }
536
537 let concat_batch = concat_batches(&self.schema, &results).map_err(|e| {
538 DataFusionError::ArrowError(
539 Box::new(e),
540 Some(format!(
541 "Fail to concat top k result record batch when sorting at {}",
542 location!()
543 )),
544 )
545 })?;
546
547 Ok(concat_batch)
548 }
549
550 fn split_batch(
560 &mut self,
561 batch: DfRecordBatch,
562 ) -> datafusion_common::Result<Option<DfRecordBatch>> {
563 if batch.num_rows() == 0 {
564 return Ok(None);
565 }
566
567 let sort_column = self
568 .expression
569 .expr
570 .evaluate(&batch)?
571 .into_array(batch.num_rows())?;
572
573 let next_range_idx = self.try_find_next_range(&sort_column)?;
574 let Some(idx) = next_range_idx else {
575 self.push_buffer(batch)?;
576 return Ok(None);
578 };
579
580 let this_range = batch.slice(0, idx);
581 let remaining_range = batch.slice(idx, batch.num_rows() - idx);
582 if this_range.num_rows() != 0 {
583 self.push_buffer(this_range)?;
584 }
585 let sorted_batch = self.sort_buffer();
587 self.cur_part_idx += 1;
589 let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
590 if self.try_find_next_range(&next_sort_column)?.is_some() {
591 self.evaluating_batch = Some(remaining_range);
594 } else {
595 if remaining_range.num_rows() != 0 {
598 self.push_buffer(remaining_range)?;
599 }
600 }
601
602 sorted_batch.map(|x| if x.num_rows() == 0 { None } else { Some(x) })
603 }
604
605 pub fn poll_next_inner(
606 mut self: Pin<&mut Self>,
607 cx: &mut Context<'_>,
608 ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
609 loop {
610 if self.input_complete {
612 if self.buffer.is_empty() {
613 return Poll::Ready(None);
614 } else {
615 return Poll::Ready(Some(self.sort_buffer()));
616 }
617 }
618
619 if let Some(evaluating_batch) = self.evaluating_batch.take()
622 && evaluating_batch.num_rows() != 0
623 {
624 if let Some(sorted_batch) = self.split_batch(evaluating_batch)? {
625 return Poll::Ready(Some(Ok(sorted_batch)));
626 } else {
627 continue;
628 }
629 }
630
631 let res = self.input.as_mut().poll_next(cx);
633 match res {
634 Poll::Ready(Some(Ok(batch))) => {
635 if let Some(sorted_batch) = self.split_batch(batch)? {
636 return Poll::Ready(Some(Ok(sorted_batch)));
637 } else {
638 continue;
639 }
640 }
641 Poll::Ready(None) => {
643 self.input_complete = true;
644 continue;
645 }
646 Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
647 Poll::Pending => return Poll::Pending,
648 }
649 }
650 }
651}
652
653impl Stream for PartSortStream {
654 type Item = datafusion_common::Result<DfRecordBatch>;
655
656 fn poll_next(
657 mut self: Pin<&mut Self>,
658 cx: &mut Context<'_>,
659 ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
660 let result = self.as_mut().poll_next_inner(cx);
661 self.metrics.record_poll(result)
662 }
663}
664
665impl RecordBatchStream for PartSortStream {
666 fn schema(&self) -> SchemaRef {
667 self.schema.clone()
668 }
669}
670
671#[cfg(test)]
672mod test {
673 use std::sync::Arc;
674
675 use arrow::json::ArrayWriter;
676 use arrow_schema::{DataType, Field, Schema, SortOptions, TimeUnit};
677 use common_time::Timestamp;
678 use datafusion_physical_expr::expressions::Column;
679 use futures::StreamExt;
680 use store_api::region_engine::PartitionRange;
681
682 use super::*;
683 use crate::test_util::{new_ts_array, MockInputExec};
684
685 #[tokio::test]
686 async fn fuzzy_test() {
687 let test_cnt = 100;
688 let part_cnt_bound = 100;
690 let range_size_bound = 100;
692 let range_offset_bound = 100;
693 let batch_cnt_bound = 20;
695 let batch_size_bound = 100;
696
697 let mut rng = fastrand::Rng::new();
698 rng.seed(1337);
699
700 let mut test_cases = Vec::new();
701
702 for case_id in 0..test_cnt {
703 let mut bound_val: Option<i64> = None;
704 let descending = rng.bool();
705 let nulls_first = rng.bool();
706 let opt = SortOptions {
707 descending,
708 nulls_first,
709 };
710 let limit = if rng.bool() {
711 Some(rng.usize(0..batch_cnt_bound * batch_size_bound))
712 } else {
713 None
714 };
715 let unit = match rng.u8(0..3) {
716 0 => TimeUnit::Second,
717 1 => TimeUnit::Millisecond,
718 2 => TimeUnit::Microsecond,
719 _ => TimeUnit::Nanosecond,
720 };
721
722 let schema = Schema::new(vec![Field::new(
723 "ts",
724 DataType::Timestamp(unit, None),
725 false,
726 )]);
727 let schema = Arc::new(schema);
728
729 let mut input_ranged_data = vec![];
730 let mut output_ranges = vec![];
731 let mut output_data = vec![];
732 for part_id in 0..rng.usize(0..part_cnt_bound) {
734 let (start, end) = if descending {
736 let end = bound_val
737 .map(
738 |i| i
739 .checked_sub(rng.i64(0..range_offset_bound))
740 .expect("Bad luck, fuzzy test generate data that will overflow, change seed and try again")
741 )
742 .unwrap_or_else(|| rng.i64(-100000000..100000000));
743 bound_val = Some(end);
744 let start = end - rng.i64(1..range_size_bound);
745 let start = Timestamp::new(start, unit.into());
746 let end = Timestamp::new(end, unit.into());
747 (start, end)
748 } else {
749 let start = bound_val
750 .map(|i| i + rng.i64(0..range_offset_bound))
751 .unwrap_or_else(|| rng.i64(..));
752 bound_val = Some(start);
753 let end = start + rng.i64(1..range_size_bound);
754 let start = Timestamp::new(start, unit.into());
755 let end = Timestamp::new(end, unit.into());
756 (start, end)
757 };
758 assert!(start < end);
759
760 let mut per_part_sort_data = vec![];
761 let mut batches = vec![];
762 for _batch_idx in 0..rng.usize(1..batch_cnt_bound) {
763 let cnt = rng.usize(0..batch_size_bound) + 1;
764 let iter = 0..rng.usize(0..cnt);
765 let mut data_gen = iter
766 .map(|_| rng.i64(start.value()..end.value()))
767 .collect_vec();
768 if data_gen.is_empty() {
769 continue;
771 }
772 data_gen.sort();
774 per_part_sort_data.extend(data_gen.clone());
775 let arr = new_ts_array(unit, data_gen.clone());
776 let batch = DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap();
777 batches.push(batch);
778 }
779
780 let range = PartitionRange {
781 start,
782 end,
783 num_rows: batches.iter().map(|b| b.num_rows()).sum(),
784 identifier: part_id,
785 };
786 input_ranged_data.push((range, batches));
787
788 output_ranges.push(range);
789 if per_part_sort_data.is_empty() {
790 continue;
791 }
792 output_data.extend_from_slice(&per_part_sort_data);
793 }
794
795 let mut output_data_iter = output_data.iter().peekable();
797 let mut output_data = vec![];
798 for range in output_ranges.clone() {
799 let mut cur_data = vec![];
800 while let Some(val) = output_data_iter.peek() {
801 if **val < range.start.value() || **val >= range.end.value() {
802 break;
803 }
804 cur_data.push(*output_data_iter.next().unwrap());
805 }
806
807 if cur_data.is_empty() {
808 continue;
809 }
810
811 if descending {
812 cur_data.sort_by(|a, b| b.cmp(a));
813 } else {
814 cur_data.sort();
815 }
816 output_data.push(cur_data);
817 }
818
819 let expected_output = output_data
820 .into_iter()
821 .map(|a| {
822 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
823 })
824 .map(|rb| {
825 if let Some(limit) = limit
827 && rb.num_rows() > limit
828 {
829 rb.slice(0, limit)
830 } else {
831 rb
832 }
833 })
834 .collect_vec();
835
836 test_cases.push((
837 case_id,
838 unit,
839 input_ranged_data,
840 schema,
841 opt,
842 limit,
843 expected_output,
844 ));
845 }
846
847 for (case_id, _unit, input_ranged_data, schema, opt, limit, expected_output) in test_cases {
848 run_test(
849 case_id,
850 input_ranged_data,
851 schema,
852 opt,
853 limit,
854 expected_output,
855 )
856 .await;
857 }
858 }
859
860 #[tokio::test]
861 async fn simple_case() {
862 let testcases = vec![
863 (
864 TimeUnit::Millisecond,
865 vec![
866 ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]),
867 ((5, 10), vec![vec![5, 6], vec![7, 8]]),
868 ],
869 false,
870 None,
871 vec![vec![1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9]],
872 ),
873 (
874 TimeUnit::Millisecond,
875 vec![
876 ((5, 10), vec![vec![5, 6], vec![7, 8, 9]]),
877 ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
878 ],
879 true,
880 None,
881 vec![vec![9, 8, 7, 6, 5], vec![8, 7, 6, 5, 4, 3, 2, 1]],
882 ),
883 (
884 TimeUnit::Millisecond,
885 vec![
886 ((5, 10), vec![]),
887 ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
888 ],
889 true,
890 None,
891 vec![vec![8, 7, 6, 5, 4, 3, 2, 1]],
892 ),
893 (
894 TimeUnit::Millisecond,
895 vec![
896 ((15, 20), vec![vec![17, 18, 19]]),
897 ((10, 15), vec![]),
898 ((5, 10), vec![]),
899 ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
900 ],
901 true,
902 None,
903 vec![vec![19, 18, 17], vec![8, 7, 6, 5, 4, 3, 2, 1]],
904 ),
905 (
906 TimeUnit::Millisecond,
907 vec![
908 ((15, 20), vec![]),
909 ((10, 15), vec![]),
910 ((5, 10), vec![]),
911 ((0, 10), vec![]),
912 ],
913 true,
914 None,
915 vec![],
916 ),
917 (
918 TimeUnit::Millisecond,
919 vec![
920 (
921 (15, 20),
922 vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
923 ),
924 ((10, 15), vec![]),
925 ((5, 10), vec![]),
926 ((0, 10), vec![]),
927 ],
928 true,
929 None,
930 vec![
931 vec![19, 17, 15],
932 vec![12, 11, 10],
933 vec![9, 8, 7, 6, 5],
934 vec![4, 3, 2, 1],
935 ],
936 ),
937 (
938 TimeUnit::Millisecond,
939 vec![
940 (
941 (15, 20),
942 vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
943 ),
944 ((10, 15), vec![]),
945 ((5, 10), vec![]),
946 ((0, 10), vec![]),
947 ],
948 true,
949 Some(2),
950 vec![vec![19, 17], vec![12, 11], vec![9, 8], vec![4, 3]],
951 ),
952 ];
953
954 for (identifier, (unit, input_ranged_data, descending, limit, expected_output)) in
955 testcases.into_iter().enumerate()
956 {
957 let schema = Schema::new(vec![Field::new(
958 "ts",
959 DataType::Timestamp(unit, None),
960 false,
961 )]);
962 let schema = Arc::new(schema);
963 let opt = SortOptions {
964 descending,
965 ..Default::default()
966 };
967
968 let input_ranged_data = input_ranged_data
969 .into_iter()
970 .map(|(range, data)| {
971 let part = PartitionRange {
972 start: Timestamp::new(range.0, unit.into()),
973 end: Timestamp::new(range.1, unit.into()),
974 num_rows: data.iter().map(|b| b.len()).sum(),
975 identifier,
976 };
977
978 let batches = data
979 .into_iter()
980 .map(|b| {
981 let arr = new_ts_array(unit, b);
982 DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap()
983 })
984 .collect_vec();
985 (part, batches)
986 })
987 .collect_vec();
988
989 let expected_output = expected_output
990 .into_iter()
991 .map(|a| {
992 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
993 })
994 .collect_vec();
995
996 run_test(
997 identifier,
998 input_ranged_data,
999 schema.clone(),
1000 opt,
1001 limit,
1002 expected_output,
1003 )
1004 .await;
1005 }
1006 }
1007
1008 #[allow(clippy::print_stdout)]
1009 async fn run_test(
1010 case_id: usize,
1011 input_ranged_data: Vec<(PartitionRange, Vec<DfRecordBatch>)>,
1012 schema: SchemaRef,
1013 opt: SortOptions,
1014 limit: Option<usize>,
1015 expected_output: Vec<DfRecordBatch>,
1016 ) {
1017 for rb in &expected_output {
1018 if let Some(limit) = limit {
1019 assert!(
1020 rb.num_rows() <= limit,
1021 "Expect row count in expected output's batch({}) <= limit({})",
1022 rb.num_rows(),
1023 limit
1024 );
1025 }
1026 }
1027 let (ranges, batches): (Vec<_>, Vec<_>) = input_ranged_data.clone().into_iter().unzip();
1028
1029 let batches = batches
1030 .into_iter()
1031 .flat_map(|mut cols| {
1032 cols.push(DfRecordBatch::new_empty(schema.clone()));
1033 cols
1034 })
1035 .collect_vec();
1036 let mock_input = MockInputExec::new(batches, schema.clone());
1037
1038 let exec = PartSortExec::new(
1039 PhysicalSortExpr {
1040 expr: Arc::new(Column::new("ts", 0)),
1041 options: opt,
1042 },
1043 limit,
1044 vec![ranges.clone()],
1045 Arc::new(mock_input),
1046 );
1047
1048 let exec_stream = exec.execute(0, Arc::new(TaskContext::default())).unwrap();
1049
1050 let real_output = exec_stream.map(|r| r.unwrap()).collect::<Vec<_>>().await;
1051 if real_output != expected_output {
1053 let mut first_diff = 0;
1054 for (idx, (lhs, rhs)) in real_output.iter().zip(expected_output.iter()).enumerate() {
1055 if lhs != rhs {
1056 first_diff = idx;
1057 break;
1058 }
1059 }
1060 println!("first diff batch at {}", first_diff);
1061 println!(
1062 "ranges: {:?}",
1063 ranges
1064 .into_iter()
1065 .map(|r| (r.start.to_chrono_datetime(), r.end.to_chrono_datetime()))
1066 .enumerate()
1067 .collect::<Vec<_>>()
1068 );
1069
1070 let mut full_msg = String::new();
1071 {
1072 let mut buf = Vec::with_capacity(10 * real_output.len());
1073 for batch in real_output.iter().skip(first_diff) {
1074 let mut rb_json: Vec<u8> = Vec::new();
1075 let mut writer = ArrayWriter::new(&mut rb_json);
1076 writer.write(batch).unwrap();
1077 writer.finish().unwrap();
1078 buf.append(&mut rb_json);
1079 buf.push(b',');
1080 }
1081 let buf = String::from_utf8_lossy(&buf);
1083 full_msg += &format!("\ncase_id:{case_id}, real_output \n{buf}\n");
1084 }
1085 {
1086 let mut buf = Vec::with_capacity(10 * real_output.len());
1087 for batch in expected_output.iter().skip(first_diff) {
1088 let mut rb_json: Vec<u8> = Vec::new();
1089 let mut writer = ArrayWriter::new(&mut rb_json);
1090 writer.write(batch).unwrap();
1091 writer.finish().unwrap();
1092 buf.append(&mut rb_json);
1093 buf.push(b',');
1094 }
1095 let buf = String::from_utf8_lossy(&buf);
1096 full_msg += &format!("case_id:{case_id}, expected_output \n{buf}");
1097 }
1098 panic!(
1099 "case_{} failed, opt: {:?},\n real output has {} batches, {} rows, expected has {} batches with {} rows\nfull msg: {}",
1100 case_id, opt,
1101 real_output.len(),
1102 real_output.iter().map(|x|x.num_rows()).sum::<usize>(),
1103 expected_output.len(),
1104 expected_output.iter().map(|x|x.num_rows()).sum::<usize>(), full_msg
1105 );
1106 }
1107 }
1108}