promql/functions/
quantile_aggr.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use datafusion::arrow::array::{ArrayRef, AsArray};
18use datafusion::common::cast::{as_list_array, as_primitive_array, as_struct_array};
19use datafusion::error::{DataFusionError, Result as DfResult};
20use datafusion::logical_expr::{Accumulator as DfAccumulator, AggregateUDF, Volatility};
21use datafusion::physical_plan::expressions::Literal;
22use datafusion::prelude::create_udaf;
23use datafusion_common::ScalarValue;
24use datafusion_expr::function::AccumulatorArgs;
25use datatypes::arrow::array::{ListArray, StructArray};
26use datatypes::arrow::datatypes::{DataType, Field, Float64Type};
27
28use crate::functions::quantile::quantile_impl;
29
30pub const QUANTILE_NAME: &str = "quantile";
31
32const VALUES_FIELD_NAME: &str = "values";
33const DEFAULT_LIST_FIELD_NAME: &str = "item";
34
35#[derive(Debug, Default)]
36pub struct QuantileAccumulator {
37    q: f64,
38    values: Vec<Option<f64>>,
39}
40
41/// Create a quantile `AggregateUDF` for PromQL quantile operator,
42/// which calculates φ-quantile (0 ≤ φ ≤ 1) over dimensions
43pub fn quantile_udaf() -> Arc<AggregateUDF> {
44    Arc::new(create_udaf(
45        QUANTILE_NAME,
46        // Input type: (φ, values)
47        vec![DataType::Float64, DataType::Float64],
48        // Output type: the φ-quantile
49        Arc::new(DataType::Float64),
50        Volatility::Volatile,
51        // Create the accumulator
52        Arc::new(QuantileAccumulator::from_args),
53        // Intermediate state types
54        Arc::new(vec![DataType::Struct(
55            vec![Field::new(
56                VALUES_FIELD_NAME,
57                DataType::List(Arc::new(Field::new(
58                    DEFAULT_LIST_FIELD_NAME,
59                    DataType::Float64,
60                    true,
61                ))),
62                false,
63            )]
64            .into(),
65        )]),
66    ))
67}
68
69impl QuantileAccumulator {
70    fn new(q: f64) -> Self {
71        Self {
72            q,
73            ..Default::default()
74        }
75    }
76
77    pub fn from_args(args: AccumulatorArgs) -> DfResult<Box<dyn DfAccumulator>> {
78        if args.exprs.len() != 2 {
79            return Err(DataFusionError::Plan(
80                "Quantile function should have 2 inputs".to_string(),
81            ));
82        }
83
84        let q = match &args.exprs[0]
85            .as_any()
86            .downcast_ref::<Literal>()
87            .map(|lit| lit.value())
88        {
89            Some(ScalarValue::Float64(Some(q))) => *q,
90            _ => {
91                return Err(DataFusionError::Internal(
92                    "Invalid quantile value".to_string(),
93                ))
94            }
95        };
96
97        Ok(Box::new(Self::new(q)))
98    }
99}
100
101impl DfAccumulator for QuantileAccumulator {
102    fn update_batch(&mut self, values: &[ArrayRef]) -> DfResult<()> {
103        let f64_array = values[1].as_primitive::<Float64Type>();
104
105        self.values.extend(f64_array);
106
107        Ok(())
108    }
109
110    fn evaluate(&mut self) -> DfResult<ScalarValue> {
111        let values: Vec<_> = self.values.iter().map(|v| v.unwrap_or(0.0)).collect();
112
113        let result = quantile_impl(&values, self.q);
114
115        ScalarValue::new_primitive::<Float64Type>(result, &DataType::Float64)
116    }
117
118    fn size(&self) -> usize {
119        std::mem::size_of::<Self>() + self.values.capacity() * std::mem::size_of::<Option<f64>>()
120    }
121
122    fn state(&mut self) -> DfResult<Vec<ScalarValue>> {
123        let values_array = Arc::new(ListArray::from_iter_primitive::<Float64Type, _, _>(vec![
124            Some(self.values.clone()),
125        ]));
126
127        let state_struct = StructArray::new(
128            vec![Field::new(
129                VALUES_FIELD_NAME,
130                DataType::List(Arc::new(Field::new(
131                    DEFAULT_LIST_FIELD_NAME,
132                    DataType::Float64,
133                    true,
134                ))),
135                false,
136            )]
137            .into(),
138            vec![values_array],
139            None,
140        );
141
142        Ok(vec![ScalarValue::Struct(Arc::new(state_struct))])
143    }
144
145    fn merge_batch(&mut self, states: &[ArrayRef]) -> DfResult<()> {
146        if states.is_empty() {
147            return Ok(());
148        }
149
150        for state in states {
151            let state = as_struct_array(state)?;
152
153            for list in as_list_array(state.column(0))?.iter().flatten() {
154                let f64_array = as_primitive_array::<Float64Type>(&list)?.clone();
155                self.values.extend(&f64_array);
156            }
157        }
158
159        Ok(())
160    }
161}
162#[cfg(test)]
163mod tests {
164    use std::sync::Arc;
165
166    use datafusion::arrow::array::{ArrayRef, Float64Array};
167    use datafusion_common::ScalarValue;
168
169    use super::*;
170
171    fn create_f64_array(values: Vec<Option<f64>>) -> ArrayRef {
172        Arc::new(Float64Array::from(values)) as ArrayRef
173    }
174
175    #[test]
176    fn test_quantile_accumulator_empty() {
177        let mut accumulator = QuantileAccumulator::new(0.5);
178
179        let result = accumulator.evaluate().unwrap();
180
181        match result {
182            ScalarValue::Float64(_) => (),
183            _ => panic!("Expected Float64 scalar value"),
184        }
185    }
186
187    #[test]
188    fn test_quantile_accumulator_single_value() {
189        let mut accumulator = QuantileAccumulator::new(0.5);
190        let q = create_f64_array(vec![Some(0.5)]);
191        let input = create_f64_array(vec![Some(10.0)]);
192
193        accumulator.update_batch(&[q, input]).unwrap();
194        let result = accumulator.evaluate().unwrap();
195
196        assert_eq!(result, ScalarValue::Float64(Some(10.0)));
197    }
198
199    #[test]
200    fn test_quantile_accumulator_multiple_values() {
201        let mut accumulator = QuantileAccumulator::new(0.5);
202        let q = create_f64_array(vec![Some(0.5)]);
203        let input = create_f64_array(vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)]);
204
205        accumulator.update_batch(&[q, input]).unwrap();
206        let result = accumulator.evaluate().unwrap();
207
208        assert_eq!(result, ScalarValue::Float64(Some(3.0)));
209    }
210
211    #[test]
212    fn test_quantile_accumulator_with_nulls() {
213        let mut accumulator = QuantileAccumulator::new(0.5);
214        let q = create_f64_array(vec![Some(0.5)]);
215        let input = create_f64_array(vec![Some(1.0), None, Some(3.0), Some(4.0), Some(5.0)]);
216
217        accumulator.update_batch(&[q, input]).unwrap();
218
219        let result = accumulator.evaluate().unwrap();
220        assert_eq!(result, ScalarValue::Float64(Some(3.0)));
221    }
222
223    #[test]
224    fn test_quantile_accumulator_multiple_batches() {
225        let mut accumulator = QuantileAccumulator::new(0.5);
226        let q = create_f64_array(vec![Some(0.5)]);
227        let input1 = create_f64_array(vec![Some(1.0), Some(2.0)]);
228        let input2 = create_f64_array(vec![Some(3.0), Some(4.0), Some(5.0)]);
229
230        accumulator.update_batch(&[q.clone(), input1]).unwrap();
231        accumulator.update_batch(&[q, input2]).unwrap();
232
233        let result = accumulator.evaluate().unwrap();
234        assert_eq!(result, ScalarValue::Float64(Some(3.0)));
235    }
236
237    #[test]
238    fn test_quantile_accumulator_different_quantiles() {
239        let mut min_accumulator = QuantileAccumulator::new(0.0);
240        let q = create_f64_array(vec![Some(0.0)]);
241        let input = create_f64_array(vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)]);
242        min_accumulator.update_batch(&[q, input.clone()]).unwrap();
243        assert_eq!(
244            min_accumulator.evaluate().unwrap(),
245            ScalarValue::Float64(Some(1.0))
246        );
247
248        let mut q1_accumulator = QuantileAccumulator::new(0.25);
249        let q = create_f64_array(vec![Some(0.25)]);
250        q1_accumulator.update_batch(&[q, input.clone()]).unwrap();
251        assert_eq!(
252            q1_accumulator.evaluate().unwrap(),
253            ScalarValue::Float64(Some(2.0))
254        );
255
256        let mut q3_accumulator = QuantileAccumulator::new(0.75);
257        let q = create_f64_array(vec![Some(0.75)]);
258        q3_accumulator.update_batch(&[q, input.clone()]).unwrap();
259        assert_eq!(
260            q3_accumulator.evaluate().unwrap(),
261            ScalarValue::Float64(Some(4.0))
262        );
263
264        let mut max_accumulator = QuantileAccumulator::new(1.0);
265        let q = create_f64_array(vec![Some(1.0)]);
266        max_accumulator.update_batch(&[q, input]).unwrap();
267        assert_eq!(
268            max_accumulator.evaluate().unwrap(),
269            ScalarValue::Float64(Some(5.0))
270        );
271    }
272
273    #[test]
274    fn test_quantile_accumulator_size() {
275        let mut accumulator = QuantileAccumulator::new(0.5);
276        let q = create_f64_array(vec![Some(0.5)]);
277        let input = create_f64_array(vec![Some(1.0), Some(2.0), Some(3.0)]);
278
279        let initial_size = accumulator.size();
280        accumulator.update_batch(&[q, input]).unwrap();
281        let after_update_size = accumulator.size();
282
283        assert!(after_update_size >= initial_size);
284    }
285
286    #[test]
287    fn test_quantile_accumulator_state_and_merge() -> DfResult<()> {
288        let mut acc1 = QuantileAccumulator::new(0.5);
289        let q = create_f64_array(vec![Some(0.5)]);
290        let input1 = create_f64_array(vec![Some(1.0), Some(2.0)]);
291        acc1.update_batch(&[q, input1])?;
292
293        let state1 = acc1.state()?;
294
295        let mut acc2 = QuantileAccumulator::new(0.5);
296        let q = create_f64_array(vec![Some(0.5)]);
297        let input2 = create_f64_array(vec![Some(3.0), Some(4.0), Some(5.0)]);
298        acc2.update_batch(&[q, input2])?;
299
300        let mut struct_builders = vec![];
301        for scalar in &state1 {
302            if let ScalarValue::Struct(struct_array) = scalar {
303                struct_builders.push(struct_array.clone() as ArrayRef);
304            }
305        }
306
307        acc2.merge_batch(&struct_builders)?;
308
309        let result = acc2.evaluate()?;
310
311        assert_eq!(result, ScalarValue::Float64(Some(3.0)));
312
313        Ok(())
314    }
315
316    #[test]
317    fn test_quantile_accumulator_with_extreme_values() {
318        let mut accumulator = QuantileAccumulator::new(0.5);
319        let q = create_f64_array(vec![Some(0.5)]);
320        let input = create_f64_array(vec![Some(f64::MAX), Some(f64::MIN), Some(0.0)]);
321
322        accumulator.update_batch(&[q, input]).unwrap();
323        let _result = accumulator.evaluate().unwrap();
324    }
325
326    #[test]
327    fn test_quantile_udaf_creation() {
328        let udaf = quantile_udaf();
329
330        assert_eq!(udaf.name(), QUANTILE_NAME);
331        assert_eq!(udaf.return_type(&[]).unwrap(), DataType::Float64);
332    }
333}