common_function/scalars/math/
modulo.rs1use std::fmt;
16use std::fmt::Display;
17
18use common_query::error;
19use common_query::error::{ArrowComputeSnafu, InvalidFuncArgsSnafu, Result};
20use datafusion::arrow::datatypes::DataType;
21use datafusion_expr::type_coercion::aggregates::NUMERICS;
22use datafusion_expr::{Signature, Volatility};
23use datatypes::arrow::compute;
24use datatypes::arrow::compute::kernels::numeric;
25use datatypes::arrow::datatypes::DataType as ArrowDataType;
26use datatypes::prelude::ConcreteDataType;
27use datatypes::vectors::{Helper, VectorRef};
28use snafu::{ResultExt, ensure};
29
30use crate::function::{Function, FunctionContext};
31
32const NAME: &str = "mod";
33
34#[derive(Clone, Debug, Default)]
36pub struct ModuloFunction;
37
38impl Display for ModuloFunction {
39 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
40 write!(f, "{}", NAME.to_ascii_uppercase())
41 }
42}
43
44impl Function for ModuloFunction {
45 fn name(&self) -> &str {
46 NAME
47 }
48
49 fn return_type(&self, input_types: &[DataType]) -> Result<DataType> {
50 if input_types.iter().all(DataType::is_signed_integer) {
51 Ok(DataType::Int64)
52 } else if input_types.iter().all(DataType::is_unsigned_integer) {
53 Ok(DataType::UInt64)
54 } else {
55 Ok(DataType::Float64)
56 }
57 }
58
59 fn signature(&self) -> Signature {
60 Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
61 }
62
63 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
64 ensure!(
65 columns.len() == 2,
66 InvalidFuncArgsSnafu {
67 err_msg: format!(
68 "The length of the args is not correct, expect exactly two, have: {}",
69 columns.len()
70 ),
71 }
72 );
73 let nums = &columns[0];
74 let divs = &columns[1];
75 let nums_arrow_array = &nums.to_arrow_array();
76 let divs_arrow_array = &divs.to_arrow_array();
77 let array = numeric::rem(nums_arrow_array, divs_arrow_array).context(ArrowComputeSnafu)?;
78
79 let result = match nums.data_type() {
80 ConcreteDataType::Int8(_)
81 | ConcreteDataType::Int16(_)
82 | ConcreteDataType::Int32(_)
83 | ConcreteDataType::Int64(_) => compute::cast(&array, &ArrowDataType::Int64),
84 ConcreteDataType::UInt8(_)
85 | ConcreteDataType::UInt16(_)
86 | ConcreteDataType::UInt32(_)
87 | ConcreteDataType::UInt64(_) => compute::cast(&array, &ArrowDataType::UInt64),
88 ConcreteDataType::Float32(_) | ConcreteDataType::Float64(_) => {
89 compute::cast(&array, &ArrowDataType::Float64)
90 }
91 _ => unreachable!("unexpected datatype: {:?}", nums.data_type()),
92 }
93 .context(ArrowComputeSnafu)?;
94 Helper::try_into_vector(&result).context(error::FromArrowArraySnafu)
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use std::sync::Arc;
101
102 use common_error::ext::ErrorExt;
103 use datatypes::value::Value;
104 use datatypes::vectors::{Float64Vector, Int32Vector, StringVector, UInt32Vector};
105
106 use super::*;
107 #[test]
108 fn test_mod_function_signed() {
109 let function = ModuloFunction;
110 assert_eq!("mod", function.name());
111 assert_eq!(
112 DataType::Int64,
113 function.return_type(&[DataType::Int64]).unwrap()
114 );
115 assert_eq!(
116 DataType::Int64,
117 function.return_type(&[DataType::Int32]).unwrap()
118 );
119
120 let nums = vec![18, -17, 5, -6];
121 let divs = vec![4, 8, -5, -5];
122
123 let args: Vec<VectorRef> = vec![
124 Arc::new(Int32Vector::from_vec(nums.clone())),
125 Arc::new(Int32Vector::from_vec(divs.clone())),
126 ];
127 let result = function.eval(&FunctionContext::default(), &args).unwrap();
128 assert_eq!(result.len(), 4);
129 for i in 0..4 {
130 let p: i64 = (nums[i] % divs[i]) as i64;
131 assert!(matches!(result.get(i), Value::Int64(v) if v == p));
132 }
133 }
134
135 #[test]
136 fn test_mod_function_unsigned() {
137 let function = ModuloFunction;
138 assert_eq!("mod", function.name());
139 assert_eq!(
140 DataType::UInt64,
141 function.return_type(&[DataType::UInt64]).unwrap()
142 );
143 assert_eq!(
144 DataType::UInt64,
145 function.return_type(&[DataType::UInt32]).unwrap()
146 );
147
148 let nums: Vec<u32> = vec![18, 17, 5, 6];
149 let divs: Vec<u32> = vec![4, 8, 5, 5];
150
151 let args: Vec<VectorRef> = vec![
152 Arc::new(UInt32Vector::from_vec(nums.clone())),
153 Arc::new(UInt32Vector::from_vec(divs.clone())),
154 ];
155 let result = function.eval(&FunctionContext::default(), &args).unwrap();
156 assert_eq!(result.len(), 4);
157 for i in 0..4 {
158 let p: u64 = (nums[i] % divs[i]) as u64;
159 assert!(matches!(result.get(i), Value::UInt64(v) if v == p));
160 }
161 }
162
163 #[test]
164 fn test_mod_function_float() {
165 let function = ModuloFunction;
166 assert_eq!("mod", function.name());
167 assert_eq!(
168 DataType::Float64,
169 function.return_type(&[DataType::Float64]).unwrap()
170 );
171 assert_eq!(
172 DataType::Float64,
173 function.return_type(&[DataType::Float32]).unwrap()
174 );
175
176 let nums = vec![18.0, 17.0, 5.0, 6.0];
177 let divs = vec![4.0, 8.0, 5.0, 5.0];
178
179 let args: Vec<VectorRef> = vec![
180 Arc::new(Float64Vector::from_vec(nums.clone())),
181 Arc::new(Float64Vector::from_vec(divs.clone())),
182 ];
183 let result = function.eval(&FunctionContext::default(), &args).unwrap();
184 assert_eq!(result.len(), 4);
185 for i in 0..4 {
186 let p: f64 = nums[i] % divs[i];
187 assert!(matches!(result.get(i), Value::Float64(v) if v == p));
188 }
189 }
190
191 #[test]
192 fn test_mod_function_errors() {
193 let function = ModuloFunction;
194 assert_eq!("mod", function.name());
195 let nums = vec![27];
196 let divs = vec![0];
197
198 let args: Vec<VectorRef> = vec![
199 Arc::new(Int32Vector::from_vec(nums.clone())),
200 Arc::new(Int32Vector::from_vec(divs.clone())),
201 ];
202 let result = function.eval(&FunctionContext::default(), &args);
203 assert!(result.is_err());
204 let err_msg = result.unwrap_err().output_msg();
205 assert_eq!(
206 err_msg,
207 "Failed to perform compute operation on arrow arrays: Divide by zero error"
208 );
209
210 let nums = vec![27];
211
212 let args: Vec<VectorRef> = vec![Arc::new(Int32Vector::from_vec(nums.clone()))];
213 let result = function.eval(&FunctionContext::default(), &args);
214 assert!(result.is_err());
215 let err_msg = result.unwrap_err().output_msg();
216 assert!(
217 err_msg.contains("The length of the args is not correct, expect exactly two, have: 1")
218 );
219
220 let nums = vec!["27"];
221 let divs = vec!["4"];
222 let args: Vec<VectorRef> = vec![
223 Arc::new(StringVector::from(nums.clone())),
224 Arc::new(StringVector::from(divs.clone())),
225 ];
226 let result = function.eval(&FunctionContext::default(), &args);
227 assert!(result.is_err());
228 let err_msg = result.unwrap_err().output_msg();
229 assert!(err_msg.contains("Invalid arithmetic operation"));
230 }
231}