common_function/scalars/expression/
if_func.rs1use std::fmt;
16use std::fmt::Display;
17
18use arrow::array::ArrowNativeTypeOp;
19use arrow::datatypes::ArrowPrimitiveType;
20use datafusion::arrow::array::{Array, ArrayRef, AsArray, BooleanArray, PrimitiveArray};
21use datafusion::arrow::compute::kernels::zip::zip;
22use datafusion::arrow::datatypes::DataType;
23use datafusion_common::DataFusionError;
24use datafusion_expr::type_coercion::binary::comparison_coercion;
25use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
26
27use crate::function::Function;
28
29const NAME: &str = "if";
30
31#[derive(Clone, Debug)]
42pub struct IfFunction {
43 signature: Signature,
44}
45
46impl Default for IfFunction {
47 fn default() -> Self {
48 Self {
49 signature: Signature::any(3, Volatility::Immutable),
50 }
51 }
52}
53
54impl Display for IfFunction {
55 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
56 write!(f, "{}", NAME.to_ascii_uppercase())
57 }
58}
59
60impl Function for IfFunction {
61 fn name(&self) -> &str {
62 NAME
63 }
64
65 fn return_type(&self, input_types: &[DataType]) -> datafusion_common::Result<DataType> {
66 if input_types.len() < 3 {
68 return Err(DataFusionError::Plan(format!(
69 "{} requires 3 arguments, got {}",
70 NAME,
71 input_types.len()
72 )));
73 }
74 let true_type = &input_types[1];
75 let false_type = &input_types[2];
76
77 comparison_coercion(true_type, false_type).ok_or_else(|| {
79 DataFusionError::Plan(format!(
80 "Cannot find common type for IF function between {:?} and {:?}",
81 true_type, false_type
82 ))
83 })
84 }
85
86 fn signature(&self) -> &Signature {
87 &self.signature
88 }
89
90 fn invoke_with_args(
91 &self,
92 args: ScalarFunctionArgs,
93 ) -> datafusion_common::Result<ColumnarValue> {
94 if args.args.len() != 3 {
95 return Err(DataFusionError::Plan(format!(
96 "{} requires exactly 3 arguments, got {}",
97 NAME,
98 args.args.len()
99 )));
100 }
101
102 let condition = &args.args[0];
103 let true_value = &args.args[1];
104 let false_value = &args.args[2];
105
106 let bool_array = to_boolean_array(condition, args.number_rows)?;
108
109 let true_array = true_value.to_array(args.number_rows)?;
111 let false_array = false_value.to_array(args.number_rows)?;
112
113 let result = zip(&bool_array, &true_array, &false_array)?;
116 Ok(ColumnarValue::Array(result))
117 }
118}
119
120fn to_boolean_array(
126 value: &ColumnarValue,
127 num_rows: usize,
128) -> datafusion_common::Result<BooleanArray> {
129 let array = value.to_array(num_rows)?;
130 array_to_bool(array)
131}
132
133fn int_array_to_bool<T>(array: &PrimitiveArray<T>) -> BooleanArray
136where
137 T: ArrowPrimitiveType,
138 T::Native: ArrowNativeTypeOp,
139{
140 BooleanArray::from_iter(
141 array
142 .iter()
143 .map(|opt| Some(opt.is_some_and(|v| !v.is_zero()))),
144 )
145}
146
147fn float_array_to_bool<T>(array: &PrimitiveArray<T>) -> BooleanArray
150where
151 T: ArrowPrimitiveType,
152 T::Native: ArrowNativeTypeOp + num_traits::Float,
153{
154 use num_traits::Float;
155 BooleanArray::from_iter(
156 array
157 .iter()
158 .map(|opt| Some(opt.is_some_and(|v| v.is_nan() || !v.is_zero()))),
159 )
160}
161
162fn array_to_bool(array: ArrayRef) -> datafusion_common::Result<BooleanArray> {
164 use arrow::datatypes::*;
165
166 match array.data_type() {
167 DataType::Boolean => {
168 let bool_array = array.as_boolean();
169 Ok(BooleanArray::from_iter(
170 bool_array.iter().map(|opt| Some(opt.unwrap_or(false))),
171 ))
172 }
173 DataType::Int8 => Ok(int_array_to_bool(array.as_primitive::<Int8Type>())),
174 DataType::Int16 => Ok(int_array_to_bool(array.as_primitive::<Int16Type>())),
175 DataType::Int32 => Ok(int_array_to_bool(array.as_primitive::<Int32Type>())),
176 DataType::Int64 => Ok(int_array_to_bool(array.as_primitive::<Int64Type>())),
177 DataType::UInt8 => Ok(int_array_to_bool(array.as_primitive::<UInt8Type>())),
178 DataType::UInt16 => Ok(int_array_to_bool(array.as_primitive::<UInt16Type>())),
179 DataType::UInt32 => Ok(int_array_to_bool(array.as_primitive::<UInt32Type>())),
180 DataType::UInt64 => Ok(int_array_to_bool(array.as_primitive::<UInt64Type>())),
181 DataType::Float16 => {
183 let typed_array = array.as_primitive::<Float16Type>();
184 Ok(BooleanArray::from_iter(typed_array.iter().map(|opt| {
185 Some(opt.is_some_and(|v| {
186 let f = v.to_f32();
187 f.is_nan() || !f.is_zero()
188 }))
189 })))
190 }
191 DataType::Float32 => Ok(float_array_to_bool(array.as_primitive::<Float32Type>())),
192 DataType::Float64 => Ok(float_array_to_bool(array.as_primitive::<Float64Type>())),
193 DataType::Null => Ok(BooleanArray::from(vec![false; array.len()])),
197 _ => {
199 let len = array.len();
200 Ok(BooleanArray::from_iter(
201 (0..len).map(|i| Some(!array.is_null(i))),
202 ))
203 }
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use std::sync::Arc;
210
211 use arrow_schema::Field;
212 use datafusion_common::ScalarValue;
213 use datafusion_common::arrow::array::{AsArray, Int32Array, StringArray};
214
215 use super::*;
216
217 #[test]
218 fn test_if_function_basic() {
219 let if_func = IfFunction::default();
220 assert_eq!("if", if_func.name());
221
222 let result = if_func
224 .invoke_with_args(ScalarFunctionArgs {
225 args: vec![
226 ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))),
227 ColumnarValue::Scalar(ScalarValue::Utf8(Some("yes".to_string()))),
228 ColumnarValue::Scalar(ScalarValue::Utf8(Some("no".to_string()))),
229 ],
230 arg_fields: vec![],
231 number_rows: 1,
232 return_field: Arc::new(Field::new("", DataType::Utf8, true)),
233 config_options: Arc::new(Default::default()),
234 })
235 .unwrap();
236
237 if let ColumnarValue::Array(arr) = result {
238 let str_arr = arr.as_string::<i32>();
239 assert_eq!(str_arr.value(0), "yes");
240 } else {
241 panic!("Expected Array result");
242 }
243 }
244
245 #[test]
246 fn test_if_function_false() {
247 let if_func = IfFunction::default();
248
249 let result = if_func
251 .invoke_with_args(ScalarFunctionArgs {
252 args: vec![
253 ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))),
254 ColumnarValue::Scalar(ScalarValue::Utf8(Some("yes".to_string()))),
255 ColumnarValue::Scalar(ScalarValue::Utf8(Some("no".to_string()))),
256 ],
257 arg_fields: vec![],
258 number_rows: 1,
259 return_field: Arc::new(Field::new("", DataType::Utf8, true)),
260 config_options: Arc::new(Default::default()),
261 })
262 .unwrap();
263
264 if let ColumnarValue::Array(arr) = result {
265 let str_arr = arr.as_string::<i32>();
266 assert_eq!(str_arr.value(0), "no");
267 } else {
268 panic!("Expected Array result");
269 }
270 }
271
272 #[test]
273 fn test_if_function_null_is_false() {
274 let if_func = IfFunction::default();
275
276 let result = if_func
279 .invoke_with_args(ScalarFunctionArgs {
280 args: vec![
281 ColumnarValue::Scalar(ScalarValue::Boolean(None)),
282 ColumnarValue::Scalar(ScalarValue::Utf8(Some("yes".to_string()))),
283 ColumnarValue::Scalar(ScalarValue::Utf8(Some("no".to_string()))),
284 ],
285 arg_fields: vec![],
286 number_rows: 1,
287 return_field: Arc::new(Field::new("", DataType::Utf8, true)),
288 config_options: Arc::new(Default::default()),
289 })
290 .unwrap();
291
292 if let ColumnarValue::Array(arr) = result {
293 let str_arr = arr.as_string::<i32>();
294 assert_eq!(str_arr.value(0), "no");
295 } else {
296 panic!("Expected Array result");
297 }
298
299 let result = if_func
301 .invoke_with_args(ScalarFunctionArgs {
302 args: vec![
303 ColumnarValue::Scalar(ScalarValue::Null),
304 ColumnarValue::Scalar(ScalarValue::Utf8(Some("yes".to_string()))),
305 ColumnarValue::Scalar(ScalarValue::Utf8(Some("no".to_string()))),
306 ],
307 arg_fields: vec![],
308 number_rows: 1,
309 return_field: Arc::new(Field::new("", DataType::Utf8, true)),
310 config_options: Arc::new(Default::default()),
311 })
312 .unwrap();
313
314 if let ColumnarValue::Array(arr) = result {
315 let str_arr = arr.as_string::<i32>();
316 assert_eq!(str_arr.value(0), "no");
317 } else {
318 panic!("Expected Array result");
319 }
320 }
321
322 #[test]
323 fn test_if_function_numeric_truthy() {
324 let if_func = IfFunction::default();
325
326 let result = if_func
328 .invoke_with_args(ScalarFunctionArgs {
329 args: vec![
330 ColumnarValue::Scalar(ScalarValue::Int32(Some(1))),
331 ColumnarValue::Scalar(ScalarValue::Utf8(Some("yes".to_string()))),
332 ColumnarValue::Scalar(ScalarValue::Utf8(Some("no".to_string()))),
333 ],
334 arg_fields: vec![],
335 number_rows: 1,
336 return_field: Arc::new(Field::new("", DataType::Utf8, true)),
337 config_options: Arc::new(Default::default()),
338 })
339 .unwrap();
340
341 if let ColumnarValue::Array(arr) = result {
342 let str_arr = arr.as_string::<i32>();
343 assert_eq!(str_arr.value(0), "yes");
344 } else {
345 panic!("Expected Array result");
346 }
347
348 let result = if_func
350 .invoke_with_args(ScalarFunctionArgs {
351 args: vec![
352 ColumnarValue::Scalar(ScalarValue::Int32(Some(0))),
353 ColumnarValue::Scalar(ScalarValue::Utf8(Some("yes".to_string()))),
354 ColumnarValue::Scalar(ScalarValue::Utf8(Some("no".to_string()))),
355 ],
356 arg_fields: vec![],
357 number_rows: 1,
358 return_field: Arc::new(Field::new("", DataType::Utf8, true)),
359 config_options: Arc::new(Default::default()),
360 })
361 .unwrap();
362
363 if let ColumnarValue::Array(arr) = result {
364 let str_arr = arr.as_string::<i32>();
365 assert_eq!(str_arr.value(0), "no");
366 } else {
367 panic!("Expected Array result");
368 }
369 }
370
371 #[test]
372 fn test_if_function_with_arrays() {
373 let if_func = IfFunction::default();
374
375 let condition = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
377 let true_val = StringArray::from(vec!["yes", "yes", "yes", "yes"]);
378 let false_val = StringArray::from(vec!["no", "no", "no", "no"]);
379
380 let result = if_func
381 .invoke_with_args(ScalarFunctionArgs {
382 args: vec![
383 ColumnarValue::Array(Arc::new(condition)),
384 ColumnarValue::Array(Arc::new(true_val)),
385 ColumnarValue::Array(Arc::new(false_val)),
386 ],
387 arg_fields: vec![],
388 number_rows: 4,
389 return_field: Arc::new(Field::new("", DataType::Utf8, true)),
390 config_options: Arc::new(Default::default()),
391 })
392 .unwrap();
393
394 if let ColumnarValue::Array(arr) = result {
395 let str_arr = arr.as_string::<i32>();
396 assert_eq!(str_arr.value(0), "yes"); assert_eq!(str_arr.value(1), "no"); assert_eq!(str_arr.value(2), "no"); assert_eq!(str_arr.value(3), "yes"); } else {
401 panic!("Expected Array result");
402 }
403 }
404}