1use std::sync::Arc;
19
20use datafusion::arrow::array::{Float64Array, Float64Builder, TimestampMillisecondArray};
21use datafusion::arrow::datatypes::TimeUnit;
22use datafusion::common::DataFusionError;
23use datafusion::logical_expr::{ScalarUDF, Volatility};
24use datafusion::physical_plan::ColumnarValue;
25use datafusion_common::ScalarValue;
26use datafusion_expr::create_udf;
27use datatypes::arrow::array::Array;
28use datatypes::arrow::datatypes::DataType;
29
30use crate::error;
31use crate::functions::{extract_range_array, linear_regression_slices};
32use crate::range_array::RangeArray;
33
34pub struct PredictLinear;
35
36impl PredictLinear {
37 pub const fn name() -> &'static str {
38 "prom_predict_linear"
39 }
40
41 pub fn scalar_udf() -> ScalarUDF {
42 let input_types = vec![
43 RangeArray::convert_data_type(DataType::Timestamp(TimeUnit::Millisecond, None)),
45 RangeArray::convert_data_type(DataType::Float64),
47 DataType::Int64,
49 ];
50 create_udf(
51 Self::name(),
52 input_types,
53 DataType::Float64,
54 Volatility::Volatile,
55 Arc::new(Self::predict_linear) as _,
56 )
57 }
58
59 fn predict_linear(input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
60 error::ensure(
61 input.len() == 3,
62 DataFusionError::Plan("prom_predict_linear function should have 3 inputs".to_string()),
63 )?;
64
65 let t_col = &input[2];
66
67 let ts_range = extract_range_array(&input[0])?;
68 let value_range = extract_range_array(&input[1])?;
69 error::ensure(
70 ts_range.len() == value_range.len(),
71 DataFusionError::Execution(format!(
72 "{}: input arrays should have the same length, found {} and {}",
73 Self::name(),
74 ts_range.len(),
75 value_range.len()
76 )),
77 )?;
78 error::ensure(
79 ts_range.value_type() == DataType::Timestamp(TimeUnit::Millisecond, None),
80 DataFusionError::Execution(format!(
81 "{}: expect TimestampMillisecond as time index array's type, found {}",
82 Self::name(),
83 ts_range.value_type()
84 )),
85 )?;
86 error::ensure(
87 value_range.value_type() == DataType::Float64,
88 DataFusionError::Execution(format!(
89 "{}: expect Float64 as value array's type, found {}",
90 Self::name(),
91 value_range.value_type()
92 )),
93 )?;
94
95 let t_iter: Box<dyn Iterator<Item = Option<i64>>> = match t_col {
96 ColumnarValue::Scalar(t_scalar) => {
97 let t = if let ScalarValue::Int64(Some(t_val)) = t_scalar {
98 *t_val
99 } else {
100 let null_array = Float64Array::new_null(ts_range.len());
103 return Ok(ColumnarValue::Array(Arc::new(null_array)));
104 };
105 Box::new((0..ts_range.len()).map(move |_| Some(t)))
106 }
107 ColumnarValue::Array(t_array) => {
108 let t_array = t_array
109 .as_any()
110 .downcast_ref::<datafusion::arrow::array::Int64Array>()
111 .ok_or_else(|| {
112 DataFusionError::Execution(format!(
113 "{}: expect Int64 as t array's type, found {}",
114 Self::name(),
115 t_array.data_type()
116 ))
117 })?;
118 error::ensure(
119 t_array.len() == ts_range.len(),
120 DataFusionError::Execution(format!(
121 "{}: t array should have the same length as other columns, found {} and {}",
122 Self::name(),
123 t_array.len(),
124 ts_range.len()
125 )),
126 )?;
127
128 Box::new(t_array.iter())
129 }
130 };
131 let all_timestamps = ts_range
132 .values()
133 .as_any()
134 .downcast_ref::<TimestampMillisecondArray>()
135 .unwrap()
136 .values();
137 let all_values = value_range
138 .values()
139 .as_any()
140 .downcast_ref::<Float64Array>()
141 .unwrap();
142 let mut result_builder = Float64Builder::with_capacity(ts_range.len());
143 for (index, t) in t_iter.enumerate() {
144 match predict_linear_impl(
145 &ts_range,
146 &value_range,
147 all_timestamps,
148 all_values,
149 index,
150 t.unwrap(),
151 Self::name(),
152 )? {
153 Some(value) => result_builder.append_value(value),
154 None => result_builder.append_null(),
155 }
156 }
157
158 let result = ColumnarValue::Array(Arc::new(result_builder.finish()));
159 Ok(result)
160 }
161}
162
163fn predict_linear_impl(
164 ts_range: &RangeArray,
165 value_range: &RangeArray,
166 all_timestamps: &[i64],
167 all_values: &Float64Array,
168 index: usize,
169 t: i64,
170 func_name: &str,
171) -> Result<Option<f64>, DataFusionError> {
172 let (ts_offset, ts_len) = ts_range.get_offset_length(index).unwrap();
173 let (value_offset, value_len) = value_range.get_offset_length(index).unwrap();
174 error::ensure(
175 ts_len == value_len,
176 DataFusionError::Execution(format!(
177 "{}: time and value arrays in a group should have the same length, found {} and {}",
178 func_name, ts_len, value_len
179 )),
180 )?;
181 if ts_len < 2 {
182 return Ok(None);
183 }
184
185 let evaluate_ts = all_timestamps[ts_offset + ts_len - 1];
187 let (slope, intercept) = linear_regression_slices(
188 all_timestamps,
189 ts_offset,
190 all_values,
191 value_offset,
192 value_len,
193 evaluate_ts,
194 );
195
196 if slope.is_none() || intercept.is_none() {
197 return Ok(None);
198 }
199
200 Ok(Some(slope.unwrap() * t as f64 + intercept.unwrap()))
201}
202
203#[cfg(test)]
204mod test {
205 use std::vec;
206
207 use datafusion::arrow::array::{DictionaryArray, Int64Array};
208 use datatypes::arrow::datatypes::Int64Type;
209
210 use super::*;
211 use crate::functions::test_util::simple_range_udf_runner;
212
213 fn build_test_range_arrays() -> (RangeArray, RangeArray) {
215 let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
216 [
217 0i64, 300, 600, 900, 1200, 1500, 1800, 2100, 2400, 2700, 3000,
218 ]
219 .into_iter()
220 .map(Some),
221 ));
222 let ranges = [(0, 11)];
223
224 let values_array = Arc::new(Float64Array::from_iter([
225 0.0, 10.0, 20.0, 30.0, 40.0, 0.0, 10.0, 20.0, 30.0, 40.0, 50.0,
226 ]));
227
228 let ts_range_array = RangeArray::from_ranges(ts_array, ranges).unwrap();
229 let value_range_array = RangeArray::from_ranges(values_array, ranges).unwrap();
230
231 (ts_range_array, value_range_array)
232 }
233
234 #[test]
235 fn calculate_predict_linear_none() {
236 let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
237 [0i64].into_iter().map(Some),
238 ));
239 let ranges = [(0, 0), (0, 1)];
240 let values_array = Arc::new(Float64Array::from_iter([0.0]));
241 let ts_array = RangeArray::from_ranges(ts_array, ranges).unwrap();
242 let value_array = RangeArray::from_ranges(values_array, ranges).unwrap();
243 simple_range_udf_runner(
244 PredictLinear::scalar_udf(),
245 ts_array,
246 value_array,
247 vec![ScalarValue::Int64(Some(0))],
248 vec![None, None],
249 );
250 }
251
252 #[test]
253 fn calculate_predict_linear_test1() {
254 let (ts_array, value_array) = build_test_range_arrays();
255 simple_range_udf_runner(
256 PredictLinear::scalar_udf(),
257 ts_array,
258 value_array,
259 vec![ScalarValue::Int64(Some(0))],
260 vec![Some(38.63636363636364)],
262 );
263 }
264
265 #[test]
266 fn calculate_predict_linear_test2() {
267 let (ts_array, value_array) = build_test_range_arrays();
268 simple_range_udf_runner(
269 PredictLinear::scalar_udf(),
270 ts_array,
271 value_array,
272 vec![ScalarValue::Int64(Some(3000))],
273 vec![Some(31856.818181818187)],
275 );
276 }
277
278 #[test]
279 fn calculate_predict_linear_test3() {
280 let (ts_array, value_array) = build_test_range_arrays();
281 simple_range_udf_runner(
282 PredictLinear::scalar_udf(),
283 ts_array,
284 value_array,
285 vec![ScalarValue::Int64(Some(4200))],
286 vec![Some(44584.09090909091)],
288 );
289 }
290
291 #[test]
292 fn calculate_predict_linear_test4() {
293 let (ts_array, value_array) = build_test_range_arrays();
294 simple_range_udf_runner(
295 PredictLinear::scalar_udf(),
296 ts_array,
297 value_array,
298 vec![ScalarValue::Int64(Some(6600))],
299 vec![Some(70038.63636363638)],
301 );
302 }
303
304 #[test]
305 fn calculate_predict_linear_test5() {
306 let (ts_array, value_array) = build_test_range_arrays();
307 simple_range_udf_runner(
308 PredictLinear::scalar_udf(),
309 ts_array,
310 value_array,
311 vec![ScalarValue::Int64(Some(7800))],
312 vec![Some(82765.9090909091)],
314 );
315 }
316
317 #[test]
318 fn calculate_predict_linear_with_misaligned_offsets() {
319 let ts_values = Arc::new(TimestampMillisecondArray::from_iter(
320 [0i64, 1000, 2000, 3000].into_iter().map(Some),
321 ));
322 let value_values = Arc::new(Float64Array::from_iter([10.0, 20.0, 30.0]));
323 let ts_array = RangeArray::from_ranges(ts_values, [(1, 3)]).unwrap();
324 let value_array = RangeArray::from_ranges(value_values, [(0, 3)]).unwrap();
325
326 simple_range_udf_runner(
327 PredictLinear::scalar_udf(),
328 ts_array,
329 value_array,
330 vec![ScalarValue::Int64(Some(0))],
331 vec![Some(30.0)],
332 );
333 }
334
335 #[test]
336 fn predict_linear_rejects_external_dictionary_with_null_keys() {
337 let ts_values = Arc::new(TimestampMillisecondArray::from_iter(
338 [0i64, 1000].into_iter().map(Some),
339 ));
340 let ts_keys = Int64Array::from_iter([Some(0), None]);
341 let ts_dict = DictionaryArray::<Int64Type>::try_new(ts_keys, ts_values).unwrap();
342
343 let value_values = Arc::new(Float64Array::from_iter([1.0, 2.0]));
344 let value_keys = Int64Array::from_iter([Some(0), Some(1)]);
345 let value_dict = DictionaryArray::<Int64Type>::try_new(value_keys, value_values).unwrap();
346
347 let err = PredictLinear::predict_linear(&[
348 ColumnarValue::Array(Arc::new(ts_dict)),
349 ColumnarValue::Array(Arc::new(value_dict)),
350 ColumnarValue::Scalar(ScalarValue::Int64(Some(0))),
351 ])
352 .unwrap_err();
353
354 assert!(err.to_string().contains("Empty range is not expected"));
355 }
356}