1use std::sync::Arc;
19
20use datafusion::arrow::array::Float64Array;
21use datafusion::arrow::datatypes::TimeUnit;
22use datafusion::common::DataFusionError;
23use datafusion::logical_expr::{ScalarUDF, Volatility};
24use datafusion::physical_plan::ColumnarValue;
25use datafusion_common::ScalarValue;
26use datafusion_expr::create_udf;
27use datatypes::arrow::array::Array;
28use datatypes::arrow::datatypes::DataType;
29
30use crate::error;
31use crate::functions::extract_array;
32use crate::range_array::RangeArray;
33
34pub struct HoltWinters {
48 sf: f64,
49 tf: f64,
50}
51
52impl HoltWinters {
53 fn new(sf: f64, tf: f64) -> Self {
54 Self { sf, tf }
55 }
56
57 pub const fn name() -> &'static str {
58 "prom_holt_winters"
59 }
60
61 fn input_type() -> Vec<DataType> {
63 vec![
64 RangeArray::convert_data_type(DataType::Timestamp(TimeUnit::Millisecond, None)),
65 RangeArray::convert_data_type(DataType::Float64),
66 DataType::Float64,
68 DataType::Float64,
70 ]
71 }
72
73 fn return_type() -> DataType {
74 DataType::Float64
75 }
76
77 pub fn scalar_udf() -> ScalarUDF {
78 create_udf(
79 Self::name(),
80 Self::input_type(),
81 Self::return_type(),
82 Volatility::Volatile,
83 Arc::new(move |input: &_| Self::create_function(input)?.calc(input)) as _,
84 )
85 }
86
87 fn create_function(inputs: &[ColumnarValue]) -> Result<Self, DataFusionError> {
88 if inputs.len() != 4 {
89 return Err(DataFusionError::Plan(
90 "HoltWinters function should have 4 inputs".to_string(),
91 ));
92 }
93 let ColumnarValue::Scalar(ScalarValue::Float64(Some(sf))) = inputs[2] else {
94 return Err(DataFusionError::Plan(
95 "HoltWinters function's third input should be a scalar float64".to_string(),
96 ));
97 };
98 let ColumnarValue::Scalar(ScalarValue::Float64(Some(tf))) = inputs[3] else {
99 return Err(DataFusionError::Plan(
100 "HoltWinters function's fourth input should be a scalar float64".to_string(),
101 ));
102 };
103 Ok(Self::new(sf, tf))
104 }
105
106 fn calc(&self, input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
107 assert_eq!(input.len(), 4);
110
111 let ts_array = extract_array(&input[0])?;
112 let value_array = extract_array(&input[1])?;
113
114 let ts_range: RangeArray = RangeArray::try_new(ts_array.to_data().into())?;
115 let value_range: RangeArray = RangeArray::try_new(value_array.to_data().into())?;
116
117 error::ensure(
118 ts_range.len() == value_range.len(),
119 DataFusionError::Execution(format!(
120 "{}: input arrays should have the same length, found {} and {}",
121 Self::name(),
122 ts_range.len(),
123 value_range.len()
124 )),
125 )?;
126 error::ensure(
127 ts_range.value_type() == DataType::Timestamp(TimeUnit::Millisecond, None),
128 DataFusionError::Execution(format!(
129 "{}: expect TimestampMillisecond as time index array's type, found {}",
130 Self::name(),
131 ts_range.value_type()
132 )),
133 )?;
134 error::ensure(
135 value_range.value_type() == DataType::Float64,
136 DataFusionError::Execution(format!(
137 "{}: expect Float64 as value array's type, found {}",
138 Self::name(),
139 value_range.value_type()
140 )),
141 )?;
142
143 let mut result_array = Vec::with_capacity(ts_range.len());
145 for index in 0..ts_range.len() {
146 let timestamps = ts_range.get(index).unwrap();
147 let values = value_range.get(index).unwrap();
148 let values = values
149 .as_any()
150 .downcast_ref::<Float64Array>()
151 .unwrap()
152 .values();
153 error::ensure(
154 timestamps.len() == values.len(),
155 DataFusionError::Execution(format!(
156 "{}: input arrays should have the same length, found {} and {}",
157 Self::name(),
158 timestamps.len(),
159 values.len()
160 )),
161 )?;
162 result_array.push(holt_winter_impl(values, self.sf, self.tf));
163 }
164
165 let result = ColumnarValue::Array(Arc::new(Float64Array::from_iter(result_array)));
166 Ok(result)
167 }
168}
169
170fn calc_trend_value(i: usize, tf: f64, s0: f64, s1: f64, b: f64) -> f64 {
171 if i == 0 {
172 return b;
173 }
174 let x = tf * (s1 - s0);
175 let y = (1.0 - tf) * b;
176 x + y
177}
178
179fn holt_winter_impl(values: &[f64], sf: f64, tf: f64) -> Option<f64> {
181 if sf.is_nan() || tf.is_nan() || values.is_empty() {
182 return Some(f64::NAN);
183 }
184 if sf < 0.0 || tf < 0.0 {
185 return Some(f64::NEG_INFINITY);
186 }
187 if sf > 1.0 || tf > 1.0 {
188 return Some(f64::INFINITY);
189 }
190
191 let l = values.len();
192 if l <= 2 {
193 return Some(f64::NAN);
195 }
196
197 let values = values.to_vec();
198
199 let mut s0 = 0.0;
200 let mut s1 = values[0];
201 let mut b = values[1] - values[0];
202
203 for (i, value) in values.iter().enumerate().skip(1) {
204 let x = sf * value;
206 b = calc_trend_value(i - 1, tf, s0, s1, b);
208 let y = (1.0 - sf) * (s1 + b);
209 s0 = s1;
210 s1 = x + y;
211 }
212 Some(s1)
213}
214
215#[cfg(test)]
216mod tests {
217 use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
218
219 use super::*;
220 use crate::functions::test_util::simple_range_udf_runner;
221
222 #[test]
223 fn test_holt_winter_impl_empty() {
224 let sf = 0.5;
225 let tf = 0.5;
226 let values = &[];
227 assert!(holt_winter_impl(values, sf, tf).unwrap().is_nan());
228
229 let values = &[1.0, 2.0];
230 assert!(holt_winter_impl(values, sf, tf).unwrap().is_nan());
231 }
232
233 #[test]
234 fn test_holt_winter_impl_nan() {
235 let values = &[1.0, 2.0, 3.0];
236 let sf = f64::NAN;
237 let tf = 0.5;
238 assert!(holt_winter_impl(values, sf, tf).unwrap().is_nan());
239
240 let values = &[1.0, 2.0, 3.0];
241 let sf = 0.5;
242 let tf = f64::NAN;
243 assert!(holt_winter_impl(values, sf, tf).unwrap().is_nan());
244 }
245
246 #[test]
247 fn test_holt_winter_impl_validation_rules() {
248 let values = &[1.0, 2.0, 3.0];
249 let sf = -0.5;
250 let tf = 0.5;
251 assert_eq!(holt_winter_impl(values, sf, tf).unwrap(), f64::NEG_INFINITY);
252
253 let values = &[1.0, 2.0, 3.0];
254 let sf = 0.5;
255 let tf = -0.5;
256 assert_eq!(holt_winter_impl(values, sf, tf).unwrap(), f64::NEG_INFINITY);
257
258 let values = &[1.0, 2.0, 3.0];
259 let sf = 1.5;
260 let tf = 0.5;
261 assert_eq!(holt_winter_impl(values, sf, tf).unwrap(), f64::INFINITY);
262
263 let values = &[1.0, 2.0, 3.0];
264 let sf = 0.5;
265 let tf = 1.5;
266 assert_eq!(holt_winter_impl(values, sf, tf).unwrap(), f64::INFINITY);
267 }
268
269 #[test]
270 fn test_holt_winter_impl() {
271 let sf = 0.5;
272 let tf = 0.1;
273 let values = &[1.0, 2.0, 3.0, 4.0, 5.0];
274 assert_eq!(holt_winter_impl(values, sf, tf), Some(5.0));
275 let values = &[50.0, 52.0, 95.0, 59.0, 52.0, 45.0, 38.0, 10.0, 47.0, 40.0];
276 assert_eq!(holt_winter_impl(values, sf, tf), Some(38.18119566835938));
277 }
278
279 #[test]
280 fn test_prom_holt_winter_monotonic() {
281 let ranges = [(0, 5)];
282 let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
283 [1000i64, 3000, 5000, 7000, 9000, 11000, 13000, 15000, 17000]
284 .into_iter()
285 .map(Some),
286 ));
287 let values_array = Arc::new(Float64Array::from_iter([1.0, 2.0, 3.0, 4.0, 5.0]));
288 let ts_range_array = RangeArray::from_ranges(ts_array, ranges).unwrap();
289 let value_range_array = RangeArray::from_ranges(values_array, ranges).unwrap();
290 simple_range_udf_runner(
291 HoltWinters::scalar_udf(),
292 ts_range_array,
293 value_range_array,
294 vec![
295 ScalarValue::Float64(Some(0.5)),
296 ScalarValue::Float64(Some(0.1)),
297 ],
298 vec![Some(5.0)],
299 );
300 }
301
302 #[test]
303 fn test_prom_holt_winter_non_monotonic() {
304 let ranges = [(0, 10)];
305 let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
306 [
307 1000i64, 3000, 5000, 7000, 9000, 11000, 13000, 15000, 17000, 19000,
308 ]
309 .into_iter()
310 .map(Some),
311 ));
312 let values_array = Arc::new(Float64Array::from_iter([
313 50.0, 52.0, 95.0, 59.0, 52.0, 45.0, 38.0, 10.0, 47.0, 40.0,
314 ]));
315 let ts_range_array = RangeArray::from_ranges(ts_array, ranges).unwrap();
316 let value_range_array = RangeArray::from_ranges(values_array, ranges).unwrap();
317 simple_range_udf_runner(
318 HoltWinters::scalar_udf(),
319 ts_range_array,
320 value_range_array,
321 vec![
322 ScalarValue::Float64(Some(0.5)),
323 ScalarValue::Float64(Some(0.1)),
324 ],
325 vec![Some(38.18119566835938)],
326 );
327 }
328
329 #[test]
330 fn test_promql_trends() {
331 let ranges = vec![(0, 801)];
332
333 let trends = vec![
334 ("0+10x1000 100+30x1000", 8000.0),
336 ("0+20x1000 200+30x1000", 16000.0),
337 ("0+30x1000 300+80x1000", 24000.0),
338 ("0+40x2000", 32000.0),
339 ("8000-10x1000", 0.0),
341 ("0-20x1000", -16000.0),
342 ("0+30x1000 300-80x1000", 24000.0),
343 ("0-40x1000 0+40x1000", -32000.0),
344 ];
345
346 for (query, expected) in trends {
347 let (ts_range_array, value_range_array) =
348 create_ts_and_value_range_arrays(query, ranges.clone());
349 simple_range_udf_runner(
350 HoltWinters::scalar_udf(),
351 ts_range_array,
352 value_range_array,
353 vec![
354 ScalarValue::Float64(Some(0.01)),
355 ScalarValue::Float64(Some(0.1)),
356 ],
357 vec![Some(expected)],
358 );
359 }
360 }
361
362 fn create_ts_and_value_range_arrays(
363 input: &str,
364 ranges: Vec<(u32, u32)>,
365 ) -> (RangeArray, RangeArray) {
366 let promql_range = create_test_range_from_promql_series(input);
367 let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
368 (0..(promql_range.len() as i64)).map(Some),
369 ));
370 let values_array = Arc::new(Float64Array::from_iter(promql_range));
371 let ts_range_array = RangeArray::from_ranges(ts_array, ranges.clone()).unwrap();
372 let value_range_array = RangeArray::from_ranges(values_array, ranges).unwrap();
373 (ts_range_array, value_range_array)
374 }
375
376 fn create_test_range_from_promql_series(input: &str) -> Vec<f64> {
379 input.split(' ').map(parse_promql_series_entry).fold(
380 Vec::new(),
381 |mut acc, (start, end, step, operation)| {
382 if operation.eq("+") {
383 let iter = (start..=((step * end) + start))
384 .step_by(step as usize)
385 .map(|x| x as f64);
386 acc.extend(iter);
387 } else {
388 let iter = (((-step * end) + start)..=start)
389 .rev()
390 .step_by(step as usize)
391 .map(|x| x as f64);
392 acc.extend(iter);
393 };
394 acc
395 },
396 )
397 }
398
399 fn parse_promql_series_entry(input: &str) -> (i32, i32, i32, &str) {
402 let mut parts = input.split('x');
403 let start_operation_step = parts.next().unwrap();
404 let operation = start_operation_step
405 .split(char::is_numeric)
406 .find(|&x| !x.is_empty())
407 .unwrap();
408 let start_step = start_operation_step
409 .split(operation)
410 .map(|s| s.parse::<i32>().unwrap())
411 .collect::<Vec<_>>();
412 let start = *start_step.first().unwrap();
413 let step = *start_step.last().unwrap();
414 let end = parts.next().unwrap().parse::<i32>().unwrap();
415 (start, end, step, operation)
416 }
417}