common_function/scalars/vector/
vector_subvector.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16use std::sync::Arc;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion::arrow::array::{Array, AsArray, BinaryViewBuilder};
20use datafusion::arrow::datatypes::Int64Type;
21use datafusion::logical_expr::ColumnarValue;
22use datafusion_common::{ScalarValue, utils};
23use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
24use datatypes::arrow::datatypes::DataType;
25use snafu::ensure;
26
27use crate::function::Function;
28use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
29
30const NAME: &str = "vec_subvector";
31
32/// Returns a subvector from start(included) to end(excluded) index.
33///
34/// # Example
35///
36/// ```sql
37/// SELECT vec_to_string(vec_subvector("[1, 2, 3, 4, 5]", 1, 3)) as result;
38///
39/// +---------+
40/// | result  |
41/// +---------+
42/// | [2, 3]  |
43/// +---------+
44///
45/// ```
46///
47
48#[derive(Debug, Clone, Default)]
49pub struct VectorSubvectorFunction;
50
51impl Function for VectorSubvectorFunction {
52    fn name(&self) -> &str {
53        NAME
54    }
55
56    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
57        Ok(DataType::BinaryView)
58    }
59
60    fn signature(&self) -> Signature {
61        Signature::one_of(
62            vec![
63                TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64, DataType::Int64]),
64                TypeSignature::Exact(vec![DataType::Binary, DataType::Int64, DataType::Int64]),
65            ],
66            Volatility::Immutable,
67        )
68    }
69
70    fn invoke_with_args(
71        &self,
72        args: ScalarFunctionArgs,
73    ) -> datafusion_common::Result<ColumnarValue> {
74        let args = ColumnarValue::values_to_arrays(&args.args)?;
75        let [arg0, arg1, arg2] = utils::take_function_args(self.name(), args)?;
76        let arg1 = arg1.as_primitive::<Int64Type>();
77        let arg2 = arg2.as_primitive::<Int64Type>();
78
79        let len = arg0.len();
80        let mut builder = BinaryViewBuilder::with_capacity(len);
81        if len == 0 {
82            return Ok(ColumnarValue::Array(Arc::new(builder.finish())));
83        }
84
85        for i in 0..len {
86            let v = ScalarValue::try_from_array(&arg0, i)?;
87            let arg0 = as_veclit(&v)?;
88            let arg1 = arg1.is_valid(i).then(|| arg1.value(i));
89            let arg2 = arg2.is_valid(i).then(|| arg2.value(i));
90            let (Some(arg0), Some(arg1), Some(arg2)) = (arg0, arg1, arg2) else {
91                builder.append_null();
92                continue;
93            };
94
95            ensure!(
96                0 <= arg1 && arg1 <= arg2 && arg2 as usize <= arg0.len(),
97                InvalidFuncArgsSnafu {
98                    err_msg: format!(
99                        "Invalid start and end indices: start={}, end={}, vec_len={}",
100                        arg1,
101                        arg2,
102                        arg0.len()
103                    )
104                }
105            );
106
107            let subvector = &arg0[arg1 as usize..arg2 as usize];
108            let binlit = veclit_to_binlit(subvector);
109            builder.append_value(&binlit);
110        }
111
112        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
113    }
114}
115
116impl Display for VectorSubvectorFunction {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        write!(f, "{}", NAME.to_ascii_uppercase())
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use std::sync::Arc;
125
126    use arrow_schema::Field;
127    use datafusion::arrow::array::{ArrayRef, Int64Array, StringViewArray};
128    use datafusion_common::config::ConfigOptions;
129
130    use super::*;
131
132    #[test]
133    fn test_subvector() {
134        let func = VectorSubvectorFunction;
135
136        let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
137            Some("[1.0, 2.0, 3.0, 4.0, 5.0]".to_string()),
138            Some("[6.0, 7.0, 8.0, 9.0, 10.0]".to_string()),
139            None,
140            Some("[11.0, 12.0, 13.0]".to_string()),
141        ]));
142        let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(0), Some(0), Some(1)]));
143        let input2: ArrayRef = Arc::new(Int64Array::from(vec![Some(3), Some(5), Some(2), Some(3)]));
144
145        let args = ScalarFunctionArgs {
146            args: vec![
147                ColumnarValue::Array(input0),
148                ColumnarValue::Array(input1),
149                ColumnarValue::Array(input2),
150            ],
151            arg_fields: vec![],
152            number_rows: 5,
153            return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
154            config_options: Arc::new(ConfigOptions::new()),
155        };
156        let result = func
157            .invoke_with_args(args)
158            .and_then(|x| x.to_array(5))
159            .unwrap();
160
161        let result = result.as_binary_view();
162        assert_eq!(result.len(), 4);
163        assert_eq!(result.value(0), veclit_to_binlit(&[2.0, 3.0]).as_slice());
164        assert_eq!(
165            result.value(1),
166            veclit_to_binlit(&[6.0, 7.0, 8.0, 9.0, 10.0]).as_slice()
167        );
168        assert!(result.is_null(2));
169        assert_eq!(result.value(3), veclit_to_binlit(&[12.0, 13.0]).as_slice());
170    }
171    #[test]
172    fn test_subvector_error() {
173        let func = VectorSubvectorFunction;
174
175        let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
176            Some("[1.0, 2.0, 3.0]".to_string()),
177            Some("[4.0, 5.0, 6.0]".to_string()),
178        ]));
179        let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2)]));
180        let input2: ArrayRef = Arc::new(Int64Array::from(vec![Some(3)]));
181
182        let args = ScalarFunctionArgs {
183            args: vec![
184                ColumnarValue::Array(input0),
185                ColumnarValue::Array(input1),
186                ColumnarValue::Array(input2),
187            ],
188            arg_fields: vec![],
189            number_rows: 3,
190            return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
191            config_options: Arc::new(ConfigOptions::new()),
192        };
193        let e = func.invoke_with_args(args).unwrap_err();
194        assert!(e.to_string().starts_with(
195            "Internal error: Arguments has mixed length. Expected length: 2, found length: 1."
196        ));
197    }
198
199    #[test]
200    fn test_subvector_invalid_indices() {
201        let func = VectorSubvectorFunction;
202
203        let input0 = Arc::new(StringViewArray::from(vec![
204            Some("[1.0, 2.0, 3.0]".to_string()),
205            Some("[4.0, 5.0, 6.0]".to_string()),
206        ]));
207        let input1 = Arc::new(Int64Array::from(vec![Some(1), Some(3)]));
208        let input2 = Arc::new(Int64Array::from(vec![Some(3), Some(4)]));
209
210        let args = ScalarFunctionArgs {
211            args: vec![
212                ColumnarValue::Array(input0),
213                ColumnarValue::Array(input1),
214                ColumnarValue::Array(input2),
215            ],
216            arg_fields: vec![],
217            number_rows: 3,
218            return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
219            config_options: Arc::new(ConfigOptions::new()),
220        };
221        let e = func.invoke_with_args(args).unwrap_err();
222        assert!(e.to_string().starts_with("External error: Invalid function args: Invalid start and end indices: start=3, end=4, vec_len=3"));
223    }
224}