1use std::sync::Arc;
19
20use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
21use datafusion::arrow::datatypes::TimeUnit;
22use datafusion::common::DataFusionError;
23use datafusion::logical_expr::{ScalarUDF, Volatility};
24use datafusion::physical_plan::ColumnarValue;
25use datafusion_common::ScalarValue;
26use datafusion_expr::create_udf;
27use datatypes::arrow::array::Array;
28use datatypes::arrow::datatypes::DataType;
29
30use crate::error;
31use crate::functions::{extract_array, linear_regression};
32use crate::range_array::RangeArray;
33
34pub struct PredictLinear {
35 t: i64,
37}
38
39impl PredictLinear {
40 fn new(t: i64) -> Self {
41 Self { t }
42 }
43
44 pub const fn name() -> &'static str {
45 "prom_predict_linear"
46 }
47
48 pub fn scalar_udf() -> ScalarUDF {
49 let input_types = vec![
50 RangeArray::convert_data_type(DataType::Timestamp(TimeUnit::Millisecond, None)),
52 RangeArray::convert_data_type(DataType::Float64),
54 DataType::Int64,
56 ];
57 create_udf(
58 Self::name(),
59 input_types,
60 DataType::Float64,
61 Volatility::Volatile,
62 Arc::new(move |input: &_| Self::create_function(input)?.predict_linear(input)) as _,
63 )
64 }
65
66 fn create_function(inputs: &[ColumnarValue]) -> Result<Self, DataFusionError> {
67 if inputs.len() != 3 {
68 return Err(DataFusionError::Plan(
69 "PredictLinear function should have 3 inputs".to_string(),
70 ));
71 }
72 let ColumnarValue::Scalar(ScalarValue::Int64(Some(t))) = inputs[2] else {
73 return Err(DataFusionError::Plan(
74 "PredictLinear function's third input should be a scalar int64".to_string(),
75 ));
76 };
77 Ok(Self::new(t))
78 }
79
80 fn predict_linear(&self, input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
81 assert_eq!(input.len(), 3);
83 let ts_array = extract_array(&input[0])?;
84 let value_array = extract_array(&input[1])?;
85
86 let ts_range: RangeArray = RangeArray::try_new(ts_array.to_data().into())?;
87 let value_range: RangeArray = RangeArray::try_new(value_array.to_data().into())?;
88 error::ensure(
89 ts_range.len() == value_range.len(),
90 DataFusionError::Execution(format!(
91 "{}: input arrays should have the same length, found {} and {}",
92 Self::name(),
93 ts_range.len(),
94 value_range.len()
95 )),
96 )?;
97 error::ensure(
98 ts_range.value_type() == DataType::Timestamp(TimeUnit::Millisecond, None),
99 DataFusionError::Execution(format!(
100 "{}: expect TimestampMillisecond as time index array's type, found {}",
101 Self::name(),
102 ts_range.value_type()
103 )),
104 )?;
105 error::ensure(
106 value_range.value_type() == DataType::Float64,
107 DataFusionError::Execution(format!(
108 "{}: expect Float64 as value array's type, found {}",
109 Self::name(),
110 value_range.value_type()
111 )),
112 )?;
113
114 let mut result_array = Vec::with_capacity(ts_range.len());
116
117 for index in 0..ts_range.len() {
118 let timestamps = ts_range
119 .get(index)
120 .unwrap()
121 .as_any()
122 .downcast_ref::<TimestampMillisecondArray>()
123 .unwrap()
124 .clone();
125 let values = value_range
126 .get(index)
127 .unwrap()
128 .as_any()
129 .downcast_ref::<Float64Array>()
130 .unwrap()
131 .clone();
132 error::ensure(
133 timestamps.len() == values.len(),
134 DataFusionError::Execution(format!(
135 "{}: input arrays should have the same length, found {} and {}",
136 Self::name(),
137 timestamps.len(),
138 values.len()
139 )),
140 )?;
141
142 let ret = predict_linear_impl(×tamps, &values, self.t);
143
144 result_array.push(ret);
145 }
146
147 let result = ColumnarValue::Array(Arc::new(Float64Array::from_iter(result_array)));
148 Ok(result)
149 }
150}
151
152fn predict_linear_impl(
153 timestamps: &TimestampMillisecondArray,
154 values: &Float64Array,
155 t: i64,
156) -> Option<f64> {
157 if timestamps.len() < 2 {
158 return None;
159 }
160
161 let evaluate_ts = timestamps.value(timestamps.len() - 1);
163 let (slope, intercept) = linear_regression(timestamps, values, evaluate_ts);
164
165 if slope.is_none() || intercept.is_none() {
166 return None;
167 }
168
169 Some(slope.unwrap() * t as f64 + intercept.unwrap())
170}
171
172#[cfg(test)]
173mod test {
174 use std::vec;
175
176 use super::*;
177 use crate::functions::test_util::simple_range_udf_runner;
178
179 fn build_test_range_arrays() -> (RangeArray, RangeArray) {
181 let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
182 [
183 0i64, 300, 600, 900, 1200, 1500, 1800, 2100, 2400, 2700, 3000,
184 ]
185 .into_iter()
186 .map(Some),
187 ));
188 let ranges = [(0, 11)];
189
190 let values_array = Arc::new(Float64Array::from_iter([
191 0.0, 10.0, 20.0, 30.0, 40.0, 0.0, 10.0, 20.0, 30.0, 40.0, 50.0,
192 ]));
193
194 let ts_range_array = RangeArray::from_ranges(ts_array, ranges).unwrap();
195 let value_range_array = RangeArray::from_ranges(values_array, ranges).unwrap();
196
197 (ts_range_array, value_range_array)
198 }
199
200 #[test]
201 fn calculate_predict_linear_none() {
202 let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
203 [0i64].into_iter().map(Some),
204 ));
205 let ranges = [(0, 0), (0, 1)];
206 let values_array = Arc::new(Float64Array::from_iter([0.0]));
207 let ts_array = RangeArray::from_ranges(ts_array, ranges).unwrap();
208 let value_array = RangeArray::from_ranges(values_array, ranges).unwrap();
209 simple_range_udf_runner(
210 PredictLinear::scalar_udf(),
211 ts_array,
212 value_array,
213 vec![ScalarValue::Int64(Some(0))],
214 vec![None, None],
215 );
216 }
217
218 #[test]
219 fn calculate_predict_linear_test1() {
220 let (ts_array, value_array) = build_test_range_arrays();
221 simple_range_udf_runner(
222 PredictLinear::scalar_udf(),
223 ts_array,
224 value_array,
225 vec![ScalarValue::Int64(Some(0))],
226 vec![Some(38.63636363636364)],
228 );
229 }
230
231 #[test]
232 fn calculate_predict_linear_test2() {
233 let (ts_array, value_array) = build_test_range_arrays();
234 simple_range_udf_runner(
235 PredictLinear::scalar_udf(),
236 ts_array,
237 value_array,
238 vec![ScalarValue::Int64(Some(3000))],
239 vec![Some(31856.818181818187)],
241 );
242 }
243
244 #[test]
245 fn calculate_predict_linear_test3() {
246 let (ts_array, value_array) = build_test_range_arrays();
247 simple_range_udf_runner(
248 PredictLinear::scalar_udf(),
249 ts_array,
250 value_array,
251 vec![ScalarValue::Int64(Some(4200))],
252 vec![Some(44584.09090909091)],
254 );
255 }
256
257 #[test]
258 fn calculate_predict_linear_test4() {
259 let (ts_array, value_array) = build_test_range_arrays();
260 simple_range_udf_runner(
261 PredictLinear::scalar_udf(),
262 ts_array,
263 value_array,
264 vec![ScalarValue::Int64(Some(6600))],
265 vec![Some(70038.63636363638)],
267 );
268 }
269
270 #[test]
271 fn calculate_predict_linear_test5() {
272 let (ts_array, value_array) = build_test_range_arrays();
273 simple_range_udf_runner(
274 PredictLinear::scalar_udf(),
275 ts_array,
276 value_array,
277 vec![ScalarValue::Int64(Some(7800))],
278 vec![Some(82765.9090909091)],
280 );
281 }
282}