common_function/scalars/vector/
vector_norm.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use common_query::prelude::{Signature, TypeSignature, Volatility};
20use datatypes::prelude::ConcreteDataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
23use nalgebra::DVectorView;
24use snafu::ensure;
25
26use crate::function::{Function, FunctionContext};
27use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
28
29const NAME: &str = "vec_norm";
30
31/// Normalizes the vector to length 1, returns a vector.
32/// This's equivalent to `VECTOR_SCALAR_MUL(1/SQRT(VECTOR_ELEM_SUM(VECTOR_MUL(v, v))), v)`.
33///
34/// # Example
35///
36/// ```sql
37/// SELECT vec_to_string(vec_norm('[7.0, 8.0, 9.0]'));
38///
39/// +--------------------------------------------------+
40/// | vec_to_string(vec_norm(Utf8("[7.0, 8.0, 9.0]"))) |
41/// +--------------------------------------------------+
42/// | [0.013888889,0.015873017,0.017857144]            |
43/// +--------------------------------------------------+
44///
45/// ```
46#[derive(Debug, Clone, Default)]
47pub struct VectorNormFunction;
48
49impl Function for VectorNormFunction {
50    fn name(&self) -> &str {
51        NAME
52    }
53
54    fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
55        Ok(ConcreteDataType::binary_datatype())
56    }
57
58    fn signature(&self) -> Signature {
59        Signature::one_of(
60            vec![
61                TypeSignature::Exact(vec![ConcreteDataType::string_datatype()]),
62                TypeSignature::Exact(vec![ConcreteDataType::binary_datatype()]),
63            ],
64            Volatility::Immutable,
65        )
66    }
67
68    fn eval(
69        &self,
70        _func_ctx: &FunctionContext,
71        columns: &[VectorRef],
72    ) -> common_query::error::Result<VectorRef> {
73        ensure!(
74            columns.len() == 1,
75            InvalidFuncArgsSnafu {
76                err_msg: format!(
77                    "The length of the args is not correct, expect exactly one, have: {}",
78                    columns.len()
79                )
80            }
81        );
82        let arg0 = &columns[0];
83
84        let len = arg0.len();
85        let mut result = BinaryVectorBuilder::with_capacity(len);
86        if len == 0 {
87            return Ok(result.to_vector());
88        }
89
90        let arg0_const = as_veclit_if_const(arg0)?;
91
92        for i in 0..len {
93            let arg0 = match arg0_const.as_ref() {
94                Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
95                None => as_veclit(arg0.get_ref(i))?,
96            };
97            let Some(arg0) = arg0 else {
98                result.push_null();
99                continue;
100            };
101
102            let vec0 = DVectorView::from_slice(&arg0, arg0.len());
103            let vec1 = DVectorView::from_slice(&arg0, arg0.len());
104            let vec2scalar = vec1.component_mul(&vec0);
105            let scalar_var = vec2scalar.sum().sqrt();
106
107            let vec = DVectorView::from_slice(&arg0, arg0.len());
108            // Use unscale to avoid division by zero and keep more precision as possible
109            let vec_res = vec.unscale(scalar_var);
110
111            let veclit = vec_res.as_slice();
112            let binlit = veclit_to_binlit(veclit);
113            result.push(Some(&binlit));
114        }
115
116        Ok(result.to_vector())
117    }
118}
119
120impl Display for VectorNormFunction {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        write!(f, "{}", NAME.to_ascii_uppercase())
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use std::sync::Arc;
129
130    use datatypes::vectors::StringVector;
131
132    use super::*;
133
134    #[test]
135    fn test_vec_norm() {
136        let func = VectorNormFunction;
137
138        let input0 = Arc::new(StringVector::from(vec![
139            Some("[0.0,2.0,3.0]".to_string()),
140            Some("[1.0,2.0,3.0]".to_string()),
141            Some("[7.0,8.0,9.0]".to_string()),
142            Some("[7.0,-8.0,9.0]".to_string()),
143            None,
144        ]));
145
146        let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
147
148        let result = result.as_ref();
149        assert_eq!(result.len(), 5);
150        assert_eq!(
151            result.get_ref(0).as_binary().unwrap(),
152            Some(veclit_to_binlit(&[0.0, 0.5547002, 0.8320503]).as_slice())
153        );
154        assert_eq!(
155            result.get_ref(1).as_binary().unwrap(),
156            Some(veclit_to_binlit(&[0.26726124, 0.5345225, 0.8017837]).as_slice())
157        );
158        assert_eq!(
159            result.get_ref(2).as_binary().unwrap(),
160            Some(veclit_to_binlit(&[0.5025707, 0.5743665, 0.64616233]).as_slice())
161        );
162        assert_eq!(
163            result.get_ref(3).as_binary().unwrap(),
164            Some(veclit_to_binlit(&[0.5025707, -0.5743665, 0.64616233]).as_slice())
165        );
166        assert!(result.get_ref(4).is_null());
167    }
168}