common_function/scalars/math/
rate.rs1use std::fmt;
16
17use common_query::error;
18use datafusion::arrow::compute::kernels::numeric;
19use datafusion_common::arrow::compute::kernels::cast;
20use datafusion_common::arrow::datatypes::DataType;
21use datafusion_expr::type_coercion::aggregates::NUMERICS;
22use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
23use snafu::ResultExt;
24
25use crate::function::{Function, extract_args};
26
27#[derive(Clone, Debug)]
29pub(crate) struct RateFunction {
30 signature: Signature,
31}
32
33impl Default for RateFunction {
34 fn default() -> Self {
35 Self {
36 signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
37 }
38 }
39}
40
41impl fmt::Display for RateFunction {
42 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
43 write!(f, "RATE")
44 }
45}
46
47impl Function for RateFunction {
48 fn name(&self) -> &str {
49 "rate"
50 }
51
52 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
53 Ok(DataType::Float64)
54 }
55
56 fn signature(&self) -> &Signature {
57 &self.signature
58 }
59
60 fn invoke_with_args(
61 &self,
62 args: ScalarFunctionArgs,
63 ) -> datafusion_common::Result<ColumnarValue> {
64 let [val, ts] = extract_args(self.name(), &args)?;
65 let val_0 = val.slice(0, val.len() - 1);
66 let val_1 = val.slice(1, val.len() - 1);
67 let dv = numeric::sub(&val_1, &val_0).context(error::ArrowComputeSnafu)?;
68 let ts_0 = ts.slice(0, ts.len() - 1);
69 let ts_1 = ts.slice(1, ts.len() - 1);
70 let dt = numeric::sub(&ts_1, &ts_0).context(error::ArrowComputeSnafu)?;
71
72 let dv = cast::cast(&dv, &DataType::Float64).context(error::TypeCastSnafu {
73 typ: DataType::Float64,
74 })?;
75 let dt = cast::cast(&dt, &DataType::Float64).context(error::TypeCastSnafu {
76 typ: DataType::Float64,
77 })?;
78 let rate = numeric::div(&dv, &dt).context(error::ArrowComputeSnafu)?;
79
80 Ok(ColumnarValue::Array(rate))
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use std::sync::Arc;
87
88 use arrow_schema::Field;
89 use datafusion_common::arrow::array::{AsArray, Float32Array, Float64Array, Int64Array};
90 use datafusion_common::arrow::datatypes::Float64Type;
91 use datafusion_expr::TypeSignature;
92
93 use super::*;
94 #[test]
95 fn test_rate_function() {
96 let rate = RateFunction::default();
97 assert_eq!("rate", rate.name());
98 assert_eq!(DataType::Float64, rate.return_type(&[]).unwrap());
99 assert!(matches!(rate.signature(),
100 Signature {
101 type_signature: TypeSignature::Uniform(2, valid_types),
102 volatility: Volatility::Immutable
103 } if valid_types == NUMERICS
104 ));
105 let values = vec![1.0, 3.0, 6.0];
106 let ts = vec![0, 1, 2];
107
108 let args = ScalarFunctionArgs {
109 args: vec![
110 ColumnarValue::Array(Arc::new(Float32Array::from(values))),
111 ColumnarValue::Array(Arc::new(Int64Array::from(ts))),
112 ],
113 arg_fields: vec![],
114 number_rows: 3,
115 return_field: Arc::new(Field::new("x", DataType::Float64, false)),
116 config_options: Arc::new(Default::default()),
117 };
118 let result = rate
119 .invoke_with_args(args)
120 .and_then(|x| x.to_array(2))
121 .unwrap();
122 let vector = result.as_primitive::<Float64Type>();
123 let expect = &Float64Array::from(vec![2.0, 3.0]);
124 assert_eq!(expect, vector);
125 }
126}