common_function/scalars/math/
pow.rs1use std::fmt;
16use std::sync::Arc;
17
18use common_query::error::Result;
19use common_query::prelude::{Signature, Volatility};
20use datatypes::data_type::DataType;
21use datatypes::prelude::ConcreteDataType;
22use datatypes::types::LogicalPrimitiveType;
23use datatypes::vectors::VectorRef;
24use datatypes::with_match_primitive_type_id;
25use num::traits::Pow;
26use num_traits::AsPrimitive;
27
28use crate::function::{Function, FunctionContext};
29use crate::scalars::expression::{scalar_binary_op, EvalContext};
30
31#[derive(Clone, Debug, Default)]
32pub struct PowFunction;
33
34impl Function for PowFunction {
35 fn name(&self) -> &str {
36 "pow"
37 }
38
39 fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
40 Ok(ConcreteDataType::float64_datatype())
41 }
42
43 fn signature(&self) -> Signature {
44 Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable)
45 }
46
47 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
48 with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
49 with_match_primitive_type_id!(columns[1].data_type().logical_type_id(), |$T| {
50 let col = scalar_binary_op::<<$S as LogicalPrimitiveType>::Native, <$T as LogicalPrimitiveType>::Native, f64, _>(&columns[0], &columns[1], scalar_pow, &mut EvalContext::default())?;
51 Ok(Arc::new(col))
52 },{
53 unreachable!()
54 })
55 },{
56 unreachable!()
57 })
58 }
59}
60
61#[inline]
62fn scalar_pow<S, T>(value: Option<S>, base: Option<T>, _ctx: &mut EvalContext) -> Option<f64>
63where
64 S: AsPrimitive<f64>,
65 T: AsPrimitive<f64>,
66{
67 match (value, base) {
68 (Some(value), Some(base)) => Some(value.as_().pow(base.as_())),
69 _ => None,
70 }
71}
72
73impl fmt::Display for PowFunction {
74 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
75 write!(f, "POW")
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use common_query::prelude::TypeSignature;
82 use datatypes::value::Value;
83 use datatypes::vectors::{Float32Vector, Int8Vector};
84
85 use super::*;
86 use crate::function::FunctionContext;
87 #[test]
88 fn test_pow_function() {
89 let pow = PowFunction;
90
91 assert_eq!("pow", pow.name());
92 assert_eq!(
93 ConcreteDataType::float64_datatype(),
94 pow.return_type(&[]).unwrap()
95 );
96
97 assert!(matches!(pow.signature(),
98 Signature {
99 type_signature: TypeSignature::Uniform(2, valid_types),
100 volatility: Volatility::Immutable
101 } if valid_types == ConcreteDataType::numerics()
102 ));
103
104 let values = vec![1.0, 2.0, 3.0];
105 let bases = vec![0i8, -1i8, 3i8];
106
107 let args: Vec<VectorRef> = vec![
108 Arc::new(Float32Vector::from_vec(values.clone())),
109 Arc::new(Int8Vector::from_vec(bases.clone())),
110 ];
111
112 let vector = pow.eval(&FunctionContext::default(), &args).unwrap();
113 assert_eq!(3, vector.len());
114
115 for i in 0..3 {
116 let p: f64 = (values[i] as f64).pow(bases[i] as f64);
117 assert!(matches!(vector.get(i), Value::Float64(v) if v == p));
118 }
119 }
120}