common_function/scalars/vector/
vector_sub.rs1use std::borrow::Cow;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion_common::{DataFusionError, ScalarValue};
20use datafusion_expr::{ScalarFunctionArgs, Signature};
21use nalgebra::DVectorView;
22
23use crate::function::Function;
24use crate::scalars::vector::impl_conv::veclit_to_binlit;
25use crate::scalars::vector::{VectorCalculator, define_args_of_two_vector_literals_udf};
26
27const NAME: &str = "vec_sub";
28
29define_args_of_two_vector_literals_udf!(
30VectorSubFunction);
44
45impl Function for VectorSubFunction {
46 fn name(&self) -> &str {
47 NAME
48 }
49
50 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
51 Ok(DataType::BinaryView)
52 }
53
54 fn signature(&self) -> &Signature {
55 &self.signature
56 }
57
58 fn invoke_with_args(
59 &self,
60 args: ScalarFunctionArgs,
61 ) -> datafusion_common::Result<ColumnarValue> {
62 let body = |v0: &Option<Cow<[f32]>>,
63 v1: &Option<Cow<[f32]>>|
64 -> datafusion_common::Result<ScalarValue> {
65 let result = if let (Some(v0), Some(v1)) = (v0, v1) {
66 let v0 = DVectorView::from_slice(v0, v0.len());
67 let v1 = DVectorView::from_slice(v1, v1.len());
68 if v0.len() != v1.len() {
69 return Err(DataFusionError::Execution(format!(
70 "vectors length not match: {}",
71 self.name()
72 )));
73 }
74
75 let result = veclit_to_binlit((v0 - v1).as_slice());
76 Some(result)
77 } else {
78 None
79 };
80 Ok(ScalarValue::BinaryView(result))
81 };
82
83 let calculator = VectorCalculator {
84 name: self.name(),
85 func: body,
86 };
87 calculator.invoke_with_vectors(args)
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use std::sync::Arc;
94
95 use arrow_schema::Field;
96 use datafusion::arrow::array::{Array, ArrayRef, AsArray, StringViewArray};
97 use datafusion_common::config::ConfigOptions;
98
99 use super::*;
100
101 #[test]
102 fn test_sub() {
103 let func = VectorSubFunction::default();
104
105 let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
106 Some("[1.0,2.0,3.0]".to_string()),
107 Some("[4.0,5.0,6.0]".to_string()),
108 None,
109 Some("[2.0,3.0,3.0]".to_string()),
110 ]));
111 let input1: ArrayRef = Arc::new(StringViewArray::from(vec![
112 Some("[1.0,1.0,1.0]".to_string()),
113 Some("[6.0,5.0,4.0]".to_string()),
114 Some("[3.0,2.0,2.0]".to_string()),
115 None,
116 ]));
117
118 let args = ScalarFunctionArgs {
119 args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
120 arg_fields: vec![],
121 number_rows: 4,
122 return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
123 config_options: Arc::new(ConfigOptions::new()),
124 };
125 let result = func
126 .invoke_with_args(args)
127 .and_then(|x| x.to_array(4))
128 .unwrap();
129
130 let result = result.as_binary_view();
131 assert_eq!(result.len(), 4);
132 assert_eq!(
133 result.value(0),
134 veclit_to_binlit(&[0.0, 1.0, 2.0]).as_slice()
135 );
136 assert_eq!(
137 result.value(1),
138 veclit_to_binlit(&[-2.0, 0.0, 2.0]).as_slice()
139 );
140 assert!(result.is_null(2));
141 assert!(result.is_null(3));
142 }
143
144 #[test]
145 fn test_sub_error() {
146 let func = VectorSubFunction::default();
147
148 let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
149 Some("[1.0,2.0,3.0]".to_string()),
150 Some("[4.0,5.0,6.0]".to_string()),
151 None,
152 Some("[2.0,3.0,3.0]".to_string()),
153 ]));
154 let input1: ArrayRef = Arc::new(StringViewArray::from(vec![
155 Some("[1.0,1.0,1.0]".to_string()),
156 Some("[6.0,5.0,4.0]".to_string()),
157 Some("[3.0,2.0,2.0]".to_string()),
158 ]));
159
160 let args = ScalarFunctionArgs {
161 args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
162 arg_fields: vec![],
163 number_rows: 4,
164 return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
165 config_options: Arc::new(ConfigOptions::new()),
166 };
167 let e = func.invoke_with_args(args).unwrap_err();
168 assert!(e.to_string().starts_with(
169 "Internal error: Arguments has mixed length. Expected length: 4, found length: 3."
170 ));
171 }
172}