query/
part_sort.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Module for sorting input data within each [`PartitionRange`].
16//!
17//! This module defines the [`PartSortExec`] execution plan, which sorts each
18//! partition ([`PartitionRange`]) independently based on the provided physical
19//! sort expressions.
20
21use std::any::Any;
22use std::pin::Pin;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26use arrow::array::ArrayRef;
27use arrow::compute::{concat, concat_batches, take_record_batch};
28use arrow_schema::SchemaRef;
29use common_recordbatch::{DfRecordBatch, DfSendableRecordBatchStream};
30use datafusion::common::arrow::compute::sort_to_indices;
31use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation};
32use datafusion::execution::{RecordBatchStream, TaskContext};
33use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
34use datafusion::physical_plan::{
35    DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, TopK,
36};
37use datafusion_common::{internal_err, DataFusionError};
38use datafusion_physical_expr::PhysicalSortExpr;
39use futures::{Stream, StreamExt};
40use itertools::Itertools;
41use snafu::location;
42use store_api::region_engine::PartitionRange;
43
44use crate::{array_iter_helper, downcast_ts_array};
45
46/// Sort input within given PartitionRange
47///
48/// Input is assumed to be segmented by empty RecordBatch, which indicates a new `PartitionRange` is starting
49///
50/// and this operator will sort each partition independently within the partition.
51#[derive(Debug, Clone)]
52pub struct PartSortExec {
53    /// Physical sort expressions(that is, sort by timestamp)
54    expression: PhysicalSortExpr,
55    limit: Option<usize>,
56    input: Arc<dyn ExecutionPlan>,
57    /// Execution metrics
58    metrics: ExecutionPlanMetricsSet,
59    partition_ranges: Vec<Vec<PartitionRange>>,
60    properties: PlanProperties,
61}
62
63impl PartSortExec {
64    pub fn new(
65        expression: PhysicalSortExpr,
66        limit: Option<usize>,
67        partition_ranges: Vec<Vec<PartitionRange>>,
68        input: Arc<dyn ExecutionPlan>,
69    ) -> Self {
70        let metrics = ExecutionPlanMetricsSet::new();
71        let properties = input.properties();
72        let properties = PlanProperties::new(
73            input.equivalence_properties().clone(),
74            input.output_partitioning().clone(),
75            properties.emission_type,
76            properties.boundedness,
77        );
78
79        Self {
80            expression,
81            limit,
82            input,
83            metrics,
84            partition_ranges,
85            properties,
86        }
87    }
88
89    pub fn to_stream(
90        &self,
91        context: Arc<TaskContext>,
92        partition: usize,
93    ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
94        let input_stream: DfSendableRecordBatchStream =
95            self.input.execute(partition, context.clone())?;
96
97        if partition >= self.partition_ranges.len() {
98            internal_err!(
99                "Partition index out of range: {} >= {} at {}",
100                partition,
101                self.partition_ranges.len(),
102                snafu::location!()
103            )?;
104        }
105
106        let df_stream = Box::pin(PartSortStream::new(
107            context,
108            self,
109            self.limit,
110            input_stream,
111            self.partition_ranges[partition].clone(),
112            partition,
113        )?) as _;
114
115        Ok(df_stream)
116    }
117}
118
119impl DisplayAs for PartSortExec {
120    fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        write!(
122            f,
123            "PartSortExec: expr={} num_ranges={}",
124            self.expression,
125            self.partition_ranges.len(),
126        )?;
127        if let Some(limit) = self.limit {
128            write!(f, " limit={}", limit)?;
129        }
130        Ok(())
131    }
132}
133
134impl ExecutionPlan for PartSortExec {
135    fn name(&self) -> &str {
136        "PartSortExec"
137    }
138
139    fn as_any(&self) -> &dyn Any {
140        self
141    }
142
143    fn schema(&self) -> SchemaRef {
144        self.input.schema()
145    }
146
147    fn properties(&self) -> &PlanProperties {
148        &self.properties
149    }
150
151    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
152        vec![&self.input]
153    }
154
155    fn with_new_children(
156        self: Arc<Self>,
157        children: Vec<Arc<dyn ExecutionPlan>>,
158    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
159        let new_input = if let Some(first) = children.first() {
160            first
161        } else {
162            internal_err!("No children found")?
163        };
164        Ok(Arc::new(Self::new(
165            self.expression.clone(),
166            self.limit,
167            self.partition_ranges.clone(),
168            new_input.clone(),
169        )))
170    }
171
172    fn execute(
173        &self,
174        partition: usize,
175        context: Arc<TaskContext>,
176    ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
177        self.to_stream(context, partition)
178    }
179
180    fn metrics(&self) -> Option<MetricsSet> {
181        Some(self.metrics.clone_inner())
182    }
183
184    /// # Explain
185    ///
186    /// This plan needs to be executed on each partition independently,
187    /// and is expected to run directly on storage engine's output
188    /// distribution / partition.
189    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
190        vec![false]
191    }
192}
193
194enum PartSortBuffer {
195    All(Vec<DfRecordBatch>),
196    /// TopK buffer with row count.
197    ///
198    /// Given this heap only keeps k element, the capacity of this buffer
199    /// is not accurate, and is only used for empty check.
200    Top(TopK, usize),
201}
202
203impl PartSortBuffer {
204    pub fn is_empty(&self) -> bool {
205        match self {
206            PartSortBuffer::All(v) => v.is_empty(),
207            PartSortBuffer::Top(_, cnt) => *cnt == 0,
208        }
209    }
210}
211
212struct PartSortStream {
213    /// Memory pool for this stream
214    reservation: MemoryReservation,
215    buffer: PartSortBuffer,
216    expression: PhysicalSortExpr,
217    limit: Option<usize>,
218    produced: usize,
219    input: DfSendableRecordBatchStream,
220    input_complete: bool,
221    schema: SchemaRef,
222    partition_ranges: Vec<PartitionRange>,
223    #[allow(dead_code)] // this is used under #[debug_assertions]
224    partition: usize,
225    cur_part_idx: usize,
226    evaluating_batch: Option<DfRecordBatch>,
227    metrics: BaselineMetrics,
228    context: Arc<TaskContext>,
229    root_metrics: ExecutionPlanMetricsSet,
230}
231
232impl PartSortStream {
233    fn new(
234        context: Arc<TaskContext>,
235        sort: &PartSortExec,
236        limit: Option<usize>,
237        input: DfSendableRecordBatchStream,
238        partition_ranges: Vec<PartitionRange>,
239        partition: usize,
240    ) -> datafusion_common::Result<Self> {
241        let buffer = if let Some(limit) = limit {
242            PartSortBuffer::Top(
243                TopK::try_new(
244                    partition,
245                    sort.schema().clone(),
246                    vec![],
247                    [sort.expression.clone()].into(),
248                    limit,
249                    context.session_config().batch_size(),
250                    context.runtime_env(),
251                    &sort.metrics,
252                    None,
253                )?,
254                0,
255            )
256        } else {
257            PartSortBuffer::All(Vec::new())
258        };
259
260        Ok(Self {
261            reservation: MemoryConsumer::new("PartSortStream".to_string())
262                .register(&context.runtime_env().memory_pool),
263            buffer,
264            expression: sort.expression.clone(),
265            limit,
266            produced: 0,
267            input,
268            input_complete: false,
269            schema: sort.input.schema(),
270            partition_ranges,
271            partition,
272            cur_part_idx: 0,
273            evaluating_batch: None,
274            metrics: BaselineMetrics::new(&sort.metrics, partition),
275            context,
276            root_metrics: sort.metrics.clone(),
277        })
278    }
279}
280
281macro_rules! array_check_helper {
282    ($t:ty, $unit:expr, $arr:expr, $cur_range:expr, $min_max_idx:expr) => {{
283            if $cur_range.start.unit().as_arrow_time_unit() != $unit
284            || $cur_range.end.unit().as_arrow_time_unit() != $unit
285        {
286            internal_err!(
287                "PartitionRange unit mismatch, expect {:?}, found {:?}",
288                $cur_range.start.unit(),
289                $unit
290            )?;
291        }
292        let arr = $arr
293            .as_any()
294            .downcast_ref::<arrow::array::PrimitiveArray<$t>>()
295            .unwrap();
296
297        let min = arr.value($min_max_idx.0);
298        let max = arr.value($min_max_idx.1);
299        let (min, max) = if min < max{
300            (min, max)
301        } else {
302            (max, min)
303        };
304        let cur_min = $cur_range.start.value();
305        let cur_max = $cur_range.end.value();
306        // note that PartitionRange is left inclusive and right exclusive
307        if !(min >= cur_min && max < cur_max) {
308            internal_err!(
309                "Sort column min/max value out of partition range: sort_column.min_max=[{:?}, {:?}] not in PartitionRange=[{:?}, {:?}]",
310                min,
311                max,
312                cur_min,
313                cur_max
314            )?;
315        }
316    }};
317}
318
319impl PartSortStream {
320    /// check whether the sort column's min/max value is within the partition range
321    fn check_in_range(
322        &self,
323        sort_column: &ArrayRef,
324        min_max_idx: (usize, usize),
325    ) -> datafusion_common::Result<()> {
326        if self.cur_part_idx >= self.partition_ranges.len() {
327            internal_err!(
328                "Partition index out of range: {} >= {} at {}",
329                self.cur_part_idx,
330                self.partition_ranges.len(),
331                snafu::location!()
332            )?;
333        }
334        let cur_range = self.partition_ranges[self.cur_part_idx];
335
336        downcast_ts_array!(
337            sort_column.data_type() => (array_check_helper, sort_column, cur_range, min_max_idx),
338            _ => internal_err!(
339                "Unsupported data type for sort column: {:?}",
340                sort_column.data_type()
341            )?,
342        );
343
344        Ok(())
345    }
346
347    /// Try find data whose value exceeds the current partition range.
348    ///
349    /// Returns `None` if no such data is found, and `Some(idx)` where idx points to
350    /// the first data that exceeds the current partition range.
351    fn try_find_next_range(
352        &self,
353        sort_column: &ArrayRef,
354    ) -> datafusion_common::Result<Option<usize>> {
355        if sort_column.is_empty() {
356            return Ok(Some(0));
357        }
358
359        // check if the current partition index is out of range
360        if self.cur_part_idx >= self.partition_ranges.len() {
361            internal_err!(
362                "Partition index out of range: {} >= {} at {}",
363                self.cur_part_idx,
364                self.partition_ranges.len(),
365                snafu::location!()
366            )?;
367        }
368        let cur_range = self.partition_ranges[self.cur_part_idx];
369
370        let sort_column_iter = downcast_ts_array!(
371            sort_column.data_type() => (array_iter_helper, sort_column),
372            _ => internal_err!(
373                "Unsupported data type for sort column: {:?}",
374                sort_column.data_type()
375            )?,
376        );
377
378        for (idx, val) in sort_column_iter {
379            // ignore vacant time index data
380            if let Some(val) = val {
381                if val >= cur_range.end.value() || val < cur_range.start.value() {
382                    return Ok(Some(idx));
383                }
384            }
385        }
386
387        Ok(None)
388    }
389
390    fn push_buffer(&mut self, batch: DfRecordBatch) -> datafusion_common::Result<()> {
391        match &mut self.buffer {
392            PartSortBuffer::All(v) => v.push(batch),
393            PartSortBuffer::Top(top, cnt) => {
394                *cnt += batch.num_rows();
395                top.insert_batch(batch)?;
396            }
397        }
398
399        Ok(())
400    }
401
402    /// Sort and clear the buffer and return the sorted record batch
403    ///
404    /// this function will return a empty record batch if the buffer is empty
405    fn sort_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
406        match &mut self.buffer {
407            PartSortBuffer::All(_) => self.sort_all_buffer(),
408            PartSortBuffer::Top(_, _) => self.sort_top_buffer(),
409        }
410    }
411
412    /// Internal method for sorting `All` buffer (without limit).
413    fn sort_all_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
414        let PartSortBuffer::All(buffer) =
415            std::mem::replace(&mut self.buffer, PartSortBuffer::All(Vec::new()))
416        else {
417            unreachable!("buffer type is checked before and should be All variant")
418        };
419
420        if buffer.is_empty() {
421            return Ok(DfRecordBatch::new_empty(self.schema.clone()));
422        }
423        let mut sort_columns = Vec::with_capacity(buffer.len());
424        let mut opt = None;
425        for batch in buffer.iter() {
426            let sort_column = self.expression.evaluate_to_sort_column(batch)?;
427            opt = opt.or(sort_column.options);
428            sort_columns.push(sort_column.values);
429        }
430
431        let sort_column =
432            concat(&sort_columns.iter().map(|a| a.as_ref()).collect_vec()).map_err(|e| {
433                DataFusionError::ArrowError(
434                    Box::new(e),
435                    Some(format!("Fail to concat sort columns at {}", location!())),
436                )
437            })?;
438
439        let indices = sort_to_indices(&sort_column, opt, self.limit).map_err(|e| {
440            DataFusionError::ArrowError(
441                Box::new(e),
442                Some(format!("Fail to sort to indices at {}", location!())),
443            )
444        })?;
445        if indices.is_empty() {
446            return Ok(DfRecordBatch::new_empty(self.schema.clone()));
447        }
448
449        self.check_in_range(
450            &sort_column,
451            (
452                indices.value(0) as usize,
453                indices.value(indices.len() - 1) as usize,
454            ),
455        )
456        .inspect_err(|_e| {
457            #[cfg(debug_assertions)]
458            common_telemetry::error!(
459                "Fail to check sort column in range at {}, current_idx: {}, num_rows: {}, err: {}",
460                self.partition,
461                self.cur_part_idx,
462                sort_column.len(),
463                _e
464            );
465        })?;
466
467        // reserve memory for the concat input and sorted output
468        let total_mem: usize = buffer.iter().map(|r| r.get_array_memory_size()).sum();
469        self.reservation.try_grow(total_mem * 2)?;
470
471        let full_input = concat_batches(&self.schema, &buffer).map_err(|e| {
472            DataFusionError::ArrowError(
473                Box::new(e),
474                Some(format!(
475                    "Fail to concat input batches when sorting at {}",
476                    location!()
477                )),
478            )
479        })?;
480
481        let sorted = take_record_batch(&full_input, &indices).map_err(|e| {
482            DataFusionError::ArrowError(
483                Box::new(e),
484                Some(format!(
485                    "Fail to take result record batch when sorting at {}",
486                    location!()
487                )),
488            )
489        })?;
490
491        self.produced += sorted.num_rows();
492        drop(full_input);
493        // here remove both buffer and full_input memory
494        self.reservation.shrink(2 * total_mem);
495        Ok(sorted)
496    }
497
498    /// Internal method for sorting `Top` buffer (with limit).
499    fn sort_top_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
500        let new_top_buffer = TopK::try_new(
501            self.partition,
502            self.schema().clone(),
503            vec![],
504            [self.expression.clone()].into(),
505            self.limit.unwrap(),
506            self.context.session_config().batch_size(),
507            self.context.runtime_env(),
508            &self.root_metrics,
509            None,
510        )?;
511        let PartSortBuffer::Top(top_k, _) =
512            std::mem::replace(&mut self.buffer, PartSortBuffer::Top(new_top_buffer, 0))
513        else {
514            unreachable!("buffer type is checked before and should be Top variant")
515        };
516
517        let mut result_stream = top_k.emit()?;
518        let mut placeholder_ctx = std::task::Context::from_waker(futures::task::noop_waker_ref());
519        let mut results = vec![];
520        // according to the current implementation of `TopK`, the result stream will always be ready
521        loop {
522            match result_stream.poll_next_unpin(&mut placeholder_ctx) {
523                Poll::Ready(Some(batch)) => {
524                    let batch = batch?;
525                    results.push(batch);
526                }
527                Poll::Pending => {
528                    #[cfg(debug_assertions)]
529                    unreachable!("TopK result stream should always be ready")
530                }
531                Poll::Ready(None) => {
532                    break;
533                }
534            }
535        }
536
537        let concat_batch = concat_batches(&self.schema, &results).map_err(|e| {
538            DataFusionError::ArrowError(
539                Box::new(e),
540                Some(format!(
541                    "Fail to concat top k result record batch when sorting at {}",
542                    location!()
543                )),
544            )
545        })?;
546
547        Ok(concat_batch)
548    }
549
550    /// Try to split the input batch if it contains data that exceeds the current partition range.
551    ///
552    /// When the input batch contains data that exceeds the current partition range, this function
553    /// will split the input batch into two parts, the first part is within the current partition
554    /// range will be merged and sorted with previous buffer, and the second part will be registered
555    /// to `evaluating_batch` for next polling.
556    ///
557    /// Returns `None` if the input batch is empty or fully within the current partition range, and
558    /// `Some(batch)` otherwise.
559    fn split_batch(
560        &mut self,
561        batch: DfRecordBatch,
562    ) -> datafusion_common::Result<Option<DfRecordBatch>> {
563        if batch.num_rows() == 0 {
564            return Ok(None);
565        }
566
567        let sort_column = self
568            .expression
569            .expr
570            .evaluate(&batch)?
571            .into_array(batch.num_rows())?;
572
573        let next_range_idx = self.try_find_next_range(&sort_column)?;
574        let Some(idx) = next_range_idx else {
575            self.push_buffer(batch)?;
576            // keep polling input for next batch
577            return Ok(None);
578        };
579
580        let this_range = batch.slice(0, idx);
581        let remaining_range = batch.slice(idx, batch.num_rows() - idx);
582        if this_range.num_rows() != 0 {
583            self.push_buffer(this_range)?;
584        }
585        // mark end of current PartitionRange
586        let sorted_batch = self.sort_buffer();
587        // step to next proper PartitionRange
588        self.cur_part_idx += 1;
589        let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
590        if self.try_find_next_range(&next_sort_column)?.is_some() {
591            // remaining batch still contains data that exceeds the current partition range
592            // register the remaining batch for next polling
593            self.evaluating_batch = Some(remaining_range);
594        } else {
595            // remaining batch is within the current partition range
596            // push to the buffer and continue polling
597            if remaining_range.num_rows() != 0 {
598                self.push_buffer(remaining_range)?;
599            }
600        }
601
602        sorted_batch.map(|x| if x.num_rows() == 0 { None } else { Some(x) })
603    }
604
605    pub fn poll_next_inner(
606        mut self: Pin<&mut Self>,
607        cx: &mut Context<'_>,
608    ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
609        loop {
610            // no more input, sort the buffer and return
611            if self.input_complete {
612                if self.buffer.is_empty() {
613                    return Poll::Ready(None);
614                } else {
615                    return Poll::Ready(Some(self.sort_buffer()));
616                }
617            }
618
619            // if there is a remaining batch being evaluated from last run,
620            // split on it instead of fetching new batch
621            if let Some(evaluating_batch) = self.evaluating_batch.take()
622                && evaluating_batch.num_rows() != 0
623            {
624                if let Some(sorted_batch) = self.split_batch(evaluating_batch)? {
625                    return Poll::Ready(Some(Ok(sorted_batch)));
626                } else {
627                    continue;
628                }
629            }
630
631            // fetch next batch from input
632            let res = self.input.as_mut().poll_next(cx);
633            match res {
634                Poll::Ready(Some(Ok(batch))) => {
635                    if let Some(sorted_batch) = self.split_batch(batch)? {
636                        return Poll::Ready(Some(Ok(sorted_batch)));
637                    } else {
638                        continue;
639                    }
640                }
641                // input stream end, mark and continue
642                Poll::Ready(None) => {
643                    self.input_complete = true;
644                    continue;
645                }
646                Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
647                Poll::Pending => return Poll::Pending,
648            }
649        }
650    }
651}
652
653impl Stream for PartSortStream {
654    type Item = datafusion_common::Result<DfRecordBatch>;
655
656    fn poll_next(
657        mut self: Pin<&mut Self>,
658        cx: &mut Context<'_>,
659    ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
660        let result = self.as_mut().poll_next_inner(cx);
661        self.metrics.record_poll(result)
662    }
663}
664
665impl RecordBatchStream for PartSortStream {
666    fn schema(&self) -> SchemaRef {
667        self.schema.clone()
668    }
669}
670
671#[cfg(test)]
672mod test {
673    use std::sync::Arc;
674
675    use arrow::json::ArrayWriter;
676    use arrow_schema::{DataType, Field, Schema, SortOptions, TimeUnit};
677    use common_time::Timestamp;
678    use datafusion_physical_expr::expressions::Column;
679    use futures::StreamExt;
680    use store_api::region_engine::PartitionRange;
681
682    use super::*;
683    use crate::test_util::{new_ts_array, MockInputExec};
684
685    #[tokio::test]
686    async fn fuzzy_test() {
687        let test_cnt = 100;
688        // bound for total count of PartitionRange
689        let part_cnt_bound = 100;
690        // bound for timestamp range size and offset for each PartitionRange
691        let range_size_bound = 100;
692        let range_offset_bound = 100;
693        // bound for batch count and size within each PartitionRange
694        let batch_cnt_bound = 20;
695        let batch_size_bound = 100;
696
697        let mut rng = fastrand::Rng::new();
698        rng.seed(1337);
699
700        let mut test_cases = Vec::new();
701
702        for case_id in 0..test_cnt {
703            let mut bound_val: Option<i64> = None;
704            let descending = rng.bool();
705            let nulls_first = rng.bool();
706            let opt = SortOptions {
707                descending,
708                nulls_first,
709            };
710            let limit = if rng.bool() {
711                Some(rng.usize(0..batch_cnt_bound * batch_size_bound))
712            } else {
713                None
714            };
715            let unit = match rng.u8(0..3) {
716                0 => TimeUnit::Second,
717                1 => TimeUnit::Millisecond,
718                2 => TimeUnit::Microsecond,
719                _ => TimeUnit::Nanosecond,
720            };
721
722            let schema = Schema::new(vec![Field::new(
723                "ts",
724                DataType::Timestamp(unit, None),
725                false,
726            )]);
727            let schema = Arc::new(schema);
728
729            let mut input_ranged_data = vec![];
730            let mut output_ranges = vec![];
731            let mut output_data = vec![];
732            // generate each input `PartitionRange`
733            for part_id in 0..rng.usize(0..part_cnt_bound) {
734                // generate each `PartitionRange`'s timestamp range
735                let (start, end) = if descending {
736                    let end = bound_val
737                        .map(
738                            |i| i
739                            .checked_sub(rng.i64(0..range_offset_bound))
740                            .expect("Bad luck, fuzzy test generate data that will overflow, change seed and try again")
741                        )
742                        .unwrap_or_else(|| rng.i64(-100000000..100000000));
743                    bound_val = Some(end);
744                    let start = end - rng.i64(1..range_size_bound);
745                    let start = Timestamp::new(start, unit.into());
746                    let end = Timestamp::new(end, unit.into());
747                    (start, end)
748                } else {
749                    let start = bound_val
750                        .map(|i| i + rng.i64(0..range_offset_bound))
751                        .unwrap_or_else(|| rng.i64(..));
752                    bound_val = Some(start);
753                    let end = start + rng.i64(1..range_size_bound);
754                    let start = Timestamp::new(start, unit.into());
755                    let end = Timestamp::new(end, unit.into());
756                    (start, end)
757                };
758                assert!(start < end);
759
760                let mut per_part_sort_data = vec![];
761                let mut batches = vec![];
762                for _batch_idx in 0..rng.usize(1..batch_cnt_bound) {
763                    let cnt = rng.usize(0..batch_size_bound) + 1;
764                    let iter = 0..rng.usize(0..cnt);
765                    let mut data_gen = iter
766                        .map(|_| rng.i64(start.value()..end.value()))
767                        .collect_vec();
768                    if data_gen.is_empty() {
769                        // current batch is empty, skip
770                        continue;
771                    }
772                    // mito always sort on ASC order
773                    data_gen.sort();
774                    per_part_sort_data.extend(data_gen.clone());
775                    let arr = new_ts_array(unit, data_gen.clone());
776                    let batch = DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap();
777                    batches.push(batch);
778                }
779
780                let range = PartitionRange {
781                    start,
782                    end,
783                    num_rows: batches.iter().map(|b| b.num_rows()).sum(),
784                    identifier: part_id,
785                };
786                input_ranged_data.push((range, batches));
787
788                output_ranges.push(range);
789                if per_part_sort_data.is_empty() {
790                    continue;
791                }
792                output_data.extend_from_slice(&per_part_sort_data);
793            }
794
795            // adjust output data with adjacent PartitionRanges
796            let mut output_data_iter = output_data.iter().peekable();
797            let mut output_data = vec![];
798            for range in output_ranges.clone() {
799                let mut cur_data = vec![];
800                while let Some(val) = output_data_iter.peek() {
801                    if **val < range.start.value() || **val >= range.end.value() {
802                        break;
803                    }
804                    cur_data.push(*output_data_iter.next().unwrap());
805                }
806
807                if cur_data.is_empty() {
808                    continue;
809                }
810
811                if descending {
812                    cur_data.sort_by(|a, b| b.cmp(a));
813                } else {
814                    cur_data.sort();
815                }
816                output_data.push(cur_data);
817            }
818
819            let expected_output = output_data
820                .into_iter()
821                .map(|a| {
822                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
823                })
824                .map(|rb| {
825                    // trim expected output with limit
826                    if let Some(limit) = limit
827                        && rb.num_rows() > limit
828                    {
829                        rb.slice(0, limit)
830                    } else {
831                        rb
832                    }
833                })
834                .collect_vec();
835
836            test_cases.push((
837                case_id,
838                unit,
839                input_ranged_data,
840                schema,
841                opt,
842                limit,
843                expected_output,
844            ));
845        }
846
847        for (case_id, _unit, input_ranged_data, schema, opt, limit, expected_output) in test_cases {
848            run_test(
849                case_id,
850                input_ranged_data,
851                schema,
852                opt,
853                limit,
854                expected_output,
855            )
856            .await;
857        }
858    }
859
860    #[tokio::test]
861    async fn simple_case() {
862        let testcases = vec![
863            (
864                TimeUnit::Millisecond,
865                vec![
866                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]),
867                    ((5, 10), vec![vec![5, 6], vec![7, 8]]),
868                ],
869                false,
870                None,
871                vec![vec![1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9]],
872            ),
873            (
874                TimeUnit::Millisecond,
875                vec![
876                    ((5, 10), vec![vec![5, 6], vec![7, 8, 9]]),
877                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
878                ],
879                true,
880                None,
881                vec![vec![9, 8, 7, 6, 5], vec![8, 7, 6, 5, 4, 3, 2, 1]],
882            ),
883            (
884                TimeUnit::Millisecond,
885                vec![
886                    ((5, 10), vec![]),
887                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
888                ],
889                true,
890                None,
891                vec![vec![8, 7, 6, 5, 4, 3, 2, 1]],
892            ),
893            (
894                TimeUnit::Millisecond,
895                vec![
896                    ((15, 20), vec![vec![17, 18, 19]]),
897                    ((10, 15), vec![]),
898                    ((5, 10), vec![]),
899                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
900                ],
901                true,
902                None,
903                vec![vec![19, 18, 17], vec![8, 7, 6, 5, 4, 3, 2, 1]],
904            ),
905            (
906                TimeUnit::Millisecond,
907                vec![
908                    ((15, 20), vec![]),
909                    ((10, 15), vec![]),
910                    ((5, 10), vec![]),
911                    ((0, 10), vec![]),
912                ],
913                true,
914                None,
915                vec![],
916            ),
917            (
918                TimeUnit::Millisecond,
919                vec![
920                    (
921                        (15, 20),
922                        vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
923                    ),
924                    ((10, 15), vec![]),
925                    ((5, 10), vec![]),
926                    ((0, 10), vec![]),
927                ],
928                true,
929                None,
930                vec![
931                    vec![19, 17, 15],
932                    vec![12, 11, 10],
933                    vec![9, 8, 7, 6, 5],
934                    vec![4, 3, 2, 1],
935                ],
936            ),
937            (
938                TimeUnit::Millisecond,
939                vec![
940                    (
941                        (15, 20),
942                        vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
943                    ),
944                    ((10, 15), vec![]),
945                    ((5, 10), vec![]),
946                    ((0, 10), vec![]),
947                ],
948                true,
949                Some(2),
950                vec![vec![19, 17], vec![12, 11], vec![9, 8], vec![4, 3]],
951            ),
952        ];
953
954        for (identifier, (unit, input_ranged_data, descending, limit, expected_output)) in
955            testcases.into_iter().enumerate()
956        {
957            let schema = Schema::new(vec![Field::new(
958                "ts",
959                DataType::Timestamp(unit, None),
960                false,
961            )]);
962            let schema = Arc::new(schema);
963            let opt = SortOptions {
964                descending,
965                ..Default::default()
966            };
967
968            let input_ranged_data = input_ranged_data
969                .into_iter()
970                .map(|(range, data)| {
971                    let part = PartitionRange {
972                        start: Timestamp::new(range.0, unit.into()),
973                        end: Timestamp::new(range.1, unit.into()),
974                        num_rows: data.iter().map(|b| b.len()).sum(),
975                        identifier,
976                    };
977
978                    let batches = data
979                        .into_iter()
980                        .map(|b| {
981                            let arr = new_ts_array(unit, b);
982                            DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap()
983                        })
984                        .collect_vec();
985                    (part, batches)
986                })
987                .collect_vec();
988
989            let expected_output = expected_output
990                .into_iter()
991                .map(|a| {
992                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
993                })
994                .collect_vec();
995
996            run_test(
997                identifier,
998                input_ranged_data,
999                schema.clone(),
1000                opt,
1001                limit,
1002                expected_output,
1003            )
1004            .await;
1005        }
1006    }
1007
1008    #[allow(clippy::print_stdout)]
1009    async fn run_test(
1010        case_id: usize,
1011        input_ranged_data: Vec<(PartitionRange, Vec<DfRecordBatch>)>,
1012        schema: SchemaRef,
1013        opt: SortOptions,
1014        limit: Option<usize>,
1015        expected_output: Vec<DfRecordBatch>,
1016    ) {
1017        for rb in &expected_output {
1018            if let Some(limit) = limit {
1019                assert!(
1020                    rb.num_rows() <= limit,
1021                    "Expect row count in expected output's batch({}) <= limit({})",
1022                    rb.num_rows(),
1023                    limit
1024                );
1025            }
1026        }
1027        let (ranges, batches): (Vec<_>, Vec<_>) = input_ranged_data.clone().into_iter().unzip();
1028
1029        let batches = batches
1030            .into_iter()
1031            .flat_map(|mut cols| {
1032                cols.push(DfRecordBatch::new_empty(schema.clone()));
1033                cols
1034            })
1035            .collect_vec();
1036        let mock_input = MockInputExec::new(batches, schema.clone());
1037
1038        let exec = PartSortExec::new(
1039            PhysicalSortExpr {
1040                expr: Arc::new(Column::new("ts", 0)),
1041                options: opt,
1042            },
1043            limit,
1044            vec![ranges.clone()],
1045            Arc::new(mock_input),
1046        );
1047
1048        let exec_stream = exec.execute(0, Arc::new(TaskContext::default())).unwrap();
1049
1050        let real_output = exec_stream.map(|r| r.unwrap()).collect::<Vec<_>>().await;
1051        // a makeshift solution for compare large data
1052        if real_output != expected_output {
1053            let mut first_diff = 0;
1054            for (idx, (lhs, rhs)) in real_output.iter().zip(expected_output.iter()).enumerate() {
1055                if lhs != rhs {
1056                    first_diff = idx;
1057                    break;
1058                }
1059            }
1060            println!("first diff batch at {}", first_diff);
1061            println!(
1062                "ranges: {:?}",
1063                ranges
1064                    .into_iter()
1065                    .map(|r| (r.start.to_chrono_datetime(), r.end.to_chrono_datetime()))
1066                    .enumerate()
1067                    .collect::<Vec<_>>()
1068            );
1069
1070            let mut full_msg = String::new();
1071            {
1072                let mut buf = Vec::with_capacity(10 * real_output.len());
1073                for batch in real_output.iter().skip(first_diff) {
1074                    let mut rb_json: Vec<u8> = Vec::new();
1075                    let mut writer = ArrayWriter::new(&mut rb_json);
1076                    writer.write(batch).unwrap();
1077                    writer.finish().unwrap();
1078                    buf.append(&mut rb_json);
1079                    buf.push(b',');
1080                }
1081                // TODO(discord9): better ways to print buf
1082                let buf = String::from_utf8_lossy(&buf);
1083                full_msg += &format!("\ncase_id:{case_id}, real_output \n{buf}\n");
1084            }
1085            {
1086                let mut buf = Vec::with_capacity(10 * real_output.len());
1087                for batch in expected_output.iter().skip(first_diff) {
1088                    let mut rb_json: Vec<u8> = Vec::new();
1089                    let mut writer = ArrayWriter::new(&mut rb_json);
1090                    writer.write(batch).unwrap();
1091                    writer.finish().unwrap();
1092                    buf.append(&mut rb_json);
1093                    buf.push(b',');
1094                }
1095                let buf = String::from_utf8_lossy(&buf);
1096                full_msg += &format!("case_id:{case_id}, expected_output \n{buf}");
1097            }
1098            panic!(
1099                "case_{} failed, opt: {:?},\n real output has {} batches, {} rows, expected has {} batches with {} rows\nfull msg: {}",
1100                case_id, opt,
1101                real_output.len(),
1102                real_output.iter().map(|x|x.num_rows()).sum::<usize>(),
1103                expected_output.len(),
1104                expected_output.iter().map(|x|x.num_rows()).sum::<usize>(), full_msg
1105            );
1106        }
1107    }
1108}