common_function/scalars/vector/
distance.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod cos;
16mod dot;
17mod l2sq;
18
19use std::borrow::Cow;
20use std::fmt::Display;
21
22use common_query::error::{InvalidFuncArgsSnafu, Result};
23use common_query::prelude::Signature;
24use datatypes::prelude::ConcreteDataType;
25use datatypes::scalars::ScalarVectorBuilder;
26use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
27use snafu::ensure;
28
29use crate::function::{Function, FunctionContext};
30use crate::helper;
31use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
32
33macro_rules! define_distance_function {
34    ($StructName:ident, $display_name:expr, $similarity_method:path) => {
35
36        /// A function calculates the distance between two vectors.
37
38        #[derive(Debug, Clone, Default)]
39        pub struct $StructName;
40
41        impl Function for $StructName {
42            fn name(&self) -> &str {
43                $display_name
44            }
45
46            fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
47                Ok(ConcreteDataType::float32_datatype())
48            }
49
50            fn signature(&self) -> Signature {
51                helper::one_of_sigs2(
52                    vec![
53                        ConcreteDataType::string_datatype(),
54                        ConcreteDataType::binary_datatype(),
55                    ],
56                    vec![
57                        ConcreteDataType::string_datatype(),
58                        ConcreteDataType::binary_datatype(),
59                    ],
60                )
61            }
62
63            fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
64                ensure!(
65                    columns.len() == 2,
66                    InvalidFuncArgsSnafu {
67                        err_msg: format!(
68                            "The length of the args is not correct, expect exactly two, have: {}",
69                            columns.len()
70                        ),
71                    }
72                );
73                let arg0 = &columns[0];
74                let arg1 = &columns[1];
75
76                let size = arg0.len();
77                let mut result = Float32VectorBuilder::with_capacity(size);
78                if size == 0 {
79                    return Ok(result.to_vector());
80                }
81
82                let arg0_const = as_veclit_if_const(arg0)?;
83                let arg1_const = as_veclit_if_const(arg1)?;
84
85                for i in 0..size {
86                    let vec0 = match arg0_const.as_ref() {
87                        Some(a) => Some(Cow::Borrowed(a.as_ref())),
88                        None => as_veclit(arg0.get_ref(i))?,
89                    };
90                    let vec1 = match arg1_const.as_ref() {
91                        Some(b) => Some(Cow::Borrowed(b.as_ref())),
92                        None => as_veclit(arg1.get_ref(i))?,
93                    };
94
95                    if let (Some(vec0), Some(vec1)) = (vec0, vec1) {
96                        ensure!(
97                            vec0.len() == vec1.len(),
98                            InvalidFuncArgsSnafu {
99                                err_msg: format!(
100                                    "The length of the vectors must match to calculate distance, have: {} vs {}",
101                                    vec0.len(),
102                                    vec1.len()
103                                ),
104                            }
105                        );
106
107                        // Checked if the length of the vectors match
108                        let d = $similarity_method(vec0.as_ref(), vec1.as_ref());
109                        result.push(Some(d));
110                    } else {
111                        result.push_null();
112                    }
113                }
114
115                return Ok(result.to_vector());
116            }
117        }
118
119        impl Display for $StructName {
120            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121                write!(f, "{}", $display_name.to_ascii_uppercase())
122            }
123        }
124    }
125}
126
127define_distance_function!(CosDistanceFunction, "vec_cos_distance", cos::cos);
128define_distance_function!(L2SqDistanceFunction, "vec_l2sq_distance", l2sq::l2sq);
129define_distance_function!(DotProductFunction, "vec_dot_product", dot::dot);
130
131#[cfg(test)]
132mod tests {
133    use std::sync::Arc;
134
135    use datatypes::vectors::{BinaryVector, ConstantVector, StringVector};
136
137    use super::*;
138
139    #[test]
140    fn test_distance_string_string() {
141        let funcs = [
142            Box::new(CosDistanceFunction {}) as Box<dyn Function>,
143            Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
144            Box::new(DotProductFunction {}) as Box<dyn Function>,
145        ];
146
147        for func in funcs {
148            let vec1 = Arc::new(StringVector::from(vec![
149                Some("[0.0, 1.0]"),
150                Some("[1.0, 0.0]"),
151                None,
152                Some("[1.0, 0.0]"),
153            ])) as VectorRef;
154            let vec2 = Arc::new(StringVector::from(vec![
155                Some("[0.0, 1.0]"),
156                Some("[0.0, 1.0]"),
157                Some("[0.0, 1.0]"),
158                None,
159            ])) as VectorRef;
160
161            let result = func
162                .eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
163                .unwrap();
164
165            assert!(!result.get(0).is_null());
166            assert!(!result.get(1).is_null());
167            assert!(result.get(2).is_null());
168            assert!(result.get(3).is_null());
169
170            let result = func
171                .eval(&FunctionContext::default(), &[vec2, vec1])
172                .unwrap();
173
174            assert!(!result.get(0).is_null());
175            assert!(!result.get(1).is_null());
176            assert!(result.get(2).is_null());
177            assert!(result.get(3).is_null());
178        }
179    }
180
181    #[test]
182    fn test_distance_binary_binary() {
183        let funcs = [
184            Box::new(CosDistanceFunction {}) as Box<dyn Function>,
185            Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
186            Box::new(DotProductFunction {}) as Box<dyn Function>,
187        ];
188
189        for func in funcs {
190            let vec1 = Arc::new(BinaryVector::from(vec![
191                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
192                Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
193                None,
194                Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
195            ])) as VectorRef;
196            let vec2 = Arc::new(BinaryVector::from(vec![
197                // [0.0, 1.0]
198                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
199                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
200                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
201                None,
202            ])) as VectorRef;
203
204            let result = func
205                .eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
206                .unwrap();
207
208            assert!(!result.get(0).is_null());
209            assert!(!result.get(1).is_null());
210            assert!(result.get(2).is_null());
211            assert!(result.get(3).is_null());
212
213            let result = func
214                .eval(&FunctionContext::default(), &[vec2, vec1])
215                .unwrap();
216
217            assert!(!result.get(0).is_null());
218            assert!(!result.get(1).is_null());
219            assert!(result.get(2).is_null());
220            assert!(result.get(3).is_null());
221        }
222    }
223
224    #[test]
225    fn test_distance_string_binary() {
226        let funcs = [
227            Box::new(CosDistanceFunction {}) as Box<dyn Function>,
228            Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
229            Box::new(DotProductFunction {}) as Box<dyn Function>,
230        ];
231
232        for func in funcs {
233            let vec1 = Arc::new(StringVector::from(vec![
234                Some("[0.0, 1.0]"),
235                Some("[1.0, 0.0]"),
236                None,
237                Some("[1.0, 0.0]"),
238            ])) as VectorRef;
239            let vec2 = Arc::new(BinaryVector::from(vec![
240                // [0.0, 1.0]
241                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
242                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
243                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
244                None,
245            ])) as VectorRef;
246
247            let result = func
248                .eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
249                .unwrap();
250
251            assert!(!result.get(0).is_null());
252            assert!(!result.get(1).is_null());
253            assert!(result.get(2).is_null());
254            assert!(result.get(3).is_null());
255
256            let result = func
257                .eval(&FunctionContext::default(), &[vec2, vec1])
258                .unwrap();
259
260            assert!(!result.get(0).is_null());
261            assert!(!result.get(1).is_null());
262            assert!(result.get(2).is_null());
263            assert!(result.get(3).is_null());
264        }
265    }
266
267    #[test]
268    fn test_distance_const_string() {
269        let funcs = [
270            Box::new(CosDistanceFunction {}) as Box<dyn Function>,
271            Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
272            Box::new(DotProductFunction {}) as Box<dyn Function>,
273        ];
274
275        for func in funcs {
276            let const_str = Arc::new(ConstantVector::new(
277                Arc::new(StringVector::from(vec!["[0.0, 1.0]"])),
278                4,
279            ));
280
281            let vec1 = Arc::new(StringVector::from(vec![
282                Some("[0.0, 1.0]"),
283                Some("[1.0, 0.0]"),
284                None,
285                Some("[1.0, 0.0]"),
286            ])) as VectorRef;
287            let vec2 = Arc::new(BinaryVector::from(vec![
288                // [0.0, 1.0]
289                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
290                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
291                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
292                None,
293            ])) as VectorRef;
294
295            let result = func
296                .eval(
297                    &FunctionContext::default(),
298                    &[const_str.clone(), vec1.clone()],
299                )
300                .unwrap();
301
302            assert!(!result.get(0).is_null());
303            assert!(!result.get(1).is_null());
304            assert!(result.get(2).is_null());
305            assert!(!result.get(3).is_null());
306
307            let result = func
308                .eval(
309                    &FunctionContext::default(),
310                    &[vec1.clone(), const_str.clone()],
311                )
312                .unwrap();
313
314            assert!(!result.get(0).is_null());
315            assert!(!result.get(1).is_null());
316            assert!(result.get(2).is_null());
317            assert!(!result.get(3).is_null());
318
319            let result = func
320                .eval(
321                    &FunctionContext::default(),
322                    &[const_str.clone(), vec2.clone()],
323                )
324                .unwrap();
325
326            assert!(!result.get(0).is_null());
327            assert!(!result.get(1).is_null());
328            assert!(!result.get(2).is_null());
329            assert!(result.get(3).is_null());
330
331            let result = func
332                .eval(
333                    &FunctionContext::default(),
334                    &[vec2.clone(), const_str.clone()],
335                )
336                .unwrap();
337
338            assert!(!result.get(0).is_null());
339            assert!(!result.get(1).is_null());
340            assert!(!result.get(2).is_null());
341            assert!(result.get(3).is_null());
342        }
343    }
344
345    #[test]
346    fn test_invalid_vector_length() {
347        let funcs = [
348            Box::new(CosDistanceFunction {}) as Box<dyn Function>,
349            Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
350            Box::new(DotProductFunction {}) as Box<dyn Function>,
351        ];
352
353        for func in funcs {
354            let vec1 = Arc::new(StringVector::from(vec!["[1.0]"])) as VectorRef;
355            let vec2 = Arc::new(StringVector::from(vec!["[1.0, 1.0]"])) as VectorRef;
356            let result = func.eval(&FunctionContext::default(), &[vec1, vec2]);
357            assert!(result.is_err());
358
359            let vec1 = Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63]])) as VectorRef;
360            let vec2 =
361                Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63, 0, 0, 0, 64]])) as VectorRef;
362            let result = func.eval(&FunctionContext::default(), &[vec1, vec2]);
363            assert!(result.is_err());
364        }
365    }
366}