common_function/scalars/vector/
distance.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod cos;
16mod dot;
17mod l2sq;
18
19pub const VEC_COS_DISTANCE: &str = "vec_cos_distance";
20pub const VEC_L2SQ_DISTANCE: &str = "vec_l2sq_distance";
21pub const VEC_DOT_PRODUCT: &str = "vec_dot_product";
22
23use std::borrow::Cow;
24use std::fmt::Display;
25
26use datafusion::logical_expr::ColumnarValue;
27use datafusion_common::ScalarValue;
28use datafusion_expr::{ScalarFunctionArgs, Signature};
29use datatypes::arrow::datatypes::DataType;
30
31use crate::function::Function;
32use crate::helper;
33
34macro_rules! define_distance_function {
35    ($StructName:ident, $display_name:expr, $similarity_method:path) => {
36        /// A function calculates the distance between two vectors.
37
38        #[derive(Debug, Clone)]
39        pub(crate) struct $StructName {
40            signature: Signature,
41        }
42
43        impl Default for $StructName {
44            fn default() -> Self {
45                Self {
46                    signature: helper::one_of_sigs2(
47                        vec![
48                            DataType::Utf8,
49                            DataType::Utf8View,
50                            DataType::Binary,
51                            DataType::BinaryView,
52                        ],
53                        vec![
54                            DataType::Utf8,
55                            DataType::Utf8View,
56                            DataType::Binary,
57                            DataType::BinaryView,
58                        ],
59                    ),
60                }
61            }
62        }
63
64        impl Function for $StructName {
65            fn name(&self) -> &str {
66                $display_name
67            }
68
69            fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
70                Ok(DataType::Float32)
71            }
72
73            fn signature(&self) -> &Signature {
74                &self.signature
75            }
76
77            fn invoke_with_args(
78                &self,
79                args: ScalarFunctionArgs,
80            ) -> datafusion_common::Result<ColumnarValue> {
81                let body = |v0: &Option<Cow<[f32]>>,
82                            v1: &Option<Cow<[f32]>>|
83                 -> datafusion_common::Result<ScalarValue> {
84                    let result = if let (Some(v0), Some(v1)) = (v0, v1) {
85                        if v0.len() != v1.len() {
86                            return Err(datafusion_common::DataFusionError::Execution(format!(
87                                "vectors length not match: {}",
88                                self.name()
89                            )));
90                        }
91
92                        let d = $similarity_method(v0, v1);
93                        Some(d)
94                    } else {
95                        None
96                    };
97                    Ok(ScalarValue::Float32(result))
98                };
99
100                let calculator = $crate::scalars::vector::VectorCalculator {
101                    name: self.name(),
102                    func: body,
103                };
104                calculator.invoke_with_vectors(args)
105            }
106        }
107
108        impl Display for $StructName {
109            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110                write!(f, "{}", $display_name.to_ascii_uppercase())
111            }
112        }
113    };
114}
115
116define_distance_function!(CosDistanceFunction, VEC_COS_DISTANCE, cos::cos);
117define_distance_function!(L2SqDistanceFunction, VEC_L2SQ_DISTANCE, l2sq::l2sq);
118define_distance_function!(DotProductFunction, VEC_DOT_PRODUCT, dot::dot);
119
120#[cfg(test)]
121mod tests {
122    use std::sync::Arc;
123
124    use arrow_schema::Field;
125    use datafusion::arrow::array::{Array, ArrayRef, AsArray, BinaryArray, StringViewArray};
126    use datafusion::arrow::datatypes::Float32Type;
127    use datafusion_common::config::ConfigOptions;
128
129    use super::*;
130
131    fn test_invoke(func: &dyn Function, args: &[ArrayRef]) -> datafusion_common::Result<ArrayRef> {
132        let number_rows = args[0].len();
133        let args = ScalarFunctionArgs {
134            args: args
135                .iter()
136                .map(|x| ColumnarValue::Array(x.clone()))
137                .collect::<Vec<_>>(),
138            arg_fields: vec![],
139            number_rows,
140            return_field: Arc::new(Field::new("x", DataType::Float32, false)),
141            config_options: Arc::new(ConfigOptions::new()),
142        };
143        func.invoke_with_args(args)
144            .and_then(|x| x.to_array(number_rows))
145    }
146
147    #[test]
148    fn test_distance_string_string() {
149        let funcs = [
150            Box::new(CosDistanceFunction::default()) as Box<dyn Function>,
151            Box::new(L2SqDistanceFunction::default()) as Box<dyn Function>,
152            Box::new(DotProductFunction::default()) as Box<dyn Function>,
153        ];
154
155        for func in funcs {
156            let vec1: ArrayRef = Arc::new(StringViewArray::from(vec![
157                Some("[0.0, 1.0]"),
158                Some("[1.0, 0.0]"),
159                None,
160                Some("[1.0, 0.0]"),
161            ]));
162            let vec2: ArrayRef = Arc::new(StringViewArray::from(vec![
163                Some("[0.0, 1.0]"),
164                Some("[0.0, 1.0]"),
165                Some("[0.0, 1.0]"),
166                None,
167            ]));
168
169            let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap();
170            let result = result.as_primitive::<Float32Type>();
171
172            assert!(!result.is_null(0));
173            assert!(!result.is_null(1));
174            assert!(result.is_null(2));
175            assert!(result.is_null(3));
176
177            let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap();
178            let result = result.as_primitive::<Float32Type>();
179
180            assert!(!result.is_null(0));
181            assert!(!result.is_null(1));
182            assert!(result.is_null(2));
183            assert!(result.is_null(3));
184        }
185    }
186
187    #[test]
188    fn test_distance_binary_binary() {
189        let funcs = [
190            Box::new(CosDistanceFunction::default()) as Box<dyn Function>,
191            Box::new(L2SqDistanceFunction::default()) as Box<dyn Function>,
192            Box::new(DotProductFunction::default()) as Box<dyn Function>,
193        ];
194
195        for func in funcs {
196            let vec1: ArrayRef = Arc::new(BinaryArray::from_iter(vec![
197                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
198                Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
199                None,
200                Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
201            ]));
202            let vec2: ArrayRef = Arc::new(BinaryArray::from_iter(vec![
203                // [0.0, 1.0]
204                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
205                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
206                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
207                None,
208            ]));
209
210            let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap();
211            let result = result.as_primitive::<Float32Type>();
212
213            assert!(!result.is_null(0));
214            assert!(!result.is_null(1));
215            assert!(result.is_null(2));
216            assert!(result.is_null(3));
217
218            let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap();
219            let result = result.as_primitive::<Float32Type>();
220
221            assert!(!result.is_null(0));
222            assert!(!result.is_null(1));
223            assert!(result.is_null(2));
224            assert!(result.is_null(3));
225        }
226    }
227
228    #[test]
229    fn test_distance_string_binary() {
230        let funcs = [
231            Box::new(CosDistanceFunction::default()) as Box<dyn Function>,
232            Box::new(L2SqDistanceFunction::default()) as Box<dyn Function>,
233            Box::new(DotProductFunction::default()) as Box<dyn Function>,
234        ];
235
236        for func in funcs {
237            let vec1: ArrayRef = Arc::new(StringViewArray::from(vec![
238                Some("[0.0, 1.0]"),
239                Some("[1.0, 0.0]"),
240                None,
241                Some("[1.0, 0.0]"),
242            ]));
243            let vec2: ArrayRef = Arc::new(BinaryArray::from_iter(vec![
244                // [0.0, 1.0]
245                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
246                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
247                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
248                None,
249            ]));
250
251            let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap();
252            let result = result.as_primitive::<Float32Type>();
253
254            assert!(!result.is_null(0));
255            assert!(!result.is_null(1));
256            assert!(result.is_null(2));
257            assert!(result.is_null(3));
258
259            let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap();
260            let result = result.as_primitive::<Float32Type>();
261
262            assert!(!result.is_null(0));
263            assert!(!result.is_null(1));
264            assert!(result.is_null(2));
265            assert!(result.is_null(3));
266        }
267    }
268
269    #[test]
270    fn test_invalid_vector_length() {
271        let funcs = [
272            Box::new(CosDistanceFunction::default()) as Box<dyn Function>,
273            Box::new(L2SqDistanceFunction::default()) as Box<dyn Function>,
274            Box::new(DotProductFunction::default()) as Box<dyn Function>,
275        ];
276
277        for func in funcs {
278            let vec1: ArrayRef = Arc::new(StringViewArray::from(vec!["[1.0]"]));
279            let vec2: ArrayRef = Arc::new(StringViewArray::from(vec!["[1.0, 1.0]"]));
280            let result = test_invoke(func.as_ref(), &[vec1, vec2]);
281            assert!(result.is_err());
282
283            let vec1: ArrayRef = Arc::new(BinaryArray::from_iter_values(vec![vec![0, 0, 128, 63]]));
284            let vec2: ArrayRef = Arc::new(BinaryArray::from_iter_values(vec![vec![
285                0, 0, 128, 63, 0, 0, 0, 64,
286            ]]));
287            let result = test_invoke(func.as_ref(), &[vec1, vec2]);
288            assert!(result.is_err());
289        }
290    }
291}