common_function/scalars/vector/
vector_kth_elem.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use common_query::prelude::Signature;
20use datatypes::prelude::ConcreteDataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
23use snafu::ensure;
24
25use crate::function::{Function, FunctionContext};
26use crate::helper;
27use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
28
29const NAME: &str = "vec_kth_elem";
30
31/// Returns the k-th(0-based index) element of the vector.
32///
33/// # Example
34///
35/// ```sql
36/// SELECT vec_kth_elem("[2, 4, 6]",1) as result;
37///
38/// +---------+
39/// | result  |
40/// +---------+
41/// | 4 |
42/// +---------+
43///
44/// ```
45///
46
47#[derive(Debug, Clone, Default)]
48pub struct VectorKthElemFunction;
49
50impl Function for VectorKthElemFunction {
51    fn name(&self) -> &str {
52        NAME
53    }
54
55    fn return_type(
56        &self,
57        _input_types: &[ConcreteDataType],
58    ) -> common_query::error::Result<ConcreteDataType> {
59        Ok(ConcreteDataType::float32_datatype())
60    }
61
62    fn signature(&self) -> Signature {
63        helper::one_of_sigs2(
64            vec![
65                ConcreteDataType::string_datatype(),
66                ConcreteDataType::binary_datatype(),
67            ],
68            vec![ConcreteDataType::int64_datatype()],
69        )
70    }
71
72    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
73        ensure!(
74            columns.len() == 2,
75            InvalidFuncArgsSnafu {
76                err_msg: format!(
77                    "The length of the args is not correct, expect exactly two, have: {}",
78                    columns.len()
79                ),
80            }
81        );
82
83        let arg0 = &columns[0];
84        let arg1 = &columns[1];
85
86        let len = arg0.len();
87        let mut result = Float32VectorBuilder::with_capacity(len);
88        if len == 0 {
89            return Ok(result.to_vector());
90        };
91
92        let arg0_const = as_veclit_if_const(arg0)?;
93
94        for i in 0..len {
95            let arg0 = match arg0_const.as_ref() {
96                Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
97                None => as_veclit(arg0.get_ref(i))?,
98            };
99            let Some(arg0) = arg0 else {
100                result.push_null();
101                continue;
102            };
103
104            let arg1 = arg1.get(i).as_f64_lossy();
105            let Some(arg1) = arg1 else {
106                result.push_null();
107                continue;
108            };
109
110            ensure!(
111                arg1 >= 0.0 && arg1.fract() == 0.0,
112                InvalidFuncArgsSnafu {
113                    err_msg: format!(
114                        "Invalid argument: k must be a non-negative integer, but got k = {}.",
115                        arg1
116                    ),
117                }
118            );
119
120            let k = arg1 as usize;
121
122            ensure!(
123                k < arg0.len(),
124                InvalidFuncArgsSnafu {
125                    err_msg: format!(
126                        "Out of range: k must be in the range [0, {}], but got k = {}.",
127                        arg0.len() - 1,
128                        k
129                    ),
130                }
131            );
132
133            let value = arg0[k];
134
135            result.push(Some(value));
136        }
137        Ok(result.to_vector())
138    }
139}
140
141impl Display for VectorKthElemFunction {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        write!(f, "{}", NAME.to_ascii_uppercase())
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use std::sync::Arc;
150
151    use common_query::error;
152    use datatypes::vectors::{Int64Vector, StringVector};
153
154    use super::*;
155
156    #[test]
157    fn test_vec_kth_elem() {
158        let func = VectorKthElemFunction;
159
160        let input0 = Arc::new(StringVector::from(vec![
161            Some("[1.0,2.0,3.0]".to_string()),
162            Some("[4.0,5.0,6.0]".to_string()),
163            Some("[7.0,8.0,9.0]".to_string()),
164            None,
165        ]));
166        let input1 = Arc::new(Int64Vector::from(vec![Some(0), Some(2), None, Some(1)]));
167
168        let result = func
169            .eval(&FunctionContext::default(), &[input0, input1])
170            .unwrap();
171
172        let result = result.as_ref();
173        assert_eq!(result.len(), 4);
174        assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(1.0));
175        assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(6.0));
176        assert!(result.get_ref(2).is_null());
177        assert!(result.get_ref(3).is_null());
178
179        let input0 = Arc::new(StringVector::from(vec![Some("[1.0,2.0,3.0]".to_string())]));
180        let input1 = Arc::new(Int64Vector::from(vec![Some(3)]));
181
182        let err = func
183            .eval(&FunctionContext::default(), &[input0, input1])
184            .unwrap_err();
185        match err {
186            error::Error::InvalidFuncArgs { err_msg, .. } => {
187                assert_eq!(
188                    err_msg,
189                    format!("Out of range: k must be in the range [0, 2], but got k = 3.")
190                )
191            }
192            _ => unreachable!(),
193        }
194
195        let input0 = Arc::new(StringVector::from(vec![Some("[1.0,2.0,3.0]".to_string())]));
196        let input1 = Arc::new(Int64Vector::from(vec![Some(-1)]));
197
198        let err = func
199            .eval(&FunctionContext::default(), &[input0, input1])
200            .unwrap_err();
201        match err {
202            error::Error::InvalidFuncArgs { err_msg, .. } => {
203                assert_eq!(
204                    err_msg,
205                    format!("Invalid argument: k must be a non-negative integer, but got k = -1.")
206                )
207            }
208            _ => unreachable!(),
209        }
210    }
211}