promql/functions/
round.rs1use std::sync::Arc;
16
17use datafusion::error::DataFusionError;
18use datafusion_common::ScalarValue;
19use datafusion_expr::{create_udf, ColumnarValue, ScalarUDF, Volatility};
20use datatypes::arrow::array::{AsArray, Float64Array, PrimitiveArray};
21use datatypes::arrow::datatypes::{DataType, Float64Type};
22use datatypes::arrow::error::ArrowError;
23
24use crate::error;
25use crate::functions::extract_array;
26
27pub struct Round;
28
29impl Round {
30 pub const fn name() -> &'static str {
31 "prom_round"
32 }
33
34 fn input_type() -> Vec<DataType> {
35 vec![DataType::Float64, DataType::Float64]
36 }
37
38 pub fn return_type() -> DataType {
39 DataType::Float64
40 }
41
42 pub fn scalar_udf() -> ScalarUDF {
43 create_udf(
44 Self::name(),
45 Self::input_type(),
46 Self::return_type(),
47 Volatility::Volatile,
48 Arc::new(Self::round) as _,
49 )
50 }
51
52 fn round(input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
53 error::ensure(
54 input.len() == 2,
55 DataFusionError::Plan("prom_round function should have 2 inputs".to_string()),
56 )?;
57
58 let value_array = extract_array(&input[0])?;
59 let nearest_col = &input[1];
60
61 match nearest_col {
62 ColumnarValue::Scalar(nearest_scalar) => {
63 let nearest = if let ScalarValue::Float64(Some(val)) = nearest_scalar {
64 *val
65 } else {
66 let null_array = Float64Array::new_null(value_array.len());
67 return Ok(ColumnarValue::Array(Arc::new(null_array)));
68 };
69 let op = |a: f64| {
70 if nearest == 0.0 {
71 a.round()
72 } else {
73 (a / nearest).round() * nearest
74 }
75 };
76 let result: PrimitiveArray<Float64Type> =
77 value_array.as_primitive::<Float64Type>().unary(op);
78 Ok(ColumnarValue::Array(Arc::new(result) as _))
79 }
80 ColumnarValue::Array(nearest_array) => {
81 let value_array = value_array.as_primitive::<Float64Type>();
82 let nearest_array = nearest_array.as_primitive::<Float64Type>();
83 error::ensure(
84 value_array.len() == nearest_array.len(),
85 DataFusionError::Execution(format!(
86 "input arrays should have the same length, found {} and {}",
87 value_array.len(),
88 nearest_array.len()
89 )),
90 )?;
91
92 let result: PrimitiveArray<Float64Type> =
93 datatypes::arrow::compute::binary(value_array, nearest_array, |a, nearest| {
94 if nearest == 0.0 {
95 a.round()
96 } else {
97 (a / nearest).round() * nearest
98 }
99 })
100 .map_err(|err: ArrowError| DataFusionError::ArrowError(err, None))?;
101
102 Ok(ColumnarValue::Array(Arc::new(result) as _))
103 }
104 }
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use datafusion_expr::ScalarFunctionArgs;
111 use datatypes::arrow::array::Float64Array;
112
113 use super::*;
114
115 fn test_round_f64(value: Vec<f64>, nearest: f64, expected: Vec<f64>) {
116 let round_udf = Round::scalar_udf();
117 let input = vec![
118 ColumnarValue::Array(Arc::new(Float64Array::from(value))),
119 ColumnarValue::Scalar(ScalarValue::Float64(Some(nearest))),
120 ];
121 let args = ScalarFunctionArgs {
122 args: input,
123 number_rows: 1,
124 return_type: &DataType::Float64,
125 };
126 let result = round_udf.invoke_with_args(args).unwrap();
127 let result_array = extract_array(&result).unwrap();
128 assert_eq!(result_array.len(), 1);
129 assert_eq!(
130 result_array.as_primitive::<Float64Type>().values(),
131 &expected
132 );
133 }
134
135 #[test]
136 fn test_round() {
137 test_round_f64(vec![123.456], 0.001, vec![123.456]);
138 test_round_f64(vec![123.456], 0.01, vec![123.46000000000001]);
139 test_round_f64(vec![123.456], 0.1, vec![123.5]);
140 test_round_f64(vec![123.456], 0.0, vec![123.0]);
141 test_round_f64(vec![123.456], 1.0, vec![123.0]);
142 test_round_f64(vec![123.456], 10.0, vec![120.0]);
143 test_round_f64(vec![123.456], 100.0, vec![100.0]);
144 test_round_f64(vec![123.456], 105.0, vec![105.0]);
145 test_round_f64(vec![123.456], 1000.0, vec![0.0]);
146 }
147}