common_function/scalars/math/
pow.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt;
16use std::sync::Arc;
17
18use common_query::error::Result;
19use common_query::prelude::{Signature, Volatility};
20use datatypes::data_type::DataType;
21use datatypes::prelude::ConcreteDataType;
22use datatypes::types::LogicalPrimitiveType;
23use datatypes::vectors::VectorRef;
24use datatypes::with_match_primitive_type_id;
25use num::traits::Pow;
26use num_traits::AsPrimitive;
27
28use crate::function::{Function, FunctionContext};
29use crate::scalars::expression::{scalar_binary_op, EvalContext};
30
31#[derive(Clone, Debug, Default)]
32pub struct PowFunction;
33
34impl Function for PowFunction {
35    fn name(&self) -> &str {
36        "pow"
37    }
38
39    fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
40        Ok(ConcreteDataType::float64_datatype())
41    }
42
43    fn signature(&self) -> Signature {
44        Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable)
45    }
46
47    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
48        with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
49            with_match_primitive_type_id!(columns[1].data_type().logical_type_id(), |$T| {
50                let col = scalar_binary_op::<<$S as LogicalPrimitiveType>::Native, <$T as LogicalPrimitiveType>::Native, f64, _>(&columns[0], &columns[1], scalar_pow, &mut EvalContext::default())?;
51                Ok(Arc::new(col))
52            },{
53                unreachable!()
54            })
55        },{
56            unreachable!()
57        })
58    }
59}
60
61#[inline]
62fn scalar_pow<S, T>(value: Option<S>, base: Option<T>, _ctx: &mut EvalContext) -> Option<f64>
63where
64    S: AsPrimitive<f64>,
65    T: AsPrimitive<f64>,
66{
67    match (value, base) {
68        (Some(value), Some(base)) => Some(value.as_().pow(base.as_())),
69        _ => None,
70    }
71}
72
73impl fmt::Display for PowFunction {
74    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
75        write!(f, "POW")
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use common_query::prelude::TypeSignature;
82    use datatypes::value::Value;
83    use datatypes::vectors::{Float32Vector, Int8Vector};
84
85    use super::*;
86    use crate::function::FunctionContext;
87    #[test]
88    fn test_pow_function() {
89        let pow = PowFunction;
90
91        assert_eq!("pow", pow.name());
92        assert_eq!(
93            ConcreteDataType::float64_datatype(),
94            pow.return_type(&[]).unwrap()
95        );
96
97        assert!(matches!(pow.signature(),
98                         Signature {
99                             type_signature: TypeSignature::Uniform(2, valid_types),
100                             volatility: Volatility::Immutable
101                         } if  valid_types == ConcreteDataType::numerics()
102        ));
103
104        let values = vec![1.0, 2.0, 3.0];
105        let bases = vec![0i8, -1i8, 3i8];
106
107        let args: Vec<VectorRef> = vec![
108            Arc::new(Float32Vector::from_vec(values.clone())),
109            Arc::new(Int8Vector::from_vec(bases.clone())),
110        ];
111
112        let vector = pow.eval(&FunctionContext::default(), &args).unwrap();
113        assert_eq!(3, vector.len());
114
115        for i in 0..3 {
116            let p: f64 = (values[i] as f64).pow(bases[i] as f64);
117            assert!(matches!(vector.get(i), Value::Float64(v) if v == p));
118        }
119    }
120}