1use std::sync::Arc;
19
20use datafusion::arrow::array::Float64Array;
21use datafusion::arrow::datatypes::TimeUnit;
22use datafusion::common::DataFusionError;
23use datafusion::logical_expr::{ScalarUDF, Volatility};
24use datafusion::physical_plan::ColumnarValue;
25use datafusion_common::ScalarValue;
26use datafusion_expr::create_udf;
27use datatypes::arrow::array::Array;
28use datatypes::arrow::datatypes::DataType;
29
30use crate::error;
31use crate::functions::extract_array;
32use crate::range_array::RangeArray;
33
34struct FactorIterator<'a> {
36 is_scalar: bool,
37 array: Option<&'a Float64Array>,
38 scalar_val: f64,
39 index: usize,
40 len: usize,
41}
42
43impl<'a> FactorIterator<'a> {
44 fn new(value: &'a ColumnarValue, len: usize) -> Self {
45 let (is_scalar, array, scalar_val) = match value {
46 ColumnarValue::Array(arr) => {
47 (false, arr.as_any().downcast_ref::<Float64Array>(), f64::NAN)
48 }
49 ColumnarValue::Scalar(ScalarValue::Float64(Some(val))) => (true, None, *val),
50 _ => (true, None, f64::NAN),
51 };
52
53 Self {
54 is_scalar,
55 array,
56 scalar_val,
57 index: 0,
58 len,
59 }
60 }
61}
62
63impl<'a> Iterator for FactorIterator<'a> {
64 type Item = f64;
65
66 fn next(&mut self) -> Option<Self::Item> {
67 if self.index >= self.len {
68 return None;
69 }
70 self.index += 1;
71
72 if self.is_scalar {
73 return Some(self.scalar_val);
74 }
75
76 if let Some(array) = self.array {
77 if array.is_null(self.index - 1) {
78 Some(f64::NAN)
79 } else {
80 Some(array.value(self.index - 1))
81 }
82 } else {
83 Some(f64::NAN)
84 }
85 }
86}
87
88pub struct DoubleExponentialSmoothing;
103
104impl DoubleExponentialSmoothing {
105 pub const fn name() -> &'static str {
106 "prom_double_exponential_smoothing"
107 }
108
109 fn input_type() -> Vec<DataType> {
111 vec![
112 RangeArray::convert_data_type(DataType::Timestamp(TimeUnit::Millisecond, None)),
113 RangeArray::convert_data_type(DataType::Float64),
114 DataType::Float64,
116 DataType::Float64,
118 ]
119 }
120
121 fn return_type() -> DataType {
122 DataType::Float64
123 }
124
125 pub fn scalar_udf() -> ScalarUDF {
126 create_udf(
127 Self::name(),
128 Self::input_type(),
129 Self::return_type(),
130 Volatility::Volatile,
131 Arc::new(Self::double_exponential_smoothing) as _,
132 )
133 }
134
135 fn double_exponential_smoothing(
136 input: &[ColumnarValue],
137 ) -> Result<ColumnarValue, DataFusionError> {
138 error::ensure(
139 input.len() == 4,
140 DataFusionError::Plan(
141 "prom_double_exponential_smoothing function should have 4 inputs".to_string(),
142 ),
143 )?;
144
145 let ts_array = extract_array(&input[0])?;
146 let value_array = extract_array(&input[1])?;
147 let sf_col = &input[2];
148 let tf_col = &input[3];
149
150 let ts_range: RangeArray = RangeArray::try_new(ts_array.to_data().into())?;
151 let value_range: RangeArray = RangeArray::try_new(value_array.to_data().into())?;
152 let num_rows = ts_range.len();
153
154 error::ensure(
155 num_rows == value_range.len(),
156 DataFusionError::Execution(format!(
157 "{}: input arrays should have the same length, found {} and {}",
158 Self::name(),
159 num_rows,
160 value_range.len()
161 )),
162 )?;
163 error::ensure(
164 ts_range.value_type() == DataType::Timestamp(TimeUnit::Millisecond, None),
165 DataFusionError::Execution(format!(
166 "{}: expect TimestampMillisecond as time index array's type, found {}",
167 Self::name(),
168 ts_range.value_type()
169 )),
170 )?;
171 error::ensure(
172 value_range.value_type() == DataType::Float64,
173 DataFusionError::Execution(format!(
174 "{}: expect Float64 as value array's type, found {}",
175 Self::name(),
176 value_range.value_type()
177 )),
178 )?;
179
180 let mut result_array = Vec::with_capacity(ts_range.len());
182
183 let sf_iter = FactorIterator::new(sf_col, num_rows);
184 let tf_iter = FactorIterator::new(tf_col, num_rows);
185
186 let iter = (0..num_rows)
187 .map(|i| (ts_range.get(i), value_range.get(i)))
188 .zip(sf_iter.zip(tf_iter));
189
190 for ((timestamps, values), (sf, tf)) in iter {
191 let timestamps = timestamps.unwrap();
192 let values = values.unwrap();
193 let values = values
194 .as_any()
195 .downcast_ref::<Float64Array>()
196 .unwrap()
197 .values();
198 error::ensure(
199 timestamps.len() == values.len(),
200 DataFusionError::Execution(format!(
201 "{}: input arrays should have the same length, found {} and {}",
202 Self::name(),
203 timestamps.len(),
204 values.len()
205 )),
206 )?;
207
208 result_array.push(double_exponential_smoothing_impl(values, sf, tf));
209 }
210
211 let result = ColumnarValue::Array(Arc::new(Float64Array::from_iter(result_array)));
212 Ok(result)
213 }
214}
215
216fn calc_trend_value(i: usize, tf: f64, s0: f64, s1: f64, b: f64) -> f64 {
217 if i == 0 {
218 return b;
219 }
220 let x = tf * (s1 - s0);
221 let y = (1.0 - tf) * b;
222 x + y
223}
224
225fn double_exponential_smoothing_impl(values: &[f64], sf: f64, tf: f64) -> Option<f64> {
227 if sf.is_nan() || tf.is_nan() || values.is_empty() {
228 return Some(f64::NAN);
229 }
230 if sf < 0.0 || tf < 0.0 {
231 return Some(f64::NEG_INFINITY);
232 }
233 if sf > 1.0 || tf > 1.0 {
234 return Some(f64::INFINITY);
235 }
236
237 let l = values.len();
238 if l <= 2 {
239 return Some(f64::NAN);
241 }
242
243 let values = values.to_vec();
244
245 let mut s0 = 0.0;
246 let mut s1 = values[0];
247 let mut b = values[1] - values[0];
248
249 for (i, value) in values.iter().enumerate().skip(1) {
250 let x = sf * value;
252 b = calc_trend_value(i - 1, tf, s0, s1, b);
254 let y = (1.0 - sf) * (s1 + b);
255 s0 = s1;
256 s1 = x + y;
257 }
258 Some(s1)
259}
260
261#[cfg(test)]
262mod tests {
263 use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
264
265 use super::*;
266 use crate::functions::test_util::simple_range_udf_runner;
267
268 #[test]
269 fn test_double_exponential_smoothing_impl_empty() {
270 let sf = 0.5;
271 let tf = 0.5;
272 let values = &[];
273 assert!(
274 double_exponential_smoothing_impl(values, sf, tf)
275 .unwrap()
276 .is_nan()
277 );
278
279 let values = &[1.0, 2.0];
280 assert!(
281 double_exponential_smoothing_impl(values, sf, tf)
282 .unwrap()
283 .is_nan()
284 );
285 }
286
287 #[test]
288 fn test_double_exponential_smoothing_impl_nan() {
289 let values = &[1.0, 2.0, 3.0];
290 let sf = f64::NAN;
291 let tf = 0.5;
292 assert!(
293 double_exponential_smoothing_impl(values, sf, tf)
294 .unwrap()
295 .is_nan()
296 );
297
298 let values = &[1.0, 2.0, 3.0];
299 let sf = 0.5;
300 let tf = f64::NAN;
301 assert!(
302 double_exponential_smoothing_impl(values, sf, tf)
303 .unwrap()
304 .is_nan()
305 );
306 }
307
308 #[test]
309 fn test_double_exponential_smoothing_impl_validation_rules() {
310 let values = &[1.0, 2.0, 3.0];
311 let sf = -0.5;
312 let tf = 0.5;
313 assert_eq!(
314 double_exponential_smoothing_impl(values, sf, tf).unwrap(),
315 f64::NEG_INFINITY
316 );
317
318 let values = &[1.0, 2.0, 3.0];
319 let sf = 0.5;
320 let tf = -0.5;
321 assert_eq!(
322 double_exponential_smoothing_impl(values, sf, tf).unwrap(),
323 f64::NEG_INFINITY
324 );
325
326 let values = &[1.0, 2.0, 3.0];
327 let sf = 1.5;
328 let tf = 0.5;
329 assert_eq!(
330 double_exponential_smoothing_impl(values, sf, tf).unwrap(),
331 f64::INFINITY
332 );
333
334 let values = &[1.0, 2.0, 3.0];
335 let sf = 0.5;
336 let tf = 1.5;
337 assert_eq!(
338 double_exponential_smoothing_impl(values, sf, tf).unwrap(),
339 f64::INFINITY
340 );
341 }
342
343 #[test]
344 fn test_double_exponential_smoothing_impl() {
345 let sf = 0.5;
346 let tf = 0.1;
347 let values = &[1.0, 2.0, 3.0, 4.0, 5.0];
348 assert_eq!(double_exponential_smoothing_impl(values, sf, tf), Some(5.0));
349 let values = &[50.0, 52.0, 95.0, 59.0, 52.0, 45.0, 38.0, 10.0, 47.0, 40.0];
350 assert_eq!(
351 double_exponential_smoothing_impl(values, sf, tf),
352 Some(38.18119566835938)
353 );
354 }
355
356 #[test]
357 fn test_prom_double_exponential_smoothing_monotonic() {
358 let ranges = [(0, 5)];
359 let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
360 [1000i64, 3000, 5000, 7000, 9000, 11000, 13000, 15000, 17000]
361 .into_iter()
362 .map(Some),
363 ));
364 let values_array = Arc::new(Float64Array::from_iter([1.0, 2.0, 3.0, 4.0, 5.0]));
365 let ts_range_array = RangeArray::from_ranges(ts_array, ranges).unwrap();
366 let value_range_array = RangeArray::from_ranges(values_array, ranges).unwrap();
367 simple_range_udf_runner(
368 DoubleExponentialSmoothing::scalar_udf(),
369 ts_range_array,
370 value_range_array,
371 vec![
372 ScalarValue::Float64(Some(0.5)),
373 ScalarValue::Float64(Some(0.1)),
374 ],
375 vec![Some(5.0)],
376 );
377 }
378
379 #[test]
380 fn test_prom_double_exponential_smoothing_non_monotonic() {
381 let ranges = [(0, 10)];
382 let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
383 [
384 1000i64, 3000, 5000, 7000, 9000, 11000, 13000, 15000, 17000, 19000,
385 ]
386 .into_iter()
387 .map(Some),
388 ));
389 let values_array = Arc::new(Float64Array::from_iter([
390 50.0, 52.0, 95.0, 59.0, 52.0, 45.0, 38.0, 10.0, 47.0, 40.0,
391 ]));
392 let ts_range_array = RangeArray::from_ranges(ts_array, ranges).unwrap();
393 let value_range_array = RangeArray::from_ranges(values_array, ranges).unwrap();
394 simple_range_udf_runner(
395 DoubleExponentialSmoothing::scalar_udf(),
396 ts_range_array,
397 value_range_array,
398 vec![
399 ScalarValue::Float64(Some(0.5)),
400 ScalarValue::Float64(Some(0.1)),
401 ],
402 vec![Some(38.18119566835938)],
403 );
404 }
405
406 #[test]
407 fn test_promql_trends() {
408 let ranges = vec![(0, 801)];
409
410 let trends = vec![
411 ("0+10x1000 100+30x1000", 8000.0),
413 ("0+20x1000 200+30x1000", 16000.0),
414 ("0+30x1000 300+80x1000", 24000.0),
415 ("0+40x2000", 32000.0),
416 ("8000-10x1000", 0.0),
418 ("0-20x1000", -16000.0),
419 ("0+30x1000 300-80x1000", 24000.0),
420 ("0-40x1000 0+40x1000", -32000.0),
421 ];
422
423 for (query, expected) in trends {
424 let (ts_range_array, value_range_array) =
425 create_ts_and_value_range_arrays(query, ranges.clone());
426 simple_range_udf_runner(
427 DoubleExponentialSmoothing::scalar_udf(),
428 ts_range_array,
429 value_range_array,
430 vec![
431 ScalarValue::Float64(Some(0.01)),
432 ScalarValue::Float64(Some(0.1)),
433 ],
434 vec![Some(expected)],
435 );
436 }
437 }
438
439 fn create_ts_and_value_range_arrays(
440 input: &str,
441 ranges: Vec<(u32, u32)>,
442 ) -> (RangeArray, RangeArray) {
443 let promql_range = create_test_range_from_promql_series(input);
444 let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
445 (0..(promql_range.len() as i64)).map(Some),
446 ));
447 let values_array = Arc::new(Float64Array::from_iter(promql_range));
448 let ts_range_array = RangeArray::from_ranges(ts_array, ranges.clone()).unwrap();
449 let value_range_array = RangeArray::from_ranges(values_array, ranges).unwrap();
450 (ts_range_array, value_range_array)
451 }
452
453 fn create_test_range_from_promql_series(input: &str) -> Vec<f64> {
456 input.split(' ').map(parse_promql_series_entry).fold(
457 Vec::new(),
458 |mut acc, (start, end, step, operation)| {
459 if operation.eq("+") {
460 let iter = (start..=((step * end) + start))
461 .step_by(step as usize)
462 .map(|x| x as f64);
463 acc.extend(iter);
464 } else {
465 let iter = (((-step * end) + start)..=start)
466 .rev()
467 .step_by(step as usize)
468 .map(|x| x as f64);
469 acc.extend(iter);
470 };
471 acc
472 },
473 )
474 }
475
476 fn parse_promql_series_entry(input: &str) -> (i32, i32, i32, &str) {
479 let mut parts = input.split('x');
480 let start_operation_step = parts.next().unwrap();
481 let operation = start_operation_step
482 .split(char::is_numeric)
483 .find(|&x| !x.is_empty())
484 .unwrap();
485 let start_step = start_operation_step
486 .split(operation)
487 .map(|s| s.parse::<i32>().unwrap())
488 .collect::<Vec<_>>();
489 let start = *start_step.first().unwrap();
490 let step = *start_step.last().unwrap();
491 let end = parts.next().unwrap().parse::<i32>().unwrap();
492 (start, end, step, operation)
493 }
494}