Skip to main content

promql/functions/
quantile.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use datafusion::arrow::array::{Float64Array, Float64Builder};
18use datafusion::arrow::datatypes::TimeUnit;
19use datafusion::common::DataFusionError;
20use datafusion::logical_expr::{ScalarUDF, Volatility};
21use datafusion::physical_plan::ColumnarValue;
22use datafusion_common::ScalarValue;
23use datafusion_expr::create_udf;
24use datatypes::arrow::array::Array;
25use datatypes::arrow::datatypes::DataType;
26
27use crate::error;
28use crate::functions::extract_array;
29use crate::range_array::RangeArray;
30
31pub struct QuantileOverTime;
32
33impl QuantileOverTime {
34    pub const fn name() -> &'static str {
35        "prom_quantile_over_time"
36    }
37
38    pub fn scalar_udf() -> ScalarUDF {
39        let input_types = vec![
40            // time index column
41            RangeArray::convert_data_type(DataType::Timestamp(TimeUnit::Millisecond, None)),
42            // value column
43            RangeArray::convert_data_type(DataType::Float64),
44            // quantile
45            DataType::Float64,
46        ];
47        create_udf(
48            Self::name(),
49            input_types,
50            DataType::Float64,
51            Volatility::Volatile,
52            Arc::new(Self::quantile_over_time) as _,
53        )
54    }
55
56    fn quantile_over_time(input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
57        error::ensure(
58            input.len() == 3,
59            DataFusionError::Plan(
60                "prom_quantile_over_time function should have 3 inputs".to_string(),
61            ),
62        )?;
63
64        let ts_array = extract_array(&input[0])?;
65        let value_array = extract_array(&input[1])?;
66        let quantile_col = &input[2];
67
68        let ts_range: RangeArray = RangeArray::try_new(ts_array.to_data().into())?;
69        let value_range: RangeArray = RangeArray::try_new(value_array.to_data().into())?;
70        error::ensure(
71            ts_range.len() == value_range.len(),
72            DataFusionError::Execution(format!(
73                "{}: input arrays should have the same length, found {} and {}",
74                Self::name(),
75                ts_range.len(),
76                value_range.len()
77            )),
78        )?;
79        error::ensure(
80            ts_range.value_type() == DataType::Timestamp(TimeUnit::Millisecond, None),
81            DataFusionError::Execution(format!(
82                "{}: expect TimestampMillisecond as time index array's type, found {}",
83                Self::name(),
84                ts_range.value_type()
85            )),
86        )?;
87        error::ensure(
88            value_range.value_type() == DataType::Float64,
89            DataFusionError::Execution(format!(
90                "{}: expect Float64 as value array's type, found {}",
91                Self::name(),
92                value_range.value_type()
93            )),
94        )?;
95
96        let all_values = value_range
97            .values()
98            .as_any()
99            .downcast_ref::<Float64Array>()
100            .unwrap()
101            .values();
102        let mut result_builder = Float64Builder::with_capacity(ts_range.len());
103        let mut scratch = Vec::new();
104
105        match quantile_col {
106            ColumnarValue::Scalar(quantile_scalar) => {
107                let quantile = if let ScalarValue::Float64(Some(q)) = quantile_scalar {
108                    *q
109                } else {
110                    // For `ScalarValue::Float64(None)` or other scalar types, use NAN,
111                    // which conforms to PromQL's behavior.
112                    f64::NAN
113                };
114
115                for index in 0..ts_range.len() {
116                    let (_, ts_len) = ts_range.get_offset_length(index).unwrap();
117                    let (value_offset, value_len) = value_range.get_offset_length(index).unwrap();
118                    error::ensure(
119                        ts_len == value_len,
120                        DataFusionError::Execution(format!(
121                            "{}: time and value arrays in a group should have the same length, found {} and {}",
122                            Self::name(),
123                            ts_len,
124                            value_len
125                        )),
126                    )?;
127
128                    match quantile_with_scratch(
129                        &all_values[value_offset..value_offset + value_len],
130                        quantile,
131                        &mut scratch,
132                    ) {
133                        Some(value) => result_builder.append_value(value),
134                        None => result_builder.append_null(),
135                    }
136                }
137            }
138            ColumnarValue::Array(quantile_array) => {
139                let quantile_array = quantile_array
140                    .as_any()
141                    .downcast_ref::<Float64Array>()
142                    .ok_or_else(|| {
143                        DataFusionError::Execution(format!(
144                            "{}: expect Float64 as quantile array's type, found {}",
145                            Self::name(),
146                            quantile_array.data_type()
147                        ))
148                    })?;
149
150                error::ensure(
151                    quantile_array.len() == ts_range.len(),
152                    DataFusionError::Execution(format!(
153                        "{}: quantile array should have the same length as other columns, found {} and {}",
154                        Self::name(),
155                        quantile_array.len(),
156                        ts_range.len()
157                    )),
158                )?;
159                for index in 0..ts_range.len() {
160                    let (_, ts_len) = ts_range.get_offset_length(index).unwrap();
161                    let (value_offset, value_len) = value_range.get_offset_length(index).unwrap();
162                    error::ensure(
163                        ts_len == value_len,
164                        DataFusionError::Execution(format!(
165                            "{}: time and value arrays in a group should have the same length, found {} and {}",
166                            Self::name(),
167                            ts_len,
168                            value_len
169                        )),
170                    )?;
171                    let quantile = if quantile_array.is_null(index) {
172                        f64::NAN
173                    } else {
174                        quantile_array.value(index)
175                    };
176                    match quantile_with_scratch(
177                        &all_values[value_offset..value_offset + value_len],
178                        quantile,
179                        &mut scratch,
180                    ) {
181                        Some(value) => result_builder.append_value(value),
182                        None => result_builder.append_null(),
183                    }
184                }
185            }
186        }
187
188        let result = ColumnarValue::Array(Arc::new(result_builder.finish()));
189        Ok(result)
190    }
191}
192
193/// Refer to <https://github.com/prometheus/prometheus/blob/6e2905a4d4ff9b47b1f6d201333f5bd53633f921/promql/quantile.go#L357-L386>
194pub(crate) fn quantile_impl(values: &[f64], quantile: f64) -> Option<f64> {
195    let mut scratch = Vec::new();
196    quantile_with_scratch(values, quantile, &mut scratch)
197}
198
199/// Same as [quantile_impl] but reuses a caller-provided scratch buffer to avoid
200/// per-call allocation.
201fn quantile_with_scratch(values: &[f64], quantile: f64, scratch: &mut Vec<f64>) -> Option<f64> {
202    if quantile.is_nan() || values.is_empty() {
203        return Some(f64::NAN);
204    }
205    if quantile < 0.0 {
206        return Some(f64::NEG_INFINITY);
207    }
208    if quantile > 1.0 {
209        return Some(f64::INFINITY);
210    }
211
212    scratch.clear();
213    scratch.extend_from_slice(values);
214    scratch.sort_unstable_by(f64::total_cmp);
215
216    let length = scratch.len();
217    let rank = quantile * (length - 1) as f64;
218
219    let lower_index = rank.floor() as usize;
220    let upper_index = (length - 1).min(lower_index + 1);
221    let weight = rank - rank.floor();
222
223    let result = scratch[lower_index] * (1.0 - weight) + scratch[upper_index] * weight;
224    Some(result)
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn test_quantile_impl_empty() {
233        let values = &[];
234        let q = 0.5;
235        assert!(quantile_impl(values, q).unwrap().is_nan());
236    }
237
238    #[test]
239    fn test_quantile_impl_nan() {
240        let values = &[1.0, 2.0, 3.0];
241        let q = f64::NAN;
242        assert!(quantile_impl(values, q).unwrap().is_nan());
243    }
244
245    #[test]
246    fn test_quantile_impl_negative_quantile() {
247        let values = &[1.0, 2.0, 3.0];
248        let q = -0.5;
249        assert_eq!(quantile_impl(values, q).unwrap(), f64::NEG_INFINITY);
250    }
251
252    #[test]
253    fn test_quantile_impl_greater_than_one_quantile() {
254        let values = &[1.0, 2.0, 3.0];
255        let q = 1.5;
256        assert_eq!(quantile_impl(values, q).unwrap(), f64::INFINITY);
257    }
258
259    #[test]
260    fn test_quantile_impl_single_element() {
261        let values = &[1.0];
262        let q = 0.8;
263        assert_eq!(quantile_impl(values, q).unwrap(), 1.0);
264    }
265
266    #[test]
267    fn test_quantile_impl_even_length() {
268        let values = &[3.0, 1.0, 5.0, 2.0];
269        let q = 0.5;
270        assert_eq!(quantile_impl(values, q).unwrap(), 2.5);
271    }
272
273    #[test]
274    fn test_quantile_impl_odd_length() {
275        let values = &[4.0, 1.0, 3.0, 2.0, 5.0];
276        let q = 0.25;
277        assert_eq!(quantile_impl(values, q).unwrap(), 2.0);
278    }
279}