common_function/scalars/string/
elt.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! MySQL-compatible ELT function implementation.
16//!
17//! ELT(N, str1, str2, str3, ...) - Returns the Nth string from the list.
18//! Returns NULL if N < 1 or N > number of strings.
19
20use std::fmt;
21use std::sync::Arc;
22
23use datafusion_common::DataFusionError;
24use datafusion_common::arrow::array::{Array, ArrayRef, AsArray, LargeStringBuilder};
25use datafusion_common::arrow::compute::cast;
26use datafusion_common::arrow::datatypes::DataType;
27use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
28
29use crate::function::Function;
30use crate::function_registry::FunctionRegistry;
31
32const NAME: &str = "elt";
33
34/// MySQL-compatible ELT function.
35///
36/// Syntax: ELT(N, str1, str2, str3, ...)
37/// Returns the Nth string argument. N is 1-based.
38/// Returns NULL if N is NULL, N < 1, or N > number of string arguments.
39#[derive(Debug)]
40pub struct EltFunction {
41    signature: Signature,
42}
43
44impl EltFunction {
45    pub fn register(registry: &FunctionRegistry) {
46        registry.register_scalar(EltFunction::default());
47    }
48}
49
50impl Default for EltFunction {
51    fn default() -> Self {
52        Self {
53            // ELT takes a variable number of arguments: (Int64, String, String, ...)
54            signature: Signature::variadic_any(Volatility::Immutable),
55        }
56    }
57}
58
59impl fmt::Display for EltFunction {
60    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
61        write!(f, "{}", NAME.to_ascii_uppercase())
62    }
63}
64
65impl Function for EltFunction {
66    fn name(&self) -> &str {
67        NAME
68    }
69
70    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
71        Ok(DataType::LargeUtf8)
72    }
73
74    fn signature(&self) -> &Signature {
75        &self.signature
76    }
77
78    fn invoke_with_args(
79        &self,
80        args: ScalarFunctionArgs,
81    ) -> datafusion_common::Result<ColumnarValue> {
82        if args.args.len() < 2 {
83            return Err(DataFusionError::Execution(
84                "ELT requires at least 2 arguments: ELT(N, str1, ...)".to_string(),
85            ));
86        }
87
88        let arrays = ColumnarValue::values_to_arrays(&args.args)?;
89        let len = arrays[0].len();
90        let num_strings = arrays.len() - 1;
91
92        // First argument is the index (N) - try to cast to Int64
93        let index_array = if arrays[0].data_type() == &DataType::Null {
94            // All NULLs - return all NULLs
95            let mut builder = LargeStringBuilder::with_capacity(len, 0);
96            for _ in 0..len {
97                builder.append_null();
98            }
99            return Ok(ColumnarValue::Array(Arc::new(builder.finish())));
100        } else {
101            cast(arrays[0].as_ref(), &DataType::Int64).map_err(|e| {
102                DataFusionError::Execution(format!("ELT: index argument cast failed: {}", e))
103            })?
104        };
105
106        // Cast string arguments to LargeUtf8
107        let string_arrays: Vec<ArrayRef> = arrays[1..]
108            .iter()
109            .enumerate()
110            .map(|(i, arr)| {
111                cast(arr.as_ref(), &DataType::LargeUtf8).map_err(|e| {
112                    DataFusionError::Execution(format!(
113                        "ELT: string argument {} cast failed: {}",
114                        i + 1,
115                        e
116                    ))
117                })
118            })
119            .collect::<datafusion_common::Result<Vec<_>>>()?;
120
121        let mut builder = LargeStringBuilder::with_capacity(len, len * 32);
122
123        for i in 0..len {
124            if index_array.is_null(i) {
125                builder.append_null();
126                continue;
127            }
128
129            let n = index_array
130                .as_primitive::<datafusion_common::arrow::datatypes::Int64Type>()
131                .value(i);
132
133            // N is 1-based, check bounds
134            if n < 1 || n as usize > num_strings {
135                builder.append_null();
136                continue;
137            }
138
139            let str_idx = (n - 1) as usize;
140            let str_array = string_arrays[str_idx].as_string::<i64>();
141
142            if str_array.is_null(i) {
143                builder.append_null();
144            } else {
145                builder.append_value(str_array.value(i));
146            }
147        }
148
149        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use std::sync::Arc;
156
157    use datafusion_common::arrow::array::{Int64Array, StringArray};
158    use datafusion_common::arrow::datatypes::Field;
159    use datafusion_expr::ScalarFunctionArgs;
160
161    use super::*;
162
163    fn create_args(arrays: Vec<ArrayRef>) -> ScalarFunctionArgs {
164        let arg_fields: Vec<_> = arrays
165            .iter()
166            .enumerate()
167            .map(|(i, arr)| {
168                Arc::new(Field::new(
169                    format!("arg_{}", i),
170                    arr.data_type().clone(),
171                    true,
172                ))
173            })
174            .collect();
175
176        ScalarFunctionArgs {
177            args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
178            arg_fields,
179            return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
180            number_rows: arrays[0].len(),
181            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
182        }
183    }
184
185    #[test]
186    fn test_elt_basic() {
187        let function = EltFunction::default();
188
189        let n = Arc::new(Int64Array::from(vec![1, 2, 3]));
190        let s1 = Arc::new(StringArray::from(vec!["a", "a", "a"]));
191        let s2 = Arc::new(StringArray::from(vec!["b", "b", "b"]));
192        let s3 = Arc::new(StringArray::from(vec!["c", "c", "c"]));
193
194        let args = create_args(vec![n, s1, s2, s3]);
195        let result = function.invoke_with_args(args).unwrap();
196
197        if let ColumnarValue::Array(array) = result {
198            let str_array = array.as_string::<i64>();
199            assert_eq!(str_array.value(0), "a");
200            assert_eq!(str_array.value(1), "b");
201            assert_eq!(str_array.value(2), "c");
202        } else {
203            panic!("Expected array result");
204        }
205    }
206
207    #[test]
208    fn test_elt_out_of_bounds() {
209        let function = EltFunction::default();
210
211        let n = Arc::new(Int64Array::from(vec![0, 4, -1]));
212        let s1 = Arc::new(StringArray::from(vec!["a", "a", "a"]));
213        let s2 = Arc::new(StringArray::from(vec!["b", "b", "b"]));
214        let s3 = Arc::new(StringArray::from(vec!["c", "c", "c"]));
215
216        let args = create_args(vec![n, s1, s2, s3]);
217        let result = function.invoke_with_args(args).unwrap();
218
219        if let ColumnarValue::Array(array) = result {
220            let str_array = array.as_string::<i64>();
221            assert!(str_array.is_null(0)); // 0 is out of bounds
222            assert!(str_array.is_null(1)); // 4 is out of bounds
223            assert!(str_array.is_null(2)); // -1 is out of bounds
224        } else {
225            panic!("Expected array result");
226        }
227    }
228
229    #[test]
230    fn test_elt_with_nulls() {
231        let function = EltFunction::default();
232
233        // Row 0: n=1, select s1="a" -> "a"
234        // Row 1: n=NULL -> NULL
235        // Row 2: n=1, select s1=NULL -> NULL
236        let n = Arc::new(Int64Array::from(vec![Some(1), None, Some(1)]));
237        let s1 = Arc::new(StringArray::from(vec![Some("a"), Some("a"), None]));
238        let s2 = Arc::new(StringArray::from(vec![Some("b"), Some("b"), Some("b")]));
239
240        let args = create_args(vec![n, s1, s2]);
241        let result = function.invoke_with_args(args).unwrap();
242
243        if let ColumnarValue::Array(array) = result {
244            let str_array = array.as_string::<i64>();
245            assert_eq!(str_array.value(0), "a");
246            assert!(str_array.is_null(1)); // N is NULL
247            assert!(str_array.is_null(2)); // Selected string is NULL
248        } else {
249            panic!("Expected array result");
250        }
251    }
252}