promql/functions/
predict_linear.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Implementation of [`predict_linear`](https://prometheus.io/docs/prometheus/latest/querying/functions/#predict_linear) in PromQL. Refer to the [original
16//! implementation](https://github.com/prometheus/prometheus/blob/90b2f7a540b8a70d8d81372e6692dcbb67ccbaaa/promql/functions.go#L859-L872).
17
18use std::sync::Arc;
19
20use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
21use datafusion::arrow::datatypes::TimeUnit;
22use datafusion::common::DataFusionError;
23use datafusion::logical_expr::{ScalarUDF, Volatility};
24use datafusion::physical_plan::ColumnarValue;
25use datafusion_common::ScalarValue;
26use datafusion_expr::create_udf;
27use datatypes::arrow::array::Array;
28use datatypes::arrow::datatypes::DataType;
29
30use crate::error;
31use crate::functions::{extract_array, linear_regression};
32use crate::range_array::RangeArray;
33
34pub struct PredictLinear {
35    /// Duration. The second param of (`predict_linear(v range-vector, t scalar)`).
36    t: i64,
37}
38
39impl PredictLinear {
40    fn new(t: i64) -> Self {
41        Self { t }
42    }
43
44    pub const fn name() -> &'static str {
45        "prom_predict_linear"
46    }
47
48    pub fn scalar_udf() -> ScalarUDF {
49        let input_types = vec![
50            // time index column
51            RangeArray::convert_data_type(DataType::Timestamp(TimeUnit::Millisecond, None)),
52            // value column
53            RangeArray::convert_data_type(DataType::Float64),
54            // t
55            DataType::Int64,
56        ];
57        create_udf(
58            Self::name(),
59            input_types,
60            DataType::Float64,
61            Volatility::Volatile,
62            Arc::new(move |input: &_| Self::create_function(input)?.predict_linear(input)) as _,
63        )
64    }
65
66    fn create_function(inputs: &[ColumnarValue]) -> Result<Self, DataFusionError> {
67        if inputs.len() != 3 {
68            return Err(DataFusionError::Plan(
69                "PredictLinear function should have 3 inputs".to_string(),
70            ));
71        }
72        let ColumnarValue::Scalar(ScalarValue::Int64(Some(t))) = inputs[2] else {
73            return Err(DataFusionError::Plan(
74                "PredictLinear function's third input should be a scalar int64".to_string(),
75            ));
76        };
77        Ok(Self::new(t))
78    }
79
80    fn predict_linear(&self, input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
81        // construct matrix from input.
82        assert_eq!(input.len(), 3);
83        let ts_array = extract_array(&input[0])?;
84        let value_array = extract_array(&input[1])?;
85
86        let ts_range: RangeArray = RangeArray::try_new(ts_array.to_data().into())?;
87        let value_range: RangeArray = RangeArray::try_new(value_array.to_data().into())?;
88        error::ensure(
89            ts_range.len() == value_range.len(),
90            DataFusionError::Execution(format!(
91                "{}: input arrays should have the same length, found {} and {}",
92                Self::name(),
93                ts_range.len(),
94                value_range.len()
95            )),
96        )?;
97        error::ensure(
98            ts_range.value_type() == DataType::Timestamp(TimeUnit::Millisecond, None),
99            DataFusionError::Execution(format!(
100                "{}: expect TimestampMillisecond as time index array's type, found {}",
101                Self::name(),
102                ts_range.value_type()
103            )),
104        )?;
105        error::ensure(
106            value_range.value_type() == DataType::Float64,
107            DataFusionError::Execution(format!(
108                "{}: expect Float64 as value array's type, found {}",
109                Self::name(),
110                value_range.value_type()
111            )),
112        )?;
113
114        // calculation
115        let mut result_array = Vec::with_capacity(ts_range.len());
116
117        for index in 0..ts_range.len() {
118            let timestamps = ts_range
119                .get(index)
120                .unwrap()
121                .as_any()
122                .downcast_ref::<TimestampMillisecondArray>()
123                .unwrap()
124                .clone();
125            let values = value_range
126                .get(index)
127                .unwrap()
128                .as_any()
129                .downcast_ref::<Float64Array>()
130                .unwrap()
131                .clone();
132            error::ensure(
133                timestamps.len() == values.len(),
134                DataFusionError::Execution(format!(
135                    "{}: input arrays should have the same length, found {} and {}",
136                    Self::name(),
137                    timestamps.len(),
138                    values.len()
139                )),
140            )?;
141
142            let ret = predict_linear_impl(&timestamps, &values, self.t);
143
144            result_array.push(ret);
145        }
146
147        let result = ColumnarValue::Array(Arc::new(Float64Array::from_iter(result_array)));
148        Ok(result)
149    }
150}
151
152fn predict_linear_impl(
153    timestamps: &TimestampMillisecondArray,
154    values: &Float64Array,
155    t: i64,
156) -> Option<f64> {
157    if timestamps.len() < 2 {
158        return None;
159    }
160
161    // last timestamp is evaluation timestamp
162    let evaluate_ts = timestamps.value(timestamps.len() - 1);
163    let (slope, intercept) = linear_regression(timestamps, values, evaluate_ts);
164
165    if slope.is_none() || intercept.is_none() {
166        return None;
167    }
168
169    Some(slope.unwrap() * t as f64 + intercept.unwrap())
170}
171
172#[cfg(test)]
173mod test {
174    use std::vec;
175
176    use super::*;
177    use crate::functions::test_util::simple_range_udf_runner;
178
179    // build timestamp range and value range arrays for test
180    fn build_test_range_arrays() -> (RangeArray, RangeArray) {
181        let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
182            [
183                0i64, 300, 600, 900, 1200, 1500, 1800, 2100, 2400, 2700, 3000,
184            ]
185            .into_iter()
186            .map(Some),
187        ));
188        let ranges = [(0, 11)];
189
190        let values_array = Arc::new(Float64Array::from_iter([
191            0.0, 10.0, 20.0, 30.0, 40.0, 0.0, 10.0, 20.0, 30.0, 40.0, 50.0,
192        ]));
193
194        let ts_range_array = RangeArray::from_ranges(ts_array, ranges).unwrap();
195        let value_range_array = RangeArray::from_ranges(values_array, ranges).unwrap();
196
197        (ts_range_array, value_range_array)
198    }
199
200    #[test]
201    fn calculate_predict_linear_none() {
202        let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
203            [0i64].into_iter().map(Some),
204        ));
205        let ranges = [(0, 0), (0, 1)];
206        let values_array = Arc::new(Float64Array::from_iter([0.0]));
207        let ts_array = RangeArray::from_ranges(ts_array, ranges).unwrap();
208        let value_array = RangeArray::from_ranges(values_array, ranges).unwrap();
209        simple_range_udf_runner(
210            PredictLinear::scalar_udf(),
211            ts_array,
212            value_array,
213            vec![ScalarValue::Int64(Some(0))],
214            vec![None, None],
215        );
216    }
217
218    #[test]
219    fn calculate_predict_linear_test1() {
220        let (ts_array, value_array) = build_test_range_arrays();
221        simple_range_udf_runner(
222            PredictLinear::scalar_udf(),
223            ts_array,
224            value_array,
225            vec![ScalarValue::Int64(Some(0))],
226            // value at t = 0
227            vec![Some(38.63636363636364)],
228        );
229    }
230
231    #[test]
232    fn calculate_predict_linear_test2() {
233        let (ts_array, value_array) = build_test_range_arrays();
234        simple_range_udf_runner(
235            PredictLinear::scalar_udf(),
236            ts_array,
237            value_array,
238            vec![ScalarValue::Int64(Some(3000))],
239            // value at t = 3000
240            vec![Some(31856.818181818187)],
241        );
242    }
243
244    #[test]
245    fn calculate_predict_linear_test3() {
246        let (ts_array, value_array) = build_test_range_arrays();
247        simple_range_udf_runner(
248            PredictLinear::scalar_udf(),
249            ts_array,
250            value_array,
251            vec![ScalarValue::Int64(Some(4200))],
252            // value at t = 4200
253            vec![Some(44584.09090909091)],
254        );
255    }
256
257    #[test]
258    fn calculate_predict_linear_test4() {
259        let (ts_array, value_array) = build_test_range_arrays();
260        simple_range_udf_runner(
261            PredictLinear::scalar_udf(),
262            ts_array,
263            value_array,
264            vec![ScalarValue::Int64(Some(6600))],
265            // value at t = 6600
266            vec![Some(70038.63636363638)],
267        );
268    }
269
270    #[test]
271    fn calculate_predict_linear_test5() {
272        let (ts_array, value_array) = build_test_range_arrays();
273        simple_range_udf_runner(
274            PredictLinear::scalar_udf(),
275            ts_array,
276            value_array,
277            vec![ScalarValue::Int64(Some(7800))],
278            // value at t = 7800
279            vec![Some(82765.9090909091)],
280        );
281    }
282}