common_function/scalars/vector/
vector_norm.rs1use std::fmt::Display;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
20use datafusion_common::ScalarValue;
21use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
22use nalgebra::DVectorView;
23
24use crate::function::Function;
25use crate::scalars::vector::VectorCalculator;
26use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
27
28const NAME: &str = "vec_norm";
29
30#[derive(Debug, Clone)]
46pub(crate) struct VectorNormFunction {
47    signature: Signature,
48}
49
50impl Default for VectorNormFunction {
51    fn default() -> Self {
52        Self {
53            signature: Signature::one_of(
54                vec![
55                    TypeSignature::Uniform(1, STRINGS.to_vec()),
56                    TypeSignature::Uniform(1, BINARYS.to_vec()),
57                    TypeSignature::Uniform(1, vec![DataType::BinaryView]),
58                ],
59                Volatility::Immutable,
60            ),
61        }
62    }
63}
64
65impl Function for VectorNormFunction {
66    fn name(&self) -> &str {
67        NAME
68    }
69
70    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
71        Ok(DataType::BinaryView)
72    }
73
74    fn signature(&self) -> &Signature {
75        &self.signature
76    }
77
78    fn invoke_with_args(
79        &self,
80        args: ScalarFunctionArgs,
81    ) -> datafusion_common::Result<ColumnarValue> {
82        let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
83            let v0 = as_veclit(v0)?;
84            let Some(v0) = v0 else {
85                return Ok(ScalarValue::BinaryView(None));
86            };
87
88            let v0 = DVectorView::from_slice(&v0, v0.len());
89            let result =
90                veclit_to_binlit(v0.unscale(v0.component_mul(&v0).sum().sqrt()).as_slice());
91            Ok(ScalarValue::BinaryView(Some(result)))
92        };
93
94        let calculator = VectorCalculator {
95            name: self.name(),
96            func: body,
97        };
98        calculator.invoke_with_single_argument(args)
99    }
100}
101
102impl Display for VectorNormFunction {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        write!(f, "{}", NAME.to_ascii_uppercase())
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use std::sync::Arc;
111
112    use arrow_schema::Field;
113    use datafusion::arrow::array::{Array, AsArray, StringViewArray};
114    use datafusion_common::config::ConfigOptions;
115
116    use super::*;
117
118    #[test]
119    fn test_vec_norm() {
120        let func = VectorNormFunction::default();
121
122        let input0 = Arc::new(StringViewArray::from(vec![
123            Some("[0.0,2.0,3.0]".to_string()),
124            Some("[1.0,2.0,3.0]".to_string()),
125            Some("[7.0,8.0,9.0]".to_string()),
126            Some("[7.0,-8.0,9.0]".to_string()),
127            None,
128        ]));
129
130        let args = ScalarFunctionArgs {
131            args: vec![ColumnarValue::Array(input0)],
132            arg_fields: vec![],
133            number_rows: 5,
134            return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
135            config_options: Arc::new(ConfigOptions::new()),
136        };
137        let result = func
138            .invoke_with_args(args)
139            .and_then(|x| x.to_array(5))
140            .unwrap();
141
142        let result = result.as_binary_view();
143        assert_eq!(result.len(), 5);
144        assert_eq!(
145            result.value(0),
146            veclit_to_binlit(&[0.0, 0.5547002, 0.8320503]).as_slice()
147        );
148        assert_eq!(
149            result.value(1),
150            veclit_to_binlit(&[0.26726124, 0.5345225, 0.8017837]).as_slice()
151        );
152        assert_eq!(
153            result.value(2),
154            veclit_to_binlit(&[0.5025707, 0.5743665, 0.64616233]).as_slice()
155        );
156        assert_eq!(
157            result.value(3),
158            veclit_to_binlit(&[0.5025707, -0.5743665, 0.64616233]).as_slice()
159        );
160        assert!(result.is_null(4));
161    }
162}