promql/extension_plan/
histogram_fold.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::borrow::Cow;
17use std::collections::{HashMap, HashSet};
18use std::sync::Arc;
19use std::task::Poll;
20use std::time::Instant;
21
22use common_telemetry::warn;
23use datafusion::arrow::array::{Array, AsArray, StringArray};
24use datafusion::arrow::compute::{SortOptions, concat_batches};
25use datafusion::arrow::datatypes::{DataType, Float64Type, SchemaRef};
26use datafusion::arrow::record_batch::RecordBatch;
27use datafusion::common::stats::Precision;
28use datafusion::common::{ColumnStatistics, DFSchema, DFSchemaRef, Statistics};
29use datafusion::error::{DataFusionError, Result as DataFusionResult};
30use datafusion::execution::TaskContext;
31use datafusion::logical_expr::{LogicalPlan, UserDefinedLogicalNodeCore};
32use datafusion::physical_expr::{
33    EquivalenceProperties, LexRequirement, OrderingRequirements, PhysicalSortRequirement,
34};
35use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
36use datafusion::physical_plan::expressions::{CastExpr as PhyCast, Column as PhyColumn};
37use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
38use datafusion::physical_plan::{
39    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
40    Partitioning, PhysicalExpr, PlanProperties, RecordBatchStream, SendableRecordBatchStream,
41};
42use datafusion::prelude::{Column, Expr};
43use datatypes::prelude::{ConcreteDataType, DataType as GtDataType};
44use datatypes::value::{OrderedF64, Value, ValueRef};
45use datatypes::vectors::{Helper, MutableVector, VectorRef};
46use futures::{Stream, StreamExt, ready};
47
48/// `HistogramFold` will fold the conventional (non-native) histogram ([1]) for later
49/// computing.
50///
51/// Specifically, it will transform the `le` and `field` column into a complex
52/// type, and samples on other tag columns:
53/// - `le` will become a [ListArray] of [f64]. With each bucket bound parsed
54/// - `field` will become a [ListArray] of [f64]
55/// - other columns will be sampled every `bucket_num` element, but their types won't change.
56///
57/// Due to the folding or sampling, the output rows number will become `input_rows` / `bucket_num`.
58///
59/// # Requirement
60/// - Input should be sorted on `<tag list>, ts, le ASC`.
61/// - The value set of `le` should be same. I.e., buckets of every series should be same.
62///
63/// [1]: https://prometheus.io/docs/concepts/metric_types/#histogram
64#[derive(Debug, PartialEq, Hash, Eq)]
65pub struct HistogramFold {
66    /// Name of the `le` column. It's a special column in prometheus
67    /// for implementing conventional histogram. It's a string column
68    /// with "literal" float value, like "+Inf", "0.001" etc.
69    le_column: String,
70    ts_column: String,
71    input: LogicalPlan,
72    field_column: String,
73    quantile: OrderedF64,
74    output_schema: DFSchemaRef,
75}
76
77impl UserDefinedLogicalNodeCore for HistogramFold {
78    fn name(&self) -> &str {
79        Self::name()
80    }
81
82    fn inputs(&self) -> Vec<&LogicalPlan> {
83        vec![&self.input]
84    }
85
86    fn schema(&self) -> &DFSchemaRef {
87        &self.output_schema
88    }
89
90    fn expressions(&self) -> Vec<Expr> {
91        vec![]
92    }
93
94    fn fmt_for_explain(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        write!(
96            f,
97            "HistogramFold: le={}, field={}, quantile={}",
98            self.le_column, self.field_column, self.quantile
99        )
100    }
101
102    fn with_exprs_and_inputs(
103        &self,
104        _exprs: Vec<Expr>,
105        inputs: Vec<LogicalPlan>,
106    ) -> DataFusionResult<Self> {
107        Ok(Self {
108            le_column: self.le_column.clone(),
109            ts_column: self.ts_column.clone(),
110            input: inputs.into_iter().next().unwrap(),
111            field_column: self.field_column.clone(),
112            quantile: self.quantile,
113            // This method cannot return error. Otherwise we should re-calculate
114            // the output schema
115            output_schema: self.output_schema.clone(),
116        })
117    }
118}
119
120impl HistogramFold {
121    pub fn new(
122        le_column: String,
123        field_column: String,
124        ts_column: String,
125        quantile: f64,
126        input: LogicalPlan,
127    ) -> DataFusionResult<Self> {
128        let input_schema = input.schema();
129        Self::check_schema(input_schema, &le_column, &field_column, &ts_column)?;
130        let output_schema = Self::convert_schema(input_schema, &le_column)?;
131        Ok(Self {
132            le_column,
133            ts_column,
134            input,
135            field_column,
136            quantile: quantile.into(),
137            output_schema,
138        })
139    }
140
141    pub const fn name() -> &'static str {
142        "HistogramFold"
143    }
144
145    fn check_schema(
146        input_schema: &DFSchemaRef,
147        le_column: &str,
148        field_column: &str,
149        ts_column: &str,
150    ) -> DataFusionResult<()> {
151        let check_column = |col| {
152            if !input_schema.has_column_with_unqualified_name(col) {
153                Err(DataFusionError::SchemaError(
154                    Box::new(datafusion::common::SchemaError::FieldNotFound {
155                        field: Box::new(Column::new(None::<String>, col)),
156                        valid_fields: input_schema.columns(),
157                    }),
158                    Box::new(None),
159                ))
160            } else {
161                Ok(())
162            }
163        };
164
165        check_column(le_column)?;
166        check_column(ts_column)?;
167        check_column(field_column)
168    }
169
170    pub fn to_execution_plan(&self, exec_input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
171        let input_schema = self.input.schema();
172        // safety: those fields are checked in `check_schema()`
173        let le_column_index = input_schema
174            .index_of_column_by_name(None, &self.le_column)
175            .unwrap();
176        let field_column_index = input_schema
177            .index_of_column_by_name(None, &self.field_column)
178            .unwrap();
179        let ts_column_index = input_schema
180            .index_of_column_by_name(None, &self.ts_column)
181            .unwrap();
182
183        let tag_columns = exec_input
184            .schema()
185            .fields()
186            .iter()
187            .enumerate()
188            .filter_map(|(idx, field)| {
189                if idx == le_column_index || idx == field_column_index || idx == ts_column_index {
190                    None
191                } else {
192                    Some(Arc::new(PhyColumn::new(field.name(), idx)) as _)
193                }
194            })
195            .collect::<Vec<_>>();
196
197        let mut partition_exprs = tag_columns.clone();
198        partition_exprs.push(Arc::new(PhyColumn::new(
199            self.input.schema().field(ts_column_index).name(),
200            ts_column_index,
201        )) as _);
202
203        let output_schema: SchemaRef = self.output_schema.inner().clone();
204        let properties = PlanProperties::new(
205            EquivalenceProperties::new(output_schema.clone()),
206            Partitioning::Hash(
207                partition_exprs.clone(),
208                exec_input.output_partitioning().partition_count(),
209            ),
210            EmissionType::Incremental,
211            Boundedness::Bounded,
212        );
213        Arc::new(HistogramFoldExec {
214            le_column_index,
215            field_column_index,
216            ts_column_index,
217            input: exec_input,
218            tag_columns,
219            partition_exprs,
220            quantile: self.quantile.into(),
221            output_schema,
222            metric: ExecutionPlanMetricsSet::new(),
223            properties,
224        })
225    }
226
227    /// Transform the schema
228    ///
229    /// - `le` will be removed
230    fn convert_schema(
231        input_schema: &DFSchemaRef,
232        le_column: &str,
233    ) -> DataFusionResult<DFSchemaRef> {
234        let fields = input_schema.fields();
235        // safety: those fields are checked in `check_schema()`
236        let mut new_fields = Vec::with_capacity(fields.len() - 1);
237        for f in fields {
238            if f.name() != le_column {
239                new_fields.push((None, f.clone()));
240            }
241        }
242        Ok(Arc::new(DFSchema::new_with_metadata(
243            new_fields,
244            HashMap::new(),
245        )?))
246    }
247}
248
249impl PartialOrd for HistogramFold {
250    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
251        // Compare fields in order excluding output_schema
252        match self.le_column.partial_cmp(&other.le_column) {
253            Some(core::cmp::Ordering::Equal) => {}
254            ord => return ord,
255        }
256        match self.ts_column.partial_cmp(&other.ts_column) {
257            Some(core::cmp::Ordering::Equal) => {}
258            ord => return ord,
259        }
260        match self.input.partial_cmp(&other.input) {
261            Some(core::cmp::Ordering::Equal) => {}
262            ord => return ord,
263        }
264        match self.field_column.partial_cmp(&other.field_column) {
265            Some(core::cmp::Ordering::Equal) => {}
266            ord => return ord,
267        }
268        self.quantile.partial_cmp(&other.quantile)
269    }
270}
271
272#[derive(Debug)]
273pub struct HistogramFoldExec {
274    /// Index for `le` column in the schema of input.
275    le_column_index: usize,
276    input: Arc<dyn ExecutionPlan>,
277    output_schema: SchemaRef,
278    /// Index for field column in the schema of input.
279    field_column_index: usize,
280    ts_column_index: usize,
281    /// Tag columns are all columns except `le`, `field` and `ts` columns.
282    tag_columns: Vec<Arc<dyn PhysicalExpr>>,
283    partition_exprs: Vec<Arc<dyn PhysicalExpr>>,
284    quantile: f64,
285    metric: ExecutionPlanMetricsSet,
286    properties: PlanProperties,
287}
288
289impl ExecutionPlan for HistogramFoldExec {
290    fn as_any(&self) -> &dyn Any {
291        self
292    }
293
294    fn properties(&self) -> &PlanProperties {
295        &self.properties
296    }
297
298    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
299        let mut cols = self
300            .tag_columns
301            .iter()
302            .map(|expr| PhysicalSortRequirement {
303                expr: expr.clone(),
304                options: None,
305            })
306            .collect::<Vec<PhysicalSortRequirement>>();
307        // add ts
308        cols.push(PhysicalSortRequirement {
309            expr: Arc::new(PhyColumn::new(
310                self.input.schema().field(self.ts_column_index).name(),
311                self.ts_column_index,
312            )),
313            options: None,
314        });
315        // add le ASC
316        cols.push(PhysicalSortRequirement {
317            expr: Arc::new(PhyCast::new(
318                Arc::new(PhyColumn::new(
319                    self.input.schema().field(self.le_column_index).name(),
320                    self.le_column_index,
321                )),
322                DataType::Float64,
323                None,
324            )),
325            options: Some(SortOptions {
326                descending: false,  // +INF in the last
327                nulls_first: false, // not nullable
328            }),
329        });
330
331        // Safety: `cols` is not empty
332        let requirement = LexRequirement::new(cols).unwrap();
333
334        vec![Some(OrderingRequirements::Hard(vec![requirement]))]
335    }
336
337    fn required_input_distribution(&self) -> Vec<Distribution> {
338        vec![Distribution::HashPartitioned(self.partition_exprs.clone())]
339    }
340
341    fn maintains_input_order(&self) -> Vec<bool> {
342        vec![true; self.children().len()]
343    }
344
345    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
346        vec![&self.input]
347    }
348
349    // cannot change schema with this method
350    fn with_new_children(
351        self: Arc<Self>,
352        children: Vec<Arc<dyn ExecutionPlan>>,
353    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
354        assert!(!children.is_empty());
355        let new_input = children[0].clone();
356        let properties = PlanProperties::new(
357            EquivalenceProperties::new(self.output_schema.clone()),
358            Partitioning::Hash(
359                self.partition_exprs.clone(),
360                new_input.output_partitioning().partition_count(),
361            ),
362            EmissionType::Incremental,
363            Boundedness::Bounded,
364        );
365        Ok(Arc::new(Self {
366            input: new_input,
367            metric: self.metric.clone(),
368            le_column_index: self.le_column_index,
369            ts_column_index: self.ts_column_index,
370            tag_columns: self.tag_columns.clone(),
371            partition_exprs: self.partition_exprs.clone(),
372            quantile: self.quantile,
373            output_schema: self.output_schema.clone(),
374            field_column_index: self.field_column_index,
375            properties,
376        }))
377    }
378
379    fn execute(
380        &self,
381        partition: usize,
382        context: Arc<TaskContext>,
383    ) -> DataFusionResult<SendableRecordBatchStream> {
384        let baseline_metric = BaselineMetrics::new(&self.metric, partition);
385
386        let batch_size = context.session_config().batch_size();
387        let input = self.input.execute(partition, context)?;
388        let output_schema = self.output_schema.clone();
389
390        let mut normal_indices = (0..input.schema().fields().len()).collect::<HashSet<_>>();
391        normal_indices.remove(&self.field_column_index);
392        normal_indices.remove(&self.le_column_index);
393        Ok(Box::pin(HistogramFoldStream {
394            le_column_index: self.le_column_index,
395            field_column_index: self.field_column_index,
396            quantile: self.quantile,
397            normal_indices: normal_indices.into_iter().collect(),
398            bucket_size: None,
399            input_buffer: vec![],
400            input,
401            output_schema,
402            input_schema: self.input.schema(),
403            mode: FoldMode::Optimistic,
404            safe_group: None,
405            metric: baseline_metric,
406            batch_size,
407            input_buffered_rows: 0,
408            output_buffer: HistogramFoldStream::empty_output_buffer(
409                &self.output_schema,
410                self.le_column_index,
411            )?,
412            output_buffered_rows: 0,
413        }))
414    }
415
416    fn metrics(&self) -> Option<MetricsSet> {
417        Some(self.metric.clone_inner())
418    }
419
420    fn partition_statistics(&self, _: Option<usize>) -> DataFusionResult<Statistics> {
421        Ok(Statistics {
422            num_rows: Precision::Absent,
423            total_byte_size: Precision::Absent,
424            column_statistics: vec![
425                ColumnStatistics::new_unknown();
426                // plus one more for the removed column by function `convert_schema`
427                self.schema().flattened_fields().len() + 1
428            ],
429        })
430    }
431
432    fn name(&self) -> &str {
433        "HistogramFoldExec"
434    }
435}
436
437impl DisplayAs for HistogramFoldExec {
438    fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
439        match t {
440            DisplayFormatType::Default
441            | DisplayFormatType::Verbose
442            | DisplayFormatType::TreeRender => {
443                write!(
444                    f,
445                    "HistogramFoldExec: le=@{}, field=@{}, quantile={}",
446                    self.le_column_index, self.field_column_index, self.quantile
447                )
448            }
449        }
450    }
451}
452
453#[derive(Debug, Clone, Copy, PartialEq, Eq)]
454enum FoldMode {
455    Optimistic,
456    Safe,
457}
458
459pub struct HistogramFoldStream {
460    // internal states
461    le_column_index: usize,
462    field_column_index: usize,
463    quantile: f64,
464    /// Columns need not folding. This indices is based on input schema
465    normal_indices: Vec<usize>,
466    bucket_size: Option<usize>,
467    /// Expected output batch size
468    batch_size: usize,
469    output_schema: SchemaRef,
470    input_schema: SchemaRef,
471    mode: FoldMode,
472    safe_group: Option<SafeGroup>,
473
474    // buffers
475    input_buffer: Vec<RecordBatch>,
476    input_buffered_rows: usize,
477    output_buffer: Vec<Box<dyn MutableVector>>,
478    output_buffered_rows: usize,
479
480    // runtime things
481    input: SendableRecordBatchStream,
482    metric: BaselineMetrics,
483}
484
485#[derive(Debug, Default)]
486struct SafeGroup {
487    tag_values: Vec<Value>,
488    buckets: Vec<f64>,
489    counters: Vec<f64>,
490}
491
492impl RecordBatchStream for HistogramFoldStream {
493    fn schema(&self) -> SchemaRef {
494        self.output_schema.clone()
495    }
496}
497
498impl Stream for HistogramFoldStream {
499    type Item = DataFusionResult<RecordBatch>;
500
501    fn poll_next(
502        mut self: std::pin::Pin<&mut Self>,
503        cx: &mut std::task::Context<'_>,
504    ) -> Poll<Option<Self::Item>> {
505        let poll = loop {
506            match ready!(self.input.poll_next_unpin(cx)) {
507                Some(batch) => {
508                    let batch = batch?;
509                    let timer = Instant::now();
510                    let Some(result) = self.fold_input(batch)? else {
511                        self.metric.elapsed_compute().add_elapsed(timer);
512                        continue;
513                    };
514                    self.metric.elapsed_compute().add_elapsed(timer);
515                    break Poll::Ready(Some(result));
516                }
517                None => {
518                    self.flush_remaining()?;
519                    break Poll::Ready(self.take_output_buf()?.map(Ok));
520                }
521            }
522        };
523        self.metric.record_poll(poll)
524    }
525}
526
527impl HistogramFoldStream {
528    /// The inner most `Result` is for `poll_next()`
529    pub fn fold_input(
530        &mut self,
531        input: RecordBatch,
532    ) -> DataFusionResult<Option<DataFusionResult<RecordBatch>>> {
533        match self.mode {
534            FoldMode::Safe => {
535                self.push_input_buf(input);
536                self.process_safe_mode_buffer()?;
537            }
538            FoldMode::Optimistic => {
539                self.push_input_buf(input);
540                let Some(bucket_num) = self.calculate_bucket_num_from_buffer()? else {
541                    return Ok(None);
542                };
543                self.bucket_size = Some(bucket_num);
544
545                if self.input_buffered_rows < bucket_num {
546                    // not enough rows to fold
547                    return Ok(None);
548                }
549
550                self.fold_buf(bucket_num)?;
551            }
552        }
553
554        self.maybe_take_output()
555    }
556
557    /// Generate a group of empty [MutableVector]s from the output schema.
558    ///
559    /// For simplicity, this method will insert a placeholder for `le`. So that
560    /// the output buffers has the same schema with input. This placeholder needs
561    /// to be removed before returning the output batch.
562    pub fn empty_output_buffer(
563        schema: &SchemaRef,
564        le_column_index: usize,
565    ) -> DataFusionResult<Vec<Box<dyn MutableVector>>> {
566        let mut builders = Vec::with_capacity(schema.fields().len() + 1);
567        for field in schema.fields() {
568            let concrete_datatype = ConcreteDataType::try_from(field.data_type()).unwrap();
569            let mutable_vector = concrete_datatype.create_mutable_vector(0);
570            builders.push(mutable_vector);
571        }
572        builders.insert(
573            le_column_index,
574            ConcreteDataType::float64_datatype().create_mutable_vector(0),
575        );
576
577        Ok(builders)
578    }
579
580    /// Determines bucket count using buffered batches, concatenating them to
581    /// detect the first complete bucket that may span batch boundaries.
582    fn calculate_bucket_num_from_buffer(&mut self) -> DataFusionResult<Option<usize>> {
583        if let Some(size) = self.bucket_size {
584            return Ok(Some(size));
585        }
586
587        if self.input_buffer.is_empty() {
588            return Ok(None);
589        }
590
591        let batch_refs: Vec<&RecordBatch> = self.input_buffer.iter().collect();
592        let batch = concat_batches(&self.input_schema, batch_refs)?;
593        self.find_first_complete_bucket(&batch)
594    }
595
596    fn find_first_complete_bucket(&self, batch: &RecordBatch) -> DataFusionResult<Option<usize>> {
597        if batch.num_rows() == 0 {
598            return Ok(None);
599        }
600
601        let vectors = Helper::try_into_vectors(batch.columns())
602            .map_err(|e| DataFusionError::Execution(e.to_string()))?;
603        let le_array = batch.column(self.le_column_index).as_string::<i32>();
604
605        let mut tag_values_buf = Vec::with_capacity(self.normal_indices.len());
606        self.collect_tag_values(&vectors, 0, &mut tag_values_buf);
607        let mut group_start = 0usize;
608
609        for row in 0..batch.num_rows() {
610            if !self.is_same_group(&vectors, row, &tag_values_buf) {
611                // new group begins
612                self.collect_tag_values(&vectors, row, &mut tag_values_buf);
613                group_start = row;
614            }
615
616            if Self::is_positive_infinity(le_array, row) {
617                return Ok(Some(row - group_start + 1));
618            }
619        }
620
621        Ok(None)
622    }
623
624    /// Fold record batches from input buffer and put to output buffer
625    fn fold_buf(&mut self, bucket_num: usize) -> DataFusionResult<()> {
626        let batch = concat_batches(&self.input_schema, self.input_buffer.drain(..).as_ref())?;
627        let mut remaining_rows = self.input_buffered_rows;
628        let mut cursor = 0;
629
630        // TODO(LFC): Try to get rid of the Arrow array to vector conversion here.
631        let vectors = Helper::try_into_vectors(batch.columns())
632            .map_err(|e| DataFusionError::Execution(e.to_string()))?;
633        let le_array = batch.column(self.le_column_index);
634        let le_array = le_array.as_string::<i32>();
635        let field_array = batch.column(self.field_column_index);
636        let field_array = field_array.as_primitive::<Float64Type>();
637        let mut tag_values_buf = Vec::with_capacity(self.normal_indices.len());
638
639        while remaining_rows >= bucket_num && self.mode == FoldMode::Optimistic {
640            self.collect_tag_values(&vectors, cursor, &mut tag_values_buf);
641            if !self.validate_optimistic_group(
642                &vectors,
643                le_array,
644                cursor,
645                bucket_num,
646                &tag_values_buf,
647            ) {
648                let remaining_input_batch = batch.slice(cursor, remaining_rows);
649                self.switch_to_safe_mode(remaining_input_batch)?;
650                return Ok(());
651            }
652
653            // "sample" normal columns
654            for (idx, value) in self.normal_indices.iter().zip(tag_values_buf.iter()) {
655                self.output_buffer[*idx].push_value_ref(value);
656            }
657            // "fold" `le` and field columns
658            let mut bucket = Vec::with_capacity(bucket_num);
659            let mut counters = Vec::with_capacity(bucket_num);
660            for bias in 0..bucket_num {
661                let position = cursor + bias;
662                let le = if le_array.is_valid(position) {
663                    le_array.value(position).parse::<f64>().unwrap_or(f64::NAN)
664                } else {
665                    f64::NAN
666                };
667                bucket.push(le);
668
669                let counter = if field_array.is_valid(position) {
670                    field_array.value(position)
671                } else {
672                    f64::NAN
673                };
674                counters.push(counter);
675            }
676            // ignore invalid data
677            let result = Self::evaluate_row(self.quantile, &bucket, &counters).unwrap_or(f64::NAN);
678            self.output_buffer[self.field_column_index].push_value_ref(&ValueRef::from(result));
679            cursor += bucket_num;
680            remaining_rows -= bucket_num;
681            self.output_buffered_rows += 1;
682        }
683
684        let remaining_input_batch = batch.slice(cursor, remaining_rows);
685        self.input_buffered_rows = remaining_input_batch.num_rows();
686        if self.input_buffered_rows > 0 {
687            self.input_buffer.push(remaining_input_batch);
688        }
689
690        Ok(())
691    }
692
693    fn push_input_buf(&mut self, batch: RecordBatch) {
694        self.input_buffered_rows += batch.num_rows();
695        self.input_buffer.push(batch);
696    }
697
698    fn maybe_take_output(&mut self) -> DataFusionResult<Option<DataFusionResult<RecordBatch>>> {
699        if self.output_buffered_rows >= self.batch_size {
700            return Ok(self.take_output_buf()?.map(Ok));
701        }
702        Ok(None)
703    }
704
705    fn switch_to_safe_mode(&mut self, remaining_batch: RecordBatch) -> DataFusionResult<()> {
706        self.mode = FoldMode::Safe;
707        self.bucket_size = None;
708        self.input_buffer.clear();
709        self.input_buffered_rows = remaining_batch.num_rows();
710
711        if self.input_buffered_rows > 0 {
712            self.input_buffer.push(remaining_batch);
713            self.process_safe_mode_buffer()?;
714        }
715
716        Ok(())
717    }
718
719    fn collect_tag_values<'a>(
720        &self,
721        vectors: &'a [VectorRef],
722        row: usize,
723        tag_values: &mut Vec<ValueRef<'a>>,
724    ) {
725        tag_values.clear();
726        for idx in self.normal_indices.iter() {
727            tag_values.push(vectors[*idx].get_ref(row));
728        }
729    }
730
731    fn validate_optimistic_group(
732        &self,
733        vectors: &[VectorRef],
734        le_array: &StringArray,
735        cursor: usize,
736        bucket_num: usize,
737        tag_values: &[ValueRef<'_>],
738    ) -> bool {
739        let inf_index = cursor + bucket_num - 1;
740        if !Self::is_positive_infinity(le_array, inf_index) {
741            return false;
742        }
743
744        for offset in 1..bucket_num {
745            let row = cursor + offset;
746            for (idx, expected) in self.normal_indices.iter().zip(tag_values.iter()) {
747                if vectors[*idx].get_ref(row) != *expected {
748                    return false;
749                }
750            }
751        }
752        true
753    }
754
755    /// Checks whether a row belongs to the current group (same series).
756    fn is_same_group(
757        &self,
758        vectors: &[VectorRef],
759        row: usize,
760        tag_values: &[ValueRef<'_>],
761    ) -> bool {
762        self.normal_indices
763            .iter()
764            .zip(tag_values.iter())
765            .all(|(idx, expected)| vectors[*idx].get_ref(row) == *expected)
766    }
767
768    fn push_output_row(&mut self, tag_values: &[ValueRef<'_>], result: f64) {
769        debug_assert_eq!(self.normal_indices.len(), tag_values.len());
770        for (idx, value) in self.normal_indices.iter().zip(tag_values.iter()) {
771            self.output_buffer[*idx].push_value_ref(value);
772        }
773        self.output_buffer[self.field_column_index].push_value_ref(&ValueRef::from(result));
774        self.output_buffered_rows += 1;
775    }
776
777    fn finalize_safe_group(&mut self) -> DataFusionResult<()> {
778        if let Some(group) = self.safe_group.take() {
779            if group.tag_values.is_empty() {
780                return Ok(());
781            }
782
783            let has_inf = group
784                .buckets
785                .last()
786                .map(|v| v.is_infinite() && v.is_sign_positive())
787                .unwrap_or(false);
788            let result = if group.buckets.len() < 2 || !has_inf {
789                f64::NAN
790            } else {
791                Self::evaluate_row(self.quantile, &group.buckets, &group.counters)
792                    .unwrap_or(f64::NAN)
793            };
794            let mut tag_value_refs = Vec::with_capacity(group.tag_values.len());
795            tag_value_refs.extend(group.tag_values.iter().map(|v| v.as_value_ref()));
796            self.push_output_row(&tag_value_refs, result);
797        }
798        Ok(())
799    }
800
801    fn process_safe_mode_buffer(&mut self) -> DataFusionResult<()> {
802        if self.input_buffer.is_empty() {
803            self.input_buffered_rows = 0;
804            return Ok(());
805        }
806
807        let batch = concat_batches(&self.input_schema, self.input_buffer.drain(..).as_ref())?;
808        self.input_buffered_rows = 0;
809        let vectors = Helper::try_into_vectors(batch.columns())
810            .map_err(|e| DataFusionError::Execution(e.to_string()))?;
811        let le_array = batch.column(self.le_column_index).as_string::<i32>();
812        let field_array = batch
813            .column(self.field_column_index)
814            .as_primitive::<Float64Type>();
815        let mut tag_values_buf = Vec::with_capacity(self.normal_indices.len());
816
817        for row in 0..batch.num_rows() {
818            self.collect_tag_values(&vectors, row, &mut tag_values_buf);
819            let should_start_new_group = self
820                .safe_group
821                .as_ref()
822                .is_none_or(|group| !Self::tag_values_equal(&group.tag_values, &tag_values_buf));
823            if should_start_new_group {
824                self.finalize_safe_group()?;
825                self.safe_group = Some(SafeGroup {
826                    tag_values: tag_values_buf.iter().cloned().map(Value::from).collect(),
827                    buckets: Vec::new(),
828                    counters: Vec::new(),
829                });
830            }
831
832            let Some(group) = self.safe_group.as_mut() else {
833                continue;
834            };
835
836            let bucket = if le_array.is_valid(row) {
837                le_array.value(row).parse::<f64>().unwrap_or(f64::NAN)
838            } else {
839                f64::NAN
840            };
841            let counter = if field_array.is_valid(row) {
842                field_array.value(row)
843            } else {
844                f64::NAN
845            };
846
847            group.buckets.push(bucket);
848            group.counters.push(counter);
849        }
850
851        Ok(())
852    }
853
854    fn tag_values_equal(group_values: &[Value], current: &[ValueRef<'_>]) -> bool {
855        group_values.len() == current.len()
856            && group_values
857                .iter()
858                .zip(current.iter())
859                .all(|(group, now)| group.as_value_ref() == *now)
860    }
861
862    /// Compute result from output buffer
863    fn take_output_buf(&mut self) -> DataFusionResult<Option<RecordBatch>> {
864        if self.output_buffered_rows == 0 {
865            if self.input_buffered_rows != 0 {
866                warn!(
867                    "input buffer is not empty, {} rows remaining",
868                    self.input_buffered_rows
869                );
870            }
871            return Ok(None);
872        }
873
874        let mut output_buf = Self::empty_output_buffer(&self.output_schema, self.le_column_index)?;
875        std::mem::swap(&mut self.output_buffer, &mut output_buf);
876        let mut columns = Vec::with_capacity(output_buf.len());
877        for builder in output_buf.iter_mut() {
878            columns.push(builder.to_vector().to_arrow_array());
879        }
880        // remove the placeholder column for `le`
881        columns.remove(self.le_column_index);
882
883        self.output_buffered_rows = 0;
884        RecordBatch::try_new(self.output_schema.clone(), columns)
885            .map(Some)
886            .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))
887    }
888
889    fn flush_remaining(&mut self) -> DataFusionResult<()> {
890        if self.mode == FoldMode::Optimistic && self.input_buffered_rows > 0 {
891            let buffered_batches: Vec<_> = self.input_buffer.drain(..).collect();
892            if !buffered_batches.is_empty() {
893                let batch = concat_batches(&self.input_schema, buffered_batches.as_slice())?;
894                self.switch_to_safe_mode(batch)?;
895            } else {
896                self.input_buffered_rows = 0;
897            }
898        }
899
900        if self.mode == FoldMode::Safe {
901            self.process_safe_mode_buffer()?;
902            self.finalize_safe_group()?;
903        }
904
905        Ok(())
906    }
907
908    fn is_positive_infinity(le_array: &StringArray, index: usize) -> bool {
909        le_array.is_valid(index)
910            && matches!(
911                le_array.value(index).parse::<f64>(),
912                Ok(value) if value.is_infinite() && value.is_sign_positive()
913            )
914    }
915
916    /// Evaluate the field column and return the result
917    fn evaluate_row(quantile: f64, bucket: &[f64], counter: &[f64]) -> DataFusionResult<f64> {
918        // check bucket
919        if bucket.len() <= 1 {
920            return Ok(f64::NAN);
921        }
922        if bucket.last().unwrap().is_finite() {
923            return Err(DataFusionError::Execution(
924                "last bucket should be +Inf".to_string(),
925            ));
926        }
927        if bucket.len() != counter.len() {
928            return Err(DataFusionError::Execution(
929                "bucket and counter should have the same length".to_string(),
930            ));
931        }
932        // check quantile
933        if quantile < 0.0 {
934            return Ok(f64::NEG_INFINITY);
935        } else if quantile > 1.0 {
936            return Ok(f64::INFINITY);
937        } else if quantile.is_nan() {
938            return Ok(f64::NAN);
939        }
940
941        // check input value
942        if !bucket.windows(2).all(|w| w[0] <= w[1]) {
943            return Ok(f64::NAN);
944        }
945        let counter = {
946            let needs_fix =
947                counter.iter().any(|v| !v.is_finite()) || !counter.windows(2).all(|w| w[0] <= w[1]);
948            if !needs_fix {
949                Cow::Borrowed(counter)
950            } else {
951                let mut fixed = Vec::with_capacity(counter.len());
952                let mut prev = 0.0;
953                for (idx, &v) in counter.iter().enumerate() {
954                    let mut val = if v.is_finite() { v } else { prev };
955                    if idx > 0 && val < prev {
956                        val = prev;
957                    }
958                    fixed.push(val);
959                    prev = val;
960                }
961                Cow::Owned(fixed)
962            }
963        };
964
965        let total = *counter.last().unwrap();
966        let expected_pos = total * quantile;
967        let mut fit_bucket_pos = 0;
968        while fit_bucket_pos < bucket.len() && counter[fit_bucket_pos] < expected_pos {
969            fit_bucket_pos += 1;
970        }
971        if fit_bucket_pos >= bucket.len() - 1 {
972            Ok(bucket[bucket.len() - 2])
973        } else {
974            let upper_bound = bucket[fit_bucket_pos];
975            let upper_count = counter[fit_bucket_pos];
976            let mut lower_bound = bucket[0].min(0.0);
977            let mut lower_count = 0.0;
978            if fit_bucket_pos > 0 {
979                lower_bound = bucket[fit_bucket_pos - 1];
980                lower_count = counter[fit_bucket_pos - 1];
981            }
982            if (upper_count - lower_count).abs() < 1e-10 {
983                return Ok(f64::NAN);
984            }
985            Ok(lower_bound
986                + (upper_bound - lower_bound) / (upper_count - lower_count)
987                    * (expected_pos - lower_count))
988        }
989    }
990}
991
992#[cfg(test)]
993mod test {
994    use std::sync::Arc;
995
996    use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
997    use datafusion::arrow::datatypes::{Field, Schema, SchemaRef, TimeUnit};
998    use datafusion::common::ToDFSchema;
999    use datafusion::datasource::memory::MemorySourceConfig;
1000    use datafusion::datasource::source::DataSourceExec;
1001    use datafusion::prelude::SessionContext;
1002    use datatypes::arrow_array::StringArray;
1003
1004    use super::*;
1005
1006    fn prepare_test_data() -> DataSourceExec {
1007        let schema = Arc::new(Schema::new(vec![
1008            Field::new("host", DataType::Utf8, true),
1009            Field::new("le", DataType::Utf8, true),
1010            Field::new("val", DataType::Float64, true),
1011        ]));
1012
1013        // 12 items
1014        let host_column_1 = Arc::new(StringArray::from(vec![
1015            "host_1", "host_1", "host_1", "host_1", "host_1", "host_1", "host_1", "host_1",
1016            "host_1", "host_1", "host_1", "host_1",
1017        ])) as _;
1018        let le_column_1 = Arc::new(StringArray::from(vec![
1019            "0.001", "0.1", "10", "1000", "+Inf", "0.001", "0.1", "10", "1000", "+inf", "0.001",
1020            "0.1",
1021        ])) as _;
1022        let val_column_1 = Arc::new(Float64Array::from(vec![
1023            0_0.0, 1.0, 1.0, 5.0, 5.0, 0_0.0, 20.0, 60.0, 70.0, 100.0, 0_1.0, 1.0,
1024        ])) as _;
1025
1026        // 2 items
1027        let host_column_2 = Arc::new(StringArray::from(vec!["host_1", "host_1"])) as _;
1028        let le_column_2 = Arc::new(StringArray::from(vec!["10", "1000"])) as _;
1029        let val_column_2 = Arc::new(Float64Array::from(vec![1.0, 1.0])) as _;
1030
1031        // 11 items
1032        let host_column_3 = Arc::new(StringArray::from(vec![
1033            "host_1", "host_2", "host_2", "host_2", "host_2", "host_2", "host_2", "host_2",
1034            "host_2", "host_2", "host_2",
1035        ])) as _;
1036        let le_column_3 = Arc::new(StringArray::from(vec![
1037            "+INF", "0.001", "0.1", "10", "1000", "+iNf", "0.001", "0.1", "10", "1000", "+Inf",
1038        ])) as _;
1039        let val_column_3 = Arc::new(Float64Array::from(vec![
1040            1.0, 0_0.0, 0.0, 0.0, 0.0, 0.0, 0_0.0, 1.0, 2.0, 3.0, 4.0,
1041        ])) as _;
1042
1043        let data_1 = RecordBatch::try_new(
1044            schema.clone(),
1045            vec![host_column_1, le_column_1, val_column_1],
1046        )
1047        .unwrap();
1048        let data_2 = RecordBatch::try_new(
1049            schema.clone(),
1050            vec![host_column_2, le_column_2, val_column_2],
1051        )
1052        .unwrap();
1053        let data_3 = RecordBatch::try_new(
1054            schema.clone(),
1055            vec![host_column_3, le_column_3, val_column_3],
1056        )
1057        .unwrap();
1058
1059        DataSourceExec::new(Arc::new(
1060            MemorySourceConfig::try_new(&[vec![data_1, data_2, data_3]], schema, None).unwrap(),
1061        ))
1062    }
1063
1064    fn build_fold_exec_from_batches(
1065        batches: Vec<RecordBatch>,
1066        schema: SchemaRef,
1067        quantile: f64,
1068        ts_column_index: usize,
1069    ) -> Arc<HistogramFoldExec> {
1070        let input: Arc<dyn ExecutionPlan> = Arc::new(DataSourceExec::new(Arc::new(
1071            MemorySourceConfig::try_new(&[batches], schema.clone(), None).unwrap(),
1072        )));
1073        let output_schema: SchemaRef = Arc::new(
1074            HistogramFold::convert_schema(&Arc::new(input.schema().to_dfschema().unwrap()), "le")
1075                .unwrap()
1076                .as_arrow()
1077                .clone(),
1078        );
1079
1080        let (tag_columns, partition_exprs, properties) =
1081            build_test_plan_properties(&input, output_schema.clone(), ts_column_index);
1082
1083        Arc::new(HistogramFoldExec {
1084            le_column_index: 1,
1085            field_column_index: 2,
1086            quantile,
1087            ts_column_index,
1088            input,
1089            output_schema,
1090            tag_columns,
1091            partition_exprs,
1092            metric: ExecutionPlanMetricsSet::new(),
1093            properties,
1094        })
1095    }
1096
1097    type PlanPropsResult = (
1098        Vec<Arc<dyn PhysicalExpr>>,
1099        Vec<Arc<dyn PhysicalExpr>>,
1100        PlanProperties,
1101    );
1102
1103    fn build_test_plan_properties(
1104        input: &Arc<dyn ExecutionPlan>,
1105        output_schema: SchemaRef,
1106        ts_column_index: usize,
1107    ) -> PlanPropsResult {
1108        let tag_columns = input
1109            .schema()
1110            .fields()
1111            .iter()
1112            .enumerate()
1113            .filter_map(|(idx, field)| {
1114                if idx == 1 || idx == 2 || idx == ts_column_index {
1115                    None
1116                } else {
1117                    Some(Arc::new(PhyColumn::new(field.name(), idx)) as _)
1118                }
1119            })
1120            .collect::<Vec<_>>();
1121
1122        let partition_exprs = if tag_columns.is_empty() {
1123            vec![Arc::new(PhyColumn::new(
1124                input.schema().field(ts_column_index).name(),
1125                ts_column_index,
1126            )) as _]
1127        } else {
1128            tag_columns.clone()
1129        };
1130
1131        let properties = PlanProperties::new(
1132            EquivalenceProperties::new(output_schema.clone()),
1133            Partitioning::Hash(
1134                partition_exprs.clone(),
1135                input.output_partitioning().partition_count(),
1136            ),
1137            EmissionType::Incremental,
1138            Boundedness::Bounded,
1139        );
1140
1141        (tag_columns, partition_exprs, properties)
1142    }
1143
1144    #[tokio::test]
1145    async fn fold_overall() {
1146        let memory_exec: Arc<dyn ExecutionPlan> = Arc::new(prepare_test_data());
1147        let output_schema: SchemaRef = Arc::new(
1148            HistogramFold::convert_schema(
1149                &Arc::new(memory_exec.schema().to_dfschema().unwrap()),
1150                "le",
1151            )
1152            .unwrap()
1153            .as_arrow()
1154            .clone(),
1155        );
1156        let (tag_columns, partition_exprs, properties) =
1157            build_test_plan_properties(&memory_exec, output_schema.clone(), 0);
1158        let fold_exec = Arc::new(HistogramFoldExec {
1159            le_column_index: 1,
1160            field_column_index: 2,
1161            quantile: 0.4,
1162            ts_column_index: 0,
1163            input: memory_exec,
1164            output_schema,
1165            tag_columns,
1166            partition_exprs,
1167            metric: ExecutionPlanMetricsSet::new(),
1168            properties,
1169        });
1170
1171        let session_context = SessionContext::default();
1172        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1173            .await
1174            .unwrap();
1175        let result_literal = datatypes::arrow::util::pretty::pretty_format_batches(&result)
1176            .unwrap()
1177            .to_string();
1178
1179        let expected = String::from(
1180            "+--------+-------------------+
1181| host   | val               |
1182+--------+-------------------+
1183| host_1 | 257.5             |
1184| host_1 | 5.05              |
1185| host_1 | 0.0004            |
1186| host_2 | NaN               |
1187| host_2 | 6.040000000000001 |
1188+--------+-------------------+",
1189        );
1190        assert_eq!(result_literal, expected);
1191    }
1192
1193    #[test]
1194    fn confirm_schema() {
1195        let input_schema = Schema::new(vec![
1196            Field::new("host", DataType::Utf8, true),
1197            Field::new("le", DataType::Utf8, true),
1198            Field::new("val", DataType::Float64, true),
1199        ])
1200        .to_dfschema_ref()
1201        .unwrap();
1202        let expected_output_schema = Schema::new(vec![
1203            Field::new("host", DataType::Utf8, true),
1204            Field::new("val", DataType::Float64, true),
1205        ])
1206        .to_dfschema_ref()
1207        .unwrap();
1208
1209        let actual = HistogramFold::convert_schema(&input_schema, "le").unwrap();
1210        assert_eq!(actual, expected_output_schema)
1211    }
1212
1213    #[tokio::test]
1214    async fn fallback_to_safe_mode_on_missing_inf() {
1215        let schema = Arc::new(Schema::new(vec![
1216            Field::new("host", DataType::Utf8, true),
1217            Field::new("le", DataType::Utf8, true),
1218            Field::new("val", DataType::Float64, true),
1219        ]));
1220        let host_column = Arc::new(StringArray::from(vec!["a", "a", "a", "a", "b", "b"])) as _;
1221        let le_column = Arc::new(StringArray::from(vec![
1222            "0.1", "+Inf", "0.1", "1.0", "0.1", "+Inf",
1223        ])) as _;
1224        let val_column = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 3.0, 1.0, 5.0])) as _;
1225        let batch =
1226            RecordBatch::try_new(schema.clone(), vec![host_column, le_column, val_column]).unwrap();
1227        let fold_exec = build_fold_exec_from_batches(vec![batch], schema, 0.5, 0);
1228        let session_context = SessionContext::default();
1229        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1230            .await
1231            .unwrap();
1232        let result_literal = datatypes::arrow::util::pretty::pretty_format_batches(&result)
1233            .unwrap()
1234            .to_string();
1235
1236        let expected = String::from(
1237            "+------+-----+
1238| host | val |
1239+------+-----+
1240| a    | 0.1 |
1241| a    | NaN |
1242| b    | 0.1 |
1243+------+-----+",
1244        );
1245        assert_eq!(result_literal, expected);
1246    }
1247
1248    #[tokio::test]
1249    async fn emit_nan_when_no_inf_present() {
1250        let schema = Arc::new(Schema::new(vec![
1251            Field::new("host", DataType::Utf8, true),
1252            Field::new("le", DataType::Utf8, true),
1253            Field::new("val", DataType::Float64, true),
1254        ]));
1255        let host_column = Arc::new(StringArray::from(vec!["c", "c"])) as _;
1256        let le_column = Arc::new(StringArray::from(vec!["0.1", "1.0"])) as _;
1257        let val_column = Arc::new(Float64Array::from(vec![1.0, 2.0])) as _;
1258        let batch =
1259            RecordBatch::try_new(schema.clone(), vec![host_column, le_column, val_column]).unwrap();
1260        let fold_exec = build_fold_exec_from_batches(vec![batch], schema, 0.9, 0);
1261        let session_context = SessionContext::default();
1262        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1263            .await
1264            .unwrap();
1265        let result_literal = datatypes::arrow::util::pretty::pretty_format_batches(&result)
1266            .unwrap()
1267            .to_string();
1268
1269        let expected = String::from(
1270            "+------+-----+
1271| host | val |
1272+------+-----+
1273| c    | NaN |
1274+------+-----+",
1275        );
1276        assert_eq!(result_literal, expected);
1277    }
1278
1279    #[tokio::test]
1280    async fn safe_mode_handles_misaligned_groups() {
1281        let schema = Arc::new(Schema::new(vec![
1282            Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
1283            Field::new("le", DataType::Utf8, true),
1284            Field::new("val", DataType::Float64, true),
1285        ]));
1286
1287        let ts_column = Arc::new(TimestampMillisecondArray::from(vec![
1288            2900000, 2900000, 2900000, 3000000, 3000000, 3000000, 3000000, 3005000, 3005000,
1289            3010000, 3010000, 3010000, 3010000, 3010000,
1290        ])) as _;
1291        let le_column = Arc::new(StringArray::from(vec![
1292            "0.1", "1", "5", "0.1", "1", "5", "+Inf", "0.1", "+Inf", "0.1", "1", "3", "5", "+Inf",
1293        ])) as _;
1294        let val_column = Arc::new(Float64Array::from(vec![
1295            0.0, 0.0, 0.0, 50.0, 70.0, 110.0, 120.0, 10.0, 30.0, 10.0, 20.0, 30.0, 40.0, 50.0,
1296        ])) as _;
1297        let batch =
1298            RecordBatch::try_new(schema.clone(), vec![ts_column, le_column, val_column]).unwrap();
1299        let fold_exec = build_fold_exec_from_batches(vec![batch], schema, 0.5, 0);
1300        let session_context = SessionContext::default();
1301        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1302            .await
1303            .unwrap();
1304
1305        let mut values = Vec::new();
1306        for batch in result {
1307            let array = batch.column(1).as_primitive::<Float64Type>();
1308            values.extend(array.iter().map(|v| v.unwrap()));
1309        }
1310
1311        assert_eq!(values.len(), 4);
1312        assert!(values[0].is_nan());
1313        assert!((values[1] - 0.55).abs() < 1e-10);
1314        assert!((values[2] - 0.1).abs() < 1e-10);
1315        assert!((values[3] - 2.0).abs() < 1e-10);
1316    }
1317
1318    #[tokio::test]
1319    async fn missing_buckets_at_first_timestamp() {
1320        let schema = Arc::new(Schema::new(vec![
1321            Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
1322            Field::new("le", DataType::Utf8, true),
1323            Field::new("val", DataType::Float64, true),
1324        ]));
1325
1326        let ts_column = Arc::new(TimestampMillisecondArray::from(vec![
1327            2_900_000, 3_000_000, 3_000_000, 3_000_000, 3_000_000, 3_005_000, 3_005_000, 3_010_000,
1328            3_010_000, 3_010_000, 3_010_000, 3_010_000,
1329        ])) as _;
1330        let le_column = Arc::new(StringArray::from(vec![
1331            "0.1", "0.1", "1", "5", "+Inf", "0.1", "+Inf", "0.1", "1", "3", "5", "+Inf",
1332        ])) as _;
1333        let val_column = Arc::new(Float64Array::from(vec![
1334            0.0, 50.0, 70.0, 110.0, 120.0, 10.0, 30.0, 10.0, 20.0, 30.0, 40.0, 50.0,
1335        ])) as _;
1336
1337        let batch =
1338            RecordBatch::try_new(schema.clone(), vec![ts_column, le_column, val_column]).unwrap();
1339        let fold_exec = build_fold_exec_from_batches(vec![batch], schema, 0.5, 0);
1340        let session_context = SessionContext::default();
1341        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1342            .await
1343            .unwrap();
1344
1345        let mut values = Vec::new();
1346        for batch in result {
1347            let array = batch.column(1).as_primitive::<Float64Type>();
1348            values.extend(array.iter().map(|v| v.unwrap()));
1349        }
1350
1351        assert_eq!(values.len(), 4);
1352        assert!(values[0].is_nan());
1353        assert!((values[1] - 0.55).abs() < 1e-10);
1354        assert!((values[2] - 0.1).abs() < 1e-10);
1355        assert!((values[3] - 2.0).abs() < 1e-10);
1356    }
1357
1358    #[tokio::test]
1359    async fn missing_inf_in_first_group() {
1360        let schema = Arc::new(Schema::new(vec![
1361            Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
1362            Field::new("le", DataType::Utf8, true),
1363            Field::new("val", DataType::Float64, true),
1364        ]));
1365
1366        let ts_column = Arc::new(TimestampMillisecondArray::from(vec![
1367            1000, 1000, 1000, 2000, 2000, 2000, 2000,
1368        ])) as _;
1369        let le_column = Arc::new(StringArray::from(vec![
1370            "0.1", "1", "5", "0.1", "1", "5", "+Inf",
1371        ])) as _;
1372        let val_column = Arc::new(Float64Array::from(vec![
1373            0.0, 0.0, 0.0, 10.0, 20.0, 30.0, 30.0,
1374        ])) as _;
1375        let batch =
1376            RecordBatch::try_new(schema.clone(), vec![ts_column, le_column, val_column]).unwrap();
1377        let fold_exec = build_fold_exec_from_batches(vec![batch], schema, 0.5, 0);
1378        let session_context = SessionContext::default();
1379        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1380            .await
1381            .unwrap();
1382
1383        let mut values = Vec::new();
1384        for batch in result {
1385            let array = batch.column(1).as_primitive::<Float64Type>();
1386            values.extend(array.iter().map(|v| v.unwrap()));
1387        }
1388
1389        assert_eq!(values.len(), 2);
1390        assert!(values[0].is_nan());
1391        assert!((values[1] - 0.55).abs() < 1e-10, "{values:?}");
1392    }
1393
1394    #[test]
1395    fn evaluate_row_normal_case() {
1396        let bucket = [0.0, 1.0, 2.0, 3.0, 4.0, f64::INFINITY];
1397
1398        #[derive(Debug)]
1399        struct Case {
1400            quantile: f64,
1401            counters: Vec<f64>,
1402            expected: f64,
1403        }
1404
1405        let cases = [
1406            Case {
1407                quantile: 0.9,
1408                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1409                expected: 4.0,
1410            },
1411            Case {
1412                quantile: 0.89,
1413                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1414                expected: 4.0,
1415            },
1416            Case {
1417                quantile: 0.78,
1418                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1419                expected: 3.9,
1420            },
1421            Case {
1422                quantile: 0.5,
1423                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1424                expected: 2.5,
1425            },
1426            Case {
1427                quantile: 0.5,
1428                counters: vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1429                expected: f64::NAN,
1430            },
1431            Case {
1432                quantile: 1.0,
1433                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1434                expected: 4.0,
1435            },
1436            Case {
1437                quantile: 0.0,
1438                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1439                expected: f64::NAN,
1440            },
1441            Case {
1442                quantile: 1.1,
1443                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1444                expected: f64::INFINITY,
1445            },
1446            Case {
1447                quantile: -1.0,
1448                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1449                expected: f64::NEG_INFINITY,
1450            },
1451        ];
1452
1453        for case in cases {
1454            let actual =
1455                HistogramFoldStream::evaluate_row(case.quantile, &bucket, &case.counters).unwrap();
1456            assert_eq!(
1457                format!("{actual}"),
1458                format!("{}", case.expected),
1459                "{:?}",
1460                case
1461            );
1462        }
1463    }
1464
1465    #[test]
1466    fn evaluate_out_of_order_input() {
1467        let bucket = [0.0, 1.0, 2.0, 3.0, 4.0, f64::INFINITY];
1468        let counters = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0];
1469        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters).unwrap();
1470        assert_eq!(0.0, result);
1471    }
1472
1473    #[test]
1474    fn evaluate_wrong_bucket() {
1475        let bucket = [0.0, 1.0, 2.0, 3.0, 4.0, f64::INFINITY, 5.0];
1476        let counters = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
1477        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters);
1478        assert!(result.is_err());
1479    }
1480
1481    #[test]
1482    fn evaluate_small_fraction() {
1483        let bucket = [0.0, 2.0, 4.0, 6.0, f64::INFINITY];
1484        let counters = [0.0, 1.0 / 300.0, 2.0 / 300.0, 0.01, 0.01];
1485        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters).unwrap();
1486        assert_eq!(3.0, result);
1487    }
1488
1489    #[test]
1490    fn evaluate_non_monotonic_counter() {
1491        let bucket = [0.0, 1.0, 2.0, 3.0, f64::INFINITY];
1492        let counters = [0.1, 0.2, 0.4, 0.17, 0.5];
1493        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters).unwrap();
1494        assert!((result - 1.25).abs() < 1e-10, "{result}");
1495    }
1496
1497    #[test]
1498    fn evaluate_nan_counter() {
1499        let bucket = [0.0, 1.0, 2.0, 3.0, f64::INFINITY];
1500        let counters = [f64::NAN, 1.0, 2.0, 3.0, 3.0];
1501        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters).unwrap();
1502        assert!((result - 1.5).abs() < 1e-10, "{result}");
1503    }
1504}