promql/functions/
double_exponential_smoothing.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Implementation of [`double_exponential_smoothing`](https://prometheus.io/docs/prometheus/latest/querying/functions/#double_exponential_smoothing) in PromQL. Refer to the [original
16//! implementation](https://github.com/prometheus/prometheus/blob/8dba9163f1e923ec213f0f4d5c185d9648e387f0/promql/functions.go#L299).
17
18use std::sync::Arc;
19
20use datafusion::arrow::array::Float64Array;
21use datafusion::arrow::datatypes::TimeUnit;
22use datafusion::common::DataFusionError;
23use datafusion::logical_expr::{ScalarUDF, Volatility};
24use datafusion::physical_plan::ColumnarValue;
25use datafusion_common::ScalarValue;
26use datafusion_expr::create_udf;
27use datatypes::arrow::array::Array;
28use datatypes::arrow::datatypes::DataType;
29
30use crate::error;
31use crate::functions::extract_array;
32use crate::range_array::RangeArray;
33
34/// `FactorIterator` iterates over a `ColumnarValue` that can be a scalar or an array.
35struct FactorIterator<'a> {
36    is_scalar: bool,
37    array: Option<&'a Float64Array>,
38    scalar_val: f64,
39    index: usize,
40    len: usize,
41}
42
43impl<'a> FactorIterator<'a> {
44    fn new(value: &'a ColumnarValue, len: usize) -> Self {
45        let (is_scalar, array, scalar_val) = match value {
46            ColumnarValue::Array(arr) => {
47                (false, arr.as_any().downcast_ref::<Float64Array>(), f64::NAN)
48            }
49            ColumnarValue::Scalar(ScalarValue::Float64(Some(val))) => (true, None, *val),
50            _ => (true, None, f64::NAN),
51        };
52
53        Self {
54            is_scalar,
55            array,
56            scalar_val,
57            index: 0,
58            len,
59        }
60    }
61}
62
63impl<'a> Iterator for FactorIterator<'a> {
64    type Item = f64;
65
66    fn next(&mut self) -> Option<Self::Item> {
67        if self.index >= self.len {
68            return None;
69        }
70        self.index += 1;
71
72        if self.is_scalar {
73            return Some(self.scalar_val);
74        }
75
76        if let Some(array) = self.array {
77            if array.is_null(self.index - 1) {
78                Some(f64::NAN)
79            } else {
80                Some(array.value(self.index - 1))
81            }
82        } else {
83            Some(f64::NAN)
84        }
85    }
86}
87
88/// There are 3 variants of smoothing functions:
89/// 1) "Simple exponential smoothing": only the `level` component (the weighted average of the observations) is used to make forecasts.
90///    This method is applied for time-series data that does not exhibit trend or seasonality.
91/// 2) "Holt's linear method" (a.k.a. "double exponential smoothing"): `level` and `trend` components are used to make forecasts.
92///    This method is applied for time-series data that exhibits trend but not seasonality.
93/// 3) "Holt-Winter's method" (a.k.a. "triple exponential smoothing"): `level`, `trend`, and `seasonality` are used to make forecasts.
94///
95/// This method is applied for time-series data that exhibits both trend and seasonality.
96///
97/// Prometheus used to expose this algorithm as `holt_winters`, even though it
98/// implements Holt's linear method ("double exponential smoothing") rather than
99/// Holt-Winters triple exponential smoothing. Prometheus 3.x renamed it to
100/// `double_exponential_smoothing`.
101/// See [discussion](https://github.com/prometheus/prometheus/issues/2458).
102pub struct DoubleExponentialSmoothing;
103
104impl DoubleExponentialSmoothing {
105    pub const fn name() -> &'static str {
106        "prom_double_exponential_smoothing"
107    }
108
109    // time index column and value column
110    fn input_type() -> Vec<DataType> {
111        vec![
112            RangeArray::convert_data_type(DataType::Timestamp(TimeUnit::Millisecond, None)),
113            RangeArray::convert_data_type(DataType::Float64),
114            // sf
115            DataType::Float64,
116            // tf
117            DataType::Float64,
118        ]
119    }
120
121    fn return_type() -> DataType {
122        DataType::Float64
123    }
124
125    pub fn scalar_udf() -> ScalarUDF {
126        create_udf(
127            Self::name(),
128            Self::input_type(),
129            Self::return_type(),
130            Volatility::Volatile,
131            Arc::new(Self::double_exponential_smoothing) as _,
132        )
133    }
134
135    fn double_exponential_smoothing(
136        input: &[ColumnarValue],
137    ) -> Result<ColumnarValue, DataFusionError> {
138        error::ensure(
139            input.len() == 4,
140            DataFusionError::Plan(
141                "prom_double_exponential_smoothing function should have 4 inputs".to_string(),
142            ),
143        )?;
144
145        let ts_array = extract_array(&input[0])?;
146        let value_array = extract_array(&input[1])?;
147        let sf_col = &input[2];
148        let tf_col = &input[3];
149
150        let ts_range: RangeArray = RangeArray::try_new(ts_array.to_data().into())?;
151        let value_range: RangeArray = RangeArray::try_new(value_array.to_data().into())?;
152        let num_rows = ts_range.len();
153
154        error::ensure(
155            num_rows == value_range.len(),
156            DataFusionError::Execution(format!(
157                "{}: input arrays should have the same length, found {} and {}",
158                Self::name(),
159                num_rows,
160                value_range.len()
161            )),
162        )?;
163        error::ensure(
164            ts_range.value_type() == DataType::Timestamp(TimeUnit::Millisecond, None),
165            DataFusionError::Execution(format!(
166                "{}: expect TimestampMillisecond as time index array's type, found {}",
167                Self::name(),
168                ts_range.value_type()
169            )),
170        )?;
171        error::ensure(
172            value_range.value_type() == DataType::Float64,
173            DataFusionError::Execution(format!(
174                "{}: expect Float64 as value array's type, found {}",
175                Self::name(),
176                value_range.value_type()
177            )),
178        )?;
179
180        // calculation
181        let mut result_array = Vec::with_capacity(ts_range.len());
182
183        let sf_iter = FactorIterator::new(sf_col, num_rows);
184        let tf_iter = FactorIterator::new(tf_col, num_rows);
185
186        let iter = (0..num_rows)
187            .map(|i| (ts_range.get(i), value_range.get(i)))
188            .zip(sf_iter.zip(tf_iter));
189
190        for ((timestamps, values), (sf, tf)) in iter {
191            let timestamps = timestamps.unwrap();
192            let values = values.unwrap();
193            let values = values
194                .as_any()
195                .downcast_ref::<Float64Array>()
196                .unwrap()
197                .values();
198            error::ensure(
199                timestamps.len() == values.len(),
200                DataFusionError::Execution(format!(
201                    "{}: input arrays should have the same length, found {} and {}",
202                    Self::name(),
203                    timestamps.len(),
204                    values.len()
205                )),
206            )?;
207
208            result_array.push(double_exponential_smoothing_impl(values, sf, tf));
209        }
210
211        let result = ColumnarValue::Array(Arc::new(Float64Array::from_iter(result_array)));
212        Ok(result)
213    }
214}
215
216fn calc_trend_value(i: usize, tf: f64, s0: f64, s1: f64, b: f64) -> f64 {
217    if i == 0 {
218        return b;
219    }
220    let x = tf * (s1 - s0);
221    let y = (1.0 - tf) * b;
222    x + y
223}
224
225/// Refer to <https://github.com/prometheus/prometheus/blob/main/promql/functions.go#L299>
226fn double_exponential_smoothing_impl(values: &[f64], sf: f64, tf: f64) -> Option<f64> {
227    if sf.is_nan() || tf.is_nan() || values.is_empty() {
228        return Some(f64::NAN);
229    }
230    if sf < 0.0 || tf < 0.0 {
231        return Some(f64::NEG_INFINITY);
232    }
233    if sf > 1.0 || tf > 1.0 {
234        return Some(f64::INFINITY);
235    }
236
237    let l = values.len();
238    if l <= 2 {
239        // Can't do the smoothing operation with less than two points.
240        return Some(f64::NAN);
241    }
242
243    let values = values.to_vec();
244
245    let mut s0 = 0.0;
246    let mut s1 = values[0];
247    let mut b = values[1] - values[0];
248
249    for (i, value) in values.iter().enumerate().skip(1) {
250        // Scale the raw value against the smoothing factor.
251        let x = sf * value;
252        // Scale the last smoothed value with the trend at this point.
253        b = calc_trend_value(i - 1, tf, s0, s1, b);
254        let y = (1.0 - sf) * (s1 + b);
255        s0 = s1;
256        s1 = x + y;
257    }
258    Some(s1)
259}
260
261#[cfg(test)]
262mod tests {
263    use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
264
265    use super::*;
266    use crate::functions::test_util::simple_range_udf_runner;
267
268    #[test]
269    fn test_double_exponential_smoothing_impl_empty() {
270        let sf = 0.5;
271        let tf = 0.5;
272        let values = &[];
273        assert!(
274            double_exponential_smoothing_impl(values, sf, tf)
275                .unwrap()
276                .is_nan()
277        );
278
279        let values = &[1.0, 2.0];
280        assert!(
281            double_exponential_smoothing_impl(values, sf, tf)
282                .unwrap()
283                .is_nan()
284        );
285    }
286
287    #[test]
288    fn test_double_exponential_smoothing_impl_nan() {
289        let values = &[1.0, 2.0, 3.0];
290        let sf = f64::NAN;
291        let tf = 0.5;
292        assert!(
293            double_exponential_smoothing_impl(values, sf, tf)
294                .unwrap()
295                .is_nan()
296        );
297
298        let values = &[1.0, 2.0, 3.0];
299        let sf = 0.5;
300        let tf = f64::NAN;
301        assert!(
302            double_exponential_smoothing_impl(values, sf, tf)
303                .unwrap()
304                .is_nan()
305        );
306    }
307
308    #[test]
309    fn test_double_exponential_smoothing_impl_validation_rules() {
310        let values = &[1.0, 2.0, 3.0];
311        let sf = -0.5;
312        let tf = 0.5;
313        assert_eq!(
314            double_exponential_smoothing_impl(values, sf, tf).unwrap(),
315            f64::NEG_INFINITY
316        );
317
318        let values = &[1.0, 2.0, 3.0];
319        let sf = 0.5;
320        let tf = -0.5;
321        assert_eq!(
322            double_exponential_smoothing_impl(values, sf, tf).unwrap(),
323            f64::NEG_INFINITY
324        );
325
326        let values = &[1.0, 2.0, 3.0];
327        let sf = 1.5;
328        let tf = 0.5;
329        assert_eq!(
330            double_exponential_smoothing_impl(values, sf, tf).unwrap(),
331            f64::INFINITY
332        );
333
334        let values = &[1.0, 2.0, 3.0];
335        let sf = 0.5;
336        let tf = 1.5;
337        assert_eq!(
338            double_exponential_smoothing_impl(values, sf, tf).unwrap(),
339            f64::INFINITY
340        );
341    }
342
343    #[test]
344    fn test_double_exponential_smoothing_impl() {
345        let sf = 0.5;
346        let tf = 0.1;
347        let values = &[1.0, 2.0, 3.0, 4.0, 5.0];
348        assert_eq!(double_exponential_smoothing_impl(values, sf, tf), Some(5.0));
349        let values = &[50.0, 52.0, 95.0, 59.0, 52.0, 45.0, 38.0, 10.0, 47.0, 40.0];
350        assert_eq!(
351            double_exponential_smoothing_impl(values, sf, tf),
352            Some(38.18119566835938)
353        );
354    }
355
356    #[test]
357    fn test_prom_double_exponential_smoothing_monotonic() {
358        let ranges = [(0, 5)];
359        let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
360            [1000i64, 3000, 5000, 7000, 9000, 11000, 13000, 15000, 17000]
361                .into_iter()
362                .map(Some),
363        ));
364        let values_array = Arc::new(Float64Array::from_iter([1.0, 2.0, 3.0, 4.0, 5.0]));
365        let ts_range_array = RangeArray::from_ranges(ts_array, ranges).unwrap();
366        let value_range_array = RangeArray::from_ranges(values_array, ranges).unwrap();
367        simple_range_udf_runner(
368            DoubleExponentialSmoothing::scalar_udf(),
369            ts_range_array,
370            value_range_array,
371            vec![
372                ScalarValue::Float64(Some(0.5)),
373                ScalarValue::Float64(Some(0.1)),
374            ],
375            vec![Some(5.0)],
376        );
377    }
378
379    #[test]
380    fn test_prom_double_exponential_smoothing_non_monotonic() {
381        let ranges = [(0, 10)];
382        let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
383            [
384                1000i64, 3000, 5000, 7000, 9000, 11000, 13000, 15000, 17000, 19000,
385            ]
386            .into_iter()
387            .map(Some),
388        ));
389        let values_array = Arc::new(Float64Array::from_iter([
390            50.0, 52.0, 95.0, 59.0, 52.0, 45.0, 38.0, 10.0, 47.0, 40.0,
391        ]));
392        let ts_range_array = RangeArray::from_ranges(ts_array, ranges).unwrap();
393        let value_range_array = RangeArray::from_ranges(values_array, ranges).unwrap();
394        simple_range_udf_runner(
395            DoubleExponentialSmoothing::scalar_udf(),
396            ts_range_array,
397            value_range_array,
398            vec![
399                ScalarValue::Float64(Some(0.5)),
400                ScalarValue::Float64(Some(0.1)),
401            ],
402            vec![Some(38.18119566835938)],
403        );
404    }
405
406    #[test]
407    fn test_promql_trends() {
408        let ranges = vec![(0, 801)];
409
410        let trends = vec![
411            // positive trends https://github.com/prometheus/prometheus/blob/8dba9163f1e923ec213f0f4d5c185d9648e387f0/promql/testdata/functions.test#L475
412            ("0+10x1000 100+30x1000", 8000.0),
413            ("0+20x1000 200+30x1000", 16000.0),
414            ("0+30x1000 300+80x1000", 24000.0),
415            ("0+40x2000", 32000.0),
416            // negative trends https://github.com/prometheus/prometheus/blob/8dba9163f1e923ec213f0f4d5c185d9648e387f0/promql/testdata/functions.test#L488
417            ("8000-10x1000", 0.0),
418            ("0-20x1000", -16000.0),
419            ("0+30x1000 300-80x1000", 24000.0),
420            ("0-40x1000 0+40x1000", -32000.0),
421        ];
422
423        for (query, expected) in trends {
424            let (ts_range_array, value_range_array) =
425                create_ts_and_value_range_arrays(query, ranges.clone());
426            simple_range_udf_runner(
427                DoubleExponentialSmoothing::scalar_udf(),
428                ts_range_array,
429                value_range_array,
430                vec![
431                    ScalarValue::Float64(Some(0.01)),
432                    ScalarValue::Float64(Some(0.1)),
433                ],
434                vec![Some(expected)],
435            );
436        }
437    }
438
439    fn create_ts_and_value_range_arrays(
440        input: &str,
441        ranges: Vec<(u32, u32)>,
442    ) -> (RangeArray, RangeArray) {
443        let promql_range = create_test_range_from_promql_series(input);
444        let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
445            (0..(promql_range.len() as i64)).map(Some),
446        ));
447        let values_array = Arc::new(Float64Array::from_iter(promql_range));
448        let ts_range_array = RangeArray::from_ranges(ts_array, ranges.clone()).unwrap();
449        let value_range_array = RangeArray::from_ranges(values_array, ranges).unwrap();
450        (ts_range_array, value_range_array)
451    }
452
453    /// Converts a prometheus functions test series into a vector of f64 element with respect to resets and trend direction   
454    /// The input example: "0+10x1000 100+30x1000"
455    fn create_test_range_from_promql_series(input: &str) -> Vec<f64> {
456        input.split(' ').map(parse_promql_series_entry).fold(
457            Vec::new(),
458            |mut acc, (start, end, step, operation)| {
459                if operation.eq("+") {
460                    let iter = (start..=((step * end) + start))
461                        .step_by(step as usize)
462                        .map(|x| x as f64);
463                    acc.extend(iter);
464                } else {
465                    let iter = (((-step * end) + start)..=start)
466                        .rev()
467                        .step_by(step as usize)
468                        .map(|x| x as f64);
469                    acc.extend(iter);
470                };
471                acc
472            },
473        )
474    }
475
476    /// Converts a prometheus functions test series entry into separate parts to create a range with a step
477    /// The input example: "100+30x1000"
478    fn parse_promql_series_entry(input: &str) -> (i32, i32, i32, &str) {
479        let mut parts = input.split('x');
480        let start_operation_step = parts.next().unwrap();
481        let operation = start_operation_step
482            .split(char::is_numeric)
483            .find(|&x| !x.is_empty())
484            .unwrap();
485        let start_step = start_operation_step
486            .split(operation)
487            .map(|s| s.parse::<i32>().unwrap())
488            .collect::<Vec<_>>();
489        let start = *start_step.first().unwrap();
490        let step = *start_step.last().unwrap();
491        let end = parts.next().unwrap().parse::<i32>().unwrap();
492        (start, end, step, operation)
493    }
494}