common_function/scalars/vector/
vector_norm.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
20use datafusion_common::ScalarValue;
21use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
22use nalgebra::DVectorView;
23
24use crate::function::Function;
25use crate::scalars::vector::VectorCalculator;
26use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
27
28const NAME: &str = "vec_norm";
29
30/// Normalizes the vector to length 1, returns a vector.
31/// This's equivalent to `VECTOR_SCALAR_MUL(1/SQRT(VECTOR_ELEM_SUM(VECTOR_MUL(v, v))), v)`.
32///
33/// # Example
34///
35/// ```sql
36/// SELECT vec_to_string(vec_norm('[7.0, 8.0, 9.0]'));
37///
38/// +--------------------------------------------------+
39/// | vec_to_string(vec_norm(Utf8("[7.0, 8.0, 9.0]"))) |
40/// +--------------------------------------------------+
41/// | [0.013888889,0.015873017,0.017857144]            |
42/// +--------------------------------------------------+
43///
44/// ```
45#[derive(Debug, Clone)]
46pub(crate) struct VectorNormFunction {
47    signature: Signature,
48}
49
50impl Default for VectorNormFunction {
51    fn default() -> Self {
52        Self {
53            signature: Signature::one_of(
54                vec![
55                    TypeSignature::Uniform(1, STRINGS.to_vec()),
56                    TypeSignature::Uniform(1, BINARYS.to_vec()),
57                    TypeSignature::Uniform(1, vec![DataType::BinaryView]),
58                ],
59                Volatility::Immutable,
60            ),
61        }
62    }
63}
64
65impl Function for VectorNormFunction {
66    fn name(&self) -> &str {
67        NAME
68    }
69
70    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
71        Ok(DataType::BinaryView)
72    }
73
74    fn signature(&self) -> &Signature {
75        &self.signature
76    }
77
78    fn invoke_with_args(
79        &self,
80        args: ScalarFunctionArgs,
81    ) -> datafusion_common::Result<ColumnarValue> {
82        let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
83            let v0 = as_veclit(v0)?;
84            let Some(v0) = v0 else {
85                return Ok(ScalarValue::BinaryView(None));
86            };
87
88            let v0 = DVectorView::from_slice(&v0, v0.len());
89            let result =
90                veclit_to_binlit(v0.unscale(v0.component_mul(&v0).sum().sqrt()).as_slice());
91            Ok(ScalarValue::BinaryView(Some(result)))
92        };
93
94        let calculator = VectorCalculator {
95            name: self.name(),
96            func: body,
97        };
98        calculator.invoke_with_single_argument(args)
99    }
100}
101
102impl Display for VectorNormFunction {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        write!(f, "{}", NAME.to_ascii_uppercase())
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use std::sync::Arc;
111
112    use arrow_schema::Field;
113    use datafusion::arrow::array::{Array, AsArray, StringViewArray};
114    use datafusion_common::config::ConfigOptions;
115
116    use super::*;
117
118    #[test]
119    fn test_vec_norm() {
120        let func = VectorNormFunction::default();
121
122        let input0 = Arc::new(StringViewArray::from(vec![
123            Some("[0.0,2.0,3.0]".to_string()),
124            Some("[1.0,2.0,3.0]".to_string()),
125            Some("[7.0,8.0,9.0]".to_string()),
126            Some("[7.0,-8.0,9.0]".to_string()),
127            None,
128        ]));
129
130        let args = ScalarFunctionArgs {
131            args: vec![ColumnarValue::Array(input0)],
132            arg_fields: vec![],
133            number_rows: 5,
134            return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
135            config_options: Arc::new(ConfigOptions::new()),
136        };
137        let result = func
138            .invoke_with_args(args)
139            .and_then(|x| x.to_array(5))
140            .unwrap();
141
142        let result = result.as_binary_view();
143        assert_eq!(result.len(), 5);
144        assert_eq!(
145            result.value(0),
146            veclit_to_binlit(&[0.0, 0.5547002, 0.8320503]).as_slice()
147        );
148        assert_eq!(
149            result.value(1),
150            veclit_to_binlit(&[0.26726124, 0.5345225, 0.8017837]).as_slice()
151        );
152        assert_eq!(
153            result.value(2),
154            veclit_to_binlit(&[0.5025707, 0.5743665, 0.64616233]).as_slice()
155        );
156        assert_eq!(
157            result.value(3),
158            veclit_to_binlit(&[0.5025707, -0.5743665, 0.64616233]).as_slice()
159        );
160        assert!(result.is_null(4));
161    }
162}