1use std::fmt::{self, Display};
16use std::sync::Arc;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use common_query::prelude::Signature;
20use datafusion::arrow::array::{ArrayIter, PrimitiveArray};
21use datafusion::logical_expr::Volatility;
22use datatypes::data_type::{ConcreteDataType, DataType};
23use datatypes::prelude::VectorRef;
24use datatypes::types::LogicalPrimitiveType;
25use datatypes::value::TryAsPrimitive;
26use datatypes::vectors::PrimitiveVector;
27use datatypes::with_match_primitive_type_id;
28use snafu::{ensure, OptionExt};
29
30use crate::function::{Function, FunctionContext};
31
32#[derive(Clone, Debug, Default)]
33pub struct ClampFunction;
34
35const CLAMP_NAME: &str = "clamp";
36
37impl Function for ClampFunction {
38 fn name(&self) -> &str {
39 CLAMP_NAME
40 }
41
42 fn return_type(&self, input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
43 Ok(input_types[0].clone())
45 }
46
47 fn signature(&self) -> Signature {
48 Signature::uniform(3, ConcreteDataType::numerics(), Volatility::Immutable)
50 }
51
52 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
53 ensure!(
54 columns.len() == 3,
55 InvalidFuncArgsSnafu {
56 err_msg: format!(
57 "The length of the args is not correct, expect exactly 3, have: {}",
58 columns.len()
59 ),
60 }
61 );
62 ensure!(
63 columns[0].data_type().is_numeric(),
64 InvalidFuncArgsSnafu {
65 err_msg: format!(
66 "The first arg's type is not numeric, have: {}",
67 columns[0].data_type()
68 ),
69 }
70 );
71 ensure!(
72 columns[0].data_type() == columns[1].data_type()
73 && columns[1].data_type() == columns[2].data_type(),
74 InvalidFuncArgsSnafu {
75 err_msg: format!(
76 "Arguments don't have identical types: {}, {}, {}",
77 columns[0].data_type(),
78 columns[1].data_type(),
79 columns[2].data_type()
80 ),
81 }
82 );
83 ensure!(
84 columns[1].len() == 1 && columns[2].len() == 1,
85 InvalidFuncArgsSnafu {
86 err_msg: format!(
87 "The second and third args should be scalar, have: {:?}, {:?}",
88 columns[1], columns[2]
89 ),
90 }
91 );
92
93 with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
94 let input_array = columns[0].to_arrow_array();
95 let input = input_array
96 .as_any()
97 .downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
98 .unwrap();
99
100 let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
101 .with_context(|| {
102 InvalidFuncArgsSnafu {
103 err_msg: "The second arg should not be none",
104 }
105 })?;
106 let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[2].get(0))
107 .with_context(|| {
108 InvalidFuncArgsSnafu {
109 err_msg: "The third arg should not be none",
110 }
111 })?;
112
113 ensure!(
115 min <= max,
116 InvalidFuncArgsSnafu {
117 err_msg: format!(
118 "The second arg should be less than or equal to the third arg, have: {:?}, {:?}",
119 columns[1], columns[2]
120 ),
121 }
122 );
123
124 clamp_impl::<$S, true, true>(input, min, max)
125 },{
126 unreachable!()
127 })
128 }
129}
130
131impl Display for ClampFunction {
132 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
133 write!(f, "{}", CLAMP_NAME.to_ascii_uppercase())
134 }
135}
136
137fn clamp_impl<T: LogicalPrimitiveType, const CLAMP_MIN: bool, const CLAMP_MAX: bool>(
138 input: &PrimitiveArray<T::ArrowPrimitive>,
139 min: T::Native,
140 max: T::Native,
141) -> Result<VectorRef> {
142 let iter = ArrayIter::new(input);
143 let result = iter.map(|x| {
144 x.map(|x| {
145 if CLAMP_MIN && x < min {
146 min
147 } else if CLAMP_MAX && x > max {
148 max
149 } else {
150 x
151 }
152 })
153 });
154 let result = PrimitiveArray::<T::ArrowPrimitive>::from_iter(result);
155 Ok(Arc::new(PrimitiveVector::<T>::from(result)))
156}
157
158#[derive(Clone, Debug, Default)]
159pub struct ClampMinFunction;
160
161const CLAMP_MIN_NAME: &str = "clamp_min";
162
163impl Function for ClampMinFunction {
164 fn name(&self) -> &str {
165 CLAMP_MIN_NAME
166 }
167
168 fn return_type(&self, input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
169 Ok(input_types[0].clone())
170 }
171
172 fn signature(&self) -> Signature {
173 Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable)
175 }
176
177 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
178 ensure!(
179 columns.len() == 2,
180 InvalidFuncArgsSnafu {
181 err_msg: format!(
182 "The length of the args is not correct, expect exactly 2, have: {}",
183 columns.len()
184 ),
185 }
186 );
187 ensure!(
188 columns[0].data_type().is_numeric(),
189 InvalidFuncArgsSnafu {
190 err_msg: format!(
191 "The first arg's type is not numeric, have: {}",
192 columns[0].data_type()
193 ),
194 }
195 );
196 ensure!(
197 columns[0].data_type() == columns[1].data_type(),
198 InvalidFuncArgsSnafu {
199 err_msg: format!(
200 "Arguments don't have identical types: {}, {}",
201 columns[0].data_type(),
202 columns[1].data_type()
203 ),
204 }
205 );
206 ensure!(
207 columns[1].len() == 1,
208 InvalidFuncArgsSnafu {
209 err_msg: format!(
210 "The second arg (min) should be scalar, have: {:?}",
211 columns[1]
212 ),
213 }
214 );
215
216 with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
217 let input_array = columns[0].to_arrow_array();
218 let input = input_array
219 .as_any()
220 .downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
221 .unwrap();
222
223 let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
224 .with_context(|| {
225 InvalidFuncArgsSnafu {
226 err_msg: "The second arg (min) should not be none",
227 }
228 })?;
229 let max_dummy = <$S as LogicalPrimitiveType>::Native::default();
232
233 clamp_impl::<$S, true, false>(input, min, max_dummy)
234 },{
235 unreachable!()
236 })
237 }
238}
239
240impl Display for ClampMinFunction {
241 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
242 write!(f, "{}", CLAMP_MIN_NAME.to_ascii_uppercase())
243 }
244}
245
246#[derive(Clone, Debug, Default)]
247pub struct ClampMaxFunction;
248
249const CLAMP_MAX_NAME: &str = "clamp_max";
250
251impl Function for ClampMaxFunction {
252 fn name(&self) -> &str {
253 CLAMP_MAX_NAME
254 }
255
256 fn return_type(&self, input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
257 Ok(input_types[0].clone())
258 }
259
260 fn signature(&self) -> Signature {
261 Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable)
263 }
264
265 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
266 ensure!(
267 columns.len() == 2,
268 InvalidFuncArgsSnafu {
269 err_msg: format!(
270 "The length of the args is not correct, expect exactly 2, have: {}",
271 columns.len()
272 ),
273 }
274 );
275 ensure!(
276 columns[0].data_type().is_numeric(),
277 InvalidFuncArgsSnafu {
278 err_msg: format!(
279 "The first arg's type is not numeric, have: {}",
280 columns[0].data_type()
281 ),
282 }
283 );
284 ensure!(
285 columns[0].data_type() == columns[1].data_type(),
286 InvalidFuncArgsSnafu {
287 err_msg: format!(
288 "Arguments don't have identical types: {}, {}",
289 columns[0].data_type(),
290 columns[1].data_type()
291 ),
292 }
293 );
294 ensure!(
295 columns[1].len() == 1,
296 InvalidFuncArgsSnafu {
297 err_msg: format!(
298 "The second arg (max) should be scalar, have: {:?}",
299 columns[1]
300 ),
301 }
302 );
303
304 with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
305 let input_array = columns[0].to_arrow_array();
306 let input = input_array
307 .as_any()
308 .downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
309 .unwrap();
310
311 let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
312 .with_context(|| {
313 InvalidFuncArgsSnafu {
314 err_msg: "The second arg (max) should not be none",
315 }
316 })?;
317 let min_dummy = <$S as LogicalPrimitiveType>::Native::default();
320
321 clamp_impl::<$S, false, true>(input, min_dummy, max)
322 },{
323 unreachable!()
324 })
325 }
326}
327
328impl Display for ClampMaxFunction {
329 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
330 write!(f, "{}", CLAMP_MAX_NAME.to_ascii_uppercase())
331 }
332}
333
334#[cfg(test)]
335mod test {
336
337 use std::sync::Arc;
338
339 use datatypes::prelude::ScalarVector;
340 use datatypes::vectors::{
341 ConstantVector, Float64Vector, Int64Vector, StringVector, UInt64Vector,
342 };
343
344 use super::*;
345 use crate::function::FunctionContext;
346
347 #[test]
348 fn clamp_i64() {
349 let inputs = [
350 (
351 vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
352 -1,
353 10,
354 vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
355 ),
356 (
357 vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
358 0,
359 0,
360 vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
361 ),
362 (
363 vec![Some(-3), None, Some(-1), None, None, Some(2)],
364 -2,
365 1,
366 vec![Some(-2), None, Some(-1), None, None, Some(1)],
367 ),
368 (
369 vec![None, None, None, None, None],
370 0,
371 1,
372 vec![None, None, None, None, None],
373 ),
374 ];
375
376 let func = ClampFunction;
377 for (in_data, min, max, expected) in inputs {
378 let args = [
379 Arc::new(Int64Vector::from(in_data)) as _,
380 Arc::new(Int64Vector::from_vec(vec![min])) as _,
381 Arc::new(Int64Vector::from_vec(vec![max])) as _,
382 ];
383 let result = func
384 .eval(&FunctionContext::default(), args.as_slice())
385 .unwrap();
386 let expected: VectorRef = Arc::new(Int64Vector::from(expected));
387 assert_eq!(expected, result);
388 }
389 }
390
391 #[test]
392 fn clamp_u64() {
393 let inputs = [
394 (
395 vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
396 1,
397 3,
398 vec![Some(1), Some(1), Some(2), Some(3), Some(3), Some(3)],
399 ),
400 (
401 vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
402 0,
403 0,
404 vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
405 ),
406 (
407 vec![Some(0), None, Some(2), None, None, Some(5)],
408 1,
409 3,
410 vec![Some(1), None, Some(2), None, None, Some(3)],
411 ),
412 (
413 vec![None, None, None, None, None],
414 0,
415 1,
416 vec![None, None, None, None, None],
417 ),
418 ];
419
420 let func = ClampFunction;
421 for (in_data, min, max, expected) in inputs {
422 let args = [
423 Arc::new(UInt64Vector::from(in_data)) as _,
424 Arc::new(UInt64Vector::from_vec(vec![min])) as _,
425 Arc::new(UInt64Vector::from_vec(vec![max])) as _,
426 ];
427 let result = func
428 .eval(&FunctionContext::default(), args.as_slice())
429 .unwrap();
430 let expected: VectorRef = Arc::new(UInt64Vector::from(expected));
431 assert_eq!(expected, result);
432 }
433 }
434
435 #[test]
436 fn clamp_f64() {
437 let inputs = [
438 (
439 vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
440 -1.0,
441 10.0,
442 vec![Some(-1.0), Some(-1.0), Some(-1.0), Some(0.0), Some(1.0)],
443 ),
444 (
445 vec![Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
446 0.0,
447 0.0,
448 vec![Some(0.0), Some(0.0), Some(0.0), Some(0.0)],
449 ),
450 (
451 vec![Some(-3.0), None, Some(-1.0), None, None, Some(2.0)],
452 -2.0,
453 1.0,
454 vec![Some(-2.0), None, Some(-1.0), None, None, Some(1.0)],
455 ),
456 (
457 vec![None, None, None, None, None],
458 0.0,
459 1.0,
460 vec![None, None, None, None, None],
461 ),
462 ];
463
464 let func = ClampFunction;
465 for (in_data, min, max, expected) in inputs {
466 let args = [
467 Arc::new(Float64Vector::from(in_data)) as _,
468 Arc::new(Float64Vector::from_vec(vec![min])) as _,
469 Arc::new(Float64Vector::from_vec(vec![max])) as _,
470 ];
471 let result = func
472 .eval(&FunctionContext::default(), args.as_slice())
473 .unwrap();
474 let expected: VectorRef = Arc::new(Float64Vector::from(expected));
475 assert_eq!(expected, result);
476 }
477 }
478
479 #[test]
480 fn clamp_const_i32() {
481 let input = vec![Some(5)];
482 let min = 2;
483 let max = 4;
484
485 let func = ClampFunction;
486 let args = [
487 Arc::new(ConstantVector::new(Arc::new(Int64Vector::from(input)), 1)) as _,
488 Arc::new(Int64Vector::from_vec(vec![min])) as _,
489 Arc::new(Int64Vector::from_vec(vec![max])) as _,
490 ];
491 let result = func
492 .eval(&FunctionContext::default(), args.as_slice())
493 .unwrap();
494 let expected: VectorRef = Arc::new(Int64Vector::from(vec![Some(4)]));
495 assert_eq!(expected, result);
496 }
497
498 #[test]
499 fn clamp_invalid_min_max() {
500 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
501 let min = 10.0;
502 let max = -1.0;
503
504 let func = ClampFunction;
505 let args = [
506 Arc::new(Float64Vector::from(input)) as _,
507 Arc::new(Float64Vector::from_vec(vec![min])) as _,
508 Arc::new(Float64Vector::from_vec(vec![max])) as _,
509 ];
510 let result = func.eval(&FunctionContext::default(), args.as_slice());
511 assert!(result.is_err());
512 }
513
514 #[test]
515 fn clamp_type_not_match() {
516 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
517 let min = -1;
518 let max = 10;
519
520 let func = ClampFunction;
521 let args = [
522 Arc::new(Float64Vector::from(input)) as _,
523 Arc::new(Int64Vector::from_vec(vec![min])) as _,
524 Arc::new(UInt64Vector::from_vec(vec![max])) as _,
525 ];
526 let result = func.eval(&FunctionContext::default(), args.as_slice());
527 assert!(result.is_err());
528 }
529
530 #[test]
531 fn clamp_min_is_not_scalar() {
532 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
533 let min = -10.0;
534 let max = 1.0;
535
536 let func = ClampFunction;
537 let args = [
538 Arc::new(Float64Vector::from(input)) as _,
539 Arc::new(Float64Vector::from_vec(vec![min, min])) as _,
540 Arc::new(Float64Vector::from_vec(vec![max])) as _,
541 ];
542 let result = func.eval(&FunctionContext::default(), args.as_slice());
543 assert!(result.is_err());
544 }
545
546 #[test]
547 fn clamp_no_max() {
548 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
549 let min = -10.0;
550
551 let func = ClampFunction;
552 let args = [
553 Arc::new(Float64Vector::from(input)) as _,
554 Arc::new(Float64Vector::from_vec(vec![min])) as _,
555 ];
556 let result = func.eval(&FunctionContext::default(), args.as_slice());
557 assert!(result.is_err());
558 }
559
560 #[test]
561 fn clamp_on_string() {
562 let input = vec![Some("foo"), Some("foo"), Some("foo"), Some("foo")];
563
564 let func = ClampFunction;
565 let args = [
566 Arc::new(StringVector::from(input)) as _,
567 Arc::new(StringVector::from_vec(vec!["bar"])) as _,
568 Arc::new(StringVector::from_vec(vec!["baz"])) as _,
569 ];
570 let result = func.eval(&FunctionContext::default(), args.as_slice());
571 assert!(result.is_err());
572 }
573
574 #[test]
575 fn clamp_min_i64() {
576 let inputs = [
577 (
578 vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
579 -1,
580 vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
581 ),
582 (
583 vec![Some(-3), None, Some(-1), None, None, Some(2)],
584 -2,
585 vec![Some(-2), None, Some(-1), None, None, Some(2)],
586 ),
587 ];
588
589 let func = ClampMinFunction;
590 for (in_data, min, expected) in inputs {
591 let args = [
592 Arc::new(Int64Vector::from(in_data)) as _,
593 Arc::new(Int64Vector::from_vec(vec![min])) as _,
594 ];
595 let result = func
596 .eval(&FunctionContext::default(), args.as_slice())
597 .unwrap();
598 let expected: VectorRef = Arc::new(Int64Vector::from(expected));
599 assert_eq!(expected, result);
600 }
601 }
602
603 #[test]
604 fn clamp_max_i64() {
605 let inputs = [
606 (
607 vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
608 1,
609 vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(1)],
610 ),
611 (
612 vec![Some(-3), None, Some(-1), None, None, Some(2)],
613 0,
614 vec![Some(-3), None, Some(-1), None, None, Some(0)],
615 ),
616 ];
617
618 let func = ClampMaxFunction;
619 for (in_data, max, expected) in inputs {
620 let args = [
621 Arc::new(Int64Vector::from(in_data)) as _,
622 Arc::new(Int64Vector::from_vec(vec![max])) as _,
623 ];
624 let result = func
625 .eval(&FunctionContext::default(), args.as_slice())
626 .unwrap();
627 let expected: VectorRef = Arc::new(Int64Vector::from(expected));
628 assert_eq!(expected, result);
629 }
630 }
631
632 #[test]
633 fn clamp_min_f64() {
634 let inputs = [(
635 vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
636 -1.0,
637 vec![Some(-1.0), Some(-1.0), Some(-1.0), Some(0.0), Some(1.0)],
638 )];
639
640 let func = ClampMinFunction;
641 for (in_data, min, expected) in inputs {
642 let args = [
643 Arc::new(Float64Vector::from(in_data)) as _,
644 Arc::new(Float64Vector::from_vec(vec![min])) as _,
645 ];
646 let result = func
647 .eval(&FunctionContext::default(), args.as_slice())
648 .unwrap();
649 let expected: VectorRef = Arc::new(Float64Vector::from(expected));
650 assert_eq!(expected, result);
651 }
652 }
653
654 #[test]
655 fn clamp_max_f64() {
656 let inputs = [(
657 vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
658 0.0,
659 vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(0.0)],
660 )];
661
662 let func = ClampMaxFunction;
663 for (in_data, max, expected) in inputs {
664 let args = [
665 Arc::new(Float64Vector::from(in_data)) as _,
666 Arc::new(Float64Vector::from_vec(vec![max])) as _,
667 ];
668 let result = func
669 .eval(&FunctionContext::default(), args.as_slice())
670 .unwrap();
671 let expected: VectorRef = Arc::new(Float64Vector::from(expected));
672 assert_eq!(expected, result);
673 }
674 }
675
676 #[test]
677 fn clamp_min_type_not_match() {
678 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
679 let min = -1;
680
681 let func = ClampMinFunction;
682 let args = [
683 Arc::new(Float64Vector::from(input)) as _,
684 Arc::new(Int64Vector::from_vec(vec![min])) as _,
685 ];
686 let result = func.eval(&FunctionContext::default(), args.as_slice());
687 assert!(result.is_err());
688 }
689
690 #[test]
691 fn clamp_max_type_not_match() {
692 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
693 let max = 1;
694
695 let func = ClampMaxFunction;
696 let args = [
697 Arc::new(Float64Vector::from(input)) as _,
698 Arc::new(Int64Vector::from_vec(vec![max])) as _,
699 ];
700 let result = func.eval(&FunctionContext::default(), args.as_slice());
701 assert!(result.is_err());
702 }
703}