common_function/scalars/vector/
vector_kth_elem.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion_expr::Signature;
20use datatypes::arrow::datatypes::DataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
23use snafu::ensure;
24
25use crate::function::{Function, FunctionContext};
26use crate::helper;
27use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
28
29const NAME: &str = "vec_kth_elem";
30
31/// Returns the k-th(0-based index) element of the vector.
32///
33/// # Example
34///
35/// ```sql
36/// SELECT vec_kth_elem("[2, 4, 6]",1) as result;
37///
38/// +---------+
39/// | result  |
40/// +---------+
41/// | 4 |
42/// +---------+
43///
44/// ```
45///
46
47#[derive(Debug, Clone, Default)]
48pub struct VectorKthElemFunction;
49
50impl Function for VectorKthElemFunction {
51    fn name(&self) -> &str {
52        NAME
53    }
54
55    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
56        Ok(DataType::Float32)
57    }
58
59    fn signature(&self) -> Signature {
60        helper::one_of_sigs2(
61            vec![DataType::Utf8, DataType::Binary],
62            vec![DataType::Int64],
63        )
64    }
65
66    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
67        ensure!(
68            columns.len() == 2,
69            InvalidFuncArgsSnafu {
70                err_msg: format!(
71                    "The length of the args is not correct, expect exactly two, have: {}",
72                    columns.len()
73                ),
74            }
75        );
76
77        let arg0 = &columns[0];
78        let arg1 = &columns[1];
79
80        let len = arg0.len();
81        let mut result = Float32VectorBuilder::with_capacity(len);
82        if len == 0 {
83            return Ok(result.to_vector());
84        };
85
86        let arg0_const = as_veclit_if_const(arg0)?;
87
88        for i in 0..len {
89            let arg0 = match arg0_const.as_ref() {
90                Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
91                None => as_veclit(arg0.get_ref(i))?,
92            };
93            let Some(arg0) = arg0 else {
94                result.push_null();
95                continue;
96            };
97
98            let arg1 = arg1.get(i).as_f64_lossy();
99            let Some(arg1) = arg1 else {
100                result.push_null();
101                continue;
102            };
103
104            ensure!(
105                arg1 >= 0.0 && arg1.fract() == 0.0,
106                InvalidFuncArgsSnafu {
107                    err_msg: format!(
108                        "Invalid argument: k must be a non-negative integer, but got k = {}.",
109                        arg1
110                    ),
111                }
112            );
113
114            let k = arg1 as usize;
115
116            ensure!(
117                k < arg0.len(),
118                InvalidFuncArgsSnafu {
119                    err_msg: format!(
120                        "Out of range: k must be in the range [0, {}], but got k = {}.",
121                        arg0.len() - 1,
122                        k
123                    ),
124                }
125            );
126
127            let value = arg0[k];
128
129            result.push(Some(value));
130        }
131        Ok(result.to_vector())
132    }
133}
134
135impl Display for VectorKthElemFunction {
136    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137        write!(f, "{}", NAME.to_ascii_uppercase())
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use std::sync::Arc;
144
145    use common_query::error;
146    use datatypes::vectors::{Int64Vector, StringVector};
147
148    use super::*;
149
150    #[test]
151    fn test_vec_kth_elem() {
152        let func = VectorKthElemFunction;
153
154        let input0 = Arc::new(StringVector::from(vec![
155            Some("[1.0,2.0,3.0]".to_string()),
156            Some("[4.0,5.0,6.0]".to_string()),
157            Some("[7.0,8.0,9.0]".to_string()),
158            None,
159        ]));
160        let input1 = Arc::new(Int64Vector::from(vec![Some(0), Some(2), None, Some(1)]));
161
162        let result = func
163            .eval(&FunctionContext::default(), &[input0, input1])
164            .unwrap();
165
166        let result = result.as_ref();
167        assert_eq!(result.len(), 4);
168        assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(1.0));
169        assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(6.0));
170        assert!(result.get_ref(2).is_null());
171        assert!(result.get_ref(3).is_null());
172
173        let input0 = Arc::new(StringVector::from(vec![Some("[1.0,2.0,3.0]".to_string())]));
174        let input1 = Arc::new(Int64Vector::from(vec![Some(3)]));
175
176        let err = func
177            .eval(&FunctionContext::default(), &[input0, input1])
178            .unwrap_err();
179        match err {
180            error::Error::InvalidFuncArgs { err_msg, .. } => {
181                assert_eq!(
182                    err_msg,
183                    format!("Out of range: k must be in the range [0, 2], but got k = 3.")
184                )
185            }
186            _ => unreachable!(),
187        }
188
189        let input0 = Arc::new(StringVector::from(vec![Some("[1.0,2.0,3.0]".to_string())]));
190        let input1 = Arc::new(Int64Vector::from(vec![Some(-1)]));
191
192        let err = func
193            .eval(&FunctionContext::default(), &[input0, input1])
194            .unwrap_err();
195        match err {
196            error::Error::InvalidFuncArgs { err_msg, .. } => {
197                assert_eq!(
198                    err_msg,
199                    format!("Invalid argument: k must be a non-negative integer, but got k = -1.")
200                )
201            }
202            _ => unreachable!(),
203        }
204    }
205}