promql/functions/
predict_linear.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Implementation of [`predict_linear`](https://prometheus.io/docs/prometheus/latest/querying/functions/#predict_linear) in PromQL. Refer to the [original
16//! implementation](https://github.com/prometheus/prometheus/blob/90b2f7a540b8a70d8d81372e6692dcbb67ccbaaa/promql/functions.go#L859-L872).
17
18use std::sync::Arc;
19
20use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
21use datafusion::arrow::datatypes::TimeUnit;
22use datafusion::common::DataFusionError;
23use datafusion::logical_expr::{ScalarUDF, Volatility};
24use datafusion::physical_plan::ColumnarValue;
25use datafusion_common::ScalarValue;
26use datafusion_expr::create_udf;
27use datatypes::arrow::array::Array;
28use datatypes::arrow::datatypes::DataType;
29
30use crate::error;
31use crate::functions::{extract_array, linear_regression};
32use crate::range_array::RangeArray;
33
34pub struct PredictLinear;
35
36impl PredictLinear {
37    pub const fn name() -> &'static str {
38        "prom_predict_linear"
39    }
40
41    pub fn scalar_udf() -> ScalarUDF {
42        let input_types = vec![
43            // time index column
44            RangeArray::convert_data_type(DataType::Timestamp(TimeUnit::Millisecond, None)),
45            // value column
46            RangeArray::convert_data_type(DataType::Float64),
47            // t
48            DataType::Int64,
49        ];
50        create_udf(
51            Self::name(),
52            input_types,
53            DataType::Float64,
54            Volatility::Volatile,
55            Arc::new(Self::predict_linear) as _,
56        )
57    }
58
59    fn predict_linear(input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
60        error::ensure(
61            input.len() == 3,
62            DataFusionError::Plan("prom_predict_linear function should have 3 inputs".to_string()),
63        )?;
64
65        let ts_array = extract_array(&input[0])?;
66        let value_array = extract_array(&input[1])?;
67        let t_col = &input[2];
68
69        let ts_range: RangeArray = RangeArray::try_new(ts_array.to_data().into())?;
70        let value_range: RangeArray = RangeArray::try_new(value_array.to_data().into())?;
71        error::ensure(
72            ts_range.len() == value_range.len(),
73            DataFusionError::Execution(format!(
74                "{}: input arrays should have the same length, found {} and {}",
75                Self::name(),
76                ts_range.len(),
77                value_range.len()
78            )),
79        )?;
80        error::ensure(
81            ts_range.value_type() == DataType::Timestamp(TimeUnit::Millisecond, None),
82            DataFusionError::Execution(format!(
83                "{}: expect TimestampMillisecond as time index array's type, found {}",
84                Self::name(),
85                ts_range.value_type()
86            )),
87        )?;
88        error::ensure(
89            value_range.value_type() == DataType::Float64,
90            DataFusionError::Execution(format!(
91                "{}: expect Float64 as value array's type, found {}",
92                Self::name(),
93                value_range.value_type()
94            )),
95        )?;
96
97        let t_iter: Box<dyn Iterator<Item = Option<i64>>> = match t_col {
98            ColumnarValue::Scalar(t_scalar) => {
99                let t = if let ScalarValue::Int64(Some(t_val)) = t_scalar {
100                    *t_val
101                } else {
102                    // For `ScalarValue::Int64(None)` or other scalar types, returns NULL array,
103                    // which conforms to PromQL's behavior.
104                    let null_array = Float64Array::new_null(ts_range.len());
105                    return Ok(ColumnarValue::Array(Arc::new(null_array)));
106                };
107                Box::new((0..ts_range.len()).map(move |_| Some(t)))
108            }
109            ColumnarValue::Array(t_array) => {
110                let t_array = t_array
111                    .as_any()
112                    .downcast_ref::<datafusion::arrow::array::Int64Array>()
113                    .ok_or_else(|| {
114                        DataFusionError::Execution(format!(
115                            "{}: expect Int64 as t array's type, found {}",
116                            Self::name(),
117                            t_array.data_type()
118                        ))
119                    })?;
120                error::ensure(
121                    t_array.len() == ts_range.len(),
122                    DataFusionError::Execution(format!(
123                        "{}: t array should have the same length as other columns, found {} and {}",
124                        Self::name(),
125                        t_array.len(),
126                        ts_range.len()
127                    )),
128                )?;
129
130                Box::new(t_array.iter())
131            }
132        };
133        let mut result_array = Vec::with_capacity(ts_range.len());
134        for (index, t) in t_iter.enumerate() {
135            let (timestamps, values) = get_ts_values(&ts_range, &value_range, index, Self::name())?;
136            let ret = predict_linear_impl(&timestamps, &values, t.unwrap());
137            result_array.push(ret);
138        }
139
140        let result = ColumnarValue::Array(Arc::new(Float64Array::from_iter(result_array)));
141        Ok(result)
142    }
143}
144
145fn get_ts_values(
146    ts_range: &RangeArray,
147    value_range: &RangeArray,
148    index: usize,
149    func_name: &str,
150) -> Result<(TimestampMillisecondArray, Float64Array), DataFusionError> {
151    let timestamps = ts_range
152        .get(index)
153        .unwrap()
154        .as_any()
155        .downcast_ref::<TimestampMillisecondArray>()
156        .unwrap()
157        .clone();
158    let values = value_range
159        .get(index)
160        .unwrap()
161        .as_any()
162        .downcast_ref::<Float64Array>()
163        .unwrap()
164        .clone();
165    error::ensure(
166        timestamps.len() == values.len(),
167        DataFusionError::Execution(format!(
168            "{}: time and value arrays in a group should have the same length, found {} and {}",
169            func_name,
170            timestamps.len(),
171            values.len()
172        )),
173    )?;
174    Ok((timestamps, values))
175}
176
177fn predict_linear_impl(
178    timestamps: &TimestampMillisecondArray,
179    values: &Float64Array,
180    t: i64,
181) -> Option<f64> {
182    if timestamps.len() < 2 {
183        return None;
184    }
185
186    // last timestamp is evaluation timestamp
187    let evaluate_ts = timestamps.value(timestamps.len() - 1);
188    let (slope, intercept) = linear_regression(timestamps, values, evaluate_ts);
189
190    if slope.is_none() || intercept.is_none() {
191        return None;
192    }
193
194    Some(slope.unwrap() * t as f64 + intercept.unwrap())
195}
196
197#[cfg(test)]
198mod test {
199    use std::vec;
200
201    use super::*;
202    use crate::functions::test_util::simple_range_udf_runner;
203
204    // build timestamp range and value range arrays for test
205    fn build_test_range_arrays() -> (RangeArray, RangeArray) {
206        let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
207            [
208                0i64, 300, 600, 900, 1200, 1500, 1800, 2100, 2400, 2700, 3000,
209            ]
210            .into_iter()
211            .map(Some),
212        ));
213        let ranges = [(0, 11)];
214
215        let values_array = Arc::new(Float64Array::from_iter([
216            0.0, 10.0, 20.0, 30.0, 40.0, 0.0, 10.0, 20.0, 30.0, 40.0, 50.0,
217        ]));
218
219        let ts_range_array = RangeArray::from_ranges(ts_array, ranges).unwrap();
220        let value_range_array = RangeArray::from_ranges(values_array, ranges).unwrap();
221
222        (ts_range_array, value_range_array)
223    }
224
225    #[test]
226    fn calculate_predict_linear_none() {
227        let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
228            [0i64].into_iter().map(Some),
229        ));
230        let ranges = [(0, 0), (0, 1)];
231        let values_array = Arc::new(Float64Array::from_iter([0.0]));
232        let ts_array = RangeArray::from_ranges(ts_array, ranges).unwrap();
233        let value_array = RangeArray::from_ranges(values_array, ranges).unwrap();
234        simple_range_udf_runner(
235            PredictLinear::scalar_udf(),
236            ts_array,
237            value_array,
238            vec![ScalarValue::Int64(Some(0))],
239            vec![None, None],
240        );
241    }
242
243    #[test]
244    fn calculate_predict_linear_test1() {
245        let (ts_array, value_array) = build_test_range_arrays();
246        simple_range_udf_runner(
247            PredictLinear::scalar_udf(),
248            ts_array,
249            value_array,
250            vec![ScalarValue::Int64(Some(0))],
251            // value at t = 0
252            vec![Some(38.63636363636364)],
253        );
254    }
255
256    #[test]
257    fn calculate_predict_linear_test2() {
258        let (ts_array, value_array) = build_test_range_arrays();
259        simple_range_udf_runner(
260            PredictLinear::scalar_udf(),
261            ts_array,
262            value_array,
263            vec![ScalarValue::Int64(Some(3000))],
264            // value at t = 3000
265            vec![Some(31856.818181818187)],
266        );
267    }
268
269    #[test]
270    fn calculate_predict_linear_test3() {
271        let (ts_array, value_array) = build_test_range_arrays();
272        simple_range_udf_runner(
273            PredictLinear::scalar_udf(),
274            ts_array,
275            value_array,
276            vec![ScalarValue::Int64(Some(4200))],
277            // value at t = 4200
278            vec![Some(44584.09090909091)],
279        );
280    }
281
282    #[test]
283    fn calculate_predict_linear_test4() {
284        let (ts_array, value_array) = build_test_range_arrays();
285        simple_range_udf_runner(
286            PredictLinear::scalar_udf(),
287            ts_array,
288            value_array,
289            vec![ScalarValue::Int64(Some(6600))],
290            // value at t = 6600
291            vec![Some(70038.63636363638)],
292        );
293    }
294
295    #[test]
296    fn calculate_predict_linear_test5() {
297        let (ts_array, value_array) = build_test_range_arrays();
298        simple_range_udf_runner(
299            PredictLinear::scalar_udf(),
300            ts_array,
301            value_array,
302            vec![ScalarValue::Int64(Some(7800))],
303            // value at t = 7800
304            vec![Some(82765.9090909091)],
305        );
306    }
307}