common_function/scalars/math/
clamp.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::{self, Display};
16use std::sync::Arc;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use common_query::prelude::Signature;
20use datafusion::arrow::array::{ArrayIter, PrimitiveArray};
21use datafusion::logical_expr::Volatility;
22use datatypes::data_type::{ConcreteDataType, DataType};
23use datatypes::prelude::VectorRef;
24use datatypes::types::LogicalPrimitiveType;
25use datatypes::value::TryAsPrimitive;
26use datatypes::vectors::PrimitiveVector;
27use datatypes::with_match_primitive_type_id;
28use snafu::{ensure, OptionExt};
29
30use crate::function::{Function, FunctionContext};
31
32#[derive(Clone, Debug, Default)]
33pub struct ClampFunction;
34
35const CLAMP_NAME: &str = "clamp";
36
37impl Function for ClampFunction {
38    fn name(&self) -> &str {
39        CLAMP_NAME
40    }
41
42    fn return_type(&self, input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
43        // Type check is done by `signature`
44        Ok(input_types[0].clone())
45    }
46
47    fn signature(&self) -> Signature {
48        // input, min, max
49        Signature::uniform(3, ConcreteDataType::numerics(), Volatility::Immutable)
50    }
51
52    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
53        ensure!(
54            columns.len() == 3,
55            InvalidFuncArgsSnafu {
56                err_msg: format!(
57                    "The length of the args is not correct, expect exactly 3, have: {}",
58                    columns.len()
59                ),
60            }
61        );
62        ensure!(
63            columns[0].data_type().is_numeric(),
64            InvalidFuncArgsSnafu {
65                err_msg: format!(
66                    "The first arg's type is not numeric, have: {}",
67                    columns[0].data_type()
68                ),
69            }
70        );
71        ensure!(
72            columns[0].data_type() == columns[1].data_type()
73                && columns[1].data_type() == columns[2].data_type(),
74            InvalidFuncArgsSnafu {
75                err_msg: format!(
76                    "Arguments don't have identical types: {}, {}, {}",
77                    columns[0].data_type(),
78                    columns[1].data_type(),
79                    columns[2].data_type()
80                ),
81            }
82        );
83        ensure!(
84            columns[1].len() == 1 && columns[2].len() == 1,
85            InvalidFuncArgsSnafu {
86                err_msg: format!(
87                    "The second and third args should be scalar, have: {:?}, {:?}",
88                    columns[1], columns[2]
89                ),
90            }
91        );
92
93        with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
94            let input_array = columns[0].to_arrow_array();
95            let input = input_array
96                    .as_any()
97                    .downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
98                    .unwrap();
99
100            let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
101                .with_context(|| {
102                    InvalidFuncArgsSnafu {
103                        err_msg: "The second arg should not be none",
104                    }
105                })?;
106            let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[2].get(0))
107                .with_context(|| {
108                    InvalidFuncArgsSnafu {
109                        err_msg: "The third arg should not be none",
110                    }
111                })?;
112
113            // ensure min <= max
114            ensure!(
115                min <= max,
116                    InvalidFuncArgsSnafu {
117                        err_msg: format!(
118                        "The second arg should be less than or equal to the third arg, have: {:?}, {:?}",
119                        columns[1], columns[2]
120                    ),
121                }
122            );
123
124            clamp_impl::<$S, true, true>(input, min, max)
125        },{
126            unreachable!()
127        })
128    }
129}
130
131impl Display for ClampFunction {
132    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
133        write!(f, "{}", CLAMP_NAME.to_ascii_uppercase())
134    }
135}
136
137fn clamp_impl<T: LogicalPrimitiveType, const CLAMP_MIN: bool, const CLAMP_MAX: bool>(
138    input: &PrimitiveArray<T::ArrowPrimitive>,
139    min: T::Native,
140    max: T::Native,
141) -> Result<VectorRef> {
142    let iter = ArrayIter::new(input);
143    let result = iter.map(|x| {
144        x.map(|x| {
145            if CLAMP_MIN && x < min {
146                min
147            } else if CLAMP_MAX && x > max {
148                max
149            } else {
150                x
151            }
152        })
153    });
154    let result = PrimitiveArray::<T::ArrowPrimitive>::from_iter(result);
155    Ok(Arc::new(PrimitiveVector::<T>::from(result)))
156}
157
158#[cfg(test)]
159mod test {
160
161    use std::sync::Arc;
162
163    use datatypes::prelude::ScalarVector;
164    use datatypes::vectors::{
165        ConstantVector, Float64Vector, Int64Vector, StringVector, UInt64Vector,
166    };
167
168    use super::*;
169    use crate::function::FunctionContext;
170
171    #[test]
172    fn clamp_i64() {
173        let inputs = [
174            (
175                vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
176                -1,
177                10,
178                vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
179            ),
180            (
181                vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
182                0,
183                0,
184                vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
185            ),
186            (
187                vec![Some(-3), None, Some(-1), None, None, Some(2)],
188                -2,
189                1,
190                vec![Some(-2), None, Some(-1), None, None, Some(1)],
191            ),
192            (
193                vec![None, None, None, None, None],
194                0,
195                1,
196                vec![None, None, None, None, None],
197            ),
198        ];
199
200        let func = ClampFunction;
201        for (in_data, min, max, expected) in inputs {
202            let args = [
203                Arc::new(Int64Vector::from(in_data)) as _,
204                Arc::new(Int64Vector::from_vec(vec![min])) as _,
205                Arc::new(Int64Vector::from_vec(vec![max])) as _,
206            ];
207            let result = func
208                .eval(&FunctionContext::default(), args.as_slice())
209                .unwrap();
210            let expected: VectorRef = Arc::new(Int64Vector::from(expected));
211            assert_eq!(expected, result);
212        }
213    }
214
215    #[test]
216    fn clamp_u64() {
217        let inputs = [
218            (
219                vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
220                1,
221                3,
222                vec![Some(1), Some(1), Some(2), Some(3), Some(3), Some(3)],
223            ),
224            (
225                vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
226                0,
227                0,
228                vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
229            ),
230            (
231                vec![Some(0), None, Some(2), None, None, Some(5)],
232                1,
233                3,
234                vec![Some(1), None, Some(2), None, None, Some(3)],
235            ),
236            (
237                vec![None, None, None, None, None],
238                0,
239                1,
240                vec![None, None, None, None, None],
241            ),
242        ];
243
244        let func = ClampFunction;
245        for (in_data, min, max, expected) in inputs {
246            let args = [
247                Arc::new(UInt64Vector::from(in_data)) as _,
248                Arc::new(UInt64Vector::from_vec(vec![min])) as _,
249                Arc::new(UInt64Vector::from_vec(vec![max])) as _,
250            ];
251            let result = func
252                .eval(&FunctionContext::default(), args.as_slice())
253                .unwrap();
254            let expected: VectorRef = Arc::new(UInt64Vector::from(expected));
255            assert_eq!(expected, result);
256        }
257    }
258
259    #[test]
260    fn clamp_f64() {
261        let inputs = [
262            (
263                vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
264                -1.0,
265                10.0,
266                vec![Some(-1.0), Some(-1.0), Some(-1.0), Some(0.0), Some(1.0)],
267            ),
268            (
269                vec![Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
270                0.0,
271                0.0,
272                vec![Some(0.0), Some(0.0), Some(0.0), Some(0.0)],
273            ),
274            (
275                vec![Some(-3.0), None, Some(-1.0), None, None, Some(2.0)],
276                -2.0,
277                1.0,
278                vec![Some(-2.0), None, Some(-1.0), None, None, Some(1.0)],
279            ),
280            (
281                vec![None, None, None, None, None],
282                0.0,
283                1.0,
284                vec![None, None, None, None, None],
285            ),
286        ];
287
288        let func = ClampFunction;
289        for (in_data, min, max, expected) in inputs {
290            let args = [
291                Arc::new(Float64Vector::from(in_data)) as _,
292                Arc::new(Float64Vector::from_vec(vec![min])) as _,
293                Arc::new(Float64Vector::from_vec(vec![max])) as _,
294            ];
295            let result = func
296                .eval(&FunctionContext::default(), args.as_slice())
297                .unwrap();
298            let expected: VectorRef = Arc::new(Float64Vector::from(expected));
299            assert_eq!(expected, result);
300        }
301    }
302
303    #[test]
304    fn clamp_const_i32() {
305        let input = vec![Some(5)];
306        let min = 2;
307        let max = 4;
308
309        let func = ClampFunction;
310        let args = [
311            Arc::new(ConstantVector::new(Arc::new(Int64Vector::from(input)), 1)) as _,
312            Arc::new(Int64Vector::from_vec(vec![min])) as _,
313            Arc::new(Int64Vector::from_vec(vec![max])) as _,
314        ];
315        let result = func
316            .eval(&FunctionContext::default(), args.as_slice())
317            .unwrap();
318        let expected: VectorRef = Arc::new(Int64Vector::from(vec![Some(4)]));
319        assert_eq!(expected, result);
320    }
321
322    #[test]
323    fn clamp_invalid_min_max() {
324        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
325        let min = 10.0;
326        let max = -1.0;
327
328        let func = ClampFunction;
329        let args = [
330            Arc::new(Float64Vector::from(input)) as _,
331            Arc::new(Float64Vector::from_vec(vec![min])) as _,
332            Arc::new(Float64Vector::from_vec(vec![max])) as _,
333        ];
334        let result = func.eval(&FunctionContext::default(), args.as_slice());
335        assert!(result.is_err());
336    }
337
338    #[test]
339    fn clamp_type_not_match() {
340        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
341        let min = -1;
342        let max = 10;
343
344        let func = ClampFunction;
345        let args = [
346            Arc::new(Float64Vector::from(input)) as _,
347            Arc::new(Int64Vector::from_vec(vec![min])) as _,
348            Arc::new(UInt64Vector::from_vec(vec![max])) as _,
349        ];
350        let result = func.eval(&FunctionContext::default(), args.as_slice());
351        assert!(result.is_err());
352    }
353
354    #[test]
355    fn clamp_min_is_not_scalar() {
356        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
357        let min = -10.0;
358        let max = 1.0;
359
360        let func = ClampFunction;
361        let args = [
362            Arc::new(Float64Vector::from(input)) as _,
363            Arc::new(Float64Vector::from_vec(vec![min, min])) as _,
364            Arc::new(Float64Vector::from_vec(vec![max])) as _,
365        ];
366        let result = func.eval(&FunctionContext::default(), args.as_slice());
367        assert!(result.is_err());
368    }
369
370    #[test]
371    fn clamp_no_max() {
372        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
373        let min = -10.0;
374
375        let func = ClampFunction;
376        let args = [
377            Arc::new(Float64Vector::from(input)) as _,
378            Arc::new(Float64Vector::from_vec(vec![min])) as _,
379        ];
380        let result = func.eval(&FunctionContext::default(), args.as_slice());
381        assert!(result.is_err());
382    }
383
384    #[test]
385    fn clamp_on_string() {
386        let input = vec![Some("foo"), Some("foo"), Some("foo"), Some("foo")];
387
388        let func = ClampFunction;
389        let args = [
390            Arc::new(StringVector::from(input)) as _,
391            Arc::new(StringVector::from_vec(vec!["bar"])) as _,
392            Arc::new(StringVector::from_vec(vec!["baz"])) as _,
393        ];
394        let result = func.eval(&FunctionContext::default(), args.as_slice());
395        assert!(result.is_err());
396    }
397}