common_function/scalars/vector/
vector_norm.rs1use std::fmt::Display;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
20use datafusion_common::ScalarValue;
21use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
22use nalgebra::DVectorView;
23
24use crate::function::Function;
25use crate::scalars::vector::VectorCalculator;
26use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
27
28const NAME: &str = "vec_norm";
29
30#[derive(Debug, Clone)]
46pub(crate) struct VectorNormFunction {
47 signature: Signature,
48}
49
50impl Default for VectorNormFunction {
51 fn default() -> Self {
52 Self {
53 signature: Signature::one_of(
54 vec![
55 TypeSignature::Uniform(1, STRINGS.to_vec()),
56 TypeSignature::Uniform(1, BINARYS.to_vec()),
57 TypeSignature::Uniform(1, vec![DataType::BinaryView]),
58 ],
59 Volatility::Immutable,
60 ),
61 }
62 }
63}
64
65impl Function for VectorNormFunction {
66 fn name(&self) -> &str {
67 NAME
68 }
69
70 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
71 Ok(DataType::BinaryView)
72 }
73
74 fn signature(&self) -> &Signature {
75 &self.signature
76 }
77
78 fn invoke_with_args(
79 &self,
80 args: ScalarFunctionArgs,
81 ) -> datafusion_common::Result<ColumnarValue> {
82 let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
83 let v0 = as_veclit(v0)?;
84 let Some(v0) = v0 else {
85 return Ok(ScalarValue::BinaryView(None));
86 };
87
88 let v0 = DVectorView::from_slice(&v0, v0.len());
89 let result =
90 veclit_to_binlit(v0.unscale(v0.component_mul(&v0).sum().sqrt()).as_slice());
91 Ok(ScalarValue::BinaryView(Some(result)))
92 };
93
94 let calculator = VectorCalculator {
95 name: self.name(),
96 func: body,
97 };
98 calculator.invoke_with_single_argument(args)
99 }
100}
101
102impl Display for VectorNormFunction {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 write!(f, "{}", NAME.to_ascii_uppercase())
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use std::sync::Arc;
111
112 use arrow_schema::Field;
113 use datafusion::arrow::array::{Array, AsArray, StringViewArray};
114 use datafusion_common::config::ConfigOptions;
115
116 use super::*;
117
118 #[test]
119 fn test_vec_norm() {
120 let func = VectorNormFunction::default();
121
122 let input0 = Arc::new(StringViewArray::from(vec![
123 Some("[0.0,2.0,3.0]".to_string()),
124 Some("[1.0,2.0,3.0]".to_string()),
125 Some("[7.0,8.0,9.0]".to_string()),
126 Some("[7.0,-8.0,9.0]".to_string()),
127 None,
128 ]));
129
130 let args = ScalarFunctionArgs {
131 args: vec![ColumnarValue::Array(input0)],
132 arg_fields: vec![],
133 number_rows: 5,
134 return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
135 config_options: Arc::new(ConfigOptions::new()),
136 };
137 let result = func
138 .invoke_with_args(args)
139 .and_then(|x| x.to_array(5))
140 .unwrap();
141
142 let result = result.as_binary_view();
143 assert_eq!(result.len(), 5);
144 assert_eq!(
145 result.value(0),
146 veclit_to_binlit(&[0.0, 0.5547002, 0.8320503]).as_slice()
147 );
148 assert_eq!(
149 result.value(1),
150 veclit_to_binlit(&[0.26726124, 0.5345225, 0.8017837]).as_slice()
151 );
152 assert_eq!(
153 result.value(2),
154 veclit_to_binlit(&[0.5025707, 0.5743665, 0.64616233]).as_slice()
155 );
156 assert_eq!(
157 result.value(3),
158 veclit_to_binlit(&[0.5025707, -0.5743665, 0.64616233]).as_slice()
159 );
160 assert!(result.is_null(4));
161 }
162}