promql/functions/
round.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use datafusion::error::DataFusionError;
18use datafusion_common::ScalarValue;
19use datafusion_expr::{create_udf, ColumnarValue, ScalarUDF, Volatility};
20use datatypes::arrow::array::{AsArray, Float64Array, PrimitiveArray};
21use datatypes::arrow::datatypes::{DataType, Float64Type};
22use datatypes::arrow::error::ArrowError;
23
24use crate::error;
25use crate::functions::extract_array;
26
27pub struct Round;
28
29impl Round {
30    pub const fn name() -> &'static str {
31        "prom_round"
32    }
33
34    fn input_type() -> Vec<DataType> {
35        vec![DataType::Float64, DataType::Float64]
36    }
37
38    pub fn return_type() -> DataType {
39        DataType::Float64
40    }
41
42    pub fn scalar_udf() -> ScalarUDF {
43        create_udf(
44            Self::name(),
45            Self::input_type(),
46            Self::return_type(),
47            Volatility::Volatile,
48            Arc::new(Self::round) as _,
49        )
50    }
51
52    fn round(input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
53        error::ensure(
54            input.len() == 2,
55            DataFusionError::Plan("prom_round function should have 2 inputs".to_string()),
56        )?;
57
58        let value_array = extract_array(&input[0])?;
59        let nearest_col = &input[1];
60
61        match nearest_col {
62            ColumnarValue::Scalar(nearest_scalar) => {
63                let nearest = if let ScalarValue::Float64(Some(val)) = nearest_scalar {
64                    *val
65                } else {
66                    let null_array = Float64Array::new_null(value_array.len());
67                    return Ok(ColumnarValue::Array(Arc::new(null_array)));
68                };
69                let op = |a: f64| {
70                    if nearest == 0.0 {
71                        a.round()
72                    } else {
73                        (a / nearest).round() * nearest
74                    }
75                };
76                let result: PrimitiveArray<Float64Type> =
77                    value_array.as_primitive::<Float64Type>().unary(op);
78                Ok(ColumnarValue::Array(Arc::new(result) as _))
79            }
80            ColumnarValue::Array(nearest_array) => {
81                let value_array = value_array.as_primitive::<Float64Type>();
82                let nearest_array = nearest_array.as_primitive::<Float64Type>();
83                error::ensure(
84                    value_array.len() == nearest_array.len(),
85                    DataFusionError::Execution(format!(
86                        "input arrays should have the same length, found {} and {}",
87                        value_array.len(),
88                        nearest_array.len()
89                    )),
90                )?;
91
92                let result: PrimitiveArray<Float64Type> =
93                    datatypes::arrow::compute::binary(value_array, nearest_array, |a, nearest| {
94                        if nearest == 0.0 {
95                            a.round()
96                        } else {
97                            (a / nearest).round() * nearest
98                        }
99                    })
100                    .map_err(|err: ArrowError| DataFusionError::ArrowError(err, None))?;
101
102                Ok(ColumnarValue::Array(Arc::new(result) as _))
103            }
104        }
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use datafusion_expr::ScalarFunctionArgs;
111    use datatypes::arrow::array::Float64Array;
112
113    use super::*;
114
115    fn test_round_f64(value: Vec<f64>, nearest: f64, expected: Vec<f64>) {
116        let round_udf = Round::scalar_udf();
117        let input = vec![
118            ColumnarValue::Array(Arc::new(Float64Array::from(value))),
119            ColumnarValue::Scalar(ScalarValue::Float64(Some(nearest))),
120        ];
121        let args = ScalarFunctionArgs {
122            args: input,
123            number_rows: 1,
124            return_type: &DataType::Float64,
125        };
126        let result = round_udf.invoke_with_args(args).unwrap();
127        let result_array = extract_array(&result).unwrap();
128        assert_eq!(result_array.len(), 1);
129        assert_eq!(
130            result_array.as_primitive::<Float64Type>().values(),
131            &expected
132        );
133    }
134
135    #[test]
136    fn test_round() {
137        test_round_f64(vec![123.456], 0.001, vec![123.456]);
138        test_round_f64(vec![123.456], 0.01, vec![123.46000000000001]);
139        test_round_f64(vec![123.456], 0.1, vec![123.5]);
140        test_round_f64(vec![123.456], 0.0, vec![123.0]);
141        test_round_f64(vec![123.456], 1.0, vec![123.0]);
142        test_round_f64(vec![123.456], 10.0, vec![120.0]);
143        test_round_f64(vec![123.456], 100.0, vec![100.0]);
144        test_round_f64(vec![123.456], 105.0, vec![105.0]);
145        test_round_f64(vec![123.456], 1000.0, vec![0.0]);
146    }
147}