1use std::fmt::{self, Display};
16use std::sync::Arc;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use common_query::prelude::Signature;
20use datafusion::arrow::array::{ArrayIter, PrimitiveArray};
21use datafusion::logical_expr::Volatility;
22use datatypes::data_type::{ConcreteDataType, DataType};
23use datatypes::prelude::VectorRef;
24use datatypes::types::LogicalPrimitiveType;
25use datatypes::value::TryAsPrimitive;
26use datatypes::vectors::PrimitiveVector;
27use datatypes::with_match_primitive_type_id;
28use snafu::{ensure, OptionExt};
29
30use crate::function::{Function, FunctionContext};
31
32#[derive(Clone, Debug, Default)]
33pub struct ClampFunction;
34
35const CLAMP_NAME: &str = "clamp";
36
37impl Function for ClampFunction {
38 fn name(&self) -> &str {
39 CLAMP_NAME
40 }
41
42 fn return_type(&self, input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
43 Ok(input_types[0].clone())
45 }
46
47 fn signature(&self) -> Signature {
48 Signature::uniform(3, ConcreteDataType::numerics(), Volatility::Immutable)
50 }
51
52 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
53 ensure!(
54 columns.len() == 3,
55 InvalidFuncArgsSnafu {
56 err_msg: format!(
57 "The length of the args is not correct, expect exactly 3, have: {}",
58 columns.len()
59 ),
60 }
61 );
62 ensure!(
63 columns[0].data_type().is_numeric(),
64 InvalidFuncArgsSnafu {
65 err_msg: format!(
66 "The first arg's type is not numeric, have: {}",
67 columns[0].data_type()
68 ),
69 }
70 );
71 ensure!(
72 columns[0].data_type() == columns[1].data_type()
73 && columns[1].data_type() == columns[2].data_type(),
74 InvalidFuncArgsSnafu {
75 err_msg: format!(
76 "Arguments don't have identical types: {}, {}, {}",
77 columns[0].data_type(),
78 columns[1].data_type(),
79 columns[2].data_type()
80 ),
81 }
82 );
83 ensure!(
84 columns[1].len() == 1 && columns[2].len() == 1,
85 InvalidFuncArgsSnafu {
86 err_msg: format!(
87 "The second and third args should be scalar, have: {:?}, {:?}",
88 columns[1], columns[2]
89 ),
90 }
91 );
92
93 with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
94 let input_array = columns[0].to_arrow_array();
95 let input = input_array
96 .as_any()
97 .downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
98 .unwrap();
99
100 let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
101 .with_context(|| {
102 InvalidFuncArgsSnafu {
103 err_msg: "The second arg should not be none",
104 }
105 })?;
106 let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[2].get(0))
107 .with_context(|| {
108 InvalidFuncArgsSnafu {
109 err_msg: "The third arg should not be none",
110 }
111 })?;
112
113 ensure!(
115 min <= max,
116 InvalidFuncArgsSnafu {
117 err_msg: format!(
118 "The second arg should be less than or equal to the third arg, have: {:?}, {:?}",
119 columns[1], columns[2]
120 ),
121 }
122 );
123
124 clamp_impl::<$S, true, true>(input, min, max)
125 },{
126 unreachable!()
127 })
128 }
129}
130
131impl Display for ClampFunction {
132 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
133 write!(f, "{}", CLAMP_NAME.to_ascii_uppercase())
134 }
135}
136
137fn clamp_impl<T: LogicalPrimitiveType, const CLAMP_MIN: bool, const CLAMP_MAX: bool>(
138 input: &PrimitiveArray<T::ArrowPrimitive>,
139 min: T::Native,
140 max: T::Native,
141) -> Result<VectorRef> {
142 let iter = ArrayIter::new(input);
143 let result = iter.map(|x| {
144 x.map(|x| {
145 if CLAMP_MIN && x < min {
146 min
147 } else if CLAMP_MAX && x > max {
148 max
149 } else {
150 x
151 }
152 })
153 });
154 let result = PrimitiveArray::<T::ArrowPrimitive>::from_iter(result);
155 Ok(Arc::new(PrimitiveVector::<T>::from(result)))
156}
157
158#[cfg(test)]
159mod test {
160
161 use std::sync::Arc;
162
163 use datatypes::prelude::ScalarVector;
164 use datatypes::vectors::{
165 ConstantVector, Float64Vector, Int64Vector, StringVector, UInt64Vector,
166 };
167
168 use super::*;
169 use crate::function::FunctionContext;
170
171 #[test]
172 fn clamp_i64() {
173 let inputs = [
174 (
175 vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
176 -1,
177 10,
178 vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
179 ),
180 (
181 vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
182 0,
183 0,
184 vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
185 ),
186 (
187 vec![Some(-3), None, Some(-1), None, None, Some(2)],
188 -2,
189 1,
190 vec![Some(-2), None, Some(-1), None, None, Some(1)],
191 ),
192 (
193 vec![None, None, None, None, None],
194 0,
195 1,
196 vec![None, None, None, None, None],
197 ),
198 ];
199
200 let func = ClampFunction;
201 for (in_data, min, max, expected) in inputs {
202 let args = [
203 Arc::new(Int64Vector::from(in_data)) as _,
204 Arc::new(Int64Vector::from_vec(vec![min])) as _,
205 Arc::new(Int64Vector::from_vec(vec![max])) as _,
206 ];
207 let result = func
208 .eval(&FunctionContext::default(), args.as_slice())
209 .unwrap();
210 let expected: VectorRef = Arc::new(Int64Vector::from(expected));
211 assert_eq!(expected, result);
212 }
213 }
214
215 #[test]
216 fn clamp_u64() {
217 let inputs = [
218 (
219 vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
220 1,
221 3,
222 vec![Some(1), Some(1), Some(2), Some(3), Some(3), Some(3)],
223 ),
224 (
225 vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
226 0,
227 0,
228 vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
229 ),
230 (
231 vec![Some(0), None, Some(2), None, None, Some(5)],
232 1,
233 3,
234 vec![Some(1), None, Some(2), None, None, Some(3)],
235 ),
236 (
237 vec![None, None, None, None, None],
238 0,
239 1,
240 vec![None, None, None, None, None],
241 ),
242 ];
243
244 let func = ClampFunction;
245 for (in_data, min, max, expected) in inputs {
246 let args = [
247 Arc::new(UInt64Vector::from(in_data)) as _,
248 Arc::new(UInt64Vector::from_vec(vec![min])) as _,
249 Arc::new(UInt64Vector::from_vec(vec![max])) as _,
250 ];
251 let result = func
252 .eval(&FunctionContext::default(), args.as_slice())
253 .unwrap();
254 let expected: VectorRef = Arc::new(UInt64Vector::from(expected));
255 assert_eq!(expected, result);
256 }
257 }
258
259 #[test]
260 fn clamp_f64() {
261 let inputs = [
262 (
263 vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
264 -1.0,
265 10.0,
266 vec![Some(-1.0), Some(-1.0), Some(-1.0), Some(0.0), Some(1.0)],
267 ),
268 (
269 vec![Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
270 0.0,
271 0.0,
272 vec![Some(0.0), Some(0.0), Some(0.0), Some(0.0)],
273 ),
274 (
275 vec![Some(-3.0), None, Some(-1.0), None, None, Some(2.0)],
276 -2.0,
277 1.0,
278 vec![Some(-2.0), None, Some(-1.0), None, None, Some(1.0)],
279 ),
280 (
281 vec![None, None, None, None, None],
282 0.0,
283 1.0,
284 vec![None, None, None, None, None],
285 ),
286 ];
287
288 let func = ClampFunction;
289 for (in_data, min, max, expected) in inputs {
290 let args = [
291 Arc::new(Float64Vector::from(in_data)) as _,
292 Arc::new(Float64Vector::from_vec(vec![min])) as _,
293 Arc::new(Float64Vector::from_vec(vec![max])) as _,
294 ];
295 let result = func
296 .eval(&FunctionContext::default(), args.as_slice())
297 .unwrap();
298 let expected: VectorRef = Arc::new(Float64Vector::from(expected));
299 assert_eq!(expected, result);
300 }
301 }
302
303 #[test]
304 fn clamp_const_i32() {
305 let input = vec![Some(5)];
306 let min = 2;
307 let max = 4;
308
309 let func = ClampFunction;
310 let args = [
311 Arc::new(ConstantVector::new(Arc::new(Int64Vector::from(input)), 1)) as _,
312 Arc::new(Int64Vector::from_vec(vec![min])) as _,
313 Arc::new(Int64Vector::from_vec(vec![max])) as _,
314 ];
315 let result = func
316 .eval(&FunctionContext::default(), args.as_slice())
317 .unwrap();
318 let expected: VectorRef = Arc::new(Int64Vector::from(vec![Some(4)]));
319 assert_eq!(expected, result);
320 }
321
322 #[test]
323 fn clamp_invalid_min_max() {
324 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
325 let min = 10.0;
326 let max = -1.0;
327
328 let func = ClampFunction;
329 let args = [
330 Arc::new(Float64Vector::from(input)) as _,
331 Arc::new(Float64Vector::from_vec(vec![min])) as _,
332 Arc::new(Float64Vector::from_vec(vec![max])) as _,
333 ];
334 let result = func.eval(&FunctionContext::default(), args.as_slice());
335 assert!(result.is_err());
336 }
337
338 #[test]
339 fn clamp_type_not_match() {
340 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
341 let min = -1;
342 let max = 10;
343
344 let func = ClampFunction;
345 let args = [
346 Arc::new(Float64Vector::from(input)) as _,
347 Arc::new(Int64Vector::from_vec(vec![min])) as _,
348 Arc::new(UInt64Vector::from_vec(vec![max])) as _,
349 ];
350 let result = func.eval(&FunctionContext::default(), args.as_slice());
351 assert!(result.is_err());
352 }
353
354 #[test]
355 fn clamp_min_is_not_scalar() {
356 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
357 let min = -10.0;
358 let max = 1.0;
359
360 let func = ClampFunction;
361 let args = [
362 Arc::new(Float64Vector::from(input)) as _,
363 Arc::new(Float64Vector::from_vec(vec![min, min])) as _,
364 Arc::new(Float64Vector::from_vec(vec![max])) as _,
365 ];
366 let result = func.eval(&FunctionContext::default(), args.as_slice());
367 assert!(result.is_err());
368 }
369
370 #[test]
371 fn clamp_no_max() {
372 let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
373 let min = -10.0;
374
375 let func = ClampFunction;
376 let args = [
377 Arc::new(Float64Vector::from(input)) as _,
378 Arc::new(Float64Vector::from_vec(vec![min])) as _,
379 ];
380 let result = func.eval(&FunctionContext::default(), args.as_slice());
381 assert!(result.is_err());
382 }
383
384 #[test]
385 fn clamp_on_string() {
386 let input = vec![Some("foo"), Some("foo"), Some("foo"), Some("foo")];
387
388 let func = ClampFunction;
389 let args = [
390 Arc::new(StringVector::from(input)) as _,
391 Arc::new(StringVector::from_vec(vec!["bar"])) as _,
392 Arc::new(StringVector::from_vec(vec!["baz"])) as _,
393 ];
394 let result = func.eval(&FunctionContext::default(), args.as_slice());
395 assert!(result.is_err());
396 }
397}