common_function/scalars/string/
elt.rs1use std::fmt;
21use std::sync::Arc;
22
23use datafusion_common::DataFusionError;
24use datafusion_common::arrow::array::{Array, ArrayRef, AsArray, LargeStringBuilder};
25use datafusion_common::arrow::compute::cast;
26use datafusion_common::arrow::datatypes::DataType;
27use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
28
29use crate::function::Function;
30use crate::function_registry::FunctionRegistry;
31
32const NAME: &str = "elt";
33
34#[derive(Debug)]
40pub struct EltFunction {
41 signature: Signature,
42}
43
44impl EltFunction {
45 pub fn register(registry: &FunctionRegistry) {
46 registry.register_scalar(EltFunction::default());
47 }
48}
49
50impl Default for EltFunction {
51 fn default() -> Self {
52 Self {
53 signature: Signature::variadic_any(Volatility::Immutable),
55 }
56 }
57}
58
59impl fmt::Display for EltFunction {
60 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
61 write!(f, "{}", NAME.to_ascii_uppercase())
62 }
63}
64
65impl Function for EltFunction {
66 fn name(&self) -> &str {
67 NAME
68 }
69
70 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
71 Ok(DataType::LargeUtf8)
72 }
73
74 fn signature(&self) -> &Signature {
75 &self.signature
76 }
77
78 fn invoke_with_args(
79 &self,
80 args: ScalarFunctionArgs,
81 ) -> datafusion_common::Result<ColumnarValue> {
82 if args.args.len() < 2 {
83 return Err(DataFusionError::Execution(
84 "ELT requires at least 2 arguments: ELT(N, str1, ...)".to_string(),
85 ));
86 }
87
88 let arrays = ColumnarValue::values_to_arrays(&args.args)?;
89 let len = arrays[0].len();
90 let num_strings = arrays.len() - 1;
91
92 let index_array = if arrays[0].data_type() == &DataType::Null {
94 let mut builder = LargeStringBuilder::with_capacity(len, 0);
96 for _ in 0..len {
97 builder.append_null();
98 }
99 return Ok(ColumnarValue::Array(Arc::new(builder.finish())));
100 } else {
101 cast(arrays[0].as_ref(), &DataType::Int64).map_err(|e| {
102 DataFusionError::Execution(format!("ELT: index argument cast failed: {}", e))
103 })?
104 };
105
106 let string_arrays: Vec<ArrayRef> = arrays[1..]
108 .iter()
109 .enumerate()
110 .map(|(i, arr)| {
111 cast(arr.as_ref(), &DataType::LargeUtf8).map_err(|e| {
112 DataFusionError::Execution(format!(
113 "ELT: string argument {} cast failed: {}",
114 i + 1,
115 e
116 ))
117 })
118 })
119 .collect::<datafusion_common::Result<Vec<_>>>()?;
120
121 let mut builder = LargeStringBuilder::with_capacity(len, len * 32);
122
123 for i in 0..len {
124 if index_array.is_null(i) {
125 builder.append_null();
126 continue;
127 }
128
129 let n = index_array
130 .as_primitive::<datafusion_common::arrow::datatypes::Int64Type>()
131 .value(i);
132
133 if n < 1 || n as usize > num_strings {
135 builder.append_null();
136 continue;
137 }
138
139 let str_idx = (n - 1) as usize;
140 let str_array = string_arrays[str_idx].as_string::<i64>();
141
142 if str_array.is_null(i) {
143 builder.append_null();
144 } else {
145 builder.append_value(str_array.value(i));
146 }
147 }
148
149 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use std::sync::Arc;
156
157 use datafusion_common::arrow::array::{Int64Array, StringArray};
158 use datafusion_common::arrow::datatypes::Field;
159 use datafusion_expr::ScalarFunctionArgs;
160
161 use super::*;
162
163 fn create_args(arrays: Vec<ArrayRef>) -> ScalarFunctionArgs {
164 let arg_fields: Vec<_> = arrays
165 .iter()
166 .enumerate()
167 .map(|(i, arr)| {
168 Arc::new(Field::new(
169 format!("arg_{}", i),
170 arr.data_type().clone(),
171 true,
172 ))
173 })
174 .collect();
175
176 ScalarFunctionArgs {
177 args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
178 arg_fields,
179 return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
180 number_rows: arrays[0].len(),
181 config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
182 }
183 }
184
185 #[test]
186 fn test_elt_basic() {
187 let function = EltFunction::default();
188
189 let n = Arc::new(Int64Array::from(vec![1, 2, 3]));
190 let s1 = Arc::new(StringArray::from(vec!["a", "a", "a"]));
191 let s2 = Arc::new(StringArray::from(vec!["b", "b", "b"]));
192 let s3 = Arc::new(StringArray::from(vec!["c", "c", "c"]));
193
194 let args = create_args(vec![n, s1, s2, s3]);
195 let result = function.invoke_with_args(args).unwrap();
196
197 if let ColumnarValue::Array(array) = result {
198 let str_array = array.as_string::<i64>();
199 assert_eq!(str_array.value(0), "a");
200 assert_eq!(str_array.value(1), "b");
201 assert_eq!(str_array.value(2), "c");
202 } else {
203 panic!("Expected array result");
204 }
205 }
206
207 #[test]
208 fn test_elt_out_of_bounds() {
209 let function = EltFunction::default();
210
211 let n = Arc::new(Int64Array::from(vec![0, 4, -1]));
212 let s1 = Arc::new(StringArray::from(vec!["a", "a", "a"]));
213 let s2 = Arc::new(StringArray::from(vec!["b", "b", "b"]));
214 let s3 = Arc::new(StringArray::from(vec!["c", "c", "c"]));
215
216 let args = create_args(vec![n, s1, s2, s3]);
217 let result = function.invoke_with_args(args).unwrap();
218
219 if let ColumnarValue::Array(array) = result {
220 let str_array = array.as_string::<i64>();
221 assert!(str_array.is_null(0)); assert!(str_array.is_null(1)); assert!(str_array.is_null(2)); } else {
225 panic!("Expected array result");
226 }
227 }
228
229 #[test]
230 fn test_elt_with_nulls() {
231 let function = EltFunction::default();
232
233 let n = Arc::new(Int64Array::from(vec![Some(1), None, Some(1)]));
237 let s1 = Arc::new(StringArray::from(vec![Some("a"), Some("a"), None]));
238 let s2 = Arc::new(StringArray::from(vec![Some("b"), Some("b"), Some("b")]));
239
240 let args = create_args(vec![n, s1, s2]);
241 let result = function.invoke_with_args(args).unwrap();
242
243 if let ColumnarValue::Array(array) = result {
244 let str_array = array.as_string::<i64>();
245 assert_eq!(str_array.value(0), "a");
246 assert!(str_array.is_null(1)); assert!(str_array.is_null(2)); } else {
249 panic!("Expected array result");
250 }
251 }
252}