common_function/scalars/math/
clamp.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::{self, Display};
16use std::sync::Arc;
17
18use datafusion::arrow::array::{Array, ArrayRef, AsArray, PrimitiveArray};
19use datafusion::arrow::datatypes::DataType as ArrowDataType;
20use datafusion::logical_expr::{ColumnarValue, Volatility};
21use datafusion_common::{DataFusionError, ScalarValue, utils};
22use datafusion_expr::type_coercion::aggregates::NUMERICS;
23use datafusion_expr::{ScalarFunctionArgs, Signature};
24
25use crate::function::Function;
26
27#[derive(Clone, Debug)]
28pub struct ClampFunction {
29    signature: Signature,
30}
31
32impl Default for ClampFunction {
33    fn default() -> Self {
34        Self {
35            // input, min, max
36            signature: Signature::uniform(3, NUMERICS.to_vec(), Volatility::Immutable),
37        }
38    }
39}
40
41const CLAMP_NAME: &str = "clamp";
42
43impl Function for ClampFunction {
44    fn name(&self) -> &str {
45        CLAMP_NAME
46    }
47
48    fn return_type(
49        &self,
50        input_types: &[ArrowDataType],
51    ) -> datafusion_common::Result<ArrowDataType> {
52        // Type check is done by `signature`
53        Ok(input_types[0].clone())
54    }
55
56    fn signature(&self) -> &Signature {
57        &self.signature
58    }
59
60    fn invoke_with_args(
61        &self,
62        args: ScalarFunctionArgs,
63    ) -> datafusion_common::Result<ColumnarValue> {
64        let [col, min, max] = utils::take_function_args(self.name(), args.args)?;
65        clamp_impl(col, min, max)
66    }
67}
68
69impl Display for ClampFunction {
70    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
71        write!(f, "{}", CLAMP_NAME.to_ascii_uppercase())
72    }
73}
74
75fn clamp_impl(
76    col: ColumnarValue,
77    min: ColumnarValue,
78    max: ColumnarValue,
79) -> datafusion_common::Result<ColumnarValue> {
80    if col.data_type() != min.data_type() || min.data_type() != max.data_type() {
81        return Err(DataFusionError::Execution(format!(
82            "argument data types mismatch: {}, {}, {}",
83            col.data_type(),
84            min.data_type(),
85            max.data_type(),
86        )));
87    }
88
89    macro_rules! with_match_numerics_types {
90        ($data_type:expr, | $_:tt $T:ident | $body:tt) => {{
91            macro_rules! __with_ty__ {
92                ( $_ $T:ident ) => {
93                    $body
94                };
95            }
96
97            use datafusion::arrow::datatypes::{
98                Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type,
99                UInt16Type, UInt32Type, UInt64Type,
100            };
101
102            match $data_type {
103                ArrowDataType::Int8 => Ok(__with_ty__! { Int8Type }),
104                ArrowDataType::Int16 => Ok(__with_ty__! { Int16Type }),
105                ArrowDataType::Int32 => Ok(__with_ty__! { Int32Type }),
106                ArrowDataType::Int64 => Ok(__with_ty__! { Int64Type }),
107                ArrowDataType::UInt8 => Ok(__with_ty__! { UInt8Type }),
108                ArrowDataType::UInt16 => Ok(__with_ty__! { UInt16Type }),
109                ArrowDataType::UInt32 => Ok(__with_ty__! { UInt32Type }),
110                ArrowDataType::UInt64 => Ok(__with_ty__! { UInt64Type }),
111                ArrowDataType::Float32 => Ok(__with_ty__! { Float32Type }),
112                ArrowDataType::Float64 => Ok(__with_ty__! { Float64Type }),
113                _ => Err(DataFusionError::Execution(format!(
114                    "unsupported numeric data type: '{}'",
115                    $data_type
116                ))),
117            }
118        }};
119    }
120
121    macro_rules! clamp {
122        ($v: ident, $min: ident, $max: ident) => {
123            if $v < $min {
124                $min
125            } else if $v > $max {
126                $max
127            } else {
128                $v
129            }
130        };
131    }
132
133    match (col, min, max) {
134        (ColumnarValue::Scalar(col), ColumnarValue::Scalar(min), ColumnarValue::Scalar(max)) => {
135            if min > max {
136                return Err(DataFusionError::Execution(format!(
137                    "min '{}' > max '{}'",
138                    min, max
139                )));
140            }
141            Ok(ColumnarValue::Scalar(clamp!(col, min, max)))
142        }
143
144        (ColumnarValue::Array(col), ColumnarValue::Array(min), ColumnarValue::Array(max)) => {
145            if col.len() != min.len() || col.len() != max.len() {
146                return Err(DataFusionError::Internal(
147                    "arguments not of same length".to_string(),
148                ));
149            }
150            let result = with_match_numerics_types!(
151                col.data_type(),
152                |$S| {
153                    let col = col.as_primitive::<$S>();
154                    let min = min.as_primitive::<$S>();
155                    let max = max.as_primitive::<$S>();
156                    Arc::new(PrimitiveArray::<$S>::from(
157                        (0..col.len())
158                            .map(|i| {
159                                let v = col.is_valid(i).then(|| col.value(i));
160                                // Index safety: checked above, all have same length.
161                                let min = min.is_valid(i).then(|| min.value(i));
162                                let max = max.is_valid(i).then(|| max.value(i));
163                                Ok(match (v, min, max) {
164                                    (Some(v), Some(min), Some(max)) => {
165                                        if min > max {
166                                            return Err(DataFusionError::Execution(format!(
167                                                "min '{}' > max '{}'",
168                                                min, max
169                                            )));
170                                        }
171                                        Some(clamp!(v, min, max))
172                                    },
173                                    _ => None,
174                                })
175                            })
176                            .collect::<datafusion_common::Result<Vec<_>>>()?,
177                        )
178                    ) as ArrayRef
179                }
180            )?;
181            Ok(ColumnarValue::Array(result))
182        }
183
184        (ColumnarValue::Array(col), ColumnarValue::Scalar(min), ColumnarValue::Scalar(max)) => {
185            if min.is_null() || max.is_null() {
186                return Err(DataFusionError::Execution(
187                    "argument 'min' or 'max' is null".to_string(),
188                ));
189            }
190            let min = min.to_array()?;
191            let max = max.to_array()?;
192            let result = with_match_numerics_types!(
193                col.data_type(),
194                |$S| {
195                    let col = col.as_primitive::<$S>();
196                    // Index safety: checked above, both are not nulls.
197                    let min = min.as_primitive::<$S>().value(0);
198                    let max = max.as_primitive::<$S>().value(0);
199                    if min > max {
200                        return Err(DataFusionError::Execution(format!(
201                            "min '{}' > max '{}'",
202                            min, max
203                        )));
204                    }
205                    Arc::new(PrimitiveArray::<$S>::from(
206                        (0..col.len())
207                            .map(|x| {
208                                col.is_valid(x).then(|| {
209                                    let v = col.value(x);
210                                    clamp!(v, min, max)
211                                })
212                            })
213                            .collect::<Vec<_>>(),
214                        )
215                    ) as ArrayRef
216                }
217            )?;
218            Ok(ColumnarValue::Array(result))
219        }
220        _ => Err(DataFusionError::Internal(
221            "argument column types mismatch".to_string(),
222        )),
223    }
224}
225
226#[derive(Clone, Debug)]
227pub struct ClampMinFunction {
228    signature: Signature,
229}
230
231impl Default for ClampMinFunction {
232    fn default() -> Self {
233        Self {
234            // input, min
235            signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
236        }
237    }
238}
239
240const CLAMP_MIN_NAME: &str = "clamp_min";
241
242impl Function for ClampMinFunction {
243    fn name(&self) -> &str {
244        CLAMP_MIN_NAME
245    }
246
247    fn return_type(
248        &self,
249        input_types: &[ArrowDataType],
250    ) -> datafusion_common::Result<ArrowDataType> {
251        Ok(input_types[0].clone())
252    }
253
254    fn signature(&self) -> &Signature {
255        &self.signature
256    }
257
258    fn invoke_with_args(
259        &self,
260        args: ScalarFunctionArgs,
261    ) -> datafusion_common::Result<ColumnarValue> {
262        let [col, min] = utils::take_function_args(self.name(), args.args)?;
263
264        let Some(max) = ScalarValue::max(&min.data_type()) else {
265            return Err(DataFusionError::Internal(format!(
266                "cannot find a max value for numeric data type {}",
267                min.data_type()
268            )));
269        };
270        clamp_impl(col, min, ColumnarValue::Scalar(max))
271    }
272}
273
274impl Display for ClampMinFunction {
275    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
276        write!(f, "{}", CLAMP_MIN_NAME.to_ascii_uppercase())
277    }
278}
279
280#[derive(Clone, Debug)]
281pub struct ClampMaxFunction {
282    signature: Signature,
283}
284
285impl Default for ClampMaxFunction {
286    fn default() -> Self {
287        Self {
288            // input, max
289            signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
290        }
291    }
292}
293
294const CLAMP_MAX_NAME: &str = "clamp_max";
295
296impl Function for ClampMaxFunction {
297    fn name(&self) -> &str {
298        CLAMP_MAX_NAME
299    }
300
301    fn return_type(
302        &self,
303        input_types: &[ArrowDataType],
304    ) -> datafusion_common::Result<ArrowDataType> {
305        Ok(input_types[0].clone())
306    }
307
308    fn signature(&self) -> &Signature {
309        &self.signature
310    }
311
312    fn invoke_with_args(
313        &self,
314        args: ScalarFunctionArgs,
315    ) -> datafusion_common::Result<ColumnarValue> {
316        let [col, max] = utils::take_function_args(self.name(), args.args)?;
317
318        let Some(min) = ScalarValue::min(&max.data_type()) else {
319            return Err(DataFusionError::Internal(format!(
320                "cannot find a min value for numeric data type {}",
321                max.data_type()
322            )));
323        };
324        clamp_impl(col, ColumnarValue::Scalar(min), max)
325    }
326}
327
328impl Display for ClampMaxFunction {
329    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
330        write!(f, "{}", CLAMP_MAX_NAME.to_ascii_uppercase())
331    }
332}
333
334#[cfg(test)]
335mod test {
336
337    use std::sync::Arc;
338
339    use arrow_schema::Field;
340    use datafusion_common::config::ConfigOptions;
341    use datatypes::arrow::array::{ArrayRef, Float64Array, Int64Array, UInt64Array};
342    use datatypes::arrow_array::StringArray;
343
344    use super::*;
345
346    macro_rules! impl_test_eval {
347        ($func: ty) => {
348            impl $func {
349                fn test_eval(
350                    &self,
351                    args: Vec<ColumnarValue>,
352                    number_rows: usize,
353                ) -> datafusion_common::Result<ArrayRef> {
354                    let input_type = args[0].data_type();
355                    self.invoke_with_args(ScalarFunctionArgs {
356                        args,
357                        arg_fields: vec![],
358                        number_rows,
359                        return_field: Arc::new(Field::new("x", input_type, false)),
360                        config_options: Arc::new(ConfigOptions::new()),
361                    })
362                    .and_then(|v| ColumnarValue::values_to_arrays(&[v]).map_err(Into::into))
363                    .map(|mut a| a.remove(0))
364                }
365            }
366        };
367    }
368
369    impl_test_eval!(ClampFunction);
370    impl_test_eval!(ClampMinFunction);
371    impl_test_eval!(ClampMaxFunction);
372
373    #[test]
374    fn clamp_i64() {
375        let inputs = [
376            (
377                vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
378                -1i64,
379                10i64,
380                vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
381            ),
382            (
383                vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
384                0i64,
385                0i64,
386                vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
387            ),
388            (
389                vec![Some(-3), None, Some(-1), None, None, Some(2)],
390                -2i64,
391                1i64,
392                vec![Some(-2), None, Some(-1), None, None, Some(1)],
393            ),
394            (
395                vec![None, None, None, None, None],
396                0i64,
397                1i64,
398                vec![None, None, None, None, None],
399            ),
400        ];
401
402        let func = ClampFunction::default();
403        for (in_data, min, max, expected) in inputs {
404            let number_rows = in_data.len();
405            let args = vec![
406                ColumnarValue::Array(Arc::new(Int64Array::from(in_data))),
407                ColumnarValue::Scalar(min.into()),
408                ColumnarValue::Scalar(max.into()),
409            ];
410            let result = func.test_eval(args, number_rows).unwrap();
411            let expected: ArrayRef = Arc::new(Int64Array::from(expected));
412            assert_eq!(expected.as_ref(), result.as_ref());
413        }
414    }
415
416    #[test]
417    fn clamp_u64() {
418        let inputs = [
419            (
420                vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
421                1u64,
422                3u64,
423                vec![Some(1), Some(1), Some(2), Some(3), Some(3), Some(3)],
424            ),
425            (
426                vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
427                0u64,
428                0u64,
429                vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
430            ),
431            (
432                vec![Some(0), None, Some(2), None, None, Some(5)],
433                1u64,
434                3u64,
435                vec![Some(1), None, Some(2), None, None, Some(3)],
436            ),
437            (
438                vec![None, None, None, None, None],
439                0u64,
440                1u64,
441                vec![None, None, None, None, None],
442            ),
443        ];
444
445        let func = ClampFunction::default();
446        for (in_data, min, max, expected) in inputs {
447            let number_rows = in_data.len();
448            let args = vec![
449                ColumnarValue::Array(Arc::new(UInt64Array::from(in_data))),
450                ColumnarValue::Scalar(min.into()),
451                ColumnarValue::Scalar(max.into()),
452            ];
453            let result = func.test_eval(args, number_rows).unwrap();
454            let expected: ArrayRef = Arc::new(UInt64Array::from(expected));
455            assert_eq!(expected.as_ref(), result.as_ref());
456        }
457    }
458
459    #[test]
460    fn clamp_f64() {
461        let inputs = [
462            (
463                vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
464                -1.0,
465                10.0,
466                vec![Some(-1.0), Some(-1.0), Some(-1.0), Some(0.0), Some(1.0)],
467            ),
468            (
469                vec![Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
470                0.0,
471                0.0,
472                vec![Some(0.0), Some(0.0), Some(0.0), Some(0.0)],
473            ),
474            (
475                vec![Some(-3.0), None, Some(-1.0), None, None, Some(2.0)],
476                -2.0,
477                1.0,
478                vec![Some(-2.0), None, Some(-1.0), None, None, Some(1.0)],
479            ),
480            (
481                vec![None, None, None, None, None],
482                0.0,
483                1.0,
484                vec![None, None, None, None, None],
485            ),
486        ];
487
488        let func = ClampFunction::default();
489        for (in_data, min, max, expected) in inputs {
490            let number_rows = in_data.len();
491            let args = vec![
492                ColumnarValue::Array(Arc::new(Float64Array::from(in_data))),
493                ColumnarValue::Scalar(min.into()),
494                ColumnarValue::Scalar(max.into()),
495            ];
496            let result = func.test_eval(args, number_rows).unwrap();
497            let expected: ArrayRef = Arc::new(Float64Array::from(expected));
498            assert_eq!(expected.as_ref(), result.as_ref());
499        }
500    }
501
502    #[test]
503    fn clamp_invalid_min_max() {
504        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
505        let min = 10.0;
506        let max = -1.0;
507
508        let func = ClampFunction::default();
509        let number_rows = input.len();
510        let args = vec![
511            ColumnarValue::Array(Arc::new(Float64Array::from(input))),
512            ColumnarValue::Scalar(min.into()),
513            ColumnarValue::Scalar(max.into()),
514        ];
515        let result = func.test_eval(args, number_rows);
516        assert!(result.is_err());
517    }
518
519    #[test]
520    fn clamp_type_not_match() {
521        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
522        let min = -1i64;
523        let max = 10u64;
524
525        let func = ClampFunction::default();
526        let number_rows = input.len();
527        let args = vec![
528            ColumnarValue::Array(Arc::new(Float64Array::from(input))),
529            ColumnarValue::Scalar(min.into()),
530            ColumnarValue::Scalar(max.into()),
531        ];
532        let result = func.test_eval(args, number_rows);
533        assert!(result.is_err());
534    }
535
536    #[test]
537    fn clamp_min_is_not_scalar() {
538        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
539        let min = -10.0;
540        let max = 1.0;
541
542        let func = ClampFunction::default();
543        let number_rows = input.len();
544        let args = vec![
545            ColumnarValue::Array(Arc::new(Float64Array::from(input))),
546            ColumnarValue::Array(Arc::new(Float64Array::from(vec![min, max]))),
547            ColumnarValue::Array(Arc::new(Float64Array::from(vec![max, min]))),
548        ];
549        let result = func.test_eval(args, number_rows);
550        assert!(result.is_err());
551    }
552
553    #[test]
554    fn clamp_no_max() {
555        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
556        let min = -10.0;
557
558        let func = ClampFunction::default();
559        let number_rows = input.len();
560        let args = vec![
561            ColumnarValue::Array(Arc::new(Float64Array::from(input))),
562            ColumnarValue::Scalar(min.into()),
563        ];
564        let result = func.test_eval(args, number_rows);
565        assert!(result.is_err());
566    }
567
568    #[test]
569    fn clamp_on_string() {
570        let input = vec![Some("foo"), Some("foo"), Some("foo"), Some("foo")];
571
572        let func = ClampFunction::default();
573        let number_rows = input.len();
574        let args = vec![
575            ColumnarValue::Array(Arc::new(StringArray::from(input))),
576            ColumnarValue::Scalar("bar".into()),
577            ColumnarValue::Scalar("baz".into()),
578        ];
579        let result = func.test_eval(args, number_rows);
580        assert!(result.is_err());
581    }
582
583    #[test]
584    fn clamp_min_i64() {
585        let inputs = [
586            (
587                vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
588                -1i64,
589                vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
590            ),
591            (
592                vec![Some(-3), None, Some(-1), None, None, Some(2)],
593                -2i64,
594                vec![Some(-2), None, Some(-1), None, None, Some(2)],
595            ),
596        ];
597
598        let func = ClampMinFunction::default();
599        for (in_data, min, expected) in inputs {
600            let number_rows = in_data.len();
601            let args = vec![
602                ColumnarValue::Array(Arc::new(Int64Array::from(in_data))),
603                ColumnarValue::Scalar(min.into()),
604            ];
605            let result = func.test_eval(args, number_rows).unwrap();
606            let expected: ArrayRef = Arc::new(Int64Array::from(expected));
607            assert_eq!(expected.as_ref(), result.as_ref());
608        }
609    }
610
611    #[test]
612    fn clamp_max_i64() {
613        let inputs = [
614            (
615                vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
616                1i64,
617                vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(1)],
618            ),
619            (
620                vec![Some(-3), None, Some(-1), None, None, Some(2)],
621                0i64,
622                vec![Some(-3), None, Some(-1), None, None, Some(0)],
623            ),
624        ];
625
626        let func = ClampMaxFunction::default();
627        for (in_data, max, expected) in inputs {
628            let number_rows = in_data.len();
629            let args = vec![
630                ColumnarValue::Array(Arc::new(Int64Array::from(in_data))),
631                ColumnarValue::Scalar(max.into()),
632            ];
633            let result = func.test_eval(args, number_rows).unwrap();
634            let expected: ArrayRef = Arc::new(Int64Array::from(expected));
635            assert_eq!(expected.as_ref(), result.as_ref());
636        }
637    }
638
639    #[test]
640    fn clamp_min_f64() {
641        let inputs = [(
642            vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
643            -1.0,
644            vec![Some(-1.0), Some(-1.0), Some(-1.0), Some(0.0), Some(1.0)],
645        )];
646
647        let func = ClampMinFunction::default();
648        for (in_data, min, expected) in inputs {
649            let number_rows = in_data.len();
650            let args = vec![
651                ColumnarValue::Array(Arc::new(Float64Array::from(in_data))),
652                ColumnarValue::Scalar(min.into()),
653            ];
654            let result = func.test_eval(args, number_rows).unwrap();
655            let expected: ArrayRef = Arc::new(Float64Array::from(expected));
656            assert_eq!(expected.as_ref(), result.as_ref());
657        }
658    }
659
660    #[test]
661    fn clamp_max_f64() {
662        let inputs = [(
663            vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
664            0.0,
665            vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(0.0)],
666        )];
667
668        let func = ClampMaxFunction::default();
669        for (in_data, max, expected) in inputs {
670            let number_rows = in_data.len();
671            let args = vec![
672                ColumnarValue::Array(Arc::new(Float64Array::from(in_data))),
673                ColumnarValue::Scalar(max.into()),
674            ];
675            let result = func.test_eval(args, number_rows).unwrap();
676            let expected: ArrayRef = Arc::new(Float64Array::from(expected));
677            assert_eq!(expected.as_ref(), result.as_ref());
678        }
679    }
680
681    #[test]
682    fn clamp_min_type_not_match() {
683        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
684        let min = -1i64;
685
686        let func = ClampMinFunction::default();
687        let number_rows = input.len();
688        let args = vec![
689            ColumnarValue::Array(Arc::new(Float64Array::from(input))),
690            ColumnarValue::Scalar(min.into()),
691        ];
692        let result = func.test_eval(args, number_rows);
693        assert!(result.is_err());
694    }
695
696    #[test]
697    fn clamp_max_type_not_match() {
698        let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
699        let max = 1i64;
700
701        let func = ClampMaxFunction::default();
702        let number_rows = input.len();
703        let args = vec![
704            ColumnarValue::Array(Arc::new(Float64Array::from(input))),
705            ColumnarValue::Scalar(max.into()),
706        ];
707        let result = func.test_eval(args, number_rows);
708        assert!(result.is_err());
709    }
710}