common_function/scalars/vector/
vector_mul.rs1use std::borrow::Cow;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion_common::{DataFusionError, ScalarValue};
20use datafusion_expr::{ScalarFunctionArgs, Signature};
21use nalgebra::DVectorView;
22
23use crate::function::Function;
24use crate::scalars::vector::impl_conv::veclit_to_binlit;
25use crate::scalars::vector::{VectorCalculator, define_args_of_two_vector_literals_udf};
26
27const NAME: &str = "vec_mul";
28
29define_args_of_two_vector_literals_udf!(
30VectorMulFunction);
45
46impl Function for VectorMulFunction {
47 fn name(&self) -> &str {
48 NAME
49 }
50
51 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
52 Ok(DataType::BinaryView)
53 }
54
55 fn signature(&self) -> &Signature {
56 &self.signature
57 }
58
59 fn invoke_with_args(
60 &self,
61 args: ScalarFunctionArgs,
62 ) -> datafusion_common::Result<ColumnarValue> {
63 let body = |v0: &Option<Cow<[f32]>>,
64 v1: &Option<Cow<[f32]>>|
65 -> datafusion_common::Result<ScalarValue> {
66 let result = if let (Some(v0), Some(v1)) = (v0, v1) {
67 let v0 = DVectorView::from_slice(v0, v0.len());
68 let v1 = DVectorView::from_slice(v1, v1.len());
69 if v0.len() != v1.len() {
70 return Err(DataFusionError::Execution(format!(
71 "vectors length not match: {}",
72 self.name()
73 )));
74 }
75
76 let result = veclit_to_binlit((v0.component_mul(&v1)).as_slice());
77 Some(result)
78 } else {
79 None
80 };
81 Ok(ScalarValue::BinaryView(result))
82 };
83
84 let calculator = VectorCalculator {
85 name: self.name(),
86 func: body,
87 };
88 calculator.invoke_with_vectors(args)
89 }
90}
91
92#[cfg(test)]
93mod tests {
94 use std::sync::Arc;
95
96 use arrow_schema::Field;
97 use datafusion::arrow::array::{Array, AsArray, StringViewArray};
98 use datafusion_common::config::ConfigOptions;
99
100 use super::*;
101
102 #[test]
103 fn test_vector_mul() {
104 let func = VectorMulFunction::default();
105
106 let vec0 = vec![1.0, 2.0, 3.0];
107 let vec1 = vec![1.0, 1.0];
108 let input0 = Arc::new(StringViewArray::from(vec![Some(format!("{vec0:?}"))]));
109 let input1 = Arc::new(StringViewArray::from(vec![Some(format!("{vec1:?}"))]));
110
111 let args = ScalarFunctionArgs {
112 args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
113 arg_fields: vec![],
114 number_rows: 4,
115 return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
116 config_options: Arc::new(ConfigOptions::new()),
117 };
118 let e = func.invoke_with_args(args).unwrap_err();
119 assert!(
120 e.to_string()
121 .starts_with("Execution error: vectors length not match: vec_mul")
122 );
123
124 let input0 = Arc::new(StringViewArray::from(vec![
125 Some("[1.0,2.0,3.0]".to_string()),
126 Some("[8.0,10.0,12.0]".to_string()),
127 Some("[7.0,8.0,9.0]".to_string()),
128 None,
129 ]));
130
131 let input1 = Arc::new(StringViewArray::from(vec![
132 Some("[1.0,1.0,1.0]".to_string()),
133 Some("[2.0,2.0,2.0]".to_string()),
134 None,
135 Some("[3.0,3.0,3.0]".to_string()),
136 ]));
137
138 let args = ScalarFunctionArgs {
139 args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
140 arg_fields: vec![],
141 number_rows: 4,
142 return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
143 config_options: Arc::new(ConfigOptions::new()),
144 };
145 let result = func
146 .invoke_with_args(args)
147 .and_then(|x| x.to_array(4))
148 .unwrap();
149
150 let result = result.as_binary_view();
151 assert_eq!(result.len(), 4);
152 assert_eq!(
153 result.value(0),
154 veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice()
155 );
156 assert_eq!(
157 result.value(1),
158 veclit_to_binlit(&[16.0, 20.0, 24.0]).as_slice()
159 );
160 assert!(result.is_null(2));
161 assert!(result.is_null(3));
162 }
163}