common_function/scalars/
vector.rs1mod convert;
16mod distance;
17mod elem_avg;
18mod elem_product;
19mod elem_sum;
20pub mod impl_conv;
21mod scalar_add;
22mod scalar_mul;
23mod vector_add;
24mod vector_dim;
25mod vector_div;
26mod vector_kth_elem;
27mod vector_mul;
28mod vector_norm;
29mod vector_sub;
30mod vector_subvector;
31
32use std::borrow::Cow;
33
34use datafusion_common::{DataFusionError, Result, ScalarValue, utils};
35use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
36
37use crate::function_registry::FunctionRegistry;
38use crate::scalars::vector::impl_conv::as_veclit;
39
40pub(crate) struct VectorFunction;
41
42impl VectorFunction {
43 pub fn register(registry: &FunctionRegistry) {
44 registry.register_scalar(convert::ParseVectorFunction::default());
46 registry.register_scalar(convert::VectorToStringFunction::default());
47
48 registry.register_scalar(distance::CosDistanceFunction::default());
50 registry.register_scalar(distance::DotProductFunction::default());
51 registry.register_scalar(distance::L2SqDistanceFunction::default());
52
53 registry.register_scalar(scalar_add::ScalarAddFunction::default());
55 registry.register_scalar(scalar_mul::ScalarMulFunction::default());
56
57 registry.register_scalar(vector_add::VectorAddFunction::default());
59 registry.register_scalar(vector_sub::VectorSubFunction::default());
60 registry.register_scalar(vector_mul::VectorMulFunction::default());
61 registry.register_scalar(vector_div::VectorDivFunction::default());
62 registry.register_scalar(vector_norm::VectorNormFunction::default());
63 registry.register_scalar(vector_dim::VectorDimFunction::default());
64 registry.register_scalar(vector_kth_elem::VectorKthElemFunction::default());
65 registry.register_scalar(vector_subvector::VectorSubvectorFunction::default());
66 registry.register_scalar(elem_sum::ElemSumFunction::default());
67 registry.register_scalar(elem_product::ElemProductFunction::default());
68 registry.register_scalar(elem_avg::ElemAvgFunction::default());
69 }
70}
71
72macro_rules! try_get_scalar_value {
75 ($col: ident, $i: ident) => {
76 match $col {
77 datafusion::logical_expr::ColumnarValue::Array(a) => {
78 &datafusion_common::ScalarValue::try_from_array(a.as_ref(), $i)?
79 }
80 datafusion::logical_expr::ColumnarValue::Scalar(v) => v,
81 }
82 };
83}
84
85pub(crate) fn ensure_same_length(values: &[&ColumnarValue]) -> Result<usize> {
86 if values.is_empty() {
87 return Ok(0);
88 }
89
90 let mut array_len = None;
91 for v in values {
92 array_len = match (v, array_len) {
93 (ColumnarValue::Array(a), None) => Some(a.len()),
94 (ColumnarValue::Array(a), Some(array_len)) => {
95 if array_len == a.len() {
96 Some(array_len)
97 } else {
98 return Err(DataFusionError::Internal(format!(
99 "Arguments has mixed length. Expected length: {array_len}, found length: {}",
100 a.len()
101 )));
102 }
103 }
104 (ColumnarValue::Scalar(_), array_len) => array_len,
105 }
106 }
107
108 let array_len = array_len.unwrap_or(1);
110 Ok(array_len)
111}
112
113struct VectorCalculator<'a, F> {
114 name: &'a str,
115 func: F,
116}
117
118impl<F> VectorCalculator<'_, F>
119where
120 F: Fn(&ScalarValue, &ScalarValue) -> Result<ScalarValue>,
121{
122 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
123 let [arg0, arg1] = utils::take_function_args(self.name, &args.args)?;
124
125 if let (ColumnarValue::Scalar(v0), ColumnarValue::Scalar(v1)) = (arg0, arg1) {
126 let result = (self.func)(v0, v1)?;
127 return Ok(ColumnarValue::Scalar(result));
128 }
129
130 let len = ensure_same_length(&[arg0, arg1])?;
131 let mut results = Vec::with_capacity(len);
132 for i in 0..len {
133 let v0 = try_get_scalar_value!(arg0, i);
134 let v1 = try_get_scalar_value!(arg1, i);
135 results.push((self.func)(v0, v1)?);
136 }
137
138 let results = ScalarValue::iter_to_array(results.into_iter())?;
139 Ok(ColumnarValue::Array(results))
140 }
141}
142
143impl<F> VectorCalculator<'_, F>
144where
145 F: Fn(&Option<Cow<[f32]>>, &Option<Cow<[f32]>>) -> Result<ScalarValue>,
146{
147 fn invoke_with_vectors(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
148 let [arg0, arg1] = utils::take_function_args(self.name, &args.args)?;
149
150 if let (ColumnarValue::Scalar(v0), ColumnarValue::Scalar(v1)) = (arg0, arg1) {
151 let v0 = as_veclit(v0)?;
152 let v1 = as_veclit(v1)?;
153 let result = (self.func)(&v0, &v1)?;
154 return Ok(ColumnarValue::Scalar(result));
155 }
156
157 let len = ensure_same_length(&[arg0, arg1])?;
158 let mut results = Vec::with_capacity(len);
159
160 match (arg0, arg1) {
161 (ColumnarValue::Scalar(v0), ColumnarValue::Array(a1)) => {
162 let v0 = as_veclit(v0)?;
163 for i in 0..len {
164 let v1 = ScalarValue::try_from_array(a1, i)?;
165 let v1 = as_veclit(&v1)?;
166 results.push((self.func)(&v0, &v1)?);
167 }
168 }
169 (ColumnarValue::Array(a0), ColumnarValue::Scalar(v1)) => {
170 let v1 = as_veclit(v1)?;
171 for i in 0..len {
172 let v0 = ScalarValue::try_from_array(a0, i)?;
173 let v0 = as_veclit(&v0)?;
174 results.push((self.func)(&v0, &v1)?);
175 }
176 }
177 (ColumnarValue::Array(a0), ColumnarValue::Array(a1)) => {
178 for i in 0..len {
179 let v0 = ScalarValue::try_from_array(a0, i)?;
180 let v0 = as_veclit(&v0)?;
181 let v1 = ScalarValue::try_from_array(a1, i)?;
182 let v1 = as_veclit(&v1)?;
183 results.push((self.func)(&v0, &v1)?);
184 }
185 }
186 (ColumnarValue::Scalar(_), ColumnarValue::Scalar(_)) => {
187 unreachable!()
189 }
190 }
191
192 let results = ScalarValue::iter_to_array(results.into_iter())?;
193 Ok(ColumnarValue::Array(results))
194 }
195}
196
197impl<F> VectorCalculator<'_, F>
198where
199 F: Fn(&ScalarValue) -> Result<ScalarValue>,
200{
201 fn invoke_with_single_argument(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
202 let [arg0] = utils::take_function_args(self.name, &args.args)?;
203
204 let arg0 = match arg0 {
205 ColumnarValue::Scalar(v) => {
206 let result = (self.func)(v)?;
207 return Ok(ColumnarValue::Scalar(result));
208 }
209 ColumnarValue::Array(a) => a,
210 };
211
212 let len = arg0.len();
213 let mut results = Vec::with_capacity(len);
214 for i in 0..len {
215 let v = ScalarValue::try_from_array(arg0, i)?;
216 results.push((self.func)(&v)?);
217 }
218
219 let results = ScalarValue::iter_to_array(results.into_iter())?;
220 Ok(ColumnarValue::Array(results))
221 }
222}
223
224macro_rules! define_args_of_two_vector_literals_udf {
225 ($(#[$attr:meta])* $name: ident) => {
226 $(#[$attr])*
227 #[derive(Debug, Clone)]
228 pub(crate) struct $name {
229 signature: datafusion_expr::Signature,
230 }
231
232 impl Default for $name {
233 fn default() -> Self {
234 use arrow::datatypes::DataType;
235
236 Self {
237 signature: crate::helper::one_of_sigs2(
238 vec![
239 DataType::Utf8,
240 DataType::Utf8View,
241 DataType::Binary,
242 DataType::BinaryView,
243 ],
244 vec![
245 DataType::Utf8,
246 DataType::Utf8View,
247 DataType::Binary,
248 DataType::BinaryView,
249 ],
250 ),
251 }
252 }
253 }
254
255 impl std::fmt::Display for $name {
256 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257 write!(f, "{}", self.name().to_ascii_uppercase())
258 }
259 }
260 };
261}
262
263pub(crate) use define_args_of_two_vector_literals_udf;