promql/functions/
holt_winters.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Implementation of [`holt_winters`](https://prometheus.io/docs/prometheus/latest/querying/functions/#holt_winters) in PromQL. Refer to the [original
16//! implementation](https://github.com/prometheus/prometheus/blob/8dba9163f1e923ec213f0f4d5c185d9648e387f0/promql/functions.go#L299).
17
18use std::sync::Arc;
19
20use datafusion::arrow::array::Float64Array;
21use datafusion::arrow::datatypes::TimeUnit;
22use datafusion::common::DataFusionError;
23use datafusion::logical_expr::{ScalarUDF, Volatility};
24use datafusion::physical_plan::ColumnarValue;
25use datafusion_common::ScalarValue;
26use datafusion_expr::create_udf;
27use datatypes::arrow::array::Array;
28use datatypes::arrow::datatypes::DataType;
29
30use crate::error;
31use crate::functions::extract_array;
32use crate::range_array::RangeArray;
33
34/// `FactorIterator` iterates over a `ColumnarValue` that can be a scalar or an array.
35struct FactorIterator<'a> {
36    is_scalar: bool,
37    array: Option<&'a Float64Array>,
38    scalar_val: f64,
39    index: usize,
40    len: usize,
41}
42
43impl<'a> FactorIterator<'a> {
44    fn new(value: &'a ColumnarValue, len: usize) -> Self {
45        let (is_scalar, array, scalar_val) = match value {
46            ColumnarValue::Array(arr) => {
47                (false, arr.as_any().downcast_ref::<Float64Array>(), f64::NAN)
48            }
49            ColumnarValue::Scalar(ScalarValue::Float64(Some(val))) => (true, None, *val),
50            _ => (true, None, f64::NAN),
51        };
52
53        Self {
54            is_scalar,
55            array,
56            scalar_val,
57            index: 0,
58            len,
59        }
60    }
61}
62
63impl<'a> Iterator for FactorIterator<'a> {
64    type Item = f64;
65
66    fn next(&mut self) -> Option<Self::Item> {
67        if self.index >= self.len {
68            return None;
69        }
70        self.index += 1;
71
72        if self.is_scalar {
73            return Some(self.scalar_val);
74        }
75
76        if let Some(array) = self.array {
77            if array.is_null(self.index - 1) {
78                Some(f64::NAN)
79            } else {
80                Some(array.value(self.index - 1))
81            }
82        } else {
83            Some(f64::NAN)
84        }
85    }
86}
87
88/// There are 3 variants of smoothing functions:
89/// 1) "Simple exponential smoothing": only the `level` component (the weighted average of the observations) is used to make forecasts.
90///    This method is applied for time-series data that does not exhibit trend or seasonality.
91/// 2) "Holt's linear method" (a.k.a. "double exponential smoothing"): `level` and `trend` components are used to make forecasts.
92///    This method is applied for time-series data that exhibits trend but not seasonality.
93/// 3) "Holt-Winter's method" (a.k.a. "triple exponential smoothing"): `level`, `trend`, and `seasonality` are used to make forecasts.
94///
95/// This method is applied for time-series data that exhibits both trend and seasonality.
96///
97/// In order to keep the parity with the Prometheus functions we had to follow the same naming ("HoltWinters"), however
98/// the "Holt's linear"("double exponential smoothing") suits better and reflects implementation.
99/// There's the [discussion](https://github.com/prometheus/prometheus/issues/2458) in the Prometheus Github that dates back
100/// to 2017 highlighting the naming/implementation mismatch.
101pub struct HoltWinters;
102
103impl HoltWinters {
104    pub const fn name() -> &'static str {
105        "prom_holt_winters"
106    }
107
108    // time index column and value column
109    fn input_type() -> Vec<DataType> {
110        vec![
111            RangeArray::convert_data_type(DataType::Timestamp(TimeUnit::Millisecond, None)),
112            RangeArray::convert_data_type(DataType::Float64),
113            // sf
114            DataType::Float64,
115            // tf
116            DataType::Float64,
117        ]
118    }
119
120    fn return_type() -> DataType {
121        DataType::Float64
122    }
123
124    pub fn scalar_udf() -> ScalarUDF {
125        create_udf(
126            Self::name(),
127            Self::input_type(),
128            Self::return_type(),
129            Volatility::Volatile,
130            Arc::new(Self::holt_winters) as _,
131        )
132    }
133
134    fn holt_winters(input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
135        error::ensure(
136            input.len() == 4,
137            DataFusionError::Plan("prom_holt_winters function should have 4 inputs".to_string()),
138        )?;
139
140        let ts_array = extract_array(&input[0])?;
141        let value_array = extract_array(&input[1])?;
142        let sf_col = &input[2];
143        let tf_col = &input[3];
144
145        let ts_range: RangeArray = RangeArray::try_new(ts_array.to_data().into())?;
146        let value_range: RangeArray = RangeArray::try_new(value_array.to_data().into())?;
147        let num_rows = ts_range.len();
148
149        error::ensure(
150            num_rows == value_range.len(),
151            DataFusionError::Execution(format!(
152                "{}: input arrays should have the same length, found {} and {}",
153                Self::name(),
154                num_rows,
155                value_range.len()
156            )),
157        )?;
158        error::ensure(
159            ts_range.value_type() == DataType::Timestamp(TimeUnit::Millisecond, None),
160            DataFusionError::Execution(format!(
161                "{}: expect TimestampMillisecond as time index array's type, found {}",
162                Self::name(),
163                ts_range.value_type()
164            )),
165        )?;
166        error::ensure(
167            value_range.value_type() == DataType::Float64,
168            DataFusionError::Execution(format!(
169                "{}: expect Float64 as value array's type, found {}",
170                Self::name(),
171                value_range.value_type()
172            )),
173        )?;
174
175        // calculation
176        let mut result_array = Vec::with_capacity(ts_range.len());
177
178        let sf_iter = FactorIterator::new(sf_col, num_rows);
179        let tf_iter = FactorIterator::new(tf_col, num_rows);
180
181        let iter = (0..num_rows)
182            .map(|i| (ts_range.get(i), value_range.get(i)))
183            .zip(sf_iter.zip(tf_iter));
184
185        for ((timestamps, values), (sf, tf)) in iter {
186            let timestamps = timestamps.unwrap();
187            let values = values.unwrap();
188            let values = values
189                .as_any()
190                .downcast_ref::<Float64Array>()
191                .unwrap()
192                .values();
193            error::ensure(
194                timestamps.len() == values.len(),
195                DataFusionError::Execution(format!(
196                    "{}: input arrays should have the same length, found {} and {}",
197                    Self::name(),
198                    timestamps.len(),
199                    values.len()
200                )),
201            )?;
202
203            result_array.push(holt_winter_impl(values, sf, tf));
204        }
205
206        let result = ColumnarValue::Array(Arc::new(Float64Array::from_iter(result_array)));
207        Ok(result)
208    }
209}
210
211fn calc_trend_value(i: usize, tf: f64, s0: f64, s1: f64, b: f64) -> f64 {
212    if i == 0 {
213        return b;
214    }
215    let x = tf * (s1 - s0);
216    let y = (1.0 - tf) * b;
217    x + y
218}
219
220/// Refer to <https://github.com/prometheus/prometheus/blob/main/promql/functions.go#L299>
221fn holt_winter_impl(values: &[f64], sf: f64, tf: f64) -> Option<f64> {
222    if sf.is_nan() || tf.is_nan() || values.is_empty() {
223        return Some(f64::NAN);
224    }
225    if sf < 0.0 || tf < 0.0 {
226        return Some(f64::NEG_INFINITY);
227    }
228    if sf > 1.0 || tf > 1.0 {
229        return Some(f64::INFINITY);
230    }
231
232    let l = values.len();
233    if l <= 2 {
234        // Can't do the smoothing operation with less than two points.
235        return Some(f64::NAN);
236    }
237
238    let values = values.to_vec();
239
240    let mut s0 = 0.0;
241    let mut s1 = values[0];
242    let mut b = values[1] - values[0];
243
244    for (i, value) in values.iter().enumerate().skip(1) {
245        // Scale the raw value against the smoothing factor.
246        let x = sf * value;
247        // Scale the last smoothed value with the trend at this point.
248        b = calc_trend_value(i - 1, tf, s0, s1, b);
249        let y = (1.0 - sf) * (s1 + b);
250        s0 = s1;
251        s1 = x + y;
252    }
253    Some(s1)
254}
255
256#[cfg(test)]
257mod tests {
258    use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
259
260    use super::*;
261    use crate::functions::test_util::simple_range_udf_runner;
262
263    #[test]
264    fn test_holt_winter_impl_empty() {
265        let sf = 0.5;
266        let tf = 0.5;
267        let values = &[];
268        assert!(holt_winter_impl(values, sf, tf).unwrap().is_nan());
269
270        let values = &[1.0, 2.0];
271        assert!(holt_winter_impl(values, sf, tf).unwrap().is_nan());
272    }
273
274    #[test]
275    fn test_holt_winter_impl_nan() {
276        let values = &[1.0, 2.0, 3.0];
277        let sf = f64::NAN;
278        let tf = 0.5;
279        assert!(holt_winter_impl(values, sf, tf).unwrap().is_nan());
280
281        let values = &[1.0, 2.0, 3.0];
282        let sf = 0.5;
283        let tf = f64::NAN;
284        assert!(holt_winter_impl(values, sf, tf).unwrap().is_nan());
285    }
286
287    #[test]
288    fn test_holt_winter_impl_validation_rules() {
289        let values = &[1.0, 2.0, 3.0];
290        let sf = -0.5;
291        let tf = 0.5;
292        assert_eq!(holt_winter_impl(values, sf, tf).unwrap(), f64::NEG_INFINITY);
293
294        let values = &[1.0, 2.0, 3.0];
295        let sf = 0.5;
296        let tf = -0.5;
297        assert_eq!(holt_winter_impl(values, sf, tf).unwrap(), f64::NEG_INFINITY);
298
299        let values = &[1.0, 2.0, 3.0];
300        let sf = 1.5;
301        let tf = 0.5;
302        assert_eq!(holt_winter_impl(values, sf, tf).unwrap(), f64::INFINITY);
303
304        let values = &[1.0, 2.0, 3.0];
305        let sf = 0.5;
306        let tf = 1.5;
307        assert_eq!(holt_winter_impl(values, sf, tf).unwrap(), f64::INFINITY);
308    }
309
310    #[test]
311    fn test_holt_winter_impl() {
312        let sf = 0.5;
313        let tf = 0.1;
314        let values = &[1.0, 2.0, 3.0, 4.0, 5.0];
315        assert_eq!(holt_winter_impl(values, sf, tf), Some(5.0));
316        let values = &[50.0, 52.0, 95.0, 59.0, 52.0, 45.0, 38.0, 10.0, 47.0, 40.0];
317        assert_eq!(holt_winter_impl(values, sf, tf), Some(38.18119566835938));
318    }
319
320    #[test]
321    fn test_prom_holt_winter_monotonic() {
322        let ranges = [(0, 5)];
323        let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
324            [1000i64, 3000, 5000, 7000, 9000, 11000, 13000, 15000, 17000]
325                .into_iter()
326                .map(Some),
327        ));
328        let values_array = Arc::new(Float64Array::from_iter([1.0, 2.0, 3.0, 4.0, 5.0]));
329        let ts_range_array = RangeArray::from_ranges(ts_array, ranges).unwrap();
330        let value_range_array = RangeArray::from_ranges(values_array, ranges).unwrap();
331        simple_range_udf_runner(
332            HoltWinters::scalar_udf(),
333            ts_range_array,
334            value_range_array,
335            vec![
336                ScalarValue::Float64(Some(0.5)),
337                ScalarValue::Float64(Some(0.1)),
338            ],
339            vec![Some(5.0)],
340        );
341    }
342
343    #[test]
344    fn test_prom_holt_winter_non_monotonic() {
345        let ranges = [(0, 10)];
346        let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
347            [
348                1000i64, 3000, 5000, 7000, 9000, 11000, 13000, 15000, 17000, 19000,
349            ]
350            .into_iter()
351            .map(Some),
352        ));
353        let values_array = Arc::new(Float64Array::from_iter([
354            50.0, 52.0, 95.0, 59.0, 52.0, 45.0, 38.0, 10.0, 47.0, 40.0,
355        ]));
356        let ts_range_array = RangeArray::from_ranges(ts_array, ranges).unwrap();
357        let value_range_array = RangeArray::from_ranges(values_array, ranges).unwrap();
358        simple_range_udf_runner(
359            HoltWinters::scalar_udf(),
360            ts_range_array,
361            value_range_array,
362            vec![
363                ScalarValue::Float64(Some(0.5)),
364                ScalarValue::Float64(Some(0.1)),
365            ],
366            vec![Some(38.18119566835938)],
367        );
368    }
369
370    #[test]
371    fn test_promql_trends() {
372        let ranges = vec![(0, 801)];
373
374        let trends = vec![
375            // positive trends https://github.com/prometheus/prometheus/blob/8dba9163f1e923ec213f0f4d5c185d9648e387f0/promql/testdata/functions.test#L475
376            ("0+10x1000 100+30x1000", 8000.0),
377            ("0+20x1000 200+30x1000", 16000.0),
378            ("0+30x1000 300+80x1000", 24000.0),
379            ("0+40x2000", 32000.0),
380            // negative trends https://github.com/prometheus/prometheus/blob/8dba9163f1e923ec213f0f4d5c185d9648e387f0/promql/testdata/functions.test#L488
381            ("8000-10x1000", 0.0),
382            ("0-20x1000", -16000.0),
383            ("0+30x1000 300-80x1000", 24000.0),
384            ("0-40x1000 0+40x1000", -32000.0),
385        ];
386
387        for (query, expected) in trends {
388            let (ts_range_array, value_range_array) =
389                create_ts_and_value_range_arrays(query, ranges.clone());
390            simple_range_udf_runner(
391                HoltWinters::scalar_udf(),
392                ts_range_array,
393                value_range_array,
394                vec![
395                    ScalarValue::Float64(Some(0.01)),
396                    ScalarValue::Float64(Some(0.1)),
397                ],
398                vec![Some(expected)],
399            );
400        }
401    }
402
403    fn create_ts_and_value_range_arrays(
404        input: &str,
405        ranges: Vec<(u32, u32)>,
406    ) -> (RangeArray, RangeArray) {
407        let promql_range = create_test_range_from_promql_series(input);
408        let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
409            (0..(promql_range.len() as i64)).map(Some),
410        ));
411        let values_array = Arc::new(Float64Array::from_iter(promql_range));
412        let ts_range_array = RangeArray::from_ranges(ts_array, ranges.clone()).unwrap();
413        let value_range_array = RangeArray::from_ranges(values_array, ranges).unwrap();
414        (ts_range_array, value_range_array)
415    }
416
417    /// Converts a prometheus functions test series into a vector of f64 element with respect to resets and trend direction   
418    /// The input example: "0+10x1000 100+30x1000"
419    fn create_test_range_from_promql_series(input: &str) -> Vec<f64> {
420        input.split(' ').map(parse_promql_series_entry).fold(
421            Vec::new(),
422            |mut acc, (start, end, step, operation)| {
423                if operation.eq("+") {
424                    let iter = (start..=((step * end) + start))
425                        .step_by(step as usize)
426                        .map(|x| x as f64);
427                    acc.extend(iter);
428                } else {
429                    let iter = (((-step * end) + start)..=start)
430                        .rev()
431                        .step_by(step as usize)
432                        .map(|x| x as f64);
433                    acc.extend(iter);
434                };
435                acc
436            },
437        )
438    }
439
440    /// Converts a prometheus functions test series entry into separate parts to create a range with a step
441    /// The input example: "100+30x1000"
442    fn parse_promql_series_entry(input: &str) -> (i32, i32, i32, &str) {
443        let mut parts = input.split('x');
444        let start_operation_step = parts.next().unwrap();
445        let operation = start_operation_step
446            .split(char::is_numeric)
447            .find(|&x| !x.is_empty())
448            .unwrap();
449        let start_step = start_operation_step
450            .split(operation)
451            .map(|s| s.parse::<i32>().unwrap())
452            .collect::<Vec<_>>();
453        let start = *start_step.first().unwrap();
454        let step = *start_step.last().unwrap();
455        let end = parts.next().unwrap().parse::<i32>().unwrap();
456        (start, end, step, operation)
457    }
458}