common_function/scalars/vector/
vector_subvector.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use common_query::prelude::{Signature, TypeSignature};
20use datafusion_expr::Volatility;
21use datatypes::prelude::ConcreteDataType;
22use datatypes::scalars::ScalarVectorBuilder;
23use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
24use snafu::ensure;
25
26use crate::function::{Function, FunctionContext};
27use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
28
29const NAME: &str = "vec_subvector";
30
31/// Returns a subvector from start(included) to end(excluded) index.
32///
33/// # Example
34///
35/// ```sql
36/// SELECT vec_to_string(vec_subvector("[1, 2, 3, 4, 5]", 1, 3)) as result;
37///
38/// +---------+
39/// | result  |
40/// +---------+
41/// | [2, 3]  |
42/// +---------+
43///
44/// ```
45///
46
47#[derive(Debug, Clone, Default)]
48pub struct VectorSubvectorFunction;
49
50impl Function for VectorSubvectorFunction {
51    fn name(&self) -> &str {
52        NAME
53    }
54
55    fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
56        Ok(ConcreteDataType::binary_datatype())
57    }
58
59    fn signature(&self) -> Signature {
60        Signature::one_of(
61            vec![
62                TypeSignature::Exact(vec![
63                    ConcreteDataType::string_datatype(),
64                    ConcreteDataType::int64_datatype(),
65                    ConcreteDataType::int64_datatype(),
66                ]),
67                TypeSignature::Exact(vec![
68                    ConcreteDataType::binary_datatype(),
69                    ConcreteDataType::int64_datatype(),
70                    ConcreteDataType::int64_datatype(),
71                ]),
72            ],
73            Volatility::Immutable,
74        )
75    }
76
77    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
78        ensure!(
79            columns.len() == 3,
80            InvalidFuncArgsSnafu {
81                err_msg: format!(
82                    "The length of the args is not correct, expect exactly three, have: {}",
83                    columns.len()
84                )
85            }
86        );
87
88        let arg0 = &columns[0];
89        let arg1 = &columns[1];
90        let arg2 = &columns[2];
91
92        ensure!(
93            arg0.len() == arg1.len() && arg1.len() == arg2.len(),
94            InvalidFuncArgsSnafu {
95                err_msg: format!(
96                    "The lengths of the vector are not aligned, args 0: {}, args 1: {}, args 2: {}",
97                    arg0.len(),
98                    arg1.len(),
99                    arg2.len()
100                )
101            }
102        );
103
104        let len = arg0.len();
105        let mut result = BinaryVectorBuilder::with_capacity(len);
106        if len == 0 {
107            return Ok(result.to_vector());
108        }
109
110        let arg0_const = as_veclit_if_const(arg0)?;
111
112        for i in 0..len {
113            let arg0 = match arg0_const.as_ref() {
114                Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
115                None => as_veclit(arg0.get_ref(i))?,
116            };
117            let arg1 = arg1.get(i).as_i64();
118            let arg2 = arg2.get(i).as_i64();
119            let (Some(arg0), Some(arg1), Some(arg2)) = (arg0, arg1, arg2) else {
120                result.push_null();
121                continue;
122            };
123
124            ensure!(
125                0 <= arg1 && arg1 <= arg2 && arg2 as usize <= arg0.len(),
126                InvalidFuncArgsSnafu {
127                    err_msg: format!(
128                        "Invalid start and end indices: start={}, end={}, vec_len={}",
129                        arg1,
130                        arg2,
131                        arg0.len()
132                    )
133                }
134            );
135
136            let subvector = &arg0[arg1 as usize..arg2 as usize];
137            let binlit = veclit_to_binlit(subvector);
138            result.push(Some(&binlit));
139        }
140
141        Ok(result.to_vector())
142    }
143}
144
145impl Display for VectorSubvectorFunction {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        write!(f, "{}", NAME.to_ascii_uppercase())
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use std::sync::Arc;
154
155    use common_query::error::Error;
156    use datatypes::vectors::{Int64Vector, StringVector};
157
158    use super::*;
159    use crate::function::FunctionContext;
160    #[test]
161    fn test_subvector() {
162        let func = VectorSubvectorFunction;
163
164        let input0 = Arc::new(StringVector::from(vec![
165            Some("[1.0, 2.0, 3.0, 4.0, 5.0]".to_string()),
166            Some("[6.0, 7.0, 8.0, 9.0, 10.0]".to_string()),
167            None,
168            Some("[11.0, 12.0, 13.0]".to_string()),
169        ]));
170        let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(0), Some(0), Some(1)]));
171        let input2 = Arc::new(Int64Vector::from(vec![Some(3), Some(5), Some(2), Some(3)]));
172
173        let result = func
174            .eval(&FunctionContext::default(), &[input0, input1, input2])
175            .unwrap();
176
177        let result = result.as_ref();
178        assert_eq!(result.len(), 4);
179        assert_eq!(
180            result.get_ref(0).as_binary().unwrap(),
181            Some(veclit_to_binlit(&[2.0, 3.0]).as_slice())
182        );
183        assert_eq!(
184            result.get_ref(1).as_binary().unwrap(),
185            Some(veclit_to_binlit(&[6.0, 7.0, 8.0, 9.0, 10.0]).as_slice())
186        );
187        assert!(result.get_ref(2).is_null());
188        assert_eq!(
189            result.get_ref(3).as_binary().unwrap(),
190            Some(veclit_to_binlit(&[12.0, 13.0]).as_slice())
191        );
192    }
193    #[test]
194    fn test_subvector_error() {
195        let func = VectorSubvectorFunction;
196
197        let input0 = Arc::new(StringVector::from(vec![
198            Some("[1.0, 2.0, 3.0]".to_string()),
199            Some("[4.0, 5.0, 6.0]".to_string()),
200        ]));
201        let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(2)]));
202        let input2 = Arc::new(Int64Vector::from(vec![Some(3)]));
203
204        let result = func.eval(&FunctionContext::default(), &[input0, input1, input2]);
205
206        match result {
207            Err(Error::InvalidFuncArgs { err_msg, .. }) => {
208                assert_eq!(
209                    err_msg,
210                    "The lengths of the vector are not aligned, args 0: 2, args 1: 2, args 2: 1"
211                )
212            }
213            _ => unreachable!(),
214        }
215    }
216
217    #[test]
218    fn test_subvector_invalid_indices() {
219        let func = VectorSubvectorFunction;
220
221        let input0 = Arc::new(StringVector::from(vec![
222            Some("[1.0, 2.0, 3.0]".to_string()),
223            Some("[4.0, 5.0, 6.0]".to_string()),
224        ]));
225        let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(3)]));
226        let input2 = Arc::new(Int64Vector::from(vec![Some(3), Some(4)]));
227
228        let result = func.eval(&FunctionContext::default(), &[input0, input1, input2]);
229
230        match result {
231            Err(Error::InvalidFuncArgs { err_msg, .. }) => {
232                assert_eq!(
233                    err_msg,
234                    "Invalid start and end indices: start=3, end=4, vec_len=3"
235                )
236            }
237            _ => unreachable!(),
238        }
239    }
240}