common_function/scalars/vector/
vector_norm.rs1use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion::arrow::datatypes::DataType;
20use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
21use datafusion_expr::{Signature, TypeSignature, Volatility};
22use datatypes::scalars::ScalarVectorBuilder;
23use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
24use nalgebra::DVectorView;
25use snafu::ensure;
26
27use crate::function::{Function, FunctionContext};
28use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
29
30const NAME: &str = "vec_norm";
31
32#[derive(Debug, Clone, Default)]
48pub struct VectorNormFunction;
49
50impl Function for VectorNormFunction {
51 fn name(&self) -> &str {
52 NAME
53 }
54
55 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
56 Ok(DataType::Binary)
57 }
58
59 fn signature(&self) -> Signature {
60 Signature::one_of(
61 vec![
62 TypeSignature::Uniform(1, STRINGS.to_vec()),
63 TypeSignature::Uniform(1, BINARYS.to_vec()),
64 ],
65 Volatility::Immutable,
66 )
67 }
68
69 fn eval(
70 &self,
71 _func_ctx: &FunctionContext,
72 columns: &[VectorRef],
73 ) -> common_query::error::Result<VectorRef> {
74 ensure!(
75 columns.len() == 1,
76 InvalidFuncArgsSnafu {
77 err_msg: format!(
78 "The length of the args is not correct, expect exactly one, have: {}",
79 columns.len()
80 )
81 }
82 );
83 let arg0 = &columns[0];
84
85 let len = arg0.len();
86 let mut result = BinaryVectorBuilder::with_capacity(len);
87 if len == 0 {
88 return Ok(result.to_vector());
89 }
90
91 let arg0_const = as_veclit_if_const(arg0)?;
92
93 for i in 0..len {
94 let arg0 = match arg0_const.as_ref() {
95 Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
96 None => as_veclit(arg0.get_ref(i))?,
97 };
98 let Some(arg0) = arg0 else {
99 result.push_null();
100 continue;
101 };
102
103 let vec0 = DVectorView::from_slice(&arg0, arg0.len());
104 let vec1 = DVectorView::from_slice(&arg0, arg0.len());
105 let vec2scalar = vec1.component_mul(&vec0);
106 let scalar_var = vec2scalar.sum().sqrt();
107
108 let vec = DVectorView::from_slice(&arg0, arg0.len());
109 let vec_res = vec.unscale(scalar_var);
111
112 let veclit = vec_res.as_slice();
113 let binlit = veclit_to_binlit(veclit);
114 result.push(Some(&binlit));
115 }
116
117 Ok(result.to_vector())
118 }
119}
120
121impl Display for VectorNormFunction {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 write!(f, "{}", NAME.to_ascii_uppercase())
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use std::sync::Arc;
130
131 use datatypes::vectors::StringVector;
132
133 use super::*;
134
135 #[test]
136 fn test_vec_norm() {
137 let func = VectorNormFunction;
138
139 let input0 = Arc::new(StringVector::from(vec![
140 Some("[0.0,2.0,3.0]".to_string()),
141 Some("[1.0,2.0,3.0]".to_string()),
142 Some("[7.0,8.0,9.0]".to_string()),
143 Some("[7.0,-8.0,9.0]".to_string()),
144 None,
145 ]));
146
147 let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
148
149 let result = result.as_ref();
150 assert_eq!(result.len(), 5);
151 assert_eq!(
152 result.get_ref(0).as_binary().unwrap(),
153 Some(veclit_to_binlit(&[0.0, 0.5547002, 0.8320503]).as_slice())
154 );
155 assert_eq!(
156 result.get_ref(1).as_binary().unwrap(),
157 Some(veclit_to_binlit(&[0.26726124, 0.5345225, 0.8017837]).as_slice())
158 );
159 assert_eq!(
160 result.get_ref(2).as_binary().unwrap(),
161 Some(veclit_to_binlit(&[0.5025707, 0.5743665, 0.64616233]).as_slice())
162 );
163 assert_eq!(
164 result.get_ref(3).as_binary().unwrap(),
165 Some(veclit_to_binlit(&[0.5025707, -0.5743665, 0.64616233]).as_slice())
166 );
167 assert!(result.get_ref(4).is_null());
168 }
169}