1mod cos;
16mod dot;
17mod l2sq;
18
19pub const VEC_COS_DISTANCE: &str = "vec_cos_distance";
20pub const VEC_L2SQ_DISTANCE: &str = "vec_l2sq_distance";
21pub const VEC_DOT_PRODUCT: &str = "vec_dot_product";
22
23use std::borrow::Cow;
24use std::fmt::Display;
25
26use datafusion::logical_expr::ColumnarValue;
27use datafusion_common::ScalarValue;
28use datafusion_expr::{ScalarFunctionArgs, Signature};
29use datatypes::arrow::datatypes::DataType;
30
31use crate::function::Function;
32use crate::helper;
33
34macro_rules! define_distance_function {
35 ($StructName:ident, $display_name:expr, $similarity_method:path) => {
36 #[derive(Debug, Clone)]
39 pub(crate) struct $StructName {
40 signature: Signature,
41 }
42
43 impl Default for $StructName {
44 fn default() -> Self {
45 Self {
46 signature: helper::one_of_sigs2(
47 vec![
48 DataType::Utf8,
49 DataType::Utf8View,
50 DataType::Binary,
51 DataType::BinaryView,
52 ],
53 vec![
54 DataType::Utf8,
55 DataType::Utf8View,
56 DataType::Binary,
57 DataType::BinaryView,
58 ],
59 ),
60 }
61 }
62 }
63
64 impl Function for $StructName {
65 fn name(&self) -> &str {
66 $display_name
67 }
68
69 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
70 Ok(DataType::Float32)
71 }
72
73 fn signature(&self) -> &Signature {
74 &self.signature
75 }
76
77 fn invoke_with_args(
78 &self,
79 args: ScalarFunctionArgs,
80 ) -> datafusion_common::Result<ColumnarValue> {
81 let body = |v0: &Option<Cow<[f32]>>,
82 v1: &Option<Cow<[f32]>>|
83 -> datafusion_common::Result<ScalarValue> {
84 let result = if let (Some(v0), Some(v1)) = (v0, v1) {
85 if v0.len() != v1.len() {
86 return Err(datafusion_common::DataFusionError::Execution(format!(
87 "vectors length not match: {}",
88 self.name()
89 )));
90 }
91
92 let d = $similarity_method(v0, v1);
93 Some(d)
94 } else {
95 None
96 };
97 Ok(ScalarValue::Float32(result))
98 };
99
100 let calculator = $crate::scalars::vector::VectorCalculator {
101 name: self.name(),
102 func: body,
103 };
104 calculator.invoke_with_vectors(args)
105 }
106 }
107
108 impl Display for $StructName {
109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 write!(f, "{}", $display_name.to_ascii_uppercase())
111 }
112 }
113 };
114}
115
116define_distance_function!(CosDistanceFunction, VEC_COS_DISTANCE, cos::cos);
117define_distance_function!(L2SqDistanceFunction, VEC_L2SQ_DISTANCE, l2sq::l2sq);
118define_distance_function!(DotProductFunction, VEC_DOT_PRODUCT, dot::dot);
119
120#[cfg(test)]
121mod tests {
122 use std::sync::Arc;
123
124 use arrow_schema::Field;
125 use datafusion::arrow::array::{Array, ArrayRef, AsArray, BinaryArray, StringViewArray};
126 use datafusion::arrow::datatypes::Float32Type;
127 use datafusion_common::config::ConfigOptions;
128
129 use super::*;
130
131 fn test_invoke(func: &dyn Function, args: &[ArrayRef]) -> datafusion_common::Result<ArrayRef> {
132 let number_rows = args[0].len();
133 let args = ScalarFunctionArgs {
134 args: args
135 .iter()
136 .map(|x| ColumnarValue::Array(x.clone()))
137 .collect::<Vec<_>>(),
138 arg_fields: vec![],
139 number_rows,
140 return_field: Arc::new(Field::new("x", DataType::Float32, false)),
141 config_options: Arc::new(ConfigOptions::new()),
142 };
143 func.invoke_with_args(args)
144 .and_then(|x| x.to_array(number_rows))
145 }
146
147 #[test]
148 fn test_distance_string_string() {
149 let funcs = [
150 Box::new(CosDistanceFunction::default()) as Box<dyn Function>,
151 Box::new(L2SqDistanceFunction::default()) as Box<dyn Function>,
152 Box::new(DotProductFunction::default()) as Box<dyn Function>,
153 ];
154
155 for func in funcs {
156 let vec1: ArrayRef = Arc::new(StringViewArray::from(vec![
157 Some("[0.0, 1.0]"),
158 Some("[1.0, 0.0]"),
159 None,
160 Some("[1.0, 0.0]"),
161 ]));
162 let vec2: ArrayRef = Arc::new(StringViewArray::from(vec![
163 Some("[0.0, 1.0]"),
164 Some("[0.0, 1.0]"),
165 Some("[0.0, 1.0]"),
166 None,
167 ]));
168
169 let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap();
170 let result = result.as_primitive::<Float32Type>();
171
172 assert!(!result.is_null(0));
173 assert!(!result.is_null(1));
174 assert!(result.is_null(2));
175 assert!(result.is_null(3));
176
177 let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap();
178 let result = result.as_primitive::<Float32Type>();
179
180 assert!(!result.is_null(0));
181 assert!(!result.is_null(1));
182 assert!(result.is_null(2));
183 assert!(result.is_null(3));
184 }
185 }
186
187 #[test]
188 fn test_distance_binary_binary() {
189 let funcs = [
190 Box::new(CosDistanceFunction::default()) as Box<dyn Function>,
191 Box::new(L2SqDistanceFunction::default()) as Box<dyn Function>,
192 Box::new(DotProductFunction::default()) as Box<dyn Function>,
193 ];
194
195 for func in funcs {
196 let vec1: ArrayRef = Arc::new(BinaryArray::from_iter(vec![
197 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
198 Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
199 None,
200 Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
201 ]));
202 let vec2: ArrayRef = Arc::new(BinaryArray::from_iter(vec![
203 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
205 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
206 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
207 None,
208 ]));
209
210 let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap();
211 let result = result.as_primitive::<Float32Type>();
212
213 assert!(!result.is_null(0));
214 assert!(!result.is_null(1));
215 assert!(result.is_null(2));
216 assert!(result.is_null(3));
217
218 let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap();
219 let result = result.as_primitive::<Float32Type>();
220
221 assert!(!result.is_null(0));
222 assert!(!result.is_null(1));
223 assert!(result.is_null(2));
224 assert!(result.is_null(3));
225 }
226 }
227
228 #[test]
229 fn test_distance_string_binary() {
230 let funcs = [
231 Box::new(CosDistanceFunction::default()) as Box<dyn Function>,
232 Box::new(L2SqDistanceFunction::default()) as Box<dyn Function>,
233 Box::new(DotProductFunction::default()) as Box<dyn Function>,
234 ];
235
236 for func in funcs {
237 let vec1: ArrayRef = Arc::new(StringViewArray::from(vec![
238 Some("[0.0, 1.0]"),
239 Some("[1.0, 0.0]"),
240 None,
241 Some("[1.0, 0.0]"),
242 ]));
243 let vec2: ArrayRef = Arc::new(BinaryArray::from_iter(vec![
244 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
246 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
247 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
248 None,
249 ]));
250
251 let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap();
252 let result = result.as_primitive::<Float32Type>();
253
254 assert!(!result.is_null(0));
255 assert!(!result.is_null(1));
256 assert!(result.is_null(2));
257 assert!(result.is_null(3));
258
259 let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap();
260 let result = result.as_primitive::<Float32Type>();
261
262 assert!(!result.is_null(0));
263 assert!(!result.is_null(1));
264 assert!(result.is_null(2));
265 assert!(result.is_null(3));
266 }
267 }
268
269 #[test]
270 fn test_invalid_vector_length() {
271 let funcs = [
272 Box::new(CosDistanceFunction::default()) as Box<dyn Function>,
273 Box::new(L2SqDistanceFunction::default()) as Box<dyn Function>,
274 Box::new(DotProductFunction::default()) as Box<dyn Function>,
275 ];
276
277 for func in funcs {
278 let vec1: ArrayRef = Arc::new(StringViewArray::from(vec!["[1.0]"]));
279 let vec2: ArrayRef = Arc::new(StringViewArray::from(vec!["[1.0, 1.0]"]));
280 let result = test_invoke(func.as_ref(), &[vec1, vec2]);
281 assert!(result.is_err());
282
283 let vec1: ArrayRef = Arc::new(BinaryArray::from_iter_values(vec![vec![0, 0, 128, 63]]));
284 let vec2: ArrayRef = Arc::new(BinaryArray::from_iter_values(vec![vec![
285 0, 0, 128, 63, 0, 0, 0, 64,
286 ]]));
287 let result = test_invoke(func.as_ref(), &[vec1, vec2]);
288 assert!(result.is_err());
289 }
290 }
291}