common_function/scalars/vector/distance/
cos.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use nalgebra::DVectorView;
16
17/// Calculates the cos distance between two vectors.
18///
19/// **Note:** Must ensure that the length of the two vectors are the same.
20pub fn cos(lhs: &[f32], rhs: &[f32]) -> f32 {
21    let lhs_vec = DVectorView::from_slice(lhs, lhs.len());
22    let rhs_vec = DVectorView::from_slice(rhs, rhs.len());
23
24    let dot_product = lhs_vec.dot(&rhs_vec);
25    let lhs_norm = lhs_vec.norm();
26    let rhs_norm = rhs_vec.norm();
27    if dot_product.abs() < f32::EPSILON
28        || lhs_norm.abs() < f32::EPSILON
29        || rhs_norm.abs() < f32::EPSILON
30    {
31        return 1.0;
32    }
33
34    let cos_similar = dot_product / (lhs_norm * rhs_norm);
35    let res = 1.0 - cos_similar;
36    if res.abs() < f32::EPSILON {
37        0.0
38    } else {
39        res
40    }
41}
42
43#[cfg(test)]
44mod tests {
45    use approx::assert_relative_eq;
46
47    use super::*;
48
49    #[test]
50    fn test_cos_scalar() {
51        let lhs = vec![1.0, 2.0, 3.0];
52        let rhs = vec![1.0, 2.0, 3.0];
53        assert_relative_eq!(cos(&lhs, &rhs), 0.0, epsilon = 1e-2);
54
55        let lhs = vec![1.0, 2.0, 3.0];
56        let rhs = vec![4.0, 5.0, 6.0];
57        assert_relative_eq!(cos(&lhs, &rhs), 0.025, epsilon = 1e-2);
58
59        let lhs = vec![1.0, 2.0, 3.0];
60        let rhs = vec![7.0, 8.0, 9.0];
61        assert_relative_eq!(cos(&lhs, &rhs), 0.04, epsilon = 1e-2);
62
63        let lhs = vec![0.0, 0.0, 0.0];
64        let rhs = vec![1.0, 2.0, 3.0];
65        assert_relative_eq!(cos(&lhs, &rhs), 1.0, epsilon = 1e-2);
66
67        let lhs = vec![0.0, 0.0, 0.0];
68        let rhs = vec![4.0, 5.0, 6.0];
69        assert_relative_eq!(cos(&lhs, &rhs), 1.0, epsilon = 1e-2);
70
71        let lhs = vec![0.0, 0.0, 0.0];
72        let rhs = vec![7.0, 8.0, 9.0];
73        assert_relative_eq!(cos(&lhs, &rhs), 1.0, epsilon = 1e-2);
74
75        let lhs = vec![7.0, 8.0, 9.0];
76        let rhs = vec![1.0, 2.0, 3.0];
77        assert_relative_eq!(cos(&lhs, &rhs), 0.04, epsilon = 1e-2);
78
79        let lhs = vec![7.0, 8.0, 9.0];
80        let rhs = vec![4.0, 5.0, 6.0];
81        assert_relative_eq!(cos(&lhs, &rhs), 0.0, epsilon = 1e-2);
82
83        let lhs = vec![7.0, 8.0, 9.0];
84        let rhs = vec![7.0, 8.0, 9.0];
85        assert_relative_eq!(cos(&lhs, &rhs), 0.0, epsilon = 1e-2);
86    }
87}