common_function/scalars/math/
modulo.rs1use std::fmt;
16use std::fmt::Display;
17
18use common_query::error;
19use common_query::error::{ArrowComputeSnafu, InvalidFuncArgsSnafu, Result};
20use common_query::prelude::{Signature, Volatility};
21use datatypes::arrow::compute;
22use datatypes::arrow::compute::kernels::numeric;
23use datatypes::arrow::datatypes::DataType as ArrowDataType;
24use datatypes::prelude::ConcreteDataType;
25use datatypes::vectors::{Helper, VectorRef};
26use snafu::{ensure, ResultExt};
27
28use crate::function::{Function, FunctionContext};
29
30const NAME: &str = "mod";
31
32#[derive(Clone, Debug, Default)]
34pub struct ModuloFunction;
35
36impl Display for ModuloFunction {
37 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
38 write!(f, "{}", NAME.to_ascii_uppercase())
39 }
40}
41
42impl Function for ModuloFunction {
43 fn name(&self) -> &str {
44 NAME
45 }
46
47 fn return_type(&self, input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
48 if input_types.iter().all(ConcreteDataType::is_signed) {
49 Ok(ConcreteDataType::int64_datatype())
50 } else if input_types.iter().all(ConcreteDataType::is_unsigned) {
51 Ok(ConcreteDataType::uint64_datatype())
52 } else {
53 Ok(ConcreteDataType::float64_datatype())
54 }
55 }
56
57 fn signature(&self) -> Signature {
58 Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable)
59 }
60
61 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
62 ensure!(
63 columns.len() == 2,
64 InvalidFuncArgsSnafu {
65 err_msg: format!(
66 "The length of the args is not correct, expect exactly two, have: {}",
67 columns.len()
68 ),
69 }
70 );
71 let nums = &columns[0];
72 let divs = &columns[1];
73 let nums_arrow_array = &nums.to_arrow_array();
74 let divs_arrow_array = &divs.to_arrow_array();
75 let array = numeric::rem(nums_arrow_array, divs_arrow_array).context(ArrowComputeSnafu)?;
76
77 let result = match nums.data_type() {
78 ConcreteDataType::Int8(_)
79 | ConcreteDataType::Int16(_)
80 | ConcreteDataType::Int32(_)
81 | ConcreteDataType::Int64(_) => compute::cast(&array, &ArrowDataType::Int64),
82 ConcreteDataType::UInt8(_)
83 | ConcreteDataType::UInt16(_)
84 | ConcreteDataType::UInt32(_)
85 | ConcreteDataType::UInt64(_) => compute::cast(&array, &ArrowDataType::UInt64),
86 ConcreteDataType::Float32(_) | ConcreteDataType::Float64(_) => {
87 compute::cast(&array, &ArrowDataType::Float64)
88 }
89 _ => unreachable!("unexpected datatype: {:?}", nums.data_type()),
90 }
91 .context(ArrowComputeSnafu)?;
92 Helper::try_into_vector(&result).context(error::FromArrowArraySnafu)
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use std::sync::Arc;
99
100 use common_error::ext::ErrorExt;
101 use datatypes::value::Value;
102 use datatypes::vectors::{Float64Vector, Int32Vector, StringVector, UInt32Vector};
103
104 use super::*;
105 #[test]
106 fn test_mod_function_signed() {
107 let function = ModuloFunction;
108 assert_eq!("mod", function.name());
109 assert_eq!(
110 ConcreteDataType::int64_datatype(),
111 function
112 .return_type(&[ConcreteDataType::int64_datatype()])
113 .unwrap()
114 );
115 assert_eq!(
116 ConcreteDataType::int64_datatype(),
117 function
118 .return_type(&[ConcreteDataType::int32_datatype()])
119 .unwrap()
120 );
121
122 let nums = vec![18, -17, 5, -6];
123 let divs = vec![4, 8, -5, -5];
124
125 let args: Vec<VectorRef> = vec![
126 Arc::new(Int32Vector::from_vec(nums.clone())),
127 Arc::new(Int32Vector::from_vec(divs.clone())),
128 ];
129 let result = function.eval(&FunctionContext::default(), &args).unwrap();
130 assert_eq!(result.len(), 4);
131 for i in 0..4 {
132 let p: i64 = (nums[i] % divs[i]) as i64;
133 assert!(matches!(result.get(i), Value::Int64(v) if v == p));
134 }
135 }
136
137 #[test]
138 fn test_mod_function_unsigned() {
139 let function = ModuloFunction;
140 assert_eq!("mod", function.name());
141 assert_eq!(
142 ConcreteDataType::uint64_datatype(),
143 function
144 .return_type(&[ConcreteDataType::uint64_datatype()])
145 .unwrap()
146 );
147 assert_eq!(
148 ConcreteDataType::uint64_datatype(),
149 function
150 .return_type(&[ConcreteDataType::uint32_datatype()])
151 .unwrap()
152 );
153
154 let nums: Vec<u32> = vec![18, 17, 5, 6];
155 let divs: Vec<u32> = vec![4, 8, 5, 5];
156
157 let args: Vec<VectorRef> = vec![
158 Arc::new(UInt32Vector::from_vec(nums.clone())),
159 Arc::new(UInt32Vector::from_vec(divs.clone())),
160 ];
161 let result = function.eval(&FunctionContext::default(), &args).unwrap();
162 assert_eq!(result.len(), 4);
163 for i in 0..4 {
164 let p: u64 = (nums[i] % divs[i]) as u64;
165 assert!(matches!(result.get(i), Value::UInt64(v) if v == p));
166 }
167 }
168
169 #[test]
170 fn test_mod_function_float() {
171 let function = ModuloFunction;
172 assert_eq!("mod", function.name());
173 assert_eq!(
174 ConcreteDataType::float64_datatype(),
175 function
176 .return_type(&[ConcreteDataType::float64_datatype()])
177 .unwrap()
178 );
179 assert_eq!(
180 ConcreteDataType::float64_datatype(),
181 function
182 .return_type(&[ConcreteDataType::float32_datatype()])
183 .unwrap()
184 );
185
186 let nums = vec![18.0, 17.0, 5.0, 6.0];
187 let divs = vec![4.0, 8.0, 5.0, 5.0];
188
189 let args: Vec<VectorRef> = vec![
190 Arc::new(Float64Vector::from_vec(nums.clone())),
191 Arc::new(Float64Vector::from_vec(divs.clone())),
192 ];
193 let result = function.eval(&FunctionContext::default(), &args).unwrap();
194 assert_eq!(result.len(), 4);
195 for i in 0..4 {
196 let p: f64 = nums[i] % divs[i];
197 assert!(matches!(result.get(i), Value::Float64(v) if v == p));
198 }
199 }
200
201 #[test]
202 fn test_mod_function_errors() {
203 let function = ModuloFunction;
204 assert_eq!("mod", function.name());
205 let nums = vec![27];
206 let divs = vec![0];
207
208 let args: Vec<VectorRef> = vec![
209 Arc::new(Int32Vector::from_vec(nums.clone())),
210 Arc::new(Int32Vector::from_vec(divs.clone())),
211 ];
212 let result = function.eval(&FunctionContext::default(), &args);
213 assert!(result.is_err());
214 let err_msg = result.unwrap_err().output_msg();
215 assert_eq!(
216 err_msg,
217 "Failed to perform compute operation on arrow arrays: Divide by zero error"
218 );
219
220 let nums = vec![27];
221
222 let args: Vec<VectorRef> = vec![Arc::new(Int32Vector::from_vec(nums.clone()))];
223 let result = function.eval(&FunctionContext::default(), &args);
224 assert!(result.is_err());
225 let err_msg = result.unwrap_err().output_msg();
226 assert!(
227 err_msg.contains("The length of the args is not correct, expect exactly two, have: 1")
228 );
229
230 let nums = vec!["27"];
231 let divs = vec!["4"];
232 let args: Vec<VectorRef> = vec![
233 Arc::new(StringVector::from(nums.clone())),
234 Arc::new(StringVector::from(divs.clone())),
235 ];
236 let result = function.eval(&FunctionContext::default(), &args);
237 assert!(result.is_err());
238 let err_msg = result.unwrap_err().output_msg();
239 assert!(err_msg.contains("Invalid arithmetic operation"));
240 }
241}