1mod cos;
16mod dot;
17mod l2sq;
18
19use std::borrow::Cow;
20use std::fmt::Display;
21
22use common_query::error::{InvalidFuncArgsSnafu, Result};
23use common_query::prelude::Signature;
24use datatypes::prelude::ConcreteDataType;
25use datatypes::scalars::ScalarVectorBuilder;
26use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
27use snafu::ensure;
28
29use crate::function::{Function, FunctionContext};
30use crate::helper;
31use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
32
33macro_rules! define_distance_function {
34 ($StructName:ident, $display_name:expr, $similarity_method:path) => {
35
36 #[derive(Debug, Clone, Default)]
39 pub struct $StructName;
40
41 impl Function for $StructName {
42 fn name(&self) -> &str {
43 $display_name
44 }
45
46 fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
47 Ok(ConcreteDataType::float32_datatype())
48 }
49
50 fn signature(&self) -> Signature {
51 helper::one_of_sigs2(
52 vec![
53 ConcreteDataType::string_datatype(),
54 ConcreteDataType::binary_datatype(),
55 ],
56 vec![
57 ConcreteDataType::string_datatype(),
58 ConcreteDataType::binary_datatype(),
59 ],
60 )
61 }
62
63 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
64 ensure!(
65 columns.len() == 2,
66 InvalidFuncArgsSnafu {
67 err_msg: format!(
68 "The length of the args is not correct, expect exactly two, have: {}",
69 columns.len()
70 ),
71 }
72 );
73 let arg0 = &columns[0];
74 let arg1 = &columns[1];
75
76 let size = arg0.len();
77 let mut result = Float32VectorBuilder::with_capacity(size);
78 if size == 0 {
79 return Ok(result.to_vector());
80 }
81
82 let arg0_const = as_veclit_if_const(arg0)?;
83 let arg1_const = as_veclit_if_const(arg1)?;
84
85 for i in 0..size {
86 let vec0 = match arg0_const.as_ref() {
87 Some(a) => Some(Cow::Borrowed(a.as_ref())),
88 None => as_veclit(arg0.get_ref(i))?,
89 };
90 let vec1 = match arg1_const.as_ref() {
91 Some(b) => Some(Cow::Borrowed(b.as_ref())),
92 None => as_veclit(arg1.get_ref(i))?,
93 };
94
95 if let (Some(vec0), Some(vec1)) = (vec0, vec1) {
96 ensure!(
97 vec0.len() == vec1.len(),
98 InvalidFuncArgsSnafu {
99 err_msg: format!(
100 "The length of the vectors must match to calculate distance, have: {} vs {}",
101 vec0.len(),
102 vec1.len()
103 ),
104 }
105 );
106
107 let d = $similarity_method(vec0.as_ref(), vec1.as_ref());
109 result.push(Some(d));
110 } else {
111 result.push_null();
112 }
113 }
114
115 return Ok(result.to_vector());
116 }
117 }
118
119 impl Display for $StructName {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 write!(f, "{}", $display_name.to_ascii_uppercase())
122 }
123 }
124 }
125}
126
127define_distance_function!(CosDistanceFunction, "vec_cos_distance", cos::cos);
128define_distance_function!(L2SqDistanceFunction, "vec_l2sq_distance", l2sq::l2sq);
129define_distance_function!(DotProductFunction, "vec_dot_product", dot::dot);
130
131#[cfg(test)]
132mod tests {
133 use std::sync::Arc;
134
135 use datatypes::vectors::{BinaryVector, ConstantVector, StringVector};
136
137 use super::*;
138
139 #[test]
140 fn test_distance_string_string() {
141 let funcs = [
142 Box::new(CosDistanceFunction {}) as Box<dyn Function>,
143 Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
144 Box::new(DotProductFunction {}) as Box<dyn Function>,
145 ];
146
147 for func in funcs {
148 let vec1 = Arc::new(StringVector::from(vec![
149 Some("[0.0, 1.0]"),
150 Some("[1.0, 0.0]"),
151 None,
152 Some("[1.0, 0.0]"),
153 ])) as VectorRef;
154 let vec2 = Arc::new(StringVector::from(vec![
155 Some("[0.0, 1.0]"),
156 Some("[0.0, 1.0]"),
157 Some("[0.0, 1.0]"),
158 None,
159 ])) as VectorRef;
160
161 let result = func
162 .eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
163 .unwrap();
164
165 assert!(!result.get(0).is_null());
166 assert!(!result.get(1).is_null());
167 assert!(result.get(2).is_null());
168 assert!(result.get(3).is_null());
169
170 let result = func
171 .eval(&FunctionContext::default(), &[vec2, vec1])
172 .unwrap();
173
174 assert!(!result.get(0).is_null());
175 assert!(!result.get(1).is_null());
176 assert!(result.get(2).is_null());
177 assert!(result.get(3).is_null());
178 }
179 }
180
181 #[test]
182 fn test_distance_binary_binary() {
183 let funcs = [
184 Box::new(CosDistanceFunction {}) as Box<dyn Function>,
185 Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
186 Box::new(DotProductFunction {}) as Box<dyn Function>,
187 ];
188
189 for func in funcs {
190 let vec1 = Arc::new(BinaryVector::from(vec![
191 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
192 Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
193 None,
194 Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
195 ])) as VectorRef;
196 let vec2 = Arc::new(BinaryVector::from(vec![
197 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
199 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
200 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
201 None,
202 ])) as VectorRef;
203
204 let result = func
205 .eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
206 .unwrap();
207
208 assert!(!result.get(0).is_null());
209 assert!(!result.get(1).is_null());
210 assert!(result.get(2).is_null());
211 assert!(result.get(3).is_null());
212
213 let result = func
214 .eval(&FunctionContext::default(), &[vec2, vec1])
215 .unwrap();
216
217 assert!(!result.get(0).is_null());
218 assert!(!result.get(1).is_null());
219 assert!(result.get(2).is_null());
220 assert!(result.get(3).is_null());
221 }
222 }
223
224 #[test]
225 fn test_distance_string_binary() {
226 let funcs = [
227 Box::new(CosDistanceFunction {}) as Box<dyn Function>,
228 Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
229 Box::new(DotProductFunction {}) as Box<dyn Function>,
230 ];
231
232 for func in funcs {
233 let vec1 = Arc::new(StringVector::from(vec![
234 Some("[0.0, 1.0]"),
235 Some("[1.0, 0.0]"),
236 None,
237 Some("[1.0, 0.0]"),
238 ])) as VectorRef;
239 let vec2 = Arc::new(BinaryVector::from(vec![
240 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
242 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
243 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
244 None,
245 ])) as VectorRef;
246
247 let result = func
248 .eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
249 .unwrap();
250
251 assert!(!result.get(0).is_null());
252 assert!(!result.get(1).is_null());
253 assert!(result.get(2).is_null());
254 assert!(result.get(3).is_null());
255
256 let result = func
257 .eval(&FunctionContext::default(), &[vec2, vec1])
258 .unwrap();
259
260 assert!(!result.get(0).is_null());
261 assert!(!result.get(1).is_null());
262 assert!(result.get(2).is_null());
263 assert!(result.get(3).is_null());
264 }
265 }
266
267 #[test]
268 fn test_distance_const_string() {
269 let funcs = [
270 Box::new(CosDistanceFunction {}) as Box<dyn Function>,
271 Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
272 Box::new(DotProductFunction {}) as Box<dyn Function>,
273 ];
274
275 for func in funcs {
276 let const_str = Arc::new(ConstantVector::new(
277 Arc::new(StringVector::from(vec!["[0.0, 1.0]"])),
278 4,
279 ));
280
281 let vec1 = Arc::new(StringVector::from(vec![
282 Some("[0.0, 1.0]"),
283 Some("[1.0, 0.0]"),
284 None,
285 Some("[1.0, 0.0]"),
286 ])) as VectorRef;
287 let vec2 = Arc::new(BinaryVector::from(vec![
288 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
290 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
291 Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
292 None,
293 ])) as VectorRef;
294
295 let result = func
296 .eval(
297 &FunctionContext::default(),
298 &[const_str.clone(), vec1.clone()],
299 )
300 .unwrap();
301
302 assert!(!result.get(0).is_null());
303 assert!(!result.get(1).is_null());
304 assert!(result.get(2).is_null());
305 assert!(!result.get(3).is_null());
306
307 let result = func
308 .eval(
309 &FunctionContext::default(),
310 &[vec1.clone(), const_str.clone()],
311 )
312 .unwrap();
313
314 assert!(!result.get(0).is_null());
315 assert!(!result.get(1).is_null());
316 assert!(result.get(2).is_null());
317 assert!(!result.get(3).is_null());
318
319 let result = func
320 .eval(
321 &FunctionContext::default(),
322 &[const_str.clone(), vec2.clone()],
323 )
324 .unwrap();
325
326 assert!(!result.get(0).is_null());
327 assert!(!result.get(1).is_null());
328 assert!(!result.get(2).is_null());
329 assert!(result.get(3).is_null());
330
331 let result = func
332 .eval(
333 &FunctionContext::default(),
334 &[vec2.clone(), const_str.clone()],
335 )
336 .unwrap();
337
338 assert!(!result.get(0).is_null());
339 assert!(!result.get(1).is_null());
340 assert!(!result.get(2).is_null());
341 assert!(result.get(3).is_null());
342 }
343 }
344
345 #[test]
346 fn test_invalid_vector_length() {
347 let funcs = [
348 Box::new(CosDistanceFunction {}) as Box<dyn Function>,
349 Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
350 Box::new(DotProductFunction {}) as Box<dyn Function>,
351 ];
352
353 for func in funcs {
354 let vec1 = Arc::new(StringVector::from(vec!["[1.0]"])) as VectorRef;
355 let vec2 = Arc::new(StringVector::from(vec!["[1.0, 1.0]"])) as VectorRef;
356 let result = func.eval(&FunctionContext::default(), &[vec1, vec2]);
357 assert!(result.is_err());
358
359 let vec1 = Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63]])) as VectorRef;
360 let vec2 =
361 Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63, 0, 0, 0, 64]])) as VectorRef;
362 let result = func.eval(&FunctionContext::default(), &[vec1, vec2]);
363 assert!(result.is_err());
364 }
365 }
366}