promql/extension_plan/
histogram_fold.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::collections::{HashMap, HashSet};
17use std::sync::Arc;
18use std::task::Poll;
19use std::time::Instant;
20
21use common_recordbatch::RecordBatch as GtRecordBatch;
22use common_telemetry::warn;
23use datafusion::arrow::array::AsArray;
24use datafusion::arrow::compute::{self, concat_batches, SortOptions};
25use datafusion::arrow::datatypes::{DataType, Float64Type, SchemaRef};
26use datafusion::arrow::record_batch::RecordBatch;
27use datafusion::common::stats::Precision;
28use datafusion::common::{ColumnStatistics, DFSchema, DFSchemaRef, Statistics};
29use datafusion::error::{DataFusionError, Result as DataFusionResult};
30use datafusion::execution::TaskContext;
31use datafusion::logical_expr::{LogicalPlan, UserDefinedLogicalNodeCore};
32use datafusion::physical_expr::{EquivalenceProperties, LexRequirement, PhysicalSortRequirement};
33use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
34use datafusion::physical_plan::expressions::{CastExpr as PhyCast, Column as PhyColumn};
35use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
36use datafusion::physical_plan::{
37    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr,
38    PlanProperties, RecordBatchStream, SendableRecordBatchStream,
39};
40use datafusion::prelude::{Column, Expr};
41use datatypes::prelude::{ConcreteDataType, DataType as GtDataType};
42use datatypes::schema::Schema as GtSchema;
43use datatypes::value::{OrderedF64, ValueRef};
44use datatypes::vectors::MutableVector;
45use futures::{ready, Stream, StreamExt};
46
47/// `HistogramFold` will fold the conventional (non-native) histogram ([1]) for later
48/// computing.
49///
50/// Specifically, it will transform the `le` and `field` column into a complex
51/// type, and samples on other tag columns:
52/// - `le` will become a [ListArray] of [f64]. With each bucket bound parsed
53/// - `field` will become a [ListArray] of [f64]
54/// - other columns will be sampled every `bucket_num` element, but their types won't change.
55///
56/// Due to the folding or sampling, the output rows number will become `input_rows` / `bucket_num`.
57///
58/// # Requirement
59/// - Input should be sorted on `<tag list>, ts, le ASC`.
60/// - The value set of `le` should be same. I.e., buckets of every series should be same.
61///
62/// [1]: https://prometheus.io/docs/concepts/metric_types/#histogram
63#[derive(Debug, PartialEq, Hash, Eq)]
64pub struct HistogramFold {
65    /// Name of the `le` column. It's a special column in prometheus
66    /// for implementing conventional histogram. It's a string column
67    /// with "literal" float value, like "+Inf", "0.001" etc.
68    le_column: String,
69    ts_column: String,
70    input: LogicalPlan,
71    field_column: String,
72    quantile: OrderedF64,
73    output_schema: DFSchemaRef,
74}
75
76impl UserDefinedLogicalNodeCore for HistogramFold {
77    fn name(&self) -> &str {
78        Self::name()
79    }
80
81    fn inputs(&self) -> Vec<&LogicalPlan> {
82        vec![&self.input]
83    }
84
85    fn schema(&self) -> &DFSchemaRef {
86        &self.output_schema
87    }
88
89    fn expressions(&self) -> Vec<Expr> {
90        vec![]
91    }
92
93    fn fmt_for_explain(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        write!(
95            f,
96            "HistogramFold: le={}, field={}, quantile={}",
97            self.le_column, self.field_column, self.quantile
98        )
99    }
100
101    fn with_exprs_and_inputs(
102        &self,
103        _exprs: Vec<Expr>,
104        inputs: Vec<LogicalPlan>,
105    ) -> DataFusionResult<Self> {
106        Ok(Self {
107            le_column: self.le_column.clone(),
108            ts_column: self.ts_column.clone(),
109            input: inputs.into_iter().next().unwrap(),
110            field_column: self.field_column.clone(),
111            quantile: self.quantile,
112            // This method cannot return error. Otherwise we should re-calculate
113            // the output schema
114            output_schema: self.output_schema.clone(),
115        })
116    }
117}
118
119impl HistogramFold {
120    pub fn new(
121        le_column: String,
122        field_column: String,
123        ts_column: String,
124        quantile: f64,
125        input: LogicalPlan,
126    ) -> DataFusionResult<Self> {
127        let input_schema = input.schema();
128        Self::check_schema(input_schema, &le_column, &field_column, &ts_column)?;
129        let output_schema = Self::convert_schema(input_schema, &le_column)?;
130        Ok(Self {
131            le_column,
132            ts_column,
133            input,
134            field_column,
135            quantile: quantile.into(),
136            output_schema,
137        })
138    }
139
140    pub const fn name() -> &'static str {
141        "HistogramFold"
142    }
143
144    fn check_schema(
145        input_schema: &DFSchemaRef,
146        le_column: &str,
147        field_column: &str,
148        ts_column: &str,
149    ) -> DataFusionResult<()> {
150        let check_column = |col| {
151            if !input_schema.has_column_with_unqualified_name(col) {
152                Err(DataFusionError::SchemaError(
153                    datafusion::common::SchemaError::FieldNotFound {
154                        field: Box::new(Column::new(None::<String>, col)),
155                        valid_fields: input_schema.columns(),
156                    },
157                    Box::new(None),
158                ))
159            } else {
160                Ok(())
161            }
162        };
163
164        check_column(le_column)?;
165        check_column(ts_column)?;
166        check_column(field_column)
167    }
168
169    pub fn to_execution_plan(&self, exec_input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
170        let input_schema = self.input.schema();
171        // safety: those fields are checked in `check_schema()`
172        let le_column_index = input_schema
173            .index_of_column_by_name(None, &self.le_column)
174            .unwrap();
175        let field_column_index = input_schema
176            .index_of_column_by_name(None, &self.field_column)
177            .unwrap();
178        let ts_column_index = input_schema
179            .index_of_column_by_name(None, &self.ts_column)
180            .unwrap();
181
182        let output_schema: SchemaRef = Arc::new(self.output_schema.as_ref().into());
183        let properties = PlanProperties::new(
184            EquivalenceProperties::new(output_schema.clone()),
185            Partitioning::UnknownPartitioning(1),
186            EmissionType::Incremental,
187            Boundedness::Bounded,
188        );
189        Arc::new(HistogramFoldExec {
190            le_column_index,
191            field_column_index,
192            ts_column_index,
193            input: exec_input,
194            quantile: self.quantile.into(),
195            output_schema,
196            metric: ExecutionPlanMetricsSet::new(),
197            properties,
198        })
199    }
200
201    /// Transform the schema
202    ///
203    /// - `le` will be removed
204    fn convert_schema(
205        input_schema: &DFSchemaRef,
206        le_column: &str,
207    ) -> DataFusionResult<DFSchemaRef> {
208        let fields = input_schema.fields();
209        // safety: those fields are checked in `check_schema()`
210        let mut new_fields = Vec::with_capacity(fields.len() - 1);
211        for f in fields {
212            if f.name() != le_column {
213                new_fields.push((None, f.clone()));
214            }
215        }
216        Ok(Arc::new(DFSchema::new_with_metadata(
217            new_fields,
218            HashMap::new(),
219        )?))
220    }
221}
222
223impl PartialOrd for HistogramFold {
224    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
225        // Compare fields in order excluding output_schema
226        match self.le_column.partial_cmp(&other.le_column) {
227            Some(core::cmp::Ordering::Equal) => {}
228            ord => return ord,
229        }
230        match self.ts_column.partial_cmp(&other.ts_column) {
231            Some(core::cmp::Ordering::Equal) => {}
232            ord => return ord,
233        }
234        match self.input.partial_cmp(&other.input) {
235            Some(core::cmp::Ordering::Equal) => {}
236            ord => return ord,
237        }
238        match self.field_column.partial_cmp(&other.field_column) {
239            Some(core::cmp::Ordering::Equal) => {}
240            ord => return ord,
241        }
242        self.quantile.partial_cmp(&other.quantile)
243    }
244}
245
246#[derive(Debug)]
247pub struct HistogramFoldExec {
248    /// Index for `le` column in the schema of input.
249    le_column_index: usize,
250    input: Arc<dyn ExecutionPlan>,
251    output_schema: SchemaRef,
252    /// Index for field column in the schema of input.
253    field_column_index: usize,
254    ts_column_index: usize,
255    quantile: f64,
256    metric: ExecutionPlanMetricsSet,
257    properties: PlanProperties,
258}
259
260impl ExecutionPlan for HistogramFoldExec {
261    fn as_any(&self) -> &dyn Any {
262        self
263    }
264
265    fn properties(&self) -> &PlanProperties {
266        &self.properties
267    }
268
269    fn required_input_ordering(&self) -> Vec<Option<LexRequirement>> {
270        let mut cols = self
271            .tag_col_exprs()
272            .into_iter()
273            .map(|expr| PhysicalSortRequirement {
274                expr,
275                options: None,
276            })
277            .collect::<Vec<PhysicalSortRequirement>>();
278        // add ts
279        cols.push(PhysicalSortRequirement {
280            expr: Arc::new(PhyColumn::new(
281                self.input.schema().field(self.ts_column_index).name(),
282                self.ts_column_index,
283            )),
284            options: None,
285        });
286        // add le ASC
287        cols.push(PhysicalSortRequirement {
288            expr: Arc::new(PhyCast::new(
289                Arc::new(PhyColumn::new(
290                    self.input.schema().field(self.le_column_index).name(),
291                    self.le_column_index,
292                )),
293                DataType::Float64,
294                None,
295            )),
296            options: Some(SortOptions {
297                descending: false,  // +INF in the last
298                nulls_first: false, // not nullable
299            }),
300        });
301
302        vec![Some(LexRequirement::new(cols))]
303    }
304
305    fn required_input_distribution(&self) -> Vec<Distribution> {
306        self.input.required_input_distribution()
307    }
308
309    fn maintains_input_order(&self) -> Vec<bool> {
310        vec![true; self.children().len()]
311    }
312
313    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
314        vec![&self.input]
315    }
316
317    // cannot change schema with this method
318    fn with_new_children(
319        self: Arc<Self>,
320        children: Vec<Arc<dyn ExecutionPlan>>,
321    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
322        assert!(!children.is_empty());
323        Ok(Arc::new(Self {
324            input: children[0].clone(),
325            metric: self.metric.clone(),
326            le_column_index: self.le_column_index,
327            ts_column_index: self.ts_column_index,
328            quantile: self.quantile,
329            output_schema: self.output_schema.clone(),
330            field_column_index: self.field_column_index,
331            properties: self.properties.clone(),
332        }))
333    }
334
335    fn execute(
336        &self,
337        partition: usize,
338        context: Arc<TaskContext>,
339    ) -> DataFusionResult<SendableRecordBatchStream> {
340        let baseline_metric = BaselineMetrics::new(&self.metric, partition);
341
342        let batch_size = context.session_config().batch_size();
343        let input = self.input.execute(partition, context)?;
344        let output_schema = self.output_schema.clone();
345
346        let mut normal_indices = (0..input.schema().fields().len()).collect::<HashSet<_>>();
347        normal_indices.remove(&self.field_column_index);
348        normal_indices.remove(&self.le_column_index);
349        Ok(Box::pin(HistogramFoldStream {
350            le_column_index: self.le_column_index,
351            field_column_index: self.field_column_index,
352            quantile: self.quantile,
353            normal_indices: normal_indices.into_iter().collect(),
354            bucket_size: None,
355            input_buffer: vec![],
356            input,
357            output_schema,
358            metric: baseline_metric,
359            batch_size,
360            input_buffered_rows: 0,
361            output_buffer: HistogramFoldStream::empty_output_buffer(
362                &self.output_schema,
363                self.le_column_index,
364            )?,
365            output_buffered_rows: 0,
366        }))
367    }
368
369    fn metrics(&self) -> Option<MetricsSet> {
370        Some(self.metric.clone_inner())
371    }
372
373    fn statistics(&self) -> DataFusionResult<Statistics> {
374        Ok(Statistics {
375            num_rows: Precision::Absent,
376            total_byte_size: Precision::Absent,
377            column_statistics: vec![
378                ColumnStatistics::new_unknown();
379                // plus one more for the removed column by function `convert_schema`
380                self.schema().flattened_fields().len() + 1
381            ],
382        })
383    }
384
385    fn name(&self) -> &str {
386        "HistogramFoldExec"
387    }
388}
389
390impl HistogramFoldExec {
391    /// Return all the [PhysicalExpr] of tag columns in order.
392    ///
393    /// Tag columns are all columns except `le`, `field` and `ts` columns.
394    pub fn tag_col_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
395        self.input
396            .schema()
397            .fields()
398            .iter()
399            .enumerate()
400            .filter_map(|(idx, field)| {
401                if idx == self.le_column_index
402                    || idx == self.field_column_index
403                    || idx == self.ts_column_index
404                {
405                    None
406                } else {
407                    Some(Arc::new(PhyColumn::new(field.name(), idx)) as _)
408                }
409            })
410            .collect()
411    }
412}
413
414impl DisplayAs for HistogramFoldExec {
415    fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
416        match t {
417            DisplayFormatType::Default | DisplayFormatType::Verbose => {
418                write!(
419                    f,
420                    "HistogramFoldExec: le=@{}, field=@{}, quantile={}",
421                    self.le_column_index, self.field_column_index, self.quantile
422                )
423            }
424        }
425    }
426}
427
428pub struct HistogramFoldStream {
429    // internal states
430    le_column_index: usize,
431    field_column_index: usize,
432    quantile: f64,
433    /// Columns need not folding. This indices is based on input schema
434    normal_indices: Vec<usize>,
435    bucket_size: Option<usize>,
436    /// Expected output batch size
437    batch_size: usize,
438    output_schema: SchemaRef,
439
440    // buffers
441    input_buffer: Vec<RecordBatch>,
442    input_buffered_rows: usize,
443    output_buffer: Vec<Box<dyn MutableVector>>,
444    output_buffered_rows: usize,
445
446    // runtime things
447    input: SendableRecordBatchStream,
448    metric: BaselineMetrics,
449}
450
451impl RecordBatchStream for HistogramFoldStream {
452    fn schema(&self) -> SchemaRef {
453        self.output_schema.clone()
454    }
455}
456
457impl Stream for HistogramFoldStream {
458    type Item = DataFusionResult<RecordBatch>;
459
460    fn poll_next(
461        mut self: std::pin::Pin<&mut Self>,
462        cx: &mut std::task::Context<'_>,
463    ) -> Poll<Option<Self::Item>> {
464        let poll = loop {
465            match ready!(self.input.poll_next_unpin(cx)) {
466                Some(batch) => {
467                    let batch = batch?;
468                    let timer = Instant::now();
469                    let Some(result) = self.fold_input(batch)? else {
470                        self.metric.elapsed_compute().add_elapsed(timer);
471                        continue;
472                    };
473                    self.metric.elapsed_compute().add_elapsed(timer);
474                    break Poll::Ready(Some(result));
475                }
476                None => break Poll::Ready(self.take_output_buf()?.map(Ok)),
477            }
478        };
479        self.metric.record_poll(poll)
480    }
481}
482
483impl HistogramFoldStream {
484    /// The inner most `Result` is for `poll_next()`
485    pub fn fold_input(
486        &mut self,
487        input: RecordBatch,
488    ) -> DataFusionResult<Option<DataFusionResult<RecordBatch>>> {
489        let Some(bucket_num) = self.calculate_bucket_num(&input)? else {
490            return Ok(None);
491        };
492
493        if self.input_buffered_rows + input.num_rows() < bucket_num {
494            // not enough rows to fold
495            self.push_input_buf(input);
496            return Ok(None);
497        }
498
499        self.fold_buf(bucket_num, input)?;
500        if self.output_buffered_rows >= self.batch_size {
501            return Ok(self.take_output_buf()?.map(Ok));
502        }
503
504        Ok(None)
505    }
506
507    /// Generate a group of empty [MutableVector]s from the output schema.
508    ///
509    /// For simplicity, this method will insert a placeholder for `le`. So that
510    /// the output buffers has the same schema with input. This placeholder needs
511    /// to be removed before returning the output batch.
512    pub fn empty_output_buffer(
513        schema: &SchemaRef,
514        le_column_index: usize,
515    ) -> DataFusionResult<Vec<Box<dyn MutableVector>>> {
516        let mut builders = Vec::with_capacity(schema.fields().len() + 1);
517        for field in schema.fields() {
518            let concrete_datatype = ConcreteDataType::try_from(field.data_type()).unwrap();
519            let mutable_vector = concrete_datatype.create_mutable_vector(0);
520            builders.push(mutable_vector);
521        }
522        builders.insert(
523            le_column_index,
524            ConcreteDataType::float64_datatype().create_mutable_vector(0),
525        );
526
527        Ok(builders)
528    }
529
530    fn calculate_bucket_num(&mut self, batch: &RecordBatch) -> DataFusionResult<Option<usize>> {
531        if let Some(size) = self.bucket_size {
532            return Ok(Some(size));
533        }
534
535        let inf_pos = self.find_positive_inf(batch)?;
536        if inf_pos == batch.num_rows() {
537            // no positive inf found, append to buffer and wait for next batch
538            self.push_input_buf(batch.clone());
539            return Ok(None);
540        }
541
542        // else we found the positive inf.
543        // calculate the bucket size
544        let bucket_size = inf_pos + self.input_buffered_rows + 1;
545        Ok(Some(bucket_size))
546    }
547
548    /// Fold record batches from input buffer and put to output buffer
549    fn fold_buf(&mut self, bucket_num: usize, input: RecordBatch) -> DataFusionResult<()> {
550        self.push_input_buf(input);
551        // TODO(ruihang): this concat is avoidable.
552        let batch = concat_batches(&self.input.schema(), self.input_buffer.drain(..).as_ref())?;
553        let mut remaining_rows = self.input_buffered_rows;
554        let mut cursor = 0;
555
556        let gt_schema = GtSchema::try_from(self.input.schema()).unwrap();
557        let batch = GtRecordBatch::try_from_df_record_batch(Arc::new(gt_schema), batch).unwrap();
558
559        while remaining_rows >= bucket_num {
560            // "sample" normal columns
561            for normal_index in &self.normal_indices {
562                let val = batch.column(*normal_index).get(cursor);
563                self.output_buffer[*normal_index].push_value_ref(val.as_value_ref());
564            }
565            // "fold" `le` and field columns
566            let le_array = batch.column(self.le_column_index);
567            let field_array = batch.column(self.field_column_index);
568            let mut bucket = vec![];
569            let mut counters = vec![];
570            for bias in 0..bucket_num {
571                let le_str_val = le_array.get(cursor + bias);
572                let le_str_val_ref = le_str_val.as_value_ref();
573                let le_str = le_str_val_ref
574                    .as_string()
575                    .unwrap()
576                    .expect("le column should not be nullable");
577                let le = le_str.parse::<f64>().unwrap();
578                bucket.push(le);
579
580                let counter = field_array
581                    .get(cursor + bias)
582                    .as_value_ref()
583                    .as_f64()
584                    .unwrap()
585                    .expect("field column should not be nullable");
586                counters.push(counter);
587            }
588            // ignore invalid data
589            let result = Self::evaluate_row(self.quantile, &bucket, &counters).unwrap_or(f64::NAN);
590            self.output_buffer[self.field_column_index].push_value_ref(ValueRef::from(result));
591            cursor += bucket_num;
592            remaining_rows -= bucket_num;
593            self.output_buffered_rows += 1;
594        }
595
596        let remaining_input_batch = batch.into_df_record_batch().slice(cursor, remaining_rows);
597        self.input_buffered_rows = remaining_input_batch.num_rows();
598        self.input_buffer.push(remaining_input_batch);
599
600        Ok(())
601    }
602
603    fn push_input_buf(&mut self, batch: RecordBatch) {
604        self.input_buffered_rows += batch.num_rows();
605        self.input_buffer.push(batch);
606    }
607
608    /// Compute result from output buffer
609    fn take_output_buf(&mut self) -> DataFusionResult<Option<RecordBatch>> {
610        if self.output_buffered_rows == 0 {
611            if self.input_buffered_rows != 0 {
612                warn!(
613                    "input buffer is not empty, {} rows remaining",
614                    self.input_buffered_rows
615                );
616            }
617            return Ok(None);
618        }
619
620        let mut output_buf = Self::empty_output_buffer(&self.output_schema, self.le_column_index)?;
621        std::mem::swap(&mut self.output_buffer, &mut output_buf);
622        let mut columns = Vec::with_capacity(output_buf.len());
623        for builder in output_buf.iter_mut() {
624            columns.push(builder.to_vector().to_arrow_array());
625        }
626        // remove the placeholder column for `le`
627        columns.remove(self.le_column_index);
628
629        self.output_buffered_rows = 0;
630        RecordBatch::try_new(self.output_schema.clone(), columns)
631            .map(Some)
632            .map_err(|e| DataFusionError::ArrowError(e, None))
633    }
634
635    /// Find the first `+Inf` which indicates the end of the bucket group
636    ///
637    /// If the return value equals to batch's num_rows means the it's not found
638    /// in this batch
639    fn find_positive_inf(&self, batch: &RecordBatch) -> DataFusionResult<usize> {
640        // fuse this function. It should not be called when the
641        // bucket size is already know.
642        if let Some(bucket_size) = self.bucket_size {
643            return Ok(bucket_size);
644        }
645        let string_le_array = batch.column(self.le_column_index);
646        let float_le_array = compute::cast(&string_le_array, &DataType::Float64).map_err(|e| {
647            DataFusionError::Execution(format!(
648                "cannot cast {} array to float64 array: {:?}",
649                string_le_array.data_type(),
650                e
651            ))
652        })?;
653        let le_as_f64_array = float_le_array
654            .as_primitive_opt::<Float64Type>()
655            .ok_or_else(|| {
656                DataFusionError::Execution(format!(
657                    "expect a float64 array, but found {}",
658                    float_le_array.data_type()
659                ))
660            })?;
661        for (i, v) in le_as_f64_array.iter().enumerate() {
662            if let Some(v) = v
663                && v == f64::INFINITY
664            {
665                return Ok(i);
666            }
667        }
668
669        Ok(batch.num_rows())
670    }
671
672    /// Evaluate the field column and return the result
673    fn evaluate_row(quantile: f64, bucket: &[f64], counter: &[f64]) -> DataFusionResult<f64> {
674        // check bucket
675        if bucket.len() <= 1 {
676            return Ok(f64::NAN);
677        }
678        if bucket.last().unwrap().is_finite() {
679            return Err(DataFusionError::Execution(
680                "last bucket should be +Inf".to_string(),
681            ));
682        }
683        if bucket.len() != counter.len() {
684            return Err(DataFusionError::Execution(
685                "bucket and counter should have the same length".to_string(),
686            ));
687        }
688        // check quantile
689        if quantile < 0.0 {
690            return Ok(f64::NEG_INFINITY);
691        } else if quantile > 1.0 {
692            return Ok(f64::INFINITY);
693        } else if quantile.is_nan() {
694            return Ok(f64::NAN);
695        }
696
697        // check input value
698        debug_assert!(bucket.windows(2).all(|w| w[0] <= w[1]), "{bucket:?}");
699        debug_assert!(counter.windows(2).all(|w| w[0] <= w[1]), "{counter:?}");
700
701        let total = *counter.last().unwrap();
702        let expected_pos = total * quantile;
703        let mut fit_bucket_pos = 0;
704        while fit_bucket_pos < bucket.len() && counter[fit_bucket_pos] < expected_pos {
705            fit_bucket_pos += 1;
706        }
707        if fit_bucket_pos >= bucket.len() - 1 {
708            Ok(bucket[bucket.len() - 2])
709        } else {
710            let upper_bound = bucket[fit_bucket_pos];
711            let upper_count = counter[fit_bucket_pos];
712            let mut lower_bound = bucket[0].min(0.0);
713            let mut lower_count = 0.0;
714            if fit_bucket_pos > 0 {
715                lower_bound = bucket[fit_bucket_pos - 1];
716                lower_count = counter[fit_bucket_pos - 1];
717            }
718            Ok(lower_bound
719                + (upper_bound - lower_bound) / (upper_count - lower_count)
720                    * (expected_pos - lower_count))
721        }
722    }
723}
724
725#[cfg(test)]
726mod test {
727    use std::sync::Arc;
728
729    use datafusion::arrow::array::Float64Array;
730    use datafusion::arrow::datatypes::{Field, Schema};
731    use datafusion::common::ToDFSchema;
732    use datafusion::physical_plan::memory::MemoryExec;
733    use datafusion::prelude::SessionContext;
734    use datatypes::arrow_array::StringArray;
735
736    use super::*;
737
738    fn prepare_test_data() -> MemoryExec {
739        let schema = Arc::new(Schema::new(vec![
740            Field::new("host", DataType::Utf8, true),
741            Field::new("le", DataType::Utf8, true),
742            Field::new("val", DataType::Float64, true),
743        ]));
744
745        // 12 items
746        let host_column_1 = Arc::new(StringArray::from(vec![
747            "host_1", "host_1", "host_1", "host_1", "host_1", "host_1", "host_1", "host_1",
748            "host_1", "host_1", "host_1", "host_1",
749        ])) as _;
750        let le_column_1 = Arc::new(StringArray::from(vec![
751            "0.001", "0.1", "10", "1000", "+Inf", "0.001", "0.1", "10", "1000", "+inf", "0.001",
752            "0.1",
753        ])) as _;
754        let val_column_1 = Arc::new(Float64Array::from(vec![
755            0_0.0, 1.0, 1.0, 5.0, 5.0, 0_0.0, 20.0, 60.0, 70.0, 100.0, 0_1.0, 1.0,
756        ])) as _;
757
758        // 2 items
759        let host_column_2 = Arc::new(StringArray::from(vec!["host_1", "host_1"])) as _;
760        let le_column_2 = Arc::new(StringArray::from(vec!["10", "1000"])) as _;
761        let val_column_2 = Arc::new(Float64Array::from(vec![1.0, 1.0])) as _;
762
763        // 11 items
764        let host_column_3 = Arc::new(StringArray::from(vec![
765            "host_1", "host_2", "host_2", "host_2", "host_2", "host_2", "host_2", "host_2",
766            "host_2", "host_2", "host_2",
767        ])) as _;
768        let le_column_3 = Arc::new(StringArray::from(vec![
769            "+INF", "0.001", "0.1", "10", "1000", "+iNf", "0.001", "0.1", "10", "1000", "+Inf",
770        ])) as _;
771        let val_column_3 = Arc::new(Float64Array::from(vec![
772            1.0, 0_0.0, 0.0, 0.0, 0.0, 0.0, 0_0.0, 1.0, 2.0, 3.0, 4.0,
773        ])) as _;
774
775        let data_1 = RecordBatch::try_new(
776            schema.clone(),
777            vec![host_column_1, le_column_1, val_column_1],
778        )
779        .unwrap();
780        let data_2 = RecordBatch::try_new(
781            schema.clone(),
782            vec![host_column_2, le_column_2, val_column_2],
783        )
784        .unwrap();
785        let data_3 = RecordBatch::try_new(
786            schema.clone(),
787            vec![host_column_3, le_column_3, val_column_3],
788        )
789        .unwrap();
790
791        MemoryExec::try_new(&[vec![data_1, data_2, data_3]], schema, None).unwrap()
792    }
793
794    #[tokio::test]
795    async fn fold_overall() {
796        let memory_exec = Arc::new(prepare_test_data());
797        let output_schema: SchemaRef = Arc::new(
798            (*HistogramFold::convert_schema(
799                &Arc::new(memory_exec.schema().to_dfschema().unwrap()),
800                "le",
801            )
802            .unwrap()
803            .as_ref())
804            .clone()
805            .into(),
806        );
807        let properties = PlanProperties::new(
808            EquivalenceProperties::new(output_schema.clone()),
809            Partitioning::UnknownPartitioning(1),
810            EmissionType::Incremental,
811            Boundedness::Bounded,
812        );
813        let fold_exec = Arc::new(HistogramFoldExec {
814            le_column_index: 1,
815            field_column_index: 2,
816            quantile: 0.4,
817            ts_column_index: 9999, // not exist but doesn't matter
818            input: memory_exec,
819            output_schema,
820            metric: ExecutionPlanMetricsSet::new(),
821            properties,
822        });
823
824        let session_context = SessionContext::default();
825        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
826            .await
827            .unwrap();
828        let result_literal = datatypes::arrow::util::pretty::pretty_format_batches(&result)
829            .unwrap()
830            .to_string();
831
832        let expected = String::from(
833            "+--------+-------------------+
834| host   | val               |
835+--------+-------------------+
836| host_1 | 257.5             |
837| host_1 | 5.05              |
838| host_1 | 0.0004            |
839| host_2 | NaN               |
840| host_2 | 6.040000000000001 |
841+--------+-------------------+",
842        );
843        assert_eq!(result_literal, expected);
844    }
845
846    #[test]
847    fn confirm_schema() {
848        let input_schema = Schema::new(vec![
849            Field::new("host", DataType::Utf8, true),
850            Field::new("le", DataType::Utf8, true),
851            Field::new("val", DataType::Float64, true),
852        ])
853        .to_dfschema_ref()
854        .unwrap();
855        let expected_output_schema = Schema::new(vec![
856            Field::new("host", DataType::Utf8, true),
857            Field::new("val", DataType::Float64, true),
858        ])
859        .to_dfschema_ref()
860        .unwrap();
861
862        let actual = HistogramFold::convert_schema(&input_schema, "le").unwrap();
863        assert_eq!(actual, expected_output_schema)
864    }
865
866    #[test]
867    fn evaluate_row_normal_case() {
868        let bucket = [0.0, 1.0, 2.0, 3.0, 4.0, f64::INFINITY];
869
870        #[derive(Debug)]
871        struct Case {
872            quantile: f64,
873            counters: Vec<f64>,
874            expected: f64,
875        }
876
877        let cases = [
878            Case {
879                quantile: 0.9,
880                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
881                expected: 4.0,
882            },
883            Case {
884                quantile: 0.89,
885                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
886                expected: 4.0,
887            },
888            Case {
889                quantile: 0.78,
890                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
891                expected: 3.9,
892            },
893            Case {
894                quantile: 0.5,
895                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
896                expected: 2.5,
897            },
898            Case {
899                quantile: 0.5,
900                counters: vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
901                expected: f64::NAN,
902            },
903            Case {
904                quantile: 1.0,
905                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
906                expected: 4.0,
907            },
908            Case {
909                quantile: 0.0,
910                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
911                expected: f64::NAN,
912            },
913            Case {
914                quantile: 1.1,
915                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
916                expected: f64::INFINITY,
917            },
918            Case {
919                quantile: -1.0,
920                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
921                expected: f64::NEG_INFINITY,
922            },
923        ];
924
925        for case in cases {
926            let actual =
927                HistogramFoldStream::evaluate_row(case.quantile, &bucket, &case.counters).unwrap();
928            assert_eq!(
929                format!("{actual}"),
930                format!("{}", case.expected),
931                "{:?}",
932                case
933            );
934        }
935    }
936
937    #[test]
938    #[should_panic]
939    fn evaluate_out_of_order_input() {
940        let bucket = [0.0, 1.0, 2.0, 3.0, 4.0, f64::INFINITY];
941        let counters = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0];
942        HistogramFoldStream::evaluate_row(0.5, &bucket, &counters).unwrap();
943    }
944
945    #[test]
946    fn evaluate_wrong_bucket() {
947        let bucket = [0.0, 1.0, 2.0, 3.0, 4.0, f64::INFINITY, 5.0];
948        let counters = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
949        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters);
950        assert!(result.is_err());
951    }
952
953    #[test]
954    fn evaluate_small_fraction() {
955        let bucket = [0.0, 2.0, 4.0, 6.0, f64::INFINITY];
956        let counters = [0.0, 1.0 / 300.0, 2.0 / 300.0, 0.01, 0.01];
957        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters).unwrap();
958        assert_eq!(3.0, result);
959    }
960}