common_function/scalars/vector/
vector_norm.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::{Coercion, ColumnarValue, TypeSignatureClass};
19use datafusion_common::ScalarValue;
20use datafusion_common::types::{logical_binary, logical_string};
21use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
22use nalgebra::DVectorView;
23
24use crate::function::Function;
25use crate::scalars::vector::VectorCalculator;
26use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
27
28const NAME: &str = "vec_norm";
29
30/// Normalizes the vector to length 1, returns a vector.
31/// This's equivalent to `VECTOR_SCALAR_MUL(1/SQRT(VECTOR_ELEM_SUM(VECTOR_MUL(v, v))), v)`.
32///
33/// # Example
34///
35/// ```sql
36/// SELECT vec_to_string(vec_norm('[7.0, 8.0, 9.0]'));
37///
38/// +--------------------------------------------------+
39/// | vec_to_string(vec_norm(Utf8("[7.0, 8.0, 9.0]"))) |
40/// +--------------------------------------------------+
41/// | [0.013888889,0.015873017,0.017857144]            |
42/// +--------------------------------------------------+
43///
44/// ```
45#[derive(Debug, Clone)]
46pub(crate) struct VectorNormFunction {
47    signature: Signature,
48}
49
50impl Default for VectorNormFunction {
51    fn default() -> Self {
52        Self {
53            signature: Signature::one_of(
54                vec![
55                    TypeSignature::Coercible(vec![Coercion::new_exact(
56                        TypeSignatureClass::Native(logical_binary()),
57                    )]),
58                    TypeSignature::Coercible(vec![Coercion::new_exact(
59                        TypeSignatureClass::Native(logical_string()),
60                    )]),
61                ],
62                Volatility::Immutable,
63            ),
64        }
65    }
66}
67
68impl Function for VectorNormFunction {
69    fn name(&self) -> &str {
70        NAME
71    }
72
73    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
74        Ok(DataType::BinaryView)
75    }
76
77    fn signature(&self) -> &Signature {
78        &self.signature
79    }
80
81    fn invoke_with_args(
82        &self,
83        args: ScalarFunctionArgs,
84    ) -> datafusion_common::Result<ColumnarValue> {
85        let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
86            let v0 = as_veclit(v0)?;
87            let Some(v0) = v0 else {
88                return Ok(ScalarValue::BinaryView(None));
89            };
90
91            let v0 = DVectorView::from_slice(&v0, v0.len());
92            let result =
93                veclit_to_binlit(v0.unscale(v0.component_mul(&v0).sum().sqrt()).as_slice());
94            Ok(ScalarValue::BinaryView(Some(result)))
95        };
96
97        let calculator = VectorCalculator {
98            name: self.name(),
99            func: body,
100        };
101        calculator.invoke_with_single_argument(args)
102    }
103}
104
105impl Display for VectorNormFunction {
106    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107        write!(f, "{}", NAME.to_ascii_uppercase())
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use std::sync::Arc;
114
115    use arrow_schema::Field;
116    use datafusion::arrow::array::{Array, AsArray, StringViewArray};
117    use datafusion_common::config::ConfigOptions;
118
119    use super::*;
120
121    #[test]
122    fn test_vec_norm() {
123        let func = VectorNormFunction::default();
124
125        let input0 = Arc::new(StringViewArray::from(vec![
126            Some("[0.0,2.0,3.0]".to_string()),
127            Some("[1.0,2.0,3.0]".to_string()),
128            Some("[7.0,8.0,9.0]".to_string()),
129            Some("[7.0,-8.0,9.0]".to_string()),
130            None,
131        ]));
132
133        let args = ScalarFunctionArgs {
134            args: vec![ColumnarValue::Array(input0)],
135            arg_fields: vec![],
136            number_rows: 5,
137            return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
138            config_options: Arc::new(ConfigOptions::new()),
139        };
140        let result = func
141            .invoke_with_args(args)
142            .and_then(|x| x.to_array(5))
143            .unwrap();
144
145        let result = result.as_binary_view();
146        assert_eq!(result.len(), 5);
147        assert_eq!(
148            result.value(0),
149            veclit_to_binlit(&[0.0, 0.5547002, 0.8320503]).as_slice()
150        );
151        assert_eq!(
152            result.value(1),
153            veclit_to_binlit(&[0.26726124, 0.5345225, 0.8017837]).as_slice()
154        );
155        assert_eq!(
156            result.value(2),
157            veclit_to_binlit(&[0.5025707, 0.5743665, 0.64616233]).as_slice()
158        );
159        assert_eq!(
160            result.value(3),
161            veclit_to_binlit(&[0.5025707, -0.5743665, 0.64616233]).as_slice()
162        );
163        assert!(result.is_null(4));
164    }
165}