common_function/scalars/
vector.rs1mod convert;
16pub mod distance;
17mod elem_avg;
18mod elem_product;
19mod elem_sum;
20pub mod impl_conv;
21mod scalar_add;
22mod scalar_mul;
23mod vector_add;
24mod vector_dim;
25mod vector_div;
26mod vector_kth_elem;
27mod vector_mul;
28mod vector_norm;
29mod vector_sub;
30mod vector_subvector;
31
32use std::borrow::Cow;
33
34use datafusion_common::{DataFusionError, Result, ScalarValue, utils};
35use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
36use datatypes::arrow::array::new_empty_array;
37
38use crate::function_registry::FunctionRegistry;
39use crate::scalars::vector::impl_conv::as_veclit;
40
41pub(crate) struct VectorFunction;
42
43impl VectorFunction {
44 pub fn register(registry: &FunctionRegistry) {
45 registry.register_scalar(convert::ParseVectorFunction::default());
47 registry.register_scalar(convert::VectorToStringFunction::default());
48
49 registry.register_scalar(distance::CosDistanceFunction::default());
51 registry.register_scalar(distance::DotProductFunction::default());
52 registry.register_scalar(distance::L2SqDistanceFunction::default());
53
54 registry.register_scalar(scalar_add::ScalarAddFunction::default());
56 registry.register_scalar(scalar_mul::ScalarMulFunction::default());
57
58 registry.register_scalar(vector_add::VectorAddFunction::default());
60 registry.register_scalar(vector_sub::VectorSubFunction::default());
61 registry.register_scalar(vector_mul::VectorMulFunction::default());
62 registry.register_scalar(vector_div::VectorDivFunction::default());
63 registry.register_scalar(vector_norm::VectorNormFunction::default());
64 registry.register_scalar(vector_dim::VectorDimFunction::default());
65 registry.register_scalar(vector_kth_elem::VectorKthElemFunction::default());
66 registry.register_scalar(vector_subvector::VectorSubvectorFunction::default());
67 registry.register_scalar(elem_sum::ElemSumFunction::default());
68 registry.register_scalar(elem_product::ElemProductFunction::default());
69 registry.register_scalar(elem_avg::ElemAvgFunction::default());
70 }
71}
72
73macro_rules! try_get_scalar_value {
76 ($col: ident, $i: ident) => {
77 match $col {
78 datafusion::logical_expr::ColumnarValue::Array(a) => {
79 &datafusion_common::ScalarValue::try_from_array(a.as_ref(), $i)?
80 }
81 datafusion::logical_expr::ColumnarValue::Scalar(v) => v,
82 }
83 };
84}
85
86pub(crate) fn ensure_same_length(values: &[&ColumnarValue]) -> Result<usize> {
87 if values.is_empty() {
88 return Ok(0);
89 }
90
91 let mut array_len = None;
92 for v in values {
93 array_len = match (v, array_len) {
94 (ColumnarValue::Array(a), None) => Some(a.len()),
95 (ColumnarValue::Array(a), Some(array_len)) => {
96 if array_len == a.len() {
97 Some(array_len)
98 } else {
99 return Err(DataFusionError::Internal(format!(
100 "Arguments has mixed length. Expected length: {array_len}, found length: {}",
101 a.len()
102 )));
103 }
104 }
105 (ColumnarValue::Scalar(_), array_len) => array_len,
106 }
107 }
108
109 let array_len = array_len.unwrap_or(1);
111 Ok(array_len)
112}
113
114struct VectorCalculator<'a, F> {
115 name: &'a str,
116 func: F,
117}
118
119impl<F> VectorCalculator<'_, F>
120where
121 F: Fn(&ScalarValue, &ScalarValue) -> Result<ScalarValue>,
122{
123 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
124 let [arg0, arg1] = utils::take_function_args(self.name, &args.args)?;
125
126 if let (ColumnarValue::Scalar(v0), ColumnarValue::Scalar(v1)) = (arg0, arg1) {
127 let result = (self.func)(v0, v1)?;
128 return Ok(ColumnarValue::Scalar(result));
129 }
130
131 let len = ensure_same_length(&[arg0, arg1])?;
132 if len == 0 {
133 return Ok(ColumnarValue::Array(new_empty_array(
134 args.return_field.data_type(),
135 )));
136 }
137 let mut results = Vec::with_capacity(len);
138 for i in 0..len {
139 let v0 = try_get_scalar_value!(arg0, i);
140 let v1 = try_get_scalar_value!(arg1, i);
141 results.push((self.func)(v0, v1)?);
142 }
143
144 let results = ScalarValue::iter_to_array(results.into_iter())?;
145 Ok(ColumnarValue::Array(results))
146 }
147}
148
149impl<F> VectorCalculator<'_, F>
150where
151 F: Fn(&Option<Cow<[f32]>>, &Option<Cow<[f32]>>) -> Result<ScalarValue>,
152{
153 fn invoke_with_vectors(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
154 let [arg0, arg1] = utils::take_function_args(self.name, &args.args)?;
155
156 if let (ColumnarValue::Scalar(v0), ColumnarValue::Scalar(v1)) = (arg0, arg1) {
157 let v0 = as_veclit(v0)?;
158 let v1 = as_veclit(v1)?;
159 let result = (self.func)(&v0, &v1)?;
160 return Ok(ColumnarValue::Scalar(result));
161 }
162
163 let len = ensure_same_length(&[arg0, arg1])?;
164 if len == 0 {
165 return Ok(ColumnarValue::Array(new_empty_array(
166 args.return_field.data_type(),
167 )));
168 }
169 let mut results = Vec::with_capacity(len);
170
171 match (arg0, arg1) {
172 (ColumnarValue::Scalar(v0), ColumnarValue::Array(a1)) => {
173 let v0 = as_veclit(v0)?;
174 for i in 0..len {
175 let v1 = ScalarValue::try_from_array(a1, i)?;
176 let v1 = as_veclit(&v1)?;
177 results.push((self.func)(&v0, &v1)?);
178 }
179 }
180 (ColumnarValue::Array(a0), ColumnarValue::Scalar(v1)) => {
181 let v1 = as_veclit(v1)?;
182 for i in 0..len {
183 let v0 = ScalarValue::try_from_array(a0, i)?;
184 let v0 = as_veclit(&v0)?;
185 results.push((self.func)(&v0, &v1)?);
186 }
187 }
188 (ColumnarValue::Array(a0), ColumnarValue::Array(a1)) => {
189 for i in 0..len {
190 let v0 = ScalarValue::try_from_array(a0, i)?;
191 let v0 = as_veclit(&v0)?;
192 let v1 = ScalarValue::try_from_array(a1, i)?;
193 let v1 = as_veclit(&v1)?;
194 results.push((self.func)(&v0, &v1)?);
195 }
196 }
197 (ColumnarValue::Scalar(_), ColumnarValue::Scalar(_)) => {
198 unreachable!()
200 }
201 }
202
203 let results = ScalarValue::iter_to_array(results.into_iter())?;
204 Ok(ColumnarValue::Array(results))
205 }
206}
207
208impl<F> VectorCalculator<'_, F>
209where
210 F: Fn(&ScalarValue) -> Result<ScalarValue>,
211{
212 fn invoke_with_single_argument(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
213 let [arg0] = utils::take_function_args(self.name, &args.args)?;
214
215 let arg0 = match arg0 {
216 ColumnarValue::Scalar(v) => {
217 let result = (self.func)(v)?;
218 return Ok(ColumnarValue::Scalar(result));
219 }
220 ColumnarValue::Array(a) => a,
221 };
222
223 let len = arg0.len();
224 if len == 0 {
225 return Ok(ColumnarValue::Array(new_empty_array(
226 args.return_field.data_type(),
227 )));
228 }
229 let mut results = Vec::with_capacity(len);
230 for i in 0..len {
231 let v = ScalarValue::try_from_array(arg0, i)?;
232 results.push((self.func)(&v)?);
233 }
234
235 let results = ScalarValue::iter_to_array(results.into_iter())?;
236 Ok(ColumnarValue::Array(results))
237 }
238}
239
240macro_rules! define_args_of_two_vector_literals_udf {
241 ($(#[$attr:meta])* $name: ident) => {
242 $(#[$attr])*
243 #[derive(Debug, Clone)]
244 pub(crate) struct $name {
245 signature: datafusion_expr::Signature,
246 }
247
248 impl Default for $name {
249 fn default() -> Self {
250 use arrow::datatypes::DataType;
251
252 Self {
253 signature: crate::helper::one_of_sigs2(
254 vec![
255 DataType::Utf8,
256 DataType::Utf8View,
257 DataType::Binary,
258 DataType::BinaryView,
259 ],
260 vec![
261 DataType::Utf8,
262 DataType::Utf8View,
263 DataType::Binary,
264 DataType::BinaryView,
265 ],
266 ),
267 }
268 }
269 }
270
271 impl std::fmt::Display for $name {
272 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273 write!(f, "{}", self.name().to_ascii_uppercase())
274 }
275 }
276 };
277}
278
279pub(crate) use define_args_of_two_vector_literals_udf;