1use std::fmt::{self, Display};
16use std::sync::Arc;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion::arrow::array::{ArrayIter, PrimitiveArray};
20use datafusion::arrow::datatypes::DataType as ArrowDataType;
21use datafusion::logical_expr::Volatility;
22use datafusion_expr::Signature;
23use datafusion_expr::type_coercion::aggregates::NUMERICS;
24use datatypes::data_type::DataType;
25use datatypes::prelude::VectorRef;
26use datatypes::types::LogicalPrimitiveType;
27use datatypes::value::TryAsPrimitive;
28use datatypes::vectors::PrimitiveVector;
29use datatypes::with_match_primitive_type_id;
30use snafu::{OptionExt, ensure};
31
32use crate::function::{Function, FunctionContext};
33
34#[derive(Clone, Debug, Default)]
35pub struct ClampFunction;
36
37const CLAMP_NAME: &str = "clamp";
38
39fn ensure_constant_vector(vector: &VectorRef) -> Result<()> {
41 ensure!(
42 !vector.is_empty(),
43 InvalidFuncArgsSnafu {
44 err_msg: "Expect at least one value",
45 }
46 );
47
48 if vector.is_const() {
49 return Ok(());
50 }
51
52 let first = vector.get_ref(0);
53 for i in 1..vector.len() {
54 let v = vector.get_ref(i);
55 if first != v {
56 return InvalidFuncArgsSnafu {
57 err_msg: "All values in min/max argument must be identical",
58 }
59 .fail();
60 }
61 }
62
63 Ok(())
64}
65
66impl Function for ClampFunction {
67 fn name(&self) -> &str {
68 CLAMP_NAME
69 }
70
71 fn return_type(&self, input_types: &[ArrowDataType]) -> Result<ArrowDataType> {
72 Ok(input_types[0].clone())
74 }
75
76 fn signature(&self) -> Signature {
77 Signature::uniform(3, NUMERICS.to_vec(), Volatility::Immutable)
79 }
80
81 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
82 ensure!(
83 columns.len() == 3,
84 InvalidFuncArgsSnafu {
85 err_msg: format!(
86 "The length of the args is not correct, expect exactly 3, have: {}",
87 columns.len()
88 ),
89 }
90 );
91 ensure!(
92 columns[0].data_type().is_numeric(),
93 InvalidFuncArgsSnafu {
94 err_msg: format!(
95 "The first arg's type is not numeric, have: {}",
96 columns[0].data_type()
97 ),
98 }
99 );
100 ensure!(
101 columns[0].data_type() == columns[1].data_type()
102 && columns[1].data_type() == columns[2].data_type(),
103 InvalidFuncArgsSnafu {
104 err_msg: format!(
105 "Arguments don't have identical types: {}, {}, {}",
106 columns[0].data_type(),
107 columns[1].data_type(),
108 columns[2].data_type()
109 ),
110 }
111 );
112
113 ensure_constant_vector(&columns[1])?;
114 ensure_constant_vector(&columns[2])?;
115
116 with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
117 let input_array = columns[0].to_arrow_array();
118 let input = input_array
119 .as_any()
120 .downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
121 .unwrap();
122
123 let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
124 .with_context(|| {
125 InvalidFuncArgsSnafu {
126 err_msg: "The second arg should not be none",
127 }
128 })?;
129 let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[2].get(0))
130 .with_context(|| {
131 InvalidFuncArgsSnafu {
132 err_msg: "The third arg should not be none",
133 }
134 })?;
135
136 ensure!(
138 min <= max,
139 InvalidFuncArgsSnafu {
140 err_msg: format!(
141 "The second arg should be less than or equal to the third arg, have: {:?}, {:?}",
142 columns[1], columns[2]
143 ),
144 }
145 );
146
147 clamp_impl::<$S, true, true>(input, min, max)
148 },{
149 unreachable!()
150 })
151 }
152}
153
154impl Display for ClampFunction {
155 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
156 write!(f, "{}", CLAMP_NAME.to_ascii_uppercase())
157 }
158}
159
160fn clamp_impl<T: LogicalPrimitiveType, const CLAMP_MIN: bool, const CLAMP_MAX: bool>(
161 input: &PrimitiveArray<T::ArrowPrimitive>,
162 min: T::Native,
163 max: T::Native,
164) -> Result<VectorRef> {
165 let iter = ArrayIter::new(input);
166 let result = iter.map(|x| {
167 x.map(|x| {
168 if CLAMP_MIN && x < min {
169 min
170 } else if CLAMP_MAX && x > max {
171 max
172 } else {
173 x
174 }
175 })
176 });
177 let result = PrimitiveArray::<T::ArrowPrimitive>::from_iter(result);
178 Ok(Arc::new(PrimitiveVector::<T>::from(result)))
179}
180
181#[derive(Clone, Debug, Default)]
182pub struct ClampMinFunction;
183
184const CLAMP_MIN_NAME: &str = "clamp_min";
185
186impl Function for ClampMinFunction {
187 fn name(&self) -> &str {
188 CLAMP_MIN_NAME
189 }
190
191 fn return_type(&self, input_types: &[ArrowDataType]) -> Result<ArrowDataType> {
192 Ok(input_types[0].clone())
193 }
194
195 fn signature(&self) -> Signature {
196 Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
198 }
199
200 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
201 ensure!(
202 columns.len() == 2,
203 InvalidFuncArgsSnafu {
204 err_msg: format!(
205 "The length of the args is not correct, expect exactly 2, have: {}",
206 columns.len()
207 ),
208 }
209 );
210 ensure!(
211 columns[0].data_type().is_numeric(),
212 InvalidFuncArgsSnafu {
213 err_msg: format!(
214 "The first arg's type is not numeric, have: {}",
215 columns[0].data_type()
216 ),
217 }
218 );
219 ensure!(
220 columns[0].data_type() == columns[1].data_type(),
221 InvalidFuncArgsSnafu {
222 err_msg: format!(
223 "Arguments don't have identical types: {}, {}",
224 columns[0].data_type(),
225 columns[1].data_type()
226 ),
227 }
228 );
229
230 ensure_constant_vector(&columns[1])?;
231
232 with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
233 let input_array = columns[0].to_arrow_array();
234 let input = input_array
235 .as_any()
236 .downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
237 .unwrap();
238
239 let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
240 .with_context(|| {
241 InvalidFuncArgsSnafu {
242 err_msg: "The second arg (min) should not be none",
243 }
244 })?;
245 let max_dummy = <$S as LogicalPrimitiveType>::Native::default();
248
249 clamp_impl::<$S, true, false>(input, min, max_dummy)
250 },{
251 unreachable!()
252 })
253 }
254}
255
256impl Display for ClampMinFunction {
257 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
258 write!(f, "{}", CLAMP_MIN_NAME.to_ascii_uppercase())
259 }
260}
261
262#[derive(Clone, Debug, Default)]
263pub struct ClampMaxFunction;
264
265const CLAMP_MAX_NAME: &str = "clamp_max";
266
267impl Function for ClampMaxFunction {
268 fn name(&self) -> &str {
269 CLAMP_MAX_NAME
270 }
271
272 fn return_type(&self, input_types: &[ArrowDataType]) -> Result<ArrowDataType> {
273 Ok(input_types[0].clone())
274 }
275
276 fn signature(&self) -> Signature {
277 Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
279 }
280
281 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
282 ensure!(
283 columns.len() == 2,
284 InvalidFuncArgsSnafu {
285 err_msg: format!(
286 "The length of the args is not correct, expect exactly 2, have: {}",
287 columns.len()
288 ),
289 }
290 );
291 ensure!(
292 columns[0].data_type().is_numeric(),
293 InvalidFuncArgsSnafu {
294 err_msg: format!(
295 "The first arg's type is not numeric, have: {}",
296 columns[0].data_type()
297 ),
298 }
299 );
300 ensure!(
301 columns[0].data_type() == columns[1].data_type(),
302 InvalidFuncArgsSnafu {
303 err_msg: format!(
304 "Arguments don't have identical types: {}, {}",
305 columns[0].data_type(),
306 columns[1].data_type()
307 ),
308 }
309 );
310
311 ensure_constant_vector(&columns[1])?;
312
313 with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
314 let input_array = columns[0].to_arrow_array();
315 let input = input_array
316 .as_any()
317 .downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
318 .unwrap();
319
320 let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
321 .with_context(|| {
322 InvalidFuncArgsSnafu {
323 err_msg: "The second arg (max) should not be none",
324 }
325 })?;
326 let min_dummy = <$S as LogicalPrimitiveType>::Native::default();
329
330 clamp_impl::<$S, false, true>(input, min_dummy, max)
331 },{
332 unreachable!()
333 })
334 }
335}
336
337impl Display for ClampMaxFunction {
338 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
339 write!(f, "{}", CLAMP_MAX_NAME.to_ascii_uppercase())
340 }
341}
342
343#[cfg(test)]
344mod test {
345
346 use std::sync::Arc;
347
348 use datatypes::prelude::ScalarVector;
349 use datatypes::vectors::{
350 ConstantVector, Float64Vector, Int64Vector, StringVector, UInt64Vector,
351 };
352
353 use super::*;
354 use crate::function::FunctionContext;
355
356 #[test]
357 fn clamp_i64() {
358 let inputs = [
359 (
360 vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
361 -1,
362 10,
363 vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
364 ),
365 (
366 vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
367 0,
368 0,
369 vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
370 ),
371 (
372 vec![Some(-3), None, Some(-1), None, None, Some(2)],
373 -2,
374 1,
375 vec![Some(-2), None, Some(-1), None, None, Some(1)],
376 ),
377 (
378 vec![None, None, None, None, None],
379 0,
380 1,
381 vec![None, None, None, None, None],
382 ),
383 ];
384
385 let func = ClampFunction;
386 for (in_data, min, max, expected) in inputs {
387 let args = [
388 Arc::new(Int64Vector::from(in_data)) as _,
389 Arc::new(Int64Vector::from_vec(vec![min])) as _,
390 Arc::new(Int64Vector::from_vec(vec![max])) as _,
391 ];
392 let result = func
393 .eval(&FunctionContext::default(), args.as_slice())
394 .unwrap();
395 let expected: VectorRef = Arc::new(Int64Vector::from(expected));
396 assert_eq!(expected, result);
397 }
398 }
399
400 #[test]
401 fn clamp_u64() {
402 let inputs = [
403 (
404 vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
405 1,
406 3,
407 vec![Some(1), Some(1), Some(2), Some(3), Some(3), Some(3)],
408 ),
409 (
410 vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
411 0,
412 0,
413 vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
414 ),
415 (
416 vec![Some(0), None, Some(2), None, None, Some(5)],
417 1,
418 3,
419 vec![Some(1), None, Some(2), None, None, Some(3)],
420 ),
421 (
422 vec![None, None, None, None, None],
423 0,
424 1,
425 vec![None, None, None, None, None],
426 ),
427 ];
428
429 let func = ClampFunction;
430 for (in_data, min, max, expected) in inputs {
431 let args = [
432 Arc::new(UInt64Vector::from(in_data)) as _,
433 Arc::new(UInt64Vector::from_vec(vec![min])) as _,
434 Arc::new(UInt64Vector::from_vec(vec![max])) as _,
435 ];
436 let result = func
437 .eval(&FunctionContext::default(), args.as_slice())
438 .unwrap();
439 let expected: VectorRef = Arc::new(UInt64Vector::from(expected));
440 assert_eq!(expected, result);
441 }
442 }
443
444 #[test]
445 fn clamp_f64() {
446 let inputs = [
447 (
448 vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
449 -1.0,
450 10.0,
451 vec![Some(-1.0), Some(-1.0), Some(-1.0), Some(0.0), Some(1.0)],
452 ),
453 (
454 vec![Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
455 0.0,
456 0.0,
457 vec![Some(0.0), Some(0.0), Some(0.0), Some(0.0)],
458 ),
459 (
460 vec![Some(-3.0), None, Some(-1.0), None, None, Some(2.0)],
461 -2.0,
462 1.0,
463 vec![Some(-2.0), None, Some(-1.0), None, None, Some(1.0)],
464 ),
465 (
466 vec![None, None, None, None, None],
467 0.0,
468 1.0,
469 vec![None, None, None, None, None],
470 ),
471 ];
472
473 let func = ClampFunction;
474 for (in_data, min, max, expected) in inputs {
475 let args = [
476 Arc::new(Float64Vector::from(in_data)) as _,
477 Arc::new(Float64Vector::from_vec(vec![min])) as _,
478 Arc::new(Float64Vector::from_vec(vec![max])) as _,
479 ];
480 let result = func
481 .eval(&FunctionContext::default(), args.as_slice())
482 .unwrap();
483 let expected: VectorRef = Arc::new(Float64Vector::from(expected));
484 assert_eq!(expected, result);
485 }
486 }
487
488 #[test]
489 fn clamp_const_i32() {
490 let input = vec![Some(5)];
491 let min = 2;
492 let max = 4;
493
494 let func = ClampFunction;
495 let args = [
496 Arc::new(ConstantVector::new(Arc::new(Int64Vector::from(input)), 1)) as _,
497 Arc::new(Int64Vector::from_vec(vec![min])) as _,
498 Arc::new(Int64Vector::from_vec(vec![max])) as _,
499 ];
500 let result = func
501 .eval(&FunctionContext::default(), args.as_slice())
502 .unwrap();
503 let expected: VectorRef = Arc::new(Int64Vector::from(vec![Some(4)]));
504 assert_eq!(expected, result);
505 }
506
507 #[test]
508 fn clamp_invalid_min_max() {
509 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
510 let min = 10.0;
511 let max = -1.0;
512
513 let func = ClampFunction;
514 let args = [
515 Arc::new(Float64Vector::from(input)) as _,
516 Arc::new(Float64Vector::from_vec(vec![min])) as _,
517 Arc::new(Float64Vector::from_vec(vec![max])) as _,
518 ];
519 let result = func.eval(&FunctionContext::default(), args.as_slice());
520 assert!(result.is_err());
521 }
522
523 #[test]
524 fn clamp_type_not_match() {
525 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
526 let min = -1;
527 let max = 10;
528
529 let func = ClampFunction;
530 let args = [
531 Arc::new(Float64Vector::from(input)) as _,
532 Arc::new(Int64Vector::from_vec(vec![min])) as _,
533 Arc::new(UInt64Vector::from_vec(vec![max])) as _,
534 ];
535 let result = func.eval(&FunctionContext::default(), args.as_slice());
536 assert!(result.is_err());
537 }
538
539 #[test]
540 fn clamp_min_is_not_scalar() {
541 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
542 let min = -10.0;
543 let max = 1.0;
544
545 let func = ClampFunction;
546 let args = [
547 Arc::new(Float64Vector::from(input)) as _,
548 Arc::new(Float64Vector::from_vec(vec![min, max])) as _,
549 Arc::new(Float64Vector::from_vec(vec![max, min])) as _,
550 ];
551 let result = func.eval(&FunctionContext::default(), args.as_slice());
552 assert!(result.is_err());
553 }
554
555 #[test]
556 fn clamp_no_max() {
557 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
558 let min = -10.0;
559
560 let func = ClampFunction;
561 let args = [
562 Arc::new(Float64Vector::from(input)) as _,
563 Arc::new(Float64Vector::from_vec(vec![min])) as _,
564 ];
565 let result = func.eval(&FunctionContext::default(), args.as_slice());
566 assert!(result.is_err());
567 }
568
569 #[test]
570 fn clamp_on_string() {
571 let input = vec![Some("foo"), Some("foo"), Some("foo"), Some("foo")];
572
573 let func = ClampFunction;
574 let args = [
575 Arc::new(StringVector::from(input)) as _,
576 Arc::new(StringVector::from_vec(vec!["bar"])) as _,
577 Arc::new(StringVector::from_vec(vec!["baz"])) as _,
578 ];
579 let result = func.eval(&FunctionContext::default(), args.as_slice());
580 assert!(result.is_err());
581 }
582
583 #[test]
584 fn clamp_min_i64() {
585 let inputs = [
586 (
587 vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
588 -1,
589 vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
590 ),
591 (
592 vec![Some(-3), None, Some(-1), None, None, Some(2)],
593 -2,
594 vec![Some(-2), None, Some(-1), None, None, Some(2)],
595 ),
596 ];
597
598 let func = ClampMinFunction;
599 for (in_data, min, expected) in inputs {
600 let args = [
601 Arc::new(Int64Vector::from(in_data)) as _,
602 Arc::new(Int64Vector::from_vec(vec![min])) as _,
603 ];
604 let result = func
605 .eval(&FunctionContext::default(), args.as_slice())
606 .unwrap();
607 let expected: VectorRef = Arc::new(Int64Vector::from(expected));
608 assert_eq!(expected, result);
609 }
610 }
611
612 #[test]
613 fn clamp_max_i64() {
614 let inputs = [
615 (
616 vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
617 1,
618 vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(1)],
619 ),
620 (
621 vec![Some(-3), None, Some(-1), None, None, Some(2)],
622 0,
623 vec![Some(-3), None, Some(-1), None, None, Some(0)],
624 ),
625 ];
626
627 let func = ClampMaxFunction;
628 for (in_data, max, expected) in inputs {
629 let args = [
630 Arc::new(Int64Vector::from(in_data)) as _,
631 Arc::new(Int64Vector::from_vec(vec![max])) as _,
632 ];
633 let result = func
634 .eval(&FunctionContext::default(), args.as_slice())
635 .unwrap();
636 let expected: VectorRef = Arc::new(Int64Vector::from(expected));
637 assert_eq!(expected, result);
638 }
639 }
640
641 #[test]
642 fn clamp_min_f64() {
643 let inputs = [(
644 vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
645 -1.0,
646 vec![Some(-1.0), Some(-1.0), Some(-1.0), Some(0.0), Some(1.0)],
647 )];
648
649 let func = ClampMinFunction;
650 for (in_data, min, expected) in inputs {
651 let args = [
652 Arc::new(Float64Vector::from(in_data)) as _,
653 Arc::new(Float64Vector::from_vec(vec![min])) as _,
654 ];
655 let result = func
656 .eval(&FunctionContext::default(), args.as_slice())
657 .unwrap();
658 let expected: VectorRef = Arc::new(Float64Vector::from(expected));
659 assert_eq!(expected, result);
660 }
661 }
662
663 #[test]
664 fn clamp_max_f64() {
665 let inputs = [(
666 vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
667 0.0,
668 vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(0.0)],
669 )];
670
671 let func = ClampMaxFunction;
672 for (in_data, max, expected) in inputs {
673 let args = [
674 Arc::new(Float64Vector::from(in_data)) as _,
675 Arc::new(Float64Vector::from_vec(vec![max])) as _,
676 ];
677 let result = func
678 .eval(&FunctionContext::default(), args.as_slice())
679 .unwrap();
680 let expected: VectorRef = Arc::new(Float64Vector::from(expected));
681 assert_eq!(expected, result);
682 }
683 }
684
685 #[test]
686 fn clamp_min_type_not_match() {
687 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
688 let min = -1;
689
690 let func = ClampMinFunction;
691 let args = [
692 Arc::new(Float64Vector::from(input)) as _,
693 Arc::new(Int64Vector::from_vec(vec![min])) as _,
694 ];
695 let result = func.eval(&FunctionContext::default(), args.as_slice());
696 assert!(result.is_err());
697 }
698
699 #[test]
700 fn clamp_max_type_not_match() {
701 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
702 let max = 1;
703
704 let func = ClampMaxFunction;
705 let args = [
706 Arc::new(Float64Vector::from(input)) as _,
707 Arc::new(Int64Vector::from_vec(vec![max])) as _,
708 ];
709 let result = func.eval(&FunctionContext::default(), args.as_slice());
710 assert!(result.is_err());
711 }
712}