common_function/scalars/vector/
vector_sub.rs1use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion_expr::Signature;
20use datatypes::arrow::datatypes::DataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
23use nalgebra::DVectorView;
24use snafu::ensure;
25
26use crate::function::{Function, FunctionContext};
27use crate::helper;
28use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
29
30const NAME: &str = "vec_sub";
31
32#[derive(Debug, Clone, Default)]
46pub struct VectorSubFunction;
47
48impl Function for VectorSubFunction {
49 fn name(&self) -> &str {
50 NAME
51 }
52
53 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
54 Ok(DataType::Binary)
55 }
56
57 fn signature(&self) -> Signature {
58 helper::one_of_sigs2(
59 vec![DataType::Utf8, DataType::Binary],
60 vec![DataType::Utf8, DataType::Binary],
61 )
62 }
63
64 fn eval(
65 &self,
66 _func_ctx: &FunctionContext,
67 columns: &[VectorRef],
68 ) -> common_query::error::Result<VectorRef> {
69 ensure!(
70 columns.len() == 2,
71 InvalidFuncArgsSnafu {
72 err_msg: format!(
73 "The length of the args is not correct, expect exactly two, have: {}",
74 columns.len()
75 )
76 }
77 );
78 let arg0 = &columns[0];
79 let arg1 = &columns[1];
80
81 ensure!(
82 arg0.len() == arg1.len(),
83 InvalidFuncArgsSnafu {
84 err_msg: format!(
85 "The lengths of the vector are not aligned, args 0: {}, args 1: {}",
86 arg0.len(),
87 arg1.len(),
88 )
89 }
90 );
91
92 let len = arg0.len();
93 let mut result = BinaryVectorBuilder::with_capacity(len);
94 if len == 0 {
95 return Ok(result.to_vector());
96 }
97
98 let arg0_const = as_veclit_if_const(arg0)?;
99 let arg1_const = as_veclit_if_const(arg1)?;
100
101 for i in 0..len {
102 let arg0 = match arg0_const.as_ref() {
103 Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
104 None => as_veclit(arg0.get_ref(i))?,
105 };
106 let arg1 = match arg1_const.as_ref() {
107 Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
108 None => as_veclit(arg1.get_ref(i))?,
109 };
110 let (Some(arg0), Some(arg1)) = (arg0, arg1) else {
111 result.push_null();
112 continue;
113 };
114 let vec0 = DVectorView::from_slice(&arg0, arg0.len());
115 let vec1 = DVectorView::from_slice(&arg1, arg1.len());
116
117 let vec_res = vec0 - vec1;
118 let veclit = vec_res.as_slice();
119 let binlit = veclit_to_binlit(veclit);
120 result.push(Some(&binlit));
121 }
122
123 Ok(result.to_vector())
124 }
125}
126
127impl Display for VectorSubFunction {
128 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129 write!(f, "{}", NAME.to_ascii_uppercase())
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use std::sync::Arc;
136
137 use common_query::error::Error;
138 use datatypes::vectors::StringVector;
139
140 use super::*;
141
142 #[test]
143 fn test_sub() {
144 let func = VectorSubFunction;
145
146 let input0 = Arc::new(StringVector::from(vec![
147 Some("[1.0,2.0,3.0]".to_string()),
148 Some("[4.0,5.0,6.0]".to_string()),
149 None,
150 Some("[2.0,3.0,3.0]".to_string()),
151 ]));
152 let input1 = Arc::new(StringVector::from(vec![
153 Some("[1.0,1.0,1.0]".to_string()),
154 Some("[6.0,5.0,4.0]".to_string()),
155 Some("[3.0,2.0,2.0]".to_string()),
156 None,
157 ]));
158
159 let result = func
160 .eval(&FunctionContext::default(), &[input0, input1])
161 .unwrap();
162
163 let result = result.as_ref();
164 assert_eq!(result.len(), 4);
165 assert_eq!(
166 result.get_ref(0).as_binary().unwrap(),
167 Some(veclit_to_binlit(&[0.0, 1.0, 2.0]).as_slice())
168 );
169 assert_eq!(
170 result.get_ref(1).as_binary().unwrap(),
171 Some(veclit_to_binlit(&[-2.0, 0.0, 2.0]).as_slice())
172 );
173 assert!(result.get_ref(2).is_null());
174 assert!(result.get_ref(3).is_null());
175 }
176
177 #[test]
178 fn test_sub_error() {
179 let func = VectorSubFunction;
180
181 let input0 = Arc::new(StringVector::from(vec![
182 Some("[1.0,2.0,3.0]".to_string()),
183 Some("[4.0,5.0,6.0]".to_string()),
184 None,
185 Some("[2.0,3.0,3.0]".to_string()),
186 ]));
187 let input1 = Arc::new(StringVector::from(vec![
188 Some("[1.0,1.0,1.0]".to_string()),
189 Some("[6.0,5.0,4.0]".to_string()),
190 Some("[3.0,2.0,2.0]".to_string()),
191 ]));
192
193 let result = func.eval(&FunctionContext::default(), &[input0, input1]);
194
195 match result {
196 Err(Error::InvalidFuncArgs { err_msg, .. }) => {
197 assert_eq!(
198 err_msg,
199 "The lengths of the vector are not aligned, args 0: 4, args 1: 3"
200 )
201 }
202 _ => unreachable!(),
203 }
204 }
205}