common_function/scalars/vector/
vector_norm.rs1use std::fmt::Display;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::{Coercion, ColumnarValue, TypeSignatureClass};
19use datafusion_common::ScalarValue;
20use datafusion_common::types::{logical_binary, logical_string};
21use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
22use nalgebra::DVectorView;
23
24use crate::function::Function;
25use crate::scalars::vector::VectorCalculator;
26use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
27
28const NAME: &str = "vec_norm";
29
30#[derive(Debug, Clone)]
46pub(crate) struct VectorNormFunction {
47 signature: Signature,
48}
49
50impl Default for VectorNormFunction {
51 fn default() -> Self {
52 Self {
53 signature: Signature::one_of(
54 vec![
55 TypeSignature::Coercible(vec![Coercion::new_exact(
56 TypeSignatureClass::Native(logical_binary()),
57 )]),
58 TypeSignature::Coercible(vec![Coercion::new_exact(
59 TypeSignatureClass::Native(logical_string()),
60 )]),
61 ],
62 Volatility::Immutable,
63 ),
64 }
65 }
66}
67
68impl Function for VectorNormFunction {
69 fn name(&self) -> &str {
70 NAME
71 }
72
73 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
74 Ok(DataType::BinaryView)
75 }
76
77 fn signature(&self) -> &Signature {
78 &self.signature
79 }
80
81 fn invoke_with_args(
82 &self,
83 args: ScalarFunctionArgs,
84 ) -> datafusion_common::Result<ColumnarValue> {
85 let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
86 let v0 = as_veclit(v0)?;
87 let Some(v0) = v0 else {
88 return Ok(ScalarValue::BinaryView(None));
89 };
90
91 let v0 = DVectorView::from_slice(&v0, v0.len());
92 let result =
93 veclit_to_binlit(v0.unscale(v0.component_mul(&v0).sum().sqrt()).as_slice());
94 Ok(ScalarValue::BinaryView(Some(result)))
95 };
96
97 let calculator = VectorCalculator {
98 name: self.name(),
99 func: body,
100 };
101 calculator.invoke_with_single_argument(args)
102 }
103}
104
105impl Display for VectorNormFunction {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 write!(f, "{}", NAME.to_ascii_uppercase())
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use std::sync::Arc;
114
115 use arrow_schema::Field;
116 use datafusion::arrow::array::{Array, AsArray, StringViewArray};
117 use datafusion_common::config::ConfigOptions;
118
119 use super::*;
120
121 #[test]
122 fn test_vec_norm() {
123 let func = VectorNormFunction::default();
124
125 let input0 = Arc::new(StringViewArray::from(vec![
126 Some("[0.0,2.0,3.0]".to_string()),
127 Some("[1.0,2.0,3.0]".to_string()),
128 Some("[7.0,8.0,9.0]".to_string()),
129 Some("[7.0,-8.0,9.0]".to_string()),
130 None,
131 ]));
132
133 let args = ScalarFunctionArgs {
134 args: vec![ColumnarValue::Array(input0)],
135 arg_fields: vec![],
136 number_rows: 5,
137 return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
138 config_options: Arc::new(ConfigOptions::new()),
139 };
140 let result = func
141 .invoke_with_args(args)
142 .and_then(|x| x.to_array(5))
143 .unwrap();
144
145 let result = result.as_binary_view();
146 assert_eq!(result.len(), 5);
147 assert_eq!(
148 result.value(0),
149 veclit_to_binlit(&[0.0, 0.5547002, 0.8320503]).as_slice()
150 );
151 assert_eq!(
152 result.value(1),
153 veclit_to_binlit(&[0.26726124, 0.5345225, 0.8017837]).as_slice()
154 );
155 assert_eq!(
156 result.value(2),
157 veclit_to_binlit(&[0.5025707, 0.5743665, 0.64616233]).as_slice()
158 );
159 assert_eq!(
160 result.value(3),
161 veclit_to_binlit(&[0.5025707, -0.5743665, 0.64616233]).as_slice()
162 );
163 assert!(result.is_null(4));
164 }
165}