query/
part_sort.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Module for sorting input data within each [`PartitionRange`].
16//!
17//! This module defines the [`PartSortExec`] execution plan, which sorts each
18//! partition ([`PartitionRange`]) independently based on the provided physical
19//! sort expressions.
20
21use std::any::Any;
22use std::pin::Pin;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26use arrow::array::ArrayRef;
27use arrow::compute::{concat, concat_batches, take_record_batch};
28use arrow_schema::SchemaRef;
29use common_recordbatch::{DfRecordBatch, DfSendableRecordBatchStream};
30use datafusion::common::arrow::compute::sort_to_indices;
31use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation};
32use datafusion::execution::{RecordBatchStream, TaskContext};
33use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
34use datafusion::physical_plan::{
35    DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, TopK,
36};
37use datafusion_common::{internal_err, DataFusionError};
38use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr};
39use futures::{Stream, StreamExt};
40use itertools::Itertools;
41use snafu::location;
42use store_api::region_engine::PartitionRange;
43
44use crate::{array_iter_helper, downcast_ts_array};
45
46/// Sort input within given PartitionRange
47///
48/// Input is assumed to be segmented by empty RecordBatch, which indicates a new `PartitionRange` is starting
49///
50/// and this operator will sort each partition independently within the partition.
51#[derive(Debug, Clone)]
52pub struct PartSortExec {
53    /// Physical sort expressions(that is, sort by timestamp)
54    expression: PhysicalSortExpr,
55    limit: Option<usize>,
56    input: Arc<dyn ExecutionPlan>,
57    /// Execution metrics
58    metrics: ExecutionPlanMetricsSet,
59    partition_ranges: Vec<Vec<PartitionRange>>,
60    properties: PlanProperties,
61}
62
63impl PartSortExec {
64    pub fn new(
65        expression: PhysicalSortExpr,
66        limit: Option<usize>,
67        partition_ranges: Vec<Vec<PartitionRange>>,
68        input: Arc<dyn ExecutionPlan>,
69    ) -> Self {
70        let metrics = ExecutionPlanMetricsSet::new();
71        let properties = input.properties();
72        let properties = PlanProperties::new(
73            input.equivalence_properties().clone(),
74            input.output_partitioning().clone(),
75            properties.emission_type,
76            properties.boundedness,
77        );
78
79        Self {
80            expression,
81            limit,
82            input,
83            metrics,
84            partition_ranges,
85            properties,
86        }
87    }
88
89    pub fn to_stream(
90        &self,
91        context: Arc<TaskContext>,
92        partition: usize,
93    ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
94        let input_stream: DfSendableRecordBatchStream =
95            self.input.execute(partition, context.clone())?;
96
97        if partition >= self.partition_ranges.len() {
98            internal_err!(
99                "Partition index out of range: {} >= {} at {}",
100                partition,
101                self.partition_ranges.len(),
102                snafu::location!()
103            )?;
104        }
105
106        let df_stream = Box::pin(PartSortStream::new(
107            context,
108            self,
109            self.limit,
110            input_stream,
111            self.partition_ranges[partition].clone(),
112            partition,
113        )?) as _;
114
115        Ok(df_stream)
116    }
117}
118
119impl DisplayAs for PartSortExec {
120    fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        write!(
122            f,
123            "PartSortExec: expr={} num_ranges={}",
124            self.expression,
125            self.partition_ranges.len(),
126        )?;
127        if let Some(limit) = self.limit {
128            write!(f, " limit={}", limit)?;
129        }
130        Ok(())
131    }
132}
133
134impl ExecutionPlan for PartSortExec {
135    fn name(&self) -> &str {
136        "PartSortExec"
137    }
138
139    fn as_any(&self) -> &dyn Any {
140        self
141    }
142
143    fn schema(&self) -> SchemaRef {
144        self.input.schema()
145    }
146
147    fn properties(&self) -> &PlanProperties {
148        &self.properties
149    }
150
151    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
152        vec![&self.input]
153    }
154
155    fn with_new_children(
156        self: Arc<Self>,
157        children: Vec<Arc<dyn ExecutionPlan>>,
158    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
159        let new_input = if let Some(first) = children.first() {
160            first
161        } else {
162            internal_err!("No children found")?
163        };
164        Ok(Arc::new(Self::new(
165            self.expression.clone(),
166            self.limit,
167            self.partition_ranges.clone(),
168            new_input.clone(),
169        )))
170    }
171
172    fn execute(
173        &self,
174        partition: usize,
175        context: Arc<TaskContext>,
176    ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
177        self.to_stream(context, partition)
178    }
179
180    fn metrics(&self) -> Option<MetricsSet> {
181        Some(self.metrics.clone_inner())
182    }
183
184    /// # Explain
185    ///
186    /// This plan needs to be executed on each partition independently,
187    /// and is expected to run directly on storage engine's output
188    /// distribution / partition.
189    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
190        vec![false]
191    }
192}
193
194enum PartSortBuffer {
195    All(Vec<DfRecordBatch>),
196    /// TopK buffer with row count.
197    ///
198    /// Given this heap only keeps k element, the capacity of this buffer
199    /// is not accurate, and is only used for empty check.
200    Top(TopK, usize),
201}
202
203impl PartSortBuffer {
204    pub fn is_empty(&self) -> bool {
205        match self {
206            PartSortBuffer::All(v) => v.is_empty(),
207            PartSortBuffer::Top(_, cnt) => *cnt == 0,
208        }
209    }
210}
211
212struct PartSortStream {
213    /// Memory pool for this stream
214    reservation: MemoryReservation,
215    buffer: PartSortBuffer,
216    expression: PhysicalSortExpr,
217    limit: Option<usize>,
218    produced: usize,
219    input: DfSendableRecordBatchStream,
220    input_complete: bool,
221    schema: SchemaRef,
222    partition_ranges: Vec<PartitionRange>,
223    #[allow(dead_code)] // this is used under #[debug_assertions]
224    partition: usize,
225    cur_part_idx: usize,
226    evaluating_batch: Option<DfRecordBatch>,
227    metrics: BaselineMetrics,
228    context: Arc<TaskContext>,
229    root_metrics: ExecutionPlanMetricsSet,
230}
231
232impl PartSortStream {
233    fn new(
234        context: Arc<TaskContext>,
235        sort: &PartSortExec,
236        limit: Option<usize>,
237        input: DfSendableRecordBatchStream,
238        partition_ranges: Vec<PartitionRange>,
239        partition: usize,
240    ) -> datafusion_common::Result<Self> {
241        let buffer = if let Some(limit) = limit {
242            PartSortBuffer::Top(
243                TopK::try_new(
244                    partition,
245                    sort.schema().clone(),
246                    LexOrdering::new(vec![sort.expression.clone()]),
247                    limit,
248                    context.session_config().batch_size(),
249                    context.runtime_env(),
250                    &sort.metrics,
251                )?,
252                0,
253            )
254        } else {
255            PartSortBuffer::All(Vec::new())
256        };
257
258        Ok(Self {
259            reservation: MemoryConsumer::new("PartSortStream".to_string())
260                .register(&context.runtime_env().memory_pool),
261            buffer,
262            expression: sort.expression.clone(),
263            limit,
264            produced: 0,
265            input,
266            input_complete: false,
267            schema: sort.input.schema(),
268            partition_ranges,
269            partition,
270            cur_part_idx: 0,
271            evaluating_batch: None,
272            metrics: BaselineMetrics::new(&sort.metrics, partition),
273            context,
274            root_metrics: sort.metrics.clone(),
275        })
276    }
277}
278
279macro_rules! array_check_helper {
280    ($t:ty, $unit:expr, $arr:expr, $cur_range:expr, $min_max_idx:expr) => {{
281            if $cur_range.start.unit().as_arrow_time_unit() != $unit
282            || $cur_range.end.unit().as_arrow_time_unit() != $unit
283        {
284            internal_err!(
285                "PartitionRange unit mismatch, expect {:?}, found {:?}",
286                $cur_range.start.unit(),
287                $unit
288            )?;
289        }
290        let arr = $arr
291            .as_any()
292            .downcast_ref::<arrow::array::PrimitiveArray<$t>>()
293            .unwrap();
294
295        let min = arr.value($min_max_idx.0);
296        let max = arr.value($min_max_idx.1);
297        let (min, max) = if min < max{
298            (min, max)
299        } else {
300            (max, min)
301        };
302        let cur_min = $cur_range.start.value();
303        let cur_max = $cur_range.end.value();
304        // note that PartitionRange is left inclusive and right exclusive
305        if !(min >= cur_min && max < cur_max) {
306            internal_err!(
307                "Sort column min/max value out of partition range: sort_column.min_max=[{:?}, {:?}] not in PartitionRange=[{:?}, {:?}]",
308                min,
309                max,
310                cur_min,
311                cur_max
312            )?;
313        }
314    }};
315}
316
317impl PartSortStream {
318    /// check whether the sort column's min/max value is within the partition range
319    fn check_in_range(
320        &self,
321        sort_column: &ArrayRef,
322        min_max_idx: (usize, usize),
323    ) -> datafusion_common::Result<()> {
324        if self.cur_part_idx >= self.partition_ranges.len() {
325            internal_err!(
326                "Partition index out of range: {} >= {} at {}",
327                self.cur_part_idx,
328                self.partition_ranges.len(),
329                snafu::location!()
330            )?;
331        }
332        let cur_range = self.partition_ranges[self.cur_part_idx];
333
334        downcast_ts_array!(
335            sort_column.data_type() => (array_check_helper, sort_column, cur_range, min_max_idx),
336            _ => internal_err!(
337                "Unsupported data type for sort column: {:?}",
338                sort_column.data_type()
339            )?,
340        );
341
342        Ok(())
343    }
344
345    /// Try find data whose value exceeds the current partition range.
346    ///
347    /// Returns `None` if no such data is found, and `Some(idx)` where idx points to
348    /// the first data that exceeds the current partition range.
349    fn try_find_next_range(
350        &self,
351        sort_column: &ArrayRef,
352    ) -> datafusion_common::Result<Option<usize>> {
353        if sort_column.is_empty() {
354            return Ok(Some(0));
355        }
356
357        // check if the current partition index is out of range
358        if self.cur_part_idx >= self.partition_ranges.len() {
359            internal_err!(
360                "Partition index out of range: {} >= {} at {}",
361                self.cur_part_idx,
362                self.partition_ranges.len(),
363                snafu::location!()
364            )?;
365        }
366        let cur_range = self.partition_ranges[self.cur_part_idx];
367
368        let sort_column_iter = downcast_ts_array!(
369            sort_column.data_type() => (array_iter_helper, sort_column),
370            _ => internal_err!(
371                "Unsupported data type for sort column: {:?}",
372                sort_column.data_type()
373            )?,
374        );
375
376        for (idx, val) in sort_column_iter {
377            // ignore vacant time index data
378            if let Some(val) = val {
379                if val >= cur_range.end.value() || val < cur_range.start.value() {
380                    return Ok(Some(idx));
381                }
382            }
383        }
384
385        Ok(None)
386    }
387
388    fn push_buffer(&mut self, batch: DfRecordBatch) -> datafusion_common::Result<()> {
389        match &mut self.buffer {
390            PartSortBuffer::All(v) => v.push(batch),
391            PartSortBuffer::Top(top, cnt) => {
392                *cnt += batch.num_rows();
393                top.insert_batch(batch)?;
394            }
395        }
396
397        Ok(())
398    }
399
400    /// Sort and clear the buffer and return the sorted record batch
401    ///
402    /// this function will return a empty record batch if the buffer is empty
403    fn sort_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
404        match &mut self.buffer {
405            PartSortBuffer::All(_) => self.sort_all_buffer(),
406            PartSortBuffer::Top(_, _) => self.sort_top_buffer(),
407        }
408    }
409
410    /// Internal method for sorting `All` buffer (without limit).
411    fn sort_all_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
412        let PartSortBuffer::All(buffer) =
413            std::mem::replace(&mut self.buffer, PartSortBuffer::All(Vec::new()))
414        else {
415            unreachable!("buffer type is checked before and should be All variant")
416        };
417
418        if buffer.is_empty() {
419            return Ok(DfRecordBatch::new_empty(self.schema.clone()));
420        }
421        let mut sort_columns = Vec::with_capacity(buffer.len());
422        let mut opt = None;
423        for batch in buffer.iter() {
424            let sort_column = self.expression.evaluate_to_sort_column(batch)?;
425            opt = opt.or(sort_column.options);
426            sort_columns.push(sort_column.values);
427        }
428
429        let sort_column =
430            concat(&sort_columns.iter().map(|a| a.as_ref()).collect_vec()).map_err(|e| {
431                DataFusionError::ArrowError(
432                    e,
433                    Some(format!("Fail to concat sort columns at {}", location!())),
434                )
435            })?;
436
437        let indices = sort_to_indices(&sort_column, opt, self.limit).map_err(|e| {
438            DataFusionError::ArrowError(
439                e,
440                Some(format!("Fail to sort to indices at {}", location!())),
441            )
442        })?;
443        if indices.is_empty() {
444            return Ok(DfRecordBatch::new_empty(self.schema.clone()));
445        }
446
447        self.check_in_range(
448            &sort_column,
449            (
450                indices.value(0) as usize,
451                indices.value(indices.len() - 1) as usize,
452            ),
453        )
454        .inspect_err(|_e| {
455            #[cfg(debug_assertions)]
456            common_telemetry::error!(
457                "Fail to check sort column in range at {}, current_idx: {}, num_rows: {}, err: {}",
458                self.partition,
459                self.cur_part_idx,
460                sort_column.len(),
461                _e
462            );
463        })?;
464
465        // reserve memory for the concat input and sorted output
466        let total_mem: usize = buffer.iter().map(|r| r.get_array_memory_size()).sum();
467        self.reservation.try_grow(total_mem * 2)?;
468
469        let full_input = concat_batches(&self.schema, &buffer).map_err(|e| {
470            DataFusionError::ArrowError(
471                e,
472                Some(format!(
473                    "Fail to concat input batches when sorting at {}",
474                    location!()
475                )),
476            )
477        })?;
478
479        let sorted = take_record_batch(&full_input, &indices).map_err(|e| {
480            DataFusionError::ArrowError(
481                e,
482                Some(format!(
483                    "Fail to take result record batch when sorting at {}",
484                    location!()
485                )),
486            )
487        })?;
488
489        self.produced += sorted.num_rows();
490        drop(full_input);
491        // here remove both buffer and full_input memory
492        self.reservation.shrink(2 * total_mem);
493        Ok(sorted)
494    }
495
496    /// Internal method for sorting `Top` buffer (with limit).
497    fn sort_top_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
498        let new_top_buffer = TopK::try_new(
499            self.partition,
500            self.schema().clone(),
501            LexOrdering::new(vec![self.expression.clone()]),
502            self.limit.unwrap(),
503            self.context.session_config().batch_size(),
504            self.context.runtime_env(),
505            &self.root_metrics,
506        )?;
507        let PartSortBuffer::Top(top_k, _) =
508            std::mem::replace(&mut self.buffer, PartSortBuffer::Top(new_top_buffer, 0))
509        else {
510            unreachable!("buffer type is checked before and should be Top variant")
511        };
512
513        let mut result_stream = top_k.emit()?;
514        let mut placeholder_ctx = std::task::Context::from_waker(futures::task::noop_waker_ref());
515        let mut results = vec![];
516        // according to the current implementation of `TopK`, the result stream will always be ready
517        loop {
518            match result_stream.poll_next_unpin(&mut placeholder_ctx) {
519                Poll::Ready(Some(batch)) => {
520                    let batch = batch?;
521                    results.push(batch);
522                }
523                Poll::Pending => {
524                    #[cfg(debug_assertions)]
525                    unreachable!("TopK result stream should always be ready")
526                }
527                Poll::Ready(None) => {
528                    break;
529                }
530            }
531        }
532
533        let concat_batch = concat_batches(&self.schema, &results).map_err(|e| {
534            DataFusionError::ArrowError(
535                e,
536                Some(format!(
537                    "Fail to concat top k result record batch when sorting at {}",
538                    location!()
539                )),
540            )
541        })?;
542
543        Ok(concat_batch)
544    }
545
546    /// Try to split the input batch if it contains data that exceeds the current partition range.
547    ///
548    /// When the input batch contains data that exceeds the current partition range, this function
549    /// will split the input batch into two parts, the first part is within the current partition
550    /// range will be merged and sorted with previous buffer, and the second part will be registered
551    /// to `evaluating_batch` for next polling.
552    ///
553    /// Returns `None` if the input batch is empty or fully within the current partition range, and
554    /// `Some(batch)` otherwise.
555    fn split_batch(
556        &mut self,
557        batch: DfRecordBatch,
558    ) -> datafusion_common::Result<Option<DfRecordBatch>> {
559        if batch.num_rows() == 0 {
560            return Ok(None);
561        }
562
563        let sort_column = self
564            .expression
565            .expr
566            .evaluate(&batch)?
567            .into_array(batch.num_rows())?;
568
569        let next_range_idx = self.try_find_next_range(&sort_column)?;
570        let Some(idx) = next_range_idx else {
571            self.push_buffer(batch)?;
572            // keep polling input for next batch
573            return Ok(None);
574        };
575
576        let this_range = batch.slice(0, idx);
577        let remaining_range = batch.slice(idx, batch.num_rows() - idx);
578        if this_range.num_rows() != 0 {
579            self.push_buffer(this_range)?;
580        }
581        // mark end of current PartitionRange
582        let sorted_batch = self.sort_buffer();
583        // step to next proper PartitionRange
584        self.cur_part_idx += 1;
585        let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
586        if self.try_find_next_range(&next_sort_column)?.is_some() {
587            // remaining batch still contains data that exceeds the current partition range
588            // register the remaining batch for next polling
589            self.evaluating_batch = Some(remaining_range);
590        } else {
591            // remaining batch is within the current partition range
592            // push to the buffer and continue polling
593            if remaining_range.num_rows() != 0 {
594                self.push_buffer(remaining_range)?;
595            }
596        }
597
598        sorted_batch.map(|x| if x.num_rows() == 0 { None } else { Some(x) })
599    }
600
601    pub fn poll_next_inner(
602        mut self: Pin<&mut Self>,
603        cx: &mut Context<'_>,
604    ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
605        loop {
606            // no more input, sort the buffer and return
607            if self.input_complete {
608                if self.buffer.is_empty() {
609                    return Poll::Ready(None);
610                } else {
611                    return Poll::Ready(Some(self.sort_buffer()));
612                }
613            }
614
615            // if there is a remaining batch being evaluated from last run,
616            // split on it instead of fetching new batch
617            if let Some(evaluating_batch) = self.evaluating_batch.take()
618                && evaluating_batch.num_rows() != 0
619            {
620                if let Some(sorted_batch) = self.split_batch(evaluating_batch)? {
621                    return Poll::Ready(Some(Ok(sorted_batch)));
622                } else {
623                    continue;
624                }
625            }
626
627            // fetch next batch from input
628            let res = self.input.as_mut().poll_next(cx);
629            match res {
630                Poll::Ready(Some(Ok(batch))) => {
631                    if let Some(sorted_batch) = self.split_batch(batch)? {
632                        return Poll::Ready(Some(Ok(sorted_batch)));
633                    } else {
634                        continue;
635                    }
636                }
637                // input stream end, mark and continue
638                Poll::Ready(None) => {
639                    self.input_complete = true;
640                    continue;
641                }
642                Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
643                Poll::Pending => return Poll::Pending,
644            }
645        }
646    }
647}
648
649impl Stream for PartSortStream {
650    type Item = datafusion_common::Result<DfRecordBatch>;
651
652    fn poll_next(
653        mut self: Pin<&mut Self>,
654        cx: &mut Context<'_>,
655    ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
656        let result = self.as_mut().poll_next_inner(cx);
657        self.metrics.record_poll(result)
658    }
659}
660
661impl RecordBatchStream for PartSortStream {
662    fn schema(&self) -> SchemaRef {
663        self.schema.clone()
664    }
665}
666
667#[cfg(test)]
668mod test {
669    use std::sync::Arc;
670
671    use arrow::json::ArrayWriter;
672    use arrow_schema::{DataType, Field, Schema, SortOptions, TimeUnit};
673    use common_time::Timestamp;
674    use datafusion_physical_expr::expressions::Column;
675    use futures::StreamExt;
676    use store_api::region_engine::PartitionRange;
677
678    use super::*;
679    use crate::test_util::{new_ts_array, MockInputExec};
680
681    #[tokio::test]
682    async fn fuzzy_test() {
683        let test_cnt = 100;
684        // bound for total count of PartitionRange
685        let part_cnt_bound = 100;
686        // bound for timestamp range size and offset for each PartitionRange
687        let range_size_bound = 100;
688        let range_offset_bound = 100;
689        // bound for batch count and size within each PartitionRange
690        let batch_cnt_bound = 20;
691        let batch_size_bound = 100;
692
693        let mut rng = fastrand::Rng::new();
694        rng.seed(1337);
695
696        let mut test_cases = Vec::new();
697
698        for case_id in 0..test_cnt {
699            let mut bound_val: Option<i64> = None;
700            let descending = rng.bool();
701            let nulls_first = rng.bool();
702            let opt = SortOptions {
703                descending,
704                nulls_first,
705            };
706            let limit = if rng.bool() {
707                Some(rng.usize(0..batch_cnt_bound * batch_size_bound))
708            } else {
709                None
710            };
711            let unit = match rng.u8(0..3) {
712                0 => TimeUnit::Second,
713                1 => TimeUnit::Millisecond,
714                2 => TimeUnit::Microsecond,
715                _ => TimeUnit::Nanosecond,
716            };
717
718            let schema = Schema::new(vec![Field::new(
719                "ts",
720                DataType::Timestamp(unit, None),
721                false,
722            )]);
723            let schema = Arc::new(schema);
724
725            let mut input_ranged_data = vec![];
726            let mut output_ranges = vec![];
727            let mut output_data = vec![];
728            // generate each input `PartitionRange`
729            for part_id in 0..rng.usize(0..part_cnt_bound) {
730                // generate each `PartitionRange`'s timestamp range
731                let (start, end) = if descending {
732                    let end = bound_val
733                        .map(
734                            |i| i
735                            .checked_sub(rng.i64(0..range_offset_bound))
736                            .expect("Bad luck, fuzzy test generate data that will overflow, change seed and try again")
737                        )
738                        .unwrap_or_else(|| rng.i64(-100000000..100000000));
739                    bound_val = Some(end);
740                    let start = end - rng.i64(1..range_size_bound);
741                    let start = Timestamp::new(start, unit.into());
742                    let end = Timestamp::new(end, unit.into());
743                    (start, end)
744                } else {
745                    let start = bound_val
746                        .map(|i| i + rng.i64(0..range_offset_bound))
747                        .unwrap_or_else(|| rng.i64(..));
748                    bound_val = Some(start);
749                    let end = start + rng.i64(1..range_size_bound);
750                    let start = Timestamp::new(start, unit.into());
751                    let end = Timestamp::new(end, unit.into());
752                    (start, end)
753                };
754                assert!(start < end);
755
756                let mut per_part_sort_data = vec![];
757                let mut batches = vec![];
758                for _batch_idx in 0..rng.usize(1..batch_cnt_bound) {
759                    let cnt = rng.usize(0..batch_size_bound) + 1;
760                    let iter = 0..rng.usize(0..cnt);
761                    let mut data_gen = iter
762                        .map(|_| rng.i64(start.value()..end.value()))
763                        .collect_vec();
764                    if data_gen.is_empty() {
765                        // current batch is empty, skip
766                        continue;
767                    }
768                    // mito always sort on ASC order
769                    data_gen.sort();
770                    per_part_sort_data.extend(data_gen.clone());
771                    let arr = new_ts_array(unit, data_gen.clone());
772                    let batch = DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap();
773                    batches.push(batch);
774                }
775
776                let range = PartitionRange {
777                    start,
778                    end,
779                    num_rows: batches.iter().map(|b| b.num_rows()).sum(),
780                    identifier: part_id,
781                };
782                input_ranged_data.push((range, batches));
783
784                output_ranges.push(range);
785                if per_part_sort_data.is_empty() {
786                    continue;
787                }
788                output_data.extend_from_slice(&per_part_sort_data);
789            }
790
791            // adjust output data with adjacent PartitionRanges
792            let mut output_data_iter = output_data.iter().peekable();
793            let mut output_data = vec![];
794            for range in output_ranges.clone() {
795                let mut cur_data = vec![];
796                while let Some(val) = output_data_iter.peek() {
797                    if **val < range.start.value() || **val >= range.end.value() {
798                        break;
799                    }
800                    cur_data.push(*output_data_iter.next().unwrap());
801                }
802
803                if cur_data.is_empty() {
804                    continue;
805                }
806
807                if descending {
808                    cur_data.sort_by(|a, b| b.cmp(a));
809                } else {
810                    cur_data.sort();
811                }
812                output_data.push(cur_data);
813            }
814
815            let expected_output = output_data
816                .into_iter()
817                .map(|a| {
818                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
819                })
820                .map(|rb| {
821                    // trim expected output with limit
822                    if let Some(limit) = limit
823                        && rb.num_rows() > limit
824                    {
825                        rb.slice(0, limit)
826                    } else {
827                        rb
828                    }
829                })
830                .collect_vec();
831
832            test_cases.push((
833                case_id,
834                unit,
835                input_ranged_data,
836                schema,
837                opt,
838                limit,
839                expected_output,
840            ));
841        }
842
843        for (case_id, _unit, input_ranged_data, schema, opt, limit, expected_output) in test_cases {
844            run_test(
845                case_id,
846                input_ranged_data,
847                schema,
848                opt,
849                limit,
850                expected_output,
851            )
852            .await;
853        }
854    }
855
856    #[tokio::test]
857    async fn simple_case() {
858        let testcases = vec![
859            (
860                TimeUnit::Millisecond,
861                vec![
862                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]),
863                    ((5, 10), vec![vec![5, 6], vec![7, 8]]),
864                ],
865                false,
866                None,
867                vec![vec![1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9]],
868            ),
869            (
870                TimeUnit::Millisecond,
871                vec![
872                    ((5, 10), vec![vec![5, 6], vec![7, 8, 9]]),
873                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
874                ],
875                true,
876                None,
877                vec![vec![9, 8, 7, 6, 5], vec![8, 7, 6, 5, 4, 3, 2, 1]],
878            ),
879            (
880                TimeUnit::Millisecond,
881                vec![
882                    ((5, 10), vec![]),
883                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
884                ],
885                true,
886                None,
887                vec![vec![8, 7, 6, 5, 4, 3, 2, 1]],
888            ),
889            (
890                TimeUnit::Millisecond,
891                vec![
892                    ((15, 20), vec![vec![17, 18, 19]]),
893                    ((10, 15), vec![]),
894                    ((5, 10), vec![]),
895                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
896                ],
897                true,
898                None,
899                vec![vec![19, 18, 17], vec![8, 7, 6, 5, 4, 3, 2, 1]],
900            ),
901            (
902                TimeUnit::Millisecond,
903                vec![
904                    ((15, 20), vec![]),
905                    ((10, 15), vec![]),
906                    ((5, 10), vec![]),
907                    ((0, 10), vec![]),
908                ],
909                true,
910                None,
911                vec![],
912            ),
913            (
914                TimeUnit::Millisecond,
915                vec![
916                    (
917                        (15, 20),
918                        vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
919                    ),
920                    ((10, 15), vec![]),
921                    ((5, 10), vec![]),
922                    ((0, 10), vec![]),
923                ],
924                true,
925                None,
926                vec![
927                    vec![19, 17, 15],
928                    vec![12, 11, 10],
929                    vec![9, 8, 7, 6, 5],
930                    vec![4, 3, 2, 1],
931                ],
932            ),
933            (
934                TimeUnit::Millisecond,
935                vec![
936                    (
937                        (15, 20),
938                        vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
939                    ),
940                    ((10, 15), vec![]),
941                    ((5, 10), vec![]),
942                    ((0, 10), vec![]),
943                ],
944                true,
945                Some(2),
946                vec![vec![19, 17], vec![12, 11], vec![9, 8], vec![4, 3]],
947            ),
948        ];
949
950        for (identifier, (unit, input_ranged_data, descending, limit, expected_output)) in
951            testcases.into_iter().enumerate()
952        {
953            let schema = Schema::new(vec![Field::new(
954                "ts",
955                DataType::Timestamp(unit, None),
956                false,
957            )]);
958            let schema = Arc::new(schema);
959            let opt = SortOptions {
960                descending,
961                ..Default::default()
962            };
963
964            let input_ranged_data = input_ranged_data
965                .into_iter()
966                .map(|(range, data)| {
967                    let part = PartitionRange {
968                        start: Timestamp::new(range.0, unit.into()),
969                        end: Timestamp::new(range.1, unit.into()),
970                        num_rows: data.iter().map(|b| b.len()).sum(),
971                        identifier,
972                    };
973
974                    let batches = data
975                        .into_iter()
976                        .map(|b| {
977                            let arr = new_ts_array(unit, b);
978                            DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap()
979                        })
980                        .collect_vec();
981                    (part, batches)
982                })
983                .collect_vec();
984
985            let expected_output = expected_output
986                .into_iter()
987                .map(|a| {
988                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
989                })
990                .collect_vec();
991
992            run_test(
993                identifier,
994                input_ranged_data,
995                schema.clone(),
996                opt,
997                limit,
998                expected_output,
999            )
1000            .await;
1001        }
1002    }
1003
1004    #[allow(clippy::print_stdout)]
1005    async fn run_test(
1006        case_id: usize,
1007        input_ranged_data: Vec<(PartitionRange, Vec<DfRecordBatch>)>,
1008        schema: SchemaRef,
1009        opt: SortOptions,
1010        limit: Option<usize>,
1011        expected_output: Vec<DfRecordBatch>,
1012    ) {
1013        for rb in &expected_output {
1014            if let Some(limit) = limit {
1015                assert!(
1016                    rb.num_rows() <= limit,
1017                    "Expect row count in expected output's batch({}) <= limit({})",
1018                    rb.num_rows(),
1019                    limit
1020                );
1021            }
1022        }
1023        let (ranges, batches): (Vec<_>, Vec<_>) = input_ranged_data.clone().into_iter().unzip();
1024
1025        let batches = batches
1026            .into_iter()
1027            .flat_map(|mut cols| {
1028                cols.push(DfRecordBatch::new_empty(schema.clone()));
1029                cols
1030            })
1031            .collect_vec();
1032        let mock_input = MockInputExec::new(batches, schema.clone());
1033
1034        let exec = PartSortExec::new(
1035            PhysicalSortExpr {
1036                expr: Arc::new(Column::new("ts", 0)),
1037                options: opt,
1038            },
1039            limit,
1040            vec![ranges.clone()],
1041            Arc::new(mock_input),
1042        );
1043
1044        let exec_stream = exec.execute(0, Arc::new(TaskContext::default())).unwrap();
1045
1046        let real_output = exec_stream.map(|r| r.unwrap()).collect::<Vec<_>>().await;
1047        // a makeshift solution for compare large data
1048        if real_output != expected_output {
1049            let mut first_diff = 0;
1050            for (idx, (lhs, rhs)) in real_output.iter().zip(expected_output.iter()).enumerate() {
1051                if lhs != rhs {
1052                    first_diff = idx;
1053                    break;
1054                }
1055            }
1056            println!("first diff batch at {}", first_diff);
1057            println!(
1058                "ranges: {:?}",
1059                ranges
1060                    .into_iter()
1061                    .map(|r| (r.start.to_chrono_datetime(), r.end.to_chrono_datetime()))
1062                    .enumerate()
1063                    .collect::<Vec<_>>()
1064            );
1065
1066            let mut full_msg = String::new();
1067            {
1068                let mut buf = Vec::with_capacity(10 * real_output.len());
1069                for batch in real_output.iter().skip(first_diff) {
1070                    let mut rb_json: Vec<u8> = Vec::new();
1071                    let mut writer = ArrayWriter::new(&mut rb_json);
1072                    writer.write(batch).unwrap();
1073                    writer.finish().unwrap();
1074                    buf.append(&mut rb_json);
1075                    buf.push(b',');
1076                }
1077                // TODO(discord9): better ways to print buf
1078                let buf = String::from_utf8_lossy(&buf);
1079                full_msg += &format!("\ncase_id:{case_id}, real_output \n{buf}\n");
1080            }
1081            {
1082                let mut buf = Vec::with_capacity(10 * real_output.len());
1083                for batch in expected_output.iter().skip(first_diff) {
1084                    let mut rb_json: Vec<u8> = Vec::new();
1085                    let mut writer = ArrayWriter::new(&mut rb_json);
1086                    writer.write(batch).unwrap();
1087                    writer.finish().unwrap();
1088                    buf.append(&mut rb_json);
1089                    buf.push(b',');
1090                }
1091                let buf = String::from_utf8_lossy(&buf);
1092                full_msg += &format!("case_id:{case_id}, expected_output \n{buf}");
1093            }
1094            panic!(
1095                "case_{} failed, opt: {:?},\n real output has {} batches, {} rows, expected has {} batches with {} rows\nfull msg: {}",
1096                case_id, opt,
1097                real_output.len(),
1098                real_output.iter().map(|x|x.num_rows()).sum::<usize>(),
1099                expected_output.len(),
1100                expected_output.iter().map(|x|x.num_rows()).sum::<usize>(), full_msg
1101            );
1102        }
1103    }
1104}