common_function/scalars/vector/
vector_subvector.rs1use std::fmt::Display;
16use std::sync::Arc;
17
18use common_query::error::InvalidFuncArgsSnafu;
19use datafusion::arrow::array::{Array, AsArray, BinaryViewBuilder};
20use datafusion::arrow::datatypes::Int64Type;
21use datafusion::logical_expr::ColumnarValue;
22use datafusion_common::ScalarValue;
23use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
24use datatypes::arrow::datatypes::DataType;
25use snafu::ensure;
26
27use crate::function::{Function, extract_args};
28use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
29
30const NAME: &str = "vec_subvector";
31
32#[derive(Debug, Clone)]
49pub(crate) struct VectorSubvectorFunction {
50 signature: Signature,
51}
52
53impl Default for VectorSubvectorFunction {
54 fn default() -> Self {
55 Self {
56 signature: Signature::one_of(
57 vec![
58 TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64, DataType::Int64]),
59 TypeSignature::Exact(vec![DataType::Binary, DataType::Int64, DataType::Int64]),
60 ],
61 Volatility::Immutable,
62 ),
63 }
64 }
65}
66
67impl Function for VectorSubvectorFunction {
68 fn name(&self) -> &str {
69 NAME
70 }
71
72 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
73 Ok(DataType::BinaryView)
74 }
75
76 fn signature(&self) -> &Signature {
77 &self.signature
78 }
79
80 fn invoke_with_args(
81 &self,
82 args: ScalarFunctionArgs,
83 ) -> datafusion_common::Result<ColumnarValue> {
84 let [arg0, arg1, arg2] = extract_args(self.name(), &args)?;
85 let arg1 = arg1.as_primitive::<Int64Type>();
86 let arg2 = arg2.as_primitive::<Int64Type>();
87
88 let len = arg0.len();
89 let mut builder = BinaryViewBuilder::with_capacity(len);
90 if len == 0 {
91 return Ok(ColumnarValue::Array(Arc::new(builder.finish())));
92 }
93
94 for i in 0..len {
95 let v = ScalarValue::try_from_array(&arg0, i)?;
96 let arg0 = as_veclit(&v)?;
97 let arg1 = arg1.is_valid(i).then(|| arg1.value(i));
98 let arg2 = arg2.is_valid(i).then(|| arg2.value(i));
99 let (Some(arg0), Some(arg1), Some(arg2)) = (arg0, arg1, arg2) else {
100 builder.append_null();
101 continue;
102 };
103
104 ensure!(
105 0 <= arg1 && arg1 <= arg2 && arg2 as usize <= arg0.len(),
106 InvalidFuncArgsSnafu {
107 err_msg: format!(
108 "Invalid start and end indices: start={}, end={}, vec_len={}",
109 arg1,
110 arg2,
111 arg0.len()
112 )
113 }
114 );
115
116 let subvector = &arg0[arg1 as usize..arg2 as usize];
117 let binlit = veclit_to_binlit(subvector);
118 builder.append_value(&binlit);
119 }
120
121 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
122 }
123}
124
125impl Display for VectorSubvectorFunction {
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 write!(f, "{}", NAME.to_ascii_uppercase())
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use std::sync::Arc;
134
135 use arrow_schema::Field;
136 use datafusion::arrow::array::{ArrayRef, Int64Array, StringViewArray};
137 use datafusion_common::config::ConfigOptions;
138
139 use super::*;
140
141 #[test]
142 fn test_subvector() {
143 let func = VectorSubvectorFunction::default();
144
145 let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
146 Some("[1.0, 2.0, 3.0, 4.0, 5.0]".to_string()),
147 Some("[6.0, 7.0, 8.0, 9.0, 10.0]".to_string()),
148 None,
149 Some("[11.0, 12.0, 13.0]".to_string()),
150 ]));
151 let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(0), Some(0), Some(1)]));
152 let input2: ArrayRef = Arc::new(Int64Array::from(vec![Some(3), Some(5), Some(2), Some(3)]));
153
154 let args = ScalarFunctionArgs {
155 args: vec![
156 ColumnarValue::Array(input0),
157 ColumnarValue::Array(input1),
158 ColumnarValue::Array(input2),
159 ],
160 arg_fields: vec![],
161 number_rows: 5,
162 return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
163 config_options: Arc::new(ConfigOptions::new()),
164 };
165 let result = func
166 .invoke_with_args(args)
167 .and_then(|x| x.to_array(5))
168 .unwrap();
169
170 let result = result.as_binary_view();
171 assert_eq!(result.len(), 4);
172 assert_eq!(result.value(0), veclit_to_binlit(&[2.0, 3.0]).as_slice());
173 assert_eq!(
174 result.value(1),
175 veclit_to_binlit(&[6.0, 7.0, 8.0, 9.0, 10.0]).as_slice()
176 );
177 assert!(result.is_null(2));
178 assert_eq!(result.value(3), veclit_to_binlit(&[12.0, 13.0]).as_slice());
179 }
180 #[test]
181 fn test_subvector_error() {
182 let func = VectorSubvectorFunction::default();
183
184 let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
185 Some("[1.0, 2.0, 3.0]".to_string()),
186 Some("[4.0, 5.0, 6.0]".to_string()),
187 ]));
188 let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2)]));
189 let input2: ArrayRef = Arc::new(Int64Array::from(vec![Some(3)]));
190
191 let args = ScalarFunctionArgs {
192 args: vec![
193 ColumnarValue::Array(input0),
194 ColumnarValue::Array(input1),
195 ColumnarValue::Array(input2),
196 ],
197 arg_fields: vec![],
198 number_rows: 3,
199 return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
200 config_options: Arc::new(ConfigOptions::new()),
201 };
202 let e = func.invoke_with_args(args).unwrap_err();
203 assert!(e.to_string().starts_with(
204 "Internal error: Arguments has mixed length. Expected length: 2, found length: 1."
205 ));
206 }
207
208 #[test]
209 fn test_subvector_invalid_indices() {
210 let func = VectorSubvectorFunction::default();
211
212 let input0 = Arc::new(StringViewArray::from(vec![
213 Some("[1.0, 2.0, 3.0]".to_string()),
214 Some("[4.0, 5.0, 6.0]".to_string()),
215 ]));
216 let input1 = Arc::new(Int64Array::from(vec![Some(1), Some(3)]));
217 let input2 = Arc::new(Int64Array::from(vec![Some(3), Some(4)]));
218
219 let args = ScalarFunctionArgs {
220 args: vec![
221 ColumnarValue::Array(input0),
222 ColumnarValue::Array(input1),
223 ColumnarValue::Array(input2),
224 ],
225 arg_fields: vec![],
226 number_rows: 3,
227 return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
228 config_options: Arc::new(ConfigOptions::new()),
229 };
230 let e = func.invoke_with_args(args).unwrap_err();
231 assert!(e.to_string().starts_with("External error: Invalid function args: Invalid start and end indices: start=3, end=4, vec_len=3"));
232 }
233}