common_function/scalars/vector/
vector_norm.rs1use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use common_query::prelude::{Signature, TypeSignature, Volatility};
20use datatypes::prelude::ConcreteDataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
23use nalgebra::DVectorView;
24use snafu::ensure;
25
26use crate::function::{Function, FunctionContext};
27use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
28
29const NAME: &str = "vec_norm";
30
31#[derive(Debug, Clone, Default)]
47pub struct VectorNormFunction;
48
49impl Function for VectorNormFunction {
50 fn name(&self) -> &str {
51 NAME
52 }
53
54 fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
55 Ok(ConcreteDataType::binary_datatype())
56 }
57
58 fn signature(&self) -> Signature {
59 Signature::one_of(
60 vec![
61 TypeSignature::Exact(vec![ConcreteDataType::string_datatype()]),
62 TypeSignature::Exact(vec![ConcreteDataType::binary_datatype()]),
63 ],
64 Volatility::Immutable,
65 )
66 }
67
68 fn eval(
69 &self,
70 _func_ctx: &FunctionContext,
71 columns: &[VectorRef],
72 ) -> common_query::error::Result<VectorRef> {
73 ensure!(
74 columns.len() == 1,
75 InvalidFuncArgsSnafu {
76 err_msg: format!(
77 "The length of the args is not correct, expect exactly one, have: {}",
78 columns.len()
79 )
80 }
81 );
82 let arg0 = &columns[0];
83
84 let len = arg0.len();
85 let mut result = BinaryVectorBuilder::with_capacity(len);
86 if len == 0 {
87 return Ok(result.to_vector());
88 }
89
90 let arg0_const = as_veclit_if_const(arg0)?;
91
92 for i in 0..len {
93 let arg0 = match arg0_const.as_ref() {
94 Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
95 None => as_veclit(arg0.get_ref(i))?,
96 };
97 let Some(arg0) = arg0 else {
98 result.push_null();
99 continue;
100 };
101
102 let vec0 = DVectorView::from_slice(&arg0, arg0.len());
103 let vec1 = DVectorView::from_slice(&arg0, arg0.len());
104 let vec2scalar = vec1.component_mul(&vec0);
105 let scalar_var = vec2scalar.sum().sqrt();
106
107 let vec = DVectorView::from_slice(&arg0, arg0.len());
108 let vec_res = vec.unscale(scalar_var);
110
111 let veclit = vec_res.as_slice();
112 let binlit = veclit_to_binlit(veclit);
113 result.push(Some(&binlit));
114 }
115
116 Ok(result.to_vector())
117 }
118}
119
120impl Display for VectorNormFunction {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 write!(f, "{}", NAME.to_ascii_uppercase())
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use std::sync::Arc;
129
130 use datatypes::vectors::StringVector;
131
132 use super::*;
133
134 #[test]
135 fn test_vec_norm() {
136 let func = VectorNormFunction;
137
138 let input0 = Arc::new(StringVector::from(vec![
139 Some("[0.0,2.0,3.0]".to_string()),
140 Some("[1.0,2.0,3.0]".to_string()),
141 Some("[7.0,8.0,9.0]".to_string()),
142 Some("[7.0,-8.0,9.0]".to_string()),
143 None,
144 ]));
145
146 let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
147
148 let result = result.as_ref();
149 assert_eq!(result.len(), 5);
150 assert_eq!(
151 result.get_ref(0).as_binary().unwrap(),
152 Some(veclit_to_binlit(&[0.0, 0.5547002, 0.8320503]).as_slice())
153 );
154 assert_eq!(
155 result.get_ref(1).as_binary().unwrap(),
156 Some(veclit_to_binlit(&[0.26726124, 0.5345225, 0.8017837]).as_slice())
157 );
158 assert_eq!(
159 result.get_ref(2).as_binary().unwrap(),
160 Some(veclit_to_binlit(&[0.5025707, 0.5743665, 0.64616233]).as_slice())
161 );
162 assert_eq!(
163 result.get_ref(3).as_binary().unwrap(),
164 Some(veclit_to_binlit(&[0.5025707, -0.5743665, 0.64616233]).as_slice())
165 );
166 assert!(result.get_ref(4).is_null());
167 }
168}