common_function/scalars/string/
field.rs1use std::fmt;
21use std::sync::Arc;
22
23use datafusion_common::DataFusionError;
24use datafusion_common::arrow::array::{Array, ArrayRef, AsArray, Int64Builder};
25use datafusion_common::arrow::compute::cast;
26use datafusion_common::arrow::datatypes::DataType;
27use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
28
29use crate::function::Function;
30use crate::function_registry::FunctionRegistry;
31
32const NAME: &str = "field";
33
34#[derive(Debug)]
40pub struct FieldFunction {
41 signature: Signature,
42}
43
44impl FieldFunction {
45 pub fn register(registry: &FunctionRegistry) {
46 registry.register_scalar(FieldFunction::default());
47 }
48}
49
50impl Default for FieldFunction {
51 fn default() -> Self {
52 Self {
53 signature: Signature::variadic_any(Volatility::Immutable),
55 }
56 }
57}
58
59impl fmt::Display for FieldFunction {
60 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
61 write!(f, "{}", NAME.to_ascii_uppercase())
62 }
63}
64
65impl Function for FieldFunction {
66 fn name(&self) -> &str {
67 NAME
68 }
69
70 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
71 Ok(DataType::Int64)
72 }
73
74 fn signature(&self) -> &Signature {
75 &self.signature
76 }
77
78 fn invoke_with_args(
79 &self,
80 args: ScalarFunctionArgs,
81 ) -> datafusion_common::Result<ColumnarValue> {
82 if args.args.len() < 2 {
83 return Err(DataFusionError::Execution(
84 "FIELD requires at least 2 arguments: FIELD(str, str1, ...)".to_string(),
85 ));
86 }
87
88 let arrays = ColumnarValue::values_to_arrays(&args.args)?;
89 let len = arrays[0].len();
90
91 let string_arrays: Vec<ArrayRef> = arrays
93 .iter()
94 .enumerate()
95 .map(|(i, arr)| {
96 cast(arr.as_ref(), &DataType::LargeUtf8).map_err(|e| {
97 DataFusionError::Execution(format!("FIELD: argument {} cast failed: {}", i, e))
98 })
99 })
100 .collect::<datafusion_common::Result<Vec<_>>>()?;
101
102 let search_str = string_arrays[0].as_string::<i64>();
103 let mut builder = Int64Builder::with_capacity(len);
104
105 for i in 0..len {
106 if search_str.is_null(i) {
108 builder.append_value(0);
109 continue;
110 }
111
112 let needle = search_str.value(i);
113 let mut found_idx = 0i64;
114
115 for (j, str_arr) in string_arrays[1..].iter().enumerate() {
117 let str_array = str_arr.as_string::<i64>();
118 if !str_array.is_null(i) && str_array.value(i) == needle {
119 found_idx = (j + 1) as i64; break;
121 }
122 }
123
124 builder.append_value(found_idx);
125 }
126
127 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use std::sync::Arc;
134
135 use datafusion_common::arrow::array::StringArray;
136 use datafusion_common::arrow::datatypes::Field;
137 use datafusion_expr::ScalarFunctionArgs;
138
139 use super::*;
140
141 fn create_args(arrays: Vec<ArrayRef>) -> ScalarFunctionArgs {
142 let arg_fields: Vec<_> = arrays
143 .iter()
144 .enumerate()
145 .map(|(i, arr)| {
146 Arc::new(Field::new(
147 format!("arg_{}", i),
148 arr.data_type().clone(),
149 true,
150 ))
151 })
152 .collect();
153
154 ScalarFunctionArgs {
155 args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
156 arg_fields,
157 return_field: Arc::new(Field::new("result", DataType::Int64, true)),
158 number_rows: arrays[0].len(),
159 config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
160 }
161 }
162
163 #[test]
164 fn test_field_basic() {
165 let function = FieldFunction::default();
166
167 let search = Arc::new(StringArray::from(vec!["b", "d", "a"]));
168 let s1 = Arc::new(StringArray::from(vec!["a", "a", "a"]));
169 let s2 = Arc::new(StringArray::from(vec!["b", "b", "b"]));
170 let s3 = Arc::new(StringArray::from(vec!["c", "c", "c"]));
171
172 let args = create_args(vec![search, s1, s2, s3]);
173 let result = function.invoke_with_args(args).unwrap();
174
175 if let ColumnarValue::Array(array) = result {
176 let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
177 assert_eq!(int_array.value(0), 2); assert_eq!(int_array.value(1), 0); assert_eq!(int_array.value(2), 1); } else {
181 panic!("Expected array result");
182 }
183 }
184
185 #[test]
186 fn test_field_with_null_search() {
187 let function = FieldFunction::default();
188
189 let search = Arc::new(StringArray::from(vec![Some("a"), None]));
190 let s1 = Arc::new(StringArray::from(vec!["a", "a"]));
191 let s2 = Arc::new(StringArray::from(vec!["b", "b"]));
192
193 let args = create_args(vec![search, s1, s2]);
194 let result = function.invoke_with_args(args).unwrap();
195
196 if let ColumnarValue::Array(array) = result {
197 let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
198 assert_eq!(int_array.value(0), 1); assert_eq!(int_array.value(1), 0); } else {
201 panic!("Expected array result");
202 }
203 }
204
205 #[test]
206 fn test_field_case_sensitive() {
207 let function = FieldFunction::default();
208
209 let search = Arc::new(StringArray::from(vec!["A", "a"]));
210 let s1 = Arc::new(StringArray::from(vec!["a", "a"]));
211 let s2 = Arc::new(StringArray::from(vec!["A", "A"]));
212
213 let args = create_args(vec![search, s1, s2]);
214 let result = function.invoke_with_args(args).unwrap();
215
216 if let ColumnarValue::Array(array) = result {
217 let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
218 assert_eq!(int_array.value(0), 2); assert_eq!(int_array.value(1), 1); } else {
221 panic!("Expected array result");
222 }
223 }
224}