1mod cos;
16mod dot;
17mod l2sq;
18
19use std::borrow::Cow;
20use std::fmt::Display;
21
22use datafusion::logical_expr::ColumnarValue;
23use datafusion_common::ScalarValue;
24use datafusion_expr::{ScalarFunctionArgs, Signature};
25use datatypes::arrow::datatypes::DataType;
26
27use crate::function::Function;
28use crate::helper;
29
30macro_rules! define_distance_function {
31 ($StructName:ident, $display_name:expr, $similarity_method:path) => {
32 #[derive(Debug, Clone)]
35 pub(crate) struct $StructName {
36 signature: Signature,
37 }
38
39 impl Default for $StructName {
40 fn default() -> Self {
41 Self {
42 signature: helper::one_of_sigs2(
43 vec![
44 DataType::Utf8,
45 DataType::Utf8View,
46 DataType::Binary,
47 DataType::BinaryView,
48 ],
49 vec![
50 DataType::Utf8,
51 DataType::Utf8View,
52 DataType::Binary,
53 DataType::BinaryView,
54 ],
55 ),
56 }
57 }
58 }
59
60 impl Function for $StructName {
61 fn name(&self) -> &str {
62 $display_name
63 }
64
65 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
66 Ok(DataType::Float32)
67 }
68
69 fn signature(&self) -> &Signature {
70 &self.signature
71 }
72
73 fn invoke_with_args(
74 &self,
75 args: ScalarFunctionArgs,
76 ) -> datafusion_common::Result<ColumnarValue> {
77 let body = |v0: &Option<Cow<[f32]>>,
78 v1: &Option<Cow<[f32]>>|
79 -> datafusion_common::Result<ScalarValue> {
80 let result = if let (Some(v0), Some(v1)) = (v0, v1) {
81 if v0.len() != v1.len() {
82 return Err(datafusion_common::DataFusionError::Execution(format!(
83 "vectors length not match: {}",
84 self.name()
85 )));
86 }
87
88 let d = $similarity_method(v0, v1);
89 Some(d)
90 } else {
91 None
92 };
93 Ok(ScalarValue::Float32(result))
94 };
95
96 let calculator = $crate::scalars::vector::VectorCalculator {
97 name: self.name(),
98 func: body,
99 };
100 calculator.invoke_with_vectors(args)
101 }
102 }
103
104 impl Display for $StructName {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 write!(f, "{}", $display_name.to_ascii_uppercase())
107 }
108 }
109 };
110}
111
112define_distance_function!(CosDistanceFunction, "vec_cos_distance", cos::cos);
113define_distance_function!(L2SqDistanceFunction, "vec_l2sq_distance", l2sq::l2sq);
114define_distance_function!(DotProductFunction, "vec_dot_product", dot::dot);
115
116#[cfg(test)]
117mod tests {
118 use std::sync::Arc;
119
120 use arrow_schema::Field;
121 use datafusion::arrow::array::{Array, ArrayRef, AsArray, BinaryArray, StringViewArray};
122 use datafusion::arrow::datatypes::Float32Type;
123 use datafusion_common::config::ConfigOptions;
124
125 use super::*;
126
127 fn test_invoke(func: &dyn Function, args: &[ArrayRef]) -> datafusion_common::Result<ArrayRef> {
128 let number_rows = args[0].len();
129 let args = ScalarFunctionArgs {
130 args: args
131 .iter()
132 .map(|x| ColumnarValue::Array(x.clone()))
133 .collect::<Vec<_>>(),
134 arg_fields: vec![],
135 number_rows,
136 return_field: Arc::new(Field::new("x", DataType::Float32, false)),
137 config_options: Arc::new(ConfigOptions::new()),
138 };
139 func.invoke_with_args(args)
140 .and_then(|x| x.to_array(number_rows))
141 }
142
143 #[test]
144 fn test_distance_string_string() {
145 let funcs = [
146 Box::new(CosDistanceFunction::default()) as Box<dyn Function>,
147 Box::new(L2SqDistanceFunction::default()) as Box<dyn Function>,
148 Box::new(DotProductFunction::default()) as Box<dyn Function>,
149 ];
150
151 for func in funcs {
152 let vec1: ArrayRef = Arc::new(StringViewArray::from(vec![
153 Some("[0.0, 1.0]"),
154 Some("[1.0, 0.0]"),
155 None,
156 Some("[1.0, 0.0]"),
157 ]));
158 let vec2: ArrayRef = Arc::new(StringViewArray::from(vec![
159 Some("[0.0, 1.0]"),
160 Some("[0.0, 1.0]"),
161 Some("[0.0, 1.0]"),
162 None,
163 ]));
164
165 let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap();
166 let result = result.as_primitive::<Float32Type>();
167
168 assert!(!result.is_null(0));
169 assert!(!result.is_null(1));
170 assert!(result.is_null(2));
171 assert!(result.is_null(3));
172
173 let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap();
174 let result = result.as_primitive::<Float32Type>();
175
176 assert!(!result.is_null(0));
177 assert!(!result.is_null(1));
178 assert!(result.is_null(2));
179 assert!(result.is_null(3));
180 }
181 }
182
183 #[test]
184 fn test_distance_binary_binary() {
185 let funcs = [
186 Box::new(CosDistanceFunction::default()) as Box<dyn Function>,
187 Box::new(L2SqDistanceFunction::default()) as Box<dyn Function>,
188 Box::new(DotProductFunction::default()) as Box<dyn Function>,
189 ];
190
191 for func in funcs {
192 let vec1: ArrayRef = Arc::new(BinaryArray::from_iter(vec![
193 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
194 Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
195 None,
196 Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
197 ]));
198 let vec2: ArrayRef = Arc::new(BinaryArray::from_iter(vec![
199 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
201 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
202 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
203 None,
204 ]));
205
206 let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap();
207 let result = result.as_primitive::<Float32Type>();
208
209 assert!(!result.is_null(0));
210 assert!(!result.is_null(1));
211 assert!(result.is_null(2));
212 assert!(result.is_null(3));
213
214 let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap();
215 let result = result.as_primitive::<Float32Type>();
216
217 assert!(!result.is_null(0));
218 assert!(!result.is_null(1));
219 assert!(result.is_null(2));
220 assert!(result.is_null(3));
221 }
222 }
223
224 #[test]
225 fn test_distance_string_binary() {
226 let funcs = [
227 Box::new(CosDistanceFunction::default()) as Box<dyn Function>,
228 Box::new(L2SqDistanceFunction::default()) as Box<dyn Function>,
229 Box::new(DotProductFunction::default()) as Box<dyn Function>,
230 ];
231
232 for func in funcs {
233 let vec1: ArrayRef = Arc::new(StringViewArray::from(vec![
234 Some("[0.0, 1.0]"),
235 Some("[1.0, 0.0]"),
236 None,
237 Some("[1.0, 0.0]"),
238 ]));
239 let vec2: ArrayRef = Arc::new(BinaryArray::from_iter(vec![
240 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
242 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
243 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
244 None,
245 ]));
246
247 let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap();
248 let result = result.as_primitive::<Float32Type>();
249
250 assert!(!result.is_null(0));
251 assert!(!result.is_null(1));
252 assert!(result.is_null(2));
253 assert!(result.is_null(3));
254
255 let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap();
256 let result = result.as_primitive::<Float32Type>();
257
258 assert!(!result.is_null(0));
259 assert!(!result.is_null(1));
260 assert!(result.is_null(2));
261 assert!(result.is_null(3));
262 }
263 }
264
265 #[test]
266 fn test_invalid_vector_length() {
267 let funcs = [
268 Box::new(CosDistanceFunction::default()) as Box<dyn Function>,
269 Box::new(L2SqDistanceFunction::default()) as Box<dyn Function>,
270 Box::new(DotProductFunction::default()) as Box<dyn Function>,
271 ];
272
273 for func in funcs {
274 let vec1: ArrayRef = Arc::new(StringViewArray::from(vec!["[1.0]"]));
275 let vec2: ArrayRef = Arc::new(StringViewArray::from(vec!["[1.0, 1.0]"]));
276 let result = test_invoke(func.as_ref(), &[vec1, vec2]);
277 assert!(result.is_err());
278
279 let vec1: ArrayRef = Arc::new(BinaryArray::from_iter_values(vec![vec![0, 0, 128, 63]]));
280 let vec2: ArrayRef = Arc::new(BinaryArray::from_iter_values(vec![vec![
281 0, 0, 128, 63, 0, 0, 0, 64,
282 ]]));
283 let result = test_invoke(func.as_ref(), &[vec1, vec2]);
284 assert!(result.is_err());
285 }
286 }
287}