common_function/scalars/vector/
distance.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod cos;
16mod dot;
17mod l2sq;
18
19use std::borrow::Cow;
20use std::fmt::Display;
21
22use common_query::error::{InvalidFuncArgsSnafu, Result};
23use datafusion_expr::Signature;
24use datatypes::arrow::datatypes::DataType;
25use datatypes::scalars::ScalarVectorBuilder;
26use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
27use snafu::ensure;
28
29use crate::function::{Function, FunctionContext};
30use crate::helper;
31use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
32
33macro_rules! define_distance_function {
34    ($StructName:ident, $display_name:expr, $similarity_method:path) => {
35
36        /// A function calculates the distance between two vectors.
37
38        #[derive(Debug, Clone, Default)]
39        pub struct $StructName;
40
41        impl Function for $StructName {
42            fn name(&self) -> &str {
43                $display_name
44            }
45
46            fn return_type(&self, _: &[DataType]) -> Result<DataType> {
47                Ok(DataType::Float32)
48            }
49
50            fn signature(&self) -> Signature {
51                helper::one_of_sigs2(
52                    vec![DataType::Utf8, DataType::Binary],
53                    vec![DataType::Utf8, DataType::Binary],
54                )
55            }
56
57            fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
58                ensure!(
59                    columns.len() == 2,
60                    InvalidFuncArgsSnafu {
61                        err_msg: format!(
62                            "The length of the args is not correct, expect exactly two, have: {}",
63                            columns.len()
64                        ),
65                    }
66                );
67                let arg0 = &columns[0];
68                let arg1 = &columns[1];
69
70                let size = arg0.len();
71                let mut result = Float32VectorBuilder::with_capacity(size);
72                if size == 0 {
73                    return Ok(result.to_vector());
74                }
75
76                let arg0_const = as_veclit_if_const(arg0)?;
77                let arg1_const = as_veclit_if_const(arg1)?;
78
79                for i in 0..size {
80                    let vec0 = match arg0_const.as_ref() {
81                        Some(a) => Some(Cow::Borrowed(a.as_ref())),
82                        None => as_veclit(arg0.get_ref(i))?,
83                    };
84                    let vec1 = match arg1_const.as_ref() {
85                        Some(b) => Some(Cow::Borrowed(b.as_ref())),
86                        None => as_veclit(arg1.get_ref(i))?,
87                    };
88
89                    if let (Some(vec0), Some(vec1)) = (vec0, vec1) {
90                        ensure!(
91                            vec0.len() == vec1.len(),
92                            InvalidFuncArgsSnafu {
93                                err_msg: format!(
94                                    "The length of the vectors must match to calculate distance, have: {} vs {}",
95                                    vec0.len(),
96                                    vec1.len()
97                                ),
98                            }
99                        );
100
101                        // Checked if the length of the vectors match
102                        let d = $similarity_method(vec0.as_ref(), vec1.as_ref());
103                        result.push(Some(d));
104                    } else {
105                        result.push_null();
106                    }
107                }
108
109                return Ok(result.to_vector());
110            }
111        }
112
113        impl Display for $StructName {
114            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115                write!(f, "{}", $display_name.to_ascii_uppercase())
116            }
117        }
118    }
119}
120
121define_distance_function!(CosDistanceFunction, "vec_cos_distance", cos::cos);
122define_distance_function!(L2SqDistanceFunction, "vec_l2sq_distance", l2sq::l2sq);
123define_distance_function!(DotProductFunction, "vec_dot_product", dot::dot);
124
125#[cfg(test)]
126mod tests {
127    use std::sync::Arc;
128
129    use datatypes::vectors::{BinaryVector, ConstantVector, StringVector};
130
131    use super::*;
132
133    #[test]
134    fn test_distance_string_string() {
135        let funcs = [
136            Box::new(CosDistanceFunction {}) as Box<dyn Function>,
137            Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
138            Box::new(DotProductFunction {}) as Box<dyn Function>,
139        ];
140
141        for func in funcs {
142            let vec1 = Arc::new(StringVector::from(vec![
143                Some("[0.0, 1.0]"),
144                Some("[1.0, 0.0]"),
145                None,
146                Some("[1.0, 0.0]"),
147            ])) as VectorRef;
148            let vec2 = Arc::new(StringVector::from(vec![
149                Some("[0.0, 1.0]"),
150                Some("[0.0, 1.0]"),
151                Some("[0.0, 1.0]"),
152                None,
153            ])) as VectorRef;
154
155            let result = func
156                .eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
157                .unwrap();
158
159            assert!(!result.get(0).is_null());
160            assert!(!result.get(1).is_null());
161            assert!(result.get(2).is_null());
162            assert!(result.get(3).is_null());
163
164            let result = func
165                .eval(&FunctionContext::default(), &[vec2, vec1])
166                .unwrap();
167
168            assert!(!result.get(0).is_null());
169            assert!(!result.get(1).is_null());
170            assert!(result.get(2).is_null());
171            assert!(result.get(3).is_null());
172        }
173    }
174
175    #[test]
176    fn test_distance_binary_binary() {
177        let funcs = [
178            Box::new(CosDistanceFunction {}) as Box<dyn Function>,
179            Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
180            Box::new(DotProductFunction {}) as Box<dyn Function>,
181        ];
182
183        for func in funcs {
184            let vec1 = Arc::new(BinaryVector::from(vec![
185                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
186                Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
187                None,
188                Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
189            ])) as VectorRef;
190            let vec2 = Arc::new(BinaryVector::from(vec![
191                // [0.0, 1.0]
192                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
193                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
194                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
195                None,
196            ])) as VectorRef;
197
198            let result = func
199                .eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
200                .unwrap();
201
202            assert!(!result.get(0).is_null());
203            assert!(!result.get(1).is_null());
204            assert!(result.get(2).is_null());
205            assert!(result.get(3).is_null());
206
207            let result = func
208                .eval(&FunctionContext::default(), &[vec2, vec1])
209                .unwrap();
210
211            assert!(!result.get(0).is_null());
212            assert!(!result.get(1).is_null());
213            assert!(result.get(2).is_null());
214            assert!(result.get(3).is_null());
215        }
216    }
217
218    #[test]
219    fn test_distance_string_binary() {
220        let funcs = [
221            Box::new(CosDistanceFunction {}) as Box<dyn Function>,
222            Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
223            Box::new(DotProductFunction {}) as Box<dyn Function>,
224        ];
225
226        for func in funcs {
227            let vec1 = Arc::new(StringVector::from(vec![
228                Some("[0.0, 1.0]"),
229                Some("[1.0, 0.0]"),
230                None,
231                Some("[1.0, 0.0]"),
232            ])) as VectorRef;
233            let vec2 = Arc::new(BinaryVector::from(vec![
234                // [0.0, 1.0]
235                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
236                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
237                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
238                None,
239            ])) as VectorRef;
240
241            let result = func
242                .eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
243                .unwrap();
244
245            assert!(!result.get(0).is_null());
246            assert!(!result.get(1).is_null());
247            assert!(result.get(2).is_null());
248            assert!(result.get(3).is_null());
249
250            let result = func
251                .eval(&FunctionContext::default(), &[vec2, vec1])
252                .unwrap();
253
254            assert!(!result.get(0).is_null());
255            assert!(!result.get(1).is_null());
256            assert!(result.get(2).is_null());
257            assert!(result.get(3).is_null());
258        }
259    }
260
261    #[test]
262    fn test_distance_const_string() {
263        let funcs = [
264            Box::new(CosDistanceFunction {}) as Box<dyn Function>,
265            Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
266            Box::new(DotProductFunction {}) as Box<dyn Function>,
267        ];
268
269        for func in funcs {
270            let const_str = Arc::new(ConstantVector::new(
271                Arc::new(StringVector::from(vec!["[0.0, 1.0]"])),
272                4,
273            ));
274
275            let vec1 = Arc::new(StringVector::from(vec![
276                Some("[0.0, 1.0]"),
277                Some("[1.0, 0.0]"),
278                None,
279                Some("[1.0, 0.0]"),
280            ])) as VectorRef;
281            let vec2 = Arc::new(BinaryVector::from(vec![
282                // [0.0, 1.0]
283                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
284                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
285                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
286                None,
287            ])) as VectorRef;
288
289            let result = func
290                .eval(
291                    &FunctionContext::default(),
292                    &[const_str.clone(), vec1.clone()],
293                )
294                .unwrap();
295
296            assert!(!result.get(0).is_null());
297            assert!(!result.get(1).is_null());
298            assert!(result.get(2).is_null());
299            assert!(!result.get(3).is_null());
300
301            let result = func
302                .eval(
303                    &FunctionContext::default(),
304                    &[vec1.clone(), const_str.clone()],
305                )
306                .unwrap();
307
308            assert!(!result.get(0).is_null());
309            assert!(!result.get(1).is_null());
310            assert!(result.get(2).is_null());
311            assert!(!result.get(3).is_null());
312
313            let result = func
314                .eval(
315                    &FunctionContext::default(),
316                    &[const_str.clone(), vec2.clone()],
317                )
318                .unwrap();
319
320            assert!(!result.get(0).is_null());
321            assert!(!result.get(1).is_null());
322            assert!(!result.get(2).is_null());
323            assert!(result.get(3).is_null());
324
325            let result = func
326                .eval(
327                    &FunctionContext::default(),
328                    &[vec2.clone(), const_str.clone()],
329                )
330                .unwrap();
331
332            assert!(!result.get(0).is_null());
333            assert!(!result.get(1).is_null());
334            assert!(!result.get(2).is_null());
335            assert!(result.get(3).is_null());
336        }
337    }
338
339    #[test]
340    fn test_invalid_vector_length() {
341        let funcs = [
342            Box::new(CosDistanceFunction {}) as Box<dyn Function>,
343            Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
344            Box::new(DotProductFunction {}) as Box<dyn Function>,
345        ];
346
347        for func in funcs {
348            let vec1 = Arc::new(StringVector::from(vec!["[1.0]"])) as VectorRef;
349            let vec2 = Arc::new(StringVector::from(vec!["[1.0, 1.0]"])) as VectorRef;
350            let result = func.eval(&FunctionContext::default(), &[vec1, vec2]);
351            assert!(result.is_err());
352
353            let vec1 = Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63]])) as VectorRef;
354            let vec2 =
355                Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63, 0, 0, 0, 64]])) as VectorRef;
356            let result = func.eval(&FunctionContext::default(), &[vec1, vec2]);
357            assert!(result.is_err());
358        }
359    }
360}