common_function/scalars/math/
modulo.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt;
16use std::fmt::Display;
17
18use common_query::error;
19use common_query::error::{ArrowComputeSnafu, InvalidFuncArgsSnafu, Result};
20use common_query::prelude::{Signature, Volatility};
21use datatypes::arrow::compute;
22use datatypes::arrow::compute::kernels::numeric;
23use datatypes::arrow::datatypes::DataType as ArrowDataType;
24use datatypes::prelude::ConcreteDataType;
25use datatypes::vectors::{Helper, VectorRef};
26use snafu::{ensure, ResultExt};
27
28use crate::function::{Function, FunctionContext};
29
30const NAME: &str = "mod";
31
32/// The function to find remainders
33#[derive(Clone, Debug, Default)]
34pub struct ModuloFunction;
35
36impl Display for ModuloFunction {
37    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
38        write!(f, "{}", NAME.to_ascii_uppercase())
39    }
40}
41
42impl Function for ModuloFunction {
43    fn name(&self) -> &str {
44        NAME
45    }
46
47    fn return_type(&self, input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
48        if input_types.iter().all(ConcreteDataType::is_signed) {
49            Ok(ConcreteDataType::int64_datatype())
50        } else if input_types.iter().all(ConcreteDataType::is_unsigned) {
51            Ok(ConcreteDataType::uint64_datatype())
52        } else {
53            Ok(ConcreteDataType::float64_datatype())
54        }
55    }
56
57    fn signature(&self) -> Signature {
58        Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable)
59    }
60
61    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
62        ensure!(
63            columns.len() == 2,
64            InvalidFuncArgsSnafu {
65                err_msg: format!(
66                    "The length of the args is not correct, expect exactly two, have: {}",
67                    columns.len()
68                ),
69            }
70        );
71        let nums = &columns[0];
72        let divs = &columns[1];
73        let nums_arrow_array = &nums.to_arrow_array();
74        let divs_arrow_array = &divs.to_arrow_array();
75        let array = numeric::rem(nums_arrow_array, divs_arrow_array).context(ArrowComputeSnafu)?;
76
77        let result = match nums.data_type() {
78            ConcreteDataType::Int8(_)
79            | ConcreteDataType::Int16(_)
80            | ConcreteDataType::Int32(_)
81            | ConcreteDataType::Int64(_) => compute::cast(&array, &ArrowDataType::Int64),
82            ConcreteDataType::UInt8(_)
83            | ConcreteDataType::UInt16(_)
84            | ConcreteDataType::UInt32(_)
85            | ConcreteDataType::UInt64(_) => compute::cast(&array, &ArrowDataType::UInt64),
86            ConcreteDataType::Float32(_) | ConcreteDataType::Float64(_) => {
87                compute::cast(&array, &ArrowDataType::Float64)
88            }
89            _ => unreachable!("unexpected datatype: {:?}", nums.data_type()),
90        }
91        .context(ArrowComputeSnafu)?;
92        Helper::try_into_vector(&result).context(error::FromArrowArraySnafu)
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use std::sync::Arc;
99
100    use common_error::ext::ErrorExt;
101    use datatypes::value::Value;
102    use datatypes::vectors::{Float64Vector, Int32Vector, StringVector, UInt32Vector};
103
104    use super::*;
105    #[test]
106    fn test_mod_function_signed() {
107        let function = ModuloFunction;
108        assert_eq!("mod", function.name());
109        assert_eq!(
110            ConcreteDataType::int64_datatype(),
111            function
112                .return_type(&[ConcreteDataType::int64_datatype()])
113                .unwrap()
114        );
115        assert_eq!(
116            ConcreteDataType::int64_datatype(),
117            function
118                .return_type(&[ConcreteDataType::int32_datatype()])
119                .unwrap()
120        );
121
122        let nums = vec![18, -17, 5, -6];
123        let divs = vec![4, 8, -5, -5];
124
125        let args: Vec<VectorRef> = vec![
126            Arc::new(Int32Vector::from_vec(nums.clone())),
127            Arc::new(Int32Vector::from_vec(divs.clone())),
128        ];
129        let result = function.eval(&FunctionContext::default(), &args).unwrap();
130        assert_eq!(result.len(), 4);
131        for i in 0..4 {
132            let p: i64 = (nums[i] % divs[i]) as i64;
133            assert!(matches!(result.get(i), Value::Int64(v) if v == p));
134        }
135    }
136
137    #[test]
138    fn test_mod_function_unsigned() {
139        let function = ModuloFunction;
140        assert_eq!("mod", function.name());
141        assert_eq!(
142            ConcreteDataType::uint64_datatype(),
143            function
144                .return_type(&[ConcreteDataType::uint64_datatype()])
145                .unwrap()
146        );
147        assert_eq!(
148            ConcreteDataType::uint64_datatype(),
149            function
150                .return_type(&[ConcreteDataType::uint32_datatype()])
151                .unwrap()
152        );
153
154        let nums: Vec<u32> = vec![18, 17, 5, 6];
155        let divs: Vec<u32> = vec![4, 8, 5, 5];
156
157        let args: Vec<VectorRef> = vec![
158            Arc::new(UInt32Vector::from_vec(nums.clone())),
159            Arc::new(UInt32Vector::from_vec(divs.clone())),
160        ];
161        let result = function.eval(&FunctionContext::default(), &args).unwrap();
162        assert_eq!(result.len(), 4);
163        for i in 0..4 {
164            let p: u64 = (nums[i] % divs[i]) as u64;
165            assert!(matches!(result.get(i), Value::UInt64(v) if v == p));
166        }
167    }
168
169    #[test]
170    fn test_mod_function_float() {
171        let function = ModuloFunction;
172        assert_eq!("mod", function.name());
173        assert_eq!(
174            ConcreteDataType::float64_datatype(),
175            function
176                .return_type(&[ConcreteDataType::float64_datatype()])
177                .unwrap()
178        );
179        assert_eq!(
180            ConcreteDataType::float64_datatype(),
181            function
182                .return_type(&[ConcreteDataType::float32_datatype()])
183                .unwrap()
184        );
185
186        let nums = vec![18.0, 17.0, 5.0, 6.0];
187        let divs = vec![4.0, 8.0, 5.0, 5.0];
188
189        let args: Vec<VectorRef> = vec![
190            Arc::new(Float64Vector::from_vec(nums.clone())),
191            Arc::new(Float64Vector::from_vec(divs.clone())),
192        ];
193        let result = function.eval(&FunctionContext::default(), &args).unwrap();
194        assert_eq!(result.len(), 4);
195        for i in 0..4 {
196            let p: f64 = nums[i] % divs[i];
197            assert!(matches!(result.get(i), Value::Float64(v) if v == p));
198        }
199    }
200
201    #[test]
202    fn test_mod_function_errors() {
203        let function = ModuloFunction;
204        assert_eq!("mod", function.name());
205        let nums = vec![27];
206        let divs = vec![0];
207
208        let args: Vec<VectorRef> = vec![
209            Arc::new(Int32Vector::from_vec(nums.clone())),
210            Arc::new(Int32Vector::from_vec(divs.clone())),
211        ];
212        let result = function.eval(&FunctionContext::default(), &args);
213        assert!(result.is_err());
214        let err_msg = result.unwrap_err().output_msg();
215        assert_eq!(
216            err_msg,
217            "Failed to perform compute operation on arrow arrays: Divide by zero error"
218        );
219
220        let nums = vec![27];
221
222        let args: Vec<VectorRef> = vec![Arc::new(Int32Vector::from_vec(nums.clone()))];
223        let result = function.eval(&FunctionContext::default(), &args);
224        assert!(result.is_err());
225        let err_msg = result.unwrap_err().output_msg();
226        assert!(
227            err_msg.contains("The length of the args is not correct, expect exactly two, have: 1")
228        );
229
230        let nums = vec!["27"];
231        let divs = vec!["4"];
232        let args: Vec<VectorRef> = vec![
233            Arc::new(StringVector::from(nums.clone())),
234            Arc::new(StringVector::from(divs.clone())),
235        ];
236        let result = function.eval(&FunctionContext::default(), &args);
237        assert!(result.is_err());
238        let err_msg = result.unwrap_err().output_msg();
239        assert!(err_msg.contains("Invalid arithmetic operation"));
240    }
241}