common_function/scalars/math/
rate.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt;
16
17use common_query::error;
18use datafusion::arrow::compute::kernels::numeric;
19use datafusion_common::arrow::compute::kernels::cast;
20use datafusion_common::arrow::datatypes::DataType;
21use datafusion_expr::type_coercion::aggregates::NUMERICS;
22use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
23use snafu::ResultExt;
24
25use crate::function::{Function, extract_args};
26
27/// generates rates from a sequence of adjacent data points.
28#[derive(Clone, Debug)]
29pub(crate) struct RateFunction {
30    signature: Signature,
31}
32
33impl Default for RateFunction {
34    fn default() -> Self {
35        Self {
36            signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
37        }
38    }
39}
40
41impl fmt::Display for RateFunction {
42    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
43        write!(f, "RATE")
44    }
45}
46
47impl Function for RateFunction {
48    fn name(&self) -> &str {
49        "rate"
50    }
51
52    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
53        Ok(DataType::Float64)
54    }
55
56    fn signature(&self) -> &Signature {
57        &self.signature
58    }
59
60    fn invoke_with_args(
61        &self,
62        args: ScalarFunctionArgs,
63    ) -> datafusion_common::Result<ColumnarValue> {
64        let [val, ts] = extract_args(self.name(), &args)?;
65        let val_0 = val.slice(0, val.len() - 1);
66        let val_1 = val.slice(1, val.len() - 1);
67        let dv = numeric::sub(&val_1, &val_0).context(error::ArrowComputeSnafu)?;
68        let ts_0 = ts.slice(0, ts.len() - 1);
69        let ts_1 = ts.slice(1, ts.len() - 1);
70        let dt = numeric::sub(&ts_1, &ts_0).context(error::ArrowComputeSnafu)?;
71
72        let dv = cast::cast(&dv, &DataType::Float64).context(error::TypeCastSnafu {
73            typ: DataType::Float64,
74        })?;
75        let dt = cast::cast(&dt, &DataType::Float64).context(error::TypeCastSnafu {
76            typ: DataType::Float64,
77        })?;
78        let rate = numeric::div(&dv, &dt).context(error::ArrowComputeSnafu)?;
79
80        Ok(ColumnarValue::Array(rate))
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use std::sync::Arc;
87
88    use arrow_schema::Field;
89    use datafusion_common::arrow::array::{AsArray, Float32Array, Float64Array, Int64Array};
90    use datafusion_common::arrow::datatypes::Float64Type;
91    use datafusion_expr::TypeSignature;
92
93    use super::*;
94    #[test]
95    fn test_rate_function() {
96        let rate = RateFunction::default();
97        assert_eq!("rate", rate.name());
98        assert_eq!(DataType::Float64, rate.return_type(&[]).unwrap());
99        assert!(matches!(rate.signature(),
100                         Signature {
101                             type_signature: TypeSignature::Uniform(2, valid_types),
102                             volatility: Volatility::Immutable
103                         } if  valid_types == NUMERICS
104        ));
105        let values = vec![1.0, 3.0, 6.0];
106        let ts = vec![0, 1, 2];
107
108        let args = ScalarFunctionArgs {
109            args: vec![
110                ColumnarValue::Array(Arc::new(Float32Array::from(values))),
111                ColumnarValue::Array(Arc::new(Int64Array::from(ts))),
112            ],
113            arg_fields: vec![],
114            number_rows: 3,
115            return_field: Arc::new(Field::new("x", DataType::Float64, false)),
116            config_options: Arc::new(Default::default()),
117        };
118        let result = rate
119            .invoke_with_args(args)
120            .and_then(|x| x.to_array(2))
121            .unwrap();
122        let vector = result.as_primitive::<Float64Type>();
123        let expect = &Float64Array::from(vec![2.0, 3.0]);
124        assert_eq!(expect, vector);
125    }
126}