common_function/scalars/vector/
vector_subvector.rs1use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use common_query::prelude::{Signature, TypeSignature};
20use datafusion_expr::Volatility;
21use datatypes::prelude::ConcreteDataType;
22use datatypes::scalars::ScalarVectorBuilder;
23use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
24use snafu::ensure;
25
26use crate::function::{Function, FunctionContext};
27use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
28
29const NAME: &str = "vec_subvector";
30
31#[derive(Debug, Clone, Default)]
48pub struct VectorSubvectorFunction;
49
50impl Function for VectorSubvectorFunction {
51 fn name(&self) -> &str {
52 NAME
53 }
54
55 fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
56 Ok(ConcreteDataType::binary_datatype())
57 }
58
59 fn signature(&self) -> Signature {
60 Signature::one_of(
61 vec![
62 TypeSignature::Exact(vec![
63 ConcreteDataType::string_datatype(),
64 ConcreteDataType::int64_datatype(),
65 ConcreteDataType::int64_datatype(),
66 ]),
67 TypeSignature::Exact(vec![
68 ConcreteDataType::binary_datatype(),
69 ConcreteDataType::int64_datatype(),
70 ConcreteDataType::int64_datatype(),
71 ]),
72 ],
73 Volatility::Immutable,
74 )
75 }
76
77 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
78 ensure!(
79 columns.len() == 3,
80 InvalidFuncArgsSnafu {
81 err_msg: format!(
82 "The length of the args is not correct, expect exactly three, have: {}",
83 columns.len()
84 )
85 }
86 );
87
88 let arg0 = &columns[0];
89 let arg1 = &columns[1];
90 let arg2 = &columns[2];
91
92 ensure!(
93 arg0.len() == arg1.len() && arg1.len() == arg2.len(),
94 InvalidFuncArgsSnafu {
95 err_msg: format!(
96 "The lengths of the vector are not aligned, args 0: {}, args 1: {}, args 2: {}",
97 arg0.len(),
98 arg1.len(),
99 arg2.len()
100 )
101 }
102 );
103
104 let len = arg0.len();
105 let mut result = BinaryVectorBuilder::with_capacity(len);
106 if len == 0 {
107 return Ok(result.to_vector());
108 }
109
110 let arg0_const = as_veclit_if_const(arg0)?;
111
112 for i in 0..len {
113 let arg0 = match arg0_const.as_ref() {
114 Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
115 None => as_veclit(arg0.get_ref(i))?,
116 };
117 let arg1 = arg1.get(i).as_i64();
118 let arg2 = arg2.get(i).as_i64();
119 let (Some(arg0), Some(arg1), Some(arg2)) = (arg0, arg1, arg2) else {
120 result.push_null();
121 continue;
122 };
123
124 ensure!(
125 0 <= arg1 && arg1 <= arg2 && arg2 as usize <= arg0.len(),
126 InvalidFuncArgsSnafu {
127 err_msg: format!(
128 "Invalid start and end indices: start={}, end={}, vec_len={}",
129 arg1,
130 arg2,
131 arg0.len()
132 )
133 }
134 );
135
136 let subvector = &arg0[arg1 as usize..arg2 as usize];
137 let binlit = veclit_to_binlit(subvector);
138 result.push(Some(&binlit));
139 }
140
141 Ok(result.to_vector())
142 }
143}
144
145impl Display for VectorSubvectorFunction {
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 write!(f, "{}", NAME.to_ascii_uppercase())
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use std::sync::Arc;
154
155 use common_query::error::Error;
156 use datatypes::vectors::{Int64Vector, StringVector};
157
158 use super::*;
159 use crate::function::FunctionContext;
160 #[test]
161 fn test_subvector() {
162 let func = VectorSubvectorFunction;
163
164 let input0 = Arc::new(StringVector::from(vec![
165 Some("[1.0, 2.0, 3.0, 4.0, 5.0]".to_string()),
166 Some("[6.0, 7.0, 8.0, 9.0, 10.0]".to_string()),
167 None,
168 Some("[11.0, 12.0, 13.0]".to_string()),
169 ]));
170 let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(0), Some(0), Some(1)]));
171 let input2 = Arc::new(Int64Vector::from(vec![Some(3), Some(5), Some(2), Some(3)]));
172
173 let result = func
174 .eval(&FunctionContext::default(), &[input0, input1, input2])
175 .unwrap();
176
177 let result = result.as_ref();
178 assert_eq!(result.len(), 4);
179 assert_eq!(
180 result.get_ref(0).as_binary().unwrap(),
181 Some(veclit_to_binlit(&[2.0, 3.0]).as_slice())
182 );
183 assert_eq!(
184 result.get_ref(1).as_binary().unwrap(),
185 Some(veclit_to_binlit(&[6.0, 7.0, 8.0, 9.0, 10.0]).as_slice())
186 );
187 assert!(result.get_ref(2).is_null());
188 assert_eq!(
189 result.get_ref(3).as_binary().unwrap(),
190 Some(veclit_to_binlit(&[12.0, 13.0]).as_slice())
191 );
192 }
193 #[test]
194 fn test_subvector_error() {
195 let func = VectorSubvectorFunction;
196
197 let input0 = Arc::new(StringVector::from(vec![
198 Some("[1.0, 2.0, 3.0]".to_string()),
199 Some("[4.0, 5.0, 6.0]".to_string()),
200 ]));
201 let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(2)]));
202 let input2 = Arc::new(Int64Vector::from(vec![Some(3)]));
203
204 let result = func.eval(&FunctionContext::default(), &[input0, input1, input2]);
205
206 match result {
207 Err(Error::InvalidFuncArgs { err_msg, .. }) => {
208 assert_eq!(
209 err_msg,
210 "The lengths of the vector are not aligned, args 0: 2, args 1: 2, args 2: 1"
211 )
212 }
213 _ => unreachable!(),
214 }
215 }
216
217 #[test]
218 fn test_subvector_invalid_indices() {
219 let func = VectorSubvectorFunction;
220
221 let input0 = Arc::new(StringVector::from(vec![
222 Some("[1.0, 2.0, 3.0]".to_string()),
223 Some("[4.0, 5.0, 6.0]".to_string()),
224 ]));
225 let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(3)]));
226 let input2 = Arc::new(Int64Vector::from(vec![Some(3), Some(4)]));
227
228 let result = func.eval(&FunctionContext::default(), &[input0, input1, input2]);
229
230 match result {
231 Err(Error::InvalidFuncArgs { err_msg, .. }) => {
232 assert_eq!(
233 err_msg,
234 "Invalid start and end indices: start=3, end=4, vec_len=3"
235 )
236 }
237 _ => unreachable!(),
238 }
239 }
240}