promql/functions/
quantile.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use datafusion::arrow::array::Float64Array;
18use datafusion::arrow::datatypes::TimeUnit;
19use datafusion::common::DataFusionError;
20use datafusion::logical_expr::{ScalarUDF, Volatility};
21use datafusion::physical_plan::ColumnarValue;
22use datafusion_common::ScalarValue;
23use datafusion_expr::create_udf;
24use datatypes::arrow::array::Array;
25use datatypes::arrow::datatypes::DataType;
26
27use crate::error;
28use crate::functions::extract_array;
29use crate::range_array::RangeArray;
30
31pub struct QuantileOverTime {
32    quantile: f64,
33}
34
35impl QuantileOverTime {
36    fn new(quantile: f64) -> Self {
37        Self { quantile }
38    }
39
40    pub const fn name() -> &'static str {
41        "prom_quantile_over_time"
42    }
43
44    pub fn scalar_udf() -> ScalarUDF {
45        let input_types = vec![
46            // time index column
47            RangeArray::convert_data_type(DataType::Timestamp(TimeUnit::Millisecond, None)),
48            // value column
49            RangeArray::convert_data_type(DataType::Float64),
50            // quantile
51            DataType::Float64,
52        ];
53        create_udf(
54            Self::name(),
55            input_types,
56            DataType::Float64,
57            Volatility::Volatile,
58            Arc::new(move |input: &_| Self::create_function(input)?.quantile_over_time(input)) as _,
59        )
60    }
61
62    fn create_function(inputs: &[ColumnarValue]) -> Result<Self, DataFusionError> {
63        if inputs.len() != 3 {
64            return Err(DataFusionError::Plan(
65                "QuantileOverTime function should have 3 inputs".to_string(),
66            ));
67        }
68        let ColumnarValue::Scalar(ScalarValue::Float64(Some(quantile))) = inputs[2] else {
69            return Err(DataFusionError::Plan(
70                "QuantileOverTime function's third input should be a scalar float64".to_string(),
71            ));
72        };
73        Ok(Self::new(quantile))
74    }
75
76    fn quantile_over_time(
77        &self,
78        input: &[ColumnarValue],
79    ) -> Result<ColumnarValue, DataFusionError> {
80        // construct matrix from input.
81        assert_eq!(input.len(), 2);
82        let ts_array = extract_array(&input[0])?;
83        let value_array = extract_array(&input[1])?;
84
85        let ts_range: RangeArray = RangeArray::try_new(ts_array.to_data().into())?;
86        let value_range: RangeArray = RangeArray::try_new(value_array.to_data().into())?;
87        error::ensure(
88            ts_range.len() == value_range.len(),
89            DataFusionError::Execution(format!(
90                "{}: input arrays should have the same length, found {} and {}",
91                Self::name(),
92                ts_range.len(),
93                value_range.len()
94            )),
95        )?;
96        error::ensure(
97            ts_range.value_type() == DataType::Timestamp(TimeUnit::Millisecond, None),
98            DataFusionError::Execution(format!(
99                "{}: expect TimestampMillisecond as time index array's type, found {}",
100                Self::name(),
101                ts_range.value_type()
102            )),
103        )?;
104        error::ensure(
105            value_range.value_type() == DataType::Float64,
106            DataFusionError::Execution(format!(
107                "{}: expect Float64 as value array's type, found {}",
108                Self::name(),
109                value_range.value_type()
110            )),
111        )?;
112
113        // calculation
114        let mut result_array = Vec::with_capacity(ts_range.len());
115
116        for index in 0..ts_range.len() {
117            let timestamps = ts_range.get(index).unwrap();
118            let values = value_range.get(index).unwrap();
119            let values = values
120                .as_any()
121                .downcast_ref::<Float64Array>()
122                .unwrap()
123                .values();
124            error::ensure(
125                timestamps.len() == values.len(),
126                DataFusionError::Execution(format!(
127                    "{}: input arrays should have the same length, found {} and {}",
128                    Self::name(),
129                    timestamps.len(),
130                    values.len()
131                )),
132            )?;
133
134            let retule = quantile_impl(values, self.quantile);
135
136            result_array.push(retule);
137        }
138
139        let result = ColumnarValue::Array(Arc::new(Float64Array::from_iter(result_array)));
140        Ok(result)
141    }
142}
143
144/// Refer to <https://github.com/prometheus/prometheus/blob/6e2905a4d4ff9b47b1f6d201333f5bd53633f921/promql/quantile.go#L357-L386>
145pub(crate) fn quantile_impl(values: &[f64], quantile: f64) -> Option<f64> {
146    if quantile.is_nan() || values.is_empty() {
147        return Some(f64::NAN);
148    }
149    if quantile < 0.0 {
150        return Some(f64::NEG_INFINITY);
151    }
152    if quantile > 1.0 {
153        return Some(f64::INFINITY);
154    }
155
156    let mut values = values.to_vec();
157    values.sort_unstable_by(f64::total_cmp);
158
159    let length = values.len();
160    let rank = quantile * (length - 1) as f64;
161
162    let lower_index = 0.max(rank.floor() as usize);
163    let upper_index = (length - 1).min(lower_index + 1);
164    let weight = rank - rank.floor();
165
166    let result = values[lower_index] * (1.0 - weight) + values[upper_index] * weight;
167    Some(result)
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[test]
175    fn test_quantile_impl_empty() {
176        let values = &[];
177        let q = 0.5;
178        assert!(quantile_impl(values, q).unwrap().is_nan());
179    }
180
181    #[test]
182    fn test_quantile_impl_nan() {
183        let values = &[1.0, 2.0, 3.0];
184        let q = f64::NAN;
185        assert!(quantile_impl(values, q).unwrap().is_nan());
186    }
187
188    #[test]
189    fn test_quantile_impl_negative_quantile() {
190        let values = &[1.0, 2.0, 3.0];
191        let q = -0.5;
192        assert_eq!(quantile_impl(values, q).unwrap(), f64::NEG_INFINITY);
193    }
194
195    #[test]
196    fn test_quantile_impl_greater_than_one_quantile() {
197        let values = &[1.0, 2.0, 3.0];
198        let q = 1.5;
199        assert_eq!(quantile_impl(values, q).unwrap(), f64::INFINITY);
200    }
201
202    #[test]
203    fn test_quantile_impl_single_element() {
204        let values = &[1.0];
205        let q = 0.8;
206        assert_eq!(quantile_impl(values, q).unwrap(), 1.0);
207    }
208
209    #[test]
210    fn test_quantile_impl_even_length() {
211        let values = &[3.0, 1.0, 5.0, 2.0];
212        let q = 0.5;
213        assert_eq!(quantile_impl(values, q).unwrap(), 2.5);
214    }
215
216    #[test]
217    fn test_quantile_impl_odd_length() {
218        let values = &[4.0, 1.0, 3.0, 2.0, 5.0];
219        let q = 0.25;
220        assert_eq!(quantile_impl(values, q).unwrap(), 2.0);
221    }
222}