common_function/scalars/
vector.rs1mod convert;
16mod distance;
17mod elem_product;
18mod elem_sum;
19pub mod impl_conv;
20mod scalar_add;
21mod scalar_mul;
22mod vector_add;
23mod vector_dim;
24mod vector_div;
25mod vector_kth_elem;
26mod vector_mul;
27mod vector_norm;
28mod vector_sub;
29mod vector_subvector;
30
31use std::borrow::Cow;
32
33use datafusion_common::{DataFusionError, Result, ScalarValue, utils};
34use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
35
36use crate::function_registry::FunctionRegistry;
37use crate::scalars::vector::impl_conv::as_veclit;
38
39pub(crate) struct VectorFunction;
40
41impl VectorFunction {
42 pub fn register(registry: &FunctionRegistry) {
43 registry.register_scalar(convert::ParseVectorFunction::default());
45 registry.register_scalar(convert::VectorToStringFunction::default());
46
47 registry.register_scalar(distance::CosDistanceFunction::default());
49 registry.register_scalar(distance::DotProductFunction::default());
50 registry.register_scalar(distance::L2SqDistanceFunction::default());
51
52 registry.register_scalar(scalar_add::ScalarAddFunction::default());
54 registry.register_scalar(scalar_mul::ScalarMulFunction::default());
55
56 registry.register_scalar(vector_add::VectorAddFunction::default());
58 registry.register_scalar(vector_sub::VectorSubFunction::default());
59 registry.register_scalar(vector_mul::VectorMulFunction::default());
60 registry.register_scalar(vector_div::VectorDivFunction::default());
61 registry.register_scalar(vector_norm::VectorNormFunction::default());
62 registry.register_scalar(vector_dim::VectorDimFunction::default());
63 registry.register_scalar(vector_kth_elem::VectorKthElemFunction::default());
64 registry.register_scalar(vector_subvector::VectorSubvectorFunction::default());
65 registry.register_scalar(elem_sum::ElemSumFunction::default());
66 registry.register_scalar(elem_product::ElemProductFunction::default());
67 }
68}
69
70macro_rules! try_get_scalar_value {
73 ($col: ident, $i: ident) => {
74 match $col {
75 datafusion::logical_expr::ColumnarValue::Array(a) => {
76 &datafusion_common::ScalarValue::try_from_array(a.as_ref(), $i)?
77 }
78 datafusion::logical_expr::ColumnarValue::Scalar(v) => v,
79 }
80 };
81}
82
83pub(crate) fn ensure_same_length(values: &[&ColumnarValue]) -> Result<usize> {
84 if values.is_empty() {
85 return Ok(0);
86 }
87
88 let mut array_len = None;
89 for v in values {
90 array_len = match (v, array_len) {
91 (ColumnarValue::Array(a), None) => Some(a.len()),
92 (ColumnarValue::Array(a), Some(array_len)) => {
93 if array_len == a.len() {
94 Some(array_len)
95 } else {
96 return Err(DataFusionError::Internal(format!(
97 "Arguments has mixed length. Expected length: {array_len}, found length: {}",
98 a.len()
99 )));
100 }
101 }
102 (ColumnarValue::Scalar(_), array_len) => array_len,
103 }
104 }
105
106 let array_len = array_len.unwrap_or(1);
108 Ok(array_len)
109}
110
111struct VectorCalculator<'a, F> {
112 name: &'a str,
113 func: F,
114}
115
116impl<F> VectorCalculator<'_, F>
117where
118 F: Fn(&ScalarValue, &ScalarValue) -> Result<ScalarValue>,
119{
120 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
121 let [arg0, arg1] = utils::take_function_args(self.name, &args.args)?;
122
123 if let (ColumnarValue::Scalar(v0), ColumnarValue::Scalar(v1)) = (arg0, arg1) {
124 let result = (self.func)(v0, v1)?;
125 return Ok(ColumnarValue::Scalar(result));
126 }
127
128 let len = ensure_same_length(&[arg0, arg1])?;
129 let mut results = Vec::with_capacity(len);
130 for i in 0..len {
131 let v0 = try_get_scalar_value!(arg0, i);
132 let v1 = try_get_scalar_value!(arg1, i);
133 results.push((self.func)(v0, v1)?);
134 }
135
136 let results = ScalarValue::iter_to_array(results.into_iter())?;
137 Ok(ColumnarValue::Array(results))
138 }
139}
140
141impl<F> VectorCalculator<'_, F>
142where
143 F: Fn(&Option<Cow<[f32]>>, &Option<Cow<[f32]>>) -> Result<ScalarValue>,
144{
145 fn invoke_with_vectors(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
146 let [arg0, arg1] = utils::take_function_args(self.name, &args.args)?;
147
148 if let (ColumnarValue::Scalar(v0), ColumnarValue::Scalar(v1)) = (arg0, arg1) {
149 let v0 = as_veclit(v0)?;
150 let v1 = as_veclit(v1)?;
151 let result = (self.func)(&v0, &v1)?;
152 return Ok(ColumnarValue::Scalar(result));
153 }
154
155 let len = ensure_same_length(&[arg0, arg1])?;
156 let mut results = Vec::with_capacity(len);
157
158 match (arg0, arg1) {
159 (ColumnarValue::Scalar(v0), ColumnarValue::Array(a1)) => {
160 let v0 = as_veclit(v0)?;
161 for i in 0..len {
162 let v1 = ScalarValue::try_from_array(a1, i)?;
163 let v1 = as_veclit(&v1)?;
164 results.push((self.func)(&v0, &v1)?);
165 }
166 }
167 (ColumnarValue::Array(a0), ColumnarValue::Scalar(v1)) => {
168 let v1 = as_veclit(v1)?;
169 for i in 0..len {
170 let v0 = ScalarValue::try_from_array(a0, i)?;
171 let v0 = as_veclit(&v0)?;
172 results.push((self.func)(&v0, &v1)?);
173 }
174 }
175 (ColumnarValue::Array(a0), ColumnarValue::Array(a1)) => {
176 for i in 0..len {
177 let v0 = ScalarValue::try_from_array(a0, i)?;
178 let v0 = as_veclit(&v0)?;
179 let v1 = ScalarValue::try_from_array(a1, i)?;
180 let v1 = as_veclit(&v1)?;
181 results.push((self.func)(&v0, &v1)?);
182 }
183 }
184 (ColumnarValue::Scalar(_), ColumnarValue::Scalar(_)) => {
185 unreachable!()
187 }
188 }
189
190 let results = ScalarValue::iter_to_array(results.into_iter())?;
191 Ok(ColumnarValue::Array(results))
192 }
193}
194
195impl<F> VectorCalculator<'_, F>
196where
197 F: Fn(&ScalarValue) -> Result<ScalarValue>,
198{
199 fn invoke_with_single_argument(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
200 let [arg0] = utils::take_function_args(self.name, &args.args)?;
201
202 let arg0 = match arg0 {
203 ColumnarValue::Scalar(v) => {
204 let result = (self.func)(v)?;
205 return Ok(ColumnarValue::Scalar(result));
206 }
207 ColumnarValue::Array(a) => a,
208 };
209
210 let len = arg0.len();
211 let mut results = Vec::with_capacity(len);
212 for i in 0..len {
213 let v = ScalarValue::try_from_array(arg0, i)?;
214 results.push((self.func)(&v)?);
215 }
216
217 let results = ScalarValue::iter_to_array(results.into_iter())?;
218 Ok(ColumnarValue::Array(results))
219 }
220}
221
222macro_rules! define_args_of_two_vector_literals_udf {
223 ($(#[$attr:meta])* $name: ident) => {
224 $(#[$attr])*
225 #[derive(Debug, Clone)]
226 pub(crate) struct $name {
227 signature: datafusion_expr::Signature,
228 }
229
230 impl Default for $name {
231 fn default() -> Self {
232 use arrow::datatypes::DataType;
233
234 Self {
235 signature: crate::helper::one_of_sigs2(
236 vec![
237 DataType::Utf8,
238 DataType::Utf8View,
239 DataType::Binary,
240 DataType::BinaryView,
241 ],
242 vec![
243 DataType::Utf8,
244 DataType::Utf8View,
245 DataType::Binary,
246 DataType::BinaryView,
247 ],
248 ),
249 }
250 }
251 }
252
253 impl std::fmt::Display for $name {
254 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255 write!(f, "{}", self.name().to_ascii_uppercase())
256 }
257 }
258 };
259}
260
261pub(crate) use define_args_of_two_vector_literals_udf;