common_function/scalars/vector/
vector_mul.rs1use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::Result;
19use datafusion::arrow::datatypes::DataType;
20use datafusion::logical_expr::ColumnarValue;
21use datafusion_common::{DataFusionError, ScalarValue};
22use datafusion_expr::{ScalarFunctionArgs, Signature};
23use nalgebra::DVectorView;
24
25use crate::function::Function;
26use crate::helper;
27use crate::scalars::vector::VectorCalculator;
28use crate::scalars::vector::impl_conv::veclit_to_binlit;
29
30const NAME: &str = "vec_mul";
31
32#[derive(Debug, Clone, Default)]
47pub struct VectorMulFunction;
48
49impl Function for VectorMulFunction {
50 fn name(&self) -> &str {
51 NAME
52 }
53
54 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
55 Ok(DataType::BinaryView)
56 }
57
58 fn signature(&self) -> Signature {
59 helper::one_of_sigs2(
60 vec![DataType::Utf8, DataType::Binary],
61 vec![DataType::Utf8, DataType::Binary],
62 )
63 }
64
65 fn invoke_with_args(
66 &self,
67 args: ScalarFunctionArgs,
68 ) -> datafusion_common::Result<ColumnarValue> {
69 let body = |v0: &Option<Cow<[f32]>>,
70 v1: &Option<Cow<[f32]>>|
71 -> datafusion_common::Result<ScalarValue> {
72 let result = if let (Some(v0), Some(v1)) = (v0, v1) {
73 let v0 = DVectorView::from_slice(v0, v0.len());
74 let v1 = DVectorView::from_slice(v1, v1.len());
75 if v0.len() != v1.len() {
76 return Err(DataFusionError::Execution(format!(
77 "vectors length not match: {}",
78 self.name()
79 )));
80 }
81
82 let result = veclit_to_binlit((v0.component_mul(&v1)).as_slice());
83 Some(result)
84 } else {
85 None
86 };
87 Ok(ScalarValue::BinaryView(result))
88 };
89
90 let calculator = VectorCalculator {
91 name: self.name(),
92 func: body,
93 };
94 calculator.invoke_with_vectors(args)
95 }
96}
97
98impl Display for VectorMulFunction {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 write!(f, "{}", NAME.to_ascii_uppercase())
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use std::sync::Arc;
107
108 use arrow_schema::Field;
109 use datafusion::arrow::array::{Array, AsArray, StringViewArray};
110 use datafusion_common::config::ConfigOptions;
111
112 use super::*;
113
114 #[test]
115 fn test_vector_mul() {
116 let func = VectorMulFunction;
117
118 let vec0 = vec![1.0, 2.0, 3.0];
119 let vec1 = vec![1.0, 1.0];
120 let input0 = Arc::new(StringViewArray::from(vec![Some(format!("{vec0:?}"))]));
121 let input1 = Arc::new(StringViewArray::from(vec![Some(format!("{vec1:?}"))]));
122
123 let args = ScalarFunctionArgs {
124 args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
125 arg_fields: vec![],
126 number_rows: 4,
127 return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
128 config_options: Arc::new(ConfigOptions::new()),
129 };
130 let e = func.invoke_with_args(args).unwrap_err();
131 assert!(
132 e.to_string()
133 .starts_with("Execution error: vectors length not match: vec_mul")
134 );
135
136 let input0 = Arc::new(StringViewArray::from(vec![
137 Some("[1.0,2.0,3.0]".to_string()),
138 Some("[8.0,10.0,12.0]".to_string()),
139 Some("[7.0,8.0,9.0]".to_string()),
140 None,
141 ]));
142
143 let input1 = Arc::new(StringViewArray::from(vec![
144 Some("[1.0,1.0,1.0]".to_string()),
145 Some("[2.0,2.0,2.0]".to_string()),
146 None,
147 Some("[3.0,3.0,3.0]".to_string()),
148 ]));
149
150 let args = ScalarFunctionArgs {
151 args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
152 arg_fields: vec![],
153 number_rows: 4,
154 return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
155 config_options: Arc::new(ConfigOptions::new()),
156 };
157 let result = func
158 .invoke_with_args(args)
159 .and_then(|x| x.to_array(4))
160 .unwrap();
161
162 let result = result.as_binary_view();
163 assert_eq!(result.len(), 4);
164 assert_eq!(
165 result.value(0),
166 veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice()
167 );
168 assert_eq!(
169 result.value(1),
170 veclit_to_binlit(&[16.0, 20.0, 24.0]).as_slice()
171 );
172 assert!(result.is_null(2));
173 assert!(result.is_null(3));
174 }
175}