query/
part_sort.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Module for sorting input data within each [`PartitionRange`].
16//!
17//! This module defines the [`PartSortExec`] execution plan, which sorts each
18//! partition ([`PartitionRange`]) independently based on the provided physical
19//! sort expressions.
20
21use std::any::Any;
22use std::pin::Pin;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26use arrow::array::ArrayRef;
27use arrow::compute::{concat, concat_batches, take_record_batch};
28use arrow_schema::SchemaRef;
29use common_recordbatch::{DfRecordBatch, DfSendableRecordBatchStream};
30use datafusion::common::arrow::compute::sort_to_indices;
31use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation};
32use datafusion::execution::{RecordBatchStream, TaskContext};
33use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
34use datafusion::physical_plan::{
35    DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, TopK,
36};
37use datafusion_common::{internal_err, DataFusionError};
38use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr};
39use futures::{Stream, StreamExt};
40use itertools::Itertools;
41use snafu::location;
42use store_api::region_engine::PartitionRange;
43
44use crate::{array_iter_helper, downcast_ts_array};
45
46/// Sort input within given PartitionRange
47///
48/// Input is assumed to be segmented by empty RecordBatch, which indicates a new `PartitionRange` is starting
49///
50/// and this operator will sort each partition independently within the partition.
51#[derive(Debug, Clone)]
52pub struct PartSortExec {
53    /// Physical sort expressions(that is, sort by timestamp)
54    expression: PhysicalSortExpr,
55    limit: Option<usize>,
56    input: Arc<dyn ExecutionPlan>,
57    /// Execution metrics
58    metrics: ExecutionPlanMetricsSet,
59    partition_ranges: Vec<Vec<PartitionRange>>,
60    properties: PlanProperties,
61}
62
63impl PartSortExec {
64    pub fn new(
65        expression: PhysicalSortExpr,
66        limit: Option<usize>,
67        partition_ranges: Vec<Vec<PartitionRange>>,
68        input: Arc<dyn ExecutionPlan>,
69    ) -> Self {
70        let metrics = ExecutionPlanMetricsSet::new();
71        let properties = input.properties();
72        let properties = PlanProperties::new(
73            input.equivalence_properties().clone(),
74            input.output_partitioning().clone(),
75            properties.emission_type,
76            properties.boundedness,
77        );
78
79        Self {
80            expression,
81            limit,
82            input,
83            metrics,
84            partition_ranges,
85            properties,
86        }
87    }
88
89    pub fn to_stream(
90        &self,
91        context: Arc<TaskContext>,
92        partition: usize,
93    ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
94        let input_stream: DfSendableRecordBatchStream =
95            self.input.execute(partition, context.clone())?;
96
97        if partition >= self.partition_ranges.len() {
98            internal_err!(
99                "Partition index out of range: {} >= {}",
100                partition,
101                self.partition_ranges.len()
102            )?;
103        }
104
105        let df_stream = Box::pin(PartSortStream::new(
106            context,
107            self,
108            self.limit,
109            input_stream,
110            self.partition_ranges[partition].clone(),
111            partition,
112        )?) as _;
113
114        Ok(df_stream)
115    }
116}
117
118impl DisplayAs for PartSortExec {
119    fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        write!(
121            f,
122            "PartSortExec: expr={} num_ranges={}",
123            self.expression,
124            self.partition_ranges.len(),
125        )?;
126        if let Some(limit) = self.limit {
127            write!(f, " limit={}", limit)?;
128        }
129        Ok(())
130    }
131}
132
133impl ExecutionPlan for PartSortExec {
134    fn name(&self) -> &str {
135        "PartSortExec"
136    }
137
138    fn as_any(&self) -> &dyn Any {
139        self
140    }
141
142    fn schema(&self) -> SchemaRef {
143        self.input.schema()
144    }
145
146    fn properties(&self) -> &PlanProperties {
147        &self.properties
148    }
149
150    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
151        vec![&self.input]
152    }
153
154    fn with_new_children(
155        self: Arc<Self>,
156        children: Vec<Arc<dyn ExecutionPlan>>,
157    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
158        let new_input = if let Some(first) = children.first() {
159            first
160        } else {
161            internal_err!("No children found")?
162        };
163        Ok(Arc::new(Self::new(
164            self.expression.clone(),
165            self.limit,
166            self.partition_ranges.clone(),
167            new_input.clone(),
168        )))
169    }
170
171    fn execute(
172        &self,
173        partition: usize,
174        context: Arc<TaskContext>,
175    ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
176        self.to_stream(context, partition)
177    }
178
179    fn metrics(&self) -> Option<MetricsSet> {
180        Some(self.metrics.clone_inner())
181    }
182
183    /// # Explain
184    ///
185    /// This plan needs to be executed on each partition independently,
186    /// and is expected to run directly on storage engine's output
187    /// distribution / partition.
188    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
189        vec![false]
190    }
191}
192
193enum PartSortBuffer {
194    All(Vec<DfRecordBatch>),
195    /// TopK buffer with row count.
196    ///
197    /// Given this heap only keeps k element, the capacity of this buffer
198    /// is not accurate, and is only used for empty check.
199    Top(TopK, usize),
200}
201
202impl PartSortBuffer {
203    pub fn is_empty(&self) -> bool {
204        match self {
205            PartSortBuffer::All(v) => v.is_empty(),
206            PartSortBuffer::Top(_, cnt) => *cnt == 0,
207        }
208    }
209}
210
211struct PartSortStream {
212    /// Memory pool for this stream
213    reservation: MemoryReservation,
214    buffer: PartSortBuffer,
215    expression: PhysicalSortExpr,
216    limit: Option<usize>,
217    produced: usize,
218    input: DfSendableRecordBatchStream,
219    input_complete: bool,
220    schema: SchemaRef,
221    partition_ranges: Vec<PartitionRange>,
222    #[allow(dead_code)] // this is used under #[debug_assertions]
223    partition: usize,
224    cur_part_idx: usize,
225    evaluating_batch: Option<DfRecordBatch>,
226    metrics: BaselineMetrics,
227    context: Arc<TaskContext>,
228    root_metrics: ExecutionPlanMetricsSet,
229}
230
231impl PartSortStream {
232    fn new(
233        context: Arc<TaskContext>,
234        sort: &PartSortExec,
235        limit: Option<usize>,
236        input: DfSendableRecordBatchStream,
237        partition_ranges: Vec<PartitionRange>,
238        partition: usize,
239    ) -> datafusion_common::Result<Self> {
240        let buffer = if let Some(limit) = limit {
241            PartSortBuffer::Top(
242                TopK::try_new(
243                    partition,
244                    sort.schema().clone(),
245                    LexOrdering::new(vec![sort.expression.clone()]),
246                    limit,
247                    context.session_config().batch_size(),
248                    context.runtime_env(),
249                    &sort.metrics,
250                )?,
251                0,
252            )
253        } else {
254            PartSortBuffer::All(Vec::new())
255        };
256
257        Ok(Self {
258            reservation: MemoryConsumer::new("PartSortStream".to_string())
259                .register(&context.runtime_env().memory_pool),
260            buffer,
261            expression: sort.expression.clone(),
262            limit,
263            produced: 0,
264            input,
265            input_complete: false,
266            schema: sort.input.schema(),
267            partition_ranges,
268            partition,
269            cur_part_idx: 0,
270            evaluating_batch: None,
271            metrics: BaselineMetrics::new(&sort.metrics, partition),
272            context,
273            root_metrics: sort.metrics.clone(),
274        })
275    }
276}
277
278macro_rules! array_check_helper {
279    ($t:ty, $unit:expr, $arr:expr, $cur_range:expr, $min_max_idx:expr) => {{
280            if $cur_range.start.unit().as_arrow_time_unit() != $unit
281            || $cur_range.end.unit().as_arrow_time_unit() != $unit
282        {
283            internal_err!(
284                "PartitionRange unit mismatch, expect {:?}, found {:?}",
285                $cur_range.start.unit(),
286                $unit
287            )?;
288        }
289        let arr = $arr
290            .as_any()
291            .downcast_ref::<arrow::array::PrimitiveArray<$t>>()
292            .unwrap();
293
294        let min = arr.value($min_max_idx.0);
295        let max = arr.value($min_max_idx.1);
296        let (min, max) = if min < max{
297            (min, max)
298        } else {
299            (max, min)
300        };
301        let cur_min = $cur_range.start.value();
302        let cur_max = $cur_range.end.value();
303        // note that PartitionRange is left inclusive and right exclusive
304        if !(min >= cur_min && max < cur_max) {
305            internal_err!(
306                "Sort column min/max value out of partition range: sort_column.min_max=[{:?}, {:?}] not in PartitionRange=[{:?}, {:?}]",
307                min,
308                max,
309                cur_min,
310                cur_max
311            )?;
312        }
313    }};
314}
315
316impl PartSortStream {
317    /// check whether the sort column's min/max value is within the partition range
318    fn check_in_range(
319        &self,
320        sort_column: &ArrayRef,
321        min_max_idx: (usize, usize),
322    ) -> datafusion_common::Result<()> {
323        if self.cur_part_idx >= self.partition_ranges.len() {
324            internal_err!(
325                "Partition index out of range: {} >= {}",
326                self.cur_part_idx,
327                self.partition_ranges.len()
328            )?;
329        }
330        let cur_range = self.partition_ranges[self.cur_part_idx];
331
332        downcast_ts_array!(
333            sort_column.data_type() => (array_check_helper, sort_column, cur_range, min_max_idx),
334            _ => internal_err!(
335                "Unsupported data type for sort column: {:?}",
336                sort_column.data_type()
337            )?,
338        );
339
340        Ok(())
341    }
342
343    /// Try find data whose value exceeds the current partition range.
344    ///
345    /// Returns `None` if no such data is found, and `Some(idx)` where idx points to
346    /// the first data that exceeds the current partition range.
347    fn try_find_next_range(
348        &self,
349        sort_column: &ArrayRef,
350    ) -> datafusion_common::Result<Option<usize>> {
351        if sort_column.is_empty() {
352            return Ok(Some(0));
353        }
354
355        // check if the current partition index is out of range
356        if self.cur_part_idx >= self.partition_ranges.len() {
357            internal_err!(
358                "Partition index out of range: {} >= {}",
359                self.cur_part_idx,
360                self.partition_ranges.len()
361            )?;
362        }
363        let cur_range = self.partition_ranges[self.cur_part_idx];
364
365        let sort_column_iter = downcast_ts_array!(
366            sort_column.data_type() => (array_iter_helper, sort_column),
367            _ => internal_err!(
368                "Unsupported data type for sort column: {:?}",
369                sort_column.data_type()
370            )?,
371        );
372
373        for (idx, val) in sort_column_iter {
374            // ignore vacant time index data
375            if let Some(val) = val {
376                if val >= cur_range.end.value() || val < cur_range.start.value() {
377                    return Ok(Some(idx));
378                }
379            }
380        }
381
382        Ok(None)
383    }
384
385    fn push_buffer(&mut self, batch: DfRecordBatch) -> datafusion_common::Result<()> {
386        match &mut self.buffer {
387            PartSortBuffer::All(v) => v.push(batch),
388            PartSortBuffer::Top(top, cnt) => {
389                *cnt += batch.num_rows();
390                top.insert_batch(batch)?;
391            }
392        }
393
394        Ok(())
395    }
396
397    /// Sort and clear the buffer and return the sorted record batch
398    ///
399    /// this function will return a empty record batch if the buffer is empty
400    fn sort_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
401        match &mut self.buffer {
402            PartSortBuffer::All(_) => self.sort_all_buffer(),
403            PartSortBuffer::Top(_, _) => self.sort_top_buffer(),
404        }
405    }
406
407    /// Internal method for sorting `All` buffer (without limit).
408    fn sort_all_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
409        let PartSortBuffer::All(buffer) =
410            std::mem::replace(&mut self.buffer, PartSortBuffer::All(Vec::new()))
411        else {
412            unreachable!("buffer type is checked before and should be All variant")
413        };
414
415        if buffer.is_empty() {
416            return Ok(DfRecordBatch::new_empty(self.schema.clone()));
417        }
418        let mut sort_columns = Vec::with_capacity(buffer.len());
419        let mut opt = None;
420        for batch in buffer.iter() {
421            let sort_column = self.expression.evaluate_to_sort_column(batch)?;
422            opt = opt.or(sort_column.options);
423            sort_columns.push(sort_column.values);
424        }
425
426        let sort_column =
427            concat(&sort_columns.iter().map(|a| a.as_ref()).collect_vec()).map_err(|e| {
428                DataFusionError::ArrowError(
429                    e,
430                    Some(format!("Fail to concat sort columns at {}", location!())),
431                )
432            })?;
433
434        let indices = sort_to_indices(&sort_column, opt, self.limit).map_err(|e| {
435            DataFusionError::ArrowError(
436                e,
437                Some(format!("Fail to sort to indices at {}", location!())),
438            )
439        })?;
440        if indices.is_empty() {
441            return Ok(DfRecordBatch::new_empty(self.schema.clone()));
442        }
443
444        self.check_in_range(
445            &sort_column,
446            (
447                indices.value(0) as usize,
448                indices.value(indices.len() - 1) as usize,
449            ),
450        )
451        .inspect_err(|_e| {
452            #[cfg(debug_assertions)]
453            common_telemetry::error!(
454                "Fail to check sort column in range at {}, current_idx: {}, num_rows: {}, err: {}",
455                self.partition,
456                self.cur_part_idx,
457                sort_column.len(),
458                _e
459            );
460        })?;
461
462        // reserve memory for the concat input and sorted output
463        let total_mem: usize = buffer.iter().map(|r| r.get_array_memory_size()).sum();
464        self.reservation.try_grow(total_mem * 2)?;
465
466        let full_input = concat_batches(&self.schema, &buffer).map_err(|e| {
467            DataFusionError::ArrowError(
468                e,
469                Some(format!(
470                    "Fail to concat input batches when sorting at {}",
471                    location!()
472                )),
473            )
474        })?;
475
476        let sorted = take_record_batch(&full_input, &indices).map_err(|e| {
477            DataFusionError::ArrowError(
478                e,
479                Some(format!(
480                    "Fail to take result record batch when sorting at {}",
481                    location!()
482                )),
483            )
484        })?;
485
486        self.produced += sorted.num_rows();
487        drop(full_input);
488        // here remove both buffer and full_input memory
489        self.reservation.shrink(2 * total_mem);
490        Ok(sorted)
491    }
492
493    /// Internal method for sorting `Top` buffer (with limit).
494    fn sort_top_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
495        let new_top_buffer = TopK::try_new(
496            self.partition,
497            self.schema().clone(),
498            LexOrdering::new(vec![self.expression.clone()]),
499            self.limit.unwrap(),
500            self.context.session_config().batch_size(),
501            self.context.runtime_env(),
502            &self.root_metrics,
503        )?;
504        let PartSortBuffer::Top(top_k, _) =
505            std::mem::replace(&mut self.buffer, PartSortBuffer::Top(new_top_buffer, 0))
506        else {
507            unreachable!("buffer type is checked before and should be Top variant")
508        };
509
510        let mut result_stream = top_k.emit()?;
511        let mut placeholder_ctx = std::task::Context::from_waker(futures::task::noop_waker_ref());
512        let mut results = vec![];
513        // according to the current implementation of `TopK`, the result stream will always be ready
514        loop {
515            match result_stream.poll_next_unpin(&mut placeholder_ctx) {
516                Poll::Ready(Some(batch)) => {
517                    let batch = batch?;
518                    results.push(batch);
519                }
520                Poll::Pending => {
521                    #[cfg(debug_assertions)]
522                    unreachable!("TopK result stream should always be ready")
523                }
524                Poll::Ready(None) => {
525                    break;
526                }
527            }
528        }
529
530        let concat_batch = concat_batches(&self.schema, &results).map_err(|e| {
531            DataFusionError::ArrowError(
532                e,
533                Some(format!(
534                    "Fail to concat top k result record batch when sorting at {}",
535                    location!()
536                )),
537            )
538        })?;
539
540        Ok(concat_batch)
541    }
542
543    /// Try to split the input batch if it contains data that exceeds the current partition range.
544    ///
545    /// When the input batch contains data that exceeds the current partition range, this function
546    /// will split the input batch into two parts, the first part is within the current partition
547    /// range will be merged and sorted with previous buffer, and the second part will be registered
548    /// to `evaluating_batch` for next polling.
549    ///
550    /// Returns `None` if the input batch is empty or fully within the current partition range, and
551    /// `Some(batch)` otherwise.
552    fn split_batch(
553        &mut self,
554        batch: DfRecordBatch,
555    ) -> datafusion_common::Result<Option<DfRecordBatch>> {
556        if batch.num_rows() == 0 {
557            return Ok(None);
558        }
559
560        let sort_column = self
561            .expression
562            .expr
563            .evaluate(&batch)?
564            .into_array(batch.num_rows())?;
565
566        let next_range_idx = self.try_find_next_range(&sort_column)?;
567        let Some(idx) = next_range_idx else {
568            self.push_buffer(batch)?;
569            // keep polling input for next batch
570            return Ok(None);
571        };
572
573        let this_range = batch.slice(0, idx);
574        let remaining_range = batch.slice(idx, batch.num_rows() - idx);
575        if this_range.num_rows() != 0 {
576            self.push_buffer(this_range)?;
577        }
578        // mark end of current PartitionRange
579        let sorted_batch = self.sort_buffer();
580        // step to next proper PartitionRange
581        self.cur_part_idx += 1;
582        let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
583        if self.try_find_next_range(&next_sort_column)?.is_some() {
584            // remaining batch still contains data that exceeds the current partition range
585            // register the remaining batch for next polling
586            self.evaluating_batch = Some(remaining_range);
587        } else {
588            // remaining batch is within the current partition range
589            // push to the buffer and continue polling
590            if remaining_range.num_rows() != 0 {
591                self.push_buffer(remaining_range)?;
592            }
593        }
594
595        sorted_batch.map(|x| if x.num_rows() == 0 { None } else { Some(x) })
596    }
597
598    pub fn poll_next_inner(
599        mut self: Pin<&mut Self>,
600        cx: &mut Context<'_>,
601    ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
602        loop {
603            // no more input, sort the buffer and return
604            if self.input_complete {
605                if self.buffer.is_empty() {
606                    return Poll::Ready(None);
607                } else {
608                    return Poll::Ready(Some(self.sort_buffer()));
609                }
610            }
611
612            // if there is a remaining batch being evaluated from last run,
613            // split on it instead of fetching new batch
614            if let Some(evaluating_batch) = self.evaluating_batch.take()
615                && evaluating_batch.num_rows() != 0
616            {
617                if let Some(sorted_batch) = self.split_batch(evaluating_batch)? {
618                    return Poll::Ready(Some(Ok(sorted_batch)));
619                } else {
620                    continue;
621                }
622            }
623
624            // fetch next batch from input
625            let res = self.input.as_mut().poll_next(cx);
626            match res {
627                Poll::Ready(Some(Ok(batch))) => {
628                    if let Some(sorted_batch) = self.split_batch(batch)? {
629                        return Poll::Ready(Some(Ok(sorted_batch)));
630                    } else {
631                        continue;
632                    }
633                }
634                // input stream end, mark and continue
635                Poll::Ready(None) => {
636                    self.input_complete = true;
637                    continue;
638                }
639                Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
640                Poll::Pending => return Poll::Pending,
641            }
642        }
643    }
644}
645
646impl Stream for PartSortStream {
647    type Item = datafusion_common::Result<DfRecordBatch>;
648
649    fn poll_next(
650        mut self: Pin<&mut Self>,
651        cx: &mut Context<'_>,
652    ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
653        let result = self.as_mut().poll_next_inner(cx);
654        self.metrics.record_poll(result)
655    }
656}
657
658impl RecordBatchStream for PartSortStream {
659    fn schema(&self) -> SchemaRef {
660        self.schema.clone()
661    }
662}
663
664#[cfg(test)]
665mod test {
666    use std::sync::Arc;
667
668    use arrow::json::ArrayWriter;
669    use arrow_schema::{DataType, Field, Schema, SortOptions, TimeUnit};
670    use common_time::Timestamp;
671    use datafusion_physical_expr::expressions::Column;
672    use futures::StreamExt;
673    use store_api::region_engine::PartitionRange;
674
675    use super::*;
676    use crate::test_util::{new_ts_array, MockInputExec};
677
678    #[tokio::test]
679    async fn fuzzy_test() {
680        let test_cnt = 100;
681        // bound for total count of PartitionRange
682        let part_cnt_bound = 100;
683        // bound for timestamp range size and offset for each PartitionRange
684        let range_size_bound = 100;
685        let range_offset_bound = 100;
686        // bound for batch count and size within each PartitionRange
687        let batch_cnt_bound = 20;
688        let batch_size_bound = 100;
689
690        let mut rng = fastrand::Rng::new();
691        rng.seed(1337);
692
693        let mut test_cases = Vec::new();
694
695        for case_id in 0..test_cnt {
696            let mut bound_val: Option<i64> = None;
697            let descending = rng.bool();
698            let nulls_first = rng.bool();
699            let opt = SortOptions {
700                descending,
701                nulls_first,
702            };
703            let limit = if rng.bool() {
704                Some(rng.usize(0..batch_cnt_bound * batch_size_bound))
705            } else {
706                None
707            };
708            let unit = match rng.u8(0..3) {
709                0 => TimeUnit::Second,
710                1 => TimeUnit::Millisecond,
711                2 => TimeUnit::Microsecond,
712                _ => TimeUnit::Nanosecond,
713            };
714
715            let schema = Schema::new(vec![Field::new(
716                "ts",
717                DataType::Timestamp(unit, None),
718                false,
719            )]);
720            let schema = Arc::new(schema);
721
722            let mut input_ranged_data = vec![];
723            let mut output_ranges = vec![];
724            let mut output_data = vec![];
725            // generate each input `PartitionRange`
726            for part_id in 0..rng.usize(0..part_cnt_bound) {
727                // generate each `PartitionRange`'s timestamp range
728                let (start, end) = if descending {
729                    let end = bound_val
730                        .map(
731                            |i| i
732                            .checked_sub(rng.i64(0..range_offset_bound))
733                            .expect("Bad luck, fuzzy test generate data that will overflow, change seed and try again")
734                        )
735                        .unwrap_or_else(|| rng.i64(-100000000..100000000));
736                    bound_val = Some(end);
737                    let start = end - rng.i64(1..range_size_bound);
738                    let start = Timestamp::new(start, unit.into());
739                    let end = Timestamp::new(end, unit.into());
740                    (start, end)
741                } else {
742                    let start = bound_val
743                        .map(|i| i + rng.i64(0..range_offset_bound))
744                        .unwrap_or_else(|| rng.i64(..));
745                    bound_val = Some(start);
746                    let end = start + rng.i64(1..range_size_bound);
747                    let start = Timestamp::new(start, unit.into());
748                    let end = Timestamp::new(end, unit.into());
749                    (start, end)
750                };
751                assert!(start < end);
752
753                let mut per_part_sort_data = vec![];
754                let mut batches = vec![];
755                for _batch_idx in 0..rng.usize(1..batch_cnt_bound) {
756                    let cnt = rng.usize(0..batch_size_bound) + 1;
757                    let iter = 0..rng.usize(0..cnt);
758                    let mut data_gen = iter
759                        .map(|_| rng.i64(start.value()..end.value()))
760                        .collect_vec();
761                    if data_gen.is_empty() {
762                        // current batch is empty, skip
763                        continue;
764                    }
765                    // mito always sort on ASC order
766                    data_gen.sort();
767                    per_part_sort_data.extend(data_gen.clone());
768                    let arr = new_ts_array(unit, data_gen.clone());
769                    let batch = DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap();
770                    batches.push(batch);
771                }
772
773                let range = PartitionRange {
774                    start,
775                    end,
776                    num_rows: batches.iter().map(|b| b.num_rows()).sum(),
777                    identifier: part_id,
778                };
779                input_ranged_data.push((range, batches));
780
781                output_ranges.push(range);
782                if per_part_sort_data.is_empty() {
783                    continue;
784                }
785                output_data.extend_from_slice(&per_part_sort_data);
786            }
787
788            // adjust output data with adjacent PartitionRanges
789            let mut output_data_iter = output_data.iter().peekable();
790            let mut output_data = vec![];
791            for range in output_ranges.clone() {
792                let mut cur_data = vec![];
793                while let Some(val) = output_data_iter.peek() {
794                    if **val < range.start.value() || **val >= range.end.value() {
795                        break;
796                    }
797                    cur_data.push(*output_data_iter.next().unwrap());
798                }
799
800                if cur_data.is_empty() {
801                    continue;
802                }
803
804                if descending {
805                    cur_data.sort_by(|a, b| b.cmp(a));
806                } else {
807                    cur_data.sort();
808                }
809                output_data.push(cur_data);
810            }
811
812            let expected_output = output_data
813                .into_iter()
814                .map(|a| {
815                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
816                })
817                .map(|rb| {
818                    // trim expected output with limit
819                    if let Some(limit) = limit
820                        && rb.num_rows() > limit
821                    {
822                        rb.slice(0, limit)
823                    } else {
824                        rb
825                    }
826                })
827                .collect_vec();
828
829            test_cases.push((
830                case_id,
831                unit,
832                input_ranged_data,
833                schema,
834                opt,
835                limit,
836                expected_output,
837            ));
838        }
839
840        for (case_id, _unit, input_ranged_data, schema, opt, limit, expected_output) in test_cases {
841            run_test(
842                case_id,
843                input_ranged_data,
844                schema,
845                opt,
846                limit,
847                expected_output,
848            )
849            .await;
850        }
851    }
852
853    #[tokio::test]
854    async fn simple_case() {
855        let testcases = vec![
856            (
857                TimeUnit::Millisecond,
858                vec![
859                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]),
860                    ((5, 10), vec![vec![5, 6], vec![7, 8]]),
861                ],
862                false,
863                None,
864                vec![vec![1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9]],
865            ),
866            (
867                TimeUnit::Millisecond,
868                vec![
869                    ((5, 10), vec![vec![5, 6], vec![7, 8, 9]]),
870                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
871                ],
872                true,
873                None,
874                vec![vec![9, 8, 7, 6, 5], vec![8, 7, 6, 5, 4, 3, 2, 1]],
875            ),
876            (
877                TimeUnit::Millisecond,
878                vec![
879                    ((5, 10), vec![]),
880                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
881                ],
882                true,
883                None,
884                vec![vec![8, 7, 6, 5, 4, 3, 2, 1]],
885            ),
886            (
887                TimeUnit::Millisecond,
888                vec![
889                    ((15, 20), vec![vec![17, 18, 19]]),
890                    ((10, 15), vec![]),
891                    ((5, 10), vec![]),
892                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
893                ],
894                true,
895                None,
896                vec![vec![19, 18, 17], vec![8, 7, 6, 5, 4, 3, 2, 1]],
897            ),
898            (
899                TimeUnit::Millisecond,
900                vec![
901                    ((15, 20), vec![]),
902                    ((10, 15), vec![]),
903                    ((5, 10), vec![]),
904                    ((0, 10), vec![]),
905                ],
906                true,
907                None,
908                vec![],
909            ),
910            (
911                TimeUnit::Millisecond,
912                vec![
913                    (
914                        (15, 20),
915                        vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
916                    ),
917                    ((10, 15), vec![]),
918                    ((5, 10), vec![]),
919                    ((0, 10), vec![]),
920                ],
921                true,
922                None,
923                vec![
924                    vec![19, 17, 15],
925                    vec![12, 11, 10],
926                    vec![9, 8, 7, 6, 5],
927                    vec![4, 3, 2, 1],
928                ],
929            ),
930            (
931                TimeUnit::Millisecond,
932                vec![
933                    (
934                        (15, 20),
935                        vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
936                    ),
937                    ((10, 15), vec![]),
938                    ((5, 10), vec![]),
939                    ((0, 10), vec![]),
940                ],
941                true,
942                Some(2),
943                vec![vec![19, 17], vec![12, 11], vec![9, 8], vec![4, 3]],
944            ),
945        ];
946
947        for (identifier, (unit, input_ranged_data, descending, limit, expected_output)) in
948            testcases.into_iter().enumerate()
949        {
950            let schema = Schema::new(vec![Field::new(
951                "ts",
952                DataType::Timestamp(unit, None),
953                false,
954            )]);
955            let schema = Arc::new(schema);
956            let opt = SortOptions {
957                descending,
958                ..Default::default()
959            };
960
961            let input_ranged_data = input_ranged_data
962                .into_iter()
963                .map(|(range, data)| {
964                    let part = PartitionRange {
965                        start: Timestamp::new(range.0, unit.into()),
966                        end: Timestamp::new(range.1, unit.into()),
967                        num_rows: data.iter().map(|b| b.len()).sum(),
968                        identifier,
969                    };
970
971                    let batches = data
972                        .into_iter()
973                        .map(|b| {
974                            let arr = new_ts_array(unit, b);
975                            DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap()
976                        })
977                        .collect_vec();
978                    (part, batches)
979                })
980                .collect_vec();
981
982            let expected_output = expected_output
983                .into_iter()
984                .map(|a| {
985                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
986                })
987                .collect_vec();
988
989            run_test(
990                identifier,
991                input_ranged_data,
992                schema.clone(),
993                opt,
994                limit,
995                expected_output,
996            )
997            .await;
998        }
999    }
1000
1001    #[allow(clippy::print_stdout)]
1002    async fn run_test(
1003        case_id: usize,
1004        input_ranged_data: Vec<(PartitionRange, Vec<DfRecordBatch>)>,
1005        schema: SchemaRef,
1006        opt: SortOptions,
1007        limit: Option<usize>,
1008        expected_output: Vec<DfRecordBatch>,
1009    ) {
1010        for rb in &expected_output {
1011            if let Some(limit) = limit {
1012                assert!(
1013                    rb.num_rows() <= limit,
1014                    "Expect row count in expected output's batch({}) <= limit({})",
1015                    rb.num_rows(),
1016                    limit
1017                );
1018            }
1019        }
1020        let (ranges, batches): (Vec<_>, Vec<_>) = input_ranged_data.clone().into_iter().unzip();
1021
1022        let batches = batches
1023            .into_iter()
1024            .flat_map(|mut cols| {
1025                cols.push(DfRecordBatch::new_empty(schema.clone()));
1026                cols
1027            })
1028            .collect_vec();
1029        let mock_input = MockInputExec::new(batches, schema.clone());
1030
1031        let exec = PartSortExec::new(
1032            PhysicalSortExpr {
1033                expr: Arc::new(Column::new("ts", 0)),
1034                options: opt,
1035            },
1036            limit,
1037            vec![ranges.clone()],
1038            Arc::new(mock_input),
1039        );
1040
1041        let exec_stream = exec.execute(0, Arc::new(TaskContext::default())).unwrap();
1042
1043        let real_output = exec_stream.map(|r| r.unwrap()).collect::<Vec<_>>().await;
1044        // a makeshift solution for compare large data
1045        if real_output != expected_output {
1046            let mut first_diff = 0;
1047            for (idx, (lhs, rhs)) in real_output.iter().zip(expected_output.iter()).enumerate() {
1048                if lhs != rhs {
1049                    first_diff = idx;
1050                    break;
1051                }
1052            }
1053            println!("first diff batch at {}", first_diff);
1054            println!(
1055                "ranges: {:?}",
1056                ranges
1057                    .into_iter()
1058                    .map(|r| (r.start.to_chrono_datetime(), r.end.to_chrono_datetime()))
1059                    .enumerate()
1060                    .collect::<Vec<_>>()
1061            );
1062
1063            let mut full_msg = String::new();
1064            {
1065                let mut buf = Vec::with_capacity(10 * real_output.len());
1066                for batch in real_output.iter().skip(first_diff) {
1067                    let mut rb_json: Vec<u8> = Vec::new();
1068                    let mut writer = ArrayWriter::new(&mut rb_json);
1069                    writer.write(batch).unwrap();
1070                    writer.finish().unwrap();
1071                    buf.append(&mut rb_json);
1072                    buf.push(b',');
1073                }
1074                // TODO(discord9): better ways to print buf
1075                let buf = String::from_utf8_lossy(&buf);
1076                full_msg += &format!("\ncase_id:{case_id}, real_output \n{buf}\n");
1077            }
1078            {
1079                let mut buf = Vec::with_capacity(10 * real_output.len());
1080                for batch in expected_output.iter().skip(first_diff) {
1081                    let mut rb_json: Vec<u8> = Vec::new();
1082                    let mut writer = ArrayWriter::new(&mut rb_json);
1083                    writer.write(batch).unwrap();
1084                    writer.finish().unwrap();
1085                    buf.append(&mut rb_json);
1086                    buf.push(b',');
1087                }
1088                let buf = String::from_utf8_lossy(&buf);
1089                full_msg += &format!("case_id:{case_id}, expected_output \n{buf}");
1090            }
1091            panic!(
1092                "case_{} failed, opt: {:?},\n real output has {} batches, {} rows, expected has {} batches with {} rows\nfull msg: {}",
1093                case_id, opt,
1094                real_output.len(),
1095                real_output.iter().map(|x|x.num_rows()).sum::<usize>(),
1096                expected_output.len(),
1097                expected_output.iter().map(|x|x.num_rows()).sum::<usize>(), full_msg
1098            );
1099        }
1100    }
1101}