promql/extension_plan/
histogram_fold.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::collections::{HashMap, HashSet};
17use std::sync::Arc;
18use std::task::Poll;
19use std::time::Instant;
20
21use common_recordbatch::RecordBatch as GtRecordBatch;
22use common_telemetry::warn;
23use datafusion::arrow::array::AsArray;
24use datafusion::arrow::compute::{self, concat_batches, SortOptions};
25use datafusion::arrow::datatypes::{DataType, Float64Type, SchemaRef};
26use datafusion::arrow::record_batch::RecordBatch;
27use datafusion::common::stats::Precision;
28use datafusion::common::{ColumnStatistics, DFSchema, DFSchemaRef, Statistics};
29use datafusion::error::{DataFusionError, Result as DataFusionResult};
30use datafusion::execution::TaskContext;
31use datafusion::logical_expr::{LogicalPlan, UserDefinedLogicalNodeCore};
32use datafusion::physical_expr::{
33    EquivalenceProperties, LexRequirement, OrderingRequirements, PhysicalSortRequirement,
34};
35use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
36use datafusion::physical_plan::expressions::{CastExpr as PhyCast, Column as PhyColumn};
37use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
38use datafusion::physical_plan::{
39    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr,
40    PlanProperties, RecordBatchStream, SendableRecordBatchStream,
41};
42use datafusion::prelude::{Column, Expr};
43use datatypes::prelude::{ConcreteDataType, DataType as GtDataType};
44use datatypes::schema::Schema as GtSchema;
45use datatypes::value::{OrderedF64, ValueRef};
46use datatypes::vectors::MutableVector;
47use futures::{ready, Stream, StreamExt};
48
49/// `HistogramFold` will fold the conventional (non-native) histogram ([1]) for later
50/// computing.
51///
52/// Specifically, it will transform the `le` and `field` column into a complex
53/// type, and samples on other tag columns:
54/// - `le` will become a [ListArray] of [f64]. With each bucket bound parsed
55/// - `field` will become a [ListArray] of [f64]
56/// - other columns will be sampled every `bucket_num` element, but their types won't change.
57///
58/// Due to the folding or sampling, the output rows number will become `input_rows` / `bucket_num`.
59///
60/// # Requirement
61/// - Input should be sorted on `<tag list>, ts, le ASC`.
62/// - The value set of `le` should be same. I.e., buckets of every series should be same.
63///
64/// [1]: https://prometheus.io/docs/concepts/metric_types/#histogram
65#[derive(Debug, PartialEq, Hash, Eq)]
66pub struct HistogramFold {
67    /// Name of the `le` column. It's a special column in prometheus
68    /// for implementing conventional histogram. It's a string column
69    /// with "literal" float value, like "+Inf", "0.001" etc.
70    le_column: String,
71    ts_column: String,
72    input: LogicalPlan,
73    field_column: String,
74    quantile: OrderedF64,
75    output_schema: DFSchemaRef,
76}
77
78impl UserDefinedLogicalNodeCore for HistogramFold {
79    fn name(&self) -> &str {
80        Self::name()
81    }
82
83    fn inputs(&self) -> Vec<&LogicalPlan> {
84        vec![&self.input]
85    }
86
87    fn schema(&self) -> &DFSchemaRef {
88        &self.output_schema
89    }
90
91    fn expressions(&self) -> Vec<Expr> {
92        vec![]
93    }
94
95    fn fmt_for_explain(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        write!(
97            f,
98            "HistogramFold: le={}, field={}, quantile={}",
99            self.le_column, self.field_column, self.quantile
100        )
101    }
102
103    fn with_exprs_and_inputs(
104        &self,
105        _exprs: Vec<Expr>,
106        inputs: Vec<LogicalPlan>,
107    ) -> DataFusionResult<Self> {
108        Ok(Self {
109            le_column: self.le_column.clone(),
110            ts_column: self.ts_column.clone(),
111            input: inputs.into_iter().next().unwrap(),
112            field_column: self.field_column.clone(),
113            quantile: self.quantile,
114            // This method cannot return error. Otherwise we should re-calculate
115            // the output schema
116            output_schema: self.output_schema.clone(),
117        })
118    }
119}
120
121impl HistogramFold {
122    pub fn new(
123        le_column: String,
124        field_column: String,
125        ts_column: String,
126        quantile: f64,
127        input: LogicalPlan,
128    ) -> DataFusionResult<Self> {
129        let input_schema = input.schema();
130        Self::check_schema(input_schema, &le_column, &field_column, &ts_column)?;
131        let output_schema = Self::convert_schema(input_schema, &le_column)?;
132        Ok(Self {
133            le_column,
134            ts_column,
135            input,
136            field_column,
137            quantile: quantile.into(),
138            output_schema,
139        })
140    }
141
142    pub const fn name() -> &'static str {
143        "HistogramFold"
144    }
145
146    fn check_schema(
147        input_schema: &DFSchemaRef,
148        le_column: &str,
149        field_column: &str,
150        ts_column: &str,
151    ) -> DataFusionResult<()> {
152        let check_column = |col| {
153            if !input_schema.has_column_with_unqualified_name(col) {
154                Err(DataFusionError::SchemaError(
155                    Box::new(datafusion::common::SchemaError::FieldNotFound {
156                        field: Box::new(Column::new(None::<String>, col)),
157                        valid_fields: input_schema.columns(),
158                    }),
159                    Box::new(None),
160                ))
161            } else {
162                Ok(())
163            }
164        };
165
166        check_column(le_column)?;
167        check_column(ts_column)?;
168        check_column(field_column)
169    }
170
171    pub fn to_execution_plan(&self, exec_input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
172        let input_schema = self.input.schema();
173        // safety: those fields are checked in `check_schema()`
174        let le_column_index = input_schema
175            .index_of_column_by_name(None, &self.le_column)
176            .unwrap();
177        let field_column_index = input_schema
178            .index_of_column_by_name(None, &self.field_column)
179            .unwrap();
180        let ts_column_index = input_schema
181            .index_of_column_by_name(None, &self.ts_column)
182            .unwrap();
183
184        let output_schema: SchemaRef = Arc::new(self.output_schema.as_ref().into());
185        let properties = PlanProperties::new(
186            EquivalenceProperties::new(output_schema.clone()),
187            Partitioning::UnknownPartitioning(1),
188            EmissionType::Incremental,
189            Boundedness::Bounded,
190        );
191        Arc::new(HistogramFoldExec {
192            le_column_index,
193            field_column_index,
194            ts_column_index,
195            input: exec_input,
196            quantile: self.quantile.into(),
197            output_schema,
198            metric: ExecutionPlanMetricsSet::new(),
199            properties,
200        })
201    }
202
203    /// Transform the schema
204    ///
205    /// - `le` will be removed
206    fn convert_schema(
207        input_schema: &DFSchemaRef,
208        le_column: &str,
209    ) -> DataFusionResult<DFSchemaRef> {
210        let fields = input_schema.fields();
211        // safety: those fields are checked in `check_schema()`
212        let mut new_fields = Vec::with_capacity(fields.len() - 1);
213        for f in fields {
214            if f.name() != le_column {
215                new_fields.push((None, f.clone()));
216            }
217        }
218        Ok(Arc::new(DFSchema::new_with_metadata(
219            new_fields,
220            HashMap::new(),
221        )?))
222    }
223}
224
225impl PartialOrd for HistogramFold {
226    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
227        // Compare fields in order excluding output_schema
228        match self.le_column.partial_cmp(&other.le_column) {
229            Some(core::cmp::Ordering::Equal) => {}
230            ord => return ord,
231        }
232        match self.ts_column.partial_cmp(&other.ts_column) {
233            Some(core::cmp::Ordering::Equal) => {}
234            ord => return ord,
235        }
236        match self.input.partial_cmp(&other.input) {
237            Some(core::cmp::Ordering::Equal) => {}
238            ord => return ord,
239        }
240        match self.field_column.partial_cmp(&other.field_column) {
241            Some(core::cmp::Ordering::Equal) => {}
242            ord => return ord,
243        }
244        self.quantile.partial_cmp(&other.quantile)
245    }
246}
247
248#[derive(Debug)]
249pub struct HistogramFoldExec {
250    /// Index for `le` column in the schema of input.
251    le_column_index: usize,
252    input: Arc<dyn ExecutionPlan>,
253    output_schema: SchemaRef,
254    /// Index for field column in the schema of input.
255    field_column_index: usize,
256    ts_column_index: usize,
257    quantile: f64,
258    metric: ExecutionPlanMetricsSet,
259    properties: PlanProperties,
260}
261
262impl ExecutionPlan for HistogramFoldExec {
263    fn as_any(&self) -> &dyn Any {
264        self
265    }
266
267    fn properties(&self) -> &PlanProperties {
268        &self.properties
269    }
270
271    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
272        let mut cols = self
273            .tag_col_exprs()
274            .into_iter()
275            .map(|expr| PhysicalSortRequirement {
276                expr,
277                options: None,
278            })
279            .collect::<Vec<PhysicalSortRequirement>>();
280        // add ts
281        cols.push(PhysicalSortRequirement {
282            expr: Arc::new(PhyColumn::new(
283                self.input.schema().field(self.ts_column_index).name(),
284                self.ts_column_index,
285            )),
286            options: None,
287        });
288        // add le ASC
289        cols.push(PhysicalSortRequirement {
290            expr: Arc::new(PhyCast::new(
291                Arc::new(PhyColumn::new(
292                    self.input.schema().field(self.le_column_index).name(),
293                    self.le_column_index,
294                )),
295                DataType::Float64,
296                None,
297            )),
298            options: Some(SortOptions {
299                descending: false,  // +INF in the last
300                nulls_first: false, // not nullable
301            }),
302        });
303
304        // Safety: `cols` is not empty
305        let requirement = LexRequirement::new(cols).unwrap();
306
307        vec![Some(OrderingRequirements::Hard(vec![requirement]))]
308    }
309
310    fn required_input_distribution(&self) -> Vec<Distribution> {
311        self.input.required_input_distribution()
312    }
313
314    fn maintains_input_order(&self) -> Vec<bool> {
315        vec![true; self.children().len()]
316    }
317
318    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
319        vec![&self.input]
320    }
321
322    // cannot change schema with this method
323    fn with_new_children(
324        self: Arc<Self>,
325        children: Vec<Arc<dyn ExecutionPlan>>,
326    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
327        assert!(!children.is_empty());
328        Ok(Arc::new(Self {
329            input: children[0].clone(),
330            metric: self.metric.clone(),
331            le_column_index: self.le_column_index,
332            ts_column_index: self.ts_column_index,
333            quantile: self.quantile,
334            output_schema: self.output_schema.clone(),
335            field_column_index: self.field_column_index,
336            properties: self.properties.clone(),
337        }))
338    }
339
340    fn execute(
341        &self,
342        partition: usize,
343        context: Arc<TaskContext>,
344    ) -> DataFusionResult<SendableRecordBatchStream> {
345        let baseline_metric = BaselineMetrics::new(&self.metric, partition);
346
347        let batch_size = context.session_config().batch_size();
348        let input = self.input.execute(partition, context)?;
349        let output_schema = self.output_schema.clone();
350
351        let mut normal_indices = (0..input.schema().fields().len()).collect::<HashSet<_>>();
352        normal_indices.remove(&self.field_column_index);
353        normal_indices.remove(&self.le_column_index);
354        Ok(Box::pin(HistogramFoldStream {
355            le_column_index: self.le_column_index,
356            field_column_index: self.field_column_index,
357            quantile: self.quantile,
358            normal_indices: normal_indices.into_iter().collect(),
359            bucket_size: None,
360            input_buffer: vec![],
361            input,
362            output_schema,
363            metric: baseline_metric,
364            batch_size,
365            input_buffered_rows: 0,
366            output_buffer: HistogramFoldStream::empty_output_buffer(
367                &self.output_schema,
368                self.le_column_index,
369            )?,
370            output_buffered_rows: 0,
371        }))
372    }
373
374    fn metrics(&self) -> Option<MetricsSet> {
375        Some(self.metric.clone_inner())
376    }
377
378    fn partition_statistics(&self, _: Option<usize>) -> DataFusionResult<Statistics> {
379        Ok(Statistics {
380            num_rows: Precision::Absent,
381            total_byte_size: Precision::Absent,
382            column_statistics: vec![
383                ColumnStatistics::new_unknown();
384                // plus one more for the removed column by function `convert_schema`
385                self.schema().flattened_fields().len() + 1
386            ],
387        })
388    }
389
390    fn name(&self) -> &str {
391        "HistogramFoldExec"
392    }
393}
394
395impl HistogramFoldExec {
396    /// Return all the [PhysicalExpr] of tag columns in order.
397    ///
398    /// Tag columns are all columns except `le`, `field` and `ts` columns.
399    pub fn tag_col_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
400        self.input
401            .schema()
402            .fields()
403            .iter()
404            .enumerate()
405            .filter_map(|(idx, field)| {
406                if idx == self.le_column_index
407                    || idx == self.field_column_index
408                    || idx == self.ts_column_index
409                {
410                    None
411                } else {
412                    Some(Arc::new(PhyColumn::new(field.name(), idx)) as _)
413                }
414            })
415            .collect()
416    }
417}
418
419impl DisplayAs for HistogramFoldExec {
420    fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
421        match t {
422            DisplayFormatType::Default
423            | DisplayFormatType::Verbose
424            | DisplayFormatType::TreeRender => {
425                write!(
426                    f,
427                    "HistogramFoldExec: le=@{}, field=@{}, quantile={}",
428                    self.le_column_index, self.field_column_index, self.quantile
429                )
430            }
431        }
432    }
433}
434
435pub struct HistogramFoldStream {
436    // internal states
437    le_column_index: usize,
438    field_column_index: usize,
439    quantile: f64,
440    /// Columns need not folding. This indices is based on input schema
441    normal_indices: Vec<usize>,
442    bucket_size: Option<usize>,
443    /// Expected output batch size
444    batch_size: usize,
445    output_schema: SchemaRef,
446
447    // buffers
448    input_buffer: Vec<RecordBatch>,
449    input_buffered_rows: usize,
450    output_buffer: Vec<Box<dyn MutableVector>>,
451    output_buffered_rows: usize,
452
453    // runtime things
454    input: SendableRecordBatchStream,
455    metric: BaselineMetrics,
456}
457
458impl RecordBatchStream for HistogramFoldStream {
459    fn schema(&self) -> SchemaRef {
460        self.output_schema.clone()
461    }
462}
463
464impl Stream for HistogramFoldStream {
465    type Item = DataFusionResult<RecordBatch>;
466
467    fn poll_next(
468        mut self: std::pin::Pin<&mut Self>,
469        cx: &mut std::task::Context<'_>,
470    ) -> Poll<Option<Self::Item>> {
471        let poll = loop {
472            match ready!(self.input.poll_next_unpin(cx)) {
473                Some(batch) => {
474                    let batch = batch?;
475                    let timer = Instant::now();
476                    let Some(result) = self.fold_input(batch)? else {
477                        self.metric.elapsed_compute().add_elapsed(timer);
478                        continue;
479                    };
480                    self.metric.elapsed_compute().add_elapsed(timer);
481                    break Poll::Ready(Some(result));
482                }
483                None => break Poll::Ready(self.take_output_buf()?.map(Ok)),
484            }
485        };
486        self.metric.record_poll(poll)
487    }
488}
489
490impl HistogramFoldStream {
491    /// The inner most `Result` is for `poll_next()`
492    pub fn fold_input(
493        &mut self,
494        input: RecordBatch,
495    ) -> DataFusionResult<Option<DataFusionResult<RecordBatch>>> {
496        let Some(bucket_num) = self.calculate_bucket_num(&input)? else {
497            return Ok(None);
498        };
499
500        if self.input_buffered_rows + input.num_rows() < bucket_num {
501            // not enough rows to fold
502            self.push_input_buf(input);
503            return Ok(None);
504        }
505
506        self.fold_buf(bucket_num, input)?;
507        if self.output_buffered_rows >= self.batch_size {
508            return Ok(self.take_output_buf()?.map(Ok));
509        }
510
511        Ok(None)
512    }
513
514    /// Generate a group of empty [MutableVector]s from the output schema.
515    ///
516    /// For simplicity, this method will insert a placeholder for `le`. So that
517    /// the output buffers has the same schema with input. This placeholder needs
518    /// to be removed before returning the output batch.
519    pub fn empty_output_buffer(
520        schema: &SchemaRef,
521        le_column_index: usize,
522    ) -> DataFusionResult<Vec<Box<dyn MutableVector>>> {
523        let mut builders = Vec::with_capacity(schema.fields().len() + 1);
524        for field in schema.fields() {
525            let concrete_datatype = ConcreteDataType::try_from(field.data_type()).unwrap();
526            let mutable_vector = concrete_datatype.create_mutable_vector(0);
527            builders.push(mutable_vector);
528        }
529        builders.insert(
530            le_column_index,
531            ConcreteDataType::float64_datatype().create_mutable_vector(0),
532        );
533
534        Ok(builders)
535    }
536
537    fn calculate_bucket_num(&mut self, batch: &RecordBatch) -> DataFusionResult<Option<usize>> {
538        if let Some(size) = self.bucket_size {
539            return Ok(Some(size));
540        }
541
542        let inf_pos = self.find_positive_inf(batch)?;
543        if inf_pos == batch.num_rows() {
544            // no positive inf found, append to buffer and wait for next batch
545            self.push_input_buf(batch.clone());
546            return Ok(None);
547        }
548
549        // else we found the positive inf.
550        // calculate the bucket size
551        let bucket_size = inf_pos + self.input_buffered_rows + 1;
552        Ok(Some(bucket_size))
553    }
554
555    /// Fold record batches from input buffer and put to output buffer
556    fn fold_buf(&mut self, bucket_num: usize, input: RecordBatch) -> DataFusionResult<()> {
557        self.push_input_buf(input);
558        // TODO(ruihang): this concat is avoidable.
559        let batch = concat_batches(&self.input.schema(), self.input_buffer.drain(..).as_ref())?;
560        let mut remaining_rows = self.input_buffered_rows;
561        let mut cursor = 0;
562
563        let gt_schema = GtSchema::try_from(self.input.schema()).unwrap();
564        let batch = GtRecordBatch::try_from_df_record_batch(Arc::new(gt_schema), batch).unwrap();
565
566        while remaining_rows >= bucket_num {
567            // "sample" normal columns
568            for normal_index in &self.normal_indices {
569                let val = batch.column(*normal_index).get(cursor);
570                self.output_buffer[*normal_index].push_value_ref(val.as_value_ref());
571            }
572            // "fold" `le` and field columns
573            let le_array = batch.column(self.le_column_index);
574            let field_array = batch.column(self.field_column_index);
575            let mut bucket = vec![];
576            let mut counters = vec![];
577            for bias in 0..bucket_num {
578                let le_str_val = le_array.get(cursor + bias);
579                let le_str_val_ref = le_str_val.as_value_ref();
580                let le_str = le_str_val_ref
581                    .as_string()
582                    .unwrap()
583                    .expect("le column should not be nullable");
584                let le = le_str.parse::<f64>().unwrap();
585                bucket.push(le);
586
587                let counter = field_array
588                    .get(cursor + bias)
589                    .as_value_ref()
590                    .as_f64()
591                    .unwrap()
592                    .expect("field column should not be nullable");
593                counters.push(counter);
594            }
595            // ignore invalid data
596            let result = Self::evaluate_row(self.quantile, &bucket, &counters).unwrap_or(f64::NAN);
597            self.output_buffer[self.field_column_index].push_value_ref(ValueRef::from(result));
598            cursor += bucket_num;
599            remaining_rows -= bucket_num;
600            self.output_buffered_rows += 1;
601        }
602
603        let remaining_input_batch = batch.into_df_record_batch().slice(cursor, remaining_rows);
604        self.input_buffered_rows = remaining_input_batch.num_rows();
605        self.input_buffer.push(remaining_input_batch);
606
607        Ok(())
608    }
609
610    fn push_input_buf(&mut self, batch: RecordBatch) {
611        self.input_buffered_rows += batch.num_rows();
612        self.input_buffer.push(batch);
613    }
614
615    /// Compute result from output buffer
616    fn take_output_buf(&mut self) -> DataFusionResult<Option<RecordBatch>> {
617        if self.output_buffered_rows == 0 {
618            if self.input_buffered_rows != 0 {
619                warn!(
620                    "input buffer is not empty, {} rows remaining",
621                    self.input_buffered_rows
622                );
623            }
624            return Ok(None);
625        }
626
627        let mut output_buf = Self::empty_output_buffer(&self.output_schema, self.le_column_index)?;
628        std::mem::swap(&mut self.output_buffer, &mut output_buf);
629        let mut columns = Vec::with_capacity(output_buf.len());
630        for builder in output_buf.iter_mut() {
631            columns.push(builder.to_vector().to_arrow_array());
632        }
633        // remove the placeholder column for `le`
634        columns.remove(self.le_column_index);
635
636        self.output_buffered_rows = 0;
637        RecordBatch::try_new(self.output_schema.clone(), columns)
638            .map(Some)
639            .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))
640    }
641
642    /// Find the first `+Inf` which indicates the end of the bucket group
643    ///
644    /// If the return value equals to batch's num_rows means the it's not found
645    /// in this batch
646    fn find_positive_inf(&self, batch: &RecordBatch) -> DataFusionResult<usize> {
647        // fuse this function. It should not be called when the
648        // bucket size is already know.
649        if let Some(bucket_size) = self.bucket_size {
650            return Ok(bucket_size);
651        }
652        let string_le_array = batch.column(self.le_column_index);
653        let float_le_array = compute::cast(&string_le_array, &DataType::Float64).map_err(|e| {
654            DataFusionError::Execution(format!(
655                "cannot cast {} array to float64 array: {:?}",
656                string_le_array.data_type(),
657                e
658            ))
659        })?;
660        let le_as_f64_array = float_le_array
661            .as_primitive_opt::<Float64Type>()
662            .ok_or_else(|| {
663                DataFusionError::Execution(format!(
664                    "expect a float64 array, but found {}",
665                    float_le_array.data_type()
666                ))
667            })?;
668        for (i, v) in le_as_f64_array.iter().enumerate() {
669            if let Some(v) = v
670                && v == f64::INFINITY
671            {
672                return Ok(i);
673            }
674        }
675
676        Ok(batch.num_rows())
677    }
678
679    /// Evaluate the field column and return the result
680    fn evaluate_row(quantile: f64, bucket: &[f64], counter: &[f64]) -> DataFusionResult<f64> {
681        // check bucket
682        if bucket.len() <= 1 {
683            return Ok(f64::NAN);
684        }
685        if bucket.last().unwrap().is_finite() {
686            return Err(DataFusionError::Execution(
687                "last bucket should be +Inf".to_string(),
688            ));
689        }
690        if bucket.len() != counter.len() {
691            return Err(DataFusionError::Execution(
692                "bucket and counter should have the same length".to_string(),
693            ));
694        }
695        // check quantile
696        if quantile < 0.0 {
697            return Ok(f64::NEG_INFINITY);
698        } else if quantile > 1.0 {
699            return Ok(f64::INFINITY);
700        } else if quantile.is_nan() {
701            return Ok(f64::NAN);
702        }
703
704        // check input value
705        debug_assert!(bucket.windows(2).all(|w| w[0] <= w[1]), "{bucket:?}");
706        debug_assert!(counter.windows(2).all(|w| w[0] <= w[1]), "{counter:?}");
707
708        let total = *counter.last().unwrap();
709        let expected_pos = total * quantile;
710        let mut fit_bucket_pos = 0;
711        while fit_bucket_pos < bucket.len() && counter[fit_bucket_pos] < expected_pos {
712            fit_bucket_pos += 1;
713        }
714        if fit_bucket_pos >= bucket.len() - 1 {
715            Ok(bucket[bucket.len() - 2])
716        } else {
717            let upper_bound = bucket[fit_bucket_pos];
718            let upper_count = counter[fit_bucket_pos];
719            let mut lower_bound = bucket[0].min(0.0);
720            let mut lower_count = 0.0;
721            if fit_bucket_pos > 0 {
722                lower_bound = bucket[fit_bucket_pos - 1];
723                lower_count = counter[fit_bucket_pos - 1];
724            }
725            Ok(lower_bound
726                + (upper_bound - lower_bound) / (upper_count - lower_count)
727                    * (expected_pos - lower_count))
728        }
729    }
730}
731
732#[cfg(test)]
733mod test {
734    use std::sync::Arc;
735
736    use datafusion::arrow::array::Float64Array;
737    use datafusion::arrow::datatypes::{Field, Schema};
738    use datafusion::common::ToDFSchema;
739    use datafusion::datasource::memory::MemorySourceConfig;
740    use datafusion::datasource::source::DataSourceExec;
741    use datafusion::prelude::SessionContext;
742    use datatypes::arrow_array::StringArray;
743
744    use super::*;
745
746    fn prepare_test_data() -> DataSourceExec {
747        let schema = Arc::new(Schema::new(vec![
748            Field::new("host", DataType::Utf8, true),
749            Field::new("le", DataType::Utf8, true),
750            Field::new("val", DataType::Float64, true),
751        ]));
752
753        // 12 items
754        let host_column_1 = Arc::new(StringArray::from(vec![
755            "host_1", "host_1", "host_1", "host_1", "host_1", "host_1", "host_1", "host_1",
756            "host_1", "host_1", "host_1", "host_1",
757        ])) as _;
758        let le_column_1 = Arc::new(StringArray::from(vec![
759            "0.001", "0.1", "10", "1000", "+Inf", "0.001", "0.1", "10", "1000", "+inf", "0.001",
760            "0.1",
761        ])) as _;
762        let val_column_1 = Arc::new(Float64Array::from(vec![
763            0_0.0, 1.0, 1.0, 5.0, 5.0, 0_0.0, 20.0, 60.0, 70.0, 100.0, 0_1.0, 1.0,
764        ])) as _;
765
766        // 2 items
767        let host_column_2 = Arc::new(StringArray::from(vec!["host_1", "host_1"])) as _;
768        let le_column_2 = Arc::new(StringArray::from(vec!["10", "1000"])) as _;
769        let val_column_2 = Arc::new(Float64Array::from(vec![1.0, 1.0])) as _;
770
771        // 11 items
772        let host_column_3 = Arc::new(StringArray::from(vec![
773            "host_1", "host_2", "host_2", "host_2", "host_2", "host_2", "host_2", "host_2",
774            "host_2", "host_2", "host_2",
775        ])) as _;
776        let le_column_3 = Arc::new(StringArray::from(vec![
777            "+INF", "0.001", "0.1", "10", "1000", "+iNf", "0.001", "0.1", "10", "1000", "+Inf",
778        ])) as _;
779        let val_column_3 = Arc::new(Float64Array::from(vec![
780            1.0, 0_0.0, 0.0, 0.0, 0.0, 0.0, 0_0.0, 1.0, 2.0, 3.0, 4.0,
781        ])) as _;
782
783        let data_1 = RecordBatch::try_new(
784            schema.clone(),
785            vec![host_column_1, le_column_1, val_column_1],
786        )
787        .unwrap();
788        let data_2 = RecordBatch::try_new(
789            schema.clone(),
790            vec![host_column_2, le_column_2, val_column_2],
791        )
792        .unwrap();
793        let data_3 = RecordBatch::try_new(
794            schema.clone(),
795            vec![host_column_3, le_column_3, val_column_3],
796        )
797        .unwrap();
798
799        DataSourceExec::new(Arc::new(
800            MemorySourceConfig::try_new(&[vec![data_1, data_2, data_3]], schema, None).unwrap(),
801        ))
802    }
803
804    #[tokio::test]
805    async fn fold_overall() {
806        let memory_exec = Arc::new(prepare_test_data());
807        let output_schema: SchemaRef = Arc::new(
808            (*HistogramFold::convert_schema(
809                &Arc::new(memory_exec.schema().to_dfschema().unwrap()),
810                "le",
811            )
812            .unwrap()
813            .as_ref())
814            .clone()
815            .into(),
816        );
817        let properties = PlanProperties::new(
818            EquivalenceProperties::new(output_schema.clone()),
819            Partitioning::UnknownPartitioning(1),
820            EmissionType::Incremental,
821            Boundedness::Bounded,
822        );
823        let fold_exec = Arc::new(HistogramFoldExec {
824            le_column_index: 1,
825            field_column_index: 2,
826            quantile: 0.4,
827            ts_column_index: 9999, // not exist but doesn't matter
828            input: memory_exec,
829            output_schema,
830            metric: ExecutionPlanMetricsSet::new(),
831            properties,
832        });
833
834        let session_context = SessionContext::default();
835        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
836            .await
837            .unwrap();
838        let result_literal = datatypes::arrow::util::pretty::pretty_format_batches(&result)
839            .unwrap()
840            .to_string();
841
842        let expected = String::from(
843            "+--------+-------------------+
844| host   | val               |
845+--------+-------------------+
846| host_1 | 257.5             |
847| host_1 | 5.05              |
848| host_1 | 0.0004            |
849| host_2 | NaN               |
850| host_2 | 6.040000000000001 |
851+--------+-------------------+",
852        );
853        assert_eq!(result_literal, expected);
854    }
855
856    #[test]
857    fn confirm_schema() {
858        let input_schema = Schema::new(vec![
859            Field::new("host", DataType::Utf8, true),
860            Field::new("le", DataType::Utf8, true),
861            Field::new("val", DataType::Float64, true),
862        ])
863        .to_dfschema_ref()
864        .unwrap();
865        let expected_output_schema = Schema::new(vec![
866            Field::new("host", DataType::Utf8, true),
867            Field::new("val", DataType::Float64, true),
868        ])
869        .to_dfschema_ref()
870        .unwrap();
871
872        let actual = HistogramFold::convert_schema(&input_schema, "le").unwrap();
873        assert_eq!(actual, expected_output_schema)
874    }
875
876    #[test]
877    fn evaluate_row_normal_case() {
878        let bucket = [0.0, 1.0, 2.0, 3.0, 4.0, f64::INFINITY];
879
880        #[derive(Debug)]
881        struct Case {
882            quantile: f64,
883            counters: Vec<f64>,
884            expected: f64,
885        }
886
887        let cases = [
888            Case {
889                quantile: 0.9,
890                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
891                expected: 4.0,
892            },
893            Case {
894                quantile: 0.89,
895                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
896                expected: 4.0,
897            },
898            Case {
899                quantile: 0.78,
900                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
901                expected: 3.9,
902            },
903            Case {
904                quantile: 0.5,
905                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
906                expected: 2.5,
907            },
908            Case {
909                quantile: 0.5,
910                counters: vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
911                expected: f64::NAN,
912            },
913            Case {
914                quantile: 1.0,
915                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
916                expected: 4.0,
917            },
918            Case {
919                quantile: 0.0,
920                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
921                expected: f64::NAN,
922            },
923            Case {
924                quantile: 1.1,
925                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
926                expected: f64::INFINITY,
927            },
928            Case {
929                quantile: -1.0,
930                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
931                expected: f64::NEG_INFINITY,
932            },
933        ];
934
935        for case in cases {
936            let actual =
937                HistogramFoldStream::evaluate_row(case.quantile, &bucket, &case.counters).unwrap();
938            assert_eq!(
939                format!("{actual}"),
940                format!("{}", case.expected),
941                "{:?}",
942                case
943            );
944        }
945    }
946
947    #[test]
948    #[should_panic]
949    fn evaluate_out_of_order_input() {
950        let bucket = [0.0, 1.0, 2.0, 3.0, 4.0, f64::INFINITY];
951        let counters = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0];
952        HistogramFoldStream::evaluate_row(0.5, &bucket, &counters).unwrap();
953    }
954
955    #[test]
956    fn evaluate_wrong_bucket() {
957        let bucket = [0.0, 1.0, 2.0, 3.0, 4.0, f64::INFINITY, 5.0];
958        let counters = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
959        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters);
960        assert!(result.is_err());
961    }
962
963    #[test]
964    fn evaluate_small_fraction() {
965        let bucket = [0.0, 2.0, 4.0, 6.0, f64::INFINITY];
966        let counters = [0.0, 1.0 / 300.0, 2.0 / 300.0, 0.01, 0.01];
967        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters).unwrap();
968        assert_eq!(3.0, result);
969    }
970}