common_function/scalars/math/
clamp.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::{self, Display};
16use std::sync::Arc;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion::arrow::array::{ArrayIter, PrimitiveArray};
20use datafusion::arrow::datatypes::DataType as ArrowDataType;
21use datafusion::logical_expr::Volatility;
22use datafusion_expr::Signature;
23use datafusion_expr::type_coercion::aggregates::NUMERICS;
24use datatypes::data_type::DataType;
25use datatypes::prelude::VectorRef;
26use datatypes::types::LogicalPrimitiveType;
27use datatypes::value::TryAsPrimitive;
28use datatypes::vectors::PrimitiveVector;
29use datatypes::with_match_primitive_type_id;
30use snafu::{OptionExt, ensure};
31
32use crate::function::{Function, FunctionContext};
33
34#[derive(Clone, Debug, Default)]
35pub struct ClampFunction;
36
37const CLAMP_NAME: &str = "clamp";
38
39/// Ensure the vector is constant and not empty (i.e., all values are identical)
40fn ensure_constant_vector(vector: &VectorRef) -> Result<()> {
41    ensure!(
42        !vector.is_empty(),
43        InvalidFuncArgsSnafu {
44            err_msg: "Expect at least one value",
45        }
46    );
47
48    if vector.is_const() {
49        return Ok(());
50    }
51
52    let first = vector.get_ref(0);
53    for i in 1..vector.len() {
54        let v = vector.get_ref(i);
55        if first != v {
56            return InvalidFuncArgsSnafu {
57                err_msg: "All values in min/max argument must be identical",
58            }
59            .fail();
60        }
61    }
62
63    Ok(())
64}
65
66impl Function for ClampFunction {
67    fn name(&self) -> &str {
68        CLAMP_NAME
69    }
70
71    fn return_type(&self, input_types: &[ArrowDataType]) -> Result<ArrowDataType> {
72        // Type check is done by `signature`
73        Ok(input_types[0].clone())
74    }
75
76    fn signature(&self) -> Signature {
77        // input, min, max
78        Signature::uniform(3, NUMERICS.to_vec(), Volatility::Immutable)
79    }
80
81    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
82        ensure!(
83            columns.len() == 3,
84            InvalidFuncArgsSnafu {
85                err_msg: format!(
86                    "The length of the args is not correct, expect exactly 3, have: {}",
87                    columns.len()
88                ),
89            }
90        );
91        ensure!(
92            columns[0].data_type().is_numeric(),
93            InvalidFuncArgsSnafu {
94                err_msg: format!(
95                    "The first arg's type is not numeric, have: {}",
96                    columns[0].data_type()
97                ),
98            }
99        );
100        ensure!(
101            columns[0].data_type() == columns[1].data_type()
102                && columns[1].data_type() == columns[2].data_type(),
103            InvalidFuncArgsSnafu {
104                err_msg: format!(
105                    "Arguments don't have identical types: {}, {}, {}",
106                    columns[0].data_type(),
107                    columns[1].data_type(),
108                    columns[2].data_type()
109                ),
110            }
111        );
112
113        ensure_constant_vector(&columns[1])?;
114        ensure_constant_vector(&columns[2])?;
115
116        with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
117            let input_array = columns[0].to_arrow_array();
118            let input = input_array
119                    .as_any()
120                    .downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
121                    .unwrap();
122
123            let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
124                .with_context(|| {
125                    InvalidFuncArgsSnafu {
126                        err_msg: "The second arg should not be none",
127                    }
128                })?;
129            let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[2].get(0))
130                .with_context(|| {
131                    InvalidFuncArgsSnafu {
132                        err_msg: "The third arg should not be none",
133                    }
134                })?;
135
136            // ensure min <= max
137            ensure!(
138                min <= max,
139                    InvalidFuncArgsSnafu {
140                        err_msg: format!(
141                        "The second arg should be less than or equal to the third arg, have: {:?}, {:?}",
142                        columns[1], columns[2]
143                    ),
144                }
145            );
146
147            clamp_impl::<$S, true, true>(input, min, max)
148        },{
149            unreachable!()
150        })
151    }
152}
153
154impl Display for ClampFunction {
155    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
156        write!(f, "{}", CLAMP_NAME.to_ascii_uppercase())
157    }
158}
159
160fn clamp_impl<T: LogicalPrimitiveType, const CLAMP_MIN: bool, const CLAMP_MAX: bool>(
161    input: &PrimitiveArray<T::ArrowPrimitive>,
162    min: T::Native,
163    max: T::Native,
164) -> Result<VectorRef> {
165    let iter = ArrayIter::new(input);
166    let result = iter.map(|x| {
167        x.map(|x| {
168            if CLAMP_MIN && x < min {
169                min
170            } else if CLAMP_MAX && x > max {
171                max
172            } else {
173                x
174            }
175        })
176    });
177    let result = PrimitiveArray::<T::ArrowPrimitive>::from_iter(result);
178    Ok(Arc::new(PrimitiveVector::<T>::from(result)))
179}
180
181#[derive(Clone, Debug, Default)]
182pub struct ClampMinFunction;
183
184const CLAMP_MIN_NAME: &str = "clamp_min";
185
186impl Function for ClampMinFunction {
187    fn name(&self) -> &str {
188        CLAMP_MIN_NAME
189    }
190
191    fn return_type(&self, input_types: &[ArrowDataType]) -> Result<ArrowDataType> {
192        Ok(input_types[0].clone())
193    }
194
195    fn signature(&self) -> Signature {
196        // input, min
197        Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
198    }
199
200    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
201        ensure!(
202            columns.len() == 2,
203            InvalidFuncArgsSnafu {
204                err_msg: format!(
205                    "The length of the args is not correct, expect exactly 2, have: {}",
206                    columns.len()
207                ),
208            }
209        );
210        ensure!(
211            columns[0].data_type().is_numeric(),
212            InvalidFuncArgsSnafu {
213                err_msg: format!(
214                    "The first arg's type is not numeric, have: {}",
215                    columns[0].data_type()
216                ),
217            }
218        );
219        ensure!(
220            columns[0].data_type() == columns[1].data_type(),
221            InvalidFuncArgsSnafu {
222                err_msg: format!(
223                    "Arguments don't have identical types: {}, {}",
224                    columns[0].data_type(),
225                    columns[1].data_type()
226                ),
227            }
228        );
229
230        ensure_constant_vector(&columns[1])?;
231
232        with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
233            let input_array = columns[0].to_arrow_array();
234            let input = input_array
235                .as_any()
236                .downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
237                .unwrap();
238
239            let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
240                .with_context(|| {
241                    InvalidFuncArgsSnafu {
242                        err_msg: "The second arg (min) should not be none",
243                    }
244                })?;
245            // For clamp_min, max is effectively infinity, so we don't use it in the clamp_impl logic.
246            // We pass a default/dummy value for max.
247            let max_dummy = <$S as LogicalPrimitiveType>::Native::default();
248
249            clamp_impl::<$S, true, false>(input, min, max_dummy)
250        },{
251            unreachable!()
252        })
253    }
254}
255
256impl Display for ClampMinFunction {
257    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
258        write!(f, "{}", CLAMP_MIN_NAME.to_ascii_uppercase())
259    }
260}
261
262#[derive(Clone, Debug, Default)]
263pub struct ClampMaxFunction;
264
265const CLAMP_MAX_NAME: &str = "clamp_max";
266
267impl Function for ClampMaxFunction {
268    fn name(&self) -> &str {
269        CLAMP_MAX_NAME
270    }
271
272    fn return_type(&self, input_types: &[ArrowDataType]) -> Result<ArrowDataType> {
273        Ok(input_types[0].clone())
274    }
275
276    fn signature(&self) -> Signature {
277        // input, max
278        Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
279    }
280
281    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
282        ensure!(
283            columns.len() == 2,
284            InvalidFuncArgsSnafu {
285                err_msg: format!(
286                    "The length of the args is not correct, expect exactly 2, have: {}",
287                    columns.len()
288                ),
289            }
290        );
291        ensure!(
292            columns[0].data_type().is_numeric(),
293            InvalidFuncArgsSnafu {
294                err_msg: format!(
295                    "The first arg's type is not numeric, have: {}",
296                    columns[0].data_type()
297                ),
298            }
299        );
300        ensure!(
301            columns[0].data_type() == columns[1].data_type(),
302            InvalidFuncArgsSnafu {
303                err_msg: format!(
304                    "Arguments don't have identical types: {}, {}",
305                    columns[0].data_type(),
306                    columns[1].data_type()
307                ),
308            }
309        );
310
311        ensure_constant_vector(&columns[1])?;
312
313        with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
314            let input_array = columns[0].to_arrow_array();
315            let input = input_array
316                .as_any()
317                .downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
318                .unwrap();
319
320            let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
321                .with_context(|| {
322                    InvalidFuncArgsSnafu {
323                        err_msg: "The second arg (max) should not be none",
324                    }
325                })?;
326            // For clamp_max, min is effectively -infinity, so we don't use it in the clamp_impl logic.
327            // We pass a default/dummy value for min.
328            let min_dummy = <$S as LogicalPrimitiveType>::Native::default();
329
330            clamp_impl::<$S, false, true>(input, min_dummy, max)
331        },{
332            unreachable!()
333        })
334    }
335}
336
337impl Display for ClampMaxFunction {
338    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
339        write!(f, "{}", CLAMP_MAX_NAME.to_ascii_uppercase())
340    }
341}
342
343#[cfg(test)]
344mod test {
345
346    use std::sync::Arc;
347
348    use datatypes::prelude::ScalarVector;
349    use datatypes::vectors::{
350        ConstantVector, Float64Vector, Int64Vector, StringVector, UInt64Vector,
351    };
352
353    use super::*;
354    use crate::function::FunctionContext;
355
356    #[test]
357    fn clamp_i64() {
358        let inputs = [
359            (
360                vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
361                -1,
362                10,
363                vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
364            ),
365            (
366                vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
367                0,
368                0,
369                vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
370            ),
371            (
372                vec![Some(-3), None, Some(-1), None, None, Some(2)],
373                -2,
374                1,
375                vec![Some(-2), None, Some(-1), None, None, Some(1)],
376            ),
377            (
378                vec![None, None, None, None, None],
379                0,
380                1,
381                vec![None, None, None, None, None],
382            ),
383        ];
384
385        let func = ClampFunction;
386        for (in_data, min, max, expected) in inputs {
387            let args = [
388                Arc::new(Int64Vector::from(in_data)) as _,
389                Arc::new(Int64Vector::from_vec(vec![min])) as _,
390                Arc::new(Int64Vector::from_vec(vec![max])) as _,
391            ];
392            let result = func
393                .eval(&FunctionContext::default(), args.as_slice())
394                .unwrap();
395            let expected: VectorRef = Arc::new(Int64Vector::from(expected));
396            assert_eq!(expected, result);
397        }
398    }
399
400    #[test]
401    fn clamp_u64() {
402        let inputs = [
403            (
404                vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
405                1,
406                3,
407                vec![Some(1), Some(1), Some(2), Some(3), Some(3), Some(3)],
408            ),
409            (
410                vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
411                0,
412                0,
413                vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
414            ),
415            (
416                vec![Some(0), None, Some(2), None, None, Some(5)],
417                1,
418                3,
419                vec![Some(1), None, Some(2), None, None, Some(3)],
420            ),
421            (
422                vec![None, None, None, None, None],
423                0,
424                1,
425                vec![None, None, None, None, None],
426            ),
427        ];
428
429        let func = ClampFunction;
430        for (in_data, min, max, expected) in inputs {
431            let args = [
432                Arc::new(UInt64Vector::from(in_data)) as _,
433                Arc::new(UInt64Vector::from_vec(vec![min])) as _,
434                Arc::new(UInt64Vector::from_vec(vec![max])) as _,
435            ];
436            let result = func
437                .eval(&FunctionContext::default(), args.as_slice())
438                .unwrap();
439            let expected: VectorRef = Arc::new(UInt64Vector::from(expected));
440            assert_eq!(expected, result);
441        }
442    }
443
444    #[test]
445    fn clamp_f64() {
446        let inputs = [
447            (
448                vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
449                -1.0,
450                10.0,
451                vec![Some(-1.0), Some(-1.0), Some(-1.0), Some(0.0), Some(1.0)],
452            ),
453            (
454                vec![Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
455                0.0,
456                0.0,
457                vec![Some(0.0), Some(0.0), Some(0.0), Some(0.0)],
458            ),
459            (
460                vec![Some(-3.0), None, Some(-1.0), None, None, Some(2.0)],
461                -2.0,
462                1.0,
463                vec![Some(-2.0), None, Some(-1.0), None, None, Some(1.0)],
464            ),
465            (
466                vec![None, None, None, None, None],
467                0.0,
468                1.0,
469                vec![None, None, None, None, None],
470            ),
471        ];
472
473        let func = ClampFunction;
474        for (in_data, min, max, expected) in inputs {
475            let args = [
476                Arc::new(Float64Vector::from(in_data)) as _,
477                Arc::new(Float64Vector::from_vec(vec![min])) as _,
478                Arc::new(Float64Vector::from_vec(vec![max])) as _,
479            ];
480            let result = func
481                .eval(&FunctionContext::default(), args.as_slice())
482                .unwrap();
483            let expected: VectorRef = Arc::new(Float64Vector::from(expected));
484            assert_eq!(expected, result);
485        }
486    }
487
488    #[test]
489    fn clamp_const_i32() {
490        let input = vec![Some(5)];
491        let min = 2;
492        let max = 4;
493
494        let func = ClampFunction;
495        let args = [
496            Arc::new(ConstantVector::new(Arc::new(Int64Vector::from(input)), 1)) as _,
497            Arc::new(Int64Vector::from_vec(vec![min])) as _,
498            Arc::new(Int64Vector::from_vec(vec![max])) as _,
499        ];
500        let result = func
501            .eval(&FunctionContext::default(), args.as_slice())
502            .unwrap();
503        let expected: VectorRef = Arc::new(Int64Vector::from(vec![Some(4)]));
504        assert_eq!(expected, result);
505    }
506
507    #[test]
508    fn clamp_invalid_min_max() {
509        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
510        let min = 10.0;
511        let max = -1.0;
512
513        let func = ClampFunction;
514        let args = [
515            Arc::new(Float64Vector::from(input)) as _,
516            Arc::new(Float64Vector::from_vec(vec![min])) as _,
517            Arc::new(Float64Vector::from_vec(vec![max])) as _,
518        ];
519        let result = func.eval(&FunctionContext::default(), args.as_slice());
520        assert!(result.is_err());
521    }
522
523    #[test]
524    fn clamp_type_not_match() {
525        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
526        let min = -1;
527        let max = 10;
528
529        let func = ClampFunction;
530        let args = [
531            Arc::new(Float64Vector::from(input)) as _,
532            Arc::new(Int64Vector::from_vec(vec![min])) as _,
533            Arc::new(UInt64Vector::from_vec(vec![max])) as _,
534        ];
535        let result = func.eval(&FunctionContext::default(), args.as_slice());
536        assert!(result.is_err());
537    }
538
539    #[test]
540    fn clamp_min_is_not_scalar() {
541        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
542        let min = -10.0;
543        let max = 1.0;
544
545        let func = ClampFunction;
546        let args = [
547            Arc::new(Float64Vector::from(input)) as _,
548            Arc::new(Float64Vector::from_vec(vec![min, max])) as _,
549            Arc::new(Float64Vector::from_vec(vec![max, min])) as _,
550        ];
551        let result = func.eval(&FunctionContext::default(), args.as_slice());
552        assert!(result.is_err());
553    }
554
555    #[test]
556    fn clamp_no_max() {
557        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
558        let min = -10.0;
559
560        let func = ClampFunction;
561        let args = [
562            Arc::new(Float64Vector::from(input)) as _,
563            Arc::new(Float64Vector::from_vec(vec![min])) as _,
564        ];
565        let result = func.eval(&FunctionContext::default(), args.as_slice());
566        assert!(result.is_err());
567    }
568
569    #[test]
570    fn clamp_on_string() {
571        let input = vec![Some("foo"), Some("foo"), Some("foo"), Some("foo")];
572
573        let func = ClampFunction;
574        let args = [
575            Arc::new(StringVector::from(input)) as _,
576            Arc::new(StringVector::from_vec(vec!["bar"])) as _,
577            Arc::new(StringVector::from_vec(vec!["baz"])) as _,
578        ];
579        let result = func.eval(&FunctionContext::default(), args.as_slice());
580        assert!(result.is_err());
581    }
582
583    #[test]
584    fn clamp_min_i64() {
585        let inputs = [
586            (
587                vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
588                -1,
589                vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
590            ),
591            (
592                vec![Some(-3), None, Some(-1), None, None, Some(2)],
593                -2,
594                vec![Some(-2), None, Some(-1), None, None, Some(2)],
595            ),
596        ];
597
598        let func = ClampMinFunction;
599        for (in_data, min, expected) in inputs {
600            let args = [
601                Arc::new(Int64Vector::from(in_data)) as _,
602                Arc::new(Int64Vector::from_vec(vec![min])) as _,
603            ];
604            let result = func
605                .eval(&FunctionContext::default(), args.as_slice())
606                .unwrap();
607            let expected: VectorRef = Arc::new(Int64Vector::from(expected));
608            assert_eq!(expected, result);
609        }
610    }
611
612    #[test]
613    fn clamp_max_i64() {
614        let inputs = [
615            (
616                vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
617                1,
618                vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(1)],
619            ),
620            (
621                vec![Some(-3), None, Some(-1), None, None, Some(2)],
622                0,
623                vec![Some(-3), None, Some(-1), None, None, Some(0)],
624            ),
625        ];
626
627        let func = ClampMaxFunction;
628        for (in_data, max, expected) in inputs {
629            let args = [
630                Arc::new(Int64Vector::from(in_data)) as _,
631                Arc::new(Int64Vector::from_vec(vec![max])) as _,
632            ];
633            let result = func
634                .eval(&FunctionContext::default(), args.as_slice())
635                .unwrap();
636            let expected: VectorRef = Arc::new(Int64Vector::from(expected));
637            assert_eq!(expected, result);
638        }
639    }
640
641    #[test]
642    fn clamp_min_f64() {
643        let inputs = [(
644            vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
645            -1.0,
646            vec![Some(-1.0), Some(-1.0), Some(-1.0), Some(0.0), Some(1.0)],
647        )];
648
649        let func = ClampMinFunction;
650        for (in_data, min, expected) in inputs {
651            let args = [
652                Arc::new(Float64Vector::from(in_data)) as _,
653                Arc::new(Float64Vector::from_vec(vec![min])) as _,
654            ];
655            let result = func
656                .eval(&FunctionContext::default(), args.as_slice())
657                .unwrap();
658            let expected: VectorRef = Arc::new(Float64Vector::from(expected));
659            assert_eq!(expected, result);
660        }
661    }
662
663    #[test]
664    fn clamp_max_f64() {
665        let inputs = [(
666            vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
667            0.0,
668            vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(0.0)],
669        )];
670
671        let func = ClampMaxFunction;
672        for (in_data, max, expected) in inputs {
673            let args = [
674                Arc::new(Float64Vector::from(in_data)) as _,
675                Arc::new(Float64Vector::from_vec(vec![max])) as _,
676            ];
677            let result = func
678                .eval(&FunctionContext::default(), args.as_slice())
679                .unwrap();
680            let expected: VectorRef = Arc::new(Float64Vector::from(expected));
681            assert_eq!(expected, result);
682        }
683    }
684
685    #[test]
686    fn clamp_min_type_not_match() {
687        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
688        let min = -1;
689
690        let func = ClampMinFunction;
691        let args = [
692            Arc::new(Float64Vector::from(input)) as _,
693            Arc::new(Int64Vector::from_vec(vec![min])) as _,
694        ];
695        let result = func.eval(&FunctionContext::default(), args.as_slice());
696        assert!(result.is_err());
697    }
698
699    #[test]
700    fn clamp_max_type_not_match() {
701        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
702        let max = 1;
703
704        let func = ClampMaxFunction;
705        let args = [
706            Arc::new(Float64Vector::from(input)) as _,
707            Arc::new(Int64Vector::from_vec(vec![max])) as _,
708        ];
709        let result = func.eval(&FunctionContext::default(), args.as_slice());
710        assert!(result.is_err());
711    }
712}