common_function/scalars/math/
clamp.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::{self, Display};
16use std::sync::Arc;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use common_query::prelude::Signature;
20use datafusion::arrow::array::{ArrayIter, PrimitiveArray};
21use datafusion::logical_expr::Volatility;
22use datatypes::data_type::{ConcreteDataType, DataType};
23use datatypes::prelude::VectorRef;
24use datatypes::types::LogicalPrimitiveType;
25use datatypes::value::TryAsPrimitive;
26use datatypes::vectors::PrimitiveVector;
27use datatypes::with_match_primitive_type_id;
28use snafu::{ensure, OptionExt};
29
30use crate::function::{Function, FunctionContext};
31
32#[derive(Clone, Debug, Default)]
33pub struct ClampFunction;
34
35const CLAMP_NAME: &str = "clamp";
36
37impl Function for ClampFunction {
38    fn name(&self) -> &str {
39        CLAMP_NAME
40    }
41
42    fn return_type(&self, input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
43        // Type check is done by `signature`
44        Ok(input_types[0].clone())
45    }
46
47    fn signature(&self) -> Signature {
48        // input, min, max
49        Signature::uniform(3, ConcreteDataType::numerics(), Volatility::Immutable)
50    }
51
52    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
53        ensure!(
54            columns.len() == 3,
55            InvalidFuncArgsSnafu {
56                err_msg: format!(
57                    "The length of the args is not correct, expect exactly 3, have: {}",
58                    columns.len()
59                ),
60            }
61        );
62        ensure!(
63            columns[0].data_type().is_numeric(),
64            InvalidFuncArgsSnafu {
65                err_msg: format!(
66                    "The first arg's type is not numeric, have: {}",
67                    columns[0].data_type()
68                ),
69            }
70        );
71        ensure!(
72            columns[0].data_type() == columns[1].data_type()
73                && columns[1].data_type() == columns[2].data_type(),
74            InvalidFuncArgsSnafu {
75                err_msg: format!(
76                    "Arguments don't have identical types: {}, {}, {}",
77                    columns[0].data_type(),
78                    columns[1].data_type(),
79                    columns[2].data_type()
80                ),
81            }
82        );
83        ensure!(
84            columns[1].len() == 1 && columns[2].len() == 1,
85            InvalidFuncArgsSnafu {
86                err_msg: format!(
87                    "The second and third args should be scalar, have: {:?}, {:?}",
88                    columns[1], columns[2]
89                ),
90            }
91        );
92
93        with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
94            let input_array = columns[0].to_arrow_array();
95            let input = input_array
96                    .as_any()
97                    .downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
98                    .unwrap();
99
100            let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
101                .with_context(|| {
102                    InvalidFuncArgsSnafu {
103                        err_msg: "The second arg should not be none",
104                    }
105                })?;
106            let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[2].get(0))
107                .with_context(|| {
108                    InvalidFuncArgsSnafu {
109                        err_msg: "The third arg should not be none",
110                    }
111                })?;
112
113            // ensure min <= max
114            ensure!(
115                min <= max,
116                    InvalidFuncArgsSnafu {
117                        err_msg: format!(
118                        "The second arg should be less than or equal to the third arg, have: {:?}, {:?}",
119                        columns[1], columns[2]
120                    ),
121                }
122            );
123
124            clamp_impl::<$S, true, true>(input, min, max)
125        },{
126            unreachable!()
127        })
128    }
129}
130
131impl Display for ClampFunction {
132    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
133        write!(f, "{}", CLAMP_NAME.to_ascii_uppercase())
134    }
135}
136
137fn clamp_impl<T: LogicalPrimitiveType, const CLAMP_MIN: bool, const CLAMP_MAX: bool>(
138    input: &PrimitiveArray<T::ArrowPrimitive>,
139    min: T::Native,
140    max: T::Native,
141) -> Result<VectorRef> {
142    let iter = ArrayIter::new(input);
143    let result = iter.map(|x| {
144        x.map(|x| {
145            if CLAMP_MIN && x < min {
146                min
147            } else if CLAMP_MAX && x > max {
148                max
149            } else {
150                x
151            }
152        })
153    });
154    let result = PrimitiveArray::<T::ArrowPrimitive>::from_iter(result);
155    Ok(Arc::new(PrimitiveVector::<T>::from(result)))
156}
157
158#[derive(Clone, Debug, Default)]
159pub struct ClampMinFunction;
160
161const CLAMP_MIN_NAME: &str = "clamp_min";
162
163impl Function for ClampMinFunction {
164    fn name(&self) -> &str {
165        CLAMP_MIN_NAME
166    }
167
168    fn return_type(&self, input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
169        Ok(input_types[0].clone())
170    }
171
172    fn signature(&self) -> Signature {
173        // input, min
174        Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable)
175    }
176
177    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
178        ensure!(
179            columns.len() == 2,
180            InvalidFuncArgsSnafu {
181                err_msg: format!(
182                    "The length of the args is not correct, expect exactly 2, have: {}",
183                    columns.len()
184                ),
185            }
186        );
187        ensure!(
188            columns[0].data_type().is_numeric(),
189            InvalidFuncArgsSnafu {
190                err_msg: format!(
191                    "The first arg's type is not numeric, have: {}",
192                    columns[0].data_type()
193                ),
194            }
195        );
196        ensure!(
197            columns[0].data_type() == columns[1].data_type(),
198            InvalidFuncArgsSnafu {
199                err_msg: format!(
200                    "Arguments don't have identical types: {}, {}",
201                    columns[0].data_type(),
202                    columns[1].data_type()
203                ),
204            }
205        );
206        ensure!(
207            columns[1].len() == 1,
208            InvalidFuncArgsSnafu {
209                err_msg: format!(
210                    "The second arg (min) should be scalar, have: {:?}",
211                    columns[1]
212                ),
213            }
214        );
215
216        with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
217            let input_array = columns[0].to_arrow_array();
218            let input = input_array
219                .as_any()
220                .downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
221                .unwrap();
222
223            let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
224                .with_context(|| {
225                    InvalidFuncArgsSnafu {
226                        err_msg: "The second arg (min) should not be none",
227                    }
228                })?;
229            // For clamp_min, max is effectively infinity, so we don't use it in the clamp_impl logic.
230            // We pass a default/dummy value for max.
231            let max_dummy = <$S as LogicalPrimitiveType>::Native::default();
232
233            clamp_impl::<$S, true, false>(input, min, max_dummy)
234        },{
235            unreachable!()
236        })
237    }
238}
239
240impl Display for ClampMinFunction {
241    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
242        write!(f, "{}", CLAMP_MIN_NAME.to_ascii_uppercase())
243    }
244}
245
246#[derive(Clone, Debug, Default)]
247pub struct ClampMaxFunction;
248
249const CLAMP_MAX_NAME: &str = "clamp_max";
250
251impl Function for ClampMaxFunction {
252    fn name(&self) -> &str {
253        CLAMP_MAX_NAME
254    }
255
256    fn return_type(&self, input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
257        Ok(input_types[0].clone())
258    }
259
260    fn signature(&self) -> Signature {
261        // input, max
262        Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable)
263    }
264
265    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
266        ensure!(
267            columns.len() == 2,
268            InvalidFuncArgsSnafu {
269                err_msg: format!(
270                    "The length of the args is not correct, expect exactly 2, have: {}",
271                    columns.len()
272                ),
273            }
274        );
275        ensure!(
276            columns[0].data_type().is_numeric(),
277            InvalidFuncArgsSnafu {
278                err_msg: format!(
279                    "The first arg's type is not numeric, have: {}",
280                    columns[0].data_type()
281                ),
282            }
283        );
284        ensure!(
285            columns[0].data_type() == columns[1].data_type(),
286            InvalidFuncArgsSnafu {
287                err_msg: format!(
288                    "Arguments don't have identical types: {}, {}",
289                    columns[0].data_type(),
290                    columns[1].data_type()
291                ),
292            }
293        );
294        ensure!(
295            columns[1].len() == 1,
296            InvalidFuncArgsSnafu {
297                err_msg: format!(
298                    "The second arg (max) should be scalar, have: {:?}",
299                    columns[1]
300                ),
301            }
302        );
303
304        with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
305            let input_array = columns[0].to_arrow_array();
306            let input = input_array
307                .as_any()
308                .downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
309                .unwrap();
310
311            let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
312                .with_context(|| {
313                    InvalidFuncArgsSnafu {
314                        err_msg: "The second arg (max) should not be none",
315                    }
316                })?;
317            // For clamp_max, min is effectively -infinity, so we don't use it in the clamp_impl logic.
318            // We pass a default/dummy value for min.
319            let min_dummy = <$S as LogicalPrimitiveType>::Native::default();
320
321            clamp_impl::<$S, false, true>(input, min_dummy, max)
322        },{
323            unreachable!()
324        })
325    }
326}
327
328impl Display for ClampMaxFunction {
329    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
330        write!(f, "{}", CLAMP_MAX_NAME.to_ascii_uppercase())
331    }
332}
333
334#[cfg(test)]
335mod test {
336
337    use std::sync::Arc;
338
339    use datatypes::prelude::ScalarVector;
340    use datatypes::vectors::{
341        ConstantVector, Float64Vector, Int64Vector, StringVector, UInt64Vector,
342    };
343
344    use super::*;
345    use crate::function::FunctionContext;
346
347    #[test]
348    fn clamp_i64() {
349        let inputs = [
350            (
351                vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
352                -1,
353                10,
354                vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
355            ),
356            (
357                vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
358                0,
359                0,
360                vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
361            ),
362            (
363                vec![Some(-3), None, Some(-1), None, None, Some(2)],
364                -2,
365                1,
366                vec![Some(-2), None, Some(-1), None, None, Some(1)],
367            ),
368            (
369                vec![None, None, None, None, None],
370                0,
371                1,
372                vec![None, None, None, None, None],
373            ),
374        ];
375
376        let func = ClampFunction;
377        for (in_data, min, max, expected) in inputs {
378            let args = [
379                Arc::new(Int64Vector::from(in_data)) as _,
380                Arc::new(Int64Vector::from_vec(vec![min])) as _,
381                Arc::new(Int64Vector::from_vec(vec![max])) as _,
382            ];
383            let result = func
384                .eval(&FunctionContext::default(), args.as_slice())
385                .unwrap();
386            let expected: VectorRef = Arc::new(Int64Vector::from(expected));
387            assert_eq!(expected, result);
388        }
389    }
390
391    #[test]
392    fn clamp_u64() {
393        let inputs = [
394            (
395                vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
396                1,
397                3,
398                vec![Some(1), Some(1), Some(2), Some(3), Some(3), Some(3)],
399            ),
400            (
401                vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
402                0,
403                0,
404                vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
405            ),
406            (
407                vec![Some(0), None, Some(2), None, None, Some(5)],
408                1,
409                3,
410                vec![Some(1), None, Some(2), None, None, Some(3)],
411            ),
412            (
413                vec![None, None, None, None, None],
414                0,
415                1,
416                vec![None, None, None, None, None],
417            ),
418        ];
419
420        let func = ClampFunction;
421        for (in_data, min, max, expected) in inputs {
422            let args = [
423                Arc::new(UInt64Vector::from(in_data)) as _,
424                Arc::new(UInt64Vector::from_vec(vec![min])) as _,
425                Arc::new(UInt64Vector::from_vec(vec![max])) as _,
426            ];
427            let result = func
428                .eval(&FunctionContext::default(), args.as_slice())
429                .unwrap();
430            let expected: VectorRef = Arc::new(UInt64Vector::from(expected));
431            assert_eq!(expected, result);
432        }
433    }
434
435    #[test]
436    fn clamp_f64() {
437        let inputs = [
438            (
439                vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
440                -1.0,
441                10.0,
442                vec![Some(-1.0), Some(-1.0), Some(-1.0), Some(0.0), Some(1.0)],
443            ),
444            (
445                vec![Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
446                0.0,
447                0.0,
448                vec![Some(0.0), Some(0.0), Some(0.0), Some(0.0)],
449            ),
450            (
451                vec![Some(-3.0), None, Some(-1.0), None, None, Some(2.0)],
452                -2.0,
453                1.0,
454                vec![Some(-2.0), None, Some(-1.0), None, None, Some(1.0)],
455            ),
456            (
457                vec![None, None, None, None, None],
458                0.0,
459                1.0,
460                vec![None, None, None, None, None],
461            ),
462        ];
463
464        let func = ClampFunction;
465        for (in_data, min, max, expected) in inputs {
466            let args = [
467                Arc::new(Float64Vector::from(in_data)) as _,
468                Arc::new(Float64Vector::from_vec(vec![min])) as _,
469                Arc::new(Float64Vector::from_vec(vec![max])) as _,
470            ];
471            let result = func
472                .eval(&FunctionContext::default(), args.as_slice())
473                .unwrap();
474            let expected: VectorRef = Arc::new(Float64Vector::from(expected));
475            assert_eq!(expected, result);
476        }
477    }
478
479    #[test]
480    fn clamp_const_i32() {
481        let input = vec![Some(5)];
482        let min = 2;
483        let max = 4;
484
485        let func = ClampFunction;
486        let args = [
487            Arc::new(ConstantVector::new(Arc::new(Int64Vector::from(input)), 1)) as _,
488            Arc::new(Int64Vector::from_vec(vec![min])) as _,
489            Arc::new(Int64Vector::from_vec(vec![max])) as _,
490        ];
491        let result = func
492            .eval(&FunctionContext::default(), args.as_slice())
493            .unwrap();
494        let expected: VectorRef = Arc::new(Int64Vector::from(vec![Some(4)]));
495        assert_eq!(expected, result);
496    }
497
498    #[test]
499    fn clamp_invalid_min_max() {
500        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
501        let min = 10.0;
502        let max = -1.0;
503
504        let func = ClampFunction;
505        let args = [
506            Arc::new(Float64Vector::from(input)) as _,
507            Arc::new(Float64Vector::from_vec(vec![min])) as _,
508            Arc::new(Float64Vector::from_vec(vec![max])) as _,
509        ];
510        let result = func.eval(&FunctionContext::default(), args.as_slice());
511        assert!(result.is_err());
512    }
513
514    #[test]
515    fn clamp_type_not_match() {
516        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
517        let min = -1;
518        let max = 10;
519
520        let func = ClampFunction;
521        let args = [
522            Arc::new(Float64Vector::from(input)) as _,
523            Arc::new(Int64Vector::from_vec(vec![min])) as _,
524            Arc::new(UInt64Vector::from_vec(vec![max])) as _,
525        ];
526        let result = func.eval(&FunctionContext::default(), args.as_slice());
527        assert!(result.is_err());
528    }
529
530    #[test]
531    fn clamp_min_is_not_scalar() {
532        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
533        let min = -10.0;
534        let max = 1.0;
535
536        let func = ClampFunction;
537        let args = [
538            Arc::new(Float64Vector::from(input)) as _,
539            Arc::new(Float64Vector::from_vec(vec![min, min])) as _,
540            Arc::new(Float64Vector::from_vec(vec![max])) as _,
541        ];
542        let result = func.eval(&FunctionContext::default(), args.as_slice());
543        assert!(result.is_err());
544    }
545
546    #[test]
547    fn clamp_no_max() {
548        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
549        let min = -10.0;
550
551        let func = ClampFunction;
552        let args = [
553            Arc::new(Float64Vector::from(input)) as _,
554            Arc::new(Float64Vector::from_vec(vec![min])) as _,
555        ];
556        let result = func.eval(&FunctionContext::default(), args.as_slice());
557        assert!(result.is_err());
558    }
559
560    #[test]
561    fn clamp_on_string() {
562        let input = vec![Some("foo"), Some("foo"), Some("foo"), Some("foo")];
563
564        let func = ClampFunction;
565        let args = [
566            Arc::new(StringVector::from(input)) as _,
567            Arc::new(StringVector::from_vec(vec!["bar"])) as _,
568            Arc::new(StringVector::from_vec(vec!["baz"])) as _,
569        ];
570        let result = func.eval(&FunctionContext::default(), args.as_slice());
571        assert!(result.is_err());
572    }
573
574    #[test]
575    fn clamp_min_i64() {
576        let inputs = [
577            (
578                vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
579                -1,
580                vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
581            ),
582            (
583                vec![Some(-3), None, Some(-1), None, None, Some(2)],
584                -2,
585                vec![Some(-2), None, Some(-1), None, None, Some(2)],
586            ),
587        ];
588
589        let func = ClampMinFunction;
590        for (in_data, min, expected) in inputs {
591            let args = [
592                Arc::new(Int64Vector::from(in_data)) as _,
593                Arc::new(Int64Vector::from_vec(vec![min])) as _,
594            ];
595            let result = func
596                .eval(&FunctionContext::default(), args.as_slice())
597                .unwrap();
598            let expected: VectorRef = Arc::new(Int64Vector::from(expected));
599            assert_eq!(expected, result);
600        }
601    }
602
603    #[test]
604    fn clamp_max_i64() {
605        let inputs = [
606            (
607                vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
608                1,
609                vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(1)],
610            ),
611            (
612                vec![Some(-3), None, Some(-1), None, None, Some(2)],
613                0,
614                vec![Some(-3), None, Some(-1), None, None, Some(0)],
615            ),
616        ];
617
618        let func = ClampMaxFunction;
619        for (in_data, max, expected) in inputs {
620            let args = [
621                Arc::new(Int64Vector::from(in_data)) as _,
622                Arc::new(Int64Vector::from_vec(vec![max])) as _,
623            ];
624            let result = func
625                .eval(&FunctionContext::default(), args.as_slice())
626                .unwrap();
627            let expected: VectorRef = Arc::new(Int64Vector::from(expected));
628            assert_eq!(expected, result);
629        }
630    }
631
632    #[test]
633    fn clamp_min_f64() {
634        let inputs = [(
635            vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
636            -1.0,
637            vec![Some(-1.0), Some(-1.0), Some(-1.0), Some(0.0), Some(1.0)],
638        )];
639
640        let func = ClampMinFunction;
641        for (in_data, min, expected) in inputs {
642            let args = [
643                Arc::new(Float64Vector::from(in_data)) as _,
644                Arc::new(Float64Vector::from_vec(vec![min])) as _,
645            ];
646            let result = func
647                .eval(&FunctionContext::default(), args.as_slice())
648                .unwrap();
649            let expected: VectorRef = Arc::new(Float64Vector::from(expected));
650            assert_eq!(expected, result);
651        }
652    }
653
654    #[test]
655    fn clamp_max_f64() {
656        let inputs = [(
657            vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
658            0.0,
659            vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(0.0)],
660        )];
661
662        let func = ClampMaxFunction;
663        for (in_data, max, expected) in inputs {
664            let args = [
665                Arc::new(Float64Vector::from(in_data)) as _,
666                Arc::new(Float64Vector::from_vec(vec![max])) as _,
667            ];
668            let result = func
669                .eval(&FunctionContext::default(), args.as_slice())
670                .unwrap();
671            let expected: VectorRef = Arc::new(Float64Vector::from(expected));
672            assert_eq!(expected, result);
673        }
674    }
675
676    #[test]
677    fn clamp_min_type_not_match() {
678        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
679        let min = -1;
680
681        let func = ClampMinFunction;
682        let args = [
683            Arc::new(Float64Vector::from(input)) as _,
684            Arc::new(Int64Vector::from_vec(vec![min])) as _,
685        ];
686        let result = func.eval(&FunctionContext::default(), args.as_slice());
687        assert!(result.is_err());
688    }
689
690    #[test]
691    fn clamp_max_type_not_match() {
692        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
693        let max = 1;
694
695        let func = ClampMaxFunction;
696        let args = [
697            Arc::new(Float64Vector::from(input)) as _,
698            Arc::new(Int64Vector::from_vec(vec![max])) as _,
699        ];
700        let result = func.eval(&FunctionContext::default(), args.as_slice());
701        assert!(result.is_err());
702    }
703}