common_function/scalars/string/
field.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! MySQL-compatible FIELD function implementation.
16//!
17//! FIELD(str, str1, str2, str3, ...) - Returns the 1-based index of str in the list.
18//! Returns 0 if str is not found or is NULL.
19
20use std::fmt;
21use std::sync::Arc;
22
23use datafusion_common::DataFusionError;
24use datafusion_common::arrow::array::{Array, ArrayRef, AsArray, Int64Builder};
25use datafusion_common::arrow::compute::cast;
26use datafusion_common::arrow::datatypes::DataType;
27use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
28
29use crate::function::Function;
30use crate::function_registry::FunctionRegistry;
31
32const NAME: &str = "field";
33
34/// MySQL-compatible FIELD function.
35///
36/// Syntax: FIELD(str, str1, str2, str3, ...)
37/// Returns the 1-based index of str in the argument list (str1, str2, str3, ...).
38/// Returns 0 if str is not found or is NULL.
39#[derive(Debug)]
40pub struct FieldFunction {
41    signature: Signature,
42}
43
44impl FieldFunction {
45    pub fn register(registry: &FunctionRegistry) {
46        registry.register_scalar(FieldFunction::default());
47    }
48}
49
50impl Default for FieldFunction {
51    fn default() -> Self {
52        Self {
53            // FIELD takes a variable number of arguments: (String, String, String, ...)
54            signature: Signature::variadic_any(Volatility::Immutable),
55        }
56    }
57}
58
59impl fmt::Display for FieldFunction {
60    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
61        write!(f, "{}", NAME.to_ascii_uppercase())
62    }
63}
64
65impl Function for FieldFunction {
66    fn name(&self) -> &str {
67        NAME
68    }
69
70    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
71        Ok(DataType::Int64)
72    }
73
74    fn signature(&self) -> &Signature {
75        &self.signature
76    }
77
78    fn invoke_with_args(
79        &self,
80        args: ScalarFunctionArgs,
81    ) -> datafusion_common::Result<ColumnarValue> {
82        if args.args.len() < 2 {
83            return Err(DataFusionError::Execution(
84                "FIELD requires at least 2 arguments: FIELD(str, str1, ...)".to_string(),
85            ));
86        }
87
88        let arrays = ColumnarValue::values_to_arrays(&args.args)?;
89        let len = arrays[0].len();
90
91        // Cast all arguments to LargeUtf8
92        let string_arrays: Vec<ArrayRef> = arrays
93            .iter()
94            .enumerate()
95            .map(|(i, arr)| {
96                cast(arr.as_ref(), &DataType::LargeUtf8).map_err(|e| {
97                    DataFusionError::Execution(format!("FIELD: argument {} cast failed: {}", i, e))
98                })
99            })
100            .collect::<datafusion_common::Result<Vec<_>>>()?;
101
102        let search_str = string_arrays[0].as_string::<i64>();
103        let mut builder = Int64Builder::with_capacity(len);
104
105        for i in 0..len {
106            // If search string is NULL, return 0
107            if search_str.is_null(i) {
108                builder.append_value(0);
109                continue;
110            }
111
112            let needle = search_str.value(i);
113            let mut found_idx = 0i64;
114
115            // Search through the list (starting from index 1 in string_arrays)
116            for (j, str_arr) in string_arrays[1..].iter().enumerate() {
117                let str_array = str_arr.as_string::<i64>();
118                if !str_array.is_null(i) && str_array.value(i) == needle {
119                    found_idx = (j + 1) as i64; // 1-based index
120                    break;
121                }
122            }
123
124            builder.append_value(found_idx);
125        }
126
127        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use std::sync::Arc;
134
135    use datafusion_common::arrow::array::StringArray;
136    use datafusion_common::arrow::datatypes::Field;
137    use datafusion_expr::ScalarFunctionArgs;
138
139    use super::*;
140
141    fn create_args(arrays: Vec<ArrayRef>) -> ScalarFunctionArgs {
142        let arg_fields: Vec<_> = arrays
143            .iter()
144            .enumerate()
145            .map(|(i, arr)| {
146                Arc::new(Field::new(
147                    format!("arg_{}", i),
148                    arr.data_type().clone(),
149                    true,
150                ))
151            })
152            .collect();
153
154        ScalarFunctionArgs {
155            args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
156            arg_fields,
157            return_field: Arc::new(Field::new("result", DataType::Int64, true)),
158            number_rows: arrays[0].len(),
159            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
160        }
161    }
162
163    #[test]
164    fn test_field_basic() {
165        let function = FieldFunction::default();
166
167        let search = Arc::new(StringArray::from(vec!["b", "d", "a"]));
168        let s1 = Arc::new(StringArray::from(vec!["a", "a", "a"]));
169        let s2 = Arc::new(StringArray::from(vec!["b", "b", "b"]));
170        let s3 = Arc::new(StringArray::from(vec!["c", "c", "c"]));
171
172        let args = create_args(vec![search, s1, s2, s3]);
173        let result = function.invoke_with_args(args).unwrap();
174
175        if let ColumnarValue::Array(array) = result {
176            let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
177            assert_eq!(int_array.value(0), 2); // "b" is at index 2
178            assert_eq!(int_array.value(1), 0); // "d" not found
179            assert_eq!(int_array.value(2), 1); // "a" is at index 1
180        } else {
181            panic!("Expected array result");
182        }
183    }
184
185    #[test]
186    fn test_field_with_null_search() {
187        let function = FieldFunction::default();
188
189        let search = Arc::new(StringArray::from(vec![Some("a"), None]));
190        let s1 = Arc::new(StringArray::from(vec!["a", "a"]));
191        let s2 = Arc::new(StringArray::from(vec!["b", "b"]));
192
193        let args = create_args(vec![search, s1, s2]);
194        let result = function.invoke_with_args(args).unwrap();
195
196        if let ColumnarValue::Array(array) = result {
197            let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
198            assert_eq!(int_array.value(0), 1); // "a" found at index 1
199            assert_eq!(int_array.value(1), 0); // NULL search returns 0
200        } else {
201            panic!("Expected array result");
202        }
203    }
204
205    #[test]
206    fn test_field_case_sensitive() {
207        let function = FieldFunction::default();
208
209        let search = Arc::new(StringArray::from(vec!["A", "a"]));
210        let s1 = Arc::new(StringArray::from(vec!["a", "a"]));
211        let s2 = Arc::new(StringArray::from(vec!["A", "A"]));
212
213        let args = create_args(vec![search, s1, s2]);
214        let result = function.invoke_with_args(args).unwrap();
215
216        if let ColumnarValue::Array(array) = result {
217            let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
218            assert_eq!(int_array.value(0), 2); // "A" matches at index 2
219            assert_eq!(int_array.value(1), 1); // "a" matches at index 1
220        } else {
221            panic!("Expected array result");
222        }
223    }
224}