common_function/scalars/vector/
vector_subvector.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16use std::sync::Arc;
17
18use common_query::error::InvalidFuncArgsSnafu;
19use datafusion::arrow::array::{Array, AsArray, BinaryViewBuilder};
20use datafusion::arrow::datatypes::Int64Type;
21use datafusion::logical_expr::ColumnarValue;
22use datafusion_common::ScalarValue;
23use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
24use datatypes::arrow::datatypes::DataType;
25use snafu::ensure;
26
27use crate::function::{Function, extract_args};
28use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
29
30const NAME: &str = "vec_subvector";
31
32/// Returns a subvector from start(included) to end(excluded) index.
33///
34/// # Example
35///
36/// ```sql
37/// SELECT vec_to_string(vec_subvector("[1, 2, 3, 4, 5]", 1, 3)) as result;
38///
39/// +---------+
40/// | result  |
41/// +---------+
42/// | [2, 3]  |
43/// +---------+
44///
45/// ```
46///
47
48#[derive(Debug, Clone)]
49pub(crate) struct VectorSubvectorFunction {
50    signature: Signature,
51}
52
53impl Default for VectorSubvectorFunction {
54    fn default() -> Self {
55        Self {
56            signature: Signature::one_of(
57                vec![
58                    TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64, DataType::Int64]),
59                    TypeSignature::Exact(vec![DataType::Binary, DataType::Int64, DataType::Int64]),
60                ],
61                Volatility::Immutable,
62            ),
63        }
64    }
65}
66
67impl Function for VectorSubvectorFunction {
68    fn name(&self) -> &str {
69        NAME
70    }
71
72    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
73        Ok(DataType::BinaryView)
74    }
75
76    fn signature(&self) -> &Signature {
77        &self.signature
78    }
79
80    fn invoke_with_args(
81        &self,
82        args: ScalarFunctionArgs,
83    ) -> datafusion_common::Result<ColumnarValue> {
84        let [arg0, arg1, arg2] = extract_args(self.name(), &args)?;
85        let arg1 = arg1.as_primitive::<Int64Type>();
86        let arg2 = arg2.as_primitive::<Int64Type>();
87
88        let len = arg0.len();
89        let mut builder = BinaryViewBuilder::with_capacity(len);
90        if len == 0 {
91            return Ok(ColumnarValue::Array(Arc::new(builder.finish())));
92        }
93
94        for i in 0..len {
95            let v = ScalarValue::try_from_array(&arg0, i)?;
96            let arg0 = as_veclit(&v)?;
97            let arg1 = arg1.is_valid(i).then(|| arg1.value(i));
98            let arg2 = arg2.is_valid(i).then(|| arg2.value(i));
99            let (Some(arg0), Some(arg1), Some(arg2)) = (arg0, arg1, arg2) else {
100                builder.append_null();
101                continue;
102            };
103
104            ensure!(
105                0 <= arg1 && arg1 <= arg2 && arg2 as usize <= arg0.len(),
106                InvalidFuncArgsSnafu {
107                    err_msg: format!(
108                        "Invalid start and end indices: start={}, end={}, vec_len={}",
109                        arg1,
110                        arg2,
111                        arg0.len()
112                    )
113                }
114            );
115
116            let subvector = &arg0[arg1 as usize..arg2 as usize];
117            let binlit = veclit_to_binlit(subvector);
118            builder.append_value(&binlit);
119        }
120
121        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
122    }
123}
124
125impl Display for VectorSubvectorFunction {
126    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127        write!(f, "{}", NAME.to_ascii_uppercase())
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use std::sync::Arc;
134
135    use arrow_schema::Field;
136    use datafusion::arrow::array::{ArrayRef, Int64Array, StringViewArray};
137    use datafusion_common::config::ConfigOptions;
138
139    use super::*;
140
141    #[test]
142    fn test_subvector() {
143        let func = VectorSubvectorFunction::default();
144
145        let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
146            Some("[1.0, 2.0, 3.0, 4.0, 5.0]".to_string()),
147            Some("[6.0, 7.0, 8.0, 9.0, 10.0]".to_string()),
148            None,
149            Some("[11.0, 12.0, 13.0]".to_string()),
150        ]));
151        let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(0), Some(0), Some(1)]));
152        let input2: ArrayRef = Arc::new(Int64Array::from(vec![Some(3), Some(5), Some(2), Some(3)]));
153
154        let args = ScalarFunctionArgs {
155            args: vec![
156                ColumnarValue::Array(input0),
157                ColumnarValue::Array(input1),
158                ColumnarValue::Array(input2),
159            ],
160            arg_fields: vec![],
161            number_rows: 5,
162            return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
163            config_options: Arc::new(ConfigOptions::new()),
164        };
165        let result = func
166            .invoke_with_args(args)
167            .and_then(|x| x.to_array(5))
168            .unwrap();
169
170        let result = result.as_binary_view();
171        assert_eq!(result.len(), 4);
172        assert_eq!(result.value(0), veclit_to_binlit(&[2.0, 3.0]).as_slice());
173        assert_eq!(
174            result.value(1),
175            veclit_to_binlit(&[6.0, 7.0, 8.0, 9.0, 10.0]).as_slice()
176        );
177        assert!(result.is_null(2));
178        assert_eq!(result.value(3), veclit_to_binlit(&[12.0, 13.0]).as_slice());
179    }
180    #[test]
181    fn test_subvector_error() {
182        let func = VectorSubvectorFunction::default();
183
184        let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
185            Some("[1.0, 2.0, 3.0]".to_string()),
186            Some("[4.0, 5.0, 6.0]".to_string()),
187        ]));
188        let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2)]));
189        let input2: ArrayRef = Arc::new(Int64Array::from(vec![Some(3)]));
190
191        let args = ScalarFunctionArgs {
192            args: vec![
193                ColumnarValue::Array(input0),
194                ColumnarValue::Array(input1),
195                ColumnarValue::Array(input2),
196            ],
197            arg_fields: vec![],
198            number_rows: 3,
199            return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
200            config_options: Arc::new(ConfigOptions::new()),
201        };
202        let e = func.invoke_with_args(args).unwrap_err();
203        assert!(e.to_string().starts_with(
204            "Internal error: Arguments has mixed length. Expected length: 2, found length: 1."
205        ));
206    }
207
208    #[test]
209    fn test_subvector_invalid_indices() {
210        let func = VectorSubvectorFunction::default();
211
212        let input0 = Arc::new(StringViewArray::from(vec![
213            Some("[1.0, 2.0, 3.0]".to_string()),
214            Some("[4.0, 5.0, 6.0]".to_string()),
215        ]));
216        let input1 = Arc::new(Int64Array::from(vec![Some(1), Some(3)]));
217        let input2 = Arc::new(Int64Array::from(vec![Some(3), Some(4)]));
218
219        let args = ScalarFunctionArgs {
220            args: vec![
221                ColumnarValue::Array(input0),
222                ColumnarValue::Array(input1),
223                ColumnarValue::Array(input2),
224            ],
225            arg_fields: vec![],
226            number_rows: 3,
227            return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
228            config_options: Arc::new(ConfigOptions::new()),
229        };
230        let e = func.invoke_with_args(args).unwrap_err();
231        assert!(e.to_string().starts_with("External error: Invalid function args: Invalid start and end indices: start=3, end=4, vec_len=3"));
232    }
233}