promql/functions/
holt_winters.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Implementation of [`holt_winters`](https://prometheus.io/docs/prometheus/latest/querying/functions/#holt_winters) in PromQL. Refer to the [original
16//! implementation](https://github.com/prometheus/prometheus/blob/8dba9163f1e923ec213f0f4d5c185d9648e387f0/promql/functions.go#L299).
17
18use std::sync::Arc;
19
20use datafusion::arrow::array::Float64Array;
21use datafusion::arrow::datatypes::TimeUnit;
22use datafusion::common::DataFusionError;
23use datafusion::logical_expr::{ScalarUDF, Volatility};
24use datafusion::physical_plan::ColumnarValue;
25use datafusion_common::ScalarValue;
26use datafusion_expr::create_udf;
27use datatypes::arrow::array::Array;
28use datatypes::arrow::datatypes::DataType;
29
30use crate::error;
31use crate::functions::extract_array;
32use crate::range_array::RangeArray;
33
34/// There are 3 variants of smoothing functions:
35/// 1) "Simple exponential smoothing": only the `level` component (the weighted average of the observations) is used to make forecasts.
36///    This method is applied for time-series data that does not exhibit trend or seasonality.
37/// 2) "Holt's linear method" (a.k.a. "double exponential smoothing"): `level` and `trend` components are used to make forecasts.
38///    This method is applied for time-series data that exhibits trend but not seasonality.
39/// 3) "Holt-Winter's method" (a.k.a. "triple exponential smoothing"): `level`, `trend`, and `seasonality` are used to make forecasts.
40///
41/// This method is applied for time-series data that exhibits both trend and seasonality.
42///
43/// In order to keep the parity with the Prometheus functions we had to follow the same naming ("HoltWinters"), however
44/// the "Holt's linear"("double exponential smoothing") suits better and reflects implementation.
45/// There's the [discussion](https://github.com/prometheus/prometheus/issues/2458) in the Prometheus Github that dates back
46/// to 2017 highlighting the naming/implementation mismatch.
47pub struct HoltWinters {
48    sf: f64,
49    tf: f64,
50}
51
52impl HoltWinters {
53    fn new(sf: f64, tf: f64) -> Self {
54        Self { sf, tf }
55    }
56
57    pub const fn name() -> &'static str {
58        "prom_holt_winters"
59    }
60
61    // time index column and value column
62    fn input_type() -> Vec<DataType> {
63        vec![
64            RangeArray::convert_data_type(DataType::Timestamp(TimeUnit::Millisecond, None)),
65            RangeArray::convert_data_type(DataType::Float64),
66            // sf
67            DataType::Float64,
68            // tf
69            DataType::Float64,
70        ]
71    }
72
73    fn return_type() -> DataType {
74        DataType::Float64
75    }
76
77    pub fn scalar_udf() -> ScalarUDF {
78        create_udf(
79            Self::name(),
80            Self::input_type(),
81            Self::return_type(),
82            Volatility::Volatile,
83            Arc::new(move |input: &_| Self::create_function(input)?.calc(input)) as _,
84        )
85    }
86
87    fn create_function(inputs: &[ColumnarValue]) -> Result<Self, DataFusionError> {
88        if inputs.len() != 4 {
89            return Err(DataFusionError::Plan(
90                "HoltWinters function should have 4 inputs".to_string(),
91            ));
92        }
93        let ColumnarValue::Scalar(ScalarValue::Float64(Some(sf))) = inputs[2] else {
94            return Err(DataFusionError::Plan(
95                "HoltWinters function's third input should be a scalar float64".to_string(),
96            ));
97        };
98        let ColumnarValue::Scalar(ScalarValue::Float64(Some(tf))) = inputs[3] else {
99            return Err(DataFusionError::Plan(
100                "HoltWinters function's fourth input should be a scalar float64".to_string(),
101            ));
102        };
103        Ok(Self::new(sf, tf))
104    }
105
106    fn calc(&self, input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
107        // construct matrix from input.
108        // The third one is level param, the fourth - trend param which are included in fields.
109        assert_eq!(input.len(), 4);
110
111        let ts_array = extract_array(&input[0])?;
112        let value_array = extract_array(&input[1])?;
113
114        let ts_range: RangeArray = RangeArray::try_new(ts_array.to_data().into())?;
115        let value_range: RangeArray = RangeArray::try_new(value_array.to_data().into())?;
116
117        error::ensure(
118            ts_range.len() == value_range.len(),
119            DataFusionError::Execution(format!(
120                "{}: input arrays should have the same length, found {} and {}",
121                Self::name(),
122                ts_range.len(),
123                value_range.len()
124            )),
125        )?;
126        error::ensure(
127            ts_range.value_type() == DataType::Timestamp(TimeUnit::Millisecond, None),
128            DataFusionError::Execution(format!(
129                "{}: expect TimestampMillisecond as time index array's type, found {}",
130                Self::name(),
131                ts_range.value_type()
132            )),
133        )?;
134        error::ensure(
135            value_range.value_type() == DataType::Float64,
136            DataFusionError::Execution(format!(
137                "{}: expect Float64 as value array's type, found {}",
138                Self::name(),
139                value_range.value_type()
140            )),
141        )?;
142
143        // calculation
144        let mut result_array = Vec::with_capacity(ts_range.len());
145        for index in 0..ts_range.len() {
146            let timestamps = ts_range.get(index).unwrap();
147            let values = value_range.get(index).unwrap();
148            let values = values
149                .as_any()
150                .downcast_ref::<Float64Array>()
151                .unwrap()
152                .values();
153            error::ensure(
154                timestamps.len() == values.len(),
155                DataFusionError::Execution(format!(
156                    "{}: input arrays should have the same length, found {} and {}",
157                    Self::name(),
158                    timestamps.len(),
159                    values.len()
160                )),
161            )?;
162            result_array.push(holt_winter_impl(values, self.sf, self.tf));
163        }
164
165        let result = ColumnarValue::Array(Arc::new(Float64Array::from_iter(result_array)));
166        Ok(result)
167    }
168}
169
170fn calc_trend_value(i: usize, tf: f64, s0: f64, s1: f64, b: f64) -> f64 {
171    if i == 0 {
172        return b;
173    }
174    let x = tf * (s1 - s0);
175    let y = (1.0 - tf) * b;
176    x + y
177}
178
179/// Refer to <https://github.com/prometheus/prometheus/blob/main/promql/functions.go#L299>
180fn holt_winter_impl(values: &[f64], sf: f64, tf: f64) -> Option<f64> {
181    if sf.is_nan() || tf.is_nan() || values.is_empty() {
182        return Some(f64::NAN);
183    }
184    if sf < 0.0 || tf < 0.0 {
185        return Some(f64::NEG_INFINITY);
186    }
187    if sf > 1.0 || tf > 1.0 {
188        return Some(f64::INFINITY);
189    }
190
191    let l = values.len();
192    if l <= 2 {
193        // Can't do the smoothing operation with less than two points.
194        return Some(f64::NAN);
195    }
196
197    let values = values.to_vec();
198
199    let mut s0 = 0.0;
200    let mut s1 = values[0];
201    let mut b = values[1] - values[0];
202
203    for (i, value) in values.iter().enumerate().skip(1) {
204        // Scale the raw value against the smoothing factor.
205        let x = sf * value;
206        // Scale the last smoothed value with the trend at this point.
207        b = calc_trend_value(i - 1, tf, s0, s1, b);
208        let y = (1.0 - sf) * (s1 + b);
209        s0 = s1;
210        s1 = x + y;
211    }
212    Some(s1)
213}
214
215#[cfg(test)]
216mod tests {
217    use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
218
219    use super::*;
220    use crate::functions::test_util::simple_range_udf_runner;
221
222    #[test]
223    fn test_holt_winter_impl_empty() {
224        let sf = 0.5;
225        let tf = 0.5;
226        let values = &[];
227        assert!(holt_winter_impl(values, sf, tf).unwrap().is_nan());
228
229        let values = &[1.0, 2.0];
230        assert!(holt_winter_impl(values, sf, tf).unwrap().is_nan());
231    }
232
233    #[test]
234    fn test_holt_winter_impl_nan() {
235        let values = &[1.0, 2.0, 3.0];
236        let sf = f64::NAN;
237        let tf = 0.5;
238        assert!(holt_winter_impl(values, sf, tf).unwrap().is_nan());
239
240        let values = &[1.0, 2.0, 3.0];
241        let sf = 0.5;
242        let tf = f64::NAN;
243        assert!(holt_winter_impl(values, sf, tf).unwrap().is_nan());
244    }
245
246    #[test]
247    fn test_holt_winter_impl_validation_rules() {
248        let values = &[1.0, 2.0, 3.0];
249        let sf = -0.5;
250        let tf = 0.5;
251        assert_eq!(holt_winter_impl(values, sf, tf).unwrap(), f64::NEG_INFINITY);
252
253        let values = &[1.0, 2.0, 3.0];
254        let sf = 0.5;
255        let tf = -0.5;
256        assert_eq!(holt_winter_impl(values, sf, tf).unwrap(), f64::NEG_INFINITY);
257
258        let values = &[1.0, 2.0, 3.0];
259        let sf = 1.5;
260        let tf = 0.5;
261        assert_eq!(holt_winter_impl(values, sf, tf).unwrap(), f64::INFINITY);
262
263        let values = &[1.0, 2.0, 3.0];
264        let sf = 0.5;
265        let tf = 1.5;
266        assert_eq!(holt_winter_impl(values, sf, tf).unwrap(), f64::INFINITY);
267    }
268
269    #[test]
270    fn test_holt_winter_impl() {
271        let sf = 0.5;
272        let tf = 0.1;
273        let values = &[1.0, 2.0, 3.0, 4.0, 5.0];
274        assert_eq!(holt_winter_impl(values, sf, tf), Some(5.0));
275        let values = &[50.0, 52.0, 95.0, 59.0, 52.0, 45.0, 38.0, 10.0, 47.0, 40.0];
276        assert_eq!(holt_winter_impl(values, sf, tf), Some(38.18119566835938));
277    }
278
279    #[test]
280    fn test_prom_holt_winter_monotonic() {
281        let ranges = [(0, 5)];
282        let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
283            [1000i64, 3000, 5000, 7000, 9000, 11000, 13000, 15000, 17000]
284                .into_iter()
285                .map(Some),
286        ));
287        let values_array = Arc::new(Float64Array::from_iter([1.0, 2.0, 3.0, 4.0, 5.0]));
288        let ts_range_array = RangeArray::from_ranges(ts_array, ranges).unwrap();
289        let value_range_array = RangeArray::from_ranges(values_array, ranges).unwrap();
290        simple_range_udf_runner(
291            HoltWinters::scalar_udf(),
292            ts_range_array,
293            value_range_array,
294            vec![
295                ScalarValue::Float64(Some(0.5)),
296                ScalarValue::Float64(Some(0.1)),
297            ],
298            vec![Some(5.0)],
299        );
300    }
301
302    #[test]
303    fn test_prom_holt_winter_non_monotonic() {
304        let ranges = [(0, 10)];
305        let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
306            [
307                1000i64, 3000, 5000, 7000, 9000, 11000, 13000, 15000, 17000, 19000,
308            ]
309            .into_iter()
310            .map(Some),
311        ));
312        let values_array = Arc::new(Float64Array::from_iter([
313            50.0, 52.0, 95.0, 59.0, 52.0, 45.0, 38.0, 10.0, 47.0, 40.0,
314        ]));
315        let ts_range_array = RangeArray::from_ranges(ts_array, ranges).unwrap();
316        let value_range_array = RangeArray::from_ranges(values_array, ranges).unwrap();
317        simple_range_udf_runner(
318            HoltWinters::scalar_udf(),
319            ts_range_array,
320            value_range_array,
321            vec![
322                ScalarValue::Float64(Some(0.5)),
323                ScalarValue::Float64(Some(0.1)),
324            ],
325            vec![Some(38.18119566835938)],
326        );
327    }
328
329    #[test]
330    fn test_promql_trends() {
331        let ranges = vec![(0, 801)];
332
333        let trends = vec![
334            // positive trends https://github.com/prometheus/prometheus/blob/8dba9163f1e923ec213f0f4d5c185d9648e387f0/promql/testdata/functions.test#L475
335            ("0+10x1000 100+30x1000", 8000.0),
336            ("0+20x1000 200+30x1000", 16000.0),
337            ("0+30x1000 300+80x1000", 24000.0),
338            ("0+40x2000", 32000.0),
339            // negative trends https://github.com/prometheus/prometheus/blob/8dba9163f1e923ec213f0f4d5c185d9648e387f0/promql/testdata/functions.test#L488
340            ("8000-10x1000", 0.0),
341            ("0-20x1000", -16000.0),
342            ("0+30x1000 300-80x1000", 24000.0),
343            ("0-40x1000 0+40x1000", -32000.0),
344        ];
345
346        for (query, expected) in trends {
347            let (ts_range_array, value_range_array) =
348                create_ts_and_value_range_arrays(query, ranges.clone());
349            simple_range_udf_runner(
350                HoltWinters::scalar_udf(),
351                ts_range_array,
352                value_range_array,
353                vec![
354                    ScalarValue::Float64(Some(0.01)),
355                    ScalarValue::Float64(Some(0.1)),
356                ],
357                vec![Some(expected)],
358            );
359        }
360    }
361
362    fn create_ts_and_value_range_arrays(
363        input: &str,
364        ranges: Vec<(u32, u32)>,
365    ) -> (RangeArray, RangeArray) {
366        let promql_range = create_test_range_from_promql_series(input);
367        let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
368            (0..(promql_range.len() as i64)).map(Some),
369        ));
370        let values_array = Arc::new(Float64Array::from_iter(promql_range));
371        let ts_range_array = RangeArray::from_ranges(ts_array, ranges.clone()).unwrap();
372        let value_range_array = RangeArray::from_ranges(values_array, ranges).unwrap();
373        (ts_range_array, value_range_array)
374    }
375
376    /// Converts a prometheus functions test series into a vector of f64 element with respect to resets and trend direction   
377    /// The input example: "0+10x1000 100+30x1000"
378    fn create_test_range_from_promql_series(input: &str) -> Vec<f64> {
379        input.split(' ').map(parse_promql_series_entry).fold(
380            Vec::new(),
381            |mut acc, (start, end, step, operation)| {
382                if operation.eq("+") {
383                    let iter = (start..=((step * end) + start))
384                        .step_by(step as usize)
385                        .map(|x| x as f64);
386                    acc.extend(iter);
387                } else {
388                    let iter = (((-step * end) + start)..=start)
389                        .rev()
390                        .step_by(step as usize)
391                        .map(|x| x as f64);
392                    acc.extend(iter);
393                };
394                acc
395            },
396        )
397    }
398
399    /// Converts a prometheus functions test series entry into separate parts to create a range with a step
400    /// The input example: "100+30x1000"
401    fn parse_promql_series_entry(input: &str) -> (i32, i32, i32, &str) {
402        let mut parts = input.split('x');
403        let start_operation_step = parts.next().unwrap();
404        let operation = start_operation_step
405            .split(char::is_numeric)
406            .find(|&x| !x.is_empty())
407            .unwrap();
408        let start_step = start_operation_step
409            .split(operation)
410            .map(|s| s.parse::<i32>().unwrap())
411            .collect::<Vec<_>>();
412        let start = *start_step.first().unwrap();
413        let step = *start_step.last().unwrap();
414        let end = parts.next().unwrap().parse::<i32>().unwrap();
415        (start, end, step, operation)
416    }
417}