promql/functions/
round.rs1use std::sync::Arc;
16
17use datafusion::error::DataFusionError;
18use datafusion_common::ScalarValue;
19use datafusion_expr::{create_udf, ColumnarValue, ScalarUDF, Volatility};
20use datatypes::arrow::array::AsArray;
21use datatypes::arrow::datatypes::{DataType, Float64Type};
22use datatypes::compute;
23
24use crate::functions::extract_array;
25
26pub struct Round {
27 nearest: f64,
28}
29
30impl Round {
31 fn new(nearest: f64) -> Self {
32 Self { nearest }
33 }
34
35 pub const fn name() -> &'static str {
36 "prom_round"
37 }
38
39 fn input_type() -> Vec<DataType> {
40 vec![DataType::Float64, DataType::Float64]
41 }
42
43 pub fn return_type() -> DataType {
44 DataType::Float64
45 }
46
47 pub fn scalar_udf() -> ScalarUDF {
48 create_udf(
49 Self::name(),
50 Self::input_type(),
51 Self::return_type(),
52 Volatility::Volatile,
53 Arc::new(move |input: &_| Self::create_function(input)?.calc(input)) as _,
54 )
55 }
56
57 fn create_function(inputs: &[ColumnarValue]) -> Result<Self, DataFusionError> {
58 if inputs.len() != 2 {
59 return Err(DataFusionError::Plan(
60 "Round function should have 2 inputs".to_string(),
61 ));
62 }
63 let ColumnarValue::Scalar(ScalarValue::Float64(Some(nearest))) = inputs[1] else {
64 return Err(DataFusionError::Plan(
65 "Round function's second input should be a scalar float64".to_string(),
66 ));
67 };
68 Ok(Self::new(nearest))
69 }
70
71 fn calc(&self, input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
72 assert_eq!(input.len(), 2);
73
74 let value_array = extract_array(&input[0])?;
75
76 if self.nearest == 0.0 {
77 let values = value_array.as_primitive::<Float64Type>();
78 let result = compute::unary::<_, _, Float64Type>(values, |a| a.round());
79 Ok(ColumnarValue::Array(Arc::new(result) as _))
80 } else {
81 let values = value_array.as_primitive::<Float64Type>();
82 let nearest = self.nearest;
83 let result =
84 compute::unary::<_, _, Float64Type>(values, |a| ((a / nearest).round() * nearest));
85 Ok(ColumnarValue::Array(Arc::new(result) as _))
86 }
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use datafusion_expr::ScalarFunctionArgs;
93 use datatypes::arrow::array::Float64Array;
94
95 use super::*;
96
97 fn test_round_f64(value: Vec<f64>, nearest: f64, expected: Vec<f64>) {
98 let round_udf = Round::scalar_udf();
99 let input = vec![
100 ColumnarValue::Array(Arc::new(Float64Array::from(value))),
101 ColumnarValue::Scalar(ScalarValue::Float64(Some(nearest))),
102 ];
103 let args = ScalarFunctionArgs {
104 args: input,
105 number_rows: 1,
106 return_type: &DataType::Float64,
107 };
108 let result = round_udf.invoke_with_args(args).unwrap();
109 let result_array = extract_array(&result).unwrap();
110 assert_eq!(result_array.len(), 1);
111 assert_eq!(
112 result_array.as_primitive::<Float64Type>().values(),
113 &expected
114 );
115 }
116
117 #[test]
118 fn test_round() {
119 test_round_f64(vec![123.456], 0.001, vec![123.456]);
120 test_round_f64(vec![123.456], 0.01, vec![123.46000000000001]);
121 test_round_f64(vec![123.456], 0.1, vec![123.5]);
122 test_round_f64(vec![123.456], 0.0, vec![123.0]);
123 test_round_f64(vec![123.456], 1.0, vec![123.0]);
124 test_round_f64(vec![123.456], 10.0, vec![120.0]);
125 test_round_f64(vec![123.456], 100.0, vec![100.0]);
126 test_round_f64(vec![123.456], 105.0, vec![105.0]);
127 test_round_f64(vec![123.456], 1000.0, vec![0.0]);
128 }
129}