common_function/scalars/vector/
vector_div.rs1use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion_expr::Signature;
20use datatypes::arrow::datatypes::DataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
23use nalgebra::DVectorView;
24use snafu::ensure;
25
26use crate::function::{Function, FunctionContext};
27use crate::helper;
28use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
29
30const NAME: &str = "vec_div";
31
32#[derive(Debug, Clone, Default)]
47pub struct VectorDivFunction;
48
49impl Function for VectorDivFunction {
50 fn name(&self) -> &str {
51 NAME
52 }
53
54 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
55 Ok(DataType::Binary)
56 }
57
58 fn signature(&self) -> Signature {
59 helper::one_of_sigs2(
60 vec![DataType::Utf8, DataType::Binary],
61 vec![DataType::Utf8, DataType::Binary],
62 )
63 }
64
65 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
66 ensure!(
67 columns.len() == 2,
68 InvalidFuncArgsSnafu {
69 err_msg: format!(
70 "The length of the args is not correct, expect exactly two, have: {}",
71 columns.len()
72 ),
73 }
74 );
75
76 let arg0 = &columns[0];
77 let arg1 = &columns[1];
78
79 let len = arg0.len();
80 let mut result = BinaryVectorBuilder::with_capacity(len);
81 if len == 0 {
82 return Ok(result.to_vector());
83 }
84
85 let arg0_const = as_veclit_if_const(arg0)?;
86 let arg1_const = as_veclit_if_const(arg1)?;
87
88 for i in 0..len {
89 let arg0 = match arg0_const.as_ref() {
90 Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
91 None => as_veclit(arg0.get_ref(i))?,
92 };
93
94 let arg1 = match arg1_const.as_ref() {
95 Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
96 None => as_veclit(arg1.get_ref(i))?,
97 };
98
99 if let (Some(arg0), Some(arg1)) = (arg0, arg1) {
100 ensure!(
101 arg0.len() == arg1.len(),
102 InvalidFuncArgsSnafu {
103 err_msg: format!(
104 "The length of the vectors must match for division, have: {} vs {}",
105 arg0.len(),
106 arg1.len()
107 ),
108 }
109 );
110 let vec0 = DVectorView::from_slice(&arg0, arg0.len());
111 let vec1 = DVectorView::from_slice(&arg1, arg1.len());
112 let vec_res = vec0.component_div(&vec1);
113
114 let veclit = vec_res.as_slice();
115 let binlit = veclit_to_binlit(veclit);
116 result.push(Some(&binlit));
117 } else {
118 result.push_null();
119 }
120 }
121
122 Ok(result.to_vector())
123 }
124}
125
126impl Display for VectorDivFunction {
127 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128 write!(f, "{}", NAME.to_ascii_uppercase())
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use std::sync::Arc;
135
136 use common_query::error;
137 use datatypes::vectors::StringVector;
138
139 use super::*;
140
141 #[test]
142 fn test_vector_mul() {
143 let func = VectorDivFunction;
144
145 let vec0 = vec![1.0, 2.0, 3.0];
146 let vec1 = vec![1.0, 1.0];
147 let (len0, len1) = (vec0.len(), vec1.len());
148 let input0 = Arc::new(StringVector::from(vec![Some(format!("{vec0:?}"))]));
149 let input1 = Arc::new(StringVector::from(vec![Some(format!("{vec1:?}"))]));
150
151 let err = func
152 .eval(&FunctionContext::default(), &[input0, input1])
153 .unwrap_err();
154
155 match err {
156 error::Error::InvalidFuncArgs { err_msg, .. } => {
157 assert_eq!(
158 err_msg,
159 format!(
160 "The length of the vectors must match for division, have: {} vs {}",
161 len0, len1
162 )
163 )
164 }
165 _ => unreachable!(),
166 }
167
168 let input0 = Arc::new(StringVector::from(vec![
169 Some("[1.0,2.0,3.0]".to_string()),
170 Some("[8.0,10.0,12.0]".to_string()),
171 Some("[7.0,8.0,9.0]".to_string()),
172 None,
173 ]));
174
175 let input1 = Arc::new(StringVector::from(vec![
176 Some("[1.0,1.0,1.0]".to_string()),
177 Some("[2.0,2.0,2.0]".to_string()),
178 None,
179 Some("[3.0,3.0,3.0]".to_string()),
180 ]));
181
182 let result = func
183 .eval(&FunctionContext::default(), &[input0, input1])
184 .unwrap();
185
186 let result = result.as_ref();
187 assert_eq!(result.len(), 4);
188 assert_eq!(
189 result.get_ref(0).as_binary().unwrap(),
190 Some(veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice())
191 );
192 assert_eq!(
193 result.get_ref(1).as_binary().unwrap(),
194 Some(veclit_to_binlit(&[4.0, 5.0, 6.0]).as_slice())
195 );
196 assert!(result.get_ref(2).is_null());
197 assert!(result.get_ref(3).is_null());
198
199 let input0 = Arc::new(StringVector::from(vec![Some("[1.0,-2.0]".to_string())]));
200 let input1 = Arc::new(StringVector::from(vec![Some("[0.0,0.0]".to_string())]));
201
202 let result = func
203 .eval(&FunctionContext::default(), &[input0, input1])
204 .unwrap();
205
206 let result = result.as_ref();
207 assert_eq!(
208 result.get_ref(0).as_binary().unwrap(),
209 Some(veclit_to_binlit(&[f64::INFINITY as f32, f64::NEG_INFINITY as f32]).as_slice())
210 );
211 }
212}