common_function/scalars/vector/
vector_kth_elem.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16
17use datafusion::logical_expr::ColumnarValue;
18use datafusion_common::{DataFusionError, ScalarValue};
19use datafusion_expr::{ScalarFunctionArgs, Signature};
20use datatypes::arrow::datatypes::DataType;
21
22use crate::function::Function;
23use crate::helper;
24use crate::scalars::vector::VectorCalculator;
25use crate::scalars::vector::impl_conv::as_veclit;
26
27const NAME: &str = "vec_kth_elem";
28
29/// Returns the k-th(0-based index) element of the vector.
30///
31/// # Example
32///
33/// ```sql
34/// SELECT vec_kth_elem("[2, 4, 6]",1) as result;
35///
36/// +---------+
37/// | result  |
38/// +---------+
39/// | 4 |
40/// +---------+
41///
42/// ```
43///
44
45#[derive(Debug, Clone)]
46pub(crate) struct VectorKthElemFunction {
47    signature: Signature,
48}
49
50impl Default for VectorKthElemFunction {
51    fn default() -> Self {
52        Self {
53            signature: helper::one_of_sigs2(
54                vec![DataType::Utf8, DataType::Binary],
55                vec![DataType::Int64],
56            ),
57        }
58    }
59}
60
61impl Function for VectorKthElemFunction {
62    fn name(&self) -> &str {
63        NAME
64    }
65
66    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
67        Ok(DataType::Float32)
68    }
69
70    fn signature(&self) -> &Signature {
71        &self.signature
72    }
73
74    fn invoke_with_args(
75        &self,
76        args: ScalarFunctionArgs,
77    ) -> datafusion_common::Result<ColumnarValue> {
78        let body = |v0: &ScalarValue, v1: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
79            let v0 = as_veclit(v0)?;
80
81            let v1 = match v1 {
82                ScalarValue::Int64(None) => return Ok(ScalarValue::Float32(None)),
83                ScalarValue::Int64(Some(v1)) if *v1 >= 0 => *v1 as usize,
84                _ => {
85                    return Err(DataFusionError::Execution(format!(
86                        "2nd argument not a valid index or expected datatype: {}",
87                        self.name()
88                    )));
89                }
90            };
91
92            let result = v0
93                .map(|v0| {
94                    if v1 >= v0.len() {
95                        Err(DataFusionError::Execution(format!(
96                            "index out of bound: {}",
97                            self.name()
98                        )))
99                    } else {
100                        Ok(v0[v1])
101                    }
102                })
103                .transpose()?;
104            Ok(ScalarValue::Float32(result))
105        };
106
107        let calculator = VectorCalculator {
108            name: self.name(),
109            func: body,
110        };
111        calculator.invoke_with_args(args)
112    }
113}
114
115impl Display for VectorKthElemFunction {
116    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117        write!(f, "{}", NAME.to_ascii_uppercase())
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use std::sync::Arc;
124
125    use arrow_schema::Field;
126    use datafusion::arrow::array::{Array, ArrayRef, AsArray, Int64Array, StringViewArray};
127    use datafusion::arrow::datatypes::Float32Type;
128    use datafusion_common::config::ConfigOptions;
129
130    use super::*;
131
132    #[test]
133    fn test_vec_kth_elem() {
134        let func = VectorKthElemFunction::default();
135
136        let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
137            Some("[1.0,2.0,3.0]".to_string()),
138            Some("[4.0,5.0,6.0]".to_string()),
139            Some("[7.0,8.0,9.0]".to_string()),
140            None,
141        ]));
142        let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(0), Some(2), None, Some(1)]));
143
144        let args = ScalarFunctionArgs {
145            args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
146            arg_fields: vec![],
147            number_rows: 4,
148            return_field: Arc::new(Field::new("x", DataType::Float32, false)),
149            config_options: Arc::new(ConfigOptions::new()),
150        };
151        let result = func
152            .invoke_with_args(args)
153            .and_then(|x| x.to_array(4))
154            .unwrap();
155
156        let result = result.as_primitive::<Float32Type>();
157        assert_eq!(result.len(), 4);
158        assert_eq!(result.value(0), 1.0);
159        assert_eq!(result.value(1), 6.0);
160        assert!(result.is_null(2));
161        assert!(result.is_null(3));
162
163        let input0: ArrayRef = Arc::new(StringViewArray::from(vec![Some(
164            "[1.0,2.0,3.0]".to_string(),
165        )]));
166        let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(3)]));
167
168        let args = ScalarFunctionArgs {
169            args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
170            arg_fields: vec![],
171            number_rows: 3,
172            return_field: Arc::new(Field::new("x", DataType::Float32, false)),
173            config_options: Arc::new(ConfigOptions::new()),
174        };
175        let e = func.invoke_with_args(args).unwrap_err();
176        assert!(
177            e.to_string()
178                .starts_with("Execution error: index out of bound: vec_kth_elem")
179        );
180
181        let input0: ArrayRef = Arc::new(StringViewArray::from(vec![Some(
182            "[1.0,2.0,3.0]".to_string(),
183        )]));
184        let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(-1)]));
185
186        let args = ScalarFunctionArgs {
187            args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
188            arg_fields: vec![],
189            number_rows: 3,
190            return_field: Arc::new(Field::new("x", DataType::Float32, false)),
191            config_options: Arc::new(ConfigOptions::new()),
192        };
193        let e = func.invoke_with_args(args).unwrap_err();
194        assert!(e.to_string().starts_with(
195            "Execution error: 2nd argument not a valid index or expected datatype: vec_kth_elem"
196        ));
197    }
198}