promql/functions/
round.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use datafusion::error::DataFusionError;
18use datafusion_common::ScalarValue;
19use datafusion_expr::{create_udf, ColumnarValue, ScalarUDF, Volatility};
20use datatypes::arrow::array::AsArray;
21use datatypes::arrow::datatypes::{DataType, Float64Type};
22use datatypes::compute;
23
24use crate::functions::extract_array;
25
26pub struct Round {
27    nearest: f64,
28}
29
30impl Round {
31    fn new(nearest: f64) -> Self {
32        Self { nearest }
33    }
34
35    pub const fn name() -> &'static str {
36        "prom_round"
37    }
38
39    fn input_type() -> Vec<DataType> {
40        vec![DataType::Float64, DataType::Float64]
41    }
42
43    pub fn return_type() -> DataType {
44        DataType::Float64
45    }
46
47    pub fn scalar_udf() -> ScalarUDF {
48        create_udf(
49            Self::name(),
50            Self::input_type(),
51            Self::return_type(),
52            Volatility::Volatile,
53            Arc::new(move |input: &_| Self::create_function(input)?.calc(input)) as _,
54        )
55    }
56
57    fn create_function(inputs: &[ColumnarValue]) -> Result<Self, DataFusionError> {
58        if inputs.len() != 2 {
59            return Err(DataFusionError::Plan(
60                "Round function should have 2 inputs".to_string(),
61            ));
62        }
63        let ColumnarValue::Scalar(ScalarValue::Float64(Some(nearest))) = inputs[1] else {
64            return Err(DataFusionError::Plan(
65                "Round function's second input should be a scalar float64".to_string(),
66            ));
67        };
68        Ok(Self::new(nearest))
69    }
70
71    fn calc(&self, input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
72        assert_eq!(input.len(), 2);
73
74        let value_array = extract_array(&input[0])?;
75
76        if self.nearest == 0.0 {
77            let values = value_array.as_primitive::<Float64Type>();
78            let result = compute::unary::<_, _, Float64Type>(values, |a| a.round());
79            Ok(ColumnarValue::Array(Arc::new(result) as _))
80        } else {
81            let values = value_array.as_primitive::<Float64Type>();
82            let nearest = self.nearest;
83            let result =
84                compute::unary::<_, _, Float64Type>(values, |a| ((a / nearest).round() * nearest));
85            Ok(ColumnarValue::Array(Arc::new(result) as _))
86        }
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use datafusion_expr::ScalarFunctionArgs;
93    use datatypes::arrow::array::Float64Array;
94
95    use super::*;
96
97    fn test_round_f64(value: Vec<f64>, nearest: f64, expected: Vec<f64>) {
98        let round_udf = Round::scalar_udf();
99        let input = vec![
100            ColumnarValue::Array(Arc::new(Float64Array::from(value))),
101            ColumnarValue::Scalar(ScalarValue::Float64(Some(nearest))),
102        ];
103        let args = ScalarFunctionArgs {
104            args: input,
105            number_rows: 1,
106            return_type: &DataType::Float64,
107        };
108        let result = round_udf.invoke_with_args(args).unwrap();
109        let result_array = extract_array(&result).unwrap();
110        assert_eq!(result_array.len(), 1);
111        assert_eq!(
112            result_array.as_primitive::<Float64Type>().values(),
113            &expected
114        );
115    }
116
117    #[test]
118    fn test_round() {
119        test_round_f64(vec![123.456], 0.001, vec![123.456]);
120        test_round_f64(vec![123.456], 0.01, vec![123.46000000000001]);
121        test_round_f64(vec![123.456], 0.1, vec![123.5]);
122        test_round_f64(vec![123.456], 0.0, vec![123.0]);
123        test_round_f64(vec![123.456], 1.0, vec![123.0]);
124        test_round_f64(vec![123.456], 10.0, vec![120.0]);
125        test_round_f64(vec![123.456], 100.0, vec![100.0]);
126        test_round_f64(vec![123.456], 105.0, vec![105.0]);
127        test_round_f64(vec![123.456], 1000.0, vec![0.0]);
128    }
129}