common_function/scalars/vector/
vector_norm.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion::arrow::datatypes::DataType;
20use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
21use datafusion_expr::{Signature, TypeSignature, Volatility};
22use datatypes::scalars::ScalarVectorBuilder;
23use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
24use nalgebra::DVectorView;
25use snafu::ensure;
26
27use crate::function::{Function, FunctionContext};
28use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
29
30const NAME: &str = "vec_norm";
31
32/// Normalizes the vector to length 1, returns a vector.
33/// This's equivalent to `VECTOR_SCALAR_MUL(1/SQRT(VECTOR_ELEM_SUM(VECTOR_MUL(v, v))), v)`.
34///
35/// # Example
36///
37/// ```sql
38/// SELECT vec_to_string(vec_norm('[7.0, 8.0, 9.0]'));
39///
40/// +--------------------------------------------------+
41/// | vec_to_string(vec_norm(Utf8("[7.0, 8.0, 9.0]"))) |
42/// +--------------------------------------------------+
43/// | [0.013888889,0.015873017,0.017857144]            |
44/// +--------------------------------------------------+
45///
46/// ```
47#[derive(Debug, Clone, Default)]
48pub struct VectorNormFunction;
49
50impl Function for VectorNormFunction {
51    fn name(&self) -> &str {
52        NAME
53    }
54
55    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
56        Ok(DataType::Binary)
57    }
58
59    fn signature(&self) -> Signature {
60        Signature::one_of(
61            vec![
62                TypeSignature::Uniform(1, STRINGS.to_vec()),
63                TypeSignature::Uniform(1, BINARYS.to_vec()),
64            ],
65            Volatility::Immutable,
66        )
67    }
68
69    fn eval(
70        &self,
71        _func_ctx: &FunctionContext,
72        columns: &[VectorRef],
73    ) -> common_query::error::Result<VectorRef> {
74        ensure!(
75            columns.len() == 1,
76            InvalidFuncArgsSnafu {
77                err_msg: format!(
78                    "The length of the args is not correct, expect exactly one, have: {}",
79                    columns.len()
80                )
81            }
82        );
83        let arg0 = &columns[0];
84
85        let len = arg0.len();
86        let mut result = BinaryVectorBuilder::with_capacity(len);
87        if len == 0 {
88            return Ok(result.to_vector());
89        }
90
91        let arg0_const = as_veclit_if_const(arg0)?;
92
93        for i in 0..len {
94            let arg0 = match arg0_const.as_ref() {
95                Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
96                None => as_veclit(arg0.get_ref(i))?,
97            };
98            let Some(arg0) = arg0 else {
99                result.push_null();
100                continue;
101            };
102
103            let vec0 = DVectorView::from_slice(&arg0, arg0.len());
104            let vec1 = DVectorView::from_slice(&arg0, arg0.len());
105            let vec2scalar = vec1.component_mul(&vec0);
106            let scalar_var = vec2scalar.sum().sqrt();
107
108            let vec = DVectorView::from_slice(&arg0, arg0.len());
109            // Use unscale to avoid division by zero and keep more precision as possible
110            let vec_res = vec.unscale(scalar_var);
111
112            let veclit = vec_res.as_slice();
113            let binlit = veclit_to_binlit(veclit);
114            result.push(Some(&binlit));
115        }
116
117        Ok(result.to_vector())
118    }
119}
120
121impl Display for VectorNormFunction {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        write!(f, "{}", NAME.to_ascii_uppercase())
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use std::sync::Arc;
130
131    use datatypes::vectors::StringVector;
132
133    use super::*;
134
135    #[test]
136    fn test_vec_norm() {
137        let func = VectorNormFunction;
138
139        let input0 = Arc::new(StringVector::from(vec![
140            Some("[0.0,2.0,3.0]".to_string()),
141            Some("[1.0,2.0,3.0]".to_string()),
142            Some("[7.0,8.0,9.0]".to_string()),
143            Some("[7.0,-8.0,9.0]".to_string()),
144            None,
145        ]));
146
147        let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
148
149        let result = result.as_ref();
150        assert_eq!(result.len(), 5);
151        assert_eq!(
152            result.get_ref(0).as_binary().unwrap(),
153            Some(veclit_to_binlit(&[0.0, 0.5547002, 0.8320503]).as_slice())
154        );
155        assert_eq!(
156            result.get_ref(1).as_binary().unwrap(),
157            Some(veclit_to_binlit(&[0.26726124, 0.5345225, 0.8017837]).as_slice())
158        );
159        assert_eq!(
160            result.get_ref(2).as_binary().unwrap(),
161            Some(veclit_to_binlit(&[0.5025707, 0.5743665, 0.64616233]).as_slice())
162        );
163        assert_eq!(
164            result.get_ref(3).as_binary().unwrap(),
165            Some(veclit_to_binlit(&[0.5025707, -0.5743665, 0.64616233]).as_slice())
166        );
167        assert!(result.get_ref(4).is_null());
168    }
169}