common_function/scalars/vector/
vector_add.rs1use std::borrow::Cow;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion_common::{DataFusionError, ScalarValue};
20use datafusion_expr::{ScalarFunctionArgs, Signature};
21use nalgebra::DVectorView;
22
23use crate::function::Function;
24use crate::scalars::vector::impl_conv::veclit_to_binlit;
25use crate::scalars::vector::{VectorCalculator, define_args_of_two_vector_literals_udf};
26
27const NAME: &str = "vec_add";
28
29define_args_of_two_vector_literals_udf!(
30VectorAddFunction);
45
46impl Function for VectorAddFunction {
47 fn name(&self) -> &str {
48 NAME
49 }
50
51 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
52 Ok(DataType::BinaryView)
53 }
54
55 fn signature(&self) -> &Signature {
56 &self.signature
57 }
58
59 fn invoke_with_args(
60 &self,
61 args: ScalarFunctionArgs,
62 ) -> datafusion_common::Result<ColumnarValue> {
63 let body = |v0: &Option<Cow<[f32]>>,
64 v1: &Option<Cow<[f32]>>|
65 -> datafusion_common::Result<ScalarValue> {
66 let result = if let (Some(v0), Some(v1)) = (v0, v1) {
67 let v0 = DVectorView::from_slice(v0, v0.len());
68 let v1 = DVectorView::from_slice(v1, v1.len());
69 if v0.len() != v1.len() {
70 return Err(DataFusionError::Execution(format!(
71 "vectors length not match: {}",
72 self.name()
73 )));
74 }
75
76 let result = veclit_to_binlit((v0 + v1).as_slice());
77 Some(result)
78 } else {
79 None
80 };
81 Ok(ScalarValue::BinaryView(result))
82 };
83
84 let calculator = VectorCalculator {
85 name: self.name(),
86 func: body,
87 };
88 calculator.invoke_with_vectors(args)
89 }
90}
91
92#[cfg(test)]
93mod tests {
94 use std::sync::Arc;
95
96 use arrow_schema::Field;
97 use datafusion::arrow::array::{Array, AsArray, StringViewArray};
98 use datafusion_common::config::ConfigOptions;
99
100 use super::*;
101
102 #[test]
103 fn test_sub() {
104 let func = VectorAddFunction::default();
105
106 let input0 = Arc::new(StringViewArray::from(vec![
107 Some("[1.0,2.0,3.0]".to_string()),
108 Some("[4.0,5.0,6.0]".to_string()),
109 None,
110 Some("[2.0,3.0,3.0]".to_string()),
111 ]));
112 let input1 = Arc::new(StringViewArray::from(vec![
113 Some("[1.0,1.0,1.0]".to_string()),
114 Some("[6.0,5.0,4.0]".to_string()),
115 Some("[3.0,2.0,2.0]".to_string()),
116 None,
117 ]));
118
119 let args = ScalarFunctionArgs {
120 args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
121 arg_fields: vec![],
122 number_rows: 4,
123 return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
124 config_options: Arc::new(ConfigOptions::new()),
125 };
126 let result = func
127 .invoke_with_args(args)
128 .and_then(|x| x.to_array(4))
129 .unwrap();
130
131 let result = result.as_binary_view();
132 assert_eq!(result.len(), 4);
133 assert_eq!(
134 result.value(0),
135 veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice()
136 );
137 assert_eq!(
138 result.value(1),
139 veclit_to_binlit(&[10.0, 10.0, 10.0]).as_slice()
140 );
141 assert!(result.is_null(2));
142 assert!(result.is_null(3));
143 }
144
145 #[test]
146 fn test_sub_error() {
147 let func = VectorAddFunction::default();
148
149 let input0 = Arc::new(StringViewArray::from(vec![
150 Some("[1.0,2.0,3.0]".to_string()),
151 Some("[4.0,5.0,6.0]".to_string()),
152 None,
153 Some("[2.0,3.0,3.0]".to_string()),
154 ]));
155 let input1 = Arc::new(StringViewArray::from(vec![
156 Some("[1.0,1.0,1.0]".to_string()),
157 Some("[6.0,5.0,4.0]".to_string()),
158 Some("[3.0,2.0,2.0]".to_string()),
159 ]));
160
161 let args = ScalarFunctionArgs {
162 args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
163 arg_fields: vec![],
164 number_rows: 4,
165 return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
166 config_options: Arc::new(ConfigOptions::new()),
167 };
168 let e = func.invoke_with_args(args).unwrap_err();
169 assert!(e.to_string().starts_with(
170 "Internal error: Arguments has mixed length. Expected length: 4, found length: 3."
171 ));
172 }
173}