1use std::fmt::{self, Display};
16use std::sync::Arc;
17
18use datafusion::arrow::array::{Array, ArrayRef, AsArray, PrimitiveArray};
19use datafusion::arrow::datatypes::DataType as ArrowDataType;
20use datafusion::logical_expr::{ColumnarValue, Volatility};
21use datafusion_common::{DataFusionError, ScalarValue, utils};
22use datafusion_expr::type_coercion::aggregates::NUMERICS;
23use datafusion_expr::{ScalarFunctionArgs, Signature};
24
25use crate::function::Function;
26
27#[derive(Clone, Debug)]
28pub struct ClampFunction {
29 signature: Signature,
30}
31
32impl Default for ClampFunction {
33 fn default() -> Self {
34 Self {
35 signature: Signature::uniform(3, NUMERICS.to_vec(), Volatility::Immutable),
37 }
38 }
39}
40
41const CLAMP_NAME: &str = "clamp";
42
43impl Function for ClampFunction {
44 fn name(&self) -> &str {
45 CLAMP_NAME
46 }
47
48 fn return_type(
49 &self,
50 input_types: &[ArrowDataType],
51 ) -> datafusion_common::Result<ArrowDataType> {
52 Ok(input_types[0].clone())
54 }
55
56 fn signature(&self) -> &Signature {
57 &self.signature
58 }
59
60 fn invoke_with_args(
61 &self,
62 args: ScalarFunctionArgs,
63 ) -> datafusion_common::Result<ColumnarValue> {
64 let [col, min, max] = utils::take_function_args(self.name(), args.args)?;
65 clamp_impl(col, min, max)
66 }
67}
68
69impl Display for ClampFunction {
70 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
71 write!(f, "{}", CLAMP_NAME.to_ascii_uppercase())
72 }
73}
74
75fn clamp_impl(
76 col: ColumnarValue,
77 min: ColumnarValue,
78 max: ColumnarValue,
79) -> datafusion_common::Result<ColumnarValue> {
80 if col.data_type() != min.data_type() || min.data_type() != max.data_type() {
81 return Err(DataFusionError::Execution(format!(
82 "argument data types mismatch: {}, {}, {}",
83 col.data_type(),
84 min.data_type(),
85 max.data_type(),
86 )));
87 }
88
89 macro_rules! with_match_numerics_types {
90 ($data_type:expr, | $_:tt $T:ident | $body:tt) => {{
91 macro_rules! __with_ty__ {
92 ( $_ $T:ident ) => {
93 $body
94 };
95 }
96
97 use datafusion::arrow::datatypes::{
98 Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type,
99 UInt16Type, UInt32Type, UInt64Type,
100 };
101
102 match $data_type {
103 ArrowDataType::Int8 => Ok(__with_ty__! { Int8Type }),
104 ArrowDataType::Int16 => Ok(__with_ty__! { Int16Type }),
105 ArrowDataType::Int32 => Ok(__with_ty__! { Int32Type }),
106 ArrowDataType::Int64 => Ok(__with_ty__! { Int64Type }),
107 ArrowDataType::UInt8 => Ok(__with_ty__! { UInt8Type }),
108 ArrowDataType::UInt16 => Ok(__with_ty__! { UInt16Type }),
109 ArrowDataType::UInt32 => Ok(__with_ty__! { UInt32Type }),
110 ArrowDataType::UInt64 => Ok(__with_ty__! { UInt64Type }),
111 ArrowDataType::Float32 => Ok(__with_ty__! { Float32Type }),
112 ArrowDataType::Float64 => Ok(__with_ty__! { Float64Type }),
113 _ => Err(DataFusionError::Execution(format!(
114 "unsupported numeric data type: '{}'",
115 $data_type
116 ))),
117 }
118 }};
119 }
120
121 macro_rules! clamp {
122 ($v: ident, $min: ident, $max: ident) => {
123 if $v < $min {
124 $min
125 } else if $v > $max {
126 $max
127 } else {
128 $v
129 }
130 };
131 }
132
133 match (col, min, max) {
134 (ColumnarValue::Scalar(col), ColumnarValue::Scalar(min), ColumnarValue::Scalar(max)) => {
135 if min > max {
136 return Err(DataFusionError::Execution(format!(
137 "min '{}' > max '{}'",
138 min, max
139 )));
140 }
141 Ok(ColumnarValue::Scalar(clamp!(col, min, max)))
142 }
143
144 (ColumnarValue::Array(col), ColumnarValue::Array(min), ColumnarValue::Array(max)) => {
145 if col.len() != min.len() || col.len() != max.len() {
146 return Err(DataFusionError::Internal(
147 "arguments not of same length".to_string(),
148 ));
149 }
150 let result = with_match_numerics_types!(
151 col.data_type(),
152 |$S| {
153 let col = col.as_primitive::<$S>();
154 let min = min.as_primitive::<$S>();
155 let max = max.as_primitive::<$S>();
156 Arc::new(PrimitiveArray::<$S>::from(
157 (0..col.len())
158 .map(|i| {
159 let v = col.is_valid(i).then(|| col.value(i));
160 let min = min.is_valid(i).then(|| min.value(i));
162 let max = max.is_valid(i).then(|| max.value(i));
163 Ok(match (v, min, max) {
164 (Some(v), Some(min), Some(max)) => {
165 if min > max {
166 return Err(DataFusionError::Execution(format!(
167 "min '{}' > max '{}'",
168 min, max
169 )));
170 }
171 Some(clamp!(v, min, max))
172 },
173 _ => None,
174 })
175 })
176 .collect::<datafusion_common::Result<Vec<_>>>()?,
177 )
178 ) as ArrayRef
179 }
180 )?;
181 Ok(ColumnarValue::Array(result))
182 }
183
184 (ColumnarValue::Array(col), ColumnarValue::Scalar(min), ColumnarValue::Scalar(max)) => {
185 if min.is_null() || max.is_null() {
186 return Err(DataFusionError::Execution(
187 "argument 'min' or 'max' is null".to_string(),
188 ));
189 }
190 let min = min.to_array()?;
191 let max = max.to_array()?;
192 let result = with_match_numerics_types!(
193 col.data_type(),
194 |$S| {
195 let col = col.as_primitive::<$S>();
196 let min = min.as_primitive::<$S>().value(0);
198 let max = max.as_primitive::<$S>().value(0);
199 if min > max {
200 return Err(DataFusionError::Execution(format!(
201 "min '{}' > max '{}'",
202 min, max
203 )));
204 }
205 Arc::new(PrimitiveArray::<$S>::from(
206 (0..col.len())
207 .map(|x| {
208 col.is_valid(x).then(|| {
209 let v = col.value(x);
210 clamp!(v, min, max)
211 })
212 })
213 .collect::<Vec<_>>(),
214 )
215 ) as ArrayRef
216 }
217 )?;
218 Ok(ColumnarValue::Array(result))
219 }
220 _ => Err(DataFusionError::Internal(
221 "argument column types mismatch".to_string(),
222 )),
223 }
224}
225
226#[derive(Clone, Debug)]
227pub struct ClampMinFunction {
228 signature: Signature,
229}
230
231impl Default for ClampMinFunction {
232 fn default() -> Self {
233 Self {
234 signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
236 }
237 }
238}
239
240const CLAMP_MIN_NAME: &str = "clamp_min";
241
242impl Function for ClampMinFunction {
243 fn name(&self) -> &str {
244 CLAMP_MIN_NAME
245 }
246
247 fn return_type(
248 &self,
249 input_types: &[ArrowDataType],
250 ) -> datafusion_common::Result<ArrowDataType> {
251 Ok(input_types[0].clone())
252 }
253
254 fn signature(&self) -> &Signature {
255 &self.signature
256 }
257
258 fn invoke_with_args(
259 &self,
260 args: ScalarFunctionArgs,
261 ) -> datafusion_common::Result<ColumnarValue> {
262 let [col, min] = utils::take_function_args(self.name(), args.args)?;
263
264 let Some(max) = ScalarValue::max(&min.data_type()) else {
265 return Err(DataFusionError::Internal(format!(
266 "cannot find a max value for numeric data type {}",
267 min.data_type()
268 )));
269 };
270 clamp_impl(col, min, ColumnarValue::Scalar(max))
271 }
272}
273
274impl Display for ClampMinFunction {
275 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
276 write!(f, "{}", CLAMP_MIN_NAME.to_ascii_uppercase())
277 }
278}
279
280#[derive(Clone, Debug)]
281pub struct ClampMaxFunction {
282 signature: Signature,
283}
284
285impl Default for ClampMaxFunction {
286 fn default() -> Self {
287 Self {
288 signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
290 }
291 }
292}
293
294const CLAMP_MAX_NAME: &str = "clamp_max";
295
296impl Function for ClampMaxFunction {
297 fn name(&self) -> &str {
298 CLAMP_MAX_NAME
299 }
300
301 fn return_type(
302 &self,
303 input_types: &[ArrowDataType],
304 ) -> datafusion_common::Result<ArrowDataType> {
305 Ok(input_types[0].clone())
306 }
307
308 fn signature(&self) -> &Signature {
309 &self.signature
310 }
311
312 fn invoke_with_args(
313 &self,
314 args: ScalarFunctionArgs,
315 ) -> datafusion_common::Result<ColumnarValue> {
316 let [col, max] = utils::take_function_args(self.name(), args.args)?;
317
318 let Some(min) = ScalarValue::min(&max.data_type()) else {
319 return Err(DataFusionError::Internal(format!(
320 "cannot find a min value for numeric data type {}",
321 max.data_type()
322 )));
323 };
324 clamp_impl(col, ColumnarValue::Scalar(min), max)
325 }
326}
327
328impl Display for ClampMaxFunction {
329 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
330 write!(f, "{}", CLAMP_MAX_NAME.to_ascii_uppercase())
331 }
332}
333
334#[cfg(test)]
335mod test {
336
337 use std::sync::Arc;
338
339 use arrow_schema::Field;
340 use datafusion_common::config::ConfigOptions;
341 use datatypes::arrow::array::{ArrayRef, Float64Array, Int64Array, UInt64Array};
342 use datatypes::arrow_array::StringArray;
343
344 use super::*;
345
346 macro_rules! impl_test_eval {
347 ($func: ty) => {
348 impl $func {
349 fn test_eval(
350 &self,
351 args: Vec<ColumnarValue>,
352 number_rows: usize,
353 ) -> datafusion_common::Result<ArrayRef> {
354 let input_type = args[0].data_type();
355 self.invoke_with_args(ScalarFunctionArgs {
356 args,
357 arg_fields: vec![],
358 number_rows,
359 return_field: Arc::new(Field::new("x", input_type, false)),
360 config_options: Arc::new(ConfigOptions::new()),
361 })
362 .and_then(|v| ColumnarValue::values_to_arrays(&[v]).map_err(Into::into))
363 .map(|mut a| a.remove(0))
364 }
365 }
366 };
367 }
368
369 impl_test_eval!(ClampFunction);
370 impl_test_eval!(ClampMinFunction);
371 impl_test_eval!(ClampMaxFunction);
372
373 #[test]
374 fn clamp_i64() {
375 let inputs = [
376 (
377 vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
378 -1i64,
379 10i64,
380 vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
381 ),
382 (
383 vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
384 0i64,
385 0i64,
386 vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
387 ),
388 (
389 vec![Some(-3), None, Some(-1), None, None, Some(2)],
390 -2i64,
391 1i64,
392 vec![Some(-2), None, Some(-1), None, None, Some(1)],
393 ),
394 (
395 vec![None, None, None, None, None],
396 0i64,
397 1i64,
398 vec![None, None, None, None, None],
399 ),
400 ];
401
402 let func = ClampFunction::default();
403 for (in_data, min, max, expected) in inputs {
404 let number_rows = in_data.len();
405 let args = vec![
406 ColumnarValue::Array(Arc::new(Int64Array::from(in_data))),
407 ColumnarValue::Scalar(min.into()),
408 ColumnarValue::Scalar(max.into()),
409 ];
410 let result = func.test_eval(args, number_rows).unwrap();
411 let expected: ArrayRef = Arc::new(Int64Array::from(expected));
412 assert_eq!(expected.as_ref(), result.as_ref());
413 }
414 }
415
416 #[test]
417 fn clamp_u64() {
418 let inputs = [
419 (
420 vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
421 1u64,
422 3u64,
423 vec![Some(1), Some(1), Some(2), Some(3), Some(3), Some(3)],
424 ),
425 (
426 vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
427 0u64,
428 0u64,
429 vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
430 ),
431 (
432 vec![Some(0), None, Some(2), None, None, Some(5)],
433 1u64,
434 3u64,
435 vec![Some(1), None, Some(2), None, None, Some(3)],
436 ),
437 (
438 vec![None, None, None, None, None],
439 0u64,
440 1u64,
441 vec![None, None, None, None, None],
442 ),
443 ];
444
445 let func = ClampFunction::default();
446 for (in_data, min, max, expected) in inputs {
447 let number_rows = in_data.len();
448 let args = vec![
449 ColumnarValue::Array(Arc::new(UInt64Array::from(in_data))),
450 ColumnarValue::Scalar(min.into()),
451 ColumnarValue::Scalar(max.into()),
452 ];
453 let result = func.test_eval(args, number_rows).unwrap();
454 let expected: ArrayRef = Arc::new(UInt64Array::from(expected));
455 assert_eq!(expected.as_ref(), result.as_ref());
456 }
457 }
458
459 #[test]
460 fn clamp_f64() {
461 let inputs = [
462 (
463 vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
464 -1.0,
465 10.0,
466 vec![Some(-1.0), Some(-1.0), Some(-1.0), Some(0.0), Some(1.0)],
467 ),
468 (
469 vec![Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
470 0.0,
471 0.0,
472 vec![Some(0.0), Some(0.0), Some(0.0), Some(0.0)],
473 ),
474 (
475 vec![Some(-3.0), None, Some(-1.0), None, None, Some(2.0)],
476 -2.0,
477 1.0,
478 vec![Some(-2.0), None, Some(-1.0), None, None, Some(1.0)],
479 ),
480 (
481 vec![None, None, None, None, None],
482 0.0,
483 1.0,
484 vec![None, None, None, None, None],
485 ),
486 ];
487
488 let func = ClampFunction::default();
489 for (in_data, min, max, expected) in inputs {
490 let number_rows = in_data.len();
491 let args = vec![
492 ColumnarValue::Array(Arc::new(Float64Array::from(in_data))),
493 ColumnarValue::Scalar(min.into()),
494 ColumnarValue::Scalar(max.into()),
495 ];
496 let result = func.test_eval(args, number_rows).unwrap();
497 let expected: ArrayRef = Arc::new(Float64Array::from(expected));
498 assert_eq!(expected.as_ref(), result.as_ref());
499 }
500 }
501
502 #[test]
503 fn clamp_invalid_min_max() {
504 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
505 let min = 10.0;
506 let max = -1.0;
507
508 let func = ClampFunction::default();
509 let number_rows = input.len();
510 let args = vec![
511 ColumnarValue::Array(Arc::new(Float64Array::from(input))),
512 ColumnarValue::Scalar(min.into()),
513 ColumnarValue::Scalar(max.into()),
514 ];
515 let result = func.test_eval(args, number_rows);
516 assert!(result.is_err());
517 }
518
519 #[test]
520 fn clamp_type_not_match() {
521 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
522 let min = -1i64;
523 let max = 10u64;
524
525 let func = ClampFunction::default();
526 let number_rows = input.len();
527 let args = vec![
528 ColumnarValue::Array(Arc::new(Float64Array::from(input))),
529 ColumnarValue::Scalar(min.into()),
530 ColumnarValue::Scalar(max.into()),
531 ];
532 let result = func.test_eval(args, number_rows);
533 assert!(result.is_err());
534 }
535
536 #[test]
537 fn clamp_min_is_not_scalar() {
538 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
539 let min = -10.0;
540 let max = 1.0;
541
542 let func = ClampFunction::default();
543 let number_rows = input.len();
544 let args = vec![
545 ColumnarValue::Array(Arc::new(Float64Array::from(input))),
546 ColumnarValue::Array(Arc::new(Float64Array::from(vec![min, max]))),
547 ColumnarValue::Array(Arc::new(Float64Array::from(vec![max, min]))),
548 ];
549 let result = func.test_eval(args, number_rows);
550 assert!(result.is_err());
551 }
552
553 #[test]
554 fn clamp_no_max() {
555 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
556 let min = -10.0;
557
558 let func = ClampFunction::default();
559 let number_rows = input.len();
560 let args = vec![
561 ColumnarValue::Array(Arc::new(Float64Array::from(input))),
562 ColumnarValue::Scalar(min.into()),
563 ];
564 let result = func.test_eval(args, number_rows);
565 assert!(result.is_err());
566 }
567
568 #[test]
569 fn clamp_on_string() {
570 let input = vec![Some("foo"), Some("foo"), Some("foo"), Some("foo")];
571
572 let func = ClampFunction::default();
573 let number_rows = input.len();
574 let args = vec![
575 ColumnarValue::Array(Arc::new(StringArray::from(input))),
576 ColumnarValue::Scalar("bar".into()),
577 ColumnarValue::Scalar("baz".into()),
578 ];
579 let result = func.test_eval(args, number_rows);
580 assert!(result.is_err());
581 }
582
583 #[test]
584 fn clamp_min_i64() {
585 let inputs = [
586 (
587 vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
588 -1i64,
589 vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
590 ),
591 (
592 vec![Some(-3), None, Some(-1), None, None, Some(2)],
593 -2i64,
594 vec![Some(-2), None, Some(-1), None, None, Some(2)],
595 ),
596 ];
597
598 let func = ClampMinFunction::default();
599 for (in_data, min, expected) in inputs {
600 let number_rows = in_data.len();
601 let args = vec![
602 ColumnarValue::Array(Arc::new(Int64Array::from(in_data))),
603 ColumnarValue::Scalar(min.into()),
604 ];
605 let result = func.test_eval(args, number_rows).unwrap();
606 let expected: ArrayRef = Arc::new(Int64Array::from(expected));
607 assert_eq!(expected.as_ref(), result.as_ref());
608 }
609 }
610
611 #[test]
612 fn clamp_max_i64() {
613 let inputs = [
614 (
615 vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
616 1i64,
617 vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(1)],
618 ),
619 (
620 vec![Some(-3), None, Some(-1), None, None, Some(2)],
621 0i64,
622 vec![Some(-3), None, Some(-1), None, None, Some(0)],
623 ),
624 ];
625
626 let func = ClampMaxFunction::default();
627 for (in_data, max, expected) in inputs {
628 let number_rows = in_data.len();
629 let args = vec![
630 ColumnarValue::Array(Arc::new(Int64Array::from(in_data))),
631 ColumnarValue::Scalar(max.into()),
632 ];
633 let result = func.test_eval(args, number_rows).unwrap();
634 let expected: ArrayRef = Arc::new(Int64Array::from(expected));
635 assert_eq!(expected.as_ref(), result.as_ref());
636 }
637 }
638
639 #[test]
640 fn clamp_min_f64() {
641 let inputs = [(
642 vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
643 -1.0,
644 vec![Some(-1.0), Some(-1.0), Some(-1.0), Some(0.0), Some(1.0)],
645 )];
646
647 let func = ClampMinFunction::default();
648 for (in_data, min, expected) in inputs {
649 let number_rows = in_data.len();
650 let args = vec![
651 ColumnarValue::Array(Arc::new(Float64Array::from(in_data))),
652 ColumnarValue::Scalar(min.into()),
653 ];
654 let result = func.test_eval(args, number_rows).unwrap();
655 let expected: ArrayRef = Arc::new(Float64Array::from(expected));
656 assert_eq!(expected.as_ref(), result.as_ref());
657 }
658 }
659
660 #[test]
661 fn clamp_max_f64() {
662 let inputs = [(
663 vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
664 0.0,
665 vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(0.0)],
666 )];
667
668 let func = ClampMaxFunction::default();
669 for (in_data, max, expected) in inputs {
670 let number_rows = in_data.len();
671 let args = vec![
672 ColumnarValue::Array(Arc::new(Float64Array::from(in_data))),
673 ColumnarValue::Scalar(max.into()),
674 ];
675 let result = func.test_eval(args, number_rows).unwrap();
676 let expected: ArrayRef = Arc::new(Float64Array::from(expected));
677 assert_eq!(expected.as_ref(), result.as_ref());
678 }
679 }
680
681 #[test]
682 fn clamp_min_type_not_match() {
683 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
684 let min = -1i64;
685
686 let func = ClampMinFunction::default();
687 let number_rows = input.len();
688 let args = vec![
689 ColumnarValue::Array(Arc::new(Float64Array::from(input))),
690 ColumnarValue::Scalar(min.into()),
691 ];
692 let result = func.test_eval(args, number_rows);
693 assert!(result.is_err());
694 }
695
696 #[test]
697 fn clamp_max_type_not_match() {
698 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
699 let max = 1i64;
700
701 let func = ClampMaxFunction::default();
702 let number_rows = input.len();
703 let args = vec![
704 ColumnarValue::Array(Arc::new(Float64Array::from(input))),
705 ColumnarValue::Scalar(max.into()),
706 ];
707 let result = func.test_eval(args, number_rows);
708 assert!(result.is_err());
709 }
710}