common_function/scalars/vector/
distance.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod cos;
16mod dot;
17mod l2sq;
18
19use std::borrow::Cow;
20use std::fmt::Display;
21
22use datafusion::logical_expr::ColumnarValue;
23use datafusion_common::ScalarValue;
24use datafusion_expr::{ScalarFunctionArgs, Signature};
25use datatypes::arrow::datatypes::DataType;
26
27use crate::function::Function;
28use crate::helper;
29
30macro_rules! define_distance_function {
31    ($StructName:ident, $display_name:expr, $similarity_method:path) => {
32        /// A function calculates the distance between two vectors.
33
34        #[derive(Debug, Clone)]
35        pub(crate) struct $StructName {
36            signature: Signature,
37        }
38
39        impl Default for $StructName {
40            fn default() -> Self {
41                Self {
42                    signature: helper::one_of_sigs2(
43                        vec![
44                            DataType::Utf8,
45                            DataType::Utf8View,
46                            DataType::Binary,
47                            DataType::BinaryView,
48                        ],
49                        vec![
50                            DataType::Utf8,
51                            DataType::Utf8View,
52                            DataType::Binary,
53                            DataType::BinaryView,
54                        ],
55                    ),
56                }
57            }
58        }
59
60        impl Function for $StructName {
61            fn name(&self) -> &str {
62                $display_name
63            }
64
65            fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
66                Ok(DataType::Float32)
67            }
68
69            fn signature(&self) -> &Signature {
70                &self.signature
71            }
72
73            fn invoke_with_args(
74                &self,
75                args: ScalarFunctionArgs,
76            ) -> datafusion_common::Result<ColumnarValue> {
77                let body = |v0: &Option<Cow<[f32]>>,
78                            v1: &Option<Cow<[f32]>>|
79                 -> datafusion_common::Result<ScalarValue> {
80                    let result = if let (Some(v0), Some(v1)) = (v0, v1) {
81                        if v0.len() != v1.len() {
82                            return Err(datafusion_common::DataFusionError::Execution(format!(
83                                "vectors length not match: {}",
84                                self.name()
85                            )));
86                        }
87
88                        let d = $similarity_method(v0, v1);
89                        Some(d)
90                    } else {
91                        None
92                    };
93                    Ok(ScalarValue::Float32(result))
94                };
95
96                let calculator = $crate::scalars::vector::VectorCalculator {
97                    name: self.name(),
98                    func: body,
99                };
100                calculator.invoke_with_vectors(args)
101            }
102        }
103
104        impl Display for $StructName {
105            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106                write!(f, "{}", $display_name.to_ascii_uppercase())
107            }
108        }
109    };
110}
111
112define_distance_function!(CosDistanceFunction, "vec_cos_distance", cos::cos);
113define_distance_function!(L2SqDistanceFunction, "vec_l2sq_distance", l2sq::l2sq);
114define_distance_function!(DotProductFunction, "vec_dot_product", dot::dot);
115
116#[cfg(test)]
117mod tests {
118    use std::sync::Arc;
119
120    use arrow_schema::Field;
121    use datafusion::arrow::array::{Array, ArrayRef, AsArray, BinaryArray, StringViewArray};
122    use datafusion::arrow::datatypes::Float32Type;
123    use datafusion_common::config::ConfigOptions;
124
125    use super::*;
126
127    fn test_invoke(func: &dyn Function, args: &[ArrayRef]) -> datafusion_common::Result<ArrayRef> {
128        let number_rows = args[0].len();
129        let args = ScalarFunctionArgs {
130            args: args
131                .iter()
132                .map(|x| ColumnarValue::Array(x.clone()))
133                .collect::<Vec<_>>(),
134            arg_fields: vec![],
135            number_rows,
136            return_field: Arc::new(Field::new("x", DataType::Float32, false)),
137            config_options: Arc::new(ConfigOptions::new()),
138        };
139        func.invoke_with_args(args)
140            .and_then(|x| x.to_array(number_rows))
141    }
142
143    #[test]
144    fn test_distance_string_string() {
145        let funcs = [
146            Box::new(CosDistanceFunction::default()) as Box<dyn Function>,
147            Box::new(L2SqDistanceFunction::default()) as Box<dyn Function>,
148            Box::new(DotProductFunction::default()) as Box<dyn Function>,
149        ];
150
151        for func in funcs {
152            let vec1: ArrayRef = Arc::new(StringViewArray::from(vec![
153                Some("[0.0, 1.0]"),
154                Some("[1.0, 0.0]"),
155                None,
156                Some("[1.0, 0.0]"),
157            ]));
158            let vec2: ArrayRef = Arc::new(StringViewArray::from(vec![
159                Some("[0.0, 1.0]"),
160                Some("[0.0, 1.0]"),
161                Some("[0.0, 1.0]"),
162                None,
163            ]));
164
165            let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap();
166            let result = result.as_primitive::<Float32Type>();
167
168            assert!(!result.is_null(0));
169            assert!(!result.is_null(1));
170            assert!(result.is_null(2));
171            assert!(result.is_null(3));
172
173            let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap();
174            let result = result.as_primitive::<Float32Type>();
175
176            assert!(!result.is_null(0));
177            assert!(!result.is_null(1));
178            assert!(result.is_null(2));
179            assert!(result.is_null(3));
180        }
181    }
182
183    #[test]
184    fn test_distance_binary_binary() {
185        let funcs = [
186            Box::new(CosDistanceFunction::default()) as Box<dyn Function>,
187            Box::new(L2SqDistanceFunction::default()) as Box<dyn Function>,
188            Box::new(DotProductFunction::default()) as Box<dyn Function>,
189        ];
190
191        for func in funcs {
192            let vec1: ArrayRef = Arc::new(BinaryArray::from_iter(vec![
193                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
194                Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
195                None,
196                Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
197            ]));
198            let vec2: ArrayRef = Arc::new(BinaryArray::from_iter(vec![
199                // [0.0, 1.0]
200                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
201                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
202                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
203                None,
204            ]));
205
206            let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap();
207            let result = result.as_primitive::<Float32Type>();
208
209            assert!(!result.is_null(0));
210            assert!(!result.is_null(1));
211            assert!(result.is_null(2));
212            assert!(result.is_null(3));
213
214            let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap();
215            let result = result.as_primitive::<Float32Type>();
216
217            assert!(!result.is_null(0));
218            assert!(!result.is_null(1));
219            assert!(result.is_null(2));
220            assert!(result.is_null(3));
221        }
222    }
223
224    #[test]
225    fn test_distance_string_binary() {
226        let funcs = [
227            Box::new(CosDistanceFunction::default()) as Box<dyn Function>,
228            Box::new(L2SqDistanceFunction::default()) as Box<dyn Function>,
229            Box::new(DotProductFunction::default()) as Box<dyn Function>,
230        ];
231
232        for func in funcs {
233            let vec1: ArrayRef = Arc::new(StringViewArray::from(vec![
234                Some("[0.0, 1.0]"),
235                Some("[1.0, 0.0]"),
236                None,
237                Some("[1.0, 0.0]"),
238            ]));
239            let vec2: ArrayRef = Arc::new(BinaryArray::from_iter(vec![
240                // [0.0, 1.0]
241                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
242                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
243                Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
244                None,
245            ]));
246
247            let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap();
248            let result = result.as_primitive::<Float32Type>();
249
250            assert!(!result.is_null(0));
251            assert!(!result.is_null(1));
252            assert!(result.is_null(2));
253            assert!(result.is_null(3));
254
255            let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap();
256            let result = result.as_primitive::<Float32Type>();
257
258            assert!(!result.is_null(0));
259            assert!(!result.is_null(1));
260            assert!(result.is_null(2));
261            assert!(result.is_null(3));
262        }
263    }
264
265    #[test]
266    fn test_invalid_vector_length() {
267        let funcs = [
268            Box::new(CosDistanceFunction::default()) as Box<dyn Function>,
269            Box::new(L2SqDistanceFunction::default()) as Box<dyn Function>,
270            Box::new(DotProductFunction::default()) as Box<dyn Function>,
271        ];
272
273        for func in funcs {
274            let vec1: ArrayRef = Arc::new(StringViewArray::from(vec!["[1.0]"]));
275            let vec2: ArrayRef = Arc::new(StringViewArray::from(vec!["[1.0, 1.0]"]));
276            let result = test_invoke(func.as_ref(), &[vec1, vec2]);
277            assert!(result.is_err());
278
279            let vec1: ArrayRef = Arc::new(BinaryArray::from_iter_values(vec![vec![0, 0, 128, 63]]));
280            let vec2: ArrayRef = Arc::new(BinaryArray::from_iter_values(vec![vec![
281                0, 0, 128, 63, 0, 0, 0, 64,
282            ]]));
283            let result = test_invoke(func.as_ref(), &[vec1, vec2]);
284            assert!(result.is_err());
285        }
286    }
287}