promql/functions/
round.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use datafusion::error::DataFusionError;
18use datafusion_common::ScalarValue;
19use datafusion_expr::{create_udf, ColumnarValue, ScalarUDF, Volatility};
20use datatypes::arrow::array::{AsArray, Float64Array, PrimitiveArray};
21use datatypes::arrow::datatypes::{DataType, Float64Type};
22use datatypes::arrow::error::ArrowError;
23
24use crate::error;
25use crate::functions::extract_array;
26
27pub struct Round;
28
29impl Round {
30    pub const fn name() -> &'static str {
31        "prom_round"
32    }
33
34    fn input_type() -> Vec<DataType> {
35        vec![DataType::Float64, DataType::Float64]
36    }
37
38    pub fn return_type() -> DataType {
39        DataType::Float64
40    }
41
42    pub fn scalar_udf() -> ScalarUDF {
43        create_udf(
44            Self::name(),
45            Self::input_type(),
46            Self::return_type(),
47            Volatility::Volatile,
48            Arc::new(Self::round) as _,
49        )
50    }
51
52    fn round(input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
53        error::ensure(
54            input.len() == 2,
55            DataFusionError::Plan("prom_round function should have 2 inputs".to_string()),
56        )?;
57
58        let value_array = extract_array(&input[0])?;
59        let nearest_col = &input[1];
60
61        match nearest_col {
62            ColumnarValue::Scalar(nearest_scalar) => {
63                let nearest = if let ScalarValue::Float64(Some(val)) = nearest_scalar {
64                    *val
65                } else {
66                    let null_array = Float64Array::new_null(value_array.len());
67                    return Ok(ColumnarValue::Array(Arc::new(null_array)));
68                };
69                let op = |a: f64| {
70                    if nearest == 0.0 {
71                        a.round()
72                    } else {
73                        (a / nearest).round() * nearest
74                    }
75                };
76                let result: PrimitiveArray<Float64Type> =
77                    value_array.as_primitive::<Float64Type>().unary(op);
78                Ok(ColumnarValue::Array(Arc::new(result) as _))
79            }
80            ColumnarValue::Array(nearest_array) => {
81                let value_array = value_array.as_primitive::<Float64Type>();
82                let nearest_array = nearest_array.as_primitive::<Float64Type>();
83                error::ensure(
84                    value_array.len() == nearest_array.len(),
85                    DataFusionError::Execution(format!(
86                        "input arrays should have the same length, found {} and {}",
87                        value_array.len(),
88                        nearest_array.len()
89                    )),
90                )?;
91
92                let result: PrimitiveArray<Float64Type> =
93                    datatypes::arrow::compute::binary(value_array, nearest_array, |a, nearest| {
94                        if nearest == 0.0 {
95                            a.round()
96                        } else {
97                            (a / nearest).round() * nearest
98                        }
99                    })
100                    .map_err(|err: ArrowError| DataFusionError::ArrowError(Box::new(err), None))?;
101
102                Ok(ColumnarValue::Array(Arc::new(result) as _))
103            }
104        }
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use datafusion_common::config::ConfigOptions;
111    use datafusion_expr::ScalarFunctionArgs;
112    use datatypes::arrow::array::Float64Array;
113    use datatypes::arrow::datatypes::Field;
114
115    use super::*;
116
117    fn test_round_f64(value: Vec<f64>, nearest: f64, expected: Vec<f64>) {
118        let round_udf = Round::scalar_udf();
119        let input = vec![
120            ColumnarValue::Array(Arc::new(Float64Array::from(value))),
121            ColumnarValue::Scalar(ScalarValue::Float64(Some(nearest))),
122        ];
123        let arg_fields = vec![
124            Arc::new(Field::new("a", input[0].data_type(), false)),
125            Arc::new(Field::new("b", input[1].data_type(), false)),
126        ];
127        let return_field = Arc::new(Field::new("c", DataType::Float64, false));
128        let args = ScalarFunctionArgs {
129            args: input,
130            arg_fields,
131            number_rows: 1,
132            return_field,
133            config_options: Arc::new(ConfigOptions::default()),
134        };
135        let result = round_udf.invoke_with_args(args).unwrap();
136        let result_array = extract_array(&result).unwrap();
137        assert_eq!(result_array.len(), 1);
138        assert_eq!(
139            result_array.as_primitive::<Float64Type>().values(),
140            &expected
141        );
142    }
143
144    #[test]
145    fn test_round() {
146        test_round_f64(vec![123.456], 0.001, vec![123.456]);
147        test_round_f64(vec![123.456], 0.01, vec![123.46000000000001]);
148        test_round_f64(vec![123.456], 0.1, vec![123.5]);
149        test_round_f64(vec![123.456], 0.0, vec![123.0]);
150        test_round_f64(vec![123.456], 1.0, vec![123.0]);
151        test_round_f64(vec![123.456], 10.0, vec![120.0]);
152        test_round_f64(vec![123.456], 100.0, vec![100.0]);
153        test_round_f64(vec![123.456], 105.0, vec![105.0]);
154        test_round_f64(vec![123.456], 1000.0, vec![0.0]);
155    }
156}