Skip to main content

promql/extension_plan/
histogram_fold.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::borrow::Cow;
17use std::collections::{HashMap, HashSet};
18use std::sync::Arc;
19use std::task::Poll;
20use std::time::Instant;
21
22use common_telemetry::warn;
23use datafusion::arrow::array::{Array, AsArray, StringArray};
24use datafusion::arrow::compute::{SortOptions, concat_batches};
25use datafusion::arrow::datatypes::{DataType, Float64Type, SchemaRef};
26use datafusion::arrow::record_batch::RecordBatch;
27use datafusion::common::stats::Precision;
28use datafusion::common::{DFSchema, DFSchemaRef, Statistics};
29use datafusion::error::{DataFusionError, Result as DataFusionResult};
30use datafusion::execution::TaskContext;
31use datafusion::logical_expr::{LogicalPlan, UserDefinedLogicalNodeCore};
32use datafusion::physical_expr::{
33    EquivalenceProperties, LexRequirement, OrderingRequirements, PhysicalSortRequirement,
34};
35use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
36use datafusion::physical_plan::expressions::{CastExpr as PhyCast, Column as PhyColumn};
37use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
38use datafusion::physical_plan::{
39    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
40    Partitioning, PhysicalExpr, PlanProperties, RecordBatchStream, SendableRecordBatchStream,
41};
42use datafusion::prelude::{Column, Expr};
43use datafusion_expr::{EmptyRelation, col};
44use datatypes::prelude::{ConcreteDataType, DataType as GtDataType};
45use datatypes::value::{OrderedF64, Value, ValueRef};
46use datatypes::vectors::{Helper, MutableVector, VectorRef};
47use futures::{Stream, StreamExt, ready};
48use greptime_proto::substrait_extension as pb;
49use prost::Message;
50use snafu::ResultExt;
51
52use crate::error::{DeserializeSnafu, Result};
53use crate::extension_plan::{resolve_column_name, serialize_column_index};
54
55/// `HistogramFold` will fold the conventional (non-native) histogram ([1]) for later
56/// computing.
57///
58/// Specifically, it will transform the `le` and `field` column into a complex
59/// type, and samples on other tag columns:
60/// - `le` will become a [ListArray] of [f64]. With each bucket bound parsed
61/// - `field` will become a [ListArray] of [f64]
62/// - other columns will be sampled every `bucket_num` element, but their types won't change.
63///
64/// Due to the folding or sampling, the output rows number will become `input_rows` / `bucket_num`.
65///
66/// # Requirement
67/// - Input should be sorted on `<tag list>, ts, le ASC`.
68/// - The value set of `le` should be same. I.e., buckets of every series should be same.
69///
70/// [1]: https://prometheus.io/docs/concepts/metric_types/#histogram
71#[derive(Debug, PartialEq, Hash, Eq)]
72pub struct HistogramFold {
73    /// Name of the `le` column. It's a special column in prometheus
74    /// for implementing conventional histogram. It's a string column
75    /// with "literal" float value, like "+Inf", "0.001" etc.
76    le_column: String,
77    ts_column: String,
78    input: LogicalPlan,
79    field_column: String,
80    quantile: OrderedF64,
81    output_schema: DFSchemaRef,
82    unfix: Option<UnfixIndices>,
83}
84
85#[derive(Debug, PartialEq, Eq, Hash, PartialOrd)]
86struct UnfixIndices {
87    pub le_column_idx: u64,
88    pub ts_column_idx: u64,
89    pub field_column_idx: u64,
90}
91
92impl UserDefinedLogicalNodeCore for HistogramFold {
93    fn name(&self) -> &str {
94        Self::name()
95    }
96
97    fn inputs(&self) -> Vec<&LogicalPlan> {
98        vec![&self.input]
99    }
100
101    fn schema(&self) -> &DFSchemaRef {
102        &self.output_schema
103    }
104
105    fn expressions(&self) -> Vec<Expr> {
106        if self.unfix.is_some() {
107            return vec![];
108        }
109
110        let mut exprs = vec![
111            col(&self.le_column),
112            col(&self.ts_column),
113            col(&self.field_column),
114        ];
115        exprs.extend(self.input.schema().fields().iter().filter_map(|f| {
116            let name = f.name();
117            if name != &self.le_column && name != &self.ts_column && name != &self.field_column {
118                Some(col(name))
119            } else {
120                None
121            }
122        }));
123        exprs
124    }
125
126    fn necessary_children_exprs(&self, output_columns: &[usize]) -> Option<Vec<Vec<usize>>> {
127        if self.unfix.is_some() {
128            return None;
129        }
130
131        let input_schema = self.input.schema();
132        let le_column_index = input_schema.index_of_column_by_name(None, &self.le_column)?;
133
134        if output_columns.is_empty() {
135            let indices = (0..input_schema.fields().len()).collect::<Vec<_>>();
136            return Some(vec![indices]);
137        }
138
139        let mut necessary_indices = output_columns
140            .iter()
141            .map(|&output_column| {
142                if output_column < le_column_index {
143                    output_column
144                } else {
145                    output_column + 1
146                }
147            })
148            .collect::<Vec<_>>();
149        necessary_indices.push(le_column_index);
150        necessary_indices.sort_unstable();
151        necessary_indices.dedup();
152        Some(vec![necessary_indices])
153    }
154
155    fn fmt_for_explain(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156        write!(
157            f,
158            "HistogramFold: le={}, field={}, quantile={}",
159            self.le_column, self.field_column, self.quantile
160        )
161    }
162
163    fn with_exprs_and_inputs(
164        &self,
165        _exprs: Vec<Expr>,
166        inputs: Vec<LogicalPlan>,
167    ) -> DataFusionResult<Self> {
168        if inputs.is_empty() {
169            return Err(DataFusionError::Internal(
170                "HistogramFold must have at least one input".to_string(),
171            ));
172        }
173
174        let input: LogicalPlan = inputs.into_iter().next().unwrap();
175        let input_schema = input.schema();
176
177        if let Some(unfix) = &self.unfix {
178            let le_column =
179                resolve_column_name(unfix.le_column_idx, input_schema, "HistogramFold", "le")?;
180            let ts_column =
181                resolve_column_name(unfix.ts_column_idx, input_schema, "HistogramFold", "ts")?;
182            let field_column = resolve_column_name(
183                unfix.field_column_idx,
184                input_schema,
185                "HistogramFold",
186                "field",
187            )?;
188
189            let output_schema = Self::convert_schema(input_schema, &le_column)?;
190
191            Ok(Self {
192                le_column,
193                ts_column,
194                input,
195                field_column,
196                quantile: self.quantile,
197                output_schema,
198                unfix: None,
199            })
200        } else {
201            Ok(Self {
202                le_column: self.le_column.clone(),
203                ts_column: self.ts_column.clone(),
204                input,
205                field_column: self.field_column.clone(),
206                quantile: self.quantile,
207                output_schema: self.output_schema.clone(),
208                unfix: None,
209            })
210        }
211    }
212}
213
214impl HistogramFold {
215    pub fn new(
216        le_column: String,
217        field_column: String,
218        ts_column: String,
219        quantile: f64,
220        input: LogicalPlan,
221    ) -> DataFusionResult<Self> {
222        let input_schema = input.schema();
223        Self::check_schema(input_schema, &le_column, &field_column, &ts_column)?;
224        let output_schema = Self::convert_schema(input_schema, &le_column)?;
225        Ok(Self {
226            le_column,
227            ts_column,
228            input,
229            field_column,
230            quantile: quantile.into(),
231            output_schema,
232            unfix: None,
233        })
234    }
235
236    pub const fn name() -> &'static str {
237        "HistogramFold"
238    }
239
240    fn check_schema(
241        input_schema: &DFSchemaRef,
242        le_column: &str,
243        field_column: &str,
244        ts_column: &str,
245    ) -> DataFusionResult<()> {
246        let check_column = |col| {
247            if !input_schema.has_column_with_unqualified_name(col) {
248                Err(DataFusionError::SchemaError(
249                    Box::new(datafusion::common::SchemaError::FieldNotFound {
250                        field: Box::new(Column::new(None::<String>, col)),
251                        valid_fields: input_schema.columns(),
252                    }),
253                    Box::new(None),
254                ))
255            } else {
256                Ok(())
257            }
258        };
259
260        check_column(le_column)?;
261        check_column(ts_column)?;
262        check_column(field_column)
263    }
264
265    pub fn to_execution_plan(&self, exec_input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
266        let input_schema = self.input.schema();
267        // safety: those fields are checked in `check_schema()`
268        let le_column_index = input_schema
269            .index_of_column_by_name(None, &self.le_column)
270            .unwrap();
271        let field_column_index = input_schema
272            .index_of_column_by_name(None, &self.field_column)
273            .unwrap();
274        let ts_column_index = input_schema
275            .index_of_column_by_name(None, &self.ts_column)
276            .unwrap();
277
278        let tag_columns = exec_input
279            .schema()
280            .fields()
281            .iter()
282            .enumerate()
283            .filter_map(|(idx, field)| {
284                if idx == le_column_index || idx == field_column_index || idx == ts_column_index {
285                    None
286                } else {
287                    Some(Arc::new(PhyColumn::new(field.name(), idx)) as _)
288                }
289            })
290            .collect::<Vec<_>>();
291
292        let mut partition_exprs = tag_columns.clone();
293        partition_exprs.push(Arc::new(PhyColumn::new(
294            self.input.schema().field(ts_column_index).name(),
295            ts_column_index,
296        )) as _);
297
298        let output_schema: SchemaRef = self.output_schema.inner().clone();
299        let properties = Arc::new(PlanProperties::new(
300            EquivalenceProperties::new(output_schema.clone()),
301            Partitioning::Hash(
302                partition_exprs.clone(),
303                exec_input.output_partitioning().partition_count(),
304            ),
305            EmissionType::Incremental,
306            Boundedness::Bounded,
307        ));
308        Arc::new(HistogramFoldExec {
309            le_column_index,
310            field_column_index,
311            ts_column_index,
312            input: exec_input,
313            tag_columns,
314            partition_exprs,
315            quantile: self.quantile.into(),
316            output_schema,
317            metric: ExecutionPlanMetricsSet::new(),
318            properties,
319        })
320    }
321
322    /// Transform the schema
323    ///
324    /// - `le` will be removed
325    ///
326    /// Column qualifiers are preserved so downstream plan nodes can keep
327    /// referencing the columns by their original qualified names.
328    fn convert_schema(
329        input_schema: &DFSchemaRef,
330        le_column: &str,
331    ) -> DataFusionResult<DFSchemaRef> {
332        // safety: those fields are checked in `check_schema()`
333        let mut new_fields = Vec::with_capacity(input_schema.fields().len() - 1);
334        for (qualifier, field) in input_schema.iter() {
335            if field.name() != le_column {
336                new_fields.push((qualifier.cloned(), field.clone()));
337            }
338        }
339        Ok(Arc::new(DFSchema::new_with_metadata(
340            new_fields,
341            HashMap::new(),
342        )?))
343    }
344
345    pub fn serialize(&self) -> Vec<u8> {
346        let le_column_idx = serialize_column_index(self.input.schema(), &self.le_column);
347        let ts_column_idx = serialize_column_index(self.input.schema(), &self.ts_column);
348        let field_column_idx = serialize_column_index(self.input.schema(), &self.field_column);
349
350        pb::HistogramFold {
351            le_column_idx,
352            ts_column_idx,
353            field_column_idx,
354            quantile: self.quantile.into(),
355        }
356        .encode_to_vec()
357    }
358
359    pub fn deserialize(bytes: &[u8]) -> Result<Self> {
360        let pb_histogram_fold = pb::HistogramFold::decode(bytes).context(DeserializeSnafu)?;
361        let placeholder_plan = LogicalPlan::EmptyRelation(EmptyRelation {
362            produce_one_row: false,
363            schema: Arc::new(DFSchema::empty()),
364        });
365
366        let unfix = UnfixIndices {
367            le_column_idx: pb_histogram_fold.le_column_idx,
368            ts_column_idx: pb_histogram_fold.ts_column_idx,
369            field_column_idx: pb_histogram_fold.field_column_idx,
370        };
371
372        Ok(Self {
373            le_column: String::new(),
374            ts_column: String::new(),
375            input: placeholder_plan,
376            field_column: String::new(),
377            quantile: pb_histogram_fold.quantile.into(),
378            output_schema: Arc::new(DFSchema::empty()),
379            unfix: Some(unfix),
380        })
381    }
382}
383
384impl PartialOrd for HistogramFold {
385    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
386        // Compare fields in order excluding output_schema
387        match self.le_column.partial_cmp(&other.le_column) {
388            Some(core::cmp::Ordering::Equal) => {}
389            ord => return ord,
390        }
391        match self.ts_column.partial_cmp(&other.ts_column) {
392            Some(core::cmp::Ordering::Equal) => {}
393            ord => return ord,
394        }
395        match self.input.partial_cmp(&other.input) {
396            Some(core::cmp::Ordering::Equal) => {}
397            ord => return ord,
398        }
399        match self.field_column.partial_cmp(&other.field_column) {
400            Some(core::cmp::Ordering::Equal) => {}
401            ord => return ord,
402        }
403        self.quantile.partial_cmp(&other.quantile)
404    }
405}
406
407#[derive(Debug)]
408pub struct HistogramFoldExec {
409    /// Index for `le` column in the schema of input.
410    le_column_index: usize,
411    input: Arc<dyn ExecutionPlan>,
412    output_schema: SchemaRef,
413    /// Index for field column in the schema of input.
414    field_column_index: usize,
415    ts_column_index: usize,
416    /// Tag columns are all columns except `le`, `field` and `ts` columns.
417    tag_columns: Vec<Arc<dyn PhysicalExpr>>,
418    partition_exprs: Vec<Arc<dyn PhysicalExpr>>,
419    quantile: f64,
420    metric: ExecutionPlanMetricsSet,
421    properties: Arc<PlanProperties>,
422}
423
424impl ExecutionPlan for HistogramFoldExec {
425    fn as_any(&self) -> &dyn Any {
426        self
427    }
428
429    fn properties(&self) -> &Arc<PlanProperties> {
430        &self.properties
431    }
432
433    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
434        let mut cols = self
435            .tag_columns
436            .iter()
437            .map(|expr| PhysicalSortRequirement {
438                expr: expr.clone(),
439                options: None,
440            })
441            .collect::<Vec<PhysicalSortRequirement>>();
442        // add ts
443        cols.push(PhysicalSortRequirement {
444            expr: Arc::new(PhyColumn::new(
445                self.input.schema().field(self.ts_column_index).name(),
446                self.ts_column_index,
447            )),
448            options: None,
449        });
450        // add le ASC
451        cols.push(PhysicalSortRequirement {
452            expr: Arc::new(PhyCast::new(
453                Arc::new(PhyColumn::new(
454                    self.input.schema().field(self.le_column_index).name(),
455                    self.le_column_index,
456                )),
457                DataType::Float64,
458                None,
459            )),
460            options: Some(SortOptions {
461                descending: false,  // +INF in the last
462                nulls_first: false, // not nullable
463            }),
464        });
465
466        // Safety: `cols` is not empty
467        let requirement = LexRequirement::new(cols).unwrap();
468
469        vec![Some(OrderingRequirements::Hard(vec![requirement]))]
470    }
471
472    fn required_input_distribution(&self) -> Vec<Distribution> {
473        vec![Distribution::HashPartitioned(self.partition_exprs.clone())]
474    }
475
476    fn maintains_input_order(&self) -> Vec<bool> {
477        vec![true; self.children().len()]
478    }
479
480    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
481        vec![&self.input]
482    }
483
484    // cannot change schema with this method
485    fn with_new_children(
486        self: Arc<Self>,
487        children: Vec<Arc<dyn ExecutionPlan>>,
488    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
489        assert!(!children.is_empty());
490        let new_input = children[0].clone();
491        let properties = Arc::new(PlanProperties::new(
492            EquivalenceProperties::new(self.output_schema.clone()),
493            Partitioning::Hash(
494                self.partition_exprs.clone(),
495                new_input.output_partitioning().partition_count(),
496            ),
497            EmissionType::Incremental,
498            Boundedness::Bounded,
499        ));
500        Ok(Arc::new(Self {
501            input: new_input,
502            metric: self.metric.clone(),
503            le_column_index: self.le_column_index,
504            ts_column_index: self.ts_column_index,
505            tag_columns: self.tag_columns.clone(),
506            partition_exprs: self.partition_exprs.clone(),
507            quantile: self.quantile,
508            output_schema: self.output_schema.clone(),
509            field_column_index: self.field_column_index,
510            properties,
511        }))
512    }
513
514    fn execute(
515        &self,
516        partition: usize,
517        context: Arc<TaskContext>,
518    ) -> DataFusionResult<SendableRecordBatchStream> {
519        let baseline_metric = BaselineMetrics::new(&self.metric, partition);
520
521        let batch_size = context.session_config().batch_size();
522        let input = self.input.execute(partition, context)?;
523        let output_schema = self.output_schema.clone();
524
525        let mut normal_indices = (0..input.schema().fields().len()).collect::<HashSet<_>>();
526        normal_indices.remove(&self.field_column_index);
527        normal_indices.remove(&self.le_column_index);
528        Ok(Box::pin(HistogramFoldStream {
529            le_column_index: self.le_column_index,
530            field_column_index: self.field_column_index,
531            quantile: self.quantile,
532            normal_indices: normal_indices.into_iter().collect(),
533            bucket_size: None,
534            input_buffer: vec![],
535            input,
536            output_schema,
537            input_schema: self.input.schema(),
538            mode: FoldMode::Optimistic,
539            safe_group: None,
540            metric: baseline_metric,
541            batch_size,
542            input_buffered_rows: 0,
543            output_buffer: HistogramFoldStream::empty_output_buffer(
544                &self.output_schema,
545                self.le_column_index,
546            )?,
547            output_buffered_rows: 0,
548        }))
549    }
550
551    fn metrics(&self) -> Option<MetricsSet> {
552        Some(self.metric.clone_inner())
553    }
554
555    fn partition_statistics(&self, _: Option<usize>) -> DataFusionResult<Statistics> {
556        Ok(Statistics {
557            num_rows: Precision::Absent,
558            total_byte_size: Precision::Absent,
559            column_statistics: Statistics::unknown_column(&self.schema()),
560        })
561    }
562
563    fn name(&self) -> &str {
564        "HistogramFoldExec"
565    }
566}
567
568impl DisplayAs for HistogramFoldExec {
569    fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
570        match t {
571            DisplayFormatType::Default
572            | DisplayFormatType::Verbose
573            | DisplayFormatType::TreeRender => {
574                write!(
575                    f,
576                    "HistogramFoldExec: le=@{}, field=@{}, quantile={}",
577                    self.le_column_index, self.field_column_index, self.quantile
578                )
579            }
580        }
581    }
582}
583
584#[derive(Debug, Clone, Copy, PartialEq, Eq)]
585enum FoldMode {
586    Optimistic,
587    Safe,
588}
589
590pub struct HistogramFoldStream {
591    // internal states
592    le_column_index: usize,
593    field_column_index: usize,
594    quantile: f64,
595    /// Columns need not folding. This indices is based on input schema
596    normal_indices: Vec<usize>,
597    bucket_size: Option<usize>,
598    /// Expected output batch size
599    batch_size: usize,
600    output_schema: SchemaRef,
601    input_schema: SchemaRef,
602    mode: FoldMode,
603    safe_group: Option<SafeGroup>,
604
605    // buffers
606    input_buffer: Vec<RecordBatch>,
607    input_buffered_rows: usize,
608    output_buffer: Vec<Box<dyn MutableVector>>,
609    output_buffered_rows: usize,
610
611    // runtime things
612    input: SendableRecordBatchStream,
613    metric: BaselineMetrics,
614}
615
616#[derive(Debug, Default)]
617struct SafeGroup {
618    tag_values: Vec<Value>,
619    buckets: Vec<f64>,
620    counters: Vec<f64>,
621}
622
623impl RecordBatchStream for HistogramFoldStream {
624    fn schema(&self) -> SchemaRef {
625        self.output_schema.clone()
626    }
627}
628
629impl Stream for HistogramFoldStream {
630    type Item = DataFusionResult<RecordBatch>;
631
632    fn poll_next(
633        mut self: std::pin::Pin<&mut Self>,
634        cx: &mut std::task::Context<'_>,
635    ) -> Poll<Option<Self::Item>> {
636        let poll = loop {
637            match ready!(self.input.poll_next_unpin(cx)) {
638                Some(batch) => {
639                    let batch = batch?;
640                    let timer = Instant::now();
641                    let Some(result) = self.fold_input(batch)? else {
642                        self.metric.elapsed_compute().add_elapsed(timer);
643                        continue;
644                    };
645                    self.metric.elapsed_compute().add_elapsed(timer);
646                    break Poll::Ready(Some(result));
647                }
648                None => {
649                    self.flush_remaining()?;
650                    break Poll::Ready(self.take_output_buf()?.map(Ok));
651                }
652            }
653        };
654        self.metric.record_poll(poll)
655    }
656}
657
658impl HistogramFoldStream {
659    /// The inner most `Result` is for `poll_next()`
660    pub fn fold_input(
661        &mut self,
662        input: RecordBatch,
663    ) -> DataFusionResult<Option<DataFusionResult<RecordBatch>>> {
664        match self.mode {
665            FoldMode::Safe => {
666                self.push_input_buf(input);
667                self.process_safe_mode_buffer()?;
668            }
669            FoldMode::Optimistic => {
670                self.push_input_buf(input);
671                let Some(bucket_num) = self.calculate_bucket_num_from_buffer()? else {
672                    return Ok(None);
673                };
674                self.bucket_size = Some(bucket_num);
675
676                if self.input_buffered_rows < bucket_num {
677                    // not enough rows to fold
678                    return Ok(None);
679                }
680
681                self.fold_buf(bucket_num)?;
682            }
683        }
684
685        self.maybe_take_output()
686    }
687
688    /// Generate a group of empty [MutableVector]s from the output schema.
689    ///
690    /// For simplicity, this method will insert a placeholder for `le`. So that
691    /// the output buffers has the same schema with input. This placeholder needs
692    /// to be removed before returning the output batch.
693    pub fn empty_output_buffer(
694        schema: &SchemaRef,
695        le_column_index: usize,
696    ) -> DataFusionResult<Vec<Box<dyn MutableVector>>> {
697        let mut builders = Vec::with_capacity(schema.fields().len() + 1);
698        for field in schema.fields() {
699            let concrete_datatype = ConcreteDataType::try_from(field.data_type()).unwrap();
700            let mutable_vector = concrete_datatype.create_mutable_vector(0);
701            builders.push(mutable_vector);
702        }
703        builders.insert(
704            le_column_index,
705            ConcreteDataType::float64_datatype().create_mutable_vector(0),
706        );
707
708        Ok(builders)
709    }
710
711    /// Determines bucket count using buffered batches, concatenating them to
712    /// detect the first complete bucket that may span batch boundaries.
713    fn calculate_bucket_num_from_buffer(&mut self) -> DataFusionResult<Option<usize>> {
714        if let Some(size) = self.bucket_size {
715            return Ok(Some(size));
716        }
717
718        if self.input_buffer.is_empty() {
719            return Ok(None);
720        }
721
722        let batch_refs: Vec<&RecordBatch> = self.input_buffer.iter().collect();
723        let batch = concat_batches(&self.input_schema, batch_refs)?;
724        self.find_first_complete_bucket(&batch)
725    }
726
727    fn find_first_complete_bucket(&self, batch: &RecordBatch) -> DataFusionResult<Option<usize>> {
728        if batch.num_rows() == 0 {
729            return Ok(None);
730        }
731
732        let vectors = Helper::try_into_vectors(batch.columns())
733            .map_err(|e| DataFusionError::Execution(e.to_string()))?;
734        let le_array = batch.column(self.le_column_index).as_string::<i32>();
735
736        let mut tag_values_buf = Vec::with_capacity(self.normal_indices.len());
737        self.collect_tag_values(&vectors, 0, &mut tag_values_buf);
738        let mut group_start = 0usize;
739
740        for row in 0..batch.num_rows() {
741            if !self.is_same_group(&vectors, row, &tag_values_buf) {
742                // new group begins
743                self.collect_tag_values(&vectors, row, &mut tag_values_buf);
744                group_start = row;
745            }
746
747            if Self::is_positive_infinity(le_array, row) {
748                return Ok(Some(row - group_start + 1));
749            }
750        }
751
752        Ok(None)
753    }
754
755    /// Fold record batches from input buffer and put to output buffer
756    fn fold_buf(&mut self, bucket_num: usize) -> DataFusionResult<()> {
757        let batch = concat_batches(&self.input_schema, self.input_buffer.drain(..).as_ref())?;
758        let mut remaining_rows = self.input_buffered_rows;
759        let mut cursor = 0;
760
761        // TODO(LFC): Try to get rid of the Arrow array to vector conversion here.
762        let vectors = Helper::try_into_vectors(batch.columns())
763            .map_err(|e| DataFusionError::Execution(e.to_string()))?;
764        let le_array = batch.column(self.le_column_index);
765        let le_array = le_array.as_string::<i32>();
766        let field_array = batch.column(self.field_column_index);
767        let field_array = field_array.as_primitive::<Float64Type>();
768        let mut tag_values_buf = Vec::with_capacity(self.normal_indices.len());
769
770        while remaining_rows >= bucket_num && self.mode == FoldMode::Optimistic {
771            self.collect_tag_values(&vectors, cursor, &mut tag_values_buf);
772            if !self.validate_optimistic_group(
773                &vectors,
774                le_array,
775                cursor,
776                bucket_num,
777                &tag_values_buf,
778            ) {
779                let remaining_input_batch = batch.slice(cursor, remaining_rows);
780                self.switch_to_safe_mode(remaining_input_batch)?;
781                return Ok(());
782            }
783
784            // "sample" normal columns
785            for (idx, value) in self.normal_indices.iter().zip(tag_values_buf.iter()) {
786                self.output_buffer[*idx].push_value_ref(value);
787            }
788            // "fold" `le` and field columns
789            let mut bucket = Vec::with_capacity(bucket_num);
790            let mut counters = Vec::with_capacity(bucket_num);
791            for bias in 0..bucket_num {
792                let position = cursor + bias;
793                let le = if le_array.is_valid(position) {
794                    le_array.value(position).parse::<f64>().unwrap_or(f64::NAN)
795                } else {
796                    f64::NAN
797                };
798                bucket.push(le);
799
800                let counter = if field_array.is_valid(position) {
801                    field_array.value(position)
802                } else {
803                    f64::NAN
804                };
805                counters.push(counter);
806            }
807            // ignore invalid data
808            let result = Self::evaluate_row(self.quantile, &bucket, &counters).unwrap_or(f64::NAN);
809            self.output_buffer[self.field_column_index].push_value_ref(&ValueRef::from(result));
810            cursor += bucket_num;
811            remaining_rows -= bucket_num;
812            self.output_buffered_rows += 1;
813        }
814
815        let remaining_input_batch = batch.slice(cursor, remaining_rows);
816        self.input_buffered_rows = remaining_input_batch.num_rows();
817        if self.input_buffered_rows > 0 {
818            self.input_buffer.push(remaining_input_batch);
819        }
820
821        Ok(())
822    }
823
824    fn push_input_buf(&mut self, batch: RecordBatch) {
825        self.input_buffered_rows += batch.num_rows();
826        self.input_buffer.push(batch);
827    }
828
829    fn maybe_take_output(&mut self) -> DataFusionResult<Option<DataFusionResult<RecordBatch>>> {
830        if self.output_buffered_rows >= self.batch_size {
831            return Ok(self.take_output_buf()?.map(Ok));
832        }
833        Ok(None)
834    }
835
836    fn switch_to_safe_mode(&mut self, remaining_batch: RecordBatch) -> DataFusionResult<()> {
837        self.mode = FoldMode::Safe;
838        self.bucket_size = None;
839        self.input_buffer.clear();
840        self.input_buffered_rows = remaining_batch.num_rows();
841
842        if self.input_buffered_rows > 0 {
843            self.input_buffer.push(remaining_batch);
844            self.process_safe_mode_buffer()?;
845        }
846
847        Ok(())
848    }
849
850    fn collect_tag_values<'a>(
851        &self,
852        vectors: &'a [VectorRef],
853        row: usize,
854        tag_values: &mut Vec<ValueRef<'a>>,
855    ) {
856        tag_values.clear();
857        for idx in self.normal_indices.iter() {
858            tag_values.push(vectors[*idx].get_ref(row));
859        }
860    }
861
862    fn validate_optimistic_group(
863        &self,
864        vectors: &[VectorRef],
865        le_array: &StringArray,
866        cursor: usize,
867        bucket_num: usize,
868        tag_values: &[ValueRef<'_>],
869    ) -> bool {
870        let inf_index = cursor + bucket_num - 1;
871        if !Self::is_positive_infinity(le_array, inf_index) {
872            return false;
873        }
874
875        for offset in 1..bucket_num {
876            let row = cursor + offset;
877            for (idx, expected) in self.normal_indices.iter().zip(tag_values.iter()) {
878                if vectors[*idx].get_ref(row) != *expected {
879                    return false;
880                }
881            }
882        }
883        true
884    }
885
886    /// Checks whether a row belongs to the current group (same series).
887    fn is_same_group(
888        &self,
889        vectors: &[VectorRef],
890        row: usize,
891        tag_values: &[ValueRef<'_>],
892    ) -> bool {
893        self.normal_indices
894            .iter()
895            .zip(tag_values.iter())
896            .all(|(idx, expected)| vectors[*idx].get_ref(row) == *expected)
897    }
898
899    fn push_output_row(&mut self, tag_values: &[ValueRef<'_>], result: f64) {
900        debug_assert_eq!(self.normal_indices.len(), tag_values.len());
901        for (idx, value) in self.normal_indices.iter().zip(tag_values.iter()) {
902            self.output_buffer[*idx].push_value_ref(value);
903        }
904        self.output_buffer[self.field_column_index].push_value_ref(&ValueRef::from(result));
905        self.output_buffered_rows += 1;
906    }
907
908    fn finalize_safe_group(&mut self) -> DataFusionResult<()> {
909        if let Some(group) = self.safe_group.take() {
910            if group.tag_values.is_empty() {
911                return Ok(());
912            }
913
914            let has_inf = group
915                .buckets
916                .last()
917                .map(|v| v.is_infinite() && v.is_sign_positive())
918                .unwrap_or(false);
919            let result = if group.buckets.len() < 2 || !has_inf {
920                f64::NAN
921            } else {
922                Self::evaluate_row(self.quantile, &group.buckets, &group.counters)
923                    .unwrap_or(f64::NAN)
924            };
925            let mut tag_value_refs = Vec::with_capacity(group.tag_values.len());
926            tag_value_refs.extend(group.tag_values.iter().map(|v| v.as_value_ref()));
927            self.push_output_row(&tag_value_refs, result);
928        }
929        Ok(())
930    }
931
932    fn process_safe_mode_buffer(&mut self) -> DataFusionResult<()> {
933        if self.input_buffer.is_empty() {
934            self.input_buffered_rows = 0;
935            return Ok(());
936        }
937
938        let batch = concat_batches(&self.input_schema, self.input_buffer.drain(..).as_ref())?;
939        self.input_buffered_rows = 0;
940        let vectors = Helper::try_into_vectors(batch.columns())
941            .map_err(|e| DataFusionError::Execution(e.to_string()))?;
942        let le_array = batch.column(self.le_column_index).as_string::<i32>();
943        let field_array = batch
944            .column(self.field_column_index)
945            .as_primitive::<Float64Type>();
946        let mut tag_values_buf = Vec::with_capacity(self.normal_indices.len());
947
948        for row in 0..batch.num_rows() {
949            self.collect_tag_values(&vectors, row, &mut tag_values_buf);
950            let should_start_new_group = self
951                .safe_group
952                .as_ref()
953                .is_none_or(|group| !Self::tag_values_equal(&group.tag_values, &tag_values_buf));
954            if should_start_new_group {
955                self.finalize_safe_group()?;
956                self.safe_group = Some(SafeGroup {
957                    tag_values: tag_values_buf.iter().cloned().map(Value::from).collect(),
958                    buckets: Vec::new(),
959                    counters: Vec::new(),
960                });
961            }
962
963            let Some(group) = self.safe_group.as_mut() else {
964                continue;
965            };
966
967            let bucket = if le_array.is_valid(row) {
968                le_array.value(row).parse::<f64>().unwrap_or(f64::NAN)
969            } else {
970                f64::NAN
971            };
972            let counter = if field_array.is_valid(row) {
973                field_array.value(row)
974            } else {
975                f64::NAN
976            };
977
978            group.buckets.push(bucket);
979            group.counters.push(counter);
980        }
981
982        Ok(())
983    }
984
985    fn tag_values_equal(group_values: &[Value], current: &[ValueRef<'_>]) -> bool {
986        group_values.len() == current.len()
987            && group_values
988                .iter()
989                .zip(current.iter())
990                .all(|(group, now)| group.as_value_ref() == *now)
991    }
992
993    /// Compute result from output buffer
994    fn take_output_buf(&mut self) -> DataFusionResult<Option<RecordBatch>> {
995        if self.output_buffered_rows == 0 {
996            if self.input_buffered_rows != 0 {
997                warn!(
998                    "input buffer is not empty, {} rows remaining",
999                    self.input_buffered_rows
1000                );
1001            }
1002            return Ok(None);
1003        }
1004
1005        let mut output_buf = Self::empty_output_buffer(&self.output_schema, self.le_column_index)?;
1006        std::mem::swap(&mut self.output_buffer, &mut output_buf);
1007        let mut columns = Vec::with_capacity(output_buf.len());
1008        for builder in output_buf.iter_mut() {
1009            columns.push(builder.to_vector().to_arrow_array());
1010        }
1011        // remove the placeholder column for `le`
1012        columns.remove(self.le_column_index);
1013
1014        self.output_buffered_rows = 0;
1015        RecordBatch::try_new(self.output_schema.clone(), columns)
1016            .map(Some)
1017            .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))
1018    }
1019
1020    fn flush_remaining(&mut self) -> DataFusionResult<()> {
1021        if self.mode == FoldMode::Optimistic && self.input_buffered_rows > 0 {
1022            let buffered_batches: Vec<_> = self.input_buffer.drain(..).collect();
1023            if !buffered_batches.is_empty() {
1024                let batch = concat_batches(&self.input_schema, buffered_batches.as_slice())?;
1025                self.switch_to_safe_mode(batch)?;
1026            } else {
1027                self.input_buffered_rows = 0;
1028            }
1029        }
1030
1031        if self.mode == FoldMode::Safe {
1032            self.process_safe_mode_buffer()?;
1033            self.finalize_safe_group()?;
1034        }
1035
1036        Ok(())
1037    }
1038
1039    fn is_positive_infinity(le_array: &StringArray, index: usize) -> bool {
1040        le_array.is_valid(index)
1041            && matches!(
1042                le_array.value(index).parse::<f64>(),
1043                Ok(value) if value.is_infinite() && value.is_sign_positive()
1044            )
1045    }
1046
1047    /// Evaluate the field column and return the result
1048    fn evaluate_row(quantile: f64, bucket: &[f64], counter: &[f64]) -> DataFusionResult<f64> {
1049        // check bucket
1050        if bucket.len() <= 1 {
1051            return Ok(f64::NAN);
1052        }
1053        if bucket.last().unwrap().is_finite() {
1054            return Err(DataFusionError::Execution(
1055                "last bucket should be +Inf".to_string(),
1056            ));
1057        }
1058        if bucket.len() != counter.len() {
1059            return Err(DataFusionError::Execution(
1060                "bucket and counter should have the same length".to_string(),
1061            ));
1062        }
1063        // check quantile
1064        if quantile < 0.0 {
1065            return Ok(f64::NEG_INFINITY);
1066        } else if quantile > 1.0 {
1067            return Ok(f64::INFINITY);
1068        } else if quantile.is_nan() {
1069            return Ok(f64::NAN);
1070        }
1071
1072        // check input value
1073        if !bucket.windows(2).all(|w| w[0] <= w[1]) {
1074            return Ok(f64::NAN);
1075        }
1076        let counter = {
1077            let needs_fix =
1078                counter.iter().any(|v| !v.is_finite()) || !counter.windows(2).all(|w| w[0] <= w[1]);
1079            if !needs_fix {
1080                Cow::Borrowed(counter)
1081            } else {
1082                let mut fixed = Vec::with_capacity(counter.len());
1083                let mut prev = 0.0;
1084                for (idx, &v) in counter.iter().enumerate() {
1085                    let mut val = if v.is_finite() { v } else { prev };
1086                    if idx > 0 && val < prev {
1087                        val = prev;
1088                    }
1089                    fixed.push(val);
1090                    prev = val;
1091                }
1092                Cow::Owned(fixed)
1093            }
1094        };
1095
1096        let total = *counter.last().unwrap();
1097        let expected_pos = total * quantile;
1098        let mut fit_bucket_pos = 0;
1099        while fit_bucket_pos < bucket.len() && counter[fit_bucket_pos] < expected_pos {
1100            fit_bucket_pos += 1;
1101        }
1102        if fit_bucket_pos >= bucket.len() - 1 {
1103            Ok(bucket[bucket.len() - 2])
1104        } else {
1105            let upper_bound = bucket[fit_bucket_pos];
1106            let upper_count = counter[fit_bucket_pos];
1107            let mut lower_bound = bucket[0].min(0.0);
1108            let mut lower_count = 0.0;
1109            if fit_bucket_pos > 0 {
1110                lower_bound = bucket[fit_bucket_pos - 1];
1111                lower_count = counter[fit_bucket_pos - 1];
1112            }
1113            if (upper_count - lower_count).abs() < 1e-10 {
1114                return Ok(f64::NAN);
1115            }
1116            Ok(lower_bound
1117                + (upper_bound - lower_bound) / (upper_count - lower_count)
1118                    * (expected_pos - lower_count))
1119        }
1120    }
1121}
1122
1123#[cfg(test)]
1124mod test {
1125    use std::sync::Arc;
1126
1127    use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
1128    use datafusion::arrow::datatypes::{Field, Schema, SchemaRef, TimeUnit};
1129    use datafusion::common::ToDFSchema;
1130    use datafusion::datasource::memory::MemorySourceConfig;
1131    use datafusion::datasource::source::DataSourceExec;
1132    use datafusion::logical_expr::EmptyRelation;
1133    use datafusion::prelude::SessionContext;
1134    use datatypes::arrow_array::StringArray;
1135    use futures::FutureExt;
1136
1137    use super::*;
1138
1139    fn project_batch(batch: &RecordBatch, indices: &[usize]) -> RecordBatch {
1140        let fields = indices
1141            .iter()
1142            .map(|&idx| batch.schema().field(idx).clone())
1143            .collect::<Vec<_>>();
1144        let columns = indices
1145            .iter()
1146            .map(|&idx| batch.column(idx).clone())
1147            .collect::<Vec<_>>();
1148        let schema = Arc::new(Schema::new(fields));
1149        RecordBatch::try_new(schema, columns).unwrap()
1150    }
1151
1152    fn prepare_test_data() -> DataSourceExec {
1153        let schema = Arc::new(Schema::new(vec![
1154            Field::new("host", DataType::Utf8, true),
1155            Field::new("le", DataType::Utf8, true),
1156            Field::new("val", DataType::Float64, true),
1157        ]));
1158
1159        // 12 items
1160        let host_column_1 = Arc::new(StringArray::from(vec![
1161            "host_1", "host_1", "host_1", "host_1", "host_1", "host_1", "host_1", "host_1",
1162            "host_1", "host_1", "host_1", "host_1",
1163        ])) as _;
1164        let le_column_1 = Arc::new(StringArray::from(vec![
1165            "0.001", "0.1", "10", "1000", "+Inf", "0.001", "0.1", "10", "1000", "+inf", "0.001",
1166            "0.1",
1167        ])) as _;
1168        let val_column_1 = Arc::new(Float64Array::from(vec![
1169            0_0.0, 1.0, 1.0, 5.0, 5.0, 0_0.0, 20.0, 60.0, 70.0, 100.0, 0_1.0, 1.0,
1170        ])) as _;
1171
1172        // 2 items
1173        let host_column_2 = Arc::new(StringArray::from(vec!["host_1", "host_1"])) as _;
1174        let le_column_2 = Arc::new(StringArray::from(vec!["10", "1000"])) as _;
1175        let val_column_2 = Arc::new(Float64Array::from(vec![1.0, 1.0])) as _;
1176
1177        // 11 items
1178        let host_column_3 = Arc::new(StringArray::from(vec![
1179            "host_1", "host_2", "host_2", "host_2", "host_2", "host_2", "host_2", "host_2",
1180            "host_2", "host_2", "host_2",
1181        ])) as _;
1182        let le_column_3 = Arc::new(StringArray::from(vec![
1183            "+INF", "0.001", "0.1", "10", "1000", "+iNf", "0.001", "0.1", "10", "1000", "+Inf",
1184        ])) as _;
1185        let val_column_3 = Arc::new(Float64Array::from(vec![
1186            1.0, 0_0.0, 0.0, 0.0, 0.0, 0.0, 0_0.0, 1.0, 2.0, 3.0, 4.0,
1187        ])) as _;
1188
1189        let data_1 = RecordBatch::try_new(
1190            schema.clone(),
1191            vec![host_column_1, le_column_1, val_column_1],
1192        )
1193        .unwrap();
1194        let data_2 = RecordBatch::try_new(
1195            schema.clone(),
1196            vec![host_column_2, le_column_2, val_column_2],
1197        )
1198        .unwrap();
1199        let data_3 = RecordBatch::try_new(
1200            schema.clone(),
1201            vec![host_column_3, le_column_3, val_column_3],
1202        )
1203        .unwrap();
1204
1205        DataSourceExec::new(Arc::new(
1206            MemorySourceConfig::try_new(&[vec![data_1, data_2, data_3]], schema, None).unwrap(),
1207        ))
1208    }
1209
1210    fn build_fold_exec_from_batches(
1211        batches: Vec<RecordBatch>,
1212        schema: SchemaRef,
1213        quantile: f64,
1214        ts_column_index: usize,
1215    ) -> Arc<HistogramFoldExec> {
1216        let input: Arc<dyn ExecutionPlan> = Arc::new(DataSourceExec::new(Arc::new(
1217            MemorySourceConfig::try_new(&[batches], schema.clone(), None).unwrap(),
1218        )));
1219        let output_schema: SchemaRef = Arc::new(
1220            HistogramFold::convert_schema(&Arc::new(input.schema().to_dfschema().unwrap()), "le")
1221                .unwrap()
1222                .as_arrow()
1223                .clone(),
1224        );
1225
1226        let (tag_columns, partition_exprs, properties) =
1227            build_test_plan_properties(&input, output_schema.clone(), ts_column_index);
1228
1229        Arc::new(HistogramFoldExec {
1230            le_column_index: 1,
1231            field_column_index: 2,
1232            quantile,
1233            ts_column_index,
1234            input,
1235            output_schema,
1236            tag_columns,
1237            partition_exprs,
1238            metric: ExecutionPlanMetricsSet::new(),
1239            properties,
1240        })
1241    }
1242
1243    type PlanPropsResult = (
1244        Vec<Arc<dyn PhysicalExpr>>,
1245        Vec<Arc<dyn PhysicalExpr>>,
1246        Arc<PlanProperties>,
1247    );
1248
1249    fn build_test_plan_properties(
1250        input: &Arc<dyn ExecutionPlan>,
1251        output_schema: SchemaRef,
1252        ts_column_index: usize,
1253    ) -> PlanPropsResult {
1254        let tag_columns = input
1255            .schema()
1256            .fields()
1257            .iter()
1258            .enumerate()
1259            .filter_map(|(idx, field)| {
1260                if idx == 1 || idx == 2 || idx == ts_column_index {
1261                    None
1262                } else {
1263                    Some(Arc::new(PhyColumn::new(field.name(), idx)) as _)
1264                }
1265            })
1266            .collect::<Vec<_>>();
1267
1268        let partition_exprs = if tag_columns.is_empty() {
1269            vec![Arc::new(PhyColumn::new(
1270                input.schema().field(ts_column_index).name(),
1271                ts_column_index,
1272            )) as _]
1273        } else {
1274            tag_columns.clone()
1275        };
1276
1277        let properties = PlanProperties::new(
1278            EquivalenceProperties::new(output_schema.clone()),
1279            Partitioning::Hash(
1280                partition_exprs.clone(),
1281                input.output_partitioning().partition_count(),
1282            ),
1283            EmissionType::Incremental,
1284            Boundedness::Bounded,
1285        );
1286
1287        (tag_columns, partition_exprs, Arc::new(properties))
1288    }
1289
1290    #[tokio::test]
1291    async fn fold_overall() {
1292        let memory_exec: Arc<dyn ExecutionPlan> = Arc::new(prepare_test_data());
1293        let output_schema: SchemaRef = Arc::new(
1294            HistogramFold::convert_schema(
1295                &Arc::new(memory_exec.schema().to_dfschema().unwrap()),
1296                "le",
1297            )
1298            .unwrap()
1299            .as_arrow()
1300            .clone(),
1301        );
1302        let (tag_columns, partition_exprs, properties) =
1303            build_test_plan_properties(&memory_exec, output_schema.clone(), 0);
1304        let fold_exec = Arc::new(HistogramFoldExec {
1305            le_column_index: 1,
1306            field_column_index: 2,
1307            quantile: 0.4,
1308            ts_column_index: 0,
1309            input: memory_exec,
1310            output_schema,
1311            tag_columns,
1312            partition_exprs,
1313            metric: ExecutionPlanMetricsSet::new(),
1314            properties,
1315        });
1316
1317        let session_context = SessionContext::default();
1318        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1319            .await
1320            .unwrap();
1321        let result_literal = datatypes::arrow::util::pretty::pretty_format_batches(&result)
1322            .unwrap()
1323            .to_string();
1324
1325        let expected = String::from(
1326            "+--------+-------------------+
1327| host   | val               |
1328+--------+-------------------+
1329| host_1 | 257.5             |
1330| host_1 | 5.05              |
1331| host_1 | 0.0004            |
1332| host_2 | NaN               |
1333| host_2 | 6.040000000000001 |
1334+--------+-------------------+",
1335        );
1336        assert_eq!(result_literal, expected);
1337    }
1338
1339    #[tokio::test]
1340    async fn pruning_should_keep_le_column_for_exec() {
1341        let schema = Arc::new(Schema::new(vec![
1342            Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
1343            Field::new("le", DataType::Utf8, true),
1344            Field::new("val", DataType::Float64, true),
1345        ]));
1346        let df_schema = schema.clone().to_dfschema_ref().unwrap();
1347        let input = LogicalPlan::EmptyRelation(EmptyRelation {
1348            produce_one_row: false,
1349            schema: df_schema,
1350        });
1351        let plan = HistogramFold::new(
1352            "le".to_string(),
1353            "val".to_string(),
1354            "ts".to_string(),
1355            0.5,
1356            input,
1357        )
1358        .unwrap();
1359
1360        let output_columns = [0usize, 1usize];
1361        let required = plan.necessary_children_exprs(&output_columns).unwrap();
1362        let required = &required[0];
1363        assert_eq!(required.as_slice(), &[0, 1, 2]);
1364
1365        let input_batch = RecordBatch::try_new(
1366            schema,
1367            vec![
1368                Arc::new(TimestampMillisecondArray::from(vec![0, 0])),
1369                Arc::new(StringArray::from(vec!["0.1", "+Inf"])),
1370                Arc::new(Float64Array::from(vec![1.0, 2.0])),
1371            ],
1372        )
1373        .unwrap();
1374        let projected = project_batch(&input_batch, required);
1375        let projected_schema = projected.schema();
1376        let memory_exec = Arc::new(DataSourceExec::new(Arc::new(
1377            MemorySourceConfig::try_new(&[vec![projected]], projected_schema, None).unwrap(),
1378        )));
1379
1380        let fold_exec = plan.to_execution_plan(memory_exec);
1381        let session_context = SessionContext::default();
1382        let output_batches =
1383            datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1384                .await
1385                .unwrap();
1386        assert_eq!(output_batches.len(), 1);
1387
1388        let output_batch = &output_batches[0];
1389        assert_eq!(output_batch.num_rows(), 1);
1390
1391        let ts = output_batch
1392            .column(0)
1393            .as_any()
1394            .downcast_ref::<TimestampMillisecondArray>()
1395            .unwrap();
1396        assert_eq!(ts.values(), &[0i64]);
1397
1398        let values = output_batch
1399            .column(1)
1400            .as_any()
1401            .downcast_ref::<Float64Array>()
1402            .unwrap();
1403        assert!((values.value(0) - 0.1).abs() < 1e-12);
1404
1405        // Simulate the pre-fix pruning behavior: omit the `le` column from the child input.
1406        let le_index = 1usize;
1407        let broken_required = output_columns
1408            .iter()
1409            .map(|&output_column| {
1410                if output_column < le_index {
1411                    output_column
1412                } else {
1413                    output_column + 1
1414                }
1415            })
1416            .collect::<Vec<_>>();
1417
1418        let broken = project_batch(&input_batch, &broken_required);
1419        let broken_schema = broken.schema();
1420        let broken_exec = Arc::new(DataSourceExec::new(Arc::new(
1421            MemorySourceConfig::try_new(&[vec![broken]], broken_schema, None).unwrap(),
1422        )));
1423        let broken_fold_exec = plan.to_execution_plan(broken_exec);
1424        let session_context = SessionContext::default();
1425        let broken_result = std::panic::AssertUnwindSafe(async {
1426            datafusion::physical_plan::collect(broken_fold_exec, session_context.task_ctx()).await
1427        })
1428        .catch_unwind()
1429        .await;
1430        assert!(broken_result.is_err());
1431    }
1432
1433    #[test]
1434    fn confirm_schema() {
1435        let input_schema = Schema::new(vec![
1436            Field::new("host", DataType::Utf8, true),
1437            Field::new("le", DataType::Utf8, true),
1438            Field::new("val", DataType::Float64, true),
1439        ])
1440        .to_dfschema_ref()
1441        .unwrap();
1442        let expected_output_schema = Schema::new(vec![
1443            Field::new("host", DataType::Utf8, true),
1444            Field::new("val", DataType::Float64, true),
1445        ])
1446        .to_dfschema_ref()
1447        .unwrap();
1448
1449        let actual = HistogramFold::convert_schema(&input_schema, "le").unwrap();
1450        assert_eq!(actual, expected_output_schema)
1451    }
1452
1453    #[tokio::test]
1454    async fn fallback_to_safe_mode_on_missing_inf() {
1455        let schema = Arc::new(Schema::new(vec![
1456            Field::new("host", DataType::Utf8, true),
1457            Field::new("le", DataType::Utf8, true),
1458            Field::new("val", DataType::Float64, true),
1459        ]));
1460        let host_column = Arc::new(StringArray::from(vec!["a", "a", "a", "a", "b", "b"])) as _;
1461        let le_column = Arc::new(StringArray::from(vec![
1462            "0.1", "+Inf", "0.1", "1.0", "0.1", "+Inf",
1463        ])) as _;
1464        let val_column = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 3.0, 1.0, 5.0])) as _;
1465        let batch =
1466            RecordBatch::try_new(schema.clone(), vec![host_column, le_column, val_column]).unwrap();
1467        let fold_exec = build_fold_exec_from_batches(vec![batch], schema, 0.5, 0);
1468        let session_context = SessionContext::default();
1469        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1470            .await
1471            .unwrap();
1472        let result_literal = datatypes::arrow::util::pretty::pretty_format_batches(&result)
1473            .unwrap()
1474            .to_string();
1475
1476        let expected = String::from(
1477            "+------+-----+
1478| host | val |
1479+------+-----+
1480| a    | 0.1 |
1481| a    | NaN |
1482| b    | 0.1 |
1483+------+-----+",
1484        );
1485        assert_eq!(result_literal, expected);
1486    }
1487
1488    #[tokio::test]
1489    async fn emit_nan_when_no_inf_present() {
1490        let schema = Arc::new(Schema::new(vec![
1491            Field::new("host", DataType::Utf8, true),
1492            Field::new("le", DataType::Utf8, true),
1493            Field::new("val", DataType::Float64, true),
1494        ]));
1495        let host_column = Arc::new(StringArray::from(vec!["c", "c"])) as _;
1496        let le_column = Arc::new(StringArray::from(vec!["0.1", "1.0"])) as _;
1497        let val_column = Arc::new(Float64Array::from(vec![1.0, 2.0])) as _;
1498        let batch =
1499            RecordBatch::try_new(schema.clone(), vec![host_column, le_column, val_column]).unwrap();
1500        let fold_exec = build_fold_exec_from_batches(vec![batch], schema, 0.9, 0);
1501        let session_context = SessionContext::default();
1502        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1503            .await
1504            .unwrap();
1505        let result_literal = datatypes::arrow::util::pretty::pretty_format_batches(&result)
1506            .unwrap()
1507            .to_string();
1508
1509        let expected = String::from(
1510            "+------+-----+
1511| host | val |
1512+------+-----+
1513| c    | NaN |
1514+------+-----+",
1515        );
1516        assert_eq!(result_literal, expected);
1517    }
1518
1519    #[tokio::test]
1520    async fn safe_mode_handles_misaligned_groups() {
1521        let schema = Arc::new(Schema::new(vec![
1522            Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
1523            Field::new("le", DataType::Utf8, true),
1524            Field::new("val", DataType::Float64, true),
1525        ]));
1526
1527        let ts_column = Arc::new(TimestampMillisecondArray::from(vec![
1528            2900000, 2900000, 2900000, 3000000, 3000000, 3000000, 3000000, 3005000, 3005000,
1529            3010000, 3010000, 3010000, 3010000, 3010000,
1530        ])) as _;
1531        let le_column = Arc::new(StringArray::from(vec![
1532            "0.1", "1", "5", "0.1", "1", "5", "+Inf", "0.1", "+Inf", "0.1", "1", "3", "5", "+Inf",
1533        ])) as _;
1534        let val_column = Arc::new(Float64Array::from(vec![
1535            0.0, 0.0, 0.0, 50.0, 70.0, 110.0, 120.0, 10.0, 30.0, 10.0, 20.0, 30.0, 40.0, 50.0,
1536        ])) as _;
1537        let batch =
1538            RecordBatch::try_new(schema.clone(), vec![ts_column, le_column, val_column]).unwrap();
1539        let fold_exec = build_fold_exec_from_batches(vec![batch], schema, 0.5, 0);
1540        let session_context = SessionContext::default();
1541        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1542            .await
1543            .unwrap();
1544
1545        let mut values = Vec::new();
1546        for batch in result {
1547            let array = batch.column(1).as_primitive::<Float64Type>();
1548            values.extend(array.iter().map(|v| v.unwrap()));
1549        }
1550
1551        assert_eq!(values.len(), 4);
1552        assert!(values[0].is_nan());
1553        assert!((values[1] - 0.55).abs() < 1e-10);
1554        assert!((values[2] - 0.1).abs() < 1e-10);
1555        assert!((values[3] - 2.0).abs() < 1e-10);
1556    }
1557
1558    #[tokio::test]
1559    async fn missing_buckets_at_first_timestamp() {
1560        let schema = Arc::new(Schema::new(vec![
1561            Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
1562            Field::new("le", DataType::Utf8, true),
1563            Field::new("val", DataType::Float64, true),
1564        ]));
1565
1566        let ts_column = Arc::new(TimestampMillisecondArray::from(vec![
1567            2_900_000, 3_000_000, 3_000_000, 3_000_000, 3_000_000, 3_005_000, 3_005_000, 3_010_000,
1568            3_010_000, 3_010_000, 3_010_000, 3_010_000,
1569        ])) as _;
1570        let le_column = Arc::new(StringArray::from(vec![
1571            "0.1", "0.1", "1", "5", "+Inf", "0.1", "+Inf", "0.1", "1", "3", "5", "+Inf",
1572        ])) as _;
1573        let val_column = Arc::new(Float64Array::from(vec![
1574            0.0, 50.0, 70.0, 110.0, 120.0, 10.0, 30.0, 10.0, 20.0, 30.0, 40.0, 50.0,
1575        ])) as _;
1576
1577        let batch =
1578            RecordBatch::try_new(schema.clone(), vec![ts_column, le_column, val_column]).unwrap();
1579        let fold_exec = build_fold_exec_from_batches(vec![batch], schema, 0.5, 0);
1580        let session_context = SessionContext::default();
1581        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1582            .await
1583            .unwrap();
1584
1585        let mut values = Vec::new();
1586        for batch in result {
1587            let array = batch.column(1).as_primitive::<Float64Type>();
1588            values.extend(array.iter().map(|v| v.unwrap()));
1589        }
1590
1591        assert_eq!(values.len(), 4);
1592        assert!(values[0].is_nan());
1593        assert!((values[1] - 0.55).abs() < 1e-10);
1594        assert!((values[2] - 0.1).abs() < 1e-10);
1595        assert!((values[3] - 2.0).abs() < 1e-10);
1596    }
1597
1598    #[tokio::test]
1599    async fn missing_inf_in_first_group() {
1600        let schema = Arc::new(Schema::new(vec![
1601            Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
1602            Field::new("le", DataType::Utf8, true),
1603            Field::new("val", DataType::Float64, true),
1604        ]));
1605
1606        let ts_column = Arc::new(TimestampMillisecondArray::from(vec![
1607            1000, 1000, 1000, 2000, 2000, 2000, 2000,
1608        ])) as _;
1609        let le_column = Arc::new(StringArray::from(vec![
1610            "0.1", "1", "5", "0.1", "1", "5", "+Inf",
1611        ])) as _;
1612        let val_column = Arc::new(Float64Array::from(vec![
1613            0.0, 0.0, 0.0, 10.0, 20.0, 30.0, 30.0,
1614        ])) as _;
1615        let batch =
1616            RecordBatch::try_new(schema.clone(), vec![ts_column, le_column, val_column]).unwrap();
1617        let fold_exec = build_fold_exec_from_batches(vec![batch], schema, 0.5, 0);
1618        let session_context = SessionContext::default();
1619        let result = datafusion::physical_plan::collect(fold_exec, session_context.task_ctx())
1620            .await
1621            .unwrap();
1622
1623        let mut values = Vec::new();
1624        for batch in result {
1625            let array = batch.column(1).as_primitive::<Float64Type>();
1626            values.extend(array.iter().map(|v| v.unwrap()));
1627        }
1628
1629        assert_eq!(values.len(), 2);
1630        assert!(values[0].is_nan());
1631        assert!((values[1] - 0.55).abs() < 1e-10, "{values:?}");
1632    }
1633
1634    #[test]
1635    fn evaluate_row_normal_case() {
1636        let bucket = [0.0, 1.0, 2.0, 3.0, 4.0, f64::INFINITY];
1637
1638        #[derive(Debug)]
1639        struct Case {
1640            quantile: f64,
1641            counters: Vec<f64>,
1642            expected: f64,
1643        }
1644
1645        let cases = [
1646            Case {
1647                quantile: 0.9,
1648                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1649                expected: 4.0,
1650            },
1651            Case {
1652                quantile: 0.89,
1653                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1654                expected: 4.0,
1655            },
1656            Case {
1657                quantile: 0.78,
1658                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1659                expected: 3.9,
1660            },
1661            Case {
1662                quantile: 0.5,
1663                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1664                expected: 2.5,
1665            },
1666            Case {
1667                quantile: 0.5,
1668                counters: vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1669                expected: f64::NAN,
1670            },
1671            Case {
1672                quantile: 1.0,
1673                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1674                expected: 4.0,
1675            },
1676            Case {
1677                quantile: 0.0,
1678                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1679                expected: f64::NAN,
1680            },
1681            Case {
1682                quantile: 1.1,
1683                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1684                expected: f64::INFINITY,
1685            },
1686            Case {
1687                quantile: -1.0,
1688                counters: vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0],
1689                expected: f64::NEG_INFINITY,
1690            },
1691        ];
1692
1693        for case in cases {
1694            let actual =
1695                HistogramFoldStream::evaluate_row(case.quantile, &bucket, &case.counters).unwrap();
1696            assert_eq!(
1697                format!("{actual}"),
1698                format!("{}", case.expected),
1699                "{:?}",
1700                case
1701            );
1702        }
1703    }
1704
1705    #[test]
1706    fn evaluate_out_of_order_input() {
1707        let bucket = [0.0, 1.0, 2.0, 3.0, 4.0, f64::INFINITY];
1708        let counters = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0];
1709        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters).unwrap();
1710        assert_eq!(0.0, result);
1711    }
1712
1713    #[test]
1714    fn evaluate_wrong_bucket() {
1715        let bucket = [0.0, 1.0, 2.0, 3.0, 4.0, f64::INFINITY, 5.0];
1716        let counters = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
1717        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters);
1718        assert!(result.is_err());
1719    }
1720
1721    #[test]
1722    fn evaluate_small_fraction() {
1723        let bucket = [0.0, 2.0, 4.0, 6.0, f64::INFINITY];
1724        let counters = [0.0, 1.0 / 300.0, 2.0 / 300.0, 0.01, 0.01];
1725        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters).unwrap();
1726        assert_eq!(3.0, result);
1727    }
1728
1729    #[test]
1730    fn evaluate_non_monotonic_counter() {
1731        let bucket = [0.0, 1.0, 2.0, 3.0, f64::INFINITY];
1732        let counters = [0.1, 0.2, 0.4, 0.17, 0.5];
1733        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters).unwrap();
1734        assert!((result - 1.25).abs() < 1e-10, "{result}");
1735    }
1736
1737    #[test]
1738    fn evaluate_nan_counter() {
1739        let bucket = [0.0, 1.0, 2.0, 3.0, f64::INFINITY];
1740        let counters = [f64::NAN, 1.0, 2.0, 3.0, 3.0];
1741        let result = HistogramFoldStream::evaluate_row(0.5, &bucket, &counters).unwrap();
1742        assert!((result - 1.5).abs() < 1e-10, "{result}");
1743    }
1744
1745    fn build_empty_relation(schema: &Arc<Schema>) -> LogicalPlan {
1746        LogicalPlan::EmptyRelation(EmptyRelation {
1747            produce_one_row: false,
1748            schema: schema.clone().to_dfschema_ref().unwrap(),
1749        })
1750    }
1751
1752    #[tokio::test]
1753    async fn encode_decode_histogram_fold() {
1754        let schema = Arc::new(Schema::new(vec![
1755            Field::new("ts", DataType::Int64, false),
1756            Field::new("le", DataType::Utf8, false),
1757            Field::new("val", DataType::Float64, false),
1758        ]));
1759        let input_plan = build_empty_relation(&schema);
1760        let plan_node = HistogramFold::new(
1761            "le".to_string(),
1762            "val".to_string(),
1763            "ts".to_string(),
1764            0.8,
1765            input_plan.clone(),
1766        )
1767        .unwrap();
1768
1769        let bytes = plan_node.serialize();
1770
1771        let histogram_fold = HistogramFold::deserialize(&bytes).unwrap();
1772        // need fix
1773        let histogram_fold = histogram_fold
1774            .with_exprs_and_inputs(vec![], vec![input_plan])
1775            .unwrap();
1776
1777        assert_eq!(histogram_fold.le_column, "le");
1778        assert_eq!(histogram_fold.ts_column, "ts");
1779        assert_eq!(histogram_fold.field_column, "val");
1780        assert_eq!(histogram_fold.quantile, OrderedF64::from(0.8));
1781        assert_eq!(histogram_fold.output_schema.fields().len(), 2);
1782        assert_eq!(histogram_fold.output_schema.field(0).name(), "ts");
1783        assert_eq!(histogram_fold.output_schema.field(1).name(), "val");
1784    }
1785}