common_function/scalars/vector/
vector_div.rs1use std::borrow::Cow;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion_common::{DataFusionError, ScalarValue};
20use datafusion_expr::{ScalarFunctionArgs, Signature};
21use nalgebra::DVectorView;
22
23use crate::function::Function;
24use crate::scalars::vector::impl_conv::veclit_to_binlit;
25use crate::scalars::vector::{VectorCalculator, define_args_of_two_vector_literals_udf};
26
27const NAME: &str = "vec_div";
28
29define_args_of_two_vector_literals_udf!(
30VectorDivFunction);
46
47impl Function for VectorDivFunction {
48 fn name(&self) -> &str {
49 NAME
50 }
51
52 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
53 Ok(DataType::BinaryView)
54 }
55
56 fn signature(&self) -> &Signature {
57 &self.signature
58 }
59
60 fn invoke_with_args(
61 &self,
62 args: ScalarFunctionArgs,
63 ) -> datafusion_common::Result<ColumnarValue> {
64 let body = |v0: &Option<Cow<[f32]>>,
65 v1: &Option<Cow<[f32]>>|
66 -> datafusion_common::Result<ScalarValue> {
67 let result = if let (Some(v0), Some(v1)) = (v0, v1) {
68 let v0 = DVectorView::from_slice(v0, v0.len());
69 let v1 = DVectorView::from_slice(v1, v1.len());
70 if v0.len() != v1.len() {
71 return Err(DataFusionError::Execution(format!(
72 "vectors length not match: {}",
73 self.name()
74 )));
75 }
76
77 let result = veclit_to_binlit((v0.component_div(&v1)).as_slice());
78 Some(result)
79 } else {
80 None
81 };
82 Ok(ScalarValue::BinaryView(result))
83 };
84
85 let calculator = VectorCalculator {
86 name: self.name(),
87 func: body,
88 };
89 calculator.invoke_with_vectors(args)
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use std::sync::Arc;
96
97 use arrow_schema::Field;
98 use datafusion::arrow::array::{Array, AsArray, StringViewArray};
99 use datafusion_common::config::ConfigOptions;
100
101 use super::*;
102
103 #[test]
104 fn test_vector_mul() {
105 let func = VectorDivFunction::default();
106
107 let vec0 = vec![1.0, 2.0, 3.0];
108 let vec1 = vec![1.0, 1.0];
109 let input0 = Arc::new(StringViewArray::from(vec![Some(format!("{vec0:?}"))]));
110 let input1 = Arc::new(StringViewArray::from(vec![Some(format!("{vec1:?}"))]));
111
112 let args = ScalarFunctionArgs {
113 args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
114 arg_fields: vec![],
115 number_rows: 3,
116 return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
117 config_options: Arc::new(ConfigOptions::new()),
118 };
119 let e = func.invoke_with_args(args).unwrap_err();
120 assert_eq!(
121 e.to_string(),
122 "Execution error: vectors length not match: vec_div"
123 );
124
125 let input0 = Arc::new(StringViewArray::from(vec![
126 Some("[1.0,2.0,3.0]".to_string()),
127 Some("[8.0,10.0,12.0]".to_string()),
128 Some("[7.0,8.0,9.0]".to_string()),
129 None,
130 ]));
131
132 let input1 = Arc::new(StringViewArray::from(vec![
133 Some("[1.0,1.0,1.0]".to_string()),
134 Some("[2.0,2.0,2.0]".to_string()),
135 None,
136 Some("[3.0,3.0,3.0]".to_string()),
137 ]));
138
139 let args = ScalarFunctionArgs {
140 args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
141 arg_fields: vec![],
142 number_rows: 4,
143 return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
144 config_options: Arc::new(ConfigOptions::new()),
145 };
146 let result = func
147 .invoke_with_args(args)
148 .and_then(|x| x.to_array(4))
149 .unwrap();
150
151 let result = result.as_binary_view();
152 assert_eq!(result.len(), 4);
153 assert_eq!(
154 result.value(0),
155 veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice()
156 );
157 assert_eq!(
158 result.value(1),
159 veclit_to_binlit(&[4.0, 5.0, 6.0]).as_slice()
160 );
161 assert!(result.is_null(2));
162 assert!(result.is_null(3));
163
164 let input0 = Arc::new(StringViewArray::from(vec![Some("[1.0,-2.0]".to_string())]));
165 let input1 = Arc::new(StringViewArray::from(vec![Some("[0.0,0.0]".to_string())]));
166
167 let args = ScalarFunctionArgs {
168 args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
169 arg_fields: vec![],
170 number_rows: 2,
171 return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
172 config_options: Arc::new(ConfigOptions::new()),
173 };
174 let result = func
175 .invoke_with_args(args)
176 .and_then(|x| x.to_array(2))
177 .unwrap();
178
179 let result = result.as_binary_view();
180 assert_eq!(
181 result.value(0),
182 veclit_to_binlit(&[f64::INFINITY as f32, f64::NEG_INFINITY as f32]).as_slice()
183 );
184 }
185}