common_function/scalars/expression/
if_func.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt;
16use std::fmt::Display;
17
18use arrow::array::ArrowNativeTypeOp;
19use arrow::datatypes::ArrowPrimitiveType;
20use datafusion::arrow::array::{Array, ArrayRef, AsArray, BooleanArray, PrimitiveArray};
21use datafusion::arrow::compute::kernels::zip::zip;
22use datafusion::arrow::datatypes::DataType;
23use datafusion_common::DataFusionError;
24use datafusion_expr::type_coercion::binary::comparison_coercion;
25use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
26
27use crate::function::Function;
28
29const NAME: &str = "if";
30
31/// MySQL-compatible IF function: IF(condition, true_value, false_value)
32///
33/// Returns true_value if condition is TRUE (not NULL and not 0),
34/// otherwise returns false_value.
35///
36/// MySQL truthy rules:
37/// - NULL -> false
38/// - 0 (numeric zero) -> false
39/// - Any non-zero numeric -> true
40/// - Boolean true/false -> use directly
41#[derive(Clone, Debug)]
42pub struct IfFunction {
43    signature: Signature,
44}
45
46impl Default for IfFunction {
47    fn default() -> Self {
48        Self {
49            signature: Signature::any(3, Volatility::Immutable),
50        }
51    }
52}
53
54impl Display for IfFunction {
55    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
56        write!(f, "{}", NAME.to_ascii_uppercase())
57    }
58}
59
60impl Function for IfFunction {
61    fn name(&self) -> &str {
62        NAME
63    }
64
65    fn return_type(&self, input_types: &[DataType]) -> datafusion_common::Result<DataType> {
66        // Return the common type of true_value and false_value (args[1] and args[2])
67        if input_types.len() < 3 {
68            return Err(DataFusionError::Plan(format!(
69                "{} requires 3 arguments, got {}",
70                NAME,
71                input_types.len()
72            )));
73        }
74        let true_type = &input_types[1];
75        let false_type = &input_types[2];
76
77        // Use comparison_coercion to find common type
78        comparison_coercion(true_type, false_type).ok_or_else(|| {
79            DataFusionError::Plan(format!(
80                "Cannot find common type for IF function between {:?} and {:?}",
81                true_type, false_type
82            ))
83        })
84    }
85
86    fn signature(&self) -> &Signature {
87        &self.signature
88    }
89
90    fn invoke_with_args(
91        &self,
92        args: ScalarFunctionArgs,
93    ) -> datafusion_common::Result<ColumnarValue> {
94        if args.args.len() != 3 {
95            return Err(DataFusionError::Plan(format!(
96                "{} requires exactly 3 arguments, got {}",
97                NAME,
98                args.args.len()
99            )));
100        }
101
102        let condition = &args.args[0];
103        let true_value = &args.args[1];
104        let false_value = &args.args[2];
105
106        // Convert condition to boolean array using MySQL truthy rules
107        let bool_array = to_boolean_array(condition, args.number_rows)?;
108
109        // Convert true and false values to arrays
110        let true_array = true_value.to_array(args.number_rows)?;
111        let false_array = false_value.to_array(args.number_rows)?;
112
113        // Use zip to select values based on condition
114        // zip expects &dyn Datum, and ArrayRef (Arc<dyn Array>) implements Datum
115        let result = zip(&bool_array, &true_array, &false_array)?;
116        Ok(ColumnarValue::Array(result))
117    }
118}
119
120/// Convert a ColumnarValue to a BooleanArray using MySQL truthy rules:
121/// - NULL -> false
122/// - 0 (any numeric zero) -> false
123/// - Non-zero numeric -> true
124/// - Boolean -> use directly
125fn to_boolean_array(
126    value: &ColumnarValue,
127    num_rows: usize,
128) -> datafusion_common::Result<BooleanArray> {
129    let array = value.to_array(num_rows)?;
130    array_to_bool(array)
131}
132
133/// Convert an integer PrimitiveArray to BooleanArray using MySQL truthy rules:
134/// NULL -> false, 0 -> false, non-zero -> true
135fn int_array_to_bool<T>(array: &PrimitiveArray<T>) -> BooleanArray
136where
137    T: ArrowPrimitiveType,
138    T::Native: ArrowNativeTypeOp,
139{
140    BooleanArray::from_iter(
141        array
142            .iter()
143            .map(|opt| Some(opt.is_some_and(|v| !v.is_zero()))),
144    )
145}
146
147/// Convert a float PrimitiveArray to BooleanArray using MySQL truthy rules:
148/// NULL -> false, 0 (including -0.0) -> false, NaN -> true, other non-zero -> true
149fn float_array_to_bool<T>(array: &PrimitiveArray<T>) -> BooleanArray
150where
151    T: ArrowPrimitiveType,
152    T::Native: ArrowNativeTypeOp + num_traits::Float,
153{
154    use num_traits::Float;
155    BooleanArray::from_iter(
156        array
157            .iter()
158            .map(|opt| Some(opt.is_some_and(|v| v.is_nan() || !v.is_zero()))),
159    )
160}
161
162/// Convert an Array to BooleanArray using MySQL truthy rules
163fn array_to_bool(array: ArrayRef) -> datafusion_common::Result<BooleanArray> {
164    use arrow::datatypes::*;
165
166    match array.data_type() {
167        DataType::Boolean => {
168            let bool_array = array.as_boolean();
169            Ok(BooleanArray::from_iter(
170                bool_array.iter().map(|opt| Some(opt.unwrap_or(false))),
171            ))
172        }
173        DataType::Int8 => Ok(int_array_to_bool(array.as_primitive::<Int8Type>())),
174        DataType::Int16 => Ok(int_array_to_bool(array.as_primitive::<Int16Type>())),
175        DataType::Int32 => Ok(int_array_to_bool(array.as_primitive::<Int32Type>())),
176        DataType::Int64 => Ok(int_array_to_bool(array.as_primitive::<Int64Type>())),
177        DataType::UInt8 => Ok(int_array_to_bool(array.as_primitive::<UInt8Type>())),
178        DataType::UInt16 => Ok(int_array_to_bool(array.as_primitive::<UInt16Type>())),
179        DataType::UInt32 => Ok(int_array_to_bool(array.as_primitive::<UInt32Type>())),
180        DataType::UInt64 => Ok(int_array_to_bool(array.as_primitive::<UInt64Type>())),
181        // Float16 needs special handling since half::f16 doesn't implement num_traits::Float
182        DataType::Float16 => {
183            let typed_array = array.as_primitive::<Float16Type>();
184            Ok(BooleanArray::from_iter(typed_array.iter().map(|opt| {
185                Some(opt.is_some_and(|v| {
186                    let f = v.to_f32();
187                    f.is_nan() || !f.is_zero()
188                }))
189            })))
190        }
191        DataType::Float32 => Ok(float_array_to_bool(array.as_primitive::<Float32Type>())),
192        DataType::Float64 => Ok(float_array_to_bool(array.as_primitive::<Float64Type>())),
193        // Null type is always false.
194        // Note: NullArray::is_null() returns false (physical null), so we must handle it explicitly.
195        // See: https://github.com/apache/arrow-rs/issues/4840
196        DataType::Null => Ok(BooleanArray::from(vec![false; array.len()])),
197        // For other types, treat non-null as true
198        _ => {
199            let len = array.len();
200            Ok(BooleanArray::from_iter(
201                (0..len).map(|i| Some(!array.is_null(i))),
202            ))
203        }
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use std::sync::Arc;
210
211    use arrow_schema::Field;
212    use datafusion_common::ScalarValue;
213    use datafusion_common::arrow::array::{AsArray, Int32Array, StringArray};
214
215    use super::*;
216
217    #[test]
218    fn test_if_function_basic() {
219        let if_func = IfFunction::default();
220        assert_eq!("if", if_func.name());
221
222        // Test IF(true, 'yes', 'no') -> 'yes'
223        let result = if_func
224            .invoke_with_args(ScalarFunctionArgs {
225                args: vec![
226                    ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))),
227                    ColumnarValue::Scalar(ScalarValue::Utf8(Some("yes".to_string()))),
228                    ColumnarValue::Scalar(ScalarValue::Utf8(Some("no".to_string()))),
229                ],
230                arg_fields: vec![],
231                number_rows: 1,
232                return_field: Arc::new(Field::new("", DataType::Utf8, true)),
233                config_options: Arc::new(Default::default()),
234            })
235            .unwrap();
236
237        if let ColumnarValue::Array(arr) = result {
238            let str_arr = arr.as_string::<i32>();
239            assert_eq!(str_arr.value(0), "yes");
240        } else {
241            panic!("Expected Array result");
242        }
243    }
244
245    #[test]
246    fn test_if_function_false() {
247        let if_func = IfFunction::default();
248
249        // Test IF(false, 'yes', 'no') -> 'no'
250        let result = if_func
251            .invoke_with_args(ScalarFunctionArgs {
252                args: vec![
253                    ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))),
254                    ColumnarValue::Scalar(ScalarValue::Utf8(Some("yes".to_string()))),
255                    ColumnarValue::Scalar(ScalarValue::Utf8(Some("no".to_string()))),
256                ],
257                arg_fields: vec![],
258                number_rows: 1,
259                return_field: Arc::new(Field::new("", DataType::Utf8, true)),
260                config_options: Arc::new(Default::default()),
261            })
262            .unwrap();
263
264        if let ColumnarValue::Array(arr) = result {
265            let str_arr = arr.as_string::<i32>();
266            assert_eq!(str_arr.value(0), "no");
267        } else {
268            panic!("Expected Array result");
269        }
270    }
271
272    #[test]
273    fn test_if_function_null_is_false() {
274        let if_func = IfFunction::default();
275
276        // Test IF(NULL, 'yes', 'no') -> 'no' (NULL is treated as false)
277        // Using Boolean(None) - typed null
278        let result = if_func
279            .invoke_with_args(ScalarFunctionArgs {
280                args: vec![
281                    ColumnarValue::Scalar(ScalarValue::Boolean(None)),
282                    ColumnarValue::Scalar(ScalarValue::Utf8(Some("yes".to_string()))),
283                    ColumnarValue::Scalar(ScalarValue::Utf8(Some("no".to_string()))),
284                ],
285                arg_fields: vec![],
286                number_rows: 1,
287                return_field: Arc::new(Field::new("", DataType::Utf8, true)),
288                config_options: Arc::new(Default::default()),
289            })
290            .unwrap();
291
292        if let ColumnarValue::Array(arr) = result {
293            let str_arr = arr.as_string::<i32>();
294            assert_eq!(str_arr.value(0), "no");
295        } else {
296            panic!("Expected Array result");
297        }
298
299        // Test IF(NULL, 'yes', 'no') -> 'no' using ScalarValue::Null (untyped null from SQL NULL literal)
300        let result = if_func
301            .invoke_with_args(ScalarFunctionArgs {
302                args: vec![
303                    ColumnarValue::Scalar(ScalarValue::Null),
304                    ColumnarValue::Scalar(ScalarValue::Utf8(Some("yes".to_string()))),
305                    ColumnarValue::Scalar(ScalarValue::Utf8(Some("no".to_string()))),
306                ],
307                arg_fields: vec![],
308                number_rows: 1,
309                return_field: Arc::new(Field::new("", DataType::Utf8, true)),
310                config_options: Arc::new(Default::default()),
311            })
312            .unwrap();
313
314        if let ColumnarValue::Array(arr) = result {
315            let str_arr = arr.as_string::<i32>();
316            assert_eq!(str_arr.value(0), "no");
317        } else {
318            panic!("Expected Array result");
319        }
320    }
321
322    #[test]
323    fn test_if_function_numeric_truthy() {
324        let if_func = IfFunction::default();
325
326        // Test IF(1, 'yes', 'no') -> 'yes' (non-zero is true)
327        let result = if_func
328            .invoke_with_args(ScalarFunctionArgs {
329                args: vec![
330                    ColumnarValue::Scalar(ScalarValue::Int32(Some(1))),
331                    ColumnarValue::Scalar(ScalarValue::Utf8(Some("yes".to_string()))),
332                    ColumnarValue::Scalar(ScalarValue::Utf8(Some("no".to_string()))),
333                ],
334                arg_fields: vec![],
335                number_rows: 1,
336                return_field: Arc::new(Field::new("", DataType::Utf8, true)),
337                config_options: Arc::new(Default::default()),
338            })
339            .unwrap();
340
341        if let ColumnarValue::Array(arr) = result {
342            let str_arr = arr.as_string::<i32>();
343            assert_eq!(str_arr.value(0), "yes");
344        } else {
345            panic!("Expected Array result");
346        }
347
348        // Test IF(0, 'yes', 'no') -> 'no' (zero is false)
349        let result = if_func
350            .invoke_with_args(ScalarFunctionArgs {
351                args: vec![
352                    ColumnarValue::Scalar(ScalarValue::Int32(Some(0))),
353                    ColumnarValue::Scalar(ScalarValue::Utf8(Some("yes".to_string()))),
354                    ColumnarValue::Scalar(ScalarValue::Utf8(Some("no".to_string()))),
355                ],
356                arg_fields: vec![],
357                number_rows: 1,
358                return_field: Arc::new(Field::new("", DataType::Utf8, true)),
359                config_options: Arc::new(Default::default()),
360            })
361            .unwrap();
362
363        if let ColumnarValue::Array(arr) = result {
364            let str_arr = arr.as_string::<i32>();
365            assert_eq!(str_arr.value(0), "no");
366        } else {
367            panic!("Expected Array result");
368        }
369    }
370
371    #[test]
372    fn test_if_function_with_arrays() {
373        let if_func = IfFunction::default();
374
375        // Test with array condition
376        let condition = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
377        let true_val = StringArray::from(vec!["yes", "yes", "yes", "yes"]);
378        let false_val = StringArray::from(vec!["no", "no", "no", "no"]);
379
380        let result = if_func
381            .invoke_with_args(ScalarFunctionArgs {
382                args: vec![
383                    ColumnarValue::Array(Arc::new(condition)),
384                    ColumnarValue::Array(Arc::new(true_val)),
385                    ColumnarValue::Array(Arc::new(false_val)),
386                ],
387                arg_fields: vec![],
388                number_rows: 4,
389                return_field: Arc::new(Field::new("", DataType::Utf8, true)),
390                config_options: Arc::new(Default::default()),
391            })
392            .unwrap();
393
394        if let ColumnarValue::Array(arr) = result {
395            let str_arr = arr.as_string::<i32>();
396            assert_eq!(str_arr.value(0), "yes"); // 1 is true
397            assert_eq!(str_arr.value(1), "no"); // 0 is false
398            assert_eq!(str_arr.value(2), "no"); // NULL is false
399            assert_eq!(str_arr.value(3), "yes"); // 5 is true
400        } else {
401            panic!("Expected Array result");
402        }
403    }
404}