promql/extension_plan/
histogram_fold.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::borrow::Cow;
17use std::collections::{HashMap, HashSet};
18use std::sync::Arc;
19use std::task::Poll;
20use std::time::Instant;
21
22use common_telemetry::warn;
23use datafusion::arrow::array::{Array, AsArray, StringArray};
24use datafusion::arrow::compute::{SortOptions, concat_batches};
25use datafusion::arrow::datatypes::{DataType, Float64Type, SchemaRef};
26use datafusion::arrow::record_batch::RecordBatch;
27use datafusion::common::stats::Precision;
28use datafusion::common::{DFSchema, DFSchemaRef, Statistics};
29use datafusion::error::{DataFusionError, Result as DataFusionResult};
30use datafusion::execution::TaskContext;
31use datafusion::logical_expr::{LogicalPlan, UserDefinedLogicalNodeCore};
32use datafusion::physical_expr::{
33    EquivalenceProperties, LexRequirement, OrderingRequirements, PhysicalSortRequirement,
34};
35use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
36use datafusion::physical_plan::expressions::{CastExpr as PhyCast, Column as PhyColumn};
37use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
38use datafusion::physical_plan::{
39    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
40    Partitioning, PhysicalExpr, PlanProperties, RecordBatchStream, SendableRecordBatchStream,
41};
42use datafusion::prelude::{Column, Expr};
43use datafusion_expr::{EmptyRelation, col};
44use datatypes::prelude::{ConcreteDataType, DataType as GtDataType};
45use datatypes::value::{OrderedF64, Value, ValueRef};
46use datatypes::vectors::{Helper, MutableVector, VectorRef};
47use futures::{Stream, StreamExt, ready};
48use greptime_proto::substrait_extension as pb;
49use prost::Message;
50use snafu::ResultExt;
51
52use crate::error::{DeserializeSnafu, Result};
53use crate::extension_plan::{resolve_column_name, serialize_column_index};
54
55/// `HistogramFold` will fold the conventional (non-native) histogram ([1]) for later
56/// computing.
57///
58/// Specifically, it will transform the `le` and `field` column into a complex
59/// type, and samples on other tag columns:
60/// - `le` will become a [ListArray] of [f64]. With each bucket bound parsed
61/// - `field` will become a [ListArray] of [f64]
62/// - other columns will be sampled every `bucket_num` element, but their types won't change.
63///
64/// Due to the folding or sampling, the output rows number will become `input_rows` / `bucket_num`.
65///
66/// # Requirement
67/// - Input should be sorted on `<tag list>, ts, le ASC`.
68/// - The value set of `le` should be same. I.e., buckets of every series should be same.
69///
70/// [1]: https://prometheus.io/docs/concepts/metric_types/#histogram
71#[derive(Debug, PartialEq, Hash, Eq)]
72pub struct HistogramFold {
73    /// Name of the `le` column. It's a special column in prometheus
74    /// for implementing conventional histogram. It's a string column
75    /// with "literal" float value, like "+Inf", "0.001" etc.
76    le_column: String,
77    ts_column: String,
78    input: LogicalPlan,
79    field_column: String,
80    quantile: OrderedF64,
81    output_schema: DFSchemaRef,
82    unfix: Option<UnfixIndices>,
83}
84
85#[derive(Debug, PartialEq, Eq, Hash, PartialOrd)]
86struct UnfixIndices {
87    pub le_column_idx: u64,
88    pub ts_column_idx: u64,
89    pub field_column_idx: u64,
90}
91
92impl UserDefinedLogicalNodeCore for HistogramFold {
93    fn name(&self) -> &str {
94        Self::name()
95    }
96
97    fn inputs(&self) -> Vec<&LogicalPlan> {
98        vec![&self.input]
99    }
100
101    fn schema(&self) -> &DFSchemaRef {
102        &self.output_schema
103    }
104
105    fn expressions(&self) -> Vec<Expr> {
106        if self.unfix.is_some() {
107            return vec![];
108        }
109
110        let mut exprs = vec![
111            col(&self.le_column),
112            col(&self.ts_column),
113            col(&self.field_column),
114        ];
115        exprs.extend(self.input.schema().fields().iter().filter_map(|f| {
116            let name = f.name();
117            if name != &self.le_column && name != &self.ts_column && name != &self.field_column {
118                Some(col(name))
119            } else {
120                None
121            }
122        }));
123        exprs
124    }
125
126    fn necessary_children_exprs(&self, output_columns: &[usize]) -> Option<Vec<Vec<usize>>> {
127        if self.unfix.is_some() {
128            return None;
129        }
130
131        let input_schema = self.input.schema();
132        let le_column_index = input_schema.index_of_column_by_name(None, &self.le_column)?;
133
134        if output_columns.is_empty() {
135            let indices = (0..input_schema.fields().len()).collect::<Vec<_>>();
136            return Some(vec![indices]);
137        }
138
139        let mut necessary_indices = output_columns
140            .iter()
141            .map(|&output_column| {
142                if output_column < le_column_index {
143                    output_column
144                } else {
145                    output_column + 1
146                }
147            })
148            .collect::<Vec<_>>();
149        necessary_indices.push(le_column_index);
150        necessary_indices.sort_unstable();
151        necessary_indices.dedup();
152        Some(vec![necessary_indices])
153    }
154
155    fn fmt_for_explain(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156        write!(
157            f,
158            "HistogramFold: le={}, field={}, quantile={}",
159            self.le_column, self.field_column, self.quantile
160        )
161    }
162
163    fn with_exprs_and_inputs(
164        &self,
165        _exprs: Vec<Expr>,
166        inputs: Vec<LogicalPlan>,
167    ) -> DataFusionResult<Self> {
168        if inputs.is_empty() {
169            return Err(DataFusionError::Internal(
170                "HistogramFold must have at least one input".to_string(),
171            ));
172        }
173
174        let input: LogicalPlan = inputs.into_iter().next().unwrap();
175        let input_schema = input.schema();
176
177        if let Some(unfix) = &self.unfix {
178            let le_column =
179                resolve_column_name(unfix.le_column_idx, input_schema, "HistogramFold", "le")?;
180            let ts_column =
181                resolve_column_name(unfix.ts_column_idx, input_schema, "HistogramFold", "ts")?;
182            let field_column = resolve_column_name(
183                unfix.field_column_idx,
184                input_schema,
185                "HistogramFold",
186                "field",
187            )?;
188
189            let output_schema = Self::convert_schema(input_schema, &le_column)?;
190
191            Ok(Self {
192                le_column,
193                ts_column,
194                input,
195                field_column,
196                quantile: self.quantile,
197                output_schema,
198                unfix: None,
199            })
200        } else {
201            Ok(Self {
202                le_column: self.le_column.clone(),
203                ts_column: self.ts_column.clone(),
204                input,
205                field_column: self.field_column.clone(),
206                quantile: self.quantile,
207                output_schema: self.output_schema.clone(),
208                unfix: None,
209            })
210        }
211    }
212}
213
214impl HistogramFold {
215    pub fn new(
216        le_column: String,
217        field_column: String,
218        ts_column: String,
219        quantile: f64,
220        input: LogicalPlan,
221    ) -> DataFusionResult<Self> {
222        let input_schema = input.schema();
223        Self::check_schema(input_schema, &le_column, &field_column, &ts_column)?;
224        let output_schema = Self::convert_schema(input_schema, &le_column)?;
225        Ok(Self {
226            le_column,
227            ts_column,
228            input,
229            field_column,
230            quantile: quantile.into(),
231            output_schema,
232            unfix: None,
233        })
234    }
235
236    pub const fn name() -> &'static str {
237        "HistogramFold"
238    }
239
240    fn check_schema(
241        input_schema: &DFSchemaRef,
242        le_column: &str,
243        field_column: &str,
244        ts_column: &str,
245    ) -> DataFusionResult<()> {
246        let check_column = |col| {
247            if !input_schema.has_column_with_unqualified_name(col) {
248                Err(DataFusionError::SchemaError(
249                    Box::new(datafusion::common::SchemaError::FieldNotFound {
250                        field: Box::new(Column::new(None::<String>, col)),
251                        valid_fields: input_schema.columns(),
252                    }),
253                    Box::new(None),
254                ))
255            } else {
256                Ok(())
257            }
258        };
259
260        check_column(le_column)?;
261        check_column(ts_column)?;
262        check_column(field_column)
263    }
264
265    pub fn to_execution_plan(&self, exec_input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
266        let input_schema = self.input.schema();
267        // safety: those fields are checked in `check_schema()`
268        let le_column_index = input_schema
269            .index_of_column_by_name(None, &self.le_column)
270            .unwrap();
271        let field_column_index = input_schema
272            .index_of_column_by_name(None, &self.field_column)
273            .unwrap();
274        let ts_column_index = input_schema
275            .index_of_column_by_name(None, &self.ts_column)
276            .unwrap();
277
278        let tag_columns = exec_input
279            .schema()
280            .fields()
281            .iter()
282            .enumerate()
283            .filter_map(|(idx, field)| {
284                if idx == le_column_index || idx == field_column_index || idx == ts_column_index {
285                    None
286                } else {
287                    Some(Arc::new(PhyColumn::new(field.name(), idx)) as _)
288                }
289            })
290            .collect::<Vec<_>>();
291
292        let mut partition_exprs = tag_columns.clone();
293        partition_exprs.push(Arc::new(PhyColumn::new(
294            self.input.schema().field(ts_column_index).name(),
295            ts_column_index,
296        )) as _);
297
298        let output_schema: SchemaRef = self.output_schema.inner().clone();
299        let properties = PlanProperties::new(
300            EquivalenceProperties::new(output_schema.clone()),
301            Partitioning::Hash(
302                partition_exprs.clone(),
303                exec_input.output_partitioning().partition_count(),
304            ),
305            EmissionType::Incremental,
306            Boundedness::Bounded,
307        );
308        Arc::new(HistogramFoldExec {
309            le_column_index,
310            field_column_index,
311            ts_column_index,
312            input: exec_input,
313            tag_columns,
314            partition_exprs,
315            quantile: self.quantile.into(),
316            output_schema,
317            metric: ExecutionPlanMetricsSet::new(),
318            properties,
319        })
320    }
321
322    /// Transform the schema
323    ///
324    /// - `le` will be removed
325    fn convert_schema(
326        input_schema: &DFSchemaRef,
327        le_column: &str,
328    ) -> DataFusionResult<DFSchemaRef> {
329        let fields = input_schema.fields();
330        // safety: those fields are checked in `check_schema()`
331        let mut new_fields = Vec::with_capacity(fields.len() - 1);
332        for f in fields {
333            if f.name() != le_column {
334                new_fields.push((None, f.clone()));
335            }
336        }
337        Ok(Arc::new(DFSchema::new_with_metadata(
338            new_fields,
339            HashMap::new(),
340        )?))
341    }
342
343    pub fn serialize(&self) -> Vec<u8> {
344        let le_column_idx = serialize_column_index(self.input.schema(), &self.le_column);
345        let ts_column_idx = serialize_column_index(self.input.schema(), &self.ts_column);
346        let field_column_idx = serialize_column_index(self.input.schema(), &self.field_column);
347
348        pb::HistogramFold {
349            le_column_idx,
350            ts_column_idx,
351            field_column_idx,
352            quantile: self.quantile.into(),
353        }
354        .encode_to_vec()
355    }
356
357    pub fn deserialize(bytes: &[u8]) -> Result<Self> {
358        let pb_histogram_fold = pb::HistogramFold::decode(bytes).context(DeserializeSnafu)?;
359        let placeholder_plan = LogicalPlan::EmptyRelation(EmptyRelation {
360            produce_one_row: false,
361            schema: Arc::new(DFSchema::empty()),
362        });
363
364        let unfix = UnfixIndices {
365            le_column_idx: pb_histogram_fold.le_column_idx,
366            ts_column_idx: pb_histogram_fold.ts_column_idx,
367            field_column_idx: pb_histogram_fold.field_column_idx,
368        };
369
370        Ok(Self {
371            le_column: String::new(),
372            ts_column: String::new(),
373            input: placeholder_plan,
374            field_column: String::new(),
375            quantile: pb_histogram_fold.quantile.into(),
376            output_schema: Arc::new(DFSchema::empty()),
377            unfix: Some(unfix),
378        })
379    }
380}
381
382impl PartialOrd for HistogramFold {
383    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
384        // Compare fields in order excluding output_schema
385        match self.le_column.partial_cmp(&other.le_column) {
386            Some(core::cmp::Ordering::Equal) => {}
387            ord => return ord,
388        }
389        match self.ts_column.partial_cmp(&other.ts_column) {
390            Some(core::cmp::Ordering::Equal) => {}
391            ord => return ord,
392        }
393        match self.input.partial_cmp(&other.input) {
394            Some(core::cmp::Ordering::Equal) => {}
395            ord => return ord,
396        }
397        match self.field_column.partial_cmp(&other.field_column) {
398            Some(core::cmp::Ordering::Equal) => {}
399            ord => return ord,
400        }
401        self.quantile.partial_cmp(&other.quantile)
402    }
403}
404
405#[derive(Debug)]
406pub struct HistogramFoldExec {
407    /// Index for `le` column in the schema of input.
408    le_column_index: usize,
409    input: Arc<dyn ExecutionPlan>,
410    output_schema: SchemaRef,
411    /// Index for field column in the schema of input.
412    field_column_index: usize,
413    ts_column_index: usize,
414    /// Tag columns are all columns except `le`, `field` and `ts` columns.
415    tag_columns: Vec<Arc<dyn PhysicalExpr>>,
416    partition_exprs: Vec<Arc<dyn PhysicalExpr>>,
417    quantile: f64,
418    metric: ExecutionPlanMetricsSet,
419    properties: PlanProperties,
420}
421
422impl ExecutionPlan for HistogramFoldExec {
423    fn as_any(&self) -> &dyn Any {
424        self
425    }
426
427    fn properties(&self) -> &PlanProperties {
428        &self.properties
429    }
430
431    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
432        let mut cols = self
433            .tag_columns
434            .iter()
435            .map(|expr| PhysicalSortRequirement {
436                expr: expr.clone(),
437                options: None,
438            })
439            .collect::<Vec<PhysicalSortRequirement>>();
440        // add ts
441        cols.push(PhysicalSortRequirement {
442            expr: Arc::new(PhyColumn::new(
443                self.input.schema().field(self.ts_column_index).name(),
444                self.ts_column_index,
445            )),
446            options: None,
447        });
448        // add le ASC
449        cols.push(PhysicalSortRequirement {
450            expr: Arc::new(PhyCast::new(
451                Arc::new(PhyColumn::new(
452                    self.input.schema().field(self.le_column_index).name(),
453                    self.le_column_index,
454                )),
455                DataType::Float64,
456                None,
457            )),
458            options: Some(SortOptions {
459                descending: false,  // +INF in the last
460                nulls_first: false, // not nullable
461            }),
462        });
463
464        // Safety: `cols` is not empty
465        let requirement = LexRequirement::new(cols).unwrap();
466
467        vec![Some(OrderingRequirements::Hard(vec![requirement]))]
468    }
469
470    fn required_input_distribution(&self) -> Vec<Distribution> {
471        vec![Distribution::HashPartitioned(self.partition_exprs.clone())]
472    }
473
474    fn maintains_input_order(&self) -> Vec<bool> {
475        vec![true; self.children().len()]
476    }
477
478    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
479        vec![&self.input]
480    }
481
482    // cannot change schema with this method
483    fn with_new_children(
484        self: Arc<Self>,
485        children: Vec<Arc<dyn ExecutionPlan>>,
486    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
487        assert!(!children.is_empty());
488        let new_input = children[0].clone();
489        let properties = PlanProperties::new(
490            EquivalenceProperties::new(self.output_schema.clone()),
491            Partitioning::Hash(
492                self.partition_exprs.clone(),
493                new_input.output_partitioning().partition_count(),
494            ),
495            EmissionType::Incremental,
496            Boundedness::Bounded,
497        );
498        Ok(Arc::new(Self {
499            input: new_input,
500            metric: self.metric.clone(),
501            le_column_index: self.le_column_index,
502            ts_column_index: self.ts_column_index,
503            tag_columns: self.tag_columns.clone(),
504            partition_exprs: self.partition_exprs.clone(),
505            quantile: self.quantile,
506            output_schema: self.output_schema.clone(),
507            field_column_index: self.field_column_index,
508            properties,
509        }))
510    }
511
512    fn execute(
513        &self,
514        partition: usize,
515        context: Arc<TaskContext>,
516    ) -> DataFusionResult<SendableRecordBatchStream> {
517        let baseline_metric = BaselineMetrics::new(&self.metric, partition);
518
519        let batch_size = context.session_config().batch_size();
520        let input = self.input.execute(partition, context)?;
521        let output_schema = self.output_schema.clone();
522
523        let mut normal_indices = (0..input.schema().fields().len()).collect::<HashSet<_>>();
524        normal_indices.remove(&self.field_column_index);
525        normal_indices.remove(&self.le_column_index);
526        Ok(Box::pin(HistogramFoldStream {
527            le_column_index: self.le_column_index,
528            field_column_index: self.field_column_index,
529            quantile: self.quantile,
530            normal_indices: normal_indices.into_iter().collect(),
531            bucket_size: None,
532            input_buffer: vec![],
533            input,
534            output_schema,
535            input_schema: self.input.schema(),
536            mode: FoldMode::Optimistic,
537            safe_group: None,
538            metric: baseline_metric,
539            batch_size,
540            input_buffered_rows: 0,
541            output_buffer: HistogramFoldStream::empty_output_buffer(
542                &self.output_schema,
543                self.le_column_index,
544            )?,
545            output_buffered_rows: 0,
546        }))
547    }
548
549    fn metrics(&self) -> Option<MetricsSet> {
550        Some(self.metric.clone_inner())
551    }
552
553    fn partition_statistics(&self, _: Option<usize>) -> DataFusionResult<Statistics> {
554        Ok(Statistics {
555            num_rows: Precision::Absent,
556            total_byte_size: Precision::Absent,
557            column_statistics: Statistics::unknown_column(&self.schema()),
558        })
559    }
560
561    fn name(&self) -> &str {
562        "HistogramFoldExec"
563    }
564}
565
566impl DisplayAs for HistogramFoldExec {
567    fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
568        match t {
569            DisplayFormatType::Default
570            | DisplayFormatType::Verbose
571            | DisplayFormatType::TreeRender => {
572                write!(
573                    f,
574                    "HistogramFoldExec: le=@{}, field=@{}, quantile={}",
575                    self.le_column_index, self.field_column_index, self.quantile
576                )
577            }
578        }
579    }
580}
581
582#[derive(Debug, Clone, Copy, PartialEq, Eq)]
583enum FoldMode {
584    Optimistic,
585    Safe,
586}
587
588pub struct HistogramFoldStream {
589    // internal states
590    le_column_index: usize,
591    field_column_index: usize,
592    quantile: f64,
593    /// Columns need not folding. This indices is based on input schema
594    normal_indices: Vec<usize>,
595    bucket_size: Option<usize>,
596    /// Expected output batch size
597    batch_size: usize,
598    output_schema: SchemaRef,
599    input_schema: SchemaRef,
600    mode: FoldMode,
601    safe_group: Option<SafeGroup>,
602
603    // buffers
604    input_buffer: Vec<RecordBatch>,
605    input_buffered_rows: usize,
606    output_buffer: Vec<Box<dyn MutableVector>>,
607    output_buffered_rows: usize,
608
609    // runtime things
610    input: SendableRecordBatchStream,
611    metric: BaselineMetrics,
612}
613
614#[derive(Debug, Default)]
615struct SafeGroup {
616    tag_values: Vec<Value>,
617    buckets: Vec<f64>,
618    counters: Vec<f64>,
619}
620
621impl RecordBatchStream for HistogramFoldStream {
622    fn schema(&self) -> SchemaRef {
623        self.output_schema.clone()
624    }
625}
626
627impl Stream for HistogramFoldStream {
628    type Item = DataFusionResult<RecordBatch>;
629
630    fn poll_next(
631        mut self: std::pin::Pin<&mut Self>,
632        cx: &mut std::task::Context<'_>,
633    ) -> Poll<Option<Self::Item>> {
634        let poll = loop {
635            match ready!(self.input.poll_next_unpin(cx)) {
636                Some(batch) => {
637                    let batch = batch?;
638                    let timer = Instant::now();
639                    let Some(result) = self.fold_input(batch)? else {
640                        self.metric.elapsed_compute().add_elapsed(timer);
641                        continue;
642                    };
643                    self.metric.elapsed_compute().add_elapsed(timer);
644                    break Poll::Ready(Some(result));
645                }
646                None => {
647                    self.flush_remaining()?;
648                    break Poll::Ready(self.take_output_buf()?.map(Ok));
649                }
650            }
651        };
652        self.metric.record_poll(poll)
653    }
654}
655
656impl HistogramFoldStream {
657    /// The inner most `Result` is for `poll_next()`
658    pub fn fold_input(
659        &mut self,
660        input: RecordBatch,
661    ) -> DataFusionResult<Option<DataFusionResult<RecordBatch>>> {
662        match self.mode {
663            FoldMode::Safe => {
664                self.push_input_buf(input);
665                self.process_safe_mode_buffer()?;
666            }
667            FoldMode::Optimistic => {
668                self.push_input_buf(input);
669                let Some(bucket_num) = self.calculate_bucket_num_from_buffer()? else {
670                    return Ok(None);
671                };
672                self.bucket_size = Some(bucket_num);
673
674                if self.input_buffered_rows < bucket_num {
675                    // not enough rows to fold
676                    return Ok(None);
677                }
678
679                self.fold_buf(bucket_num)?;
680            }
681        }
682
683        self.maybe_take_output()
684    }
685
686    /// Generate a group of empty [MutableVector]s from the output schema.
687    ///
688    /// For simplicity, this method will insert a placeholder for `le`. So that
689    /// the output buffers has the same schema with input. This placeholder needs
690    /// to be removed before returning the output batch.
691    pub fn empty_output_buffer(
692        schema: &SchemaRef,
693        le_column_index: usize,
694    ) -> DataFusionResult<Vec<Box<dyn MutableVector>>> {
695        let mut builders = Vec::with_capacity(schema.fields().len() + 1);
696        for field in schema.fields() {
697            let concrete_datatype = ConcreteDataType::try_from(field.data_type()).unwrap();
698            let mutable_vector = concrete_datatype.create_mutable_vector(0);
699            builders.push(mutable_vector);
700        }
701        builders.insert(
702            le_column_index,
703            ConcreteDataType::float64_datatype().create_mutable_vector(0),
704        );
705
706        Ok(builders)
707    }
708
709    /// Determines bucket count using buffered batches, concatenating them to
710    /// detect the first complete bucket that may span batch boundaries.
711    fn calculate_bucket_num_from_buffer(&mut self) -> DataFusionResult<Option<usize>> {
712        if let Some(size) = self.bucket_size {
713            return Ok(Some(size));
714        }
715
716        if self.input_buffer.is_empty() {
717            return Ok(None);
718        }
719
720        let batch_refs: Vec<&RecordBatch> = self.input_buffer.iter().collect();
721        let batch = concat_batches(&self.input_schema, batch_refs)?;
722        self.find_first_complete_bucket(&batch)
723    }
724
725    fn find_first_complete_bucket(&self, batch: &RecordBatch) -> DataFusionResult<Option<usize>> {
726        if batch.num_rows() == 0 {
727            return Ok(None);
728        }
729
730        let vectors = Helper::try_into_vectors(batch.columns())
731            .map_err(|e| DataFusionError::Execution(e.to_string()))?;
732        let le_array = batch.column(self.le_column_index).as_string::<i32>();
733
734        let mut tag_values_buf = Vec::with_capacity(self.normal_indices.len());
735        self.collect_tag_values(&vectors, 0, &mut tag_values_buf);
736        let mut group_start = 0usize;
737
738        for row in 0..batch.num_rows() {
739            if !self.is_same_group(&vectors, row, &tag_values_buf) {
740                // new group begins
741                self.collect_tag_values(&vectors, row, &mut tag_values_buf);
742                group_start = row;
743            }
744
745            if Self::is_positive_infinity(le_array, row) {
746                return Ok(Some(row - group_start + 1));
747            }
748        }
749
750        Ok(None)
751    }
752
753    /// Fold record batches from input buffer and put to output buffer
754    fn fold_buf(&mut self, bucket_num: usize) -> DataFusionResult<()> {
755        let batch = concat_batches(&self.input_schema, self.input_buffer.drain(..).as_ref())?;
756        let mut remaining_rows = self.input_buffered_rows;
757        let mut cursor = 0;
758
759        // TODO(LFC): Try to get rid of the Arrow array to vector conversion here.
760        let vectors = Helper::try_into_vectors(batch.columns())
761            .map_err(|e| DataFusionError::Execution(e.to_string()))?;
762        let le_array = batch.column(self.le_column_index);
763        let le_array = le_array.as_string::<i32>();
764        let field_array = batch.column(self.field_column_index);
765        let field_array = field_array.as_primitive::<Float64Type>();
766        let mut tag_values_buf = Vec::with_capacity(self.normal_indices.len());
767
768        while remaining_rows >= bucket_num && self.mode == FoldMode::Optimistic {
769            self.collect_tag_values(&vectors, cursor, &mut tag_values_buf);
770            if !self.validate_optimistic_group(
771                &vectors,
772                le_array,
773                cursor,
774                bucket_num,
775                &tag_values_buf,
776            ) {
777                let remaining_input_batch = batch.slice(cursor, remaining_rows);
778                self.switch_to_safe_mode(remaining_input_batch)?;
779                return Ok(());
780            }
781
782            // "sample" normal columns
783            for (idx, value) in self.normal_indices.iter().zip(tag_values_buf.iter()) {
784                self.output_buffer[*idx].push_value_ref(value);
785            }
786            // "fold" `le` and field columns
787            let mut bucket = Vec::with_capacity(bucket_num);
788            let mut counters = Vec::with_capacity(bucket_num);
789            for bias in 0..bucket_num {
790                let position = cursor + bias;
791                let le = if le_array.is_valid(position) {
792                    le_array.value(position).parse::<f64>().unwrap_or(f64::NAN)
793                } else {
794                    f64::NAN
795                };
796                bucket.push(le);
797
798                let counter = if field_array.is_valid(position) {
799                    field_array.value(position)
800                } else {
801                    f64::NAN
802                };
803                counters.push(counter);
804            }
805            // ignore invalid data
806            let result = Self::evaluate_row(self.quantile, &bucket, &counters).unwrap_or(f64::NAN);
807            self.output_buffer[self.field_column_index].push_value_ref(&ValueRef::from(result));
808            cursor += bucket_num;
809            remaining_rows -= bucket_num;
810            self.output_buffered_rows += 1;
811        }
812
813        let remaining_input_batch = batch.slice(cursor, remaining_rows);
814        self.input_buffered_rows = remaining_input_batch.num_rows();
815        if self.input_buffered_rows > 0 {
816            self.input_buffer.push(remaining_input_batch);
817        }
818
819        Ok(())
820    }
821
822    fn push_input_buf(&mut self, batch: RecordBatch) {
823        self.input_buffered_rows += batch.num_rows();
824        self.input_buffer.push(batch);
825    }
826
827    fn maybe_take_output(&mut self) -> DataFusionResult<Option<DataFusionResult<RecordBatch>>> {
828        if self.output_buffered_rows >= self.batch_size {
829            return Ok(self.take_output_buf()?.map(Ok));
830        }
831        Ok(None)
832    }
833
834    fn switch_to_safe_mode(&mut self, remaining_batch: RecordBatch) -> DataFusionResult<()> {
835        self.mode = FoldMode::Safe;
836        self.bucket_size = None;
837        self.input_buffer.clear();
838        self.input_buffered_rows = remaining_batch.num_rows();
839
840        if self.input_buffered_rows > 0 {
841            self.input_buffer.push(remaining_batch);
842            self.process_safe_mode_buffer()?;
843        }
844
845        Ok(())
846    }
847
848    fn collect_tag_values<'a>(
849        &self,
850        vectors: &'a [VectorRef],
851        row: usize,
852        tag_values: &mut Vec<ValueRef<'a>>,
853    ) {
854        tag_values.clear();
855        for idx in self.normal_indices.iter() {
856            tag_values.push(vectors[*idx].get_ref(row));
857        }
858    }
859
860    fn validate_optimistic_group(
861        &self,
862        vectors: &[VectorRef],
863        le_array: &StringArray,
864        cursor: usize,
865        bucket_num: usize,
866        tag_values: &[ValueRef<'_>],
867    ) -> bool {
868        let inf_index = cursor + bucket_num - 1;
869        if !Self::is_positive_infinity(le_array, inf_index) {
870            return false;
871        }
872
873        for offset in 1..bucket_num {
874            let row = cursor + offset;
875            for (idx, expected) in self.normal_indices.iter().zip(tag_values.iter()) {
876                if vectors[*idx].get_ref(row) != *expected {
877                    return false;
878                }
879            }
880        }
881        true
882    }
883
884    /// Checks whether a row belongs to the current group (same series).
885    fn is_same_group(
886        &self,
887        vectors: &[VectorRef],
888        row: usize,
889        tag_values: &[ValueRef<'_>],
890    ) -> bool {
891        self.normal_indices
892            .iter()
893            .zip(tag_values.iter())
894            .all(|(idx, expected)| vectors[*idx].get_ref(row) == *expected)
895    }
896
897    fn push_output_row(&mut self, tag_values: &[ValueRef<'_>], result: f64) {
898        debug_assert_eq!(self.normal_indices.len(), tag_values.len());
899        for (idx, value) in self.normal_indices.iter().zip(tag_values.iter()) {
900            self.output_buffer[*idx].push_value_ref(value);
901        }
902        self.output_buffer[self.field_column_index].push_value_ref(&ValueRef::from(result));
903        self.output_buffered_rows += 1;
904    }
905
906    fn finalize_safe_group(&mut self) -> DataFusionResult<()> {
907        if let Some(group) = self.safe_group.take() {
908            if group.tag_values.is_empty() {
909                return Ok(());
910            }
911
912            let has_inf = group
913                .buckets
914                .last()
915                .map(|v| v.is_infinite() && v.is_sign_positive())
916                .unwrap_or(false);
917            let result = if group.buckets.len() < 2 || !has_inf {
918                f64::NAN
919            } else {
920                Self::evaluate_row(self.quantile, &group.buckets, &group.counters)
921                    .unwrap_or(f64::NAN)
922            };
923            let mut tag_value_refs = Vec::with_capacity(group.tag_values.len());
924            tag_value_refs.extend(group.tag_values.iter().map(|v| v.as_value_ref()));
925            self.push_output_row(&tag_value_refs, result);
926        }
927        Ok(())
928    }
929
930    fn process_safe_mode_buffer(&mut self) -> DataFusionResult<()> {
931        if self.input_buffer.is_empty() {
932            self.input_buffered_rows = 0;
933            return Ok(());
934        }
935
936        let batch = concat_batches(&self.input_schema, self.input_buffer.drain(..).as_ref())?;
937        self.input_buffered_rows = 0;
938        let vectors = Helper::try_into_vectors(batch.columns())
939            .map_err(|e| DataFusionError::Execution(e.to_string()))?;
940        let le_array = batch.column(self.le_column_index).as_string::<i32>();
941        let field_array = batch
942            .column(self.field_column_index)
943            .as_primitive::<Float64Type>();
944        let mut tag_values_buf = Vec::with_capacity(self.normal_indices.len());
945
946        for row in 0..batch.num_rows() {
947            self.collect_tag_values(&vectors, row, &mut tag_values_buf);
948            let should_start_new_group = self
949                .safe_group
950                .as_ref()
951                .is_none_or(|group| !Self::tag_values_equal(&group.tag_values, &tag_values_buf));
952            if should_start_new_group {
953                self.finalize_safe_group()?;
954                self.safe_group = Some(SafeGroup {
955                    tag_values: tag_values_buf.iter().cloned().map(Value::from).collect(),
956                    buckets: Vec::new(),
957                    counters: Vec::new(),
958                });
959            }
960
961            let Some(group) = self.safe_group.as_mut() else {
962                continue;
963            };
964
965            let bucket = if le_array.is_valid(row) {
966                le_array.value(row).parse::<f64>().unwrap_or(f64::NAN)
967            } else {
968                f64::NAN
969            };
970            let counter = if field_array.is_valid(row) {
971                field_array.value(row)
972            } else {
973                f64::NAN
974            };
975
976            group.buckets.push(bucket);
977            group.counters.push(counter);
978        }
979
980        Ok(())
981    }
982
983    fn tag_values_equal(group_values: &[Value], current: &[ValueRef<'_>]) -> bool {
984        group_values.len() == current.len()
985            && group_values
986                .iter()
987                .zip(current.iter())
988                .all(|(group, now)| group.as_value_ref() == *now)
989    }
990
991    /// Compute result from output buffer
992    fn take_output_buf(&mut self) -> DataFusionResult<Option<RecordBatch>> {
993        if self.output_buffered_rows == 0 {
994            if self.input_buffered_rows != 0 {
995                warn!(
996                    "input buffer is not empty, {} rows remaining",
997                    self.input_buffered_rows
998                );
999            }
1000            return Ok(None);
1001        }
1002
1003        let mut output_buf = Self::empty_output_buffer(&self.output_schema, self.le_column_index)?;
1004        std::mem::swap(&mut self.output_buffer, &mut output_buf);
1005        let mut columns = Vec::with_capacity(output_buf.len());
1006        for builder in output_buf.iter_mut() {
1007            columns.push(builder.to_vector().to_arrow_array());
1008        }
1009        // remove the placeholder column for `le`
1010        columns.remove(self.le_column_index);
1011
1012        self.output_buffered_rows = 0;
1013        RecordBatch::try_new(self.output_schema.clone(), columns)
1014            .map(Some)
1015            .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))
1016    }
1017
1018    fn flush_remaining(&mut self) -> DataFusionResult<()> {
1019        if self.mode == FoldMode::Optimistic && self.input_buffered_rows > 0 {
1020            let buffered_batches: Vec<_> = self.input_buffer.drain(..).collect();
1021            if !buffered_batches.is_empty() {
1022                let batch = concat_batches(&self.input_schema, buffered_batches.as_slice())?;
1023                self.switch_to_safe_mode(batch)?;
1024            } else {
1025                self.input_buffered_rows = 0;
1026            }
1027        }
1028
1029        if self.mode == FoldMode::Safe {
1030            self.process_safe_mode_buffer()?;
1031            self.finalize_safe_group()?;
1032        }
1033
1034        Ok(())
1035    }
1036
1037    fn is_positive_infinity(le_array: &StringArray, index: usize) -> bool {
1038        le_array.is_valid(index)
1039            && matches!(
1040                le_array.value(index).parse::<f64>(),
1041                Ok(value) if value.is_infinite() && value.is_sign_positive()
1042            )
1043    }
1044
1045    /// Evaluate the field column and return the result
1046    fn evaluate_row(quantile: f64, bucket: &[f64], counter: &[f64]) -> DataFusionResult<f64> {
1047        // check bucket
1048        if bucket.len() <= 1 {
1049            return Ok(f64::NAN);
1050        }
1051        if bucket.last().unwrap().is_finite() {
1052            return Err(DataFusionError::Execution(
1053                "last bucket should be +Inf".to_string(),
1054            ));
1055        }
1056        if bucket.len() != counter.len() {
1057            return Err(DataFusionError::Execution(
1058                "bucket and counter should have the same length".to_string(),
1059            ));
1060        }
1061        // check quantile
1062        if quantile < 0.0 {
1063            return Ok(f64::NEG_INFINITY);
1064        } else if quantile > 1.0 {
1065            return Ok(f64::INFINITY);
1066        } else if quantile.is_nan() {
1067            return Ok(f64::NAN);
1068        }
1069
1070        // check input value
1071        if !bucket.windows(2).all(|w| w[0] <= w[1]) {
1072            return Ok(f64::NAN);
1073        }
1074        let counter = {
1075            let needs_fix =
1076                counter.iter().any(|v| !v.is_finite()) || !counter.windows(2).all(|w| w[0] <= w[1]);
1077            if !needs_fix {
1078                Cow::Borrowed(counter)
1079            } else {
1080                let mut fixed = Vec::with_capacity(counter.len());
1081                let mut prev = 0.0;
1082                for (idx, &v) in counter.iter().enumerate() {
1083                    let mut val = if v.is_finite() { v } else { prev };
1084                    if idx > 0 && val < prev {
1085                        val = prev;
1086                    }
1087                    fixed.push(val);
1088                    prev = val;
1089                }
1090                Cow::Owned(fixed)
1091            }
1092        };
1093
1094        let total = *counter.last().unwrap();
1095        let expected_pos = total * quantile;
1096        let mut fit_bucket_pos = 0;
1097        while fit_bucket_pos < bucket.len() && counter[fit_bucket_pos] < expected_pos {
1098            fit_bucket_pos += 1;
1099        }
1100        if fit_bucket_pos >= bucket.len() - 1 {
1101            Ok(bucket[bucket.len() - 2])
1102        } else {
1103            let upper_bound = bucket[fit_bucket_pos];
1104            let upper_count = counter[fit_bucket_pos];
1105            let mut lower_bound = bucket[0].min(0.0);
1106            let mut lower_count = 0.0;
1107            if fit_bucket_pos > 0 {
1108                lower_bound = bucket[fit_bucket_pos - 1];
1109                lower_count = counter[fit_bucket_pos - 1];
1110            }
1111            if (upper_count - lower_count).abs() < 1e-10 {
1112                return Ok(f64::NAN);
1113            }
1114            Ok(lower_bound
1115                + (upper_bound - lower_bound) / (upper_count - lower_count)
1116                    * (expected_pos - lower_count))
1117        }
1118    }
1119}
1120
1121#[cfg(test)]
1122mod test {
1123    use std::sync::Arc;
1124
1125    use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
1126    use datafusion::arrow::datatypes::{Field, Schema, SchemaRef, TimeUnit};
1127    use datafusion::common::ToDFSchema;
1128    use datafusion::datasource::memory::MemorySourceConfig;
1129    use datafusion::datasource::source::DataSourceExec;
1130    use datafusion::logical_expr::EmptyRelation;
1131    use datafusion::prelude::SessionContext;
1132    use datatypes::arrow_array::StringArray;
1133    use futures::FutureExt;
1134
1135    use super::*;
1136
1137    fn project_batch(batch: &RecordBatch, indices: &[usize]) -> RecordBatch {
1138        let fields = indices
1139            .iter()
1140            .map(|&idx| batch.schema().field(idx).clone())
1141            .collect::<Vec<_>>();
1142        let columns = indices
1143            .iter()
1144            .map(|&idx| batch.column(idx).clone())
1145            .collect::<Vec<_>>();
1146        let schema = Arc::new(Schema::new(fields));
1147        RecordBatch::try_new(schema, columns).unwrap()
1148    }
1149
1150    fn prepare_test_data() -> DataSourceExec {
1151        let schema = Arc::new(Schema::new(vec![
1152            Field::new("host", DataType::Utf8, true),
1153            Field::new("le", DataType::Utf8, true),
1154            Field::new("val", DataType::Float64, true),
1155        ]));
1156
1157        // 12 items
1158        let host_column_1 = Arc::new(StringArray::from(vec![
1159            "host_1", "host_1", "host_1", "host_1", "host_1", "host_1", "host_1", "host_1",
1160            "host_1", "host_1", "host_1", "host_1",
1161        ])) as _;
1162        let le_column_1 = Arc::new(StringArray::from(vec![
1163            "0.001", "0.1", "10", "1000", "+Inf", "0.001", "0.1", "10", "1000", "+inf", "0.001",
1164            "0.1",
1165        ])) as _;
1166        let val_column_1 = Arc::new(Float64Array::from(vec![
1167            0_0.0, 1.0, 1.0, 5.0, 5.0, 0_0.0, 20.0, 60.0, 70.0, 100.0, 0_1.0, 1.0,
1168        ])) as _;
1169
1170        // 2 items
1171        let host_column_2 = Arc::new(StringArray::from(vec!["host_1", "host_1"])) as _;
1172        let le_column_2 = Arc::new(StringArray::from(vec!["10", "1000"])) as _;
1173        let val_column_2 = Arc::new(Float64Array::from(vec![1.0, 1.0])) as _;
1174
1175        // 11 items
1176        let host_column_3 = Arc::new(StringArray::from(vec![
1177            "host_1", "host_2", "host_2", "host_2", "host_2", "host_2", "host_2", "host_2",
1178            "host_2", "host_2", "host_2",
1179        ])) as _;
1180        let le_column_3 = Arc::new(StringArray::from(vec![
1181            "+INF", "0.001", "0.1", "10", "1000", "+iNf", "0.001", "0.1", "10", "1000", "+Inf",
1182        ])) as _;
1183        let val_column_3 = Arc::new(Float64Array::from(vec![
1184            1.0, 0_0.0, 0.0, 0.0, 0.0, 0.0, 0_0.0, 1.0, 2.0, 3.0, 4.0,
1185        ])) as _;
1186
1187        let data_1 = RecordBatch::try_new(
1188            schema.clone(),
1189            vec![host_column_1, le_column_1, val_column_1],
1190        )
1191        .unwrap();
1192        let data_2 = RecordBatch::try_new(
1193            schema.clone(),
1194            vec![host_column_2, le_column_2, val_column_2],
1195        )
1196        .unwrap();
1197        let data_3 = RecordBatch::try_new(
1198            schema.clone(),
1199            vec![host_column_3, le_column_3, val_column_3],
1200        )
1201        .unwrap();
1202
1203        DataSourceExec::new(Arc::new(
1204            MemorySourceConfig::try_new(&[vec![data_1, data_2, data_3]], schema, None).unwrap(),
1205        ))
1206    }
1207
1208    fn build_fold_exec_from_batches(
1209        batches: Vec<RecordBatch>,
1210        schema: SchemaRef,
1211        quantile: f64,
1212        ts_column_index: usize,
1213    ) -> Arc<HistogramFoldExec> {
1214        let input: Arc<dyn ExecutionPlan> = Arc::new(DataSourceExec::new(Arc::new(
1215            MemorySourceConfig::try_new(&[batches], schema.clone(), None).unwrap(),
1216        )));
1217        let output_schema: SchemaRef = Arc::new(
1218            HistogramFold::convert_schema(&Arc::new(input.schema().to_dfschema().unwrap()), "le")
1219                .unwrap()
1220                .as_arrow()
1221                .clone(),
1222        );
1223
1224        let (tag_columns, partition_exprs, properties) =
1225            build_test_plan_properties(&input, output_schema.clone(), ts_column_index);
1226
1227        Arc::new(HistogramFoldExec {
1228            le_column_index: 1,
1229            field_column_index: 2,
1230            quantile,
1231            ts_column_index,
1232            input,
1233            output_schema,
1234            tag_columns,
1235            partition_exprs,
1236            metric: ExecutionPlanMetricsSet::new(),
1237            properties,
1238        })
1239    }
1240
1241    type PlanPropsResult = (
1242        Vec<Arc<dyn PhysicalExpr>>,
1243        Vec<Arc<dyn PhysicalExpr>>,
1244        PlanProperties,
1245    );
1246
1247    fn build_test_plan_properties(
1248        input: &Arc<dyn ExecutionPlan>,
1249        output_schema: SchemaRef,
1250        ts_column_index: usize,
1251    ) -> PlanPropsResult {
1252        let tag_columns = input
1253            .schema()
1254            .fields()
1255            .iter()
1256            .enumerate()
1257            .filter_map(|(idx, field)| {
1258                if idx == 1 || idx == 2 || idx == ts_column_index {
1259                    None
1260                } else {
1261                    Some(Arc::new(PhyColumn::new(field.name(), idx)) as _)
1262                }
1263            })
1264            .collect::<Vec<_>>();
1265
1266        let partition_exprs = if tag_columns.is_empty() {
1267            vec![Arc::new(PhyColumn::new(
1268                input.schema().field(ts_column_index).name(),
1269                ts_column_index,
1270            )) as _]
1271        } else {
1272            tag_columns.clone()
1273        };
1274
1275        let properties = PlanProperties::new(
1276            EquivalenceProperties::new(output_schema.clone()),
1277            Partitioning::Hash(
1278                partition_exprs.clone(),
1279                input.output_partitioning().partition_count(),
1280            ),
1281            EmissionType::Incremental,
1282            Boundedness::Bounded,
1283        );
1284
1285        (tag_columns, partition_exprs, properties)
1286    }
1287
1288    #[tokio::test]
1289    async fn fold_overall() {
1290        let memory_exec: Arc<dyn ExecutionPlan> = Arc::new(prepare_test_data());
1291        let output_schema: SchemaRef = Arc::new(
1292            HistogramFold::convert_schema(
1293                &Arc::new(memory_exec.schema().to_dfschema().unwrap()),
1294                "le",
1295            )
1296            .unwrap()
1297            .as_arrow()
1298            .clone(),
1299        );
1300        let (tag_columns, partition_exprs, properties) =
1301            build_test_plan_properties(&memory_exec, output_schema.clone(), 0);
1302        let fold_exec = Arc::new(HistogramFoldExec {
1303            le_column_index: 1,
1304            field_column_index: 2,
1305            quantile: 0.4,
1306            ts_column_index: 0,
1307            input: memory_exec,
1308            output_schema,
1309            tag_columns,
1310            partition_exprs,
1311            metric: ExecutionPlanMetricsSet::new(),
1312            properties,
1313        });
1314
1315        let session_context = SessionContext::default();
1316        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1317            .await
1318            .unwrap();
1319        let result_literal = datatypes::arrow::util::pretty::pretty_format_batches(&result)
1320            .unwrap()
1321            .to_string();
1322
1323        let expected = String::from(
1324            "+--------+-------------------+
1325| host   | val               |
1326+--------+-------------------+
1327| host_1 | 257.5             |
1328| host_1 | 5.05              |
1329| host_1 | 0.0004            |
1330| host_2 | NaN               |
1331| host_2 | 6.040000000000001 |
1332+--------+-------------------+",
1333        );
1334        assert_eq!(result_literal, expected);
1335    }
1336
1337    #[tokio::test]
1338    async fn pruning_should_keep_le_column_for_exec() {
1339        let schema = Arc::new(Schema::new(vec![
1340            Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
1341            Field::new("le", DataType::Utf8, true),
1342            Field::new("val", DataType::Float64, true),
1343        ]));
1344        let df_schema = schema.clone().to_dfschema_ref().unwrap();
1345        let input = LogicalPlan::EmptyRelation(EmptyRelation {
1346            produce_one_row: false,
1347            schema: df_schema,
1348        });
1349        let plan = HistogramFold::new(
1350            "le".to_string(),
1351            "val".to_string(),
1352            "ts".to_string(),
1353            0.5,
1354            input,
1355        )
1356        .unwrap();
1357
1358        let output_columns = [0usize, 1usize];
1359        let required = plan.necessary_children_exprs(&output_columns).unwrap();
1360        let required = &required[0];
1361        assert_eq!(required.as_slice(), &[0, 1, 2]);
1362
1363        let input_batch = RecordBatch::try_new(
1364            schema,
1365            vec![
1366                Arc::new(TimestampMillisecondArray::from(vec![0, 0])),
1367                Arc::new(StringArray::from(vec!["0.1", "+Inf"])),
1368                Arc::new(Float64Array::from(vec![1.0, 2.0])),
1369            ],
1370        )
1371        .unwrap();
1372        let projected = project_batch(&input_batch, required);
1373        let projected_schema = projected.schema();
1374        let memory_exec = Arc::new(DataSourceExec::new(Arc::new(
1375            MemorySourceConfig::try_new(&[vec![projected]], projected_schema, None).unwrap(),
1376        )));
1377
1378        let fold_exec = plan.to_execution_plan(memory_exec);
1379        let session_context = SessionContext::default();
1380        let output_batches =
1381            datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1382                .await
1383                .unwrap();
1384        assert_eq!(output_batches.len(), 1);
1385
1386        let output_batch = &output_batches[0];
1387        assert_eq!(output_batch.num_rows(), 1);
1388
1389        let ts = output_batch
1390            .column(0)
1391            .as_any()
1392            .downcast_ref::<TimestampMillisecondArray>()
1393            .unwrap();
1394        assert_eq!(ts.values(), &[0i64]);
1395
1396        let values = output_batch
1397            .column(1)
1398            .as_any()
1399            .downcast_ref::<Float64Array>()
1400            .unwrap();
1401        assert!((values.value(0) - 0.1).abs() < 1e-12);
1402
1403        // Simulate the pre-fix pruning behavior: omit the `le` column from the child input.
1404        let le_index = 1usize;
1405        let broken_required = output_columns
1406            .iter()
1407            .map(|&output_column| {
1408                if output_column < le_index {
1409                    output_column
1410                } else {
1411                    output_column + 1
1412                }
1413            })
1414            .collect::<Vec<_>>();
1415
1416        let broken = project_batch(&input_batch, &broken_required);
1417        let broken_schema = broken.schema();
1418        let broken_exec = Arc::new(DataSourceExec::new(Arc::new(
1419            MemorySourceConfig::try_new(&[vec![broken]], broken_schema, None).unwrap(),
1420        )));
1421        let broken_fold_exec = plan.to_execution_plan(broken_exec);
1422        let session_context = SessionContext::default();
1423        let broken_result = std::panic::AssertUnwindSafe(async {
1424            datafusion::physical_plan::collect(broken_fold_exec, session_context.task_ctx()).await
1425        })
1426        .catch_unwind()
1427        .await;
1428        assert!(broken_result.is_err());
1429    }
1430
1431    #[test]
1432    fn confirm_schema() {
1433        let input_schema = Schema::new(vec![
1434            Field::new("host", DataType::Utf8, true),
1435            Field::new("le", DataType::Utf8, true),
1436            Field::new("val", DataType::Float64, true),
1437        ])
1438        .to_dfschema_ref()
1439        .unwrap();
1440        let expected_output_schema = Schema::new(vec![
1441            Field::new("host", DataType::Utf8, true),
1442            Field::new("val", DataType::Float64, true),
1443        ])
1444        .to_dfschema_ref()
1445        .unwrap();
1446
1447        let actual = HistogramFold::convert_schema(&input_schema, "le").unwrap();
1448        assert_eq!(actual, expected_output_schema)
1449    }
1450
1451    #[tokio::test]
1452    async fn fallback_to_safe_mode_on_missing_inf() {
1453        let schema = Arc::new(Schema::new(vec![
1454            Field::new("host", DataType::Utf8, true),
1455            Field::new("le", DataType::Utf8, true),
1456            Field::new("val", DataType::Float64, true),
1457        ]));
1458        let host_column = Arc::new(StringArray::from(vec!["a", "a", "a", "a", "b", "b"])) as _;
1459        let le_column = Arc::new(StringArray::from(vec![
1460            "0.1", "+Inf", "0.1", "1.0", "0.1", "+Inf",
1461        ])) as _;
1462        let val_column = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 3.0, 1.0, 5.0])) as _;
1463        let batch =
1464            RecordBatch::try_new(schema.clone(), vec![host_column, le_column, val_column]).unwrap();
1465        let fold_exec = build_fold_exec_from_batches(vec![batch], schema, 0.5, 0);
1466        let session_context = SessionContext::default();
1467        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1468            .await
1469            .unwrap();
1470        let result_literal = datatypes::arrow::util::pretty::pretty_format_batches(&result)
1471            .unwrap()
1472            .to_string();
1473
1474        let expected = String::from(
1475            "+------+-----+
1476| host | val |
1477+------+-----+
1478| a    | 0.1 |
1479| a    | NaN |
1480| b    | 0.1 |
1481+------+-----+",
1482        );
1483        assert_eq!(result_literal, expected);
1484    }
1485
1486    #[tokio::test]
1487    async fn emit_nan_when_no_inf_present() {
1488        let schema = Arc::new(Schema::new(vec![
1489            Field::new("host", DataType::Utf8, true),
1490            Field::new("le", DataType::Utf8, true),
1491            Field::new("val", DataType::Float64, true),
1492        ]));
1493        let host_column = Arc::new(StringArray::from(vec!["c", "c"])) as _;
1494        let le_column = Arc::new(StringArray::from(vec!["0.1", "1.0"])) as _;
1495        let val_column = Arc::new(Float64Array::from(vec![1.0, 2.0])) as _;
1496        let batch =
1497            RecordBatch::try_new(schema.clone(), vec![host_column, le_column, val_column]).unwrap();
1498        let fold_exec = build_fold_exec_from_batches(vec![batch], schema, 0.9, 0);
1499        let session_context = SessionContext::default();
1500        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1501            .await
1502            .unwrap();
1503        let result_literal = datatypes::arrow::util::pretty::pretty_format_batches(&result)
1504            .unwrap()
1505            .to_string();
1506
1507        let expected = String::from(
1508            "+------+-----+
1509| host | val |
1510+------+-----+
1511| c    | NaN |
1512+------+-----+",
1513        );
1514        assert_eq!(result_literal, expected);
1515    }
1516
1517    #[tokio::test]
1518    async fn safe_mode_handles_misaligned_groups() {
1519        let schema = Arc::new(Schema::new(vec![
1520            Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
1521            Field::new("le", DataType::Utf8, true),
1522            Field::new("val", DataType::Float64, true),
1523        ]));
1524
1525        let ts_column = Arc::new(TimestampMillisecondArray::from(vec![
1526            2900000, 2900000, 2900000, 3000000, 3000000, 3000000, 3000000, 3005000, 3005000,
1527            3010000, 3010000, 3010000, 3010000, 3010000,
1528        ])) as _;
1529        let le_column = Arc::new(StringArray::from(vec![
1530            "0.1", "1", "5", "0.1", "1", "5", "+Inf", "0.1", "+Inf", "0.1", "1", "3", "5", "+Inf",
1531        ])) as _;
1532        let val_column = Arc::new(Float64Array::from(vec![
1533            0.0, 0.0, 0.0, 50.0, 70.0, 110.0, 120.0, 10.0, 30.0, 10.0, 20.0, 30.0, 40.0, 50.0,
1534        ])) as _;
1535        let batch =
1536            RecordBatch::try_new(schema.clone(), vec![ts_column, le_column, val_column]).unwrap();
1537        let fold_exec = build_fold_exec_from_batches(vec![batch], schema, 0.5, 0);
1538        let session_context = SessionContext::default();
1539        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1540            .await
1541            .unwrap();
1542
1543        let mut values = Vec::new();
1544        for batch in result {
1545            let array = batch.column(1).as_primitive::<Float64Type>();
1546            values.extend(array.iter().map(|v| v.unwrap()));
1547        }
1548
1549        assert_eq!(values.len(), 4);
1550        assert!(values[0].is_nan());
1551        assert!((values[1] - 0.55).abs() < 1e-10);
1552        assert!((values[2] - 0.1).abs() < 1e-10);
1553        assert!((values[3] - 2.0).abs() < 1e-10);
1554    }
1555
1556    #[tokio::test]
1557    async fn missing_buckets_at_first_timestamp() {
1558        let schema = Arc::new(Schema::new(vec![
1559            Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
1560            Field::new("le", DataType::Utf8, true),
1561            Field::new("val", DataType::Float64, true),
1562        ]));
1563
1564        let ts_column = Arc::new(TimestampMillisecondArray::from(vec![
1565            2_900_000, 3_000_000, 3_000_000, 3_000_000, 3_000_000, 3_005_000, 3_005_000, 3_010_000,
1566            3_010_000, 3_010_000, 3_010_000, 3_010_000,
1567        ])) as _;
1568        let le_column = Arc::new(StringArray::from(vec![
1569            "0.1", "0.1", "1", "5", "+Inf", "0.1", "+Inf", "0.1", "1", "3", "5", "+Inf",
1570        ])) as _;
1571        let val_column = Arc::new(Float64Array::from(vec![
1572            0.0, 50.0, 70.0, 110.0, 120.0, 10.0, 30.0, 10.0, 20.0, 30.0, 40.0, 50.0,
1573        ])) as _;
1574
1575        let batch =
1576            RecordBatch::try_new(schema.clone(), vec![ts_column, le_column, val_column]).unwrap();
1577        let fold_exec = build_fold_exec_from_batches(vec![batch], schema, 0.5, 0);
1578        let session_context = SessionContext::default();
1579        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1580            .await
1581            .unwrap();
1582
1583        let mut values = Vec::new();
1584        for batch in result {
1585            let array = batch.column(1).as_primitive::<Float64Type>();
1586            values.extend(array.iter().map(|v| v.unwrap()));
1587        }
1588
1589        assert_eq!(values.len(), 4);
1590        assert!(values[0].is_nan());
1591        assert!((values[1] - 0.55).abs() < 1e-10);
1592        assert!((values[2] - 0.1).abs() < 1e-10);
1593        assert!((values[3] - 2.0).abs() < 1e-10);
1594    }
1595
1596    #[tokio::test]
1597    async fn missing_inf_in_first_group() {
1598        let schema = Arc::new(Schema::new(vec![
1599            Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
1600            Field::new("le", DataType::Utf8, true),
1601            Field::new("val", DataType::Float64, true),
1602        ]));
1603
1604        let ts_column = Arc::new(TimestampMillisecondArray::from(vec![
1605            1000, 1000, 1000, 2000, 2000, 2000, 2000,
1606        ])) as _;
1607        let le_column = Arc::new(StringArray::from(vec![
1608            "0.1", "1", "5", "0.1", "1", "5", "+Inf",
1609        ])) as _;
1610        let val_column = Arc::new(Float64Array::from(vec![
1611            0.0, 0.0, 0.0, 10.0, 20.0, 30.0, 30.0,
1612        ])) as _;
1613        let batch =
1614            RecordBatch::try_new(schema.clone(), vec![ts_column, le_column, val_column]).unwrap();
1615        let fold_exec = build_fold_exec_from_batches(vec![batch], schema, 0.5, 0);
1616        let session_context = SessionContext::default();
1617        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1618            .await
1619            .unwrap();
1620
1621        let mut values = Vec::new();
1622        for batch in result {
1623            let array = batch.column(1).as_primitive::<Float64Type>();
1624            values.extend(array.iter().map(|v| v.unwrap()));
1625        }
1626
1627        assert_eq!(values.len(), 2);
1628        assert!(values[0].is_nan());
1629        assert!((values[1] - 0.55).abs() < 1e-10, "{values:?}");
1630    }
1631
1632    #[test]
1633    fn evaluate_row_normal_case() {
1634        let bucket = [0.0, 1.0, 2.0, 3.0, 4.0, f64::INFINITY];
1635
1636        #[derive(Debug)]
1637        struct Case {
1638            quantile: f64,
1639            counters: Vec<f64>,
1640            expected: f64,
1641        }
1642
1643        let cases = [
1644            Case {
1645                quantile: 0.9,
1646                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1647                expected: 4.0,
1648            },
1649            Case {
1650                quantile: 0.89,
1651                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1652                expected: 4.0,
1653            },
1654            Case {
1655                quantile: 0.78,
1656                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1657                expected: 3.9,
1658            },
1659            Case {
1660                quantile: 0.5,
1661                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1662                expected: 2.5,
1663            },
1664            Case {
1665                quantile: 0.5,
1666                counters: vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1667                expected: f64::NAN,
1668            },
1669            Case {
1670                quantile: 1.0,
1671                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1672                expected: 4.0,
1673            },
1674            Case {
1675                quantile: 0.0,
1676                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1677                expected: f64::NAN,
1678            },
1679            Case {
1680                quantile: 1.1,
1681                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1682                expected: f64::INFINITY,
1683            },
1684            Case {
1685                quantile: -1.0,
1686                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1687                expected: f64::NEG_INFINITY,
1688            },
1689        ];
1690
1691        for case in cases {
1692            let actual =
1693                HistogramFoldStream::evaluate_row(case.quantile, &bucket, &case.counters).unwrap();
1694            assert_eq!(
1695                format!("{actual}"),
1696                format!("{}", case.expected),
1697                "{:?}",
1698                case
1699            );
1700        }
1701    }
1702
1703    #[test]
1704    fn evaluate_out_of_order_input() {
1705        let bucket = [0.0, 1.0, 2.0, 3.0, 4.0, f64::INFINITY];
1706        let counters = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0];
1707        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters).unwrap();
1708        assert_eq!(0.0, result);
1709    }
1710
1711    #[test]
1712    fn evaluate_wrong_bucket() {
1713        let bucket = [0.0, 1.0, 2.0, 3.0, 4.0, f64::INFINITY, 5.0];
1714        let counters = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
1715        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters);
1716        assert!(result.is_err());
1717    }
1718
1719    #[test]
1720    fn evaluate_small_fraction() {
1721        let bucket = [0.0, 2.0, 4.0, 6.0, f64::INFINITY];
1722        let counters = [0.0, 1.0 / 300.0, 2.0 / 300.0, 0.01, 0.01];
1723        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters).unwrap();
1724        assert_eq!(3.0, result);
1725    }
1726
1727    #[test]
1728    fn evaluate_non_monotonic_counter() {
1729        let bucket = [0.0, 1.0, 2.0, 3.0, f64::INFINITY];
1730        let counters = [0.1, 0.2, 0.4, 0.17, 0.5];
1731        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters).unwrap();
1732        assert!((result - 1.25).abs() < 1e-10, "{result}");
1733    }
1734
1735    #[test]
1736    fn evaluate_nan_counter() {
1737        let bucket = [0.0, 1.0, 2.0, 3.0, f64::INFINITY];
1738        let counters = [f64::NAN, 1.0, 2.0, 3.0, 3.0];
1739        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters).unwrap();
1740        assert!((result - 1.5).abs() < 1e-10, "{result}");
1741    }
1742
1743    fn build_empty_relation(schema: &Arc<Schema>) -> LogicalPlan {
1744        LogicalPlan::EmptyRelation(EmptyRelation {
1745            produce_one_row: false,
1746            schema: schema.clone().to_dfschema_ref().unwrap(),
1747        })
1748    }
1749
1750    #[tokio::test]
1751    async fn encode_decode_histogram_fold() {
1752        let schema = Arc::new(Schema::new(vec![
1753            Field::new("ts", DataType::Int64, false),
1754            Field::new("le", DataType::Utf8, false),
1755            Field::new("val", DataType::Float64, false),
1756        ]));
1757        let input_plan = build_empty_relation(&schema);
1758        let plan_node = HistogramFold::new(
1759            "le".to_string(),
1760            "val".to_string(),
1761            "ts".to_string(),
1762            0.8,
1763            input_plan.clone(),
1764        )
1765        .unwrap();
1766
1767        let bytes = plan_node.serialize();
1768
1769        let histogram_fold = HistogramFold::deserialize(&bytes).unwrap();
1770        // need fix
1771        let histogram_fold = histogram_fold
1772            .with_exprs_and_inputs(vec![], vec![input_plan])
1773            .unwrap();
1774
1775        assert_eq!(histogram_fold.le_column, "le");
1776        assert_eq!(histogram_fold.ts_column, "ts");
1777        assert_eq!(histogram_fold.field_column, "val");
1778        assert_eq!(histogram_fold.quantile, OrderedF64::from(0.8));
1779        assert_eq!(histogram_fold.output_schema.fields().len(), 2);
1780        assert_eq!(histogram_fold.output_schema.field(0).name(), "ts");
1781        assert_eq!(histogram_fold.output_schema.field(1).name(), "val");
1782    }
1783}