1mod cos;
16mod dot;
17mod l2sq;
18
19use std::borrow::Cow;
20use std::fmt::Display;
21
22use common_query::error::{InvalidFuncArgsSnafu, Result};
23use datafusion_expr::Signature;
24use datatypes::arrow::datatypes::DataType;
25use datatypes::scalars::ScalarVectorBuilder;
26use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
27use snafu::ensure;
28
29use crate::function::{Function, FunctionContext};
30use crate::helper;
31use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
32
33macro_rules! define_distance_function {
34 ($StructName:ident, $display_name:expr, $similarity_method:path) => {
35
36 #[derive(Debug, Clone, Default)]
39 pub struct $StructName;
40
41 impl Function for $StructName {
42 fn name(&self) -> &str {
43 $display_name
44 }
45
46 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
47 Ok(DataType::Float32)
48 }
49
50 fn signature(&self) -> Signature {
51 helper::one_of_sigs2(
52 vec![DataType::Utf8, DataType::Binary],
53 vec![DataType::Utf8, DataType::Binary],
54 )
55 }
56
57 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
58 ensure!(
59 columns.len() == 2,
60 InvalidFuncArgsSnafu {
61 err_msg: format!(
62 "The length of the args is not correct, expect exactly two, have: {}",
63 columns.len()
64 ),
65 }
66 );
67 let arg0 = &columns[0];
68 let arg1 = &columns[1];
69
70 let size = arg0.len();
71 let mut result = Float32VectorBuilder::with_capacity(size);
72 if size == 0 {
73 return Ok(result.to_vector());
74 }
75
76 let arg0_const = as_veclit_if_const(arg0)?;
77 let arg1_const = as_veclit_if_const(arg1)?;
78
79 for i in 0..size {
80 let vec0 = match arg0_const.as_ref() {
81 Some(a) => Some(Cow::Borrowed(a.as_ref())),
82 None => as_veclit(arg0.get_ref(i))?,
83 };
84 let vec1 = match arg1_const.as_ref() {
85 Some(b) => Some(Cow::Borrowed(b.as_ref())),
86 None => as_veclit(arg1.get_ref(i))?,
87 };
88
89 if let (Some(vec0), Some(vec1)) = (vec0, vec1) {
90 ensure!(
91 vec0.len() == vec1.len(),
92 InvalidFuncArgsSnafu {
93 err_msg: format!(
94 "The length of the vectors must match to calculate distance, have: {} vs {}",
95 vec0.len(),
96 vec1.len()
97 ),
98 }
99 );
100
101 let d = $similarity_method(vec0.as_ref(), vec1.as_ref());
103 result.push(Some(d));
104 } else {
105 result.push_null();
106 }
107 }
108
109 return Ok(result.to_vector());
110 }
111 }
112
113 impl Display for $StructName {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 write!(f, "{}", $display_name.to_ascii_uppercase())
116 }
117 }
118 }
119}
120
121define_distance_function!(CosDistanceFunction, "vec_cos_distance", cos::cos);
122define_distance_function!(L2SqDistanceFunction, "vec_l2sq_distance", l2sq::l2sq);
123define_distance_function!(DotProductFunction, "vec_dot_product", dot::dot);
124
125#[cfg(test)]
126mod tests {
127 use std::sync::Arc;
128
129 use datatypes::vectors::{BinaryVector, ConstantVector, StringVector};
130
131 use super::*;
132
133 #[test]
134 fn test_distance_string_string() {
135 let funcs = [
136 Box::new(CosDistanceFunction {}) as Box<dyn Function>,
137 Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
138 Box::new(DotProductFunction {}) as Box<dyn Function>,
139 ];
140
141 for func in funcs {
142 let vec1 = Arc::new(StringVector::from(vec![
143 Some("[0.0, 1.0]"),
144 Some("[1.0, 0.0]"),
145 None,
146 Some("[1.0, 0.0]"),
147 ])) as VectorRef;
148 let vec2 = Arc::new(StringVector::from(vec![
149 Some("[0.0, 1.0]"),
150 Some("[0.0, 1.0]"),
151 Some("[0.0, 1.0]"),
152 None,
153 ])) as VectorRef;
154
155 let result = func
156 .eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
157 .unwrap();
158
159 assert!(!result.get(0).is_null());
160 assert!(!result.get(1).is_null());
161 assert!(result.get(2).is_null());
162 assert!(result.get(3).is_null());
163
164 let result = func
165 .eval(&FunctionContext::default(), &[vec2, vec1])
166 .unwrap();
167
168 assert!(!result.get(0).is_null());
169 assert!(!result.get(1).is_null());
170 assert!(result.get(2).is_null());
171 assert!(result.get(3).is_null());
172 }
173 }
174
175 #[test]
176 fn test_distance_binary_binary() {
177 let funcs = [
178 Box::new(CosDistanceFunction {}) as Box<dyn Function>,
179 Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
180 Box::new(DotProductFunction {}) as Box<dyn Function>,
181 ];
182
183 for func in funcs {
184 let vec1 = Arc::new(BinaryVector::from(vec![
185 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
186 Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
187 None,
188 Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
189 ])) as VectorRef;
190 let vec2 = Arc::new(BinaryVector::from(vec![
191 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
193 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
194 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
195 None,
196 ])) as VectorRef;
197
198 let result = func
199 .eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
200 .unwrap();
201
202 assert!(!result.get(0).is_null());
203 assert!(!result.get(1).is_null());
204 assert!(result.get(2).is_null());
205 assert!(result.get(3).is_null());
206
207 let result = func
208 .eval(&FunctionContext::default(), &[vec2, vec1])
209 .unwrap();
210
211 assert!(!result.get(0).is_null());
212 assert!(!result.get(1).is_null());
213 assert!(result.get(2).is_null());
214 assert!(result.get(3).is_null());
215 }
216 }
217
218 #[test]
219 fn test_distance_string_binary() {
220 let funcs = [
221 Box::new(CosDistanceFunction {}) as Box<dyn Function>,
222 Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
223 Box::new(DotProductFunction {}) as Box<dyn Function>,
224 ];
225
226 for func in funcs {
227 let vec1 = Arc::new(StringVector::from(vec![
228 Some("[0.0, 1.0]"),
229 Some("[1.0, 0.0]"),
230 None,
231 Some("[1.0, 0.0]"),
232 ])) as VectorRef;
233 let vec2 = Arc::new(BinaryVector::from(vec![
234 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
236 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
237 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
238 None,
239 ])) as VectorRef;
240
241 let result = func
242 .eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
243 .unwrap();
244
245 assert!(!result.get(0).is_null());
246 assert!(!result.get(1).is_null());
247 assert!(result.get(2).is_null());
248 assert!(result.get(3).is_null());
249
250 let result = func
251 .eval(&FunctionContext::default(), &[vec2, vec1])
252 .unwrap();
253
254 assert!(!result.get(0).is_null());
255 assert!(!result.get(1).is_null());
256 assert!(result.get(2).is_null());
257 assert!(result.get(3).is_null());
258 }
259 }
260
261 #[test]
262 fn test_distance_const_string() {
263 let funcs = [
264 Box::new(CosDistanceFunction {}) as Box<dyn Function>,
265 Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
266 Box::new(DotProductFunction {}) as Box<dyn Function>,
267 ];
268
269 for func in funcs {
270 let const_str = Arc::new(ConstantVector::new(
271 Arc::new(StringVector::from(vec!["[0.0, 1.0]"])),
272 4,
273 ));
274
275 let vec1 = Arc::new(StringVector::from(vec![
276 Some("[0.0, 1.0]"),
277 Some("[1.0, 0.0]"),
278 None,
279 Some("[1.0, 0.0]"),
280 ])) as VectorRef;
281 let vec2 = Arc::new(BinaryVector::from(vec![
282 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
284 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
285 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
286 None,
287 ])) as VectorRef;
288
289 let result = func
290 .eval(
291 &FunctionContext::default(),
292 &[const_str.clone(), vec1.clone()],
293 )
294 .unwrap();
295
296 assert!(!result.get(0).is_null());
297 assert!(!result.get(1).is_null());
298 assert!(result.get(2).is_null());
299 assert!(!result.get(3).is_null());
300
301 let result = func
302 .eval(
303 &FunctionContext::default(),
304 &[vec1.clone(), const_str.clone()],
305 )
306 .unwrap();
307
308 assert!(!result.get(0).is_null());
309 assert!(!result.get(1).is_null());
310 assert!(result.get(2).is_null());
311 assert!(!result.get(3).is_null());
312
313 let result = func
314 .eval(
315 &FunctionContext::default(),
316 &[const_str.clone(), vec2.clone()],
317 )
318 .unwrap();
319
320 assert!(!result.get(0).is_null());
321 assert!(!result.get(1).is_null());
322 assert!(!result.get(2).is_null());
323 assert!(result.get(3).is_null());
324
325 let result = func
326 .eval(
327 &FunctionContext::default(),
328 &[vec2.clone(), const_str.clone()],
329 )
330 .unwrap();
331
332 assert!(!result.get(0).is_null());
333 assert!(!result.get(1).is_null());
334 assert!(!result.get(2).is_null());
335 assert!(result.get(3).is_null());
336 }
337 }
338
339 #[test]
340 fn test_invalid_vector_length() {
341 let funcs = [
342 Box::new(CosDistanceFunction {}) as Box<dyn Function>,
343 Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
344 Box::new(DotProductFunction {}) as Box<dyn Function>,
345 ];
346
347 for func in funcs {
348 let vec1 = Arc::new(StringVector::from(vec!["[1.0]"])) as VectorRef;
349 let vec2 = Arc::new(StringVector::from(vec!["[1.0, 1.0]"])) as VectorRef;
350 let result = func.eval(&FunctionContext::default(), &[vec1, vec2]);
351 assert!(result.is_err());
352
353 let vec1 = Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63]])) as VectorRef;
354 let vec2 =
355 Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63, 0, 0, 0, 64]])) as VectorRef;
356 let result = func.eval(&FunctionContext::default(), &[vec1, vec2]);
357 assert!(result.is_err());
358 }
359 }
360}