common_function/scalars/vector/
vector_subvector.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion_expr::{Signature, TypeSignature, Volatility};
20use datatypes::arrow::datatypes::DataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
23use snafu::ensure;
24
25use crate::function::{Function, FunctionContext};
26use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
27
28const NAME: &str = "vec_subvector";
29
30/// Returns a subvector from start(included) to end(excluded) index.
31///
32/// # Example
33///
34/// ```sql
35/// SELECT vec_to_string(vec_subvector("[1, 2, 3, 4, 5]", 1, 3)) as result;
36///
37/// +---------+
38/// | result  |
39/// +---------+
40/// | [2, 3]  |
41/// +---------+
42///
43/// ```
44///
45
46#[derive(Debug, Clone, Default)]
47pub struct VectorSubvectorFunction;
48
49impl Function for VectorSubvectorFunction {
50    fn name(&self) -> &str {
51        NAME
52    }
53
54    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
55        Ok(DataType::Binary)
56    }
57
58    fn signature(&self) -> Signature {
59        Signature::one_of(
60            vec![
61                TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64, DataType::Int64]),
62                TypeSignature::Exact(vec![DataType::Binary, DataType::Int64, DataType::Int64]),
63            ],
64            Volatility::Immutable,
65        )
66    }
67
68    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
69        ensure!(
70            columns.len() == 3,
71            InvalidFuncArgsSnafu {
72                err_msg: format!(
73                    "The length of the args is not correct, expect exactly three, have: {}",
74                    columns.len()
75                )
76            }
77        );
78
79        let arg0 = &columns[0];
80        let arg1 = &columns[1];
81        let arg2 = &columns[2];
82
83        ensure!(
84            arg0.len() == arg1.len() && arg1.len() == arg2.len(),
85            InvalidFuncArgsSnafu {
86                err_msg: format!(
87                    "The lengths of the vector are not aligned, args 0: {}, args 1: {}, args 2: {}",
88                    arg0.len(),
89                    arg1.len(),
90                    arg2.len()
91                )
92            }
93        );
94
95        let len = arg0.len();
96        let mut result = BinaryVectorBuilder::with_capacity(len);
97        if len == 0 {
98            return Ok(result.to_vector());
99        }
100
101        let arg0_const = as_veclit_if_const(arg0)?;
102
103        for i in 0..len {
104            let arg0 = match arg0_const.as_ref() {
105                Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
106                None => as_veclit(arg0.get_ref(i))?,
107            };
108            let arg1 = arg1.get(i).as_i64();
109            let arg2 = arg2.get(i).as_i64();
110            let (Some(arg0), Some(arg1), Some(arg2)) = (arg0, arg1, arg2) else {
111                result.push_null();
112                continue;
113            };
114
115            ensure!(
116                0 <= arg1 && arg1 <= arg2 && arg2 as usize <= arg0.len(),
117                InvalidFuncArgsSnafu {
118                    err_msg: format!(
119                        "Invalid start and end indices: start={}, end={}, vec_len={}",
120                        arg1,
121                        arg2,
122                        arg0.len()
123                    )
124                }
125            );
126
127            let subvector = &arg0[arg1 as usize..arg2 as usize];
128            let binlit = veclit_to_binlit(subvector);
129            result.push(Some(&binlit));
130        }
131
132        Ok(result.to_vector())
133    }
134}
135
136impl Display for VectorSubvectorFunction {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        write!(f, "{}", NAME.to_ascii_uppercase())
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use std::sync::Arc;
145
146    use common_query::error::Error;
147    use datatypes::vectors::{Int64Vector, StringVector};
148
149    use super::*;
150    use crate::function::FunctionContext;
151    #[test]
152    fn test_subvector() {
153        let func = VectorSubvectorFunction;
154
155        let input0 = Arc::new(StringVector::from(vec![
156            Some("[1.0, 2.0, 3.0, 4.0, 5.0]".to_string()),
157            Some("[6.0, 7.0, 8.0, 9.0, 10.0]".to_string()),
158            None,
159            Some("[11.0, 12.0, 13.0]".to_string()),
160        ]));
161        let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(0), Some(0), Some(1)]));
162        let input2 = Arc::new(Int64Vector::from(vec![Some(3), Some(5), Some(2), Some(3)]));
163
164        let result = func
165            .eval(&FunctionContext::default(), &[input0, input1, input2])
166            .unwrap();
167
168        let result = result.as_ref();
169        assert_eq!(result.len(), 4);
170        assert_eq!(
171            result.get_ref(0).as_binary().unwrap(),
172            Some(veclit_to_binlit(&[2.0, 3.0]).as_slice())
173        );
174        assert_eq!(
175            result.get_ref(1).as_binary().unwrap(),
176            Some(veclit_to_binlit(&[6.0, 7.0, 8.0, 9.0, 10.0]).as_slice())
177        );
178        assert!(result.get_ref(2).is_null());
179        assert_eq!(
180            result.get_ref(3).as_binary().unwrap(),
181            Some(veclit_to_binlit(&[12.0, 13.0]).as_slice())
182        );
183    }
184    #[test]
185    fn test_subvector_error() {
186        let func = VectorSubvectorFunction;
187
188        let input0 = Arc::new(StringVector::from(vec![
189            Some("[1.0, 2.0, 3.0]".to_string()),
190            Some("[4.0, 5.0, 6.0]".to_string()),
191        ]));
192        let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(2)]));
193        let input2 = Arc::new(Int64Vector::from(vec![Some(3)]));
194
195        let result = func.eval(&FunctionContext::default(), &[input0, input1, input2]);
196
197        match result {
198            Err(Error::InvalidFuncArgs { err_msg, .. }) => {
199                assert_eq!(
200                    err_msg,
201                    "The lengths of the vector are not aligned, args 0: 2, args 1: 2, args 2: 1"
202                )
203            }
204            _ => unreachable!(),
205        }
206    }
207
208    #[test]
209    fn test_subvector_invalid_indices() {
210        let func = VectorSubvectorFunction;
211
212        let input0 = Arc::new(StringVector::from(vec![
213            Some("[1.0, 2.0, 3.0]".to_string()),
214            Some("[4.0, 5.0, 6.0]".to_string()),
215        ]));
216        let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(3)]));
217        let input2 = Arc::new(Int64Vector::from(vec![Some(3), Some(4)]));
218
219        let result = func.eval(&FunctionContext::default(), &[input0, input1, input2]);
220
221        match result {
222            Err(Error::InvalidFuncArgs { err_msg, .. }) => {
223                assert_eq!(
224                    err_msg,
225                    "Invalid start and end indices: start=3, end=4, vec_len=3"
226                )
227            }
228            _ => unreachable!(),
229        }
230    }
231}