common_function/scalars/math/
clamp.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::{self, Display};
16use std::sync::Arc;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use common_query::prelude::Signature;
20use datafusion::arrow::array::{ArrayIter, PrimitiveArray};
21use datafusion::logical_expr::Volatility;
22use datatypes::data_type::{ConcreteDataType, DataType};
23use datatypes::prelude::VectorRef;
24use datatypes::types::LogicalPrimitiveType;
25use datatypes::value::TryAsPrimitive;
26use datatypes::vectors::PrimitiveVector;
27use datatypes::with_match_primitive_type_id;
28use snafu::{ensure, OptionExt};
29
30use crate::function::{Function, FunctionContext};
31
32#[derive(Clone, Debug, Default)]
33pub struct ClampFunction;
34
35const CLAMP_NAME: &str = "clamp";
36
37impl Function for ClampFunction {
38    fn name(&self) -> &str {
39        CLAMP_NAME
40    }
41
42    fn return_type(&self, input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
43        // Type check is done by `signature`
44        Ok(input_types[0].clone())
45    }
46
47    fn signature(&self) -> Signature {
48        // input, min, max
49        Signature::uniform(3, ConcreteDataType::numerics(), Volatility::Immutable)
50    }
51
52    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
53        ensure!(
54            columns.len() == 3,
55            InvalidFuncArgsSnafu {
56                err_msg: format!(
57                    "The length of the args is not correct, expect exactly 3, have: {}",
58                    columns.len()
59                ),
60            }
61        );
62        ensure!(
63            columns[0].data_type().is_numeric(),
64            InvalidFuncArgsSnafu {
65                err_msg: format!(
66                    "The first arg's type is not numeric, have: {}",
67                    columns[0].data_type()
68                ),
69            }
70        );
71        ensure!(
72            columns[0].data_type() == columns[1].data_type()
73                && columns[1].data_type() == columns[2].data_type(),
74            InvalidFuncArgsSnafu {
75                err_msg: format!(
76                    "Arguments don't have identical types: {}, {}, {}",
77                    columns[0].data_type(),
78                    columns[1].data_type(),
79                    columns[2].data_type()
80                ),
81            }
82        );
83        ensure!(
84            (columns[1].len() == 1 || columns[1].is_const())
85                && (columns[2].len() == 1 || columns[2].is_const()),
86            InvalidFuncArgsSnafu {
87                err_msg: format!(
88                    "The second and third args should be scalar, have: {:?}, {:?}",
89                    columns[1], columns[2]
90                ),
91            }
92        );
93
94        with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
95            let input_array = columns[0].to_arrow_array();
96            let input = input_array
97                    .as_any()
98                    .downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
99                    .unwrap();
100
101            let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
102                .with_context(|| {
103                    InvalidFuncArgsSnafu {
104                        err_msg: "The second arg should not be none",
105                    }
106                })?;
107            let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[2].get(0))
108                .with_context(|| {
109                    InvalidFuncArgsSnafu {
110                        err_msg: "The third arg should not be none",
111                    }
112                })?;
113
114            // ensure min <= max
115            ensure!(
116                min <= max,
117                    InvalidFuncArgsSnafu {
118                        err_msg: format!(
119                        "The second arg should be less than or equal to the third arg, have: {:?}, {:?}",
120                        columns[1], columns[2]
121                    ),
122                }
123            );
124
125            clamp_impl::<$S, true, true>(input, min, max)
126        },{
127            unreachable!()
128        })
129    }
130}
131
132impl Display for ClampFunction {
133    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
134        write!(f, "{}", CLAMP_NAME.to_ascii_uppercase())
135    }
136}
137
138fn clamp_impl<T: LogicalPrimitiveType, const CLAMP_MIN: bool, const CLAMP_MAX: bool>(
139    input: &PrimitiveArray<T::ArrowPrimitive>,
140    min: T::Native,
141    max: T::Native,
142) -> Result<VectorRef> {
143    let iter = ArrayIter::new(input);
144    let result = iter.map(|x| {
145        x.map(|x| {
146            if CLAMP_MIN && x < min {
147                min
148            } else if CLAMP_MAX && x > max {
149                max
150            } else {
151                x
152            }
153        })
154    });
155    let result = PrimitiveArray::<T::ArrowPrimitive>::from_iter(result);
156    Ok(Arc::new(PrimitiveVector::<T>::from(result)))
157}
158
159#[derive(Clone, Debug, Default)]
160pub struct ClampMinFunction;
161
162const CLAMP_MIN_NAME: &str = "clamp_min";
163
164impl Function for ClampMinFunction {
165    fn name(&self) -> &str {
166        CLAMP_MIN_NAME
167    }
168
169    fn return_type(&self, input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
170        Ok(input_types[0].clone())
171    }
172
173    fn signature(&self) -> Signature {
174        // input, min
175        Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable)
176    }
177
178    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
179        ensure!(
180            columns.len() == 2,
181            InvalidFuncArgsSnafu {
182                err_msg: format!(
183                    "The length of the args is not correct, expect exactly 2, have: {}",
184                    columns.len()
185                ),
186            }
187        );
188        ensure!(
189            columns[0].data_type().is_numeric(),
190            InvalidFuncArgsSnafu {
191                err_msg: format!(
192                    "The first arg's type is not numeric, have: {}",
193                    columns[0].data_type()
194                ),
195            }
196        );
197        ensure!(
198            columns[0].data_type() == columns[1].data_type(),
199            InvalidFuncArgsSnafu {
200                err_msg: format!(
201                    "Arguments don't have identical types: {}, {}",
202                    columns[0].data_type(),
203                    columns[1].data_type()
204                ),
205            }
206        );
207        ensure!(
208            columns[1].len() == 1 || columns[1].is_const(),
209            InvalidFuncArgsSnafu {
210                err_msg: format!(
211                    "The second arg (min) should be scalar, have: {:?}",
212                    columns[1]
213                ),
214            }
215        );
216
217        with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
218            let input_array = columns[0].to_arrow_array();
219            let input = input_array
220                .as_any()
221                .downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
222                .unwrap();
223
224            let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
225                .with_context(|| {
226                    InvalidFuncArgsSnafu {
227                        err_msg: "The second arg (min) should not be none",
228                    }
229                })?;
230            // For clamp_min, max is effectively infinity, so we don't use it in the clamp_impl logic.
231            // We pass a default/dummy value for max.
232            let max_dummy = <$S as LogicalPrimitiveType>::Native::default();
233
234            clamp_impl::<$S, true, false>(input, min, max_dummy)
235        },{
236            unreachable!()
237        })
238    }
239}
240
241impl Display for ClampMinFunction {
242    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
243        write!(f, "{}", CLAMP_MIN_NAME.to_ascii_uppercase())
244    }
245}
246
247#[derive(Clone, Debug, Default)]
248pub struct ClampMaxFunction;
249
250const CLAMP_MAX_NAME: &str = "clamp_max";
251
252impl Function for ClampMaxFunction {
253    fn name(&self) -> &str {
254        CLAMP_MAX_NAME
255    }
256
257    fn return_type(&self, input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
258        Ok(input_types[0].clone())
259    }
260
261    fn signature(&self) -> Signature {
262        // input, max
263        Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable)
264    }
265
266    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
267        ensure!(
268            columns.len() == 2,
269            InvalidFuncArgsSnafu {
270                err_msg: format!(
271                    "The length of the args is not correct, expect exactly 2, have: {}",
272                    columns.len()
273                ),
274            }
275        );
276        ensure!(
277            columns[0].data_type().is_numeric(),
278            InvalidFuncArgsSnafu {
279                err_msg: format!(
280                    "The first arg's type is not numeric, have: {}",
281                    columns[0].data_type()
282                ),
283            }
284        );
285        ensure!(
286            columns[0].data_type() == columns[1].data_type(),
287            InvalidFuncArgsSnafu {
288                err_msg: format!(
289                    "Arguments don't have identical types: {}, {}",
290                    columns[0].data_type(),
291                    columns[1].data_type()
292                ),
293            }
294        );
295        ensure!(
296            columns[1].len() == 1 || columns[1].is_const(),
297            InvalidFuncArgsSnafu {
298                err_msg: format!(
299                    "The second arg (max) should be scalar, have: {:?}",
300                    columns[1]
301                ),
302            }
303        );
304
305        with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
306            let input_array = columns[0].to_arrow_array();
307            let input = input_array
308                .as_any()
309                .downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
310                .unwrap();
311
312            let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
313                .with_context(|| {
314                    InvalidFuncArgsSnafu {
315                        err_msg: "The second arg (max) should not be none",
316                    }
317                })?;
318            // For clamp_max, min is effectively -infinity, so we don't use it in the clamp_impl logic.
319            // We pass a default/dummy value for min.
320            let min_dummy = <$S as LogicalPrimitiveType>::Native::default();
321
322            clamp_impl::<$S, false, true>(input, min_dummy, max)
323        },{
324            unreachable!()
325        })
326    }
327}
328
329impl Display for ClampMaxFunction {
330    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
331        write!(f, "{}", CLAMP_MAX_NAME.to_ascii_uppercase())
332    }
333}
334
335#[cfg(test)]
336mod test {
337
338    use std::sync::Arc;
339
340    use datatypes::prelude::ScalarVector;
341    use datatypes::vectors::{
342        ConstantVector, Float64Vector, Int64Vector, StringVector, UInt64Vector,
343    };
344
345    use super::*;
346    use crate::function::FunctionContext;
347
348    #[test]
349    fn clamp_i64() {
350        let inputs = [
351            (
352                vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
353                -1,
354                10,
355                vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
356            ),
357            (
358                vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
359                0,
360                0,
361                vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
362            ),
363            (
364                vec![Some(-3), None, Some(-1), None, None, Some(2)],
365                -2,
366                1,
367                vec![Some(-2), None, Some(-1), None, None, Some(1)],
368            ),
369            (
370                vec![None, None, None, None, None],
371                0,
372                1,
373                vec![None, None, None, None, None],
374            ),
375        ];
376
377        let func = ClampFunction;
378        for (in_data, min, max, expected) in inputs {
379            let args = [
380                Arc::new(Int64Vector::from(in_data)) as _,
381                Arc::new(Int64Vector::from_vec(vec![min])) as _,
382                Arc::new(Int64Vector::from_vec(vec![max])) as _,
383            ];
384            let result = func
385                .eval(&FunctionContext::default(), args.as_slice())
386                .unwrap();
387            let expected: VectorRef = Arc::new(Int64Vector::from(expected));
388            assert_eq!(expected, result);
389        }
390    }
391
392    #[test]
393    fn clamp_u64() {
394        let inputs = [
395            (
396                vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
397                1,
398                3,
399                vec![Some(1), Some(1), Some(2), Some(3), Some(3), Some(3)],
400            ),
401            (
402                vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
403                0,
404                0,
405                vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
406            ),
407            (
408                vec![Some(0), None, Some(2), None, None, Some(5)],
409                1,
410                3,
411                vec![Some(1), None, Some(2), None, None, Some(3)],
412            ),
413            (
414                vec![None, None, None, None, None],
415                0,
416                1,
417                vec![None, None, None, None, None],
418            ),
419        ];
420
421        let func = ClampFunction;
422        for (in_data, min, max, expected) in inputs {
423            let args = [
424                Arc::new(UInt64Vector::from(in_data)) as _,
425                Arc::new(UInt64Vector::from_vec(vec![min])) as _,
426                Arc::new(UInt64Vector::from_vec(vec![max])) as _,
427            ];
428            let result = func
429                .eval(&FunctionContext::default(), args.as_slice())
430                .unwrap();
431            let expected: VectorRef = Arc::new(UInt64Vector::from(expected));
432            assert_eq!(expected, result);
433        }
434    }
435
436    #[test]
437    fn clamp_f64() {
438        let inputs = [
439            (
440                vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
441                -1.0,
442                10.0,
443                vec![Some(-1.0), Some(-1.0), Some(-1.0), Some(0.0), Some(1.0)],
444            ),
445            (
446                vec![Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
447                0.0,
448                0.0,
449                vec![Some(0.0), Some(0.0), Some(0.0), Some(0.0)],
450            ),
451            (
452                vec![Some(-3.0), None, Some(-1.0), None, None, Some(2.0)],
453                -2.0,
454                1.0,
455                vec![Some(-2.0), None, Some(-1.0), None, None, Some(1.0)],
456            ),
457            (
458                vec![None, None, None, None, None],
459                0.0,
460                1.0,
461                vec![None, None, None, None, None],
462            ),
463        ];
464
465        let func = ClampFunction;
466        for (in_data, min, max, expected) in inputs {
467            let args = [
468                Arc::new(Float64Vector::from(in_data)) as _,
469                Arc::new(Float64Vector::from_vec(vec![min])) as _,
470                Arc::new(Float64Vector::from_vec(vec![max])) as _,
471            ];
472            let result = func
473                .eval(&FunctionContext::default(), args.as_slice())
474                .unwrap();
475            let expected: VectorRef = Arc::new(Float64Vector::from(expected));
476            assert_eq!(expected, result);
477        }
478    }
479
480    #[test]
481    fn clamp_const_i32() {
482        let input = vec![Some(5)];
483        let min = 2;
484        let max = 4;
485
486        let func = ClampFunction;
487        let args = [
488            Arc::new(ConstantVector::new(Arc::new(Int64Vector::from(input)), 1)) as _,
489            Arc::new(Int64Vector::from_vec(vec![min])) as _,
490            Arc::new(Int64Vector::from_vec(vec![max])) as _,
491        ];
492        let result = func
493            .eval(&FunctionContext::default(), args.as_slice())
494            .unwrap();
495        let expected: VectorRef = Arc::new(Int64Vector::from(vec![Some(4)]));
496        assert_eq!(expected, result);
497    }
498
499    #[test]
500    fn clamp_invalid_min_max() {
501        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
502        let min = 10.0;
503        let max = -1.0;
504
505        let func = ClampFunction;
506        let args = [
507            Arc::new(Float64Vector::from(input)) as _,
508            Arc::new(Float64Vector::from_vec(vec![min])) as _,
509            Arc::new(Float64Vector::from_vec(vec![max])) as _,
510        ];
511        let result = func.eval(&FunctionContext::default(), args.as_slice());
512        assert!(result.is_err());
513    }
514
515    #[test]
516    fn clamp_type_not_match() {
517        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
518        let min = -1;
519        let max = 10;
520
521        let func = ClampFunction;
522        let args = [
523            Arc::new(Float64Vector::from(input)) as _,
524            Arc::new(Int64Vector::from_vec(vec![min])) as _,
525            Arc::new(UInt64Vector::from_vec(vec![max])) as _,
526        ];
527        let result = func.eval(&FunctionContext::default(), args.as_slice());
528        assert!(result.is_err());
529    }
530
531    #[test]
532    fn clamp_min_is_not_scalar() {
533        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
534        let min = -10.0;
535        let max = 1.0;
536
537        let func = ClampFunction;
538        let args = [
539            Arc::new(Float64Vector::from(input)) as _,
540            Arc::new(Float64Vector::from_vec(vec![min, min])) as _,
541            Arc::new(Float64Vector::from_vec(vec![max])) as _,
542        ];
543        let result = func.eval(&FunctionContext::default(), args.as_slice());
544        assert!(result.is_err());
545    }
546
547    #[test]
548    fn clamp_no_max() {
549        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
550        let min = -10.0;
551
552        let func = ClampFunction;
553        let args = [
554            Arc::new(Float64Vector::from(input)) as _,
555            Arc::new(Float64Vector::from_vec(vec![min])) as _,
556        ];
557        let result = func.eval(&FunctionContext::default(), args.as_slice());
558        assert!(result.is_err());
559    }
560
561    #[test]
562    fn clamp_on_string() {
563        let input = vec![Some("foo"), Some("foo"), Some("foo"), Some("foo")];
564
565        let func = ClampFunction;
566        let args = [
567            Arc::new(StringVector::from(input)) as _,
568            Arc::new(StringVector::from_vec(vec!["bar"])) as _,
569            Arc::new(StringVector::from_vec(vec!["baz"])) as _,
570        ];
571        let result = func.eval(&FunctionContext::default(), args.as_slice());
572        assert!(result.is_err());
573    }
574
575    #[test]
576    fn clamp_min_i64() {
577        let inputs = [
578            (
579                vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
580                -1,
581                vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
582            ),
583            (
584                vec![Some(-3), None, Some(-1), None, None, Some(2)],
585                -2,
586                vec![Some(-2), None, Some(-1), None, None, Some(2)],
587            ),
588        ];
589
590        let func = ClampMinFunction;
591        for (in_data, min, expected) in inputs {
592            let args = [
593                Arc::new(Int64Vector::from(in_data)) as _,
594                Arc::new(Int64Vector::from_vec(vec![min])) as _,
595            ];
596            let result = func
597                .eval(&FunctionContext::default(), args.as_slice())
598                .unwrap();
599            let expected: VectorRef = Arc::new(Int64Vector::from(expected));
600            assert_eq!(expected, result);
601        }
602    }
603
604    #[test]
605    fn clamp_max_i64() {
606        let inputs = [
607            (
608                vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
609                1,
610                vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(1)],
611            ),
612            (
613                vec![Some(-3), None, Some(-1), None, None, Some(2)],
614                0,
615                vec![Some(-3), None, Some(-1), None, None, Some(0)],
616            ),
617        ];
618
619        let func = ClampMaxFunction;
620        for (in_data, max, expected) in inputs {
621            let args = [
622                Arc::new(Int64Vector::from(in_data)) as _,
623                Arc::new(Int64Vector::from_vec(vec![max])) as _,
624            ];
625            let result = func
626                .eval(&FunctionContext::default(), args.as_slice())
627                .unwrap();
628            let expected: VectorRef = Arc::new(Int64Vector::from(expected));
629            assert_eq!(expected, result);
630        }
631    }
632
633    #[test]
634    fn clamp_min_f64() {
635        let inputs = [(
636            vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
637            -1.0,
638            vec![Some(-1.0), Some(-1.0), Some(-1.0), Some(0.0), Some(1.0)],
639        )];
640
641        let func = ClampMinFunction;
642        for (in_data, min, expected) in inputs {
643            let args = [
644                Arc::new(Float64Vector::from(in_data)) as _,
645                Arc::new(Float64Vector::from_vec(vec![min])) as _,
646            ];
647            let result = func
648                .eval(&FunctionContext::default(), args.as_slice())
649                .unwrap();
650            let expected: VectorRef = Arc::new(Float64Vector::from(expected));
651            assert_eq!(expected, result);
652        }
653    }
654
655    #[test]
656    fn clamp_max_f64() {
657        let inputs = [(
658            vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
659            0.0,
660            vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(0.0)],
661        )];
662
663        let func = ClampMaxFunction;
664        for (in_data, max, expected) in inputs {
665            let args = [
666                Arc::new(Float64Vector::from(in_data)) as _,
667                Arc::new(Float64Vector::from_vec(vec![max])) as _,
668            ];
669            let result = func
670                .eval(&FunctionContext::default(), args.as_slice())
671                .unwrap();
672            let expected: VectorRef = Arc::new(Float64Vector::from(expected));
673            assert_eq!(expected, result);
674        }
675    }
676
677    #[test]
678    fn clamp_min_type_not_match() {
679        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
680        let min = -1;
681
682        let func = ClampMinFunction;
683        let args = [
684            Arc::new(Float64Vector::from(input)) as _,
685            Arc::new(Int64Vector::from_vec(vec![min])) as _,
686        ];
687        let result = func.eval(&FunctionContext::default(), args.as_slice());
688        assert!(result.is_err());
689    }
690
691    #[test]
692    fn clamp_max_type_not_match() {
693        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
694        let max = 1;
695
696        let func = ClampMaxFunction;
697        let args = [
698            Arc::new(Float64Vector::from(input)) as _,
699            Arc::new(Int64Vector::from_vec(vec![max])) as _,
700        ];
701        let result = func.eval(&FunctionContext::default(), args.as_slice());
702        assert!(result.is_err());
703    }
704}