1use std::sync::Arc;
19
20use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
21use datafusion::arrow::datatypes::TimeUnit;
22use datafusion::common::DataFusionError;
23use datafusion::logical_expr::{ScalarUDF, Volatility};
24use datafusion::physical_plan::ColumnarValue;
25use datafusion_common::ScalarValue;
26use datafusion_expr::create_udf;
27use datatypes::arrow::array::Array;
28use datatypes::arrow::datatypes::DataType;
29
30use crate::error;
31use crate::functions::{extract_array, linear_regression};
32use crate::range_array::RangeArray;
33
34pub struct PredictLinear;
35
36impl PredictLinear {
37 pub const fn name() -> &'static str {
38 "prom_predict_linear"
39 }
40
41 pub fn scalar_udf() -> ScalarUDF {
42 let input_types = vec![
43 RangeArray::convert_data_type(DataType::Timestamp(TimeUnit::Millisecond, None)),
45 RangeArray::convert_data_type(DataType::Float64),
47 DataType::Int64,
49 ];
50 create_udf(
51 Self::name(),
52 input_types,
53 DataType::Float64,
54 Volatility::Volatile,
55 Arc::new(Self::predict_linear) as _,
56 )
57 }
58
59 fn predict_linear(input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
60 error::ensure(
61 input.len() == 3,
62 DataFusionError::Plan("prom_predict_linear function should have 3 inputs".to_string()),
63 )?;
64
65 let ts_array = extract_array(&input[0])?;
66 let value_array = extract_array(&input[1])?;
67 let t_col = &input[2];
68
69 let ts_range: RangeArray = RangeArray::try_new(ts_array.to_data().into())?;
70 let value_range: RangeArray = RangeArray::try_new(value_array.to_data().into())?;
71 error::ensure(
72 ts_range.len() == value_range.len(),
73 DataFusionError::Execution(format!(
74 "{}: input arrays should have the same length, found {} and {}",
75 Self::name(),
76 ts_range.len(),
77 value_range.len()
78 )),
79 )?;
80 error::ensure(
81 ts_range.value_type() == DataType::Timestamp(TimeUnit::Millisecond, None),
82 DataFusionError::Execution(format!(
83 "{}: expect TimestampMillisecond as time index array's type, found {}",
84 Self::name(),
85 ts_range.value_type()
86 )),
87 )?;
88 error::ensure(
89 value_range.value_type() == DataType::Float64,
90 DataFusionError::Execution(format!(
91 "{}: expect Float64 as value array's type, found {}",
92 Self::name(),
93 value_range.value_type()
94 )),
95 )?;
96
97 let t_iter: Box<dyn Iterator<Item = Option<i64>>> = match t_col {
98 ColumnarValue::Scalar(t_scalar) => {
99 let t = if let ScalarValue::Int64(Some(t_val)) = t_scalar {
100 *t_val
101 } else {
102 let null_array = Float64Array::new_null(ts_range.len());
105 return Ok(ColumnarValue::Array(Arc::new(null_array)));
106 };
107 Box::new((0..ts_range.len()).map(move |_| Some(t)))
108 }
109 ColumnarValue::Array(t_array) => {
110 let t_array = t_array
111 .as_any()
112 .downcast_ref::<datafusion::arrow::array::Int64Array>()
113 .ok_or_else(|| {
114 DataFusionError::Execution(format!(
115 "{}: expect Int64 as t array's type, found {}",
116 Self::name(),
117 t_array.data_type()
118 ))
119 })?;
120 error::ensure(
121 t_array.len() == ts_range.len(),
122 DataFusionError::Execution(format!(
123 "{}: t array should have the same length as other columns, found {} and {}",
124 Self::name(),
125 t_array.len(),
126 ts_range.len()
127 )),
128 )?;
129
130 Box::new(t_array.iter())
131 }
132 };
133 let mut result_array = Vec::with_capacity(ts_range.len());
134 for (index, t) in t_iter.enumerate() {
135 let (timestamps, values) = get_ts_values(&ts_range, &value_range, index, Self::name())?;
136 let ret = predict_linear_impl(×tamps, &values, t.unwrap());
137 result_array.push(ret);
138 }
139
140 let result = ColumnarValue::Array(Arc::new(Float64Array::from_iter(result_array)));
141 Ok(result)
142 }
143}
144
145fn get_ts_values(
146 ts_range: &RangeArray,
147 value_range: &RangeArray,
148 index: usize,
149 func_name: &str,
150) -> Result<(TimestampMillisecondArray, Float64Array), DataFusionError> {
151 let timestamps = ts_range
152 .get(index)
153 .unwrap()
154 .as_any()
155 .downcast_ref::<TimestampMillisecondArray>()
156 .unwrap()
157 .clone();
158 let values = value_range
159 .get(index)
160 .unwrap()
161 .as_any()
162 .downcast_ref::<Float64Array>()
163 .unwrap()
164 .clone();
165 error::ensure(
166 timestamps.len() == values.len(),
167 DataFusionError::Execution(format!(
168 "{}: time and value arrays in a group should have the same length, found {} and {}",
169 func_name,
170 timestamps.len(),
171 values.len()
172 )),
173 )?;
174 Ok((timestamps, values))
175}
176
177fn predict_linear_impl(
178 timestamps: &TimestampMillisecondArray,
179 values: &Float64Array,
180 t: i64,
181) -> Option<f64> {
182 if timestamps.len() < 2 {
183 return None;
184 }
185
186 let evaluate_ts = timestamps.value(timestamps.len() - 1);
188 let (slope, intercept) = linear_regression(timestamps, values, evaluate_ts);
189
190 if slope.is_none() || intercept.is_none() {
191 return None;
192 }
193
194 Some(slope.unwrap() * t as f64 + intercept.unwrap())
195}
196
197#[cfg(test)]
198mod test {
199 use std::vec;
200
201 use super::*;
202 use crate::functions::test_util::simple_range_udf_runner;
203
204 fn build_test_range_arrays() -> (RangeArray, RangeArray) {
206 let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
207 [
208 0i64, 300, 600, 900, 1200, 1500, 1800, 2100, 2400, 2700, 3000,
209 ]
210 .into_iter()
211 .map(Some),
212 ));
213 let ranges = [(0, 11)];
214
215 let values_array = Arc::new(Float64Array::from_iter([
216 0.0, 10.0, 20.0, 30.0, 40.0, 0.0, 10.0, 20.0, 30.0, 40.0, 50.0,
217 ]));
218
219 let ts_range_array = RangeArray::from_ranges(ts_array, ranges).unwrap();
220 let value_range_array = RangeArray::from_ranges(values_array, ranges).unwrap();
221
222 (ts_range_array, value_range_array)
223 }
224
225 #[test]
226 fn calculate_predict_linear_none() {
227 let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
228 [0i64].into_iter().map(Some),
229 ));
230 let ranges = [(0, 0), (0, 1)];
231 let values_array = Arc::new(Float64Array::from_iter([0.0]));
232 let ts_array = RangeArray::from_ranges(ts_array, ranges).unwrap();
233 let value_array = RangeArray::from_ranges(values_array, ranges).unwrap();
234 simple_range_udf_runner(
235 PredictLinear::scalar_udf(),
236 ts_array,
237 value_array,
238 vec![ScalarValue::Int64(Some(0))],
239 vec![None, None],
240 );
241 }
242
243 #[test]
244 fn calculate_predict_linear_test1() {
245 let (ts_array, value_array) = build_test_range_arrays();
246 simple_range_udf_runner(
247 PredictLinear::scalar_udf(),
248 ts_array,
249 value_array,
250 vec![ScalarValue::Int64(Some(0))],
251 vec![Some(38.63636363636364)],
253 );
254 }
255
256 #[test]
257 fn calculate_predict_linear_test2() {
258 let (ts_array, value_array) = build_test_range_arrays();
259 simple_range_udf_runner(
260 PredictLinear::scalar_udf(),
261 ts_array,
262 value_array,
263 vec![ScalarValue::Int64(Some(3000))],
264 vec![Some(31856.818181818187)],
266 );
267 }
268
269 #[test]
270 fn calculate_predict_linear_test3() {
271 let (ts_array, value_array) = build_test_range_arrays();
272 simple_range_udf_runner(
273 PredictLinear::scalar_udf(),
274 ts_array,
275 value_array,
276 vec![ScalarValue::Int64(Some(4200))],
277 vec![Some(44584.09090909091)],
279 );
280 }
281
282 #[test]
283 fn calculate_predict_linear_test4() {
284 let (ts_array, value_array) = build_test_range_arrays();
285 simple_range_udf_runner(
286 PredictLinear::scalar_udf(),
287 ts_array,
288 value_array,
289 vec![ScalarValue::Int64(Some(6600))],
290 vec![Some(70038.63636363638)],
292 );
293 }
294
295 #[test]
296 fn calculate_predict_linear_test5() {
297 let (ts_array, value_array) = build_test_range_arrays();
298 simple_range_udf_runner(
299 PredictLinear::scalar_udf(),
300 ts_array,
301 value_array,
302 vec![ScalarValue::Int64(Some(7800))],
303 vec![Some(82765.9090909091)],
305 );
306 }
307}