1use std::sync::Arc;
19
20use datafusion::arrow::array::Float64Array;
21use datafusion::arrow::datatypes::TimeUnit;
22use datafusion::common::DataFusionError;
23use datafusion::logical_expr::{ScalarUDF, Volatility};
24use datafusion::physical_plan::ColumnarValue;
25use datafusion_common::ScalarValue;
26use datafusion_expr::create_udf;
27use datatypes::arrow::array::Array;
28use datatypes::arrow::datatypes::DataType;
29
30use crate::error;
31use crate::functions::extract_array;
32use crate::range_array::RangeArray;
33
34struct FactorIterator<'a> {
36 is_scalar: bool,
37 array: Option<&'a Float64Array>,
38 scalar_val: f64,
39 index: usize,
40 len: usize,
41}
42
43impl<'a> FactorIterator<'a> {
44 fn new(value: &'a ColumnarValue, len: usize) -> Self {
45 let (is_scalar, array, scalar_val) = match value {
46 ColumnarValue::Array(arr) => {
47 (false, arr.as_any().downcast_ref::<Float64Array>(), f64::NAN)
48 }
49 ColumnarValue::Scalar(ScalarValue::Float64(Some(val))) => (true, None, *val),
50 _ => (true, None, f64::NAN),
51 };
52
53 Self {
54 is_scalar,
55 array,
56 scalar_val,
57 index: 0,
58 len,
59 }
60 }
61}
62
63impl<'a> Iterator for FactorIterator<'a> {
64 type Item = f64;
65
66 fn next(&mut self) -> Option<Self::Item> {
67 if self.index >= self.len {
68 return None;
69 }
70 self.index += 1;
71
72 if self.is_scalar {
73 return Some(self.scalar_val);
74 }
75
76 if let Some(array) = self.array {
77 if array.is_null(self.index - 1) {
78 Some(f64::NAN)
79 } else {
80 Some(array.value(self.index - 1))
81 }
82 } else {
83 Some(f64::NAN)
84 }
85 }
86}
87
88pub struct HoltWinters;
102
103impl HoltWinters {
104 pub const fn name() -> &'static str {
105 "prom_holt_winters"
106 }
107
108 fn input_type() -> Vec<DataType> {
110 vec![
111 RangeArray::convert_data_type(DataType::Timestamp(TimeUnit::Millisecond, None)),
112 RangeArray::convert_data_type(DataType::Float64),
113 DataType::Float64,
115 DataType::Float64,
117 ]
118 }
119
120 fn return_type() -> DataType {
121 DataType::Float64
122 }
123
124 pub fn scalar_udf() -> ScalarUDF {
125 create_udf(
126 Self::name(),
127 Self::input_type(),
128 Self::return_type(),
129 Volatility::Volatile,
130 Arc::new(Self::holt_winters) as _,
131 )
132 }
133
134 fn holt_winters(input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
135 error::ensure(
136 input.len() == 4,
137 DataFusionError::Plan("prom_holt_winters function should have 4 inputs".to_string()),
138 )?;
139
140 let ts_array = extract_array(&input[0])?;
141 let value_array = extract_array(&input[1])?;
142 let sf_col = &input[2];
143 let tf_col = &input[3];
144
145 let ts_range: RangeArray = RangeArray::try_new(ts_array.to_data().into())?;
146 let value_range: RangeArray = RangeArray::try_new(value_array.to_data().into())?;
147 let num_rows = ts_range.len();
148
149 error::ensure(
150 num_rows == value_range.len(),
151 DataFusionError::Execution(format!(
152 "{}: input arrays should have the same length, found {} and {}",
153 Self::name(),
154 num_rows,
155 value_range.len()
156 )),
157 )?;
158 error::ensure(
159 ts_range.value_type() == DataType::Timestamp(TimeUnit::Millisecond, None),
160 DataFusionError::Execution(format!(
161 "{}: expect TimestampMillisecond as time index array's type, found {}",
162 Self::name(),
163 ts_range.value_type()
164 )),
165 )?;
166 error::ensure(
167 value_range.value_type() == DataType::Float64,
168 DataFusionError::Execution(format!(
169 "{}: expect Float64 as value array's type, found {}",
170 Self::name(),
171 value_range.value_type()
172 )),
173 )?;
174
175 let mut result_array = Vec::with_capacity(ts_range.len());
177
178 let sf_iter = FactorIterator::new(sf_col, num_rows);
179 let tf_iter = FactorIterator::new(tf_col, num_rows);
180
181 let iter = (0..num_rows)
182 .map(|i| (ts_range.get(i), value_range.get(i)))
183 .zip(sf_iter.zip(tf_iter));
184
185 for ((timestamps, values), (sf, tf)) in iter {
186 let timestamps = timestamps.unwrap();
187 let values = values.unwrap();
188 let values = values
189 .as_any()
190 .downcast_ref::<Float64Array>()
191 .unwrap()
192 .values();
193 error::ensure(
194 timestamps.len() == values.len(),
195 DataFusionError::Execution(format!(
196 "{}: input arrays should have the same length, found {} and {}",
197 Self::name(),
198 timestamps.len(),
199 values.len()
200 )),
201 )?;
202
203 result_array.push(holt_winter_impl(values, sf, tf));
204 }
205
206 let result = ColumnarValue::Array(Arc::new(Float64Array::from_iter(result_array)));
207 Ok(result)
208 }
209}
210
211fn calc_trend_value(i: usize, tf: f64, s0: f64, s1: f64, b: f64) -> f64 {
212 if i == 0 {
213 return b;
214 }
215 let x = tf * (s1 - s0);
216 let y = (1.0 - tf) * b;
217 x + y
218}
219
220fn holt_winter_impl(values: &[f64], sf: f64, tf: f64) -> Option<f64> {
222 if sf.is_nan() || tf.is_nan() || values.is_empty() {
223 return Some(f64::NAN);
224 }
225 if sf < 0.0 || tf < 0.0 {
226 return Some(f64::NEG_INFINITY);
227 }
228 if sf > 1.0 || tf > 1.0 {
229 return Some(f64::INFINITY);
230 }
231
232 let l = values.len();
233 if l <= 2 {
234 return Some(f64::NAN);
236 }
237
238 let values = values.to_vec();
239
240 let mut s0 = 0.0;
241 let mut s1 = values[0];
242 let mut b = values[1] - values[0];
243
244 for (i, value) in values.iter().enumerate().skip(1) {
245 let x = sf * value;
247 b = calc_trend_value(i - 1, tf, s0, s1, b);
249 let y = (1.0 - sf) * (s1 + b);
250 s0 = s1;
251 s1 = x + y;
252 }
253 Some(s1)
254}
255
256#[cfg(test)]
257mod tests {
258 use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
259
260 use super::*;
261 use crate::functions::test_util::simple_range_udf_runner;
262
263 #[test]
264 fn test_holt_winter_impl_empty() {
265 let sf = 0.5;
266 let tf = 0.5;
267 let values = &[];
268 assert!(holt_winter_impl(values, sf, tf).unwrap().is_nan());
269
270 let values = &[1.0, 2.0];
271 assert!(holt_winter_impl(values, sf, tf).unwrap().is_nan());
272 }
273
274 #[test]
275 fn test_holt_winter_impl_nan() {
276 let values = &[1.0, 2.0, 3.0];
277 let sf = f64::NAN;
278 let tf = 0.5;
279 assert!(holt_winter_impl(values, sf, tf).unwrap().is_nan());
280
281 let values = &[1.0, 2.0, 3.0];
282 let sf = 0.5;
283 let tf = f64::NAN;
284 assert!(holt_winter_impl(values, sf, tf).unwrap().is_nan());
285 }
286
287 #[test]
288 fn test_holt_winter_impl_validation_rules() {
289 let values = &[1.0, 2.0, 3.0];
290 let sf = -0.5;
291 let tf = 0.5;
292 assert_eq!(holt_winter_impl(values, sf, tf).unwrap(), f64::NEG_INFINITY);
293
294 let values = &[1.0, 2.0, 3.0];
295 let sf = 0.5;
296 let tf = -0.5;
297 assert_eq!(holt_winter_impl(values, sf, tf).unwrap(), f64::NEG_INFINITY);
298
299 let values = &[1.0, 2.0, 3.0];
300 let sf = 1.5;
301 let tf = 0.5;
302 assert_eq!(holt_winter_impl(values, sf, tf).unwrap(), f64::INFINITY);
303
304 let values = &[1.0, 2.0, 3.0];
305 let sf = 0.5;
306 let tf = 1.5;
307 assert_eq!(holt_winter_impl(values, sf, tf).unwrap(), f64::INFINITY);
308 }
309
310 #[test]
311 fn test_holt_winter_impl() {
312 let sf = 0.5;
313 let tf = 0.1;
314 let values = &[1.0, 2.0, 3.0, 4.0, 5.0];
315 assert_eq!(holt_winter_impl(values, sf, tf), Some(5.0));
316 let values = &[50.0, 52.0, 95.0, 59.0, 52.0, 45.0, 38.0, 10.0, 47.0, 40.0];
317 assert_eq!(holt_winter_impl(values, sf, tf), Some(38.18119566835938));
318 }
319
320 #[test]
321 fn test_prom_holt_winter_monotonic() {
322 let ranges = [(0, 5)];
323 let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
324 [1000i64, 3000, 5000, 7000, 9000, 11000, 13000, 15000, 17000]
325 .into_iter()
326 .map(Some),
327 ));
328 let values_array = Arc::new(Float64Array::from_iter([1.0, 2.0, 3.0, 4.0, 5.0]));
329 let ts_range_array = RangeArray::from_ranges(ts_array, ranges).unwrap();
330 let value_range_array = RangeArray::from_ranges(values_array, ranges).unwrap();
331 simple_range_udf_runner(
332 HoltWinters::scalar_udf(),
333 ts_range_array,
334 value_range_array,
335 vec![
336 ScalarValue::Float64(Some(0.5)),
337 ScalarValue::Float64(Some(0.1)),
338 ],
339 vec![Some(5.0)],
340 );
341 }
342
343 #[test]
344 fn test_prom_holt_winter_non_monotonic() {
345 let ranges = [(0, 10)];
346 let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
347 [
348 1000i64, 3000, 5000, 7000, 9000, 11000, 13000, 15000, 17000, 19000,
349 ]
350 .into_iter()
351 .map(Some),
352 ));
353 let values_array = Arc::new(Float64Array::from_iter([
354 50.0, 52.0, 95.0, 59.0, 52.0, 45.0, 38.0, 10.0, 47.0, 40.0,
355 ]));
356 let ts_range_array = RangeArray::from_ranges(ts_array, ranges).unwrap();
357 let value_range_array = RangeArray::from_ranges(values_array, ranges).unwrap();
358 simple_range_udf_runner(
359 HoltWinters::scalar_udf(),
360 ts_range_array,
361 value_range_array,
362 vec![
363 ScalarValue::Float64(Some(0.5)),
364 ScalarValue::Float64(Some(0.1)),
365 ],
366 vec![Some(38.18119566835938)],
367 );
368 }
369
370 #[test]
371 fn test_promql_trends() {
372 let ranges = vec![(0, 801)];
373
374 let trends = vec![
375 ("0+10x1000 100+30x1000", 8000.0),
377 ("0+20x1000 200+30x1000", 16000.0),
378 ("0+30x1000 300+80x1000", 24000.0),
379 ("0+40x2000", 32000.0),
380 ("8000-10x1000", 0.0),
382 ("0-20x1000", -16000.0),
383 ("0+30x1000 300-80x1000", 24000.0),
384 ("0-40x1000 0+40x1000", -32000.0),
385 ];
386
387 for (query, expected) in trends {
388 let (ts_range_array, value_range_array) =
389 create_ts_and_value_range_arrays(query, ranges.clone());
390 simple_range_udf_runner(
391 HoltWinters::scalar_udf(),
392 ts_range_array,
393 value_range_array,
394 vec![
395 ScalarValue::Float64(Some(0.01)),
396 ScalarValue::Float64(Some(0.1)),
397 ],
398 vec![Some(expected)],
399 );
400 }
401 }
402
403 fn create_ts_and_value_range_arrays(
404 input: &str,
405 ranges: Vec<(u32, u32)>,
406 ) -> (RangeArray, RangeArray) {
407 let promql_range = create_test_range_from_promql_series(input);
408 let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
409 (0..(promql_range.len() as i64)).map(Some),
410 ));
411 let values_array = Arc::new(Float64Array::from_iter(promql_range));
412 let ts_range_array = RangeArray::from_ranges(ts_array, ranges.clone()).unwrap();
413 let value_range_array = RangeArray::from_ranges(values_array, ranges).unwrap();
414 (ts_range_array, value_range_array)
415 }
416
417 fn create_test_range_from_promql_series(input: &str) -> Vec<f64> {
420 input.split(' ').map(parse_promql_series_entry).fold(
421 Vec::new(),
422 |mut acc, (start, end, step, operation)| {
423 if operation.eq("+") {
424 let iter = (start..=((step * end) + start))
425 .step_by(step as usize)
426 .map(|x| x as f64);
427 acc.extend(iter);
428 } else {
429 let iter = (((-step * end) + start)..=start)
430 .rev()
431 .step_by(step as usize)
432 .map(|x| x as f64);
433 acc.extend(iter);
434 };
435 acc
436 },
437 )
438 }
439
440 fn parse_promql_series_entry(input: &str) -> (i32, i32, i32, &str) {
443 let mut parts = input.split('x');
444 let start_operation_step = parts.next().unwrap();
445 let operation = start_operation_step
446 .split(char::is_numeric)
447 .find(|&x| !x.is_empty())
448 .unwrap();
449 let start_step = start_operation_step
450 .split(operation)
451 .map(|s| s.parse::<i32>().unwrap())
452 .collect::<Vec<_>>();
453 let start = *start_step.first().unwrap();
454 let step = *start_step.last().unwrap();
455 let end = parts.next().unwrap().parse::<i32>().unwrap();
456 (start, end, step, operation)
457 }
458}