common_function/scalars/string/
insert.rs1use std::fmt;
21use std::sync::Arc;
22
23use datafusion_common::DataFusionError;
24use datafusion_common::arrow::array::{Array, ArrayRef, AsArray, LargeStringBuilder};
25use datafusion_common::arrow::compute::cast;
26use datafusion_common::arrow::datatypes::DataType;
27use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
28
29use crate::function::Function;
30use crate::function_registry::FunctionRegistry;
31
32const NAME: &str = "insert";
33
34#[derive(Debug)]
44pub struct InsertFunction {
45 signature: Signature,
46}
47
48impl InsertFunction {
49 pub fn register(registry: &FunctionRegistry) {
50 registry.register_scalar(InsertFunction::default());
51 }
52}
53
54impl Default for InsertFunction {
55 fn default() -> Self {
56 let mut signatures = Vec::new();
57 let string_types = [DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View];
58 let int_types = [
59 DataType::Int64,
60 DataType::Int32,
61 DataType::Int16,
62 DataType::Int8,
63 DataType::UInt64,
64 DataType::UInt32,
65 DataType::UInt16,
66 DataType::UInt8,
67 ];
68
69 for str_type in &string_types {
70 for newstr_type in &string_types {
71 for pos_type in &int_types {
72 for len_type in &int_types {
73 signatures.push(TypeSignature::Exact(vec![
74 str_type.clone(),
75 pos_type.clone(),
76 len_type.clone(),
77 newstr_type.clone(),
78 ]));
79 }
80 }
81 }
82 }
83
84 Self {
85 signature: Signature::one_of(signatures, Volatility::Immutable),
86 }
87 }
88}
89
90impl fmt::Display for InsertFunction {
91 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
92 write!(f, "{}", NAME.to_ascii_uppercase())
93 }
94}
95
96impl Function for InsertFunction {
97 fn name(&self) -> &str {
98 NAME
99 }
100
101 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
102 Ok(DataType::LargeUtf8)
103 }
104
105 fn signature(&self) -> &Signature {
106 &self.signature
107 }
108
109 fn invoke_with_args(
110 &self,
111 args: ScalarFunctionArgs,
112 ) -> datafusion_common::Result<ColumnarValue> {
113 if args.args.len() != 4 {
114 return Err(DataFusionError::Execution(
115 "INSERT requires exactly 4 arguments: INSERT(str, pos, len, newstr)".to_string(),
116 ));
117 }
118
119 let arrays = ColumnarValue::values_to_arrays(&args.args)?;
120 let len = arrays[0].len();
121
122 let str_array = cast_to_large_utf8(&arrays[0], "str")?;
124 let newstr_array = cast_to_large_utf8(&arrays[3], "newstr")?;
125 let pos_array = cast_to_int64(&arrays[1], "pos")?;
126 let replace_len_array = cast_to_int64(&arrays[2], "len")?;
127
128 let str_arr = str_array.as_string::<i64>();
129 let pos_arr = pos_array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
130 let len_arr =
131 replace_len_array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
132 let newstr_arr = newstr_array.as_string::<i64>();
133
134 let mut builder = LargeStringBuilder::with_capacity(len, len * 32);
135
136 for i in 0..len {
137 if str_arr.is_null(i)
139 || pos_array.is_null(i)
140 || replace_len_array.is_null(i)
141 || newstr_arr.is_null(i)
142 {
143 builder.append_null();
144 continue;
145 }
146
147 let original = str_arr.value(i);
148 let pos = pos_arr.value(i);
149 let replace_len = len_arr.value(i);
150 let new_str = newstr_arr.value(i);
151
152 let result = insert_string(original, pos, replace_len, new_str);
153 builder.append_value(&result);
154 }
155
156 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
157 }
158}
159
160fn cast_to_large_utf8(array: &ArrayRef, name: &str) -> datafusion_common::Result<ArrayRef> {
162 cast(array.as_ref(), &DataType::LargeUtf8)
163 .map_err(|e| DataFusionError::Execution(format!("INSERT: {} cast failed: {}", name, e)))
164}
165
166fn cast_to_int64(array: &ArrayRef, name: &str) -> datafusion_common::Result<ArrayRef> {
167 cast(array.as_ref(), &DataType::Int64)
168 .map_err(|e| DataFusionError::Execution(format!("INSERT: {} cast failed: {}", name, e)))
169}
170
171fn insert_string(original: &str, pos: i64, replace_len: i64, new_str: &str) -> String {
174 let char_count = original.chars().count();
175
176 if pos < 1 || pos as usize > char_count + 1 {
178 return original.to_string();
179 }
180
181 let start_idx = (pos - 1) as usize; let replace_len = if replace_len < 0 {
185 0
186 } else {
187 replace_len as usize
188 };
189 let end_idx = (start_idx + replace_len).min(char_count);
190
191 let start_byte = char_to_byte_idx(original, start_idx);
192 let end_byte = char_to_byte_idx(original, end_idx);
193
194 let mut result = String::with_capacity(original.len() + new_str.len());
195 result.push_str(&original[..start_byte]);
196 result.push_str(new_str);
197 result.push_str(&original[end_byte..]);
198 result
199}
200
201fn char_to_byte_idx(s: &str, char_idx: usize) -> usize {
202 s.char_indices()
203 .nth(char_idx)
204 .map(|(idx, _)| idx)
205 .unwrap_or(s.len())
206}
207
208#[cfg(test)]
209mod tests {
210 use std::sync::Arc;
211
212 use datafusion_common::arrow::array::{Int64Array, StringArray};
213 use datafusion_common::arrow::datatypes::Field;
214 use datafusion_expr::ScalarFunctionArgs;
215
216 use super::*;
217
218 fn create_args(arrays: Vec<ArrayRef>) -> ScalarFunctionArgs {
219 let arg_fields: Vec<_> = arrays
220 .iter()
221 .enumerate()
222 .map(|(i, arr)| {
223 Arc::new(Field::new(
224 format!("arg_{}", i),
225 arr.data_type().clone(),
226 true,
227 ))
228 })
229 .collect();
230
231 ScalarFunctionArgs {
232 args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
233 arg_fields,
234 return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
235 number_rows: arrays[0].len(),
236 config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
237 }
238 }
239
240 #[test]
241 fn test_insert_basic() {
242 let function = InsertFunction::default();
243
244 let str_arr = Arc::new(StringArray::from(vec!["Quadratic"]));
246 let pos = Arc::new(Int64Array::from(vec![3]));
247 let len = Arc::new(Int64Array::from(vec![4]));
248 let newstr = Arc::new(StringArray::from(vec!["What"]));
249
250 let args = create_args(vec![str_arr, pos, len, newstr]);
251 let result = function.invoke_with_args(args).unwrap();
252
253 if let ColumnarValue::Array(array) = result {
254 let str_array = array.as_string::<i64>();
255 assert_eq!(str_array.value(0), "QuWhattic");
256 } else {
257 panic!("Expected array result");
258 }
259 }
260
261 #[test]
262 fn test_insert_out_of_range_pos() {
263 let function = InsertFunction::default();
264
265 let str_arr = Arc::new(StringArray::from(vec!["Quadratic", "Quadratic"]));
267 let pos = Arc::new(Int64Array::from(vec![0, 100]));
268 let len = Arc::new(Int64Array::from(vec![4, 4]));
269 let newstr = Arc::new(StringArray::from(vec!["What", "What"]));
270
271 let args = create_args(vec![str_arr, pos, len, newstr]);
272 let result = function.invoke_with_args(args).unwrap();
273
274 if let ColumnarValue::Array(array) = result {
275 let str_array = array.as_string::<i64>();
276 assert_eq!(str_array.value(0), "Quadratic"); assert_eq!(str_array.value(1), "Quadratic"); } else {
279 panic!("Expected array result");
280 }
281 }
282
283 #[test]
284 fn test_insert_replace_to_end() {
285 let function = InsertFunction::default();
286
287 let str_arr = Arc::new(StringArray::from(vec!["Quadratic"]));
289 let pos = Arc::new(Int64Array::from(vec![3]));
290 let len = Arc::new(Int64Array::from(vec![100]));
291 let newstr = Arc::new(StringArray::from(vec!["What"]));
292
293 let args = create_args(vec![str_arr, pos, len, newstr]);
294 let result = function.invoke_with_args(args).unwrap();
295
296 if let ColumnarValue::Array(array) = result {
297 let str_array = array.as_string::<i64>();
298 assert_eq!(str_array.value(0), "QuWhat");
299 } else {
300 panic!("Expected array result");
301 }
302 }
303
304 #[test]
305 fn test_insert_unicode() {
306 let function = InsertFunction::default();
307
308 let str_arr = Arc::new(StringArray::from(vec!["hello世界"]));
310 let pos = Arc::new(Int64Array::from(vec![6]));
311 let len = Arc::new(Int64Array::from(vec![1]));
312 let newstr = Arc::new(StringArray::from(vec!["の"]));
313
314 let args = create_args(vec![str_arr, pos, len, newstr]);
315 let result = function.invoke_with_args(args).unwrap();
316
317 if let ColumnarValue::Array(array) = result {
318 let str_array = array.as_string::<i64>();
319 assert_eq!(str_array.value(0), "helloの界");
320 } else {
321 panic!("Expected array result");
322 }
323 }
324
325 #[test]
326 fn test_insert_with_nulls() {
327 let function = InsertFunction::default();
328
329 let str_arr = Arc::new(StringArray::from(vec![Some("hello"), None]));
330 let pos = Arc::new(Int64Array::from(vec![1, 1]));
331 let len = Arc::new(Int64Array::from(vec![1, 1]));
332 let newstr = Arc::new(StringArray::from(vec!["X", "X"]));
333
334 let args = create_args(vec![str_arr, pos, len, newstr]);
335 let result = function.invoke_with_args(args).unwrap();
336
337 if let ColumnarValue::Array(array) = result {
338 let str_array = array.as_string::<i64>();
339 assert_eq!(str_array.value(0), "Xello");
340 assert!(str_array.is_null(1));
341 } else {
342 panic!("Expected array result");
343 }
344 }
345}