common_function/scalars/string/
locate.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! MySQL-compatible LOCATE function implementation.
16//!
17//! LOCATE(substr, str) - Returns the position of the first occurrence of substr in str (1-based).
18//! LOCATE(substr, str, pos) - Returns the position of the first occurrence of substr in str,
19//!                            starting from position pos.
20//! Returns 0 if substr is not found.
21
22use std::fmt;
23use std::sync::Arc;
24
25use datafusion_common::DataFusionError;
26use datafusion_common::arrow::array::{Array, ArrayRef, AsArray, Int64Builder};
27use datafusion_common::arrow::compute::cast;
28use datafusion_common::arrow::datatypes::DataType;
29use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
30
31use crate::function::Function;
32use crate::function_registry::FunctionRegistry;
33
34const NAME: &str = "locate";
35
36/// MySQL-compatible LOCATE function.
37///
38/// Syntax:
39/// - LOCATE(substr, str) - Returns 1-based position of substr in str, or 0 if not found.
40/// - LOCATE(substr, str, pos) - Same, but starts searching from position pos.
41#[derive(Debug)]
42pub struct LocateFunction {
43    signature: Signature,
44}
45
46impl LocateFunction {
47    pub fn register(registry: &FunctionRegistry) {
48        registry.register_scalar(LocateFunction::default());
49    }
50}
51
52impl Default for LocateFunction {
53    fn default() -> Self {
54        // Support 2 or 3 arguments with various string types
55        let mut signatures = Vec::new();
56        let string_types = [DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View];
57        let int_types = [
58            DataType::Int64,
59            DataType::Int32,
60            DataType::Int16,
61            DataType::Int8,
62            DataType::UInt64,
63            DataType::UInt32,
64            DataType::UInt16,
65            DataType::UInt8,
66        ];
67
68        // 2-argument form: LOCATE(substr, str)
69        for substr_type in &string_types {
70            for str_type in &string_types {
71                signatures.push(TypeSignature::Exact(vec![
72                    substr_type.clone(),
73                    str_type.clone(),
74                ]));
75            }
76        }
77
78        // 3-argument form: LOCATE(substr, str, pos)
79        for substr_type in &string_types {
80            for str_type in &string_types {
81                for pos_type in &int_types {
82                    signatures.push(TypeSignature::Exact(vec![
83                        substr_type.clone(),
84                        str_type.clone(),
85                        pos_type.clone(),
86                    ]));
87                }
88            }
89        }
90
91        Self {
92            signature: Signature::one_of(signatures, Volatility::Immutable),
93        }
94    }
95}
96
97impl fmt::Display for LocateFunction {
98    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
99        write!(f, "{}", NAME.to_ascii_uppercase())
100    }
101}
102
103impl Function for LocateFunction {
104    fn name(&self) -> &str {
105        NAME
106    }
107
108    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
109        Ok(DataType::Int64)
110    }
111
112    fn signature(&self) -> &Signature {
113        &self.signature
114    }
115
116    fn invoke_with_args(
117        &self,
118        args: ScalarFunctionArgs,
119    ) -> datafusion_common::Result<ColumnarValue> {
120        let arg_count = args.args.len();
121        if !(2..=3).contains(&arg_count) {
122            return Err(DataFusionError::Execution(
123                "LOCATE requires 2 or 3 arguments: LOCATE(substr, str) or LOCATE(substr, str, pos)"
124                    .to_string(),
125            ));
126        }
127
128        let arrays = ColumnarValue::values_to_arrays(&args.args)?;
129
130        // Cast string arguments to LargeUtf8 for uniform access
131        let substr_array = cast_to_large_utf8(&arrays[0], "substr")?;
132        let str_array = cast_to_large_utf8(&arrays[1], "str")?;
133
134        let substr = substr_array.as_string::<i64>();
135        let str_arr = str_array.as_string::<i64>();
136        let len = substr.len();
137
138        // Handle optional pos argument
139        let pos_array: Option<ArrayRef> = if arg_count == 3 {
140            Some(cast_to_int64(&arrays[2], "pos")?)
141        } else {
142            None
143        };
144
145        let mut builder = Int64Builder::with_capacity(len);
146
147        for i in 0..len {
148            if substr.is_null(i) || str_arr.is_null(i) {
149                builder.append_null();
150                continue;
151            }
152
153            let needle = substr.value(i);
154            let haystack = str_arr.value(i);
155
156            // Get starting position (1-based in MySQL, convert to 0-based)
157            let start_pos = if let Some(ref pos_arr) = pos_array {
158                if pos_arr.is_null(i) {
159                    builder.append_null();
160                    continue;
161                }
162                let pos = pos_arr
163                    .as_primitive::<datafusion_common::arrow::datatypes::Int64Type>()
164                    .value(i);
165                if pos < 1 {
166                    // MySQL returns 0 for pos < 1
167                    builder.append_value(0);
168                    continue;
169                }
170                (pos - 1) as usize
171            } else {
172                0
173            };
174
175            // Find position using character-based indexing (for Unicode support)
176            let result = locate_substr(haystack, needle, start_pos);
177            builder.append_value(result);
178        }
179
180        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
181    }
182}
183
184/// Cast array to LargeUtf8 for uniform string access.
185fn cast_to_large_utf8(array: &ArrayRef, name: &str) -> datafusion_common::Result<ArrayRef> {
186    cast(array.as_ref(), &DataType::LargeUtf8)
187        .map_err(|e| DataFusionError::Execution(format!("LOCATE: {} cast failed: {}", name, e)))
188}
189
190fn cast_to_int64(array: &ArrayRef, name: &str) -> datafusion_common::Result<ArrayRef> {
191    cast(array.as_ref(), &DataType::Int64)
192        .map_err(|e| DataFusionError::Execution(format!("LOCATE: {} cast failed: {}", name, e)))
193}
194
195/// Find the 1-based position of needle in haystack, starting from start_pos (0-based character index).
196/// Returns 0 if not found.
197fn locate_substr(haystack: &str, needle: &str, start_pos: usize) -> i64 {
198    // Handle empty needle - MySQL returns start_pos + 1
199    if needle.is_empty() {
200        let char_count = haystack.chars().count();
201        return if start_pos <= char_count {
202            (start_pos + 1) as i64
203        } else {
204            0
205        };
206    }
207
208    // Convert start_pos (character index) to byte index
209    let byte_start = haystack
210        .char_indices()
211        .nth(start_pos)
212        .map(|(idx, _)| idx)
213        .unwrap_or(haystack.len());
214
215    if byte_start >= haystack.len() {
216        return 0;
217    }
218
219    // Search in the substring
220    let search_str = &haystack[byte_start..];
221    if let Some(byte_pos) = search_str.find(needle) {
222        // Convert byte position back to character position
223        let char_pos = search_str[..byte_pos].chars().count();
224        // Return 1-based position relative to original string
225        (start_pos + char_pos + 1) as i64
226    } else {
227        0
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use std::sync::Arc;
234
235    use datafusion_common::arrow::array::StringArray;
236    use datafusion_common::arrow::datatypes::Field;
237    use datafusion_expr::ScalarFunctionArgs;
238
239    use super::*;
240
241    fn create_args(arrays: Vec<ArrayRef>) -> ScalarFunctionArgs {
242        let arg_fields: Vec<_> = arrays
243            .iter()
244            .enumerate()
245            .map(|(i, arr)| {
246                Arc::new(Field::new(
247                    format!("arg_{}", i),
248                    arr.data_type().clone(),
249                    true,
250                ))
251            })
252            .collect();
253
254        ScalarFunctionArgs {
255            args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
256            arg_fields,
257            return_field: Arc::new(Field::new("result", DataType::Int64, true)),
258            number_rows: arrays[0].len(),
259            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
260        }
261    }
262
263    #[test]
264    fn test_locate_basic() {
265        let function = LocateFunction::default();
266
267        let substr = Arc::new(StringArray::from(vec!["world", "xyz", "hello"]));
268        let str_arr = Arc::new(StringArray::from(vec![
269            "hello world",
270            "hello world",
271            "hello world",
272        ]));
273
274        let args = create_args(vec![substr, str_arr]);
275        let result = function.invoke_with_args(args).unwrap();
276
277        if let ColumnarValue::Array(array) = result {
278            let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
279            assert_eq!(int_array.value(0), 7); // "world" at position 7
280            assert_eq!(int_array.value(1), 0); // "xyz" not found
281            assert_eq!(int_array.value(2), 1); // "hello" at position 1
282        } else {
283            panic!("Expected array result");
284        }
285    }
286
287    #[test]
288    fn test_locate_with_position() {
289        let function = LocateFunction::default();
290
291        let substr = Arc::new(StringArray::from(vec!["o", "o", "o"]));
292        let str_arr = Arc::new(StringArray::from(vec![
293            "hello world",
294            "hello world",
295            "hello world",
296        ]));
297        let pos = Arc::new(datafusion_common::arrow::array::Int64Array::from(vec![
298            1, 5, 8,
299        ]));
300
301        let args = create_args(vec![substr, str_arr, pos]);
302        let result = function.invoke_with_args(args).unwrap();
303
304        if let ColumnarValue::Array(array) = result {
305            let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
306            assert_eq!(int_array.value(0), 5); // first 'o' at position 5
307            assert_eq!(int_array.value(1), 5); // 'o' at position 5 (start from 5)
308            assert_eq!(int_array.value(2), 8); // 'o' in "world" at position 8
309        } else {
310            panic!("Expected array result");
311        }
312    }
313
314    #[test]
315    fn test_locate_unicode() {
316        let function = LocateFunction::default();
317
318        let substr = Arc::new(StringArray::from(vec!["世", "界"]));
319        let str_arr = Arc::new(StringArray::from(vec!["hello世界", "hello世界"]));
320
321        let args = create_args(vec![substr, str_arr]);
322        let result = function.invoke_with_args(args).unwrap();
323
324        if let ColumnarValue::Array(array) = result {
325            let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
326            assert_eq!(int_array.value(0), 6); // "世" at position 6
327            assert_eq!(int_array.value(1), 7); // "界" at position 7
328        } else {
329            panic!("Expected array result");
330        }
331    }
332
333    #[test]
334    fn test_locate_empty_needle() {
335        let function = LocateFunction::default();
336
337        let substr = Arc::new(StringArray::from(vec!["", ""]));
338        let str_arr = Arc::new(StringArray::from(vec!["hello", "hello"]));
339        let pos = Arc::new(datafusion_common::arrow::array::Int64Array::from(vec![
340            1, 3,
341        ]));
342
343        let args = create_args(vec![substr, str_arr, pos]);
344        let result = function.invoke_with_args(args).unwrap();
345
346        if let ColumnarValue::Array(array) = result {
347            let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
348            assert_eq!(int_array.value(0), 1); // empty string at pos 1
349            assert_eq!(int_array.value(1), 3); // empty string at pos 3
350        } else {
351            panic!("Expected array result");
352        }
353    }
354
355    #[test]
356    fn test_locate_with_nulls() {
357        let function = LocateFunction::default();
358
359        let substr = Arc::new(StringArray::from(vec![Some("o"), None]));
360        let str_arr = Arc::new(StringArray::from(vec![Some("hello"), Some("hello")]));
361
362        let args = create_args(vec![substr, str_arr]);
363        let result = function.invoke_with_args(args).unwrap();
364
365        if let ColumnarValue::Array(array) = result {
366            let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
367            assert_eq!(int_array.value(0), 5);
368            assert!(int_array.is_null(1));
369        } else {
370            panic!("Expected array result");
371        }
372    }
373}