query/
part_sort.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Module for sorting input data within each [`PartitionRange`].
16//!
17//! This module defines the [`PartSortExec`] execution plan, which sorts each
18//! partition ([`PartitionRange`]) independently based on the provided physical
19//! sort expressions.
20
21use std::any::Any;
22use std::pin::Pin;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26use arrow::array::ArrayRef;
27use arrow::compute::{concat, concat_batches, take_record_batch};
28use arrow_schema::SchemaRef;
29use common_recordbatch::{DfRecordBatch, DfSendableRecordBatchStream};
30use datafusion::common::arrow::compute::sort_to_indices;
31use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation};
32use datafusion::execution::{RecordBatchStream, TaskContext};
33use datafusion::physical_plan::execution_plan::CardinalityEffect;
34use datafusion::physical_plan::filter_pushdown::{
35    ChildFilterDescription, FilterDescription, FilterPushdownPhase,
36};
37use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
38use datafusion::physical_plan::{
39    DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, TopK,
40    TopKDynamicFilters,
41};
42use datafusion_common::{DataFusionError, internal_err};
43use datafusion_physical_expr::expressions::{DynamicFilterPhysicalExpr, lit};
44use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
45use futures::{Stream, StreamExt};
46use itertools::Itertools;
47use parking_lot::RwLock;
48use snafu::location;
49use store_api::region_engine::PartitionRange;
50
51use crate::{array_iter_helper, downcast_ts_array};
52
53/// Sort input within given PartitionRange
54///
55/// Input is assumed to be segmented by empty RecordBatch, which indicates a new `PartitionRange` is starting
56///
57/// and this operator will sort each partition independently within the partition.
58#[derive(Debug, Clone)]
59pub struct PartSortExec {
60    /// Physical sort expressions(that is, sort by timestamp)
61    expression: PhysicalSortExpr,
62    limit: Option<usize>,
63    input: Arc<dyn ExecutionPlan>,
64    /// Execution metrics
65    metrics: ExecutionPlanMetricsSet,
66    partition_ranges: Vec<Vec<PartitionRange>>,
67    properties: PlanProperties,
68    /// Filter matching the state of the sort for dynamic filter pushdown.
69    /// If `limit` is `Some`, this will also be set and a TopK operator may be used.
70    /// If `limit` is `None`, this will be `None`.
71    filter: Option<Arc<RwLock<TopKDynamicFilters>>>,
72}
73
74impl PartSortExec {
75    pub fn new(
76        expression: PhysicalSortExpr,
77        limit: Option<usize>,
78        partition_ranges: Vec<Vec<PartitionRange>>,
79        input: Arc<dyn ExecutionPlan>,
80    ) -> Self {
81        let metrics = ExecutionPlanMetricsSet::new();
82        let properties = input.properties();
83        let properties = PlanProperties::new(
84            input.equivalence_properties().clone(),
85            input.output_partitioning().clone(),
86            properties.emission_type,
87            properties.boundedness,
88        );
89
90        let filter = limit
91            .is_some()
92            .then(|| Self::create_filter(expression.expr.clone()));
93
94        Self {
95            expression,
96            limit,
97            input,
98            metrics,
99            partition_ranges,
100            properties,
101            filter,
102        }
103    }
104
105    /// Add or reset `self.filter` to a new `TopKDynamicFilters`.
106    fn create_filter(expr: Arc<dyn PhysicalExpr>) -> Arc<RwLock<TopKDynamicFilters>> {
107        Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new(
108            DynamicFilterPhysicalExpr::new(vec![expr], lit(true)),
109        ))))
110    }
111
112    pub fn to_stream(
113        &self,
114        context: Arc<TaskContext>,
115        partition: usize,
116    ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
117        let input_stream: DfSendableRecordBatchStream =
118            self.input.execute(partition, context.clone())?;
119
120        if partition >= self.partition_ranges.len() {
121            internal_err!(
122                "Partition index out of range: {} >= {} at {}",
123                partition,
124                self.partition_ranges.len(),
125                snafu::location!()
126            )?;
127        }
128
129        let df_stream = Box::pin(PartSortStream::new(
130            context,
131            self,
132            self.limit,
133            input_stream,
134            self.partition_ranges[partition].clone(),
135            partition,
136            self.filter.clone(),
137        )?) as _;
138
139        Ok(df_stream)
140    }
141}
142
143impl DisplayAs for PartSortExec {
144    fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        write!(
146            f,
147            "PartSortExec: expr={} num_ranges={}",
148            self.expression,
149            self.partition_ranges.len(),
150        )?;
151        if let Some(limit) = self.limit {
152            write!(f, " limit={}", limit)?;
153        }
154        Ok(())
155    }
156}
157
158impl ExecutionPlan for PartSortExec {
159    fn name(&self) -> &str {
160        "PartSortExec"
161    }
162
163    fn as_any(&self) -> &dyn Any {
164        self
165    }
166
167    fn schema(&self) -> SchemaRef {
168        self.input.schema()
169    }
170
171    fn properties(&self) -> &PlanProperties {
172        &self.properties
173    }
174
175    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
176        vec![&self.input]
177    }
178
179    fn with_new_children(
180        self: Arc<Self>,
181        children: Vec<Arc<dyn ExecutionPlan>>,
182    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
183        let new_input = if let Some(first) = children.first() {
184            first
185        } else {
186            internal_err!("No children found")?
187        };
188        Ok(Arc::new(Self::new(
189            self.expression.clone(),
190            self.limit,
191            self.partition_ranges.clone(),
192            new_input.clone(),
193        )))
194    }
195
196    fn execute(
197        &self,
198        partition: usize,
199        context: Arc<TaskContext>,
200    ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
201        self.to_stream(context, partition)
202    }
203
204    fn metrics(&self) -> Option<MetricsSet> {
205        Some(self.metrics.clone_inner())
206    }
207
208    /// # Explain
209    ///
210    /// This plan needs to be executed on each partition independently,
211    /// and is expected to run directly on storage engine's output
212    /// distribution / partition.
213    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
214        vec![false]
215    }
216
217    fn cardinality_effect(&self) -> CardinalityEffect {
218        if self.limit.is_none() {
219            CardinalityEffect::Equal
220        } else {
221            CardinalityEffect::LowerEqual
222        }
223    }
224
225    fn gather_filters_for_pushdown(
226        &self,
227        phase: FilterPushdownPhase,
228        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
229        _config: &datafusion::config::ConfigOptions,
230    ) -> datafusion_common::Result<FilterDescription> {
231        if !matches!(phase, FilterPushdownPhase::Post) {
232            return FilterDescription::from_children(parent_filters, &self.children());
233        }
234
235        let mut child = ChildFilterDescription::from_child(&parent_filters, &self.input)?;
236
237        if let Some(filter) = &self.filter {
238            child = child.with_self_filter(filter.read().expr());
239        }
240
241        Ok(FilterDescription::new().with_child(child))
242    }
243
244    fn reset_state(self: Arc<Self>) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
245        // shared dynamic filter needs to be reset
246        let new_filter = self
247            .limit
248            .is_some()
249            .then(|| Self::create_filter(self.expression.expr.clone()));
250
251        Ok(Arc::new(Self {
252            expression: self.expression.clone(),
253            limit: self.limit,
254            input: self.input.clone(),
255            metrics: self.metrics.clone(),
256            partition_ranges: self.partition_ranges.clone(),
257            properties: self.properties.clone(),
258            filter: new_filter,
259        }))
260    }
261}
262
263enum PartSortBuffer {
264    All(Vec<DfRecordBatch>),
265    /// TopK buffer with row count.
266    ///
267    /// Given this heap only keeps k element, the capacity of this buffer
268    /// is not accurate, and is only used for empty check.
269    Top(TopK, usize),
270}
271
272impl PartSortBuffer {
273    pub fn is_empty(&self) -> bool {
274        match self {
275            PartSortBuffer::All(v) => v.is_empty(),
276            PartSortBuffer::Top(_, cnt) => *cnt == 0,
277        }
278    }
279}
280
281struct PartSortStream {
282    /// Memory pool for this stream
283    reservation: MemoryReservation,
284    buffer: PartSortBuffer,
285    expression: PhysicalSortExpr,
286    limit: Option<usize>,
287    produced: usize,
288    input: DfSendableRecordBatchStream,
289    input_complete: bool,
290    schema: SchemaRef,
291    partition_ranges: Vec<PartitionRange>,
292    #[allow(dead_code)] // this is used under #[debug_assertions]
293    partition: usize,
294    cur_part_idx: usize,
295    evaluating_batch: Option<DfRecordBatch>,
296    metrics: BaselineMetrics,
297    context: Arc<TaskContext>,
298    root_metrics: ExecutionPlanMetricsSet,
299}
300
301impl PartSortStream {
302    fn new(
303        context: Arc<TaskContext>,
304        sort: &PartSortExec,
305        limit: Option<usize>,
306        input: DfSendableRecordBatchStream,
307        partition_ranges: Vec<PartitionRange>,
308        partition: usize,
309        filter: Option<Arc<RwLock<TopKDynamicFilters>>>,
310    ) -> datafusion_common::Result<Self> {
311        let buffer = if let Some(limit) = limit {
312            let Some(filter) = filter else {
313                return internal_err!(
314                    "TopKDynamicFilters must be provided when limit is set at {}",
315                    snafu::location!()
316                );
317            };
318
319            PartSortBuffer::Top(
320                TopK::try_new(
321                    partition,
322                    sort.schema().clone(),
323                    vec![],
324                    [sort.expression.clone()].into(),
325                    limit,
326                    context.session_config().batch_size(),
327                    context.runtime_env(),
328                    &sort.metrics,
329                    filter,
330                )?,
331                0,
332            )
333        } else {
334            PartSortBuffer::All(Vec::new())
335        };
336
337        Ok(Self {
338            reservation: MemoryConsumer::new("PartSortStream".to_string())
339                .register(&context.runtime_env().memory_pool),
340            buffer,
341            expression: sort.expression.clone(),
342            limit,
343            produced: 0,
344            input,
345            input_complete: false,
346            schema: sort.input.schema(),
347            partition_ranges,
348            partition,
349            cur_part_idx: 0,
350            evaluating_batch: None,
351            metrics: BaselineMetrics::new(&sort.metrics, partition),
352            context,
353            root_metrics: sort.metrics.clone(),
354        })
355    }
356}
357
358macro_rules! array_check_helper {
359    ($t:ty, $unit:expr, $arr:expr, $cur_range:expr, $min_max_idx:expr) => {{
360            if $cur_range.start.unit().as_arrow_time_unit() != $unit
361            || $cur_range.end.unit().as_arrow_time_unit() != $unit
362        {
363            internal_err!(
364                "PartitionRange unit mismatch, expect {:?}, found {:?}",
365                $cur_range.start.unit(),
366                $unit
367            )?;
368        }
369        let arr = $arr
370            .as_any()
371            .downcast_ref::<arrow::array::PrimitiveArray<$t>>()
372            .unwrap();
373
374        let min = arr.value($min_max_idx.0);
375        let max = arr.value($min_max_idx.1);
376        let (min, max) = if min < max{
377            (min, max)
378        } else {
379            (max, min)
380        };
381        let cur_min = $cur_range.start.value();
382        let cur_max = $cur_range.end.value();
383        // note that PartitionRange is left inclusive and right exclusive
384        if !(min >= cur_min && max < cur_max) {
385            internal_err!(
386                "Sort column min/max value out of partition range: sort_column.min_max=[{:?}, {:?}] not in PartitionRange=[{:?}, {:?}]",
387                min,
388                max,
389                cur_min,
390                cur_max
391            )?;
392        }
393    }};
394}
395
396impl PartSortStream {
397    /// check whether the sort column's min/max value is within the partition range
398    fn check_in_range(
399        &self,
400        sort_column: &ArrayRef,
401        min_max_idx: (usize, usize),
402    ) -> datafusion_common::Result<()> {
403        if self.cur_part_idx >= self.partition_ranges.len() {
404            internal_err!(
405                "Partition index out of range: {} >= {} at {}",
406                self.cur_part_idx,
407                self.partition_ranges.len(),
408                snafu::location!()
409            )?;
410        }
411        let cur_range = self.partition_ranges[self.cur_part_idx];
412
413        downcast_ts_array!(
414            sort_column.data_type() => (array_check_helper, sort_column, cur_range, min_max_idx),
415            _ => internal_err!(
416                "Unsupported data type for sort column: {:?}",
417                sort_column.data_type()
418            )?,
419        );
420
421        Ok(())
422    }
423
424    /// Try find data whose value exceeds the current partition range.
425    ///
426    /// Returns `None` if no such data is found, and `Some(idx)` where idx points to
427    /// the first data that exceeds the current partition range.
428    fn try_find_next_range(
429        &self,
430        sort_column: &ArrayRef,
431    ) -> datafusion_common::Result<Option<usize>> {
432        if sort_column.is_empty() {
433            return Ok(Some(0));
434        }
435
436        // check if the current partition index is out of range
437        if self.cur_part_idx >= self.partition_ranges.len() {
438            internal_err!(
439                "Partition index out of range: {} >= {} at {}",
440                self.cur_part_idx,
441                self.partition_ranges.len(),
442                snafu::location!()
443            )?;
444        }
445        let cur_range = self.partition_ranges[self.cur_part_idx];
446
447        let sort_column_iter = downcast_ts_array!(
448            sort_column.data_type() => (array_iter_helper, sort_column),
449            _ => internal_err!(
450                "Unsupported data type for sort column: {:?}",
451                sort_column.data_type()
452            )?,
453        );
454
455        for (idx, val) in sort_column_iter {
456            // ignore vacant time index data
457            if let Some(val) = val
458                && (val >= cur_range.end.value() || val < cur_range.start.value())
459            {
460                return Ok(Some(idx));
461            }
462        }
463
464        Ok(None)
465    }
466
467    fn push_buffer(&mut self, batch: DfRecordBatch) -> datafusion_common::Result<()> {
468        match &mut self.buffer {
469            PartSortBuffer::All(v) => v.push(batch),
470            PartSortBuffer::Top(top, cnt) => {
471                *cnt += batch.num_rows();
472                top.insert_batch(batch)?;
473            }
474        }
475
476        Ok(())
477    }
478
479    /// Sort and clear the buffer and return the sorted record batch
480    ///
481    /// this function will return a empty record batch if the buffer is empty
482    fn sort_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
483        match &mut self.buffer {
484            PartSortBuffer::All(_) => self.sort_all_buffer(),
485            PartSortBuffer::Top(_, _) => self.sort_top_buffer(),
486        }
487    }
488
489    /// Internal method for sorting `All` buffer (without limit).
490    fn sort_all_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
491        let PartSortBuffer::All(buffer) =
492            std::mem::replace(&mut self.buffer, PartSortBuffer::All(Vec::new()))
493        else {
494            unreachable!("buffer type is checked before and should be All variant")
495        };
496
497        if buffer.is_empty() {
498            return Ok(DfRecordBatch::new_empty(self.schema.clone()));
499        }
500        let mut sort_columns = Vec::with_capacity(buffer.len());
501        let mut opt = None;
502        for batch in buffer.iter() {
503            let sort_column = self.expression.evaluate_to_sort_column(batch)?;
504            opt = opt.or(sort_column.options);
505            sort_columns.push(sort_column.values);
506        }
507
508        let sort_column =
509            concat(&sort_columns.iter().map(|a| a.as_ref()).collect_vec()).map_err(|e| {
510                DataFusionError::ArrowError(
511                    Box::new(e),
512                    Some(format!("Fail to concat sort columns at {}", location!())),
513                )
514            })?;
515
516        let indices = sort_to_indices(&sort_column, opt, self.limit).map_err(|e| {
517            DataFusionError::ArrowError(
518                Box::new(e),
519                Some(format!("Fail to sort to indices at {}", location!())),
520            )
521        })?;
522        if indices.is_empty() {
523            return Ok(DfRecordBatch::new_empty(self.schema.clone()));
524        }
525
526        self.check_in_range(
527            &sort_column,
528            (
529                indices.value(0) as usize,
530                indices.value(indices.len() - 1) as usize,
531            ),
532        )
533        .inspect_err(|_e| {
534            #[cfg(debug_assertions)]
535            common_telemetry::error!(
536                "Fail to check sort column in range at {}, current_idx: {}, num_rows: {}, err: {}",
537                self.partition,
538                self.cur_part_idx,
539                sort_column.len(),
540                _e
541            );
542        })?;
543
544        // reserve memory for the concat input and sorted output
545        let total_mem: usize = buffer.iter().map(|r| r.get_array_memory_size()).sum();
546        self.reservation.try_grow(total_mem * 2)?;
547
548        let full_input = concat_batches(&self.schema, &buffer).map_err(|e| {
549            DataFusionError::ArrowError(
550                Box::new(e),
551                Some(format!(
552                    "Fail to concat input batches when sorting at {}",
553                    location!()
554                )),
555            )
556        })?;
557
558        let sorted = take_record_batch(&full_input, &indices).map_err(|e| {
559            DataFusionError::ArrowError(
560                Box::new(e),
561                Some(format!(
562                    "Fail to take result record batch when sorting at {}",
563                    location!()
564                )),
565            )
566        })?;
567
568        self.produced += sorted.num_rows();
569        drop(full_input);
570        // here remove both buffer and full_input memory
571        self.reservation.shrink(2 * total_mem);
572        Ok(sorted)
573    }
574
575    /// Internal method for sorting `Top` buffer (with limit).
576    fn sort_top_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
577        let filter = Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new(
578            DynamicFilterPhysicalExpr::new(vec![], lit(true)),
579        ))));
580        let new_top_buffer = TopK::try_new(
581            self.partition,
582            self.schema().clone(),
583            vec![],
584            [self.expression.clone()].into(),
585            self.limit.unwrap(),
586            self.context.session_config().batch_size(),
587            self.context.runtime_env(),
588            &self.root_metrics,
589            filter,
590        )?;
591        let PartSortBuffer::Top(top_k, _) =
592            std::mem::replace(&mut self.buffer, PartSortBuffer::Top(new_top_buffer, 0))
593        else {
594            unreachable!("buffer type is checked before and should be Top variant")
595        };
596
597        let mut result_stream = top_k.emit()?;
598        let mut placeholder_ctx = std::task::Context::from_waker(futures::task::noop_waker_ref());
599        let mut results = vec![];
600        // according to the current implementation of `TopK`, the result stream will always be ready
601        loop {
602            match result_stream.poll_next_unpin(&mut placeholder_ctx) {
603                Poll::Ready(Some(batch)) => {
604                    let batch = batch?;
605                    results.push(batch);
606                }
607                Poll::Pending => {
608                    #[cfg(debug_assertions)]
609                    unreachable!("TopK result stream should always be ready")
610                }
611                Poll::Ready(None) => {
612                    break;
613                }
614            }
615        }
616
617        let concat_batch = concat_batches(&self.schema, &results).map_err(|e| {
618            DataFusionError::ArrowError(
619                Box::new(e),
620                Some(format!(
621                    "Fail to concat top k result record batch when sorting at {}",
622                    location!()
623                )),
624            )
625        })?;
626
627        Ok(concat_batch)
628    }
629
630    /// Try to split the input batch if it contains data that exceeds the current partition range.
631    ///
632    /// When the input batch contains data that exceeds the current partition range, this function
633    /// will split the input batch into two parts, the first part is within the current partition
634    /// range will be merged and sorted with previous buffer, and the second part will be registered
635    /// to `evaluating_batch` for next polling.
636    ///
637    /// Returns `None` if the input batch is empty or fully within the current partition range, and
638    /// `Some(batch)` otherwise.
639    fn split_batch(
640        &mut self,
641        batch: DfRecordBatch,
642    ) -> datafusion_common::Result<Option<DfRecordBatch>> {
643        if batch.num_rows() == 0 {
644            return Ok(None);
645        }
646
647        let sort_column = self
648            .expression
649            .expr
650            .evaluate(&batch)?
651            .into_array(batch.num_rows())?;
652
653        let next_range_idx = self.try_find_next_range(&sort_column)?;
654        let Some(idx) = next_range_idx else {
655            self.push_buffer(batch)?;
656            // keep polling input for next batch
657            return Ok(None);
658        };
659
660        let this_range = batch.slice(0, idx);
661        let remaining_range = batch.slice(idx, batch.num_rows() - idx);
662        if this_range.num_rows() != 0 {
663            self.push_buffer(this_range)?;
664        }
665        // mark end of current PartitionRange
666        let sorted_batch = self.sort_buffer();
667        // step to next proper PartitionRange
668        self.cur_part_idx += 1;
669        let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
670        if self.try_find_next_range(&next_sort_column)?.is_some() {
671            // remaining batch still contains data that exceeds the current partition range
672            // register the remaining batch for next polling
673            self.evaluating_batch = Some(remaining_range);
674        } else {
675            // remaining batch is within the current partition range
676            // push to the buffer and continue polling
677            if remaining_range.num_rows() != 0 {
678                self.push_buffer(remaining_range)?;
679            }
680        }
681
682        sorted_batch.map(|x| if x.num_rows() == 0 { None } else { Some(x) })
683    }
684
685    pub fn poll_next_inner(
686        mut self: Pin<&mut Self>,
687        cx: &mut Context<'_>,
688    ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
689        loop {
690            // no more input, sort the buffer and return
691            if self.input_complete {
692                if self.buffer.is_empty() {
693                    return Poll::Ready(None);
694                } else {
695                    return Poll::Ready(Some(self.sort_buffer()));
696                }
697            }
698
699            // if there is a remaining batch being evaluated from last run,
700            // split on it instead of fetching new batch
701            if let Some(evaluating_batch) = self.evaluating_batch.take()
702                && evaluating_batch.num_rows() != 0
703            {
704                if let Some(sorted_batch) = self.split_batch(evaluating_batch)? {
705                    return Poll::Ready(Some(Ok(sorted_batch)));
706                } else {
707                    continue;
708                }
709            }
710
711            // fetch next batch from input
712            let res = self.input.as_mut().poll_next(cx);
713            match res {
714                Poll::Ready(Some(Ok(batch))) => {
715                    if let Some(sorted_batch) = self.split_batch(batch)? {
716                        return Poll::Ready(Some(Ok(sorted_batch)));
717                    } else {
718                        continue;
719                    }
720                }
721                // input stream end, mark and continue
722                Poll::Ready(None) => {
723                    self.input_complete = true;
724                    continue;
725                }
726                Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
727                Poll::Pending => return Poll::Pending,
728            }
729        }
730    }
731}
732
733impl Stream for PartSortStream {
734    type Item = datafusion_common::Result<DfRecordBatch>;
735
736    fn poll_next(
737        mut self: Pin<&mut Self>,
738        cx: &mut Context<'_>,
739    ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
740        let result = self.as_mut().poll_next_inner(cx);
741        self.metrics.record_poll(result)
742    }
743}
744
745impl RecordBatchStream for PartSortStream {
746    fn schema(&self) -> SchemaRef {
747        self.schema.clone()
748    }
749}
750
751#[cfg(test)]
752mod test {
753    use std::sync::Arc;
754
755    use arrow::json::ArrayWriter;
756    use arrow_schema::{DataType, Field, Schema, SortOptions, TimeUnit};
757    use common_time::Timestamp;
758    use datafusion_physical_expr::expressions::Column;
759    use futures::StreamExt;
760    use store_api::region_engine::PartitionRange;
761
762    use super::*;
763    use crate::test_util::{MockInputExec, new_ts_array};
764
765    #[tokio::test]
766    async fn fuzzy_test() {
767        let test_cnt = 100;
768        // bound for total count of PartitionRange
769        let part_cnt_bound = 100;
770        // bound for timestamp range size and offset for each PartitionRange
771        let range_size_bound = 100;
772        let range_offset_bound = 100;
773        // bound for batch count and size within each PartitionRange
774        let batch_cnt_bound = 20;
775        let batch_size_bound = 100;
776
777        let mut rng = fastrand::Rng::new();
778        rng.seed(1337);
779
780        let mut test_cases = Vec::new();
781
782        for case_id in 0..test_cnt {
783            let mut bound_val: Option<i64> = None;
784            let descending = rng.bool();
785            let nulls_first = rng.bool();
786            let opt = SortOptions {
787                descending,
788                nulls_first,
789            };
790            let limit = if rng.bool() {
791                Some(rng.usize(0..batch_cnt_bound * batch_size_bound))
792            } else {
793                None
794            };
795            let unit = match rng.u8(0..3) {
796                0 => TimeUnit::Second,
797                1 => TimeUnit::Millisecond,
798                2 => TimeUnit::Microsecond,
799                _ => TimeUnit::Nanosecond,
800            };
801
802            let schema = Schema::new(vec![Field::new(
803                "ts",
804                DataType::Timestamp(unit, None),
805                false,
806            )]);
807            let schema = Arc::new(schema);
808
809            let mut input_ranged_data = vec![];
810            let mut output_ranges = vec![];
811            let mut output_data = vec![];
812            // generate each input `PartitionRange`
813            for part_id in 0..rng.usize(0..part_cnt_bound) {
814                // generate each `PartitionRange`'s timestamp range
815                let (start, end) = if descending {
816                    let end = bound_val
817                        .map(
818                            |i| i
819                            .checked_sub(rng.i64(0..range_offset_bound))
820                            .expect("Bad luck, fuzzy test generate data that will overflow, change seed and try again")
821                        )
822                        .unwrap_or_else(|| rng.i64(-100000000..100000000));
823                    bound_val = Some(end);
824                    let start = end - rng.i64(1..range_size_bound);
825                    let start = Timestamp::new(start, unit.into());
826                    let end = Timestamp::new(end, unit.into());
827                    (start, end)
828                } else {
829                    let start = bound_val
830                        .map(|i| i + rng.i64(0..range_offset_bound))
831                        .unwrap_or_else(|| rng.i64(..));
832                    bound_val = Some(start);
833                    let end = start + rng.i64(1..range_size_bound);
834                    let start = Timestamp::new(start, unit.into());
835                    let end = Timestamp::new(end, unit.into());
836                    (start, end)
837                };
838                assert!(start < end);
839
840                let mut per_part_sort_data = vec![];
841                let mut batches = vec![];
842                for _batch_idx in 0..rng.usize(1..batch_cnt_bound) {
843                    let cnt = rng.usize(0..batch_size_bound) + 1;
844                    let iter = 0..rng.usize(0..cnt);
845                    let mut data_gen = iter
846                        .map(|_| rng.i64(start.value()..end.value()))
847                        .collect_vec();
848                    if data_gen.is_empty() {
849                        // current batch is empty, skip
850                        continue;
851                    }
852                    // mito always sort on ASC order
853                    data_gen.sort();
854                    per_part_sort_data.extend(data_gen.clone());
855                    let arr = new_ts_array(unit, data_gen.clone());
856                    let batch = DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap();
857                    batches.push(batch);
858                }
859
860                let range = PartitionRange {
861                    start,
862                    end,
863                    num_rows: batches.iter().map(|b| b.num_rows()).sum(),
864                    identifier: part_id,
865                };
866                input_ranged_data.push((range, batches));
867
868                output_ranges.push(range);
869                if per_part_sort_data.is_empty() {
870                    continue;
871                }
872                output_data.extend_from_slice(&per_part_sort_data);
873            }
874
875            // adjust output data with adjacent PartitionRanges
876            let mut output_data_iter = output_data.iter().peekable();
877            let mut output_data = vec![];
878            for range in output_ranges.clone() {
879                let mut cur_data = vec![];
880                while let Some(val) = output_data_iter.peek() {
881                    if **val < range.start.value() || **val >= range.end.value() {
882                        break;
883                    }
884                    cur_data.push(*output_data_iter.next().unwrap());
885                }
886
887                if cur_data.is_empty() {
888                    continue;
889                }
890
891                if descending {
892                    cur_data.sort_by(|a, b| b.cmp(a));
893                } else {
894                    cur_data.sort();
895                }
896                output_data.push(cur_data);
897            }
898
899            let expected_output = output_data
900                .into_iter()
901                .map(|a| {
902                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
903                })
904                .map(|rb| {
905                    // trim expected output with limit
906                    if let Some(limit) = limit
907                        && rb.num_rows() > limit
908                    {
909                        rb.slice(0, limit)
910                    } else {
911                        rb
912                    }
913                })
914                .collect_vec();
915
916            test_cases.push((
917                case_id,
918                unit,
919                input_ranged_data,
920                schema,
921                opt,
922                limit,
923                expected_output,
924            ));
925        }
926
927        for (case_id, _unit, input_ranged_data, schema, opt, limit, expected_output) in test_cases {
928            run_test(
929                case_id,
930                input_ranged_data,
931                schema,
932                opt,
933                limit,
934                expected_output,
935            )
936            .await;
937        }
938    }
939
940    #[tokio::test]
941    async fn simple_case() {
942        let testcases = vec![
943            (
944                TimeUnit::Millisecond,
945                vec![
946                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]),
947                    ((5, 10), vec![vec![5, 6], vec![7, 8]]),
948                ],
949                false,
950                None,
951                vec![vec![1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9]],
952            ),
953            (
954                TimeUnit::Millisecond,
955                vec![
956                    ((5, 10), vec![vec![5, 6], vec![7, 8, 9]]),
957                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
958                ],
959                true,
960                None,
961                vec![vec![9, 8, 7, 6, 5], vec![8, 7, 6, 5, 4, 3, 2, 1]],
962            ),
963            (
964                TimeUnit::Millisecond,
965                vec![
966                    ((5, 10), vec![]),
967                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
968                ],
969                true,
970                None,
971                vec![vec![8, 7, 6, 5, 4, 3, 2, 1]],
972            ),
973            (
974                TimeUnit::Millisecond,
975                vec![
976                    ((15, 20), vec![vec![17, 18, 19]]),
977                    ((10, 15), vec![]),
978                    ((5, 10), vec![]),
979                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
980                ],
981                true,
982                None,
983                vec![vec![19, 18, 17], vec![8, 7, 6, 5, 4, 3, 2, 1]],
984            ),
985            (
986                TimeUnit::Millisecond,
987                vec![
988                    ((15, 20), vec![]),
989                    ((10, 15), vec![]),
990                    ((5, 10), vec![]),
991                    ((0, 10), vec![]),
992                ],
993                true,
994                None,
995                vec![],
996            ),
997            (
998                TimeUnit::Millisecond,
999                vec![
1000                    (
1001                        (15, 20),
1002                        vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
1003                    ),
1004                    ((10, 15), vec![]),
1005                    ((5, 10), vec![]),
1006                    ((0, 10), vec![]),
1007                ],
1008                true,
1009                None,
1010                vec![
1011                    vec![19, 17, 15],
1012                    vec![12, 11, 10],
1013                    vec![9, 8, 7, 6, 5],
1014                    vec![4, 3, 2, 1],
1015                ],
1016            ),
1017            (
1018                TimeUnit::Millisecond,
1019                vec![
1020                    (
1021                        (15, 20),
1022                        vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
1023                    ),
1024                    ((10, 15), vec![]),
1025                    ((5, 10), vec![]),
1026                    ((0, 10), vec![]),
1027                ],
1028                true,
1029                Some(2),
1030                vec![vec![19, 17], vec![12, 11], vec![9, 8], vec![4, 3]],
1031            ),
1032        ];
1033
1034        for (identifier, (unit, input_ranged_data, descending, limit, expected_output)) in
1035            testcases.into_iter().enumerate()
1036        {
1037            let schema = Schema::new(vec![Field::new(
1038                "ts",
1039                DataType::Timestamp(unit, None),
1040                false,
1041            )]);
1042            let schema = Arc::new(schema);
1043            let opt = SortOptions {
1044                descending,
1045                ..Default::default()
1046            };
1047
1048            let input_ranged_data = input_ranged_data
1049                .into_iter()
1050                .map(|(range, data)| {
1051                    let part = PartitionRange {
1052                        start: Timestamp::new(range.0, unit.into()),
1053                        end: Timestamp::new(range.1, unit.into()),
1054                        num_rows: data.iter().map(|b| b.len()).sum(),
1055                        identifier,
1056                    };
1057
1058                    let batches = data
1059                        .into_iter()
1060                        .map(|b| {
1061                            let arr = new_ts_array(unit, b);
1062                            DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap()
1063                        })
1064                        .collect_vec();
1065                    (part, batches)
1066                })
1067                .collect_vec();
1068
1069            let expected_output = expected_output
1070                .into_iter()
1071                .map(|a| {
1072                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
1073                })
1074                .collect_vec();
1075
1076            run_test(
1077                identifier,
1078                input_ranged_data,
1079                schema.clone(),
1080                opt,
1081                limit,
1082                expected_output,
1083            )
1084            .await;
1085        }
1086    }
1087
1088    #[allow(clippy::print_stdout)]
1089    async fn run_test(
1090        case_id: usize,
1091        input_ranged_data: Vec<(PartitionRange, Vec<DfRecordBatch>)>,
1092        schema: SchemaRef,
1093        opt: SortOptions,
1094        limit: Option<usize>,
1095        expected_output: Vec<DfRecordBatch>,
1096    ) {
1097        for rb in &expected_output {
1098            if let Some(limit) = limit {
1099                assert!(
1100                    rb.num_rows() <= limit,
1101                    "Expect row count in expected output's batch({}) <= limit({})",
1102                    rb.num_rows(),
1103                    limit
1104                );
1105            }
1106        }
1107        let (ranges, batches): (Vec<_>, Vec<_>) = input_ranged_data.clone().into_iter().unzip();
1108
1109        let batches = batches
1110            .into_iter()
1111            .flat_map(|mut cols| {
1112                cols.push(DfRecordBatch::new_empty(schema.clone()));
1113                cols
1114            })
1115            .collect_vec();
1116        let mock_input = MockInputExec::new(batches, schema.clone());
1117
1118        let exec = PartSortExec::new(
1119            PhysicalSortExpr {
1120                expr: Arc::new(Column::new("ts", 0)),
1121                options: opt,
1122            },
1123            limit,
1124            vec![ranges.clone()],
1125            Arc::new(mock_input),
1126        );
1127
1128        let exec_stream = exec.execute(0, Arc::new(TaskContext::default())).unwrap();
1129
1130        let real_output = exec_stream.map(|r| r.unwrap()).collect::<Vec<_>>().await;
1131        // a makeshift solution for compare large data
1132        if real_output != expected_output {
1133            let mut first_diff = 0;
1134            for (idx, (lhs, rhs)) in real_output.iter().zip(expected_output.iter()).enumerate() {
1135                if lhs != rhs {
1136                    first_diff = idx;
1137                    break;
1138                }
1139            }
1140            println!("first diff batch at {}", first_diff);
1141            println!(
1142                "ranges: {:?}",
1143                ranges
1144                    .into_iter()
1145                    .map(|r| (r.start.to_chrono_datetime(), r.end.to_chrono_datetime()))
1146                    .enumerate()
1147                    .collect::<Vec<_>>()
1148            );
1149
1150            let mut full_msg = String::new();
1151            {
1152                let mut buf = Vec::with_capacity(10 * real_output.len());
1153                for batch in real_output.iter().skip(first_diff) {
1154                    let mut rb_json: Vec<u8> = Vec::new();
1155                    let mut writer = ArrayWriter::new(&mut rb_json);
1156                    writer.write(batch).unwrap();
1157                    writer.finish().unwrap();
1158                    buf.append(&mut rb_json);
1159                    buf.push(b',');
1160                }
1161                // TODO(discord9): better ways to print buf
1162                let buf = String::from_utf8_lossy(&buf);
1163                full_msg += &format!("\ncase_id:{case_id}, real_output \n{buf}\n");
1164            }
1165            {
1166                let mut buf = Vec::with_capacity(10 * real_output.len());
1167                for batch in expected_output.iter().skip(first_diff) {
1168                    let mut rb_json: Vec<u8> = Vec::new();
1169                    let mut writer = ArrayWriter::new(&mut rb_json);
1170                    writer.write(batch).unwrap();
1171                    writer.finish().unwrap();
1172                    buf.append(&mut rb_json);
1173                    buf.push(b',');
1174                }
1175                let buf = String::from_utf8_lossy(&buf);
1176                full_msg += &format!("case_id:{case_id}, expected_output \n{buf}");
1177            }
1178            panic!(
1179                "case_{} failed, opt: {:?},\n real output has {} batches, {} rows, expected has {} batches with {} rows\nfull msg: {}",
1180                case_id,
1181                opt,
1182                real_output.len(),
1183                real_output.iter().map(|x| x.num_rows()).sum::<usize>(),
1184                expected_output.len(),
1185                expected_output.iter().map(|x| x.num_rows()).sum::<usize>(),
1186                full_msg
1187            );
1188        }
1189    }
1190}