common_function/scalars/vector/
vector_add.rs1use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::InvalidFuncArgsSnafu;
19use common_query::prelude::Signature;
20use datatypes::prelude::ConcreteDataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
23use nalgebra::DVectorView;
24use snafu::ensure;
25
26use crate::function::{Function, FunctionContext};
27use crate::helper;
28use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
29
30const NAME: &str = "vec_add";
31
32#[derive(Debug, Clone, Default)]
46pub struct VectorAddFunction;
47
48impl Function for VectorAddFunction {
49 fn name(&self) -> &str {
50 NAME
51 }
52
53 fn return_type(
54 &self,
55 _input_types: &[ConcreteDataType],
56 ) -> common_query::error::Result<ConcreteDataType> {
57 Ok(ConcreteDataType::binary_datatype())
58 }
59
60 fn signature(&self) -> Signature {
61 helper::one_of_sigs2(
62 vec![
63 ConcreteDataType::string_datatype(),
64 ConcreteDataType::binary_datatype(),
65 ],
66 vec![
67 ConcreteDataType::string_datatype(),
68 ConcreteDataType::binary_datatype(),
69 ],
70 )
71 }
72
73 fn eval(
74 &self,
75 _func_ctx: &FunctionContext,
76 columns: &[VectorRef],
77 ) -> common_query::error::Result<VectorRef> {
78 ensure!(
79 columns.len() == 2,
80 InvalidFuncArgsSnafu {
81 err_msg: format!(
82 "The length of the args is not correct, expect exactly two, have: {}",
83 columns.len()
84 )
85 }
86 );
87 let arg0 = &columns[0];
88 let arg1 = &columns[1];
89
90 ensure!(
91 arg0.len() == arg1.len(),
92 InvalidFuncArgsSnafu {
93 err_msg: format!(
94 "The lengths of the vector are not aligned, args 0: {}, args 1: {}",
95 arg0.len(),
96 arg1.len(),
97 )
98 }
99 );
100
101 let len = arg0.len();
102 let mut result = BinaryVectorBuilder::with_capacity(len);
103 if len == 0 {
104 return Ok(result.to_vector());
105 }
106
107 let arg0_const = as_veclit_if_const(arg0)?;
108 let arg1_const = as_veclit_if_const(arg1)?;
109
110 for i in 0..len {
111 let arg0 = match arg0_const.as_ref() {
112 Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
113 None => as_veclit(arg0.get_ref(i))?,
114 };
115 let arg1 = match arg1_const.as_ref() {
116 Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
117 None => as_veclit(arg1.get_ref(i))?,
118 };
119 let (Some(arg0), Some(arg1)) = (arg0, arg1) else {
120 result.push_null();
121 continue;
122 };
123 let vec0 = DVectorView::from_slice(&arg0, arg0.len());
124 let vec1 = DVectorView::from_slice(&arg1, arg1.len());
125
126 let vec_res = vec0 + vec1;
127 let veclit = vec_res.as_slice();
128 let binlit = veclit_to_binlit(veclit);
129 result.push(Some(&binlit));
130 }
131
132 Ok(result.to_vector())
133 }
134}
135
136impl Display for VectorAddFunction {
137 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138 write!(f, "{}", NAME.to_ascii_uppercase())
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use std::sync::Arc;
145
146 use common_query::error::Error;
147 use datatypes::vectors::StringVector;
148
149 use super::*;
150
151 #[test]
152 fn test_sub() {
153 let func = VectorAddFunction;
154
155 let input0 = Arc::new(StringVector::from(vec![
156 Some("[1.0,2.0,3.0]".to_string()),
157 Some("[4.0,5.0,6.0]".to_string()),
158 None,
159 Some("[2.0,3.0,3.0]".to_string()),
160 ]));
161 let input1 = Arc::new(StringVector::from(vec![
162 Some("[1.0,1.0,1.0]".to_string()),
163 Some("[6.0,5.0,4.0]".to_string()),
164 Some("[3.0,2.0,2.0]".to_string()),
165 None,
166 ]));
167
168 let result = func
169 .eval(&FunctionContext::default(), &[input0, input1])
170 .unwrap();
171
172 let result = result.as_ref();
173 assert_eq!(result.len(), 4);
174 assert_eq!(
175 result.get_ref(0).as_binary().unwrap(),
176 Some(veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice())
177 );
178 assert_eq!(
179 result.get_ref(1).as_binary().unwrap(),
180 Some(veclit_to_binlit(&[10.0, 10.0, 10.0]).as_slice())
181 );
182 assert!(result.get_ref(2).is_null());
183 assert!(result.get_ref(3).is_null());
184 }
185
186 #[test]
187 fn test_sub_error() {
188 let func = VectorAddFunction;
189
190 let input0 = Arc::new(StringVector::from(vec![
191 Some("[1.0,2.0,3.0]".to_string()),
192 Some("[4.0,5.0,6.0]".to_string()),
193 None,
194 Some("[2.0,3.0,3.0]".to_string()),
195 ]));
196 let input1 = Arc::new(StringVector::from(vec![
197 Some("[1.0,1.0,1.0]".to_string()),
198 Some("[6.0,5.0,4.0]".to_string()),
199 Some("[3.0,2.0,2.0]".to_string()),
200 ]));
201
202 let result = func.eval(&FunctionContext::default(), &[input0, input1]);
203
204 match result {
205 Err(Error::InvalidFuncArgs { err_msg, .. }) => {
206 assert_eq!(
207 err_msg,
208 "The lengths of the vector are not aligned, args 0: 4, args 1: 3"
209 )
210 }
211 _ => unreachable!(),
212 }
213 }
214}