promql/extension_plan/
histogram_fold.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::collections::{HashMap, HashSet};
17use std::sync::Arc;
18use std::task::Poll;
19use std::time::Instant;
20
21use common_telemetry::warn;
22use datafusion::arrow::array::AsArray;
23use datafusion::arrow::compute::{self, SortOptions, concat_batches};
24use datafusion::arrow::datatypes::{DataType, Float64Type, SchemaRef};
25use datafusion::arrow::record_batch::RecordBatch;
26use datafusion::common::stats::Precision;
27use datafusion::common::{ColumnStatistics, DFSchema, DFSchemaRef, Statistics};
28use datafusion::error::{DataFusionError, Result as DataFusionResult};
29use datafusion::execution::TaskContext;
30use datafusion::logical_expr::{LogicalPlan, UserDefinedLogicalNodeCore};
31use datafusion::physical_expr::{
32    EquivalenceProperties, LexRequirement, OrderingRequirements, PhysicalSortRequirement,
33};
34use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
35use datafusion::physical_plan::expressions::{CastExpr as PhyCast, Column as PhyColumn};
36use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
37use datafusion::physical_plan::{
38    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr,
39    PlanProperties, RecordBatchStream, SendableRecordBatchStream,
40};
41use datafusion::prelude::{Column, Expr};
42use datatypes::prelude::{ConcreteDataType, DataType as GtDataType};
43use datatypes::value::{OrderedF64, ValueRef};
44use datatypes::vectors::{Helper, MutableVector};
45use futures::{Stream, StreamExt, ready};
46
47/// `HistogramFold` will fold the conventional (non-native) histogram ([1]) for later
48/// computing.
49///
50/// Specifically, it will transform the `le` and `field` column into a complex
51/// type, and samples on other tag columns:
52/// - `le` will become a [ListArray] of [f64]. With each bucket bound parsed
53/// - `field` will become a [ListArray] of [f64]
54/// - other columns will be sampled every `bucket_num` element, but their types won't change.
55///
56/// Due to the folding or sampling, the output rows number will become `input_rows` / `bucket_num`.
57///
58/// # Requirement
59/// - Input should be sorted on `<tag list>, ts, le ASC`.
60/// - The value set of `le` should be same. I.e., buckets of every series should be same.
61///
62/// [1]: https://prometheus.io/docs/concepts/metric_types/#histogram
63#[derive(Debug, PartialEq, Hash, Eq)]
64pub struct HistogramFold {
65    /// Name of the `le` column. It's a special column in prometheus
66    /// for implementing conventional histogram. It's a string column
67    /// with "literal" float value, like "+Inf", "0.001" etc.
68    le_column: String,
69    ts_column: String,
70    input: LogicalPlan,
71    field_column: String,
72    quantile: OrderedF64,
73    output_schema: DFSchemaRef,
74}
75
76impl UserDefinedLogicalNodeCore for HistogramFold {
77    fn name(&self) -> &str {
78        Self::name()
79    }
80
81    fn inputs(&self) -> Vec<&LogicalPlan> {
82        vec![&self.input]
83    }
84
85    fn schema(&self) -> &DFSchemaRef {
86        &self.output_schema
87    }
88
89    fn expressions(&self) -> Vec<Expr> {
90        vec![]
91    }
92
93    fn fmt_for_explain(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        write!(
95            f,
96            "HistogramFold: le={}, field={}, quantile={}",
97            self.le_column, self.field_column, self.quantile
98        )
99    }
100
101    fn with_exprs_and_inputs(
102        &self,
103        _exprs: Vec<Expr>,
104        inputs: Vec<LogicalPlan>,
105    ) -> DataFusionResult<Self> {
106        Ok(Self {
107            le_column: self.le_column.clone(),
108            ts_column: self.ts_column.clone(),
109            input: inputs.into_iter().next().unwrap(),
110            field_column: self.field_column.clone(),
111            quantile: self.quantile,
112            // This method cannot return error. Otherwise we should re-calculate
113            // the output schema
114            output_schema: self.output_schema.clone(),
115        })
116    }
117}
118
119impl HistogramFold {
120    pub fn new(
121        le_column: String,
122        field_column: String,
123        ts_column: String,
124        quantile: f64,
125        input: LogicalPlan,
126    ) -> DataFusionResult<Self> {
127        let input_schema = input.schema();
128        Self::check_schema(input_schema, &le_column, &field_column, &ts_column)?;
129        let output_schema = Self::convert_schema(input_schema, &le_column)?;
130        Ok(Self {
131            le_column,
132            ts_column,
133            input,
134            field_column,
135            quantile: quantile.into(),
136            output_schema,
137        })
138    }
139
140    pub const fn name() -> &'static str {
141        "HistogramFold"
142    }
143
144    fn check_schema(
145        input_schema: &DFSchemaRef,
146        le_column: &str,
147        field_column: &str,
148        ts_column: &str,
149    ) -> DataFusionResult<()> {
150        let check_column = |col| {
151            if !input_schema.has_column_with_unqualified_name(col) {
152                Err(DataFusionError::SchemaError(
153                    Box::new(datafusion::common::SchemaError::FieldNotFound {
154                        field: Box::new(Column::new(None::<String>, col)),
155                        valid_fields: input_schema.columns(),
156                    }),
157                    Box::new(None),
158                ))
159            } else {
160                Ok(())
161            }
162        };
163
164        check_column(le_column)?;
165        check_column(ts_column)?;
166        check_column(field_column)
167    }
168
169    pub fn to_execution_plan(&self, exec_input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
170        let input_schema = self.input.schema();
171        // safety: those fields are checked in `check_schema()`
172        let le_column_index = input_schema
173            .index_of_column_by_name(None, &self.le_column)
174            .unwrap();
175        let field_column_index = input_schema
176            .index_of_column_by_name(None, &self.field_column)
177            .unwrap();
178        let ts_column_index = input_schema
179            .index_of_column_by_name(None, &self.ts_column)
180            .unwrap();
181
182        let output_schema: SchemaRef = self.output_schema.inner().clone();
183        let properties = PlanProperties::new(
184            EquivalenceProperties::new(output_schema.clone()),
185            Partitioning::UnknownPartitioning(1),
186            EmissionType::Incremental,
187            Boundedness::Bounded,
188        );
189        Arc::new(HistogramFoldExec {
190            le_column_index,
191            field_column_index,
192            ts_column_index,
193            input: exec_input,
194            quantile: self.quantile.into(),
195            output_schema,
196            metric: ExecutionPlanMetricsSet::new(),
197            properties,
198        })
199    }
200
201    /// Transform the schema
202    ///
203    /// - `le` will be removed
204    fn convert_schema(
205        input_schema: &DFSchemaRef,
206        le_column: &str,
207    ) -> DataFusionResult<DFSchemaRef> {
208        let fields = input_schema.fields();
209        // safety: those fields are checked in `check_schema()`
210        let mut new_fields = Vec::with_capacity(fields.len() - 1);
211        for f in fields {
212            if f.name() != le_column {
213                new_fields.push((None, f.clone()));
214            }
215        }
216        Ok(Arc::new(DFSchema::new_with_metadata(
217            new_fields,
218            HashMap::new(),
219        )?))
220    }
221}
222
223impl PartialOrd for HistogramFold {
224    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
225        // Compare fields in order excluding output_schema
226        match self.le_column.partial_cmp(&other.le_column) {
227            Some(core::cmp::Ordering::Equal) => {}
228            ord => return ord,
229        }
230        match self.ts_column.partial_cmp(&other.ts_column) {
231            Some(core::cmp::Ordering::Equal) => {}
232            ord => return ord,
233        }
234        match self.input.partial_cmp(&other.input) {
235            Some(core::cmp::Ordering::Equal) => {}
236            ord => return ord,
237        }
238        match self.field_column.partial_cmp(&other.field_column) {
239            Some(core::cmp::Ordering::Equal) => {}
240            ord => return ord,
241        }
242        self.quantile.partial_cmp(&other.quantile)
243    }
244}
245
246#[derive(Debug)]
247pub struct HistogramFoldExec {
248    /// Index for `le` column in the schema of input.
249    le_column_index: usize,
250    input: Arc<dyn ExecutionPlan>,
251    output_schema: SchemaRef,
252    /// Index for field column in the schema of input.
253    field_column_index: usize,
254    ts_column_index: usize,
255    quantile: f64,
256    metric: ExecutionPlanMetricsSet,
257    properties: PlanProperties,
258}
259
260impl ExecutionPlan for HistogramFoldExec {
261    fn as_any(&self) -> &dyn Any {
262        self
263    }
264
265    fn properties(&self) -> &PlanProperties {
266        &self.properties
267    }
268
269    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
270        let mut cols = self
271            .tag_col_exprs()
272            .into_iter()
273            .map(|expr| PhysicalSortRequirement {
274                expr,
275                options: None,
276            })
277            .collect::<Vec<PhysicalSortRequirement>>();
278        // add ts
279        cols.push(PhysicalSortRequirement {
280            expr: Arc::new(PhyColumn::new(
281                self.input.schema().field(self.ts_column_index).name(),
282                self.ts_column_index,
283            )),
284            options: None,
285        });
286        // add le ASC
287        cols.push(PhysicalSortRequirement {
288            expr: Arc::new(PhyCast::new(
289                Arc::new(PhyColumn::new(
290                    self.input.schema().field(self.le_column_index).name(),
291                    self.le_column_index,
292                )),
293                DataType::Float64,
294                None,
295            )),
296            options: Some(SortOptions {
297                descending: false,  // +INF in the last
298                nulls_first: false, // not nullable
299            }),
300        });
301
302        // Safety: `cols` is not empty
303        let requirement = LexRequirement::new(cols).unwrap();
304
305        vec![Some(OrderingRequirements::Hard(vec![requirement]))]
306    }
307
308    fn required_input_distribution(&self) -> Vec<Distribution> {
309        self.input.required_input_distribution()
310    }
311
312    fn maintains_input_order(&self) -> Vec<bool> {
313        vec![true; self.children().len()]
314    }
315
316    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
317        vec![&self.input]
318    }
319
320    // cannot change schema with this method
321    fn with_new_children(
322        self: Arc<Self>,
323        children: Vec<Arc<dyn ExecutionPlan>>,
324    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
325        assert!(!children.is_empty());
326        Ok(Arc::new(Self {
327            input: children[0].clone(),
328            metric: self.metric.clone(),
329            le_column_index: self.le_column_index,
330            ts_column_index: self.ts_column_index,
331            quantile: self.quantile,
332            output_schema: self.output_schema.clone(),
333            field_column_index: self.field_column_index,
334            properties: self.properties.clone(),
335        }))
336    }
337
338    fn execute(
339        &self,
340        partition: usize,
341        context: Arc<TaskContext>,
342    ) -> DataFusionResult<SendableRecordBatchStream> {
343        let baseline_metric = BaselineMetrics::new(&self.metric, partition);
344
345        let batch_size = context.session_config().batch_size();
346        let input = self.input.execute(partition, context)?;
347        let output_schema = self.output_schema.clone();
348
349        let mut normal_indices = (0..input.schema().fields().len()).collect::<HashSet<_>>();
350        normal_indices.remove(&self.field_column_index);
351        normal_indices.remove(&self.le_column_index);
352        Ok(Box::pin(HistogramFoldStream {
353            le_column_index: self.le_column_index,
354            field_column_index: self.field_column_index,
355            quantile: self.quantile,
356            normal_indices: normal_indices.into_iter().collect(),
357            bucket_size: None,
358            input_buffer: vec![],
359            input,
360            output_schema,
361            metric: baseline_metric,
362            batch_size,
363            input_buffered_rows: 0,
364            output_buffer: HistogramFoldStream::empty_output_buffer(
365                &self.output_schema,
366                self.le_column_index,
367            )?,
368            output_buffered_rows: 0,
369        }))
370    }
371
372    fn metrics(&self) -> Option<MetricsSet> {
373        Some(self.metric.clone_inner())
374    }
375
376    fn partition_statistics(&self, _: Option<usize>) -> DataFusionResult<Statistics> {
377        Ok(Statistics {
378            num_rows: Precision::Absent,
379            total_byte_size: Precision::Absent,
380            column_statistics: vec![
381                ColumnStatistics::new_unknown();
382                // plus one more for the removed column by function `convert_schema`
383                self.schema().flattened_fields().len() + 1
384            ],
385        })
386    }
387
388    fn name(&self) -> &str {
389        "HistogramFoldExec"
390    }
391}
392
393impl HistogramFoldExec {
394    /// Return all the [PhysicalExpr] of tag columns in order.
395    ///
396    /// Tag columns are all columns except `le`, `field` and `ts` columns.
397    pub fn tag_col_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
398        self.input
399            .schema()
400            .fields()
401            .iter()
402            .enumerate()
403            .filter_map(|(idx, field)| {
404                if idx == self.le_column_index
405                    || idx == self.field_column_index
406                    || idx == self.ts_column_index
407                {
408                    None
409                } else {
410                    Some(Arc::new(PhyColumn::new(field.name(), idx)) as _)
411                }
412            })
413            .collect()
414    }
415}
416
417impl DisplayAs for HistogramFoldExec {
418    fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
419        match t {
420            DisplayFormatType::Default
421            | DisplayFormatType::Verbose
422            | DisplayFormatType::TreeRender => {
423                write!(
424                    f,
425                    "HistogramFoldExec: le=@{}, field=@{}, quantile={}",
426                    self.le_column_index, self.field_column_index, self.quantile
427                )
428            }
429        }
430    }
431}
432
433pub struct HistogramFoldStream {
434    // internal states
435    le_column_index: usize,
436    field_column_index: usize,
437    quantile: f64,
438    /// Columns need not folding. This indices is based on input schema
439    normal_indices: Vec<usize>,
440    bucket_size: Option<usize>,
441    /// Expected output batch size
442    batch_size: usize,
443    output_schema: SchemaRef,
444
445    // buffers
446    input_buffer: Vec<RecordBatch>,
447    input_buffered_rows: usize,
448    output_buffer: Vec<Box<dyn MutableVector>>,
449    output_buffered_rows: usize,
450
451    // runtime things
452    input: SendableRecordBatchStream,
453    metric: BaselineMetrics,
454}
455
456impl RecordBatchStream for HistogramFoldStream {
457    fn schema(&self) -> SchemaRef {
458        self.output_schema.clone()
459    }
460}
461
462impl Stream for HistogramFoldStream {
463    type Item = DataFusionResult<RecordBatch>;
464
465    fn poll_next(
466        mut self: std::pin::Pin<&mut Self>,
467        cx: &mut std::task::Context<'_>,
468    ) -> Poll<Option<Self::Item>> {
469        let poll = loop {
470            match ready!(self.input.poll_next_unpin(cx)) {
471                Some(batch) => {
472                    let batch = batch?;
473                    let timer = Instant::now();
474                    let Some(result) = self.fold_input(batch)? else {
475                        self.metric.elapsed_compute().add_elapsed(timer);
476                        continue;
477                    };
478                    self.metric.elapsed_compute().add_elapsed(timer);
479                    break Poll::Ready(Some(result));
480                }
481                None => break Poll::Ready(self.take_output_buf()?.map(Ok)),
482            }
483        };
484        self.metric.record_poll(poll)
485    }
486}
487
488impl HistogramFoldStream {
489    /// The inner most `Result` is for `poll_next()`
490    pub fn fold_input(
491        &mut self,
492        input: RecordBatch,
493    ) -> DataFusionResult<Option<DataFusionResult<RecordBatch>>> {
494        let Some(bucket_num) = self.calculate_bucket_num(&input)? else {
495            return Ok(None);
496        };
497
498        if self.input_buffered_rows + input.num_rows() < bucket_num {
499            // not enough rows to fold
500            self.push_input_buf(input);
501            return Ok(None);
502        }
503
504        self.fold_buf(bucket_num, input)?;
505        if self.output_buffered_rows >= self.batch_size {
506            return Ok(self.take_output_buf()?.map(Ok));
507        }
508
509        Ok(None)
510    }
511
512    /// Generate a group of empty [MutableVector]s from the output schema.
513    ///
514    /// For simplicity, this method will insert a placeholder for `le`. So that
515    /// the output buffers has the same schema with input. This placeholder needs
516    /// to be removed before returning the output batch.
517    pub fn empty_output_buffer(
518        schema: &SchemaRef,
519        le_column_index: usize,
520    ) -> DataFusionResult<Vec<Box<dyn MutableVector>>> {
521        let mut builders = Vec::with_capacity(schema.fields().len() + 1);
522        for field in schema.fields() {
523            let concrete_datatype = ConcreteDataType::try_from(field.data_type()).unwrap();
524            let mutable_vector = concrete_datatype.create_mutable_vector(0);
525            builders.push(mutable_vector);
526        }
527        builders.insert(
528            le_column_index,
529            ConcreteDataType::float64_datatype().create_mutable_vector(0),
530        );
531
532        Ok(builders)
533    }
534
535    fn calculate_bucket_num(&mut self, batch: &RecordBatch) -> DataFusionResult<Option<usize>> {
536        if let Some(size) = self.bucket_size {
537            return Ok(Some(size));
538        }
539
540        let inf_pos = self.find_positive_inf(batch)?;
541        if inf_pos == batch.num_rows() {
542            // no positive inf found, append to buffer and wait for next batch
543            self.push_input_buf(batch.clone());
544            return Ok(None);
545        }
546
547        // else we found the positive inf.
548        // calculate the bucket size
549        let bucket_size = inf_pos + self.input_buffered_rows + 1;
550        Ok(Some(bucket_size))
551    }
552
553    /// Fold record batches from input buffer and put to output buffer
554    fn fold_buf(&mut self, bucket_num: usize, input: RecordBatch) -> DataFusionResult<()> {
555        self.push_input_buf(input);
556        // TODO(ruihang): this concat is avoidable.
557        let batch = concat_batches(&self.input.schema(), self.input_buffer.drain(..).as_ref())?;
558        let mut remaining_rows = self.input_buffered_rows;
559        let mut cursor = 0;
560
561        // TODO(LFC): Try to get rid of the Arrow array to vector conversion here.
562        let vectors = Helper::try_into_vectors(batch.columns())
563            .map_err(|e| DataFusionError::Execution(e.to_string()))?;
564
565        while remaining_rows >= bucket_num {
566            // "sample" normal columns
567            for normal_index in &self.normal_indices {
568                let val = vectors[*normal_index].get(cursor);
569                self.output_buffer[*normal_index].push_value_ref(&val.as_value_ref());
570            }
571            // "fold" `le` and field columns
572            let le_array = batch.column(self.le_column_index);
573            let le_array = le_array.as_string::<i32>();
574            let field_array = batch.column(self.field_column_index);
575            let field_array = field_array.as_primitive::<Float64Type>();
576            let mut bucket = vec![];
577            let mut counters = vec![];
578            for bias in 0..bucket_num {
579                let le_str = le_array.value(cursor + bias);
580                let le = le_str.parse::<f64>().unwrap();
581                bucket.push(le);
582
583                let counter = field_array.value(cursor + bias);
584                counters.push(counter);
585            }
586            // ignore invalid data
587            let result = Self::evaluate_row(self.quantile, &bucket, &counters).unwrap_or(f64::NAN);
588            self.output_buffer[self.field_column_index].push_value_ref(&ValueRef::from(result));
589            cursor += bucket_num;
590            remaining_rows -= bucket_num;
591            self.output_buffered_rows += 1;
592        }
593
594        let remaining_input_batch = batch.slice(cursor, remaining_rows);
595        self.input_buffered_rows = remaining_input_batch.num_rows();
596        self.input_buffer.push(remaining_input_batch);
597
598        Ok(())
599    }
600
601    fn push_input_buf(&mut self, batch: RecordBatch) {
602        self.input_buffered_rows += batch.num_rows();
603        self.input_buffer.push(batch);
604    }
605
606    /// Compute result from output buffer
607    fn take_output_buf(&mut self) -> DataFusionResult<Option<RecordBatch>> {
608        if self.output_buffered_rows == 0 {
609            if self.input_buffered_rows != 0 {
610                warn!(
611                    "input buffer is not empty, {} rows remaining",
612                    self.input_buffered_rows
613                );
614            }
615            return Ok(None);
616        }
617
618        let mut output_buf = Self::empty_output_buffer(&self.output_schema, self.le_column_index)?;
619        std::mem::swap(&mut self.output_buffer, &mut output_buf);
620        let mut columns = Vec::with_capacity(output_buf.len());
621        for builder in output_buf.iter_mut() {
622            columns.push(builder.to_vector().to_arrow_array());
623        }
624        // remove the placeholder column for `le`
625        columns.remove(self.le_column_index);
626
627        self.output_buffered_rows = 0;
628        RecordBatch::try_new(self.output_schema.clone(), columns)
629            .map(Some)
630            .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))
631    }
632
633    /// Find the first `+Inf` which indicates the end of the bucket group
634    ///
635    /// If the return value equals to batch's num_rows means the it's not found
636    /// in this batch
637    fn find_positive_inf(&self, batch: &RecordBatch) -> DataFusionResult<usize> {
638        // fuse this function. It should not be called when the
639        // bucket size is already know.
640        if let Some(bucket_size) = self.bucket_size {
641            return Ok(bucket_size);
642        }
643        let string_le_array = batch.column(self.le_column_index);
644        let float_le_array = compute::cast(&string_le_array, &DataType::Float64).map_err(|e| {
645            DataFusionError::Execution(format!(
646                "cannot cast {} array to float64 array: {:?}",
647                string_le_array.data_type(),
648                e
649            ))
650        })?;
651        let le_as_f64_array = float_le_array
652            .as_primitive_opt::<Float64Type>()
653            .ok_or_else(|| {
654                DataFusionError::Execution(format!(
655                    "expect a float64 array, but found {}",
656                    float_le_array.data_type()
657                ))
658            })?;
659        for (i, v) in le_as_f64_array.iter().enumerate() {
660            if let Some(v) = v
661                && v == f64::INFINITY
662            {
663                return Ok(i);
664            }
665        }
666
667        Ok(batch.num_rows())
668    }
669
670    /// Evaluate the field column and return the result
671    fn evaluate_row(quantile: f64, bucket: &[f64], counter: &[f64]) -> DataFusionResult<f64> {
672        // check bucket
673        if bucket.len() <= 1 {
674            return Ok(f64::NAN);
675        }
676        if bucket.last().unwrap().is_finite() {
677            return Err(DataFusionError::Execution(
678                "last bucket should be +Inf".to_string(),
679            ));
680        }
681        if bucket.len() != counter.len() {
682            return Err(DataFusionError::Execution(
683                "bucket and counter should have the same length".to_string(),
684            ));
685        }
686        // check quantile
687        if quantile < 0.0 {
688            return Ok(f64::NEG_INFINITY);
689        } else if quantile > 1.0 {
690            return Ok(f64::INFINITY);
691        } else if quantile.is_nan() {
692            return Ok(f64::NAN);
693        }
694
695        // check input value
696        debug_assert!(bucket.windows(2).all(|w| w[0] <= w[1]), "{bucket:?}");
697        debug_assert!(counter.windows(2).all(|w| w[0] <= w[1]), "{counter:?}");
698
699        let total = *counter.last().unwrap();
700        let expected_pos = total * quantile;
701        let mut fit_bucket_pos = 0;
702        while fit_bucket_pos < bucket.len() && counter[fit_bucket_pos] < expected_pos {
703            fit_bucket_pos += 1;
704        }
705        if fit_bucket_pos >= bucket.len() - 1 {
706            Ok(bucket[bucket.len() - 2])
707        } else {
708            let upper_bound = bucket[fit_bucket_pos];
709            let upper_count = counter[fit_bucket_pos];
710            let mut lower_bound = bucket[0].min(0.0);
711            let mut lower_count = 0.0;
712            if fit_bucket_pos > 0 {
713                lower_bound = bucket[fit_bucket_pos - 1];
714                lower_count = counter[fit_bucket_pos - 1];
715            }
716            Ok(lower_bound
717                + (upper_bound - lower_bound) / (upper_count - lower_count)
718                    * (expected_pos - lower_count))
719        }
720    }
721}
722
723#[cfg(test)]
724mod test {
725    use std::sync::Arc;
726
727    use datafusion::arrow::array::Float64Array;
728    use datafusion::arrow::datatypes::{Field, Schema};
729    use datafusion::common::ToDFSchema;
730    use datafusion::datasource::memory::MemorySourceConfig;
731    use datafusion::datasource::source::DataSourceExec;
732    use datafusion::prelude::SessionContext;
733    use datatypes::arrow_array::StringArray;
734
735    use super::*;
736
737    fn prepare_test_data() -> DataSourceExec {
738        let schema = Arc::new(Schema::new(vec![
739            Field::new("host", DataType::Utf8, true),
740            Field::new("le", DataType::Utf8, true),
741            Field::new("val", DataType::Float64, true),
742        ]));
743
744        // 12 items
745        let host_column_1 = Arc::new(StringArray::from(vec![
746            "host_1", "host_1", "host_1", "host_1", "host_1", "host_1", "host_1", "host_1",
747            "host_1", "host_1", "host_1", "host_1",
748        ])) as _;
749        let le_column_1 = Arc::new(StringArray::from(vec![
750            "0.001", "0.1", "10", "1000", "+Inf", "0.001", "0.1", "10", "1000", "+inf", "0.001",
751            "0.1",
752        ])) as _;
753        let val_column_1 = Arc::new(Float64Array::from(vec![
754            0_0.0, 1.0, 1.0, 5.0, 5.0, 0_0.0, 20.0, 60.0, 70.0, 100.0, 0_1.0, 1.0,
755        ])) as _;
756
757        // 2 items
758        let host_column_2 = Arc::new(StringArray::from(vec!["host_1", "host_1"])) as _;
759        let le_column_2 = Arc::new(StringArray::from(vec!["10", "1000"])) as _;
760        let val_column_2 = Arc::new(Float64Array::from(vec![1.0, 1.0])) as _;
761
762        // 11 items
763        let host_column_3 = Arc::new(StringArray::from(vec![
764            "host_1", "host_2", "host_2", "host_2", "host_2", "host_2", "host_2", "host_2",
765            "host_2", "host_2", "host_2",
766        ])) as _;
767        let le_column_3 = Arc::new(StringArray::from(vec![
768            "+INF", "0.001", "0.1", "10", "1000", "+iNf", "0.001", "0.1", "10", "1000", "+Inf",
769        ])) as _;
770        let val_column_3 = Arc::new(Float64Array::from(vec![
771            1.0, 0_0.0, 0.0, 0.0, 0.0, 0.0, 0_0.0, 1.0, 2.0, 3.0, 4.0,
772        ])) as _;
773
774        let data_1 = RecordBatch::try_new(
775            schema.clone(),
776            vec![host_column_1, le_column_1, val_column_1],
777        )
778        .unwrap();
779        let data_2 = RecordBatch::try_new(
780            schema.clone(),
781            vec![host_column_2, le_column_2, val_column_2],
782        )
783        .unwrap();
784        let data_3 = RecordBatch::try_new(
785            schema.clone(),
786            vec![host_column_3, le_column_3, val_column_3],
787        )
788        .unwrap();
789
790        DataSourceExec::new(Arc::new(
791            MemorySourceConfig::try_new(&[vec![data_1, data_2, data_3]], schema, None).unwrap(),
792        ))
793    }
794
795    #[tokio::test]
796    async fn fold_overall() {
797        let memory_exec = Arc::new(prepare_test_data());
798        let output_schema: SchemaRef = Arc::new(
799            HistogramFold::convert_schema(
800                &Arc::new(memory_exec.schema().to_dfschema().unwrap()),
801                "le",
802            )
803            .unwrap()
804            .as_arrow()
805            .clone(),
806        );
807        let properties = PlanProperties::new(
808            EquivalenceProperties::new(output_schema.clone()),
809            Partitioning::UnknownPartitioning(1),
810            EmissionType::Incremental,
811            Boundedness::Bounded,
812        );
813        let fold_exec = Arc::new(HistogramFoldExec {
814            le_column_index: 1,
815            field_column_index: 2,
816            quantile: 0.4,
817            ts_column_index: 9999, // not exist but doesn't matter
818            input: memory_exec,
819            output_schema,
820            metric: ExecutionPlanMetricsSet::new(),
821            properties,
822        });
823
824        let session_context = SessionContext::default();
825        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
826            .await
827            .unwrap();
828        let result_literal = datatypes::arrow::util::pretty::pretty_format_batches(&result)
829            .unwrap()
830            .to_string();
831
832        let expected = String::from(
833            "+--------+-------------------+
834| host   | val               |
835+--------+-------------------+
836| host_1 | 257.5             |
837| host_1 | 5.05              |
838| host_1 | 0.0004            |
839| host_2 | NaN               |
840| host_2 | 6.040000000000001 |
841+--------+-------------------+",
842        );
843        assert_eq!(result_literal, expected);
844    }
845
846    #[test]
847    fn confirm_schema() {
848        let input_schema = Schema::new(vec![
849            Field::new("host", DataType::Utf8, true),
850            Field::new("le", DataType::Utf8, true),
851            Field::new("val", DataType::Float64, true),
852        ])
853        .to_dfschema_ref()
854        .unwrap();
855        let expected_output_schema = Schema::new(vec![
856            Field::new("host", DataType::Utf8, true),
857            Field::new("val", DataType::Float64, true),
858        ])
859        .to_dfschema_ref()
860        .unwrap();
861
862        let actual = HistogramFold::convert_schema(&input_schema, "le").unwrap();
863        assert_eq!(actual, expected_output_schema)
864    }
865
866    #[test]
867    fn evaluate_row_normal_case() {
868        let bucket = [0.0, 1.0, 2.0, 3.0, 4.0, f64::INFINITY];
869
870        #[derive(Debug)]
871        struct Case {
872            quantile: f64,
873            counters: Vec<f64>,
874            expected: f64,
875        }
876
877        let cases = [
878            Case {
879                quantile: 0.9,
880                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
881                expected: 4.0,
882            },
883            Case {
884                quantile: 0.89,
885                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
886                expected: 4.0,
887            },
888            Case {
889                quantile: 0.78,
890                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
891                expected: 3.9,
892            },
893            Case {
894                quantile: 0.5,
895                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
896                expected: 2.5,
897            },
898            Case {
899                quantile: 0.5,
900                counters: vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
901                expected: f64::NAN,
902            },
903            Case {
904                quantile: 1.0,
905                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
906                expected: 4.0,
907            },
908            Case {
909                quantile: 0.0,
910                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
911                expected: f64::NAN,
912            },
913            Case {
914                quantile: 1.1,
915                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
916                expected: f64::INFINITY,
917            },
918            Case {
919                quantile: -1.0,
920                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
921                expected: f64::NEG_INFINITY,
922            },
923        ];
924
925        for case in cases {
926            let actual =
927                HistogramFoldStream::evaluate_row(case.quantile, &bucket, &case.counters).unwrap();
928            assert_eq!(
929                format!("{actual}"),
930                format!("{}", case.expected),
931                "{:?}",
932                case
933            );
934        }
935    }
936
937    #[test]
938    #[should_panic]
939    fn evaluate_out_of_order_input() {
940        let bucket = [0.0, 1.0, 2.0, 3.0, 4.0, f64::INFINITY];
941        let counters = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0];
942        HistogramFoldStream::evaluate_row(0.5, &bucket, &counters).unwrap();
943    }
944
945    #[test]
946    fn evaluate_wrong_bucket() {
947        let bucket = [0.0, 1.0, 2.0, 3.0, 4.0, f64::INFINITY, 5.0];
948        let counters = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
949        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters);
950        assert!(result.is_err());
951    }
952
953    #[test]
954    fn evaluate_small_fraction() {
955        let bucket = [0.0, 2.0, 4.0, 6.0, f64::INFINITY];
956        let counters = [0.0, 1.0 / 300.0, 2.0 / 300.0, 0.01, 0.01];
957        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters).unwrap();
958        assert_eq!(3.0, result);
959    }
960}