common_function/scalars/math/
modulo.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt;
16use std::fmt::Display;
17
18use common_query::error;
19use common_query::error::{ArrowComputeSnafu, InvalidFuncArgsSnafu, Result};
20use datafusion::arrow::datatypes::DataType;
21use datafusion_expr::type_coercion::aggregates::NUMERICS;
22use datafusion_expr::{Signature, Volatility};
23use datatypes::arrow::compute;
24use datatypes::arrow::compute::kernels::numeric;
25use datatypes::arrow::datatypes::DataType as ArrowDataType;
26use datatypes::prelude::ConcreteDataType;
27use datatypes::vectors::{Helper, VectorRef};
28use snafu::{ResultExt, ensure};
29
30use crate::function::{Function, FunctionContext};
31
32const NAME: &str = "mod";
33
34/// The function to find remainders
35#[derive(Clone, Debug, Default)]
36pub struct ModuloFunction;
37
38impl Display for ModuloFunction {
39    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
40        write!(f, "{}", NAME.to_ascii_uppercase())
41    }
42}
43
44impl Function for ModuloFunction {
45    fn name(&self) -> &str {
46        NAME
47    }
48
49    fn return_type(&self, input_types: &[DataType]) -> Result<DataType> {
50        if input_types.iter().all(DataType::is_signed_integer) {
51            Ok(DataType::Int64)
52        } else if input_types.iter().all(DataType::is_unsigned_integer) {
53            Ok(DataType::UInt64)
54        } else {
55            Ok(DataType::Float64)
56        }
57    }
58
59    fn signature(&self) -> Signature {
60        Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
61    }
62
63    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
64        ensure!(
65            columns.len() == 2,
66            InvalidFuncArgsSnafu {
67                err_msg: format!(
68                    "The length of the args is not correct, expect exactly two, have: {}",
69                    columns.len()
70                ),
71            }
72        );
73        let nums = &columns[0];
74        let divs = &columns[1];
75        let nums_arrow_array = &nums.to_arrow_array();
76        let divs_arrow_array = &divs.to_arrow_array();
77        let array = numeric::rem(nums_arrow_array, divs_arrow_array).context(ArrowComputeSnafu)?;
78
79        let result = match nums.data_type() {
80            ConcreteDataType::Int8(_)
81            | ConcreteDataType::Int16(_)
82            | ConcreteDataType::Int32(_)
83            | ConcreteDataType::Int64(_) => compute::cast(&array, &ArrowDataType::Int64),
84            ConcreteDataType::UInt8(_)
85            | ConcreteDataType::UInt16(_)
86            | ConcreteDataType::UInt32(_)
87            | ConcreteDataType::UInt64(_) => compute::cast(&array, &ArrowDataType::UInt64),
88            ConcreteDataType::Float32(_) | ConcreteDataType::Float64(_) => {
89                compute::cast(&array, &ArrowDataType::Float64)
90            }
91            _ => unreachable!("unexpected datatype: {:?}", nums.data_type()),
92        }
93        .context(ArrowComputeSnafu)?;
94        Helper::try_into_vector(&result).context(error::FromArrowArraySnafu)
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use std::sync::Arc;
101
102    use common_error::ext::ErrorExt;
103    use datatypes::value::Value;
104    use datatypes::vectors::{Float64Vector, Int32Vector, StringVector, UInt32Vector};
105
106    use super::*;
107    #[test]
108    fn test_mod_function_signed() {
109        let function = ModuloFunction;
110        assert_eq!("mod", function.name());
111        assert_eq!(
112            DataType::Int64,
113            function.return_type(&[DataType::Int64]).unwrap()
114        );
115        assert_eq!(
116            DataType::Int64,
117            function.return_type(&[DataType::Int32]).unwrap()
118        );
119
120        let nums = vec![18, -17, 5, -6];
121        let divs = vec![4, 8, -5, -5];
122
123        let args: Vec<VectorRef> = vec![
124            Arc::new(Int32Vector::from_vec(nums.clone())),
125            Arc::new(Int32Vector::from_vec(divs.clone())),
126        ];
127        let result = function.eval(&FunctionContext::default(), &args).unwrap();
128        assert_eq!(result.len(), 4);
129        for i in 0..4 {
130            let p: i64 = (nums[i] % divs[i]) as i64;
131            assert!(matches!(result.get(i), Value::Int64(v) if v == p));
132        }
133    }
134
135    #[test]
136    fn test_mod_function_unsigned() {
137        let function = ModuloFunction;
138        assert_eq!("mod", function.name());
139        assert_eq!(
140            DataType::UInt64,
141            function.return_type(&[DataType::UInt64]).unwrap()
142        );
143        assert_eq!(
144            DataType::UInt64,
145            function.return_type(&[DataType::UInt32]).unwrap()
146        );
147
148        let nums: Vec<u32> = vec![18, 17, 5, 6];
149        let divs: Vec<u32> = vec![4, 8, 5, 5];
150
151        let args: Vec<VectorRef> = vec![
152            Arc::new(UInt32Vector::from_vec(nums.clone())),
153            Arc::new(UInt32Vector::from_vec(divs.clone())),
154        ];
155        let result = function.eval(&FunctionContext::default(), &args).unwrap();
156        assert_eq!(result.len(), 4);
157        for i in 0..4 {
158            let p: u64 = (nums[i] % divs[i]) as u64;
159            assert!(matches!(result.get(i), Value::UInt64(v) if v == p));
160        }
161    }
162
163    #[test]
164    fn test_mod_function_float() {
165        let function = ModuloFunction;
166        assert_eq!("mod", function.name());
167        assert_eq!(
168            DataType::Float64,
169            function.return_type(&[DataType::Float64]).unwrap()
170        );
171        assert_eq!(
172            DataType::Float64,
173            function.return_type(&[DataType::Float32]).unwrap()
174        );
175
176        let nums = vec![18.0, 17.0, 5.0, 6.0];
177        let divs = vec![4.0, 8.0, 5.0, 5.0];
178
179        let args: Vec<VectorRef> = vec![
180            Arc::new(Float64Vector::from_vec(nums.clone())),
181            Arc::new(Float64Vector::from_vec(divs.clone())),
182        ];
183        let result = function.eval(&FunctionContext::default(), &args).unwrap();
184        assert_eq!(result.len(), 4);
185        for i in 0..4 {
186            let p: f64 = nums[i] % divs[i];
187            assert!(matches!(result.get(i), Value::Float64(v) if v == p));
188        }
189    }
190
191    #[test]
192    fn test_mod_function_errors() {
193        let function = ModuloFunction;
194        assert_eq!("mod", function.name());
195        let nums = vec![27];
196        let divs = vec![0];
197
198        let args: Vec<VectorRef> = vec![
199            Arc::new(Int32Vector::from_vec(nums.clone())),
200            Arc::new(Int32Vector::from_vec(divs.clone())),
201        ];
202        let result = function.eval(&FunctionContext::default(), &args);
203        assert!(result.is_err());
204        let err_msg = result.unwrap_err().output_msg();
205        assert_eq!(
206            err_msg,
207            "Failed to perform compute operation on arrow arrays: Divide by zero error"
208        );
209
210        let nums = vec![27];
211
212        let args: Vec<VectorRef> = vec![Arc::new(Int32Vector::from_vec(nums.clone()))];
213        let result = function.eval(&FunctionContext::default(), &args);
214        assert!(result.is_err());
215        let err_msg = result.unwrap_err().output_msg();
216        assert!(
217            err_msg.contains("The length of the args is not correct, expect exactly two, have: 1")
218        );
219
220        let nums = vec!["27"];
221        let divs = vec!["4"];
222        let args: Vec<VectorRef> = vec![
223            Arc::new(StringVector::from(nums.clone())),
224            Arc::new(StringVector::from(divs.clone())),
225        ];
226        let result = function.eval(&FunctionContext::default(), &args);
227        assert!(result.is_err());
228        let err_msg = result.unwrap_err().output_msg();
229        assert!(err_msg.contains("Invalid arithmetic operation"));
230    }
231}