common_function/scalars/vector/
vector_subvector.rs1use std::fmt::Display;
16use std::sync::Arc;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion::arrow::array::{Array, AsArray, BinaryViewBuilder};
20use datafusion::arrow::datatypes::Int64Type;
21use datafusion::logical_expr::ColumnarValue;
22use datafusion_common::{ScalarValue, utils};
23use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
24use datatypes::arrow::datatypes::DataType;
25use snafu::ensure;
26
27use crate::function::Function;
28use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
29
30const NAME: &str = "vec_subvector";
31
32#[derive(Debug, Clone, Default)]
49pub struct VectorSubvectorFunction;
50
51impl Function for VectorSubvectorFunction {
52 fn name(&self) -> &str {
53 NAME
54 }
55
56 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
57 Ok(DataType::BinaryView)
58 }
59
60 fn signature(&self) -> Signature {
61 Signature::one_of(
62 vec![
63 TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64, DataType::Int64]),
64 TypeSignature::Exact(vec![DataType::Binary, DataType::Int64, DataType::Int64]),
65 ],
66 Volatility::Immutable,
67 )
68 }
69
70 fn invoke_with_args(
71 &self,
72 args: ScalarFunctionArgs,
73 ) -> datafusion_common::Result<ColumnarValue> {
74 let args = ColumnarValue::values_to_arrays(&args.args)?;
75 let [arg0, arg1, arg2] = utils::take_function_args(self.name(), args)?;
76 let arg1 = arg1.as_primitive::<Int64Type>();
77 let arg2 = arg2.as_primitive::<Int64Type>();
78
79 let len = arg0.len();
80 let mut builder = BinaryViewBuilder::with_capacity(len);
81 if len == 0 {
82 return Ok(ColumnarValue::Array(Arc::new(builder.finish())));
83 }
84
85 for i in 0..len {
86 let v = ScalarValue::try_from_array(&arg0, i)?;
87 let arg0 = as_veclit(&v)?;
88 let arg1 = arg1.is_valid(i).then(|| arg1.value(i));
89 let arg2 = arg2.is_valid(i).then(|| arg2.value(i));
90 let (Some(arg0), Some(arg1), Some(arg2)) = (arg0, arg1, arg2) else {
91 builder.append_null();
92 continue;
93 };
94
95 ensure!(
96 0 <= arg1 && arg1 <= arg2 && arg2 as usize <= arg0.len(),
97 InvalidFuncArgsSnafu {
98 err_msg: format!(
99 "Invalid start and end indices: start={}, end={}, vec_len={}",
100 arg1,
101 arg2,
102 arg0.len()
103 )
104 }
105 );
106
107 let subvector = &arg0[arg1 as usize..arg2 as usize];
108 let binlit = veclit_to_binlit(subvector);
109 builder.append_value(&binlit);
110 }
111
112 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
113 }
114}
115
116impl Display for VectorSubvectorFunction {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 write!(f, "{}", NAME.to_ascii_uppercase())
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use std::sync::Arc;
125
126 use arrow_schema::Field;
127 use datafusion::arrow::array::{ArrayRef, Int64Array, StringViewArray};
128 use datafusion_common::config::ConfigOptions;
129
130 use super::*;
131
132 #[test]
133 fn test_subvector() {
134 let func = VectorSubvectorFunction;
135
136 let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
137 Some("[1.0, 2.0, 3.0, 4.0, 5.0]".to_string()),
138 Some("[6.0, 7.0, 8.0, 9.0, 10.0]".to_string()),
139 None,
140 Some("[11.0, 12.0, 13.0]".to_string()),
141 ]));
142 let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(0), Some(0), Some(1)]));
143 let input2: ArrayRef = Arc::new(Int64Array::from(vec![Some(3), Some(5), Some(2), Some(3)]));
144
145 let args = ScalarFunctionArgs {
146 args: vec![
147 ColumnarValue::Array(input0),
148 ColumnarValue::Array(input1),
149 ColumnarValue::Array(input2),
150 ],
151 arg_fields: vec![],
152 number_rows: 5,
153 return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
154 config_options: Arc::new(ConfigOptions::new()),
155 };
156 let result = func
157 .invoke_with_args(args)
158 .and_then(|x| x.to_array(5))
159 .unwrap();
160
161 let result = result.as_binary_view();
162 assert_eq!(result.len(), 4);
163 assert_eq!(result.value(0), veclit_to_binlit(&[2.0, 3.0]).as_slice());
164 assert_eq!(
165 result.value(1),
166 veclit_to_binlit(&[6.0, 7.0, 8.0, 9.0, 10.0]).as_slice()
167 );
168 assert!(result.is_null(2));
169 assert_eq!(result.value(3), veclit_to_binlit(&[12.0, 13.0]).as_slice());
170 }
171 #[test]
172 fn test_subvector_error() {
173 let func = VectorSubvectorFunction;
174
175 let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
176 Some("[1.0, 2.0, 3.0]".to_string()),
177 Some("[4.0, 5.0, 6.0]".to_string()),
178 ]));
179 let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2)]));
180 let input2: ArrayRef = Arc::new(Int64Array::from(vec![Some(3)]));
181
182 let args = ScalarFunctionArgs {
183 args: vec![
184 ColumnarValue::Array(input0),
185 ColumnarValue::Array(input1),
186 ColumnarValue::Array(input2),
187 ],
188 arg_fields: vec![],
189 number_rows: 3,
190 return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
191 config_options: Arc::new(ConfigOptions::new()),
192 };
193 let e = func.invoke_with_args(args).unwrap_err();
194 assert!(e.to_string().starts_with(
195 "Internal error: Arguments has mixed length. Expected length: 2, found length: 1."
196 ));
197 }
198
199 #[test]
200 fn test_subvector_invalid_indices() {
201 let func = VectorSubvectorFunction;
202
203 let input0 = Arc::new(StringViewArray::from(vec![
204 Some("[1.0, 2.0, 3.0]".to_string()),
205 Some("[4.0, 5.0, 6.0]".to_string()),
206 ]));
207 let input1 = Arc::new(Int64Array::from(vec![Some(1), Some(3)]));
208 let input2 = Arc::new(Int64Array::from(vec![Some(3), Some(4)]));
209
210 let args = ScalarFunctionArgs {
211 args: vec![
212 ColumnarValue::Array(input0),
213 ColumnarValue::Array(input1),
214 ColumnarValue::Array(input2),
215 ],
216 arg_fields: vec![],
217 number_rows: 3,
218 return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
219 config_options: Arc::new(ConfigOptions::new()),
220 };
221 let e = func.invoke_with_args(args).unwrap_err();
222 assert!(e.to_string().starts_with("External error: Invalid function args: Invalid start and end indices: start=3, end=4, vec_len=3"));
223 }
224}