promql/extension_plan/
histogram_fold.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::borrow::Cow;
17use std::collections::{HashMap, HashSet};
18use std::sync::Arc;
19use std::task::Poll;
20use std::time::Instant;
21
22use common_telemetry::warn;
23use datafusion::arrow::array::{Array, AsArray, StringArray};
24use datafusion::arrow::compute::{SortOptions, concat_batches};
25use datafusion::arrow::datatypes::{DataType, Float64Type, SchemaRef};
26use datafusion::arrow::record_batch::RecordBatch;
27use datafusion::common::stats::Precision;
28use datafusion::common::{ColumnStatistics, DFSchema, DFSchemaRef, Statistics};
29use datafusion::error::{DataFusionError, Result as DataFusionResult};
30use datafusion::execution::TaskContext;
31use datafusion::logical_expr::{LogicalPlan, UserDefinedLogicalNodeCore};
32use datafusion::physical_expr::{
33    EquivalenceProperties, LexRequirement, OrderingRequirements, PhysicalSortRequirement,
34};
35use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
36use datafusion::physical_plan::expressions::{CastExpr as PhyCast, Column as PhyColumn};
37use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
38use datafusion::physical_plan::{
39    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr,
40    PlanProperties, RecordBatchStream, SendableRecordBatchStream,
41};
42use datafusion::prelude::{Column, Expr};
43use datatypes::prelude::{ConcreteDataType, DataType as GtDataType};
44use datatypes::value::{OrderedF64, Value, ValueRef};
45use datatypes::vectors::{Helper, MutableVector, VectorRef};
46use futures::{Stream, StreamExt, ready};
47
48/// `HistogramFold` will fold the conventional (non-native) histogram ([1]) for later
49/// computing.
50///
51/// Specifically, it will transform the `le` and `field` column into a complex
52/// type, and samples on other tag columns:
53/// - `le` will become a [ListArray] of [f64]. With each bucket bound parsed
54/// - `field` will become a [ListArray] of [f64]
55/// - other columns will be sampled every `bucket_num` element, but their types won't change.
56///
57/// Due to the folding or sampling, the output rows number will become `input_rows` / `bucket_num`.
58///
59/// # Requirement
60/// - Input should be sorted on `<tag list>, ts, le ASC`.
61/// - The value set of `le` should be same. I.e., buckets of every series should be same.
62///
63/// [1]: https://prometheus.io/docs/concepts/metric_types/#histogram
64#[derive(Debug, PartialEq, Hash, Eq)]
65pub struct HistogramFold {
66    /// Name of the `le` column. It's a special column in prometheus
67    /// for implementing conventional histogram. It's a string column
68    /// with "literal" float value, like "+Inf", "0.001" etc.
69    le_column: String,
70    ts_column: String,
71    input: LogicalPlan,
72    field_column: String,
73    quantile: OrderedF64,
74    output_schema: DFSchemaRef,
75}
76
77impl UserDefinedLogicalNodeCore for HistogramFold {
78    fn name(&self) -> &str {
79        Self::name()
80    }
81
82    fn inputs(&self) -> Vec<&LogicalPlan> {
83        vec![&self.input]
84    }
85
86    fn schema(&self) -> &DFSchemaRef {
87        &self.output_schema
88    }
89
90    fn expressions(&self) -> Vec<Expr> {
91        vec![]
92    }
93
94    fn fmt_for_explain(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        write!(
96            f,
97            "HistogramFold: le={}, field={}, quantile={}",
98            self.le_column, self.field_column, self.quantile
99        )
100    }
101
102    fn with_exprs_and_inputs(
103        &self,
104        _exprs: Vec<Expr>,
105        inputs: Vec<LogicalPlan>,
106    ) -> DataFusionResult<Self> {
107        Ok(Self {
108            le_column: self.le_column.clone(),
109            ts_column: self.ts_column.clone(),
110            input: inputs.into_iter().next().unwrap(),
111            field_column: self.field_column.clone(),
112            quantile: self.quantile,
113            // This method cannot return error. Otherwise we should re-calculate
114            // the output schema
115            output_schema: self.output_schema.clone(),
116        })
117    }
118}
119
120impl HistogramFold {
121    pub fn new(
122        le_column: String,
123        field_column: String,
124        ts_column: String,
125        quantile: f64,
126        input: LogicalPlan,
127    ) -> DataFusionResult<Self> {
128        let input_schema = input.schema();
129        Self::check_schema(input_schema, &le_column, &field_column, &ts_column)?;
130        let output_schema = Self::convert_schema(input_schema, &le_column)?;
131        Ok(Self {
132            le_column,
133            ts_column,
134            input,
135            field_column,
136            quantile: quantile.into(),
137            output_schema,
138        })
139    }
140
141    pub const fn name() -> &'static str {
142        "HistogramFold"
143    }
144
145    fn check_schema(
146        input_schema: &DFSchemaRef,
147        le_column: &str,
148        field_column: &str,
149        ts_column: &str,
150    ) -> DataFusionResult<()> {
151        let check_column = |col| {
152            if !input_schema.has_column_with_unqualified_name(col) {
153                Err(DataFusionError::SchemaError(
154                    Box::new(datafusion::common::SchemaError::FieldNotFound {
155                        field: Box::new(Column::new(None::<String>, col)),
156                        valid_fields: input_schema.columns(),
157                    }),
158                    Box::new(None),
159                ))
160            } else {
161                Ok(())
162            }
163        };
164
165        check_column(le_column)?;
166        check_column(ts_column)?;
167        check_column(field_column)
168    }
169
170    pub fn to_execution_plan(&self, exec_input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
171        let input_schema = self.input.schema();
172        // safety: those fields are checked in `check_schema()`
173        let le_column_index = input_schema
174            .index_of_column_by_name(None, &self.le_column)
175            .unwrap();
176        let field_column_index = input_schema
177            .index_of_column_by_name(None, &self.field_column)
178            .unwrap();
179        let ts_column_index = input_schema
180            .index_of_column_by_name(None, &self.ts_column)
181            .unwrap();
182
183        let output_schema: SchemaRef = self.output_schema.inner().clone();
184        let properties = PlanProperties::new(
185            EquivalenceProperties::new(output_schema.clone()),
186            Partitioning::UnknownPartitioning(1),
187            EmissionType::Incremental,
188            Boundedness::Bounded,
189        );
190        Arc::new(HistogramFoldExec {
191            le_column_index,
192            field_column_index,
193            ts_column_index,
194            input: exec_input,
195            quantile: self.quantile.into(),
196            output_schema,
197            metric: ExecutionPlanMetricsSet::new(),
198            properties,
199        })
200    }
201
202    /// Transform the schema
203    ///
204    /// - `le` will be removed
205    fn convert_schema(
206        input_schema: &DFSchemaRef,
207        le_column: &str,
208    ) -> DataFusionResult<DFSchemaRef> {
209        let fields = input_schema.fields();
210        // safety: those fields are checked in `check_schema()`
211        let mut new_fields = Vec::with_capacity(fields.len() - 1);
212        for f in fields {
213            if f.name() != le_column {
214                new_fields.push((None, f.clone()));
215            }
216        }
217        Ok(Arc::new(DFSchema::new_with_metadata(
218            new_fields,
219            HashMap::new(),
220        )?))
221    }
222}
223
224impl PartialOrd for HistogramFold {
225    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
226        // Compare fields in order excluding output_schema
227        match self.le_column.partial_cmp(&other.le_column) {
228            Some(core::cmp::Ordering::Equal) => {}
229            ord => return ord,
230        }
231        match self.ts_column.partial_cmp(&other.ts_column) {
232            Some(core::cmp::Ordering::Equal) => {}
233            ord => return ord,
234        }
235        match self.input.partial_cmp(&other.input) {
236            Some(core::cmp::Ordering::Equal) => {}
237            ord => return ord,
238        }
239        match self.field_column.partial_cmp(&other.field_column) {
240            Some(core::cmp::Ordering::Equal) => {}
241            ord => return ord,
242        }
243        self.quantile.partial_cmp(&other.quantile)
244    }
245}
246
247#[derive(Debug)]
248pub struct HistogramFoldExec {
249    /// Index for `le` column in the schema of input.
250    le_column_index: usize,
251    input: Arc<dyn ExecutionPlan>,
252    output_schema: SchemaRef,
253    /// Index for field column in the schema of input.
254    field_column_index: usize,
255    ts_column_index: usize,
256    quantile: f64,
257    metric: ExecutionPlanMetricsSet,
258    properties: PlanProperties,
259}
260
261impl ExecutionPlan for HistogramFoldExec {
262    fn as_any(&self) -> &dyn Any {
263        self
264    }
265
266    fn properties(&self) -> &PlanProperties {
267        &self.properties
268    }
269
270    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
271        let mut cols = self
272            .tag_col_exprs()
273            .into_iter()
274            .map(|expr| PhysicalSortRequirement {
275                expr,
276                options: None,
277            })
278            .collect::<Vec<PhysicalSortRequirement>>();
279        // add ts
280        cols.push(PhysicalSortRequirement {
281            expr: Arc::new(PhyColumn::new(
282                self.input.schema().field(self.ts_column_index).name(),
283                self.ts_column_index,
284            )),
285            options: None,
286        });
287        // add le ASC
288        cols.push(PhysicalSortRequirement {
289            expr: Arc::new(PhyCast::new(
290                Arc::new(PhyColumn::new(
291                    self.input.schema().field(self.le_column_index).name(),
292                    self.le_column_index,
293                )),
294                DataType::Float64,
295                None,
296            )),
297            options: Some(SortOptions {
298                descending: false,  // +INF in the last
299                nulls_first: false, // not nullable
300            }),
301        });
302
303        // Safety: `cols` is not empty
304        let requirement = LexRequirement::new(cols).unwrap();
305
306        vec![Some(OrderingRequirements::Hard(vec![requirement]))]
307    }
308
309    fn required_input_distribution(&self) -> Vec<Distribution> {
310        self.input.required_input_distribution()
311    }
312
313    fn maintains_input_order(&self) -> Vec<bool> {
314        vec![true; self.children().len()]
315    }
316
317    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
318        vec![&self.input]
319    }
320
321    // cannot change schema with this method
322    fn with_new_children(
323        self: Arc<Self>,
324        children: Vec<Arc<dyn ExecutionPlan>>,
325    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
326        assert!(!children.is_empty());
327        Ok(Arc::new(Self {
328            input: children[0].clone(),
329            metric: self.metric.clone(),
330            le_column_index: self.le_column_index,
331            ts_column_index: self.ts_column_index,
332            quantile: self.quantile,
333            output_schema: self.output_schema.clone(),
334            field_column_index: self.field_column_index,
335            properties: self.properties.clone(),
336        }))
337    }
338
339    fn execute(
340        &self,
341        partition: usize,
342        context: Arc<TaskContext>,
343    ) -> DataFusionResult<SendableRecordBatchStream> {
344        let baseline_metric = BaselineMetrics::new(&self.metric, partition);
345
346        let batch_size = context.session_config().batch_size();
347        let input = self.input.execute(partition, context)?;
348        let output_schema = self.output_schema.clone();
349
350        let mut normal_indices = (0..input.schema().fields().len()).collect::<HashSet<_>>();
351        normal_indices.remove(&self.field_column_index);
352        normal_indices.remove(&self.le_column_index);
353        Ok(Box::pin(HistogramFoldStream {
354            le_column_index: self.le_column_index,
355            field_column_index: self.field_column_index,
356            quantile: self.quantile,
357            normal_indices: normal_indices.into_iter().collect(),
358            bucket_size: None,
359            input_buffer: vec![],
360            input,
361            output_schema,
362            input_schema: self.input.schema(),
363            mode: FoldMode::Optimistic,
364            safe_group: None,
365            metric: baseline_metric,
366            batch_size,
367            input_buffered_rows: 0,
368            output_buffer: HistogramFoldStream::empty_output_buffer(
369                &self.output_schema,
370                self.le_column_index,
371            )?,
372            output_buffered_rows: 0,
373        }))
374    }
375
376    fn metrics(&self) -> Option<MetricsSet> {
377        Some(self.metric.clone_inner())
378    }
379
380    fn partition_statistics(&self, _: Option<usize>) -> DataFusionResult<Statistics> {
381        Ok(Statistics {
382            num_rows: Precision::Absent,
383            total_byte_size: Precision::Absent,
384            column_statistics: vec![
385                ColumnStatistics::new_unknown();
386                // plus one more for the removed column by function `convert_schema`
387                self.schema().flattened_fields().len() + 1
388            ],
389        })
390    }
391
392    fn name(&self) -> &str {
393        "HistogramFoldExec"
394    }
395}
396
397impl HistogramFoldExec {
398    /// Return all the [PhysicalExpr] of tag columns in order.
399    ///
400    /// Tag columns are all columns except `le`, `field` and `ts` columns.
401    pub fn tag_col_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
402        self.input
403            .schema()
404            .fields()
405            .iter()
406            .enumerate()
407            .filter_map(|(idx, field)| {
408                if idx == self.le_column_index
409                    || idx == self.field_column_index
410                    || idx == self.ts_column_index
411                {
412                    None
413                } else {
414                    Some(Arc::new(PhyColumn::new(field.name(), idx)) as _)
415                }
416            })
417            .collect()
418    }
419}
420
421impl DisplayAs for HistogramFoldExec {
422    fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
423        match t {
424            DisplayFormatType::Default
425            | DisplayFormatType::Verbose
426            | DisplayFormatType::TreeRender => {
427                write!(
428                    f,
429                    "HistogramFoldExec: le=@{}, field=@{}, quantile={}",
430                    self.le_column_index, self.field_column_index, self.quantile
431                )
432            }
433        }
434    }
435}
436
437#[derive(Debug, Clone, Copy, PartialEq, Eq)]
438enum FoldMode {
439    Optimistic,
440    Safe,
441}
442
443pub struct HistogramFoldStream {
444    // internal states
445    le_column_index: usize,
446    field_column_index: usize,
447    quantile: f64,
448    /// Columns need not folding. This indices is based on input schema
449    normal_indices: Vec<usize>,
450    bucket_size: Option<usize>,
451    /// Expected output batch size
452    batch_size: usize,
453    output_schema: SchemaRef,
454    input_schema: SchemaRef,
455    mode: FoldMode,
456    safe_group: Option<SafeGroup>,
457
458    // buffers
459    input_buffer: Vec<RecordBatch>,
460    input_buffered_rows: usize,
461    output_buffer: Vec<Box<dyn MutableVector>>,
462    output_buffered_rows: usize,
463
464    // runtime things
465    input: SendableRecordBatchStream,
466    metric: BaselineMetrics,
467}
468
469#[derive(Debug, Default)]
470struct SafeGroup {
471    tag_values: Vec<Value>,
472    buckets: Vec<f64>,
473    counters: Vec<f64>,
474}
475
476impl RecordBatchStream for HistogramFoldStream {
477    fn schema(&self) -> SchemaRef {
478        self.output_schema.clone()
479    }
480}
481
482impl Stream for HistogramFoldStream {
483    type Item = DataFusionResult<RecordBatch>;
484
485    fn poll_next(
486        mut self: std::pin::Pin<&mut Self>,
487        cx: &mut std::task::Context<'_>,
488    ) -> Poll<Option<Self::Item>> {
489        let poll = loop {
490            match ready!(self.input.poll_next_unpin(cx)) {
491                Some(batch) => {
492                    let batch = batch?;
493                    let timer = Instant::now();
494                    let Some(result) = self.fold_input(batch)? else {
495                        self.metric.elapsed_compute().add_elapsed(timer);
496                        continue;
497                    };
498                    self.metric.elapsed_compute().add_elapsed(timer);
499                    break Poll::Ready(Some(result));
500                }
501                None => {
502                    self.flush_remaining()?;
503                    break Poll::Ready(self.take_output_buf()?.map(Ok));
504                }
505            }
506        };
507        self.metric.record_poll(poll)
508    }
509}
510
511impl HistogramFoldStream {
512    /// The inner most `Result` is for `poll_next()`
513    pub fn fold_input(
514        &mut self,
515        input: RecordBatch,
516    ) -> DataFusionResult<Option<DataFusionResult<RecordBatch>>> {
517        match self.mode {
518            FoldMode::Safe => {
519                self.push_input_buf(input);
520                self.process_safe_mode_buffer()?;
521            }
522            FoldMode::Optimistic => {
523                self.push_input_buf(input);
524                let Some(bucket_num) = self.calculate_bucket_num_from_buffer()? else {
525                    return Ok(None);
526                };
527                self.bucket_size = Some(bucket_num);
528
529                if self.input_buffered_rows < bucket_num {
530                    // not enough rows to fold
531                    return Ok(None);
532                }
533
534                self.fold_buf(bucket_num)?;
535            }
536        }
537
538        self.maybe_take_output()
539    }
540
541    /// Generate a group of empty [MutableVector]s from the output schema.
542    ///
543    /// For simplicity, this method will insert a placeholder for `le`. So that
544    /// the output buffers has the same schema with input. This placeholder needs
545    /// to be removed before returning the output batch.
546    pub fn empty_output_buffer(
547        schema: &SchemaRef,
548        le_column_index: usize,
549    ) -> DataFusionResult<Vec<Box<dyn MutableVector>>> {
550        let mut builders = Vec::with_capacity(schema.fields().len() + 1);
551        for field in schema.fields() {
552            let concrete_datatype = ConcreteDataType::try_from(field.data_type()).unwrap();
553            let mutable_vector = concrete_datatype.create_mutable_vector(0);
554            builders.push(mutable_vector);
555        }
556        builders.insert(
557            le_column_index,
558            ConcreteDataType::float64_datatype().create_mutable_vector(0),
559        );
560
561        Ok(builders)
562    }
563
564    /// Determines bucket count using buffered batches, concatenating them to
565    /// detect the first complete bucket that may span batch boundaries.
566    fn calculate_bucket_num_from_buffer(&mut self) -> DataFusionResult<Option<usize>> {
567        if let Some(size) = self.bucket_size {
568            return Ok(Some(size));
569        }
570
571        if self.input_buffer.is_empty() {
572            return Ok(None);
573        }
574
575        let batch_refs: Vec<&RecordBatch> = self.input_buffer.iter().collect();
576        let batch = concat_batches(&self.input_schema, batch_refs)?;
577        self.find_first_complete_bucket(&batch)
578    }
579
580    fn find_first_complete_bucket(&self, batch: &RecordBatch) -> DataFusionResult<Option<usize>> {
581        if batch.num_rows() == 0 {
582            return Ok(None);
583        }
584
585        let vectors = Helper::try_into_vectors(batch.columns())
586            .map_err(|e| DataFusionError::Execution(e.to_string()))?;
587        let le_array = batch.column(self.le_column_index).as_string::<i32>();
588
589        let mut tag_values_buf = Vec::with_capacity(self.normal_indices.len());
590        self.collect_tag_values(&vectors, 0, &mut tag_values_buf);
591        let mut group_start = 0usize;
592
593        for row in 0..batch.num_rows() {
594            if !self.is_same_group(&vectors, row, &tag_values_buf) {
595                // new group begins
596                self.collect_tag_values(&vectors, row, &mut tag_values_buf);
597                group_start = row;
598            }
599
600            if Self::is_positive_infinity(le_array, row) {
601                return Ok(Some(row - group_start + 1));
602            }
603        }
604
605        Ok(None)
606    }
607
608    /// Fold record batches from input buffer and put to output buffer
609    fn fold_buf(&mut self, bucket_num: usize) -> DataFusionResult<()> {
610        let batch = concat_batches(&self.input_schema, self.input_buffer.drain(..).as_ref())?;
611        let mut remaining_rows = self.input_buffered_rows;
612        let mut cursor = 0;
613
614        // TODO(LFC): Try to get rid of the Arrow array to vector conversion here.
615        let vectors = Helper::try_into_vectors(batch.columns())
616            .map_err(|e| DataFusionError::Execution(e.to_string()))?;
617        let le_array = batch.column(self.le_column_index);
618        let le_array = le_array.as_string::<i32>();
619        let field_array = batch.column(self.field_column_index);
620        let field_array = field_array.as_primitive::<Float64Type>();
621        let mut tag_values_buf = Vec::with_capacity(self.normal_indices.len());
622
623        while remaining_rows >= bucket_num && self.mode == FoldMode::Optimistic {
624            self.collect_tag_values(&vectors, cursor, &mut tag_values_buf);
625            if !self.validate_optimistic_group(
626                &vectors,
627                le_array,
628                cursor,
629                bucket_num,
630                &tag_values_buf,
631            ) {
632                let remaining_input_batch = batch.slice(cursor, remaining_rows);
633                self.switch_to_safe_mode(remaining_input_batch)?;
634                return Ok(());
635            }
636
637            // "sample" normal columns
638            for (idx, value) in self.normal_indices.iter().zip(tag_values_buf.iter()) {
639                self.output_buffer[*idx].push_value_ref(value);
640            }
641            // "fold" `le` and field columns
642            let mut bucket = Vec::with_capacity(bucket_num);
643            let mut counters = Vec::with_capacity(bucket_num);
644            for bias in 0..bucket_num {
645                let position = cursor + bias;
646                let le = if le_array.is_valid(position) {
647                    le_array.value(position).parse::<f64>().unwrap_or(f64::NAN)
648                } else {
649                    f64::NAN
650                };
651                bucket.push(le);
652
653                let counter = if field_array.is_valid(position) {
654                    field_array.value(position)
655                } else {
656                    f64::NAN
657                };
658                counters.push(counter);
659            }
660            // ignore invalid data
661            let result = Self::evaluate_row(self.quantile, &bucket, &counters).unwrap_or(f64::NAN);
662            self.output_buffer[self.field_column_index].push_value_ref(&ValueRef::from(result));
663            cursor += bucket_num;
664            remaining_rows -= bucket_num;
665            self.output_buffered_rows += 1;
666        }
667
668        let remaining_input_batch = batch.slice(cursor, remaining_rows);
669        self.input_buffered_rows = remaining_input_batch.num_rows();
670        if self.input_buffered_rows > 0 {
671            self.input_buffer.push(remaining_input_batch);
672        }
673
674        Ok(())
675    }
676
677    fn push_input_buf(&mut self, batch: RecordBatch) {
678        self.input_buffered_rows += batch.num_rows();
679        self.input_buffer.push(batch);
680    }
681
682    fn maybe_take_output(&mut self) -> DataFusionResult<Option<DataFusionResult<RecordBatch>>> {
683        if self.output_buffered_rows >= self.batch_size {
684            return Ok(self.take_output_buf()?.map(Ok));
685        }
686        Ok(None)
687    }
688
689    fn switch_to_safe_mode(&mut self, remaining_batch: RecordBatch) -> DataFusionResult<()> {
690        self.mode = FoldMode::Safe;
691        self.bucket_size = None;
692        self.input_buffer.clear();
693        self.input_buffered_rows = remaining_batch.num_rows();
694
695        if self.input_buffered_rows > 0 {
696            self.input_buffer.push(remaining_batch);
697            self.process_safe_mode_buffer()?;
698        }
699
700        Ok(())
701    }
702
703    fn collect_tag_values<'a>(
704        &self,
705        vectors: &'a [VectorRef],
706        row: usize,
707        tag_values: &mut Vec<ValueRef<'a>>,
708    ) {
709        tag_values.clear();
710        for idx in self.normal_indices.iter() {
711            tag_values.push(vectors[*idx].get_ref(row));
712        }
713    }
714
715    fn validate_optimistic_group(
716        &self,
717        vectors: &[VectorRef],
718        le_array: &StringArray,
719        cursor: usize,
720        bucket_num: usize,
721        tag_values: &[ValueRef<'_>],
722    ) -> bool {
723        let inf_index = cursor + bucket_num - 1;
724        if !Self::is_positive_infinity(le_array, inf_index) {
725            return false;
726        }
727
728        for offset in 1..bucket_num {
729            let row = cursor + offset;
730            for (idx, expected) in self.normal_indices.iter().zip(tag_values.iter()) {
731                if vectors[*idx].get_ref(row) != *expected {
732                    return false;
733                }
734            }
735        }
736        true
737    }
738
739    /// Checks whether a row belongs to the current group (same series).
740    fn is_same_group(
741        &self,
742        vectors: &[VectorRef],
743        row: usize,
744        tag_values: &[ValueRef<'_>],
745    ) -> bool {
746        self.normal_indices
747            .iter()
748            .zip(tag_values.iter())
749            .all(|(idx, expected)| vectors[*idx].get_ref(row) == *expected)
750    }
751
752    fn push_output_row(&mut self, tag_values: &[ValueRef<'_>], result: f64) {
753        debug_assert_eq!(self.normal_indices.len(), tag_values.len());
754        for (idx, value) in self.normal_indices.iter().zip(tag_values.iter()) {
755            self.output_buffer[*idx].push_value_ref(value);
756        }
757        self.output_buffer[self.field_column_index].push_value_ref(&ValueRef::from(result));
758        self.output_buffered_rows += 1;
759    }
760
761    fn finalize_safe_group(&mut self) -> DataFusionResult<()> {
762        if let Some(group) = self.safe_group.take() {
763            if group.tag_values.is_empty() {
764                return Ok(());
765            }
766
767            let has_inf = group
768                .buckets
769                .last()
770                .map(|v| v.is_infinite() && v.is_sign_positive())
771                .unwrap_or(false);
772            let result = if group.buckets.len() < 2 || !has_inf {
773                f64::NAN
774            } else {
775                Self::evaluate_row(self.quantile, &group.buckets, &group.counters)
776                    .unwrap_or(f64::NAN)
777            };
778            let mut tag_value_refs = Vec::with_capacity(group.tag_values.len());
779            tag_value_refs.extend(group.tag_values.iter().map(|v| v.as_value_ref()));
780            self.push_output_row(&tag_value_refs, result);
781        }
782        Ok(())
783    }
784
785    fn process_safe_mode_buffer(&mut self) -> DataFusionResult<()> {
786        if self.input_buffer.is_empty() {
787            self.input_buffered_rows = 0;
788            return Ok(());
789        }
790
791        let batch = concat_batches(&self.input_schema, self.input_buffer.drain(..).as_ref())?;
792        self.input_buffered_rows = 0;
793        let vectors = Helper::try_into_vectors(batch.columns())
794            .map_err(|e| DataFusionError::Execution(e.to_string()))?;
795        let le_array = batch.column(self.le_column_index).as_string::<i32>();
796        let field_array = batch
797            .column(self.field_column_index)
798            .as_primitive::<Float64Type>();
799        let mut tag_values_buf = Vec::with_capacity(self.normal_indices.len());
800
801        for row in 0..batch.num_rows() {
802            self.collect_tag_values(&vectors, row, &mut tag_values_buf);
803            let should_start_new_group = self
804                .safe_group
805                .as_ref()
806                .is_none_or(|group| !Self::tag_values_equal(&group.tag_values, &tag_values_buf));
807            if should_start_new_group {
808                self.finalize_safe_group()?;
809                self.safe_group = Some(SafeGroup {
810                    tag_values: tag_values_buf.iter().cloned().map(Value::from).collect(),
811                    buckets: Vec::new(),
812                    counters: Vec::new(),
813                });
814            }
815
816            let Some(group) = self.safe_group.as_mut() else {
817                continue;
818            };
819
820            let bucket = if le_array.is_valid(row) {
821                le_array.value(row).parse::<f64>().unwrap_or(f64::NAN)
822            } else {
823                f64::NAN
824            };
825            let counter = if field_array.is_valid(row) {
826                field_array.value(row)
827            } else {
828                f64::NAN
829            };
830
831            group.buckets.push(bucket);
832            group.counters.push(counter);
833        }
834
835        Ok(())
836    }
837
838    fn tag_values_equal(group_values: &[Value], current: &[ValueRef<'_>]) -> bool {
839        group_values.len() == current.len()
840            && group_values
841                .iter()
842                .zip(current.iter())
843                .all(|(group, now)| group.as_value_ref() == *now)
844    }
845
846    /// Compute result from output buffer
847    fn take_output_buf(&mut self) -> DataFusionResult<Option<RecordBatch>> {
848        if self.output_buffered_rows == 0 {
849            if self.input_buffered_rows != 0 {
850                warn!(
851                    "input buffer is not empty, {} rows remaining",
852                    self.input_buffered_rows
853                );
854            }
855            return Ok(None);
856        }
857
858        let mut output_buf = Self::empty_output_buffer(&self.output_schema, self.le_column_index)?;
859        std::mem::swap(&mut self.output_buffer, &mut output_buf);
860        let mut columns = Vec::with_capacity(output_buf.len());
861        for builder in output_buf.iter_mut() {
862            columns.push(builder.to_vector().to_arrow_array());
863        }
864        // remove the placeholder column for `le`
865        columns.remove(self.le_column_index);
866
867        self.output_buffered_rows = 0;
868        RecordBatch::try_new(self.output_schema.clone(), columns)
869            .map(Some)
870            .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))
871    }
872
873    fn flush_remaining(&mut self) -> DataFusionResult<()> {
874        if self.mode == FoldMode::Optimistic && self.input_buffered_rows > 0 {
875            let buffered_batches: Vec<_> = self.input_buffer.drain(..).collect();
876            if !buffered_batches.is_empty() {
877                let batch = concat_batches(&self.input_schema, buffered_batches.as_slice())?;
878                self.switch_to_safe_mode(batch)?;
879            } else {
880                self.input_buffered_rows = 0;
881            }
882        }
883
884        if self.mode == FoldMode::Safe {
885            self.process_safe_mode_buffer()?;
886            self.finalize_safe_group()?;
887        }
888
889        Ok(())
890    }
891
892    fn is_positive_infinity(le_array: &StringArray, index: usize) -> bool {
893        le_array.is_valid(index)
894            && matches!(
895                le_array.value(index).parse::<f64>(),
896                Ok(value) if value.is_infinite() && value.is_sign_positive()
897            )
898    }
899
900    /// Evaluate the field column and return the result
901    fn evaluate_row(quantile: f64, bucket: &[f64], counter: &[f64]) -> DataFusionResult<f64> {
902        // check bucket
903        if bucket.len() <= 1 {
904            return Ok(f64::NAN);
905        }
906        if bucket.last().unwrap().is_finite() {
907            return Err(DataFusionError::Execution(
908                "last bucket should be +Inf".to_string(),
909            ));
910        }
911        if bucket.len() != counter.len() {
912            return Err(DataFusionError::Execution(
913                "bucket and counter should have the same length".to_string(),
914            ));
915        }
916        // check quantile
917        if quantile < 0.0 {
918            return Ok(f64::NEG_INFINITY);
919        } else if quantile > 1.0 {
920            return Ok(f64::INFINITY);
921        } else if quantile.is_nan() {
922            return Ok(f64::NAN);
923        }
924
925        // check input value
926        if !bucket.windows(2).all(|w| w[0] <= w[1]) {
927            return Ok(f64::NAN);
928        }
929        let counter = {
930            let needs_fix =
931                counter.iter().any(|v| !v.is_finite()) || !counter.windows(2).all(|w| w[0] <= w[1]);
932            if !needs_fix {
933                Cow::Borrowed(counter)
934            } else {
935                let mut fixed = Vec::with_capacity(counter.len());
936                let mut prev = 0.0;
937                for (idx, &v) in counter.iter().enumerate() {
938                    let mut val = if v.is_finite() { v } else { prev };
939                    if idx > 0 && val < prev {
940                        val = prev;
941                    }
942                    fixed.push(val);
943                    prev = val;
944                }
945                Cow::Owned(fixed)
946            }
947        };
948
949        let total = *counter.last().unwrap();
950        let expected_pos = total * quantile;
951        let mut fit_bucket_pos = 0;
952        while fit_bucket_pos < bucket.len() && counter[fit_bucket_pos] < expected_pos {
953            fit_bucket_pos += 1;
954        }
955        if fit_bucket_pos >= bucket.len() - 1 {
956            Ok(bucket[bucket.len() - 2])
957        } else {
958            let upper_bound = bucket[fit_bucket_pos];
959            let upper_count = counter[fit_bucket_pos];
960            let mut lower_bound = bucket[0].min(0.0);
961            let mut lower_count = 0.0;
962            if fit_bucket_pos > 0 {
963                lower_bound = bucket[fit_bucket_pos - 1];
964                lower_count = counter[fit_bucket_pos - 1];
965            }
966            if (upper_count - lower_count).abs() < 1e-10 {
967                return Ok(f64::NAN);
968            }
969            Ok(lower_bound
970                + (upper_bound - lower_bound) / (upper_count - lower_count)
971                    * (expected_pos - lower_count))
972        }
973    }
974}
975
976#[cfg(test)]
977mod test {
978    use std::sync::Arc;
979
980    use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
981    use datafusion::arrow::datatypes::{Field, Schema, SchemaRef, TimeUnit};
982    use datafusion::common::ToDFSchema;
983    use datafusion::datasource::memory::MemorySourceConfig;
984    use datafusion::datasource::source::DataSourceExec;
985    use datafusion::prelude::SessionContext;
986    use datatypes::arrow_array::StringArray;
987
988    use super::*;
989
990    fn prepare_test_data() -> DataSourceExec {
991        let schema = Arc::new(Schema::new(vec![
992            Field::new("host", DataType::Utf8, true),
993            Field::new("le", DataType::Utf8, true),
994            Field::new("val", DataType::Float64, true),
995        ]));
996
997        // 12 items
998        let host_column_1 = Arc::new(StringArray::from(vec![
999            "host_1", "host_1", "host_1", "host_1", "host_1", "host_1", "host_1", "host_1",
1000            "host_1", "host_1", "host_1", "host_1",
1001        ])) as _;
1002        let le_column_1 = Arc::new(StringArray::from(vec![
1003            "0.001", "0.1", "10", "1000", "+Inf", "0.001", "0.1", "10", "1000", "+inf", "0.001",
1004            "0.1",
1005        ])) as _;
1006        let val_column_1 = Arc::new(Float64Array::from(vec![
1007            0_0.0, 1.0, 1.0, 5.0, 5.0, 0_0.0, 20.0, 60.0, 70.0, 100.0, 0_1.0, 1.0,
1008        ])) as _;
1009
1010        // 2 items
1011        let host_column_2 = Arc::new(StringArray::from(vec!["host_1", "host_1"])) as _;
1012        let le_column_2 = Arc::new(StringArray::from(vec!["10", "1000"])) as _;
1013        let val_column_2 = Arc::new(Float64Array::from(vec![1.0, 1.0])) as _;
1014
1015        // 11 items
1016        let host_column_3 = Arc::new(StringArray::from(vec![
1017            "host_1", "host_2", "host_2", "host_2", "host_2", "host_2", "host_2", "host_2",
1018            "host_2", "host_2", "host_2",
1019        ])) as _;
1020        let le_column_3 = Arc::new(StringArray::from(vec![
1021            "+INF", "0.001", "0.1", "10", "1000", "+iNf", "0.001", "0.1", "10", "1000", "+Inf",
1022        ])) as _;
1023        let val_column_3 = Arc::new(Float64Array::from(vec![
1024            1.0, 0_0.0, 0.0, 0.0, 0.0, 0.0, 0_0.0, 1.0, 2.0, 3.0, 4.0,
1025        ])) as _;
1026
1027        let data_1 = RecordBatch::try_new(
1028            schema.clone(),
1029            vec![host_column_1, le_column_1, val_column_1],
1030        )
1031        .unwrap();
1032        let data_2 = RecordBatch::try_new(
1033            schema.clone(),
1034            vec![host_column_2, le_column_2, val_column_2],
1035        )
1036        .unwrap();
1037        let data_3 = RecordBatch::try_new(
1038            schema.clone(),
1039            vec![host_column_3, le_column_3, val_column_3],
1040        )
1041        .unwrap();
1042
1043        DataSourceExec::new(Arc::new(
1044            MemorySourceConfig::try_new(&[vec![data_1, data_2, data_3]], schema, None).unwrap(),
1045        ))
1046    }
1047
1048    fn build_fold_exec_from_batches(
1049        batches: Vec<RecordBatch>,
1050        schema: SchemaRef,
1051        quantile: f64,
1052        ts_column_index: usize,
1053    ) -> Arc<HistogramFoldExec> {
1054        let memory_exec = Arc::new(DataSourceExec::new(Arc::new(
1055            MemorySourceConfig::try_new(&[batches], schema.clone(), None).unwrap(),
1056        )));
1057        let output_schema: SchemaRef = Arc::new(
1058            HistogramFold::convert_schema(
1059                &Arc::new(memory_exec.schema().to_dfschema().unwrap()),
1060                "le",
1061            )
1062            .unwrap()
1063            .as_arrow()
1064            .clone(),
1065        );
1066        let properties = PlanProperties::new(
1067            EquivalenceProperties::new(output_schema.clone()),
1068            Partitioning::UnknownPartitioning(1),
1069            EmissionType::Incremental,
1070            Boundedness::Bounded,
1071        );
1072
1073        Arc::new(HistogramFoldExec {
1074            le_column_index: 1,
1075            field_column_index: 2,
1076            quantile,
1077            ts_column_index,
1078            input: memory_exec,
1079            output_schema,
1080            metric: ExecutionPlanMetricsSet::new(),
1081            properties,
1082        })
1083    }
1084
1085    #[tokio::test]
1086    async fn fold_overall() {
1087        let memory_exec = Arc::new(prepare_test_data());
1088        let output_schema: SchemaRef = Arc::new(
1089            HistogramFold::convert_schema(
1090                &Arc::new(memory_exec.schema().to_dfschema().unwrap()),
1091                "le",
1092            )
1093            .unwrap()
1094            .as_arrow()
1095            .clone(),
1096        );
1097        let properties = PlanProperties::new(
1098            EquivalenceProperties::new(output_schema.clone()),
1099            Partitioning::UnknownPartitioning(1),
1100            EmissionType::Incremental,
1101            Boundedness::Bounded,
1102        );
1103        let fold_exec = Arc::new(HistogramFoldExec {
1104            le_column_index: 1,
1105            field_column_index: 2,
1106            quantile: 0.4,
1107            ts_column_index: 9999, // not exist but doesn't matter
1108            input: memory_exec,
1109            output_schema,
1110            metric: ExecutionPlanMetricsSet::new(),
1111            properties,
1112        });
1113
1114        let session_context = SessionContext::default();
1115        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1116            .await
1117            .unwrap();
1118        let result_literal = datatypes::arrow::util::pretty::pretty_format_batches(&result)
1119            .unwrap()
1120            .to_string();
1121
1122        let expected = String::from(
1123            "+--------+-------------------+
1124| host   | val               |
1125+--------+-------------------+
1126| host_1 | 257.5             |
1127| host_1 | 5.05              |
1128| host_1 | 0.0004            |
1129| host_2 | NaN               |
1130| host_2 | 6.040000000000001 |
1131+--------+-------------------+",
1132        );
1133        assert_eq!(result_literal, expected);
1134    }
1135
1136    #[test]
1137    fn confirm_schema() {
1138        let input_schema = Schema::new(vec![
1139            Field::new("host", DataType::Utf8, true),
1140            Field::new("le", DataType::Utf8, true),
1141            Field::new("val", DataType::Float64, true),
1142        ])
1143        .to_dfschema_ref()
1144        .unwrap();
1145        let expected_output_schema = Schema::new(vec![
1146            Field::new("host", DataType::Utf8, true),
1147            Field::new("val", DataType::Float64, true),
1148        ])
1149        .to_dfschema_ref()
1150        .unwrap();
1151
1152        let actual = HistogramFold::convert_schema(&input_schema, "le").unwrap();
1153        assert_eq!(actual, expected_output_schema)
1154    }
1155
1156    #[tokio::test]
1157    async fn fallback_to_safe_mode_on_missing_inf() {
1158        let schema = Arc::new(Schema::new(vec![
1159            Field::new("host", DataType::Utf8, true),
1160            Field::new("le", DataType::Utf8, true),
1161            Field::new("val", DataType::Float64, true),
1162        ]));
1163        let host_column = Arc::new(StringArray::from(vec!["a", "a", "a", "a", "b", "b"])) as _;
1164        let le_column = Arc::new(StringArray::from(vec![
1165            "0.1", "+Inf", "0.1", "1.0", "0.1", "+Inf",
1166        ])) as _;
1167        let val_column = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 3.0, 1.0, 5.0])) as _;
1168        let batch =
1169            RecordBatch::try_new(schema.clone(), vec![host_column, le_column, val_column]).unwrap();
1170        let fold_exec = build_fold_exec_from_batches(vec![batch], schema, 0.5, 0);
1171        let session_context = SessionContext::default();
1172        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1173            .await
1174            .unwrap();
1175        let result_literal = datatypes::arrow::util::pretty::pretty_format_batches(&result)
1176            .unwrap()
1177            .to_string();
1178
1179        let expected = String::from(
1180            "+------+-----+
1181| host | val |
1182+------+-----+
1183| a    | 0.1 |
1184| a    | NaN |
1185| b    | 0.1 |
1186+------+-----+",
1187        );
1188        assert_eq!(result_literal, expected);
1189    }
1190
1191    #[tokio::test]
1192    async fn emit_nan_when_no_inf_present() {
1193        let schema = Arc::new(Schema::new(vec![
1194            Field::new("host", DataType::Utf8, true),
1195            Field::new("le", DataType::Utf8, true),
1196            Field::new("val", DataType::Float64, true),
1197        ]));
1198        let host_column = Arc::new(StringArray::from(vec!["c", "c"])) as _;
1199        let le_column = Arc::new(StringArray::from(vec!["0.1", "1.0"])) as _;
1200        let val_column = Arc::new(Float64Array::from(vec![1.0, 2.0])) as _;
1201        let batch =
1202            RecordBatch::try_new(schema.clone(), vec![host_column, le_column, val_column]).unwrap();
1203        let fold_exec = build_fold_exec_from_batches(vec![batch], schema, 0.9, 0);
1204        let session_context = SessionContext::default();
1205        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1206            .await
1207            .unwrap();
1208        let result_literal = datatypes::arrow::util::pretty::pretty_format_batches(&result)
1209            .unwrap()
1210            .to_string();
1211
1212        let expected = String::from(
1213            "+------+-----+
1214| host | val |
1215+------+-----+
1216| c    | NaN |
1217+------+-----+",
1218        );
1219        assert_eq!(result_literal, expected);
1220    }
1221
1222    #[tokio::test]
1223    async fn safe_mode_handles_misaligned_groups() {
1224        let schema = Arc::new(Schema::new(vec![
1225            Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
1226            Field::new("le", DataType::Utf8, true),
1227            Field::new("val", DataType::Float64, true),
1228        ]));
1229
1230        let ts_column = Arc::new(TimestampMillisecondArray::from(vec![
1231            2900000, 2900000, 2900000, 3000000, 3000000, 3000000, 3000000, 3005000, 3005000,
1232            3010000, 3010000, 3010000, 3010000, 3010000,
1233        ])) as _;
1234        let le_column = Arc::new(StringArray::from(vec![
1235            "0.1", "1", "5", "0.1", "1", "5", "+Inf", "0.1", "+Inf", "0.1", "1", "3", "5", "+Inf",
1236        ])) as _;
1237        let val_column = Arc::new(Float64Array::from(vec![
1238            0.0, 0.0, 0.0, 50.0, 70.0, 110.0, 120.0, 10.0, 30.0, 10.0, 20.0, 30.0, 40.0, 50.0,
1239        ])) as _;
1240        let batch =
1241            RecordBatch::try_new(schema.clone(), vec![ts_column, le_column, val_column]).unwrap();
1242        let fold_exec = build_fold_exec_from_batches(vec![batch], schema, 0.5, 0);
1243        let session_context = SessionContext::default();
1244        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1245            .await
1246            .unwrap();
1247
1248        let mut values = Vec::new();
1249        for batch in result {
1250            let array = batch.column(1).as_primitive::<Float64Type>();
1251            values.extend(array.iter().map(|v| v.unwrap()));
1252        }
1253
1254        assert_eq!(values.len(), 4);
1255        assert!(values[0].is_nan());
1256        assert!((values[1] - 0.55).abs() < 1e-10);
1257        assert!((values[2] - 0.1).abs() < 1e-10);
1258        assert!((values[3] - 2.0).abs() < 1e-10);
1259    }
1260
1261    #[tokio::test]
1262    async fn missing_buckets_at_first_timestamp() {
1263        let schema = Arc::new(Schema::new(vec![
1264            Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
1265            Field::new("le", DataType::Utf8, true),
1266            Field::new("val", DataType::Float64, true),
1267        ]));
1268
1269        let ts_column = Arc::new(TimestampMillisecondArray::from(vec![
1270            2_900_000, 3_000_000, 3_000_000, 3_000_000, 3_000_000, 3_005_000, 3_005_000, 3_010_000,
1271            3_010_000, 3_010_000, 3_010_000, 3_010_000,
1272        ])) as _;
1273        let le_column = Arc::new(StringArray::from(vec![
1274            "0.1", "0.1", "1", "5", "+Inf", "0.1", "+Inf", "0.1", "1", "3", "5", "+Inf",
1275        ])) as _;
1276        let val_column = Arc::new(Float64Array::from(vec![
1277            0.0, 50.0, 70.0, 110.0, 120.0, 10.0, 30.0, 10.0, 20.0, 30.0, 40.0, 50.0,
1278        ])) as _;
1279
1280        let batch =
1281            RecordBatch::try_new(schema.clone(), vec![ts_column, le_column, val_column]).unwrap();
1282        let fold_exec = build_fold_exec_from_batches(vec![batch], schema, 0.5, 0);
1283        let session_context = SessionContext::default();
1284        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1285            .await
1286            .unwrap();
1287
1288        let mut values = Vec::new();
1289        for batch in result {
1290            let array = batch.column(1).as_primitive::<Float64Type>();
1291            values.extend(array.iter().map(|v| v.unwrap()));
1292        }
1293
1294        assert_eq!(values.len(), 4);
1295        assert!(values[0].is_nan());
1296        assert!((values[1] - 0.55).abs() < 1e-10);
1297        assert!((values[2] - 0.1).abs() < 1e-10);
1298        assert!((values[3] - 2.0).abs() < 1e-10);
1299    }
1300
1301    #[tokio::test]
1302    async fn missing_inf_in_first_group() {
1303        let schema = Arc::new(Schema::new(vec![
1304            Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
1305            Field::new("le", DataType::Utf8, true),
1306            Field::new("val", DataType::Float64, true),
1307        ]));
1308
1309        let ts_column = Arc::new(TimestampMillisecondArray::from(vec![
1310            1000, 1000, 1000, 2000, 2000, 2000, 2000,
1311        ])) as _;
1312        let le_column = Arc::new(StringArray::from(vec![
1313            "0.1", "1", "5", "0.1", "1", "5", "+Inf",
1314        ])) as _;
1315        let val_column = Arc::new(Float64Array::from(vec![
1316            0.0, 0.0, 0.0, 10.0, 20.0, 30.0, 30.0,
1317        ])) as _;
1318        let batch =
1319            RecordBatch::try_new(schema.clone(), vec![ts_column, le_column, val_column]).unwrap();
1320        let fold_exec = build_fold_exec_from_batches(vec![batch], schema, 0.5, 0);
1321        let session_context = SessionContext::default();
1322        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1323            .await
1324            .unwrap();
1325
1326        let mut values = Vec::new();
1327        for batch in result {
1328            let array = batch.column(1).as_primitive::<Float64Type>();
1329            values.extend(array.iter().map(|v| v.unwrap()));
1330        }
1331
1332        assert_eq!(values.len(), 2);
1333        assert!(values[0].is_nan());
1334        assert!((values[1] - 0.55).abs() < 1e-10, "{values:?}");
1335    }
1336
1337    #[test]
1338    fn evaluate_row_normal_case() {
1339        let bucket = [0.0, 1.0, 2.0, 3.0, 4.0, f64::INFINITY];
1340
1341        #[derive(Debug)]
1342        struct Case {
1343            quantile: f64,
1344            counters: Vec<f64>,
1345            expected: f64,
1346        }
1347
1348        let cases = [
1349            Case {
1350                quantile: 0.9,
1351                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1352                expected: 4.0,
1353            },
1354            Case {
1355                quantile: 0.89,
1356                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1357                expected: 4.0,
1358            },
1359            Case {
1360                quantile: 0.78,
1361                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1362                expected: 3.9,
1363            },
1364            Case {
1365                quantile: 0.5,
1366                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1367                expected: 2.5,
1368            },
1369            Case {
1370                quantile: 0.5,
1371                counters: vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1372                expected: f64::NAN,
1373            },
1374            Case {
1375                quantile: 1.0,
1376                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1377                expected: 4.0,
1378            },
1379            Case {
1380                quantile: 0.0,
1381                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1382                expected: f64::NAN,
1383            },
1384            Case {
1385                quantile: 1.1,
1386                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1387                expected: f64::INFINITY,
1388            },
1389            Case {
1390                quantile: -1.0,
1391                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1392                expected: f64::NEG_INFINITY,
1393            },
1394        ];
1395
1396        for case in cases {
1397            let actual =
1398                HistogramFoldStream::evaluate_row(case.quantile, &bucket, &case.counters).unwrap();
1399            assert_eq!(
1400                format!("{actual}"),
1401                format!("{}", case.expected),
1402                "{:?}",
1403                case
1404            );
1405        }
1406    }
1407
1408    #[test]
1409    fn evaluate_out_of_order_input() {
1410        let bucket = [0.0, 1.0, 2.0, 3.0, 4.0, f64::INFINITY];
1411        let counters = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0];
1412        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters).unwrap();
1413        assert_eq!(0.0, result);
1414    }
1415
1416    #[test]
1417    fn evaluate_wrong_bucket() {
1418        let bucket = [0.0, 1.0, 2.0, 3.0, 4.0, f64::INFINITY, 5.0];
1419        let counters = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
1420        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters);
1421        assert!(result.is_err());
1422    }
1423
1424    #[test]
1425    fn evaluate_small_fraction() {
1426        let bucket = [0.0, 2.0, 4.0, 6.0, f64::INFINITY];
1427        let counters = [0.0, 1.0 / 300.0, 2.0 / 300.0, 0.01, 0.01];
1428        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters).unwrap();
1429        assert_eq!(3.0, result);
1430    }
1431
1432    #[test]
1433    fn evaluate_non_monotonic_counter() {
1434        let bucket = [0.0, 1.0, 2.0, 3.0, f64::INFINITY];
1435        let counters = [0.1, 0.2, 0.4, 0.17, 0.5];
1436        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters).unwrap();
1437        assert!((result - 1.25).abs() < 1e-10, "{result}");
1438    }
1439
1440    #[test]
1441    fn evaluate_nan_counter() {
1442        let bucket = [0.0, 1.0, 2.0, 3.0, f64::INFINITY];
1443        let counters = [f64::NAN, 1.0, 2.0, 3.0, 3.0];
1444        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters).unwrap();
1445        assert!((result - 1.5).abs() < 1e-10, "{result}");
1446    }
1447}