common_function/scalars/vector/
vector_div.rs1use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use common_query::prelude::Signature;
20use datatypes::prelude::ConcreteDataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
23use nalgebra::DVectorView;
24use snafu::ensure;
25
26use crate::function::{Function, FunctionContext};
27use crate::helper;
28use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
29
30const NAME: &str = "vec_div";
31
32#[derive(Debug, Clone, Default)]
47pub struct VectorDivFunction;
48
49impl Function for VectorDivFunction {
50 fn name(&self) -> &str {
51 NAME
52 }
53
54 fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
55 Ok(ConcreteDataType::binary_datatype())
56 }
57
58 fn signature(&self) -> Signature {
59 helper::one_of_sigs2(
60 vec![
61 ConcreteDataType::string_datatype(),
62 ConcreteDataType::binary_datatype(),
63 ],
64 vec![
65 ConcreteDataType::string_datatype(),
66 ConcreteDataType::binary_datatype(),
67 ],
68 )
69 }
70
71 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
72 ensure!(
73 columns.len() == 2,
74 InvalidFuncArgsSnafu {
75 err_msg: format!(
76 "The length of the args is not correct, expect exactly two, have: {}",
77 columns.len()
78 ),
79 }
80 );
81
82 let arg0 = &columns[0];
83 let arg1 = &columns[1];
84
85 let len = arg0.len();
86 let mut result = BinaryVectorBuilder::with_capacity(len);
87 if len == 0 {
88 return Ok(result.to_vector());
89 }
90
91 let arg0_const = as_veclit_if_const(arg0)?;
92 let arg1_const = as_veclit_if_const(arg1)?;
93
94 for i in 0..len {
95 let arg0 = match arg0_const.as_ref() {
96 Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
97 None => as_veclit(arg0.get_ref(i))?,
98 };
99
100 let arg1 = match arg1_const.as_ref() {
101 Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
102 None => as_veclit(arg1.get_ref(i))?,
103 };
104
105 if let (Some(arg0), Some(arg1)) = (arg0, arg1) {
106 ensure!(
107 arg0.len() == arg1.len(),
108 InvalidFuncArgsSnafu {
109 err_msg: format!(
110 "The length of the vectors must match for division, have: {} vs {}",
111 arg0.len(),
112 arg1.len()
113 ),
114 }
115 );
116 let vec0 = DVectorView::from_slice(&arg0, arg0.len());
117 let vec1 = DVectorView::from_slice(&arg1, arg1.len());
118 let vec_res = vec0.component_div(&vec1);
119
120 let veclit = vec_res.as_slice();
121 let binlit = veclit_to_binlit(veclit);
122 result.push(Some(&binlit));
123 } else {
124 result.push_null();
125 }
126 }
127
128 Ok(result.to_vector())
129 }
130}
131
132impl Display for VectorDivFunction {
133 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134 write!(f, "{}", NAME.to_ascii_uppercase())
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use std::sync::Arc;
141
142 use common_query::error;
143 use datatypes::vectors::StringVector;
144
145 use super::*;
146
147 #[test]
148 fn test_vector_mul() {
149 let func = VectorDivFunction;
150
151 let vec0 = vec![1.0, 2.0, 3.0];
152 let vec1 = vec![1.0, 1.0];
153 let (len0, len1) = (vec0.len(), vec1.len());
154 let input0 = Arc::new(StringVector::from(vec![Some(format!("{vec0:?}"))]));
155 let input1 = Arc::new(StringVector::from(vec![Some(format!("{vec1:?}"))]));
156
157 let err = func
158 .eval(&FunctionContext::default(), &[input0, input1])
159 .unwrap_err();
160
161 match err {
162 error::Error::InvalidFuncArgs { err_msg, .. } => {
163 assert_eq!(
164 err_msg,
165 format!(
166 "The length of the vectors must match for division, have: {} vs {}",
167 len0, len1
168 )
169 )
170 }
171 _ => unreachable!(),
172 }
173
174 let input0 = Arc::new(StringVector::from(vec![
175 Some("[1.0,2.0,3.0]".to_string()),
176 Some("[8.0,10.0,12.0]".to_string()),
177 Some("[7.0,8.0,9.0]".to_string()),
178 None,
179 ]));
180
181 let input1 = Arc::new(StringVector::from(vec![
182 Some("[1.0,1.0,1.0]".to_string()),
183 Some("[2.0,2.0,2.0]".to_string()),
184 None,
185 Some("[3.0,3.0,3.0]".to_string()),
186 ]));
187
188 let result = func
189 .eval(&FunctionContext::default(), &[input0, input1])
190 .unwrap();
191
192 let result = result.as_ref();
193 assert_eq!(result.len(), 4);
194 assert_eq!(
195 result.get_ref(0).as_binary().unwrap(),
196 Some(veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice())
197 );
198 assert_eq!(
199 result.get_ref(1).as_binary().unwrap(),
200 Some(veclit_to_binlit(&[4.0, 5.0, 6.0]).as_slice())
201 );
202 assert!(result.get_ref(2).is_null());
203 assert!(result.get_ref(3).is_null());
204
205 let input0 = Arc::new(StringVector::from(vec![Some("[1.0,-2.0]".to_string())]));
206 let input1 = Arc::new(StringVector::from(vec![Some("[0.0,0.0]".to_string())]));
207
208 let result = func
209 .eval(&FunctionContext::default(), &[input0, input1])
210 .unwrap();
211
212 let result = result.as_ref();
213 assert_eq!(
214 result.get_ref(0).as_binary().unwrap(),
215 Some(veclit_to_binlit(&[f64::INFINITY as f32, f64::NEG_INFINITY as f32]).as_slice())
216 );
217 }
218}