1use std::fmt::{self, Display};
16use std::sync::Arc;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use common_query::prelude::Signature;
20use datafusion::arrow::array::{ArrayIter, PrimitiveArray};
21use datafusion::logical_expr::Volatility;
22use datatypes::data_type::{ConcreteDataType, DataType};
23use datatypes::prelude::VectorRef;
24use datatypes::types::LogicalPrimitiveType;
25use datatypes::value::TryAsPrimitive;
26use datatypes::vectors::PrimitiveVector;
27use datatypes::with_match_primitive_type_id;
28use snafu::{ensure, OptionExt};
29
30use crate::function::{Function, FunctionContext};
31
32#[derive(Clone, Debug, Default)]
33pub struct ClampFunction;
34
35const CLAMP_NAME: &str = "clamp";
36
37impl Function for ClampFunction {
38 fn name(&self) -> &str {
39 CLAMP_NAME
40 }
41
42 fn return_type(&self, input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
43 Ok(input_types[0].clone())
45 }
46
47 fn signature(&self) -> Signature {
48 Signature::uniform(3, ConcreteDataType::numerics(), Volatility::Immutable)
50 }
51
52 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
53 ensure!(
54 columns.len() == 3,
55 InvalidFuncArgsSnafu {
56 err_msg: format!(
57 "The length of the args is not correct, expect exactly 3, have: {}",
58 columns.len()
59 ),
60 }
61 );
62 ensure!(
63 columns[0].data_type().is_numeric(),
64 InvalidFuncArgsSnafu {
65 err_msg: format!(
66 "The first arg's type is not numeric, have: {}",
67 columns[0].data_type()
68 ),
69 }
70 );
71 ensure!(
72 columns[0].data_type() == columns[1].data_type()
73 && columns[1].data_type() == columns[2].data_type(),
74 InvalidFuncArgsSnafu {
75 err_msg: format!(
76 "Arguments don't have identical types: {}, {}, {}",
77 columns[0].data_type(),
78 columns[1].data_type(),
79 columns[2].data_type()
80 ),
81 }
82 );
83 ensure!(
84 (columns[1].len() == 1 || columns[1].is_const())
85 && (columns[2].len() == 1 || columns[2].is_const()),
86 InvalidFuncArgsSnafu {
87 err_msg: format!(
88 "The second and third args should be scalar, have: {:?}, {:?}",
89 columns[1], columns[2]
90 ),
91 }
92 );
93
94 with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
95 let input_array = columns[0].to_arrow_array();
96 let input = input_array
97 .as_any()
98 .downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
99 .unwrap();
100
101 let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
102 .with_context(|| {
103 InvalidFuncArgsSnafu {
104 err_msg: "The second arg should not be none",
105 }
106 })?;
107 let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[2].get(0))
108 .with_context(|| {
109 InvalidFuncArgsSnafu {
110 err_msg: "The third arg should not be none",
111 }
112 })?;
113
114 ensure!(
116 min <= max,
117 InvalidFuncArgsSnafu {
118 err_msg: format!(
119 "The second arg should be less than or equal to the third arg, have: {:?}, {:?}",
120 columns[1], columns[2]
121 ),
122 }
123 );
124
125 clamp_impl::<$S, true, true>(input, min, max)
126 },{
127 unreachable!()
128 })
129 }
130}
131
132impl Display for ClampFunction {
133 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
134 write!(f, "{}", CLAMP_NAME.to_ascii_uppercase())
135 }
136}
137
138fn clamp_impl<T: LogicalPrimitiveType, const CLAMP_MIN: bool, const CLAMP_MAX: bool>(
139 input: &PrimitiveArray<T::ArrowPrimitive>,
140 min: T::Native,
141 max: T::Native,
142) -> Result<VectorRef> {
143 let iter = ArrayIter::new(input);
144 let result = iter.map(|x| {
145 x.map(|x| {
146 if CLAMP_MIN && x < min {
147 min
148 } else if CLAMP_MAX && x > max {
149 max
150 } else {
151 x
152 }
153 })
154 });
155 let result = PrimitiveArray::<T::ArrowPrimitive>::from_iter(result);
156 Ok(Arc::new(PrimitiveVector::<T>::from(result)))
157}
158
159#[derive(Clone, Debug, Default)]
160pub struct ClampMinFunction;
161
162const CLAMP_MIN_NAME: &str = "clamp_min";
163
164impl Function for ClampMinFunction {
165 fn name(&self) -> &str {
166 CLAMP_MIN_NAME
167 }
168
169 fn return_type(&self, input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
170 Ok(input_types[0].clone())
171 }
172
173 fn signature(&self) -> Signature {
174 Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable)
176 }
177
178 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
179 ensure!(
180 columns.len() == 2,
181 InvalidFuncArgsSnafu {
182 err_msg: format!(
183 "The length of the args is not correct, expect exactly 2, have: {}",
184 columns.len()
185 ),
186 }
187 );
188 ensure!(
189 columns[0].data_type().is_numeric(),
190 InvalidFuncArgsSnafu {
191 err_msg: format!(
192 "The first arg's type is not numeric, have: {}",
193 columns[0].data_type()
194 ),
195 }
196 );
197 ensure!(
198 columns[0].data_type() == columns[1].data_type(),
199 InvalidFuncArgsSnafu {
200 err_msg: format!(
201 "Arguments don't have identical types: {}, {}",
202 columns[0].data_type(),
203 columns[1].data_type()
204 ),
205 }
206 );
207 ensure!(
208 columns[1].len() == 1 || columns[1].is_const(),
209 InvalidFuncArgsSnafu {
210 err_msg: format!(
211 "The second arg (min) should be scalar, have: {:?}",
212 columns[1]
213 ),
214 }
215 );
216
217 with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
218 let input_array = columns[0].to_arrow_array();
219 let input = input_array
220 .as_any()
221 .downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
222 .unwrap();
223
224 let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
225 .with_context(|| {
226 InvalidFuncArgsSnafu {
227 err_msg: "The second arg (min) should not be none",
228 }
229 })?;
230 let max_dummy = <$S as LogicalPrimitiveType>::Native::default();
233
234 clamp_impl::<$S, true, false>(input, min, max_dummy)
235 },{
236 unreachable!()
237 })
238 }
239}
240
241impl Display for ClampMinFunction {
242 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
243 write!(f, "{}", CLAMP_MIN_NAME.to_ascii_uppercase())
244 }
245}
246
247#[derive(Clone, Debug, Default)]
248pub struct ClampMaxFunction;
249
250const CLAMP_MAX_NAME: &str = "clamp_max";
251
252impl Function for ClampMaxFunction {
253 fn name(&self) -> &str {
254 CLAMP_MAX_NAME
255 }
256
257 fn return_type(&self, input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
258 Ok(input_types[0].clone())
259 }
260
261 fn signature(&self) -> Signature {
262 Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable)
264 }
265
266 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
267 ensure!(
268 columns.len() == 2,
269 InvalidFuncArgsSnafu {
270 err_msg: format!(
271 "The length of the args is not correct, expect exactly 2, have: {}",
272 columns.len()
273 ),
274 }
275 );
276 ensure!(
277 columns[0].data_type().is_numeric(),
278 InvalidFuncArgsSnafu {
279 err_msg: format!(
280 "The first arg's type is not numeric, have: {}",
281 columns[0].data_type()
282 ),
283 }
284 );
285 ensure!(
286 columns[0].data_type() == columns[1].data_type(),
287 InvalidFuncArgsSnafu {
288 err_msg: format!(
289 "Arguments don't have identical types: {}, {}",
290 columns[0].data_type(),
291 columns[1].data_type()
292 ),
293 }
294 );
295 ensure!(
296 columns[1].len() == 1 || columns[1].is_const(),
297 InvalidFuncArgsSnafu {
298 err_msg: format!(
299 "The second arg (max) should be scalar, have: {:?}",
300 columns[1]
301 ),
302 }
303 );
304
305 with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
306 let input_array = columns[0].to_arrow_array();
307 let input = input_array
308 .as_any()
309 .downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
310 .unwrap();
311
312 let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
313 .with_context(|| {
314 InvalidFuncArgsSnafu {
315 err_msg: "The second arg (max) should not be none",
316 }
317 })?;
318 let min_dummy = <$S as LogicalPrimitiveType>::Native::default();
321
322 clamp_impl::<$S, false, true>(input, min_dummy, max)
323 },{
324 unreachable!()
325 })
326 }
327}
328
329impl Display for ClampMaxFunction {
330 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
331 write!(f, "{}", CLAMP_MAX_NAME.to_ascii_uppercase())
332 }
333}
334
335#[cfg(test)]
336mod test {
337
338 use std::sync::Arc;
339
340 use datatypes::prelude::ScalarVector;
341 use datatypes::vectors::{
342 ConstantVector, Float64Vector, Int64Vector, StringVector, UInt64Vector,
343 };
344
345 use super::*;
346 use crate::function::FunctionContext;
347
348 #[test]
349 fn clamp_i64() {
350 let inputs = [
351 (
352 vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
353 -1,
354 10,
355 vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
356 ),
357 (
358 vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
359 0,
360 0,
361 vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
362 ),
363 (
364 vec![Some(-3), None, Some(-1), None, None, Some(2)],
365 -2,
366 1,
367 vec![Some(-2), None, Some(-1), None, None, Some(1)],
368 ),
369 (
370 vec![None, None, None, None, None],
371 0,
372 1,
373 vec![None, None, None, None, None],
374 ),
375 ];
376
377 let func = ClampFunction;
378 for (in_data, min, max, expected) in inputs {
379 let args = [
380 Arc::new(Int64Vector::from(in_data)) as _,
381 Arc::new(Int64Vector::from_vec(vec![min])) as _,
382 Arc::new(Int64Vector::from_vec(vec![max])) as _,
383 ];
384 let result = func
385 .eval(&FunctionContext::default(), args.as_slice())
386 .unwrap();
387 let expected: VectorRef = Arc::new(Int64Vector::from(expected));
388 assert_eq!(expected, result);
389 }
390 }
391
392 #[test]
393 fn clamp_u64() {
394 let inputs = [
395 (
396 vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
397 1,
398 3,
399 vec![Some(1), Some(1), Some(2), Some(3), Some(3), Some(3)],
400 ),
401 (
402 vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
403 0,
404 0,
405 vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
406 ),
407 (
408 vec![Some(0), None, Some(2), None, None, Some(5)],
409 1,
410 3,
411 vec![Some(1), None, Some(2), None, None, Some(3)],
412 ),
413 (
414 vec![None, None, None, None, None],
415 0,
416 1,
417 vec![None, None, None, None, None],
418 ),
419 ];
420
421 let func = ClampFunction;
422 for (in_data, min, max, expected) in inputs {
423 let args = [
424 Arc::new(UInt64Vector::from(in_data)) as _,
425 Arc::new(UInt64Vector::from_vec(vec![min])) as _,
426 Arc::new(UInt64Vector::from_vec(vec![max])) as _,
427 ];
428 let result = func
429 .eval(&FunctionContext::default(), args.as_slice())
430 .unwrap();
431 let expected: VectorRef = Arc::new(UInt64Vector::from(expected));
432 assert_eq!(expected, result);
433 }
434 }
435
436 #[test]
437 fn clamp_f64() {
438 let inputs = [
439 (
440 vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
441 -1.0,
442 10.0,
443 vec![Some(-1.0), Some(-1.0), Some(-1.0), Some(0.0), Some(1.0)],
444 ),
445 (
446 vec![Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
447 0.0,
448 0.0,
449 vec![Some(0.0), Some(0.0), Some(0.0), Some(0.0)],
450 ),
451 (
452 vec![Some(-3.0), None, Some(-1.0), None, None, Some(2.0)],
453 -2.0,
454 1.0,
455 vec![Some(-2.0), None, Some(-1.0), None, None, Some(1.0)],
456 ),
457 (
458 vec![None, None, None, None, None],
459 0.0,
460 1.0,
461 vec![None, None, None, None, None],
462 ),
463 ];
464
465 let func = ClampFunction;
466 for (in_data, min, max, expected) in inputs {
467 let args = [
468 Arc::new(Float64Vector::from(in_data)) as _,
469 Arc::new(Float64Vector::from_vec(vec![min])) as _,
470 Arc::new(Float64Vector::from_vec(vec![max])) as _,
471 ];
472 let result = func
473 .eval(&FunctionContext::default(), args.as_slice())
474 .unwrap();
475 let expected: VectorRef = Arc::new(Float64Vector::from(expected));
476 assert_eq!(expected, result);
477 }
478 }
479
480 #[test]
481 fn clamp_const_i32() {
482 let input = vec![Some(5)];
483 let min = 2;
484 let max = 4;
485
486 let func = ClampFunction;
487 let args = [
488 Arc::new(ConstantVector::new(Arc::new(Int64Vector::from(input)), 1)) as _,
489 Arc::new(Int64Vector::from_vec(vec![min])) as _,
490 Arc::new(Int64Vector::from_vec(vec![max])) as _,
491 ];
492 let result = func
493 .eval(&FunctionContext::default(), args.as_slice())
494 .unwrap();
495 let expected: VectorRef = Arc::new(Int64Vector::from(vec![Some(4)]));
496 assert_eq!(expected, result);
497 }
498
499 #[test]
500 fn clamp_invalid_min_max() {
501 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
502 let min = 10.0;
503 let max = -1.0;
504
505 let func = ClampFunction;
506 let args = [
507 Arc::new(Float64Vector::from(input)) as _,
508 Arc::new(Float64Vector::from_vec(vec![min])) as _,
509 Arc::new(Float64Vector::from_vec(vec![max])) as _,
510 ];
511 let result = func.eval(&FunctionContext::default(), args.as_slice());
512 assert!(result.is_err());
513 }
514
515 #[test]
516 fn clamp_type_not_match() {
517 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
518 let min = -1;
519 let max = 10;
520
521 let func = ClampFunction;
522 let args = [
523 Arc::new(Float64Vector::from(input)) as _,
524 Arc::new(Int64Vector::from_vec(vec![min])) as _,
525 Arc::new(UInt64Vector::from_vec(vec![max])) as _,
526 ];
527 let result = func.eval(&FunctionContext::default(), args.as_slice());
528 assert!(result.is_err());
529 }
530
531 #[test]
532 fn clamp_min_is_not_scalar() {
533 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
534 let min = -10.0;
535 let max = 1.0;
536
537 let func = ClampFunction;
538 let args = [
539 Arc::new(Float64Vector::from(input)) as _,
540 Arc::new(Float64Vector::from_vec(vec![min, min])) as _,
541 Arc::new(Float64Vector::from_vec(vec![max])) as _,
542 ];
543 let result = func.eval(&FunctionContext::default(), args.as_slice());
544 assert!(result.is_err());
545 }
546
547 #[test]
548 fn clamp_no_max() {
549 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
550 let min = -10.0;
551
552 let func = ClampFunction;
553 let args = [
554 Arc::new(Float64Vector::from(input)) as _,
555 Arc::new(Float64Vector::from_vec(vec![min])) as _,
556 ];
557 let result = func.eval(&FunctionContext::default(), args.as_slice());
558 assert!(result.is_err());
559 }
560
561 #[test]
562 fn clamp_on_string() {
563 let input = vec![Some("foo"), Some("foo"), Some("foo"), Some("foo")];
564
565 let func = ClampFunction;
566 let args = [
567 Arc::new(StringVector::from(input)) as _,
568 Arc::new(StringVector::from_vec(vec!["bar"])) as _,
569 Arc::new(StringVector::from_vec(vec!["baz"])) as _,
570 ];
571 let result = func.eval(&FunctionContext::default(), args.as_slice());
572 assert!(result.is_err());
573 }
574
575 #[test]
576 fn clamp_min_i64() {
577 let inputs = [
578 (
579 vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
580 -1,
581 vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
582 ),
583 (
584 vec![Some(-3), None, Some(-1), None, None, Some(2)],
585 -2,
586 vec![Some(-2), None, Some(-1), None, None, Some(2)],
587 ),
588 ];
589
590 let func = ClampMinFunction;
591 for (in_data, min, expected) in inputs {
592 let args = [
593 Arc::new(Int64Vector::from(in_data)) as _,
594 Arc::new(Int64Vector::from_vec(vec![min])) as _,
595 ];
596 let result = func
597 .eval(&FunctionContext::default(), args.as_slice())
598 .unwrap();
599 let expected: VectorRef = Arc::new(Int64Vector::from(expected));
600 assert_eq!(expected, result);
601 }
602 }
603
604 #[test]
605 fn clamp_max_i64() {
606 let inputs = [
607 (
608 vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
609 1,
610 vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(1)],
611 ),
612 (
613 vec![Some(-3), None, Some(-1), None, None, Some(2)],
614 0,
615 vec![Some(-3), None, Some(-1), None, None, Some(0)],
616 ),
617 ];
618
619 let func = ClampMaxFunction;
620 for (in_data, max, expected) in inputs {
621 let args = [
622 Arc::new(Int64Vector::from(in_data)) as _,
623 Arc::new(Int64Vector::from_vec(vec![max])) as _,
624 ];
625 let result = func
626 .eval(&FunctionContext::default(), args.as_slice())
627 .unwrap();
628 let expected: VectorRef = Arc::new(Int64Vector::from(expected));
629 assert_eq!(expected, result);
630 }
631 }
632
633 #[test]
634 fn clamp_min_f64() {
635 let inputs = [(
636 vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
637 -1.0,
638 vec![Some(-1.0), Some(-1.0), Some(-1.0), Some(0.0), Some(1.0)],
639 )];
640
641 let func = ClampMinFunction;
642 for (in_data, min, expected) in inputs {
643 let args = [
644 Arc::new(Float64Vector::from(in_data)) as _,
645 Arc::new(Float64Vector::from_vec(vec![min])) as _,
646 ];
647 let result = func
648 .eval(&FunctionContext::default(), args.as_slice())
649 .unwrap();
650 let expected: VectorRef = Arc::new(Float64Vector::from(expected));
651 assert_eq!(expected, result);
652 }
653 }
654
655 #[test]
656 fn clamp_max_f64() {
657 let inputs = [(
658 vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
659 0.0,
660 vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(0.0)],
661 )];
662
663 let func = ClampMaxFunction;
664 for (in_data, max, expected) in inputs {
665 let args = [
666 Arc::new(Float64Vector::from(in_data)) as _,
667 Arc::new(Float64Vector::from_vec(vec![max])) as _,
668 ];
669 let result = func
670 .eval(&FunctionContext::default(), args.as_slice())
671 .unwrap();
672 let expected: VectorRef = Arc::new(Float64Vector::from(expected));
673 assert_eq!(expected, result);
674 }
675 }
676
677 #[test]
678 fn clamp_min_type_not_match() {
679 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
680 let min = -1;
681
682 let func = ClampMinFunction;
683 let args = [
684 Arc::new(Float64Vector::from(input)) as _,
685 Arc::new(Int64Vector::from_vec(vec![min])) as _,
686 ];
687 let result = func.eval(&FunctionContext::default(), args.as_slice());
688 assert!(result.is_err());
689 }
690
691 #[test]
692 fn clamp_max_type_not_match() {
693 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
694 let max = 1;
695
696 let func = ClampMaxFunction;
697 let args = [
698 Arc::new(Float64Vector::from(input)) as _,
699 Arc::new(Int64Vector::from_vec(vec![max])) as _,
700 ];
701 let result = func.eval(&FunctionContext::default(), args.as_slice());
702 assert!(result.is_err());
703 }
704}