common_function/scalars/string/
locate.rs1use std::fmt;
23use std::sync::Arc;
24
25use datafusion_common::DataFusionError;
26use datafusion_common::arrow::array::{Array, ArrayRef, AsArray, Int64Builder};
27use datafusion_common::arrow::compute::cast;
28use datafusion_common::arrow::datatypes::DataType;
29use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
30
31use crate::function::Function;
32use crate::function_registry::FunctionRegistry;
33
34const NAME: &str = "locate";
35
36#[derive(Debug)]
42pub struct LocateFunction {
43 signature: Signature,
44}
45
46impl LocateFunction {
47 pub fn register(registry: &FunctionRegistry) {
48 registry.register_scalar(LocateFunction::default());
49 }
50}
51
52impl Default for LocateFunction {
53 fn default() -> Self {
54 let mut signatures = Vec::new();
56 let string_types = [DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View];
57 let int_types = [
58 DataType::Int64,
59 DataType::Int32,
60 DataType::Int16,
61 DataType::Int8,
62 DataType::UInt64,
63 DataType::UInt32,
64 DataType::UInt16,
65 DataType::UInt8,
66 ];
67
68 for substr_type in &string_types {
70 for str_type in &string_types {
71 signatures.push(TypeSignature::Exact(vec![
72 substr_type.clone(),
73 str_type.clone(),
74 ]));
75 }
76 }
77
78 for substr_type in &string_types {
80 for str_type in &string_types {
81 for pos_type in &int_types {
82 signatures.push(TypeSignature::Exact(vec![
83 substr_type.clone(),
84 str_type.clone(),
85 pos_type.clone(),
86 ]));
87 }
88 }
89 }
90
91 Self {
92 signature: Signature::one_of(signatures, Volatility::Immutable),
93 }
94 }
95}
96
97impl fmt::Display for LocateFunction {
98 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
99 write!(f, "{}", NAME.to_ascii_uppercase())
100 }
101}
102
103impl Function for LocateFunction {
104 fn name(&self) -> &str {
105 NAME
106 }
107
108 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
109 Ok(DataType::Int64)
110 }
111
112 fn signature(&self) -> &Signature {
113 &self.signature
114 }
115
116 fn invoke_with_args(
117 &self,
118 args: ScalarFunctionArgs,
119 ) -> datafusion_common::Result<ColumnarValue> {
120 let arg_count = args.args.len();
121 if !(2..=3).contains(&arg_count) {
122 return Err(DataFusionError::Execution(
123 "LOCATE requires 2 or 3 arguments: LOCATE(substr, str) or LOCATE(substr, str, pos)"
124 .to_string(),
125 ));
126 }
127
128 let arrays = ColumnarValue::values_to_arrays(&args.args)?;
129
130 let substr_array = cast_to_large_utf8(&arrays[0], "substr")?;
132 let str_array = cast_to_large_utf8(&arrays[1], "str")?;
133
134 let substr = substr_array.as_string::<i64>();
135 let str_arr = str_array.as_string::<i64>();
136 let len = substr.len();
137
138 let pos_array: Option<ArrayRef> = if arg_count == 3 {
140 Some(cast_to_int64(&arrays[2], "pos")?)
141 } else {
142 None
143 };
144
145 let mut builder = Int64Builder::with_capacity(len);
146
147 for i in 0..len {
148 if substr.is_null(i) || str_arr.is_null(i) {
149 builder.append_null();
150 continue;
151 }
152
153 let needle = substr.value(i);
154 let haystack = str_arr.value(i);
155
156 let start_pos = if let Some(ref pos_arr) = pos_array {
158 if pos_arr.is_null(i) {
159 builder.append_null();
160 continue;
161 }
162 let pos = pos_arr
163 .as_primitive::<datafusion_common::arrow::datatypes::Int64Type>()
164 .value(i);
165 if pos < 1 {
166 builder.append_value(0);
168 continue;
169 }
170 (pos - 1) as usize
171 } else {
172 0
173 };
174
175 let result = locate_substr(haystack, needle, start_pos);
177 builder.append_value(result);
178 }
179
180 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
181 }
182}
183
184fn cast_to_large_utf8(array: &ArrayRef, name: &str) -> datafusion_common::Result<ArrayRef> {
186 cast(array.as_ref(), &DataType::LargeUtf8)
187 .map_err(|e| DataFusionError::Execution(format!("LOCATE: {} cast failed: {}", name, e)))
188}
189
190fn cast_to_int64(array: &ArrayRef, name: &str) -> datafusion_common::Result<ArrayRef> {
191 cast(array.as_ref(), &DataType::Int64)
192 .map_err(|e| DataFusionError::Execution(format!("LOCATE: {} cast failed: {}", name, e)))
193}
194
195fn locate_substr(haystack: &str, needle: &str, start_pos: usize) -> i64 {
198 if needle.is_empty() {
200 let char_count = haystack.chars().count();
201 return if start_pos <= char_count {
202 (start_pos + 1) as i64
203 } else {
204 0
205 };
206 }
207
208 let byte_start = haystack
210 .char_indices()
211 .nth(start_pos)
212 .map(|(idx, _)| idx)
213 .unwrap_or(haystack.len());
214
215 if byte_start >= haystack.len() {
216 return 0;
217 }
218
219 let search_str = &haystack[byte_start..];
221 if let Some(byte_pos) = search_str.find(needle) {
222 let char_pos = search_str[..byte_pos].chars().count();
224 (start_pos + char_pos + 1) as i64
226 } else {
227 0
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use std::sync::Arc;
234
235 use datafusion_common::arrow::array::StringArray;
236 use datafusion_common::arrow::datatypes::Field;
237 use datafusion_expr::ScalarFunctionArgs;
238
239 use super::*;
240
241 fn create_args(arrays: Vec<ArrayRef>) -> ScalarFunctionArgs {
242 let arg_fields: Vec<_> = arrays
243 .iter()
244 .enumerate()
245 .map(|(i, arr)| {
246 Arc::new(Field::new(
247 format!("arg_{}", i),
248 arr.data_type().clone(),
249 true,
250 ))
251 })
252 .collect();
253
254 ScalarFunctionArgs {
255 args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
256 arg_fields,
257 return_field: Arc::new(Field::new("result", DataType::Int64, true)),
258 number_rows: arrays[0].len(),
259 config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
260 }
261 }
262
263 #[test]
264 fn test_locate_basic() {
265 let function = LocateFunction::default();
266
267 let substr = Arc::new(StringArray::from(vec!["world", "xyz", "hello"]));
268 let str_arr = Arc::new(StringArray::from(vec![
269 "hello world",
270 "hello world",
271 "hello world",
272 ]));
273
274 let args = create_args(vec![substr, str_arr]);
275 let result = function.invoke_with_args(args).unwrap();
276
277 if let ColumnarValue::Array(array) = result {
278 let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
279 assert_eq!(int_array.value(0), 7); assert_eq!(int_array.value(1), 0); assert_eq!(int_array.value(2), 1); } else {
283 panic!("Expected array result");
284 }
285 }
286
287 #[test]
288 fn test_locate_with_position() {
289 let function = LocateFunction::default();
290
291 let substr = Arc::new(StringArray::from(vec!["o", "o", "o"]));
292 let str_arr = Arc::new(StringArray::from(vec![
293 "hello world",
294 "hello world",
295 "hello world",
296 ]));
297 let pos = Arc::new(datafusion_common::arrow::array::Int64Array::from(vec![
298 1, 5, 8,
299 ]));
300
301 let args = create_args(vec![substr, str_arr, pos]);
302 let result = function.invoke_with_args(args).unwrap();
303
304 if let ColumnarValue::Array(array) = result {
305 let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
306 assert_eq!(int_array.value(0), 5); assert_eq!(int_array.value(1), 5); assert_eq!(int_array.value(2), 8); } else {
310 panic!("Expected array result");
311 }
312 }
313
314 #[test]
315 fn test_locate_unicode() {
316 let function = LocateFunction::default();
317
318 let substr = Arc::new(StringArray::from(vec!["世", "界"]));
319 let str_arr = Arc::new(StringArray::from(vec!["hello世界", "hello世界"]));
320
321 let args = create_args(vec![substr, str_arr]);
322 let result = function.invoke_with_args(args).unwrap();
323
324 if let ColumnarValue::Array(array) = result {
325 let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
326 assert_eq!(int_array.value(0), 6); assert_eq!(int_array.value(1), 7); } else {
329 panic!("Expected array result");
330 }
331 }
332
333 #[test]
334 fn test_locate_empty_needle() {
335 let function = LocateFunction::default();
336
337 let substr = Arc::new(StringArray::from(vec!["", ""]));
338 let str_arr = Arc::new(StringArray::from(vec!["hello", "hello"]));
339 let pos = Arc::new(datafusion_common::arrow::array::Int64Array::from(vec![
340 1, 3,
341 ]));
342
343 let args = create_args(vec![substr, str_arr, pos]);
344 let result = function.invoke_with_args(args).unwrap();
345
346 if let ColumnarValue::Array(array) = result {
347 let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
348 assert_eq!(int_array.value(0), 1); assert_eq!(int_array.value(1), 3); } else {
351 panic!("Expected array result");
352 }
353 }
354
355 #[test]
356 fn test_locate_with_nulls() {
357 let function = LocateFunction::default();
358
359 let substr = Arc::new(StringArray::from(vec![Some("o"), None]));
360 let str_arr = Arc::new(StringArray::from(vec![Some("hello"), Some("hello")]));
361
362 let args = create_args(vec![substr, str_arr]);
363 let result = function.invoke_with_args(args).unwrap();
364
365 if let ColumnarValue::Array(array) = result {
366 let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
367 assert_eq!(int_array.value(0), 5);
368 assert!(int_array.is_null(1));
369 } else {
370 panic!("Expected array result");
371 }
372 }
373}