common_function/scalars/math/
modulo.rs1use std::fmt;
16use std::fmt::Display;
17
18use datafusion_common::arrow::compute;
19use datafusion_common::arrow::compute::kernels::numeric;
20use datafusion_common::arrow::datatypes::DataType;
21use datafusion_expr::type_coercion::aggregates::NUMERICS;
22use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
23
24use crate::function::{Function, extract_args};
25
26const NAME: &str = "mod";
27
28#[derive(Clone, Debug)]
30pub(crate) struct ModuloFunction {
31 signature: Signature,
32}
33
34impl Default for ModuloFunction {
35 fn default() -> Self {
36 Self {
37 signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
38 }
39 }
40}
41
42impl Display for ModuloFunction {
43 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
44 write!(f, "{}", NAME.to_ascii_uppercase())
45 }
46}
47
48impl Function for ModuloFunction {
49 fn name(&self) -> &str {
50 NAME
51 }
52
53 fn return_type(&self, input_types: &[DataType]) -> datafusion_common::Result<DataType> {
54 if input_types.iter().all(DataType::is_signed_integer) {
55 Ok(DataType::Int64)
56 } else if input_types.iter().all(DataType::is_unsigned_integer) {
57 Ok(DataType::UInt64)
58 } else {
59 Ok(DataType::Float64)
60 }
61 }
62
63 fn signature(&self) -> &Signature {
64 &self.signature
65 }
66
67 fn invoke_with_args(
68 &self,
69 args: ScalarFunctionArgs,
70 ) -> datafusion_common::Result<ColumnarValue> {
71 let [nums, divs] = extract_args(self.name(), &args)?;
72 let array = numeric::rem(&nums, &divs)?;
73
74 let result = match nums.data_type() {
75 DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
76 compute::cast(&array, &DataType::Int64)
77 }
78 DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
79 compute::cast(&array, &DataType::UInt64)
80 }
81 DataType::Float32 | DataType::Float64 => compute::cast(&array, &DataType::Float64),
82 _ => unreachable!("unexpected datatype: {:?}", nums.data_type()),
83 }?;
84 Ok(ColumnarValue::Array(result))
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use std::sync::Arc;
91
92 use arrow_schema::Field;
93 use datafusion_common::arrow::array::{
94 AsArray, Float64Array, Int32Array, StringViewArray, UInt32Array,
95 };
96 use datafusion_common::arrow::datatypes::{Float64Type, Int64Type, UInt64Type};
97
98 use super::*;
99 #[test]
100 fn test_mod_function_signed() {
101 let function = ModuloFunction::default();
102 assert_eq!("mod", function.name());
103 assert_eq!(
104 DataType::Int64,
105 function.return_type(&[DataType::Int64]).unwrap()
106 );
107 assert_eq!(
108 DataType::Int64,
109 function.return_type(&[DataType::Int32]).unwrap()
110 );
111
112 let nums = vec![18, -17, 5, -6];
113 let divs = vec![4, 8, -5, -5];
114
115 let args = ScalarFunctionArgs {
116 args: vec![
117 ColumnarValue::Array(Arc::new(Int32Array::from(nums.clone()))),
118 ColumnarValue::Array(Arc::new(Int32Array::from(divs.clone()))),
119 ],
120 arg_fields: vec![],
121 number_rows: 4,
122 return_field: Arc::new(Field::new("x", DataType::Int64, false)),
123 config_options: Arc::new(Default::default()),
124 };
125 let result = function.invoke_with_args(args).unwrap();
126 let result = result.to_array(4).unwrap();
127 let result = result.as_primitive::<Int64Type>();
128 assert_eq!(result.len(), 4);
129 for i in 0..4 {
130 let p: i64 = (nums[i] % divs[i]) as i64;
131 assert_eq!(result.value(i), p);
132 }
133 }
134
135 #[test]
136 fn test_mod_function_unsigned() {
137 let function = ModuloFunction::default();
138 assert_eq!("mod", function.name());
139 assert_eq!(
140 DataType::UInt64,
141 function.return_type(&[DataType::UInt64]).unwrap()
142 );
143 assert_eq!(
144 DataType::UInt64,
145 function.return_type(&[DataType::UInt32]).unwrap()
146 );
147
148 let nums: Vec<u32> = vec![18, 17, 5, 6];
149 let divs: Vec<u32> = vec![4, 8, 5, 5];
150
151 let args = ScalarFunctionArgs {
152 args: vec![
153 ColumnarValue::Array(Arc::new(UInt32Array::from(nums.clone()))),
154 ColumnarValue::Array(Arc::new(UInt32Array::from(divs.clone()))),
155 ],
156 arg_fields: vec![],
157 number_rows: 4,
158 return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
159 config_options: Arc::new(Default::default()),
160 };
161 let result = function.invoke_with_args(args).unwrap();
162 let result = result.to_array(4).unwrap();
163 let result = result.as_primitive::<UInt64Type>();
164 assert_eq!(result.len(), 4);
165 for i in 0..4 {
166 let p: u64 = (nums[i] % divs[i]) as u64;
167 assert_eq!(result.value(i), p);
168 }
169 }
170
171 #[test]
172 fn test_mod_function_float() {
173 let function = ModuloFunction::default();
174 assert_eq!("mod", function.name());
175 assert_eq!(
176 DataType::Float64,
177 function.return_type(&[DataType::Float64]).unwrap()
178 );
179 assert_eq!(
180 DataType::Float64,
181 function.return_type(&[DataType::Float32]).unwrap()
182 );
183
184 let nums = vec![18.0, 17.0, 5.0, 6.0];
185 let divs = vec![4.0, 8.0, 5.0, 5.0];
186
187 let args = ScalarFunctionArgs {
188 args: vec![
189 ColumnarValue::Array(Arc::new(Float64Array::from(nums.clone()))),
190 ColumnarValue::Array(Arc::new(Float64Array::from(divs.clone()))),
191 ],
192 arg_fields: vec![],
193 number_rows: 4,
194 return_field: Arc::new(Field::new("x", DataType::Float64, false)),
195 config_options: Arc::new(Default::default()),
196 };
197 let result = function.invoke_with_args(args).unwrap();
198 let result = result.to_array(4).unwrap();
199 let result = result.as_primitive::<Float64Type>();
200 assert_eq!(result.len(), 4);
201 for i in 0..4 {
202 let p: f64 = nums[i] % divs[i];
203 assert_eq!(result.value(i), p);
204 }
205 }
206
207 #[test]
208 fn test_mod_function_errors() {
209 let function = ModuloFunction::default();
210 assert_eq!("mod", function.name());
211 let nums = vec![27];
212 let divs = vec![0];
213
214 let args = ScalarFunctionArgs {
215 args: vec![
216 ColumnarValue::Array(Arc::new(Int32Array::from(nums))),
217 ColumnarValue::Array(Arc::new(Int32Array::from(divs))),
218 ],
219 arg_fields: vec![],
220 number_rows: 1,
221 return_field: Arc::new(Field::new("x", DataType::Int64, false)),
222 config_options: Arc::new(Default::default()),
223 };
224 let result = function.invoke_with_args(args);
225 assert!(result.is_err());
226 let err_msg = result.unwrap_err().to_string();
227 assert_eq!(err_msg, "Arrow error: Divide by zero error");
228
229 let nums = vec![27];
230
231 let args = ScalarFunctionArgs {
232 args: vec![ColumnarValue::Array(Arc::new(Int32Array::from(nums)))],
233 arg_fields: vec![],
234 number_rows: 1,
235 return_field: Arc::new(Field::new("x", DataType::Int64, false)),
236 config_options: Arc::new(Default::default()),
237 };
238 let result = function.invoke_with_args(args);
239 assert!(result.is_err());
240 let err_msg = result.unwrap_err().to_string();
241 assert_eq!(
242 err_msg,
243 "Execution error: mod function requires 2 arguments, got 1"
244 );
245
246 let nums = vec!["27"];
247 let divs = vec!["4"];
248 let args = ScalarFunctionArgs {
249 args: vec![
250 ColumnarValue::Array(Arc::new(StringViewArray::from(nums))),
251 ColumnarValue::Array(Arc::new(StringViewArray::from(divs))),
252 ],
253 arg_fields: vec![],
254 number_rows: 1,
255 return_field: Arc::new(Field::new("x", DataType::Int64, false)),
256 config_options: Arc::new(Default::default()),
257 };
258 let result = function.invoke_with_args(args);
259 assert!(result.is_err());
260 let err_msg = result.unwrap_err().to_string();
261 assert!(err_msg.contains("Invalid arithmetic operation"));
262 }
263}