common_function/scalars/vector/
vector_subvector.rs1use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion_expr::{Signature, TypeSignature, Volatility};
20use datatypes::arrow::datatypes::DataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
23use snafu::ensure;
24
25use crate::function::{Function, FunctionContext};
26use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
27
28const NAME: &str = "vec_subvector";
29
30#[derive(Debug, Clone, Default)]
47pub struct VectorSubvectorFunction;
48
49impl Function for VectorSubvectorFunction {
50 fn name(&self) -> &str {
51 NAME
52 }
53
54 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
55 Ok(DataType::Binary)
56 }
57
58 fn signature(&self) -> Signature {
59 Signature::one_of(
60 vec![
61 TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64, DataType::Int64]),
62 TypeSignature::Exact(vec![DataType::Binary, DataType::Int64, DataType::Int64]),
63 ],
64 Volatility::Immutable,
65 )
66 }
67
68 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
69 ensure!(
70 columns.len() == 3,
71 InvalidFuncArgsSnafu {
72 err_msg: format!(
73 "The length of the args is not correct, expect exactly three, have: {}",
74 columns.len()
75 )
76 }
77 );
78
79 let arg0 = &columns[0];
80 let arg1 = &columns[1];
81 let arg2 = &columns[2];
82
83 ensure!(
84 arg0.len() == arg1.len() && arg1.len() == arg2.len(),
85 InvalidFuncArgsSnafu {
86 err_msg: format!(
87 "The lengths of the vector are not aligned, args 0: {}, args 1: {}, args 2: {}",
88 arg0.len(),
89 arg1.len(),
90 arg2.len()
91 )
92 }
93 );
94
95 let len = arg0.len();
96 let mut result = BinaryVectorBuilder::with_capacity(len);
97 if len == 0 {
98 return Ok(result.to_vector());
99 }
100
101 let arg0_const = as_veclit_if_const(arg0)?;
102
103 for i in 0..len {
104 let arg0 = match arg0_const.as_ref() {
105 Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
106 None => as_veclit(arg0.get_ref(i))?,
107 };
108 let arg1 = arg1.get(i).as_i64();
109 let arg2 = arg2.get(i).as_i64();
110 let (Some(arg0), Some(arg1), Some(arg2)) = (arg0, arg1, arg2) else {
111 result.push_null();
112 continue;
113 };
114
115 ensure!(
116 0 <= arg1 && arg1 <= arg2 && arg2 as usize <= arg0.len(),
117 InvalidFuncArgsSnafu {
118 err_msg: format!(
119 "Invalid start and end indices: start={}, end={}, vec_len={}",
120 arg1,
121 arg2,
122 arg0.len()
123 )
124 }
125 );
126
127 let subvector = &arg0[arg1 as usize..arg2 as usize];
128 let binlit = veclit_to_binlit(subvector);
129 result.push(Some(&binlit));
130 }
131
132 Ok(result.to_vector())
133 }
134}
135
136impl Display for VectorSubvectorFunction {
137 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138 write!(f, "{}", NAME.to_ascii_uppercase())
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use std::sync::Arc;
145
146 use common_query::error::Error;
147 use datatypes::vectors::{Int64Vector, StringVector};
148
149 use super::*;
150 use crate::function::FunctionContext;
151 #[test]
152 fn test_subvector() {
153 let func = VectorSubvectorFunction;
154
155 let input0 = Arc::new(StringVector::from(vec![
156 Some("[1.0, 2.0, 3.0, 4.0, 5.0]".to_string()),
157 Some("[6.0, 7.0, 8.0, 9.0, 10.0]".to_string()),
158 None,
159 Some("[11.0, 12.0, 13.0]".to_string()),
160 ]));
161 let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(0), Some(0), Some(1)]));
162 let input2 = Arc::new(Int64Vector::from(vec![Some(3), Some(5), Some(2), Some(3)]));
163
164 let result = func
165 .eval(&FunctionContext::default(), &[input0, input1, input2])
166 .unwrap();
167
168 let result = result.as_ref();
169 assert_eq!(result.len(), 4);
170 assert_eq!(
171 result.get_ref(0).as_binary().unwrap(),
172 Some(veclit_to_binlit(&[2.0, 3.0]).as_slice())
173 );
174 assert_eq!(
175 result.get_ref(1).as_binary().unwrap(),
176 Some(veclit_to_binlit(&[6.0, 7.0, 8.0, 9.0, 10.0]).as_slice())
177 );
178 assert!(result.get_ref(2).is_null());
179 assert_eq!(
180 result.get_ref(3).as_binary().unwrap(),
181 Some(veclit_to_binlit(&[12.0, 13.0]).as_slice())
182 );
183 }
184 #[test]
185 fn test_subvector_error() {
186 let func = VectorSubvectorFunction;
187
188 let input0 = Arc::new(StringVector::from(vec![
189 Some("[1.0, 2.0, 3.0]".to_string()),
190 Some("[4.0, 5.0, 6.0]".to_string()),
191 ]));
192 let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(2)]));
193 let input2 = Arc::new(Int64Vector::from(vec![Some(3)]));
194
195 let result = func.eval(&FunctionContext::default(), &[input0, input1, input2]);
196
197 match result {
198 Err(Error::InvalidFuncArgs { err_msg, .. }) => {
199 assert_eq!(
200 err_msg,
201 "The lengths of the vector are not aligned, args 0: 2, args 1: 2, args 2: 1"
202 )
203 }
204 _ => unreachable!(),
205 }
206 }
207
208 #[test]
209 fn test_subvector_invalid_indices() {
210 let func = VectorSubvectorFunction;
211
212 let input0 = Arc::new(StringVector::from(vec![
213 Some("[1.0, 2.0, 3.0]".to_string()),
214 Some("[4.0, 5.0, 6.0]".to_string()),
215 ]));
216 let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(3)]));
217 let input2 = Arc::new(Int64Vector::from(vec![Some(3), Some(4)]));
218
219 let result = func.eval(&FunctionContext::default(), &[input0, input1, input2]);
220
221 match result {
222 Err(Error::InvalidFuncArgs { err_msg, .. }) => {
223 assert_eq!(
224 err_msg,
225 "Invalid start and end indices: start=3, end=4, vec_len=3"
226 )
227 }
228 _ => unreachable!(),
229 }
230 }
231}