common_function/scalars/string/
insert.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! MySQL-compatible INSERT function implementation.
16//!
17//! INSERT(str, pos, len, newstr) - Inserts newstr into str at position pos,
18//! replacing len characters.
19
20use std::fmt;
21use std::sync::Arc;
22
23use datafusion_common::DataFusionError;
24use datafusion_common::arrow::array::{Array, ArrayRef, AsArray, LargeStringBuilder};
25use datafusion_common::arrow::compute::cast;
26use datafusion_common::arrow::datatypes::DataType;
27use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
28
29use crate::function::Function;
30use crate::function_registry::FunctionRegistry;
31
32const NAME: &str = "insert";
33
34/// MySQL-compatible INSERT function.
35///
36/// Syntax: INSERT(str, pos, len, newstr)
37/// Returns str with the substring beginning at position pos and len characters long
38/// replaced by newstr.
39///
40/// - pos is 1-based
41/// - If pos is out of range, returns the original string
42/// - If len is out of range, replaces from pos to end of string
43#[derive(Debug)]
44pub struct InsertFunction {
45    signature: Signature,
46}
47
48impl InsertFunction {
49    pub fn register(registry: &FunctionRegistry) {
50        registry.register_scalar(InsertFunction::default());
51    }
52}
53
54impl Default for InsertFunction {
55    fn default() -> Self {
56        let mut signatures = Vec::new();
57        let string_types = [DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View];
58        let int_types = [
59            DataType::Int64,
60            DataType::Int32,
61            DataType::Int16,
62            DataType::Int8,
63            DataType::UInt64,
64            DataType::UInt32,
65            DataType::UInt16,
66            DataType::UInt8,
67        ];
68
69        for str_type in &string_types {
70            for newstr_type in &string_types {
71                for pos_type in &int_types {
72                    for len_type in &int_types {
73                        signatures.push(TypeSignature::Exact(vec![
74                            str_type.clone(),
75                            pos_type.clone(),
76                            len_type.clone(),
77                            newstr_type.clone(),
78                        ]));
79                    }
80                }
81            }
82        }
83
84        Self {
85            signature: Signature::one_of(signatures, Volatility::Immutable),
86        }
87    }
88}
89
90impl fmt::Display for InsertFunction {
91    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
92        write!(f, "{}", NAME.to_ascii_uppercase())
93    }
94}
95
96impl Function for InsertFunction {
97    fn name(&self) -> &str {
98        NAME
99    }
100
101    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
102        Ok(DataType::LargeUtf8)
103    }
104
105    fn signature(&self) -> &Signature {
106        &self.signature
107    }
108
109    fn invoke_with_args(
110        &self,
111        args: ScalarFunctionArgs,
112    ) -> datafusion_common::Result<ColumnarValue> {
113        if args.args.len() != 4 {
114            return Err(DataFusionError::Execution(
115                "INSERT requires exactly 4 arguments: INSERT(str, pos, len, newstr)".to_string(),
116            ));
117        }
118
119        let arrays = ColumnarValue::values_to_arrays(&args.args)?;
120        let len = arrays[0].len();
121
122        // Cast string arguments to LargeUtf8
123        let str_array = cast_to_large_utf8(&arrays[0], "str")?;
124        let newstr_array = cast_to_large_utf8(&arrays[3], "newstr")?;
125        let pos_array = cast_to_int64(&arrays[1], "pos")?;
126        let replace_len_array = cast_to_int64(&arrays[2], "len")?;
127
128        let str_arr = str_array.as_string::<i64>();
129        let pos_arr = pos_array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
130        let len_arr =
131            replace_len_array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
132        let newstr_arr = newstr_array.as_string::<i64>();
133
134        let mut builder = LargeStringBuilder::with_capacity(len, len * 32);
135
136        for i in 0..len {
137            // Check for NULLs
138            if str_arr.is_null(i)
139                || pos_array.is_null(i)
140                || replace_len_array.is_null(i)
141                || newstr_arr.is_null(i)
142            {
143                builder.append_null();
144                continue;
145            }
146
147            let original = str_arr.value(i);
148            let pos = pos_arr.value(i);
149            let replace_len = len_arr.value(i);
150            let new_str = newstr_arr.value(i);
151
152            let result = insert_string(original, pos, replace_len, new_str);
153            builder.append_value(&result);
154        }
155
156        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
157    }
158}
159
160/// Cast array to LargeUtf8 for uniform string access.
161fn cast_to_large_utf8(array: &ArrayRef, name: &str) -> datafusion_common::Result<ArrayRef> {
162    cast(array.as_ref(), &DataType::LargeUtf8)
163        .map_err(|e| DataFusionError::Execution(format!("INSERT: {} cast failed: {}", name, e)))
164}
165
166fn cast_to_int64(array: &ArrayRef, name: &str) -> datafusion_common::Result<ArrayRef> {
167    cast(array.as_ref(), &DataType::Int64)
168        .map_err(|e| DataFusionError::Execution(format!("INSERT: {} cast failed: {}", name, e)))
169}
170
171/// Perform the INSERT string operation.
172/// pos is 1-based. If pos < 1 or pos > len(str) + 1, returns original string.
173fn insert_string(original: &str, pos: i64, replace_len: i64, new_str: &str) -> String {
174    let char_count = original.chars().count();
175
176    // MySQL behavior: if pos < 1 or pos > string length + 1, return original
177    if pos < 1 || pos as usize > char_count + 1 {
178        return original.to_string();
179    }
180
181    let start_idx = (pos - 1) as usize; // Convert to 0-based
182
183    // Calculate end index for replacement
184    let replace_len = if replace_len < 0 {
185        0
186    } else {
187        replace_len as usize
188    };
189    let end_idx = (start_idx + replace_len).min(char_count);
190
191    let start_byte = char_to_byte_idx(original, start_idx);
192    let end_byte = char_to_byte_idx(original, end_idx);
193
194    let mut result = String::with_capacity(original.len() + new_str.len());
195    result.push_str(&original[..start_byte]);
196    result.push_str(new_str);
197    result.push_str(&original[end_byte..]);
198    result
199}
200
201fn char_to_byte_idx(s: &str, char_idx: usize) -> usize {
202    s.char_indices()
203        .nth(char_idx)
204        .map(|(idx, _)| idx)
205        .unwrap_or(s.len())
206}
207
208#[cfg(test)]
209mod tests {
210    use std::sync::Arc;
211
212    use datafusion_common::arrow::array::{Int64Array, StringArray};
213    use datafusion_common::arrow::datatypes::Field;
214    use datafusion_expr::ScalarFunctionArgs;
215
216    use super::*;
217
218    fn create_args(arrays: Vec<ArrayRef>) -> ScalarFunctionArgs {
219        let arg_fields: Vec<_> = arrays
220            .iter()
221            .enumerate()
222            .map(|(i, arr)| {
223                Arc::new(Field::new(
224                    format!("arg_{}", i),
225                    arr.data_type().clone(),
226                    true,
227                ))
228            })
229            .collect();
230
231        ScalarFunctionArgs {
232            args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
233            arg_fields,
234            return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
235            number_rows: arrays[0].len(),
236            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
237        }
238    }
239
240    #[test]
241    fn test_insert_basic() {
242        let function = InsertFunction::default();
243
244        // INSERT('Quadratic', 3, 4, 'What') => 'QuWhattic'
245        let str_arr = Arc::new(StringArray::from(vec!["Quadratic"]));
246        let pos = Arc::new(Int64Array::from(vec![3]));
247        let len = Arc::new(Int64Array::from(vec![4]));
248        let newstr = Arc::new(StringArray::from(vec!["What"]));
249
250        let args = create_args(vec![str_arr, pos, len, newstr]);
251        let result = function.invoke_with_args(args).unwrap();
252
253        if let ColumnarValue::Array(array) = result {
254            let str_array = array.as_string::<i64>();
255            assert_eq!(str_array.value(0), "QuWhattic");
256        } else {
257            panic!("Expected array result");
258        }
259    }
260
261    #[test]
262    fn test_insert_out_of_range_pos() {
263        let function = InsertFunction::default();
264
265        // INSERT('Quadratic', 0, 4, 'What') => 'Quadratic' (pos < 1)
266        let str_arr = Arc::new(StringArray::from(vec!["Quadratic", "Quadratic"]));
267        let pos = Arc::new(Int64Array::from(vec![0, 100]));
268        let len = Arc::new(Int64Array::from(vec![4, 4]));
269        let newstr = Arc::new(StringArray::from(vec!["What", "What"]));
270
271        let args = create_args(vec![str_arr, pos, len, newstr]);
272        let result = function.invoke_with_args(args).unwrap();
273
274        if let ColumnarValue::Array(array) = result {
275            let str_array = array.as_string::<i64>();
276            assert_eq!(str_array.value(0), "Quadratic"); // pos < 1
277            assert_eq!(str_array.value(1), "Quadratic"); // pos > length
278        } else {
279            panic!("Expected array result");
280        }
281    }
282
283    #[test]
284    fn test_insert_replace_to_end() {
285        let function = InsertFunction::default();
286
287        // INSERT('Quadratic', 3, 100, 'What') => 'QuWhat' (len exceeds remaining)
288        let str_arr = Arc::new(StringArray::from(vec!["Quadratic"]));
289        let pos = Arc::new(Int64Array::from(vec![3]));
290        let len = Arc::new(Int64Array::from(vec![100]));
291        let newstr = Arc::new(StringArray::from(vec!["What"]));
292
293        let args = create_args(vec![str_arr, pos, len, newstr]);
294        let result = function.invoke_with_args(args).unwrap();
295
296        if let ColumnarValue::Array(array) = result {
297            let str_array = array.as_string::<i64>();
298            assert_eq!(str_array.value(0), "QuWhat");
299        } else {
300            panic!("Expected array result");
301        }
302    }
303
304    #[test]
305    fn test_insert_unicode() {
306        let function = InsertFunction::default();
307
308        // INSERT('hello世界', 6, 1, 'の') => 'helloの界'
309        let str_arr = Arc::new(StringArray::from(vec!["hello世界"]));
310        let pos = Arc::new(Int64Array::from(vec![6]));
311        let len = Arc::new(Int64Array::from(vec![1]));
312        let newstr = Arc::new(StringArray::from(vec!["の"]));
313
314        let args = create_args(vec![str_arr, pos, len, newstr]);
315        let result = function.invoke_with_args(args).unwrap();
316
317        if let ColumnarValue::Array(array) = result {
318            let str_array = array.as_string::<i64>();
319            assert_eq!(str_array.value(0), "helloの界");
320        } else {
321            panic!("Expected array result");
322        }
323    }
324
325    #[test]
326    fn test_insert_with_nulls() {
327        let function = InsertFunction::default();
328
329        let str_arr = Arc::new(StringArray::from(vec![Some("hello"), None]));
330        let pos = Arc::new(Int64Array::from(vec![1, 1]));
331        let len = Arc::new(Int64Array::from(vec![1, 1]));
332        let newstr = Arc::new(StringArray::from(vec!["X", "X"]));
333
334        let args = create_args(vec![str_arr, pos, len, newstr]);
335        let result = function.invoke_with_args(args).unwrap();
336
337        if let ColumnarValue::Array(array) = result {
338            let str_array = array.as_string::<i64>();
339            assert_eq!(str_array.value(0), "Xello");
340            assert!(str_array.is_null(1));
341        } else {
342            panic!("Expected array result");
343        }
344    }
345}