common_function/scalars/math/
modulo.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt;
16use std::fmt::Display;
17
18use datafusion_common::arrow::compute;
19use datafusion_common::arrow::compute::kernels::numeric;
20use datafusion_common::arrow::datatypes::DataType;
21use datafusion_expr::type_coercion::aggregates::NUMERICS;
22use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
23
24use crate::function::{Function, extract_args};
25
26const NAME: &str = "mod";
27
28/// The function to find remainders
29#[derive(Clone, Debug)]
30pub(crate) struct ModuloFunction {
31    signature: Signature,
32}
33
34impl Default for ModuloFunction {
35    fn default() -> Self {
36        Self {
37            signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
38        }
39    }
40}
41
42impl Display for ModuloFunction {
43    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
44        write!(f, "{}", NAME.to_ascii_uppercase())
45    }
46}
47
48impl Function for ModuloFunction {
49    fn name(&self) -> &str {
50        NAME
51    }
52
53    fn return_type(&self, input_types: &[DataType]) -> datafusion_common::Result<DataType> {
54        if input_types.iter().all(DataType::is_signed_integer) {
55            Ok(DataType::Int64)
56        } else if input_types.iter().all(DataType::is_unsigned_integer) {
57            Ok(DataType::UInt64)
58        } else {
59            Ok(DataType::Float64)
60        }
61    }
62
63    fn signature(&self) -> &Signature {
64        &self.signature
65    }
66
67    fn invoke_with_args(
68        &self,
69        args: ScalarFunctionArgs,
70    ) -> datafusion_common::Result<ColumnarValue> {
71        let [nums, divs] = extract_args(self.name(), &args)?;
72        let array = numeric::rem(&nums, &divs)?;
73
74        let result = match nums.data_type() {
75            DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
76                compute::cast(&array, &DataType::Int64)
77            }
78            DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
79                compute::cast(&array, &DataType::UInt64)
80            }
81            DataType::Float32 | DataType::Float64 => compute::cast(&array, &DataType::Float64),
82            _ => unreachable!("unexpected datatype: {:?}", nums.data_type()),
83        }?;
84        Ok(ColumnarValue::Array(result))
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use std::sync::Arc;
91
92    use arrow_schema::Field;
93    use datafusion_common::arrow::array::{
94        AsArray, Float64Array, Int32Array, StringViewArray, UInt32Array,
95    };
96    use datafusion_common::arrow::datatypes::{Float64Type, Int64Type, UInt64Type};
97
98    use super::*;
99    #[test]
100    fn test_mod_function_signed() {
101        let function = ModuloFunction::default();
102        assert_eq!("mod", function.name());
103        assert_eq!(
104            DataType::Int64,
105            function.return_type(&[DataType::Int64]).unwrap()
106        );
107        assert_eq!(
108            DataType::Int64,
109            function.return_type(&[DataType::Int32]).unwrap()
110        );
111
112        let nums = vec![18, -17, 5, -6];
113        let divs = vec![4, 8, -5, -5];
114
115        let args = ScalarFunctionArgs {
116            args: vec![
117                ColumnarValue::Array(Arc::new(Int32Array::from(nums.clone()))),
118                ColumnarValue::Array(Arc::new(Int32Array::from(divs.clone()))),
119            ],
120            arg_fields: vec![],
121            number_rows: 4,
122            return_field: Arc::new(Field::new("x", DataType::Int64, false)),
123            config_options: Arc::new(Default::default()),
124        };
125        let result = function.invoke_with_args(args).unwrap();
126        let result = result.to_array(4).unwrap();
127        let result = result.as_primitive::<Int64Type>();
128        assert_eq!(result.len(), 4);
129        for i in 0..4 {
130            let p: i64 = (nums[i] % divs[i]) as i64;
131            assert_eq!(result.value(i), p);
132        }
133    }
134
135    #[test]
136    fn test_mod_function_unsigned() {
137        let function = ModuloFunction::default();
138        assert_eq!("mod", function.name());
139        assert_eq!(
140            DataType::UInt64,
141            function.return_type(&[DataType::UInt64]).unwrap()
142        );
143        assert_eq!(
144            DataType::UInt64,
145            function.return_type(&[DataType::UInt32]).unwrap()
146        );
147
148        let nums: Vec<u32> = vec![18, 17, 5, 6];
149        let divs: Vec<u32> = vec![4, 8, 5, 5];
150
151        let args = ScalarFunctionArgs {
152            args: vec![
153                ColumnarValue::Array(Arc::new(UInt32Array::from(nums.clone()))),
154                ColumnarValue::Array(Arc::new(UInt32Array::from(divs.clone()))),
155            ],
156            arg_fields: vec![],
157            number_rows: 4,
158            return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
159            config_options: Arc::new(Default::default()),
160        };
161        let result = function.invoke_with_args(args).unwrap();
162        let result = result.to_array(4).unwrap();
163        let result = result.as_primitive::<UInt64Type>();
164        assert_eq!(result.len(), 4);
165        for i in 0..4 {
166            let p: u64 = (nums[i] % divs[i]) as u64;
167            assert_eq!(result.value(i), p);
168        }
169    }
170
171    #[test]
172    fn test_mod_function_float() {
173        let function = ModuloFunction::default();
174        assert_eq!("mod", function.name());
175        assert_eq!(
176            DataType::Float64,
177            function.return_type(&[DataType::Float64]).unwrap()
178        );
179        assert_eq!(
180            DataType::Float64,
181            function.return_type(&[DataType::Float32]).unwrap()
182        );
183
184        let nums = vec![18.0, 17.0, 5.0, 6.0];
185        let divs = vec![4.0, 8.0, 5.0, 5.0];
186
187        let args = ScalarFunctionArgs {
188            args: vec![
189                ColumnarValue::Array(Arc::new(Float64Array::from(nums.clone()))),
190                ColumnarValue::Array(Arc::new(Float64Array::from(divs.clone()))),
191            ],
192            arg_fields: vec![],
193            number_rows: 4,
194            return_field: Arc::new(Field::new("x", DataType::Float64, false)),
195            config_options: Arc::new(Default::default()),
196        };
197        let result = function.invoke_with_args(args).unwrap();
198        let result = result.to_array(4).unwrap();
199        let result = result.as_primitive::<Float64Type>();
200        assert_eq!(result.len(), 4);
201        for i in 0..4 {
202            let p: f64 = nums[i] % divs[i];
203            assert_eq!(result.value(i), p);
204        }
205    }
206
207    #[test]
208    fn test_mod_function_errors() {
209        let function = ModuloFunction::default();
210        assert_eq!("mod", function.name());
211        let nums = vec![27];
212        let divs = vec![0];
213
214        let args = ScalarFunctionArgs {
215            args: vec![
216                ColumnarValue::Array(Arc::new(Int32Array::from(nums))),
217                ColumnarValue::Array(Arc::new(Int32Array::from(divs))),
218            ],
219            arg_fields: vec![],
220            number_rows: 1,
221            return_field: Arc::new(Field::new("x", DataType::Int64, false)),
222            config_options: Arc::new(Default::default()),
223        };
224        let result = function.invoke_with_args(args);
225        assert!(result.is_err());
226        let err_msg = result.unwrap_err().to_string();
227        assert_eq!(err_msg, "Arrow error: Divide by zero error");
228
229        let nums = vec![27];
230
231        let args = ScalarFunctionArgs {
232            args: vec![ColumnarValue::Array(Arc::new(Int32Array::from(nums)))],
233            arg_fields: vec![],
234            number_rows: 1,
235            return_field: Arc::new(Field::new("x", DataType::Int64, false)),
236            config_options: Arc::new(Default::default()),
237        };
238        let result = function.invoke_with_args(args);
239        assert!(result.is_err());
240        let err_msg = result.unwrap_err().to_string();
241        assert_eq!(
242            err_msg,
243            "Execution error: mod function requires 2 arguments, got 1"
244        );
245
246        let nums = vec!["27"];
247        let divs = vec!["4"];
248        let args = ScalarFunctionArgs {
249            args: vec![
250                ColumnarValue::Array(Arc::new(StringViewArray::from(nums))),
251                ColumnarValue::Array(Arc::new(StringViewArray::from(divs))),
252            ],
253            arg_fields: vec![],
254            number_rows: 1,
255            return_field: Arc::new(Field::new("x", DataType::Int64, false)),
256            config_options: Arc::new(Default::default()),
257        };
258        let result = function.invoke_with_args(args);
259        assert!(result.is_err());
260        let err_msg = result.unwrap_err().to_string();
261        assert!(err_msg.contains("Invalid arithmetic operation"));
262    }
263}