promql/functions/
quantile.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use datafusion::arrow::array::Float64Array;
18use datafusion::arrow::datatypes::TimeUnit;
19use datafusion::common::DataFusionError;
20use datafusion::logical_expr::{ScalarUDF, Volatility};
21use datafusion::physical_plan::ColumnarValue;
22use datafusion_common::ScalarValue;
23use datafusion_expr::create_udf;
24use datatypes::arrow::array::Array;
25use datatypes::arrow::datatypes::DataType;
26
27use crate::error;
28use crate::functions::extract_array;
29use crate::range_array::RangeArray;
30
31pub struct QuantileOverTime;
32
33impl QuantileOverTime {
34    pub const fn name() -> &'static str {
35        "prom_quantile_over_time"
36    }
37
38    pub fn scalar_udf() -> ScalarUDF {
39        let input_types = vec![
40            // time index column
41            RangeArray::convert_data_type(DataType::Timestamp(TimeUnit::Millisecond, None)),
42            // value column
43            RangeArray::convert_data_type(DataType::Float64),
44            // quantile
45            DataType::Float64,
46        ];
47        create_udf(
48            Self::name(),
49            input_types,
50            DataType::Float64,
51            Volatility::Volatile,
52            Arc::new(Self::quantile_over_time) as _,
53        )
54    }
55
56    fn quantile_over_time(input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
57        error::ensure(
58            input.len() == 3,
59            DataFusionError::Plan(
60                "prom_quantile_over_time function should have 3 inputs".to_string(),
61            ),
62        )?;
63
64        let ts_array = extract_array(&input[0])?;
65        let value_array = extract_array(&input[1])?;
66        let quantile_col = &input[2];
67
68        let ts_range: RangeArray = RangeArray::try_new(ts_array.to_data().into())?;
69        let value_range: RangeArray = RangeArray::try_new(value_array.to_data().into())?;
70        error::ensure(
71            ts_range.len() == value_range.len(),
72            DataFusionError::Execution(format!(
73                "{}: input arrays should have the same length, found {} and {}",
74                Self::name(),
75                ts_range.len(),
76                value_range.len()
77            )),
78        )?;
79        error::ensure(
80            ts_range.value_type() == DataType::Timestamp(TimeUnit::Millisecond, None),
81            DataFusionError::Execution(format!(
82                "{}: expect TimestampMillisecond as time index array's type, found {}",
83                Self::name(),
84                ts_range.value_type()
85            )),
86        )?;
87        error::ensure(
88            value_range.value_type() == DataType::Float64,
89            DataFusionError::Execution(format!(
90                "{}: expect Float64 as value array's type, found {}",
91                Self::name(),
92                value_range.value_type()
93            )),
94        )?;
95
96        // calculation
97        let mut result_array = Vec::with_capacity(ts_range.len());
98
99        match quantile_col {
100            ColumnarValue::Scalar(quantile_scalar) => {
101                let quantile = if let ScalarValue::Float64(Some(q)) = quantile_scalar {
102                    *q
103                } else {
104                    // For `ScalarValue::Float64(None)` or other scalar types, use NAN,
105                    // which conforms to PromQL's behavior.
106                    f64::NAN
107                };
108
109                for index in 0..ts_range.len() {
110                    let timestamps = ts_range.get(index).unwrap();
111                    let values = value_range.get(index).unwrap();
112                    let values = values
113                        .as_any()
114                        .downcast_ref::<Float64Array>()
115                        .unwrap()
116                        .values();
117                    error::ensure(
118                        timestamps.len() == values.len(),
119                        DataFusionError::Execution(format!(
120                            "{}: time and value arrays in a group should have the same length, found {} and {}",
121                            Self::name(),
122                            timestamps.len(),
123                            values.len()
124                        )),
125                    )?;
126
127                    let result = quantile_impl(values, quantile);
128                    result_array.push(result);
129                }
130            }
131            ColumnarValue::Array(quantile_array) => {
132                let quantile_array = quantile_array
133                    .as_any()
134                    .downcast_ref::<Float64Array>()
135                    .ok_or_else(|| {
136                        DataFusionError::Execution(format!(
137                            "{}: expect Float64 as quantile array's type, found {}",
138                            Self::name(),
139                            quantile_array.data_type()
140                        ))
141                    })?;
142
143                error::ensure(
144                    quantile_array.len() == ts_range.len(),
145                    DataFusionError::Execution(format!(
146                        "{}: quantile array should have the same length as other columns, found {} and {}",
147                        Self::name(),
148                        quantile_array.len(),
149                        ts_range.len()
150                    )),
151                )?;
152                for index in 0..ts_range.len() {
153                    let timestamps = ts_range.get(index).unwrap();
154                    let values = value_range.get(index).unwrap();
155                    let values = values
156                        .as_any()
157                        .downcast_ref::<Float64Array>()
158                        .unwrap()
159                        .values();
160                    error::ensure(
161                        timestamps.len() == values.len(),
162                        DataFusionError::Execution(format!(
163                            "{}: time and value arrays in a group should have the same length, found {} and {}",
164                            Self::name(),
165                            timestamps.len(),
166                            values.len()
167                        )),
168                    )?;
169                    let quantile = if quantile_array.is_null(index) {
170                        f64::NAN
171                    } else {
172                        quantile_array.value(index)
173                    };
174                    let result = quantile_impl(values, quantile);
175                    result_array.push(result);
176                }
177            }
178        }
179
180        let result = ColumnarValue::Array(Arc::new(Float64Array::from_iter(result_array)));
181        Ok(result)
182    }
183}
184
185/// Refer to <https://github.com/prometheus/prometheus/blob/6e2905a4d4ff9b47b1f6d201333f5bd53633f921/promql/quantile.go#L357-L386>
186pub(crate) fn quantile_impl(values: &[f64], quantile: f64) -> Option<f64> {
187    if quantile.is_nan() || values.is_empty() {
188        return Some(f64::NAN);
189    }
190    if quantile < 0.0 {
191        return Some(f64::NEG_INFINITY);
192    }
193    if quantile > 1.0 {
194        return Some(f64::INFINITY);
195    }
196
197    let mut values = values.to_vec();
198    values.sort_unstable_by(f64::total_cmp);
199
200    let length = values.len();
201    let rank = quantile * (length - 1) as f64;
202
203    let lower_index = 0.max(rank.floor() as usize);
204    let upper_index = (length - 1).min(lower_index + 1);
205    let weight = rank - rank.floor();
206
207    let result = values[lower_index] * (1.0 - weight) + values[upper_index] * weight;
208    Some(result)
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    #[test]
216    fn test_quantile_impl_empty() {
217        let values = &[];
218        let q = 0.5;
219        assert!(quantile_impl(values, q).unwrap().is_nan());
220    }
221
222    #[test]
223    fn test_quantile_impl_nan() {
224        let values = &[1.0, 2.0, 3.0];
225        let q = f64::NAN;
226        assert!(quantile_impl(values, q).unwrap().is_nan());
227    }
228
229    #[test]
230    fn test_quantile_impl_negative_quantile() {
231        let values = &[1.0, 2.0, 3.0];
232        let q = -0.5;
233        assert_eq!(quantile_impl(values, q).unwrap(), f64::NEG_INFINITY);
234    }
235
236    #[test]
237    fn test_quantile_impl_greater_than_one_quantile() {
238        let values = &[1.0, 2.0, 3.0];
239        let q = 1.5;
240        assert_eq!(quantile_impl(values, q).unwrap(), f64::INFINITY);
241    }
242
243    #[test]
244    fn test_quantile_impl_single_element() {
245        let values = &[1.0];
246        let q = 0.8;
247        assert_eq!(quantile_impl(values, q).unwrap(), 1.0);
248    }
249
250    #[test]
251    fn test_quantile_impl_even_length() {
252        let values = &[3.0, 1.0, 5.0, 2.0];
253        let q = 0.5;
254        assert_eq!(quantile_impl(values, q).unwrap(), 2.5);
255    }
256
257    #[test]
258    fn test_quantile_impl_odd_length() {
259        let values = &[4.0, 1.0, 3.0, 2.0, 5.0];
260        let q = 0.25;
261        assert_eq!(quantile_impl(values, q).unwrap(), 2.0);
262    }
263}