common_function/scalars/vector/
scalar_mul.rs1use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion_expr::Signature;
20use datatypes::arrow::datatypes::DataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
23use nalgebra::DVectorView;
24use snafu::ensure;
25
26use crate::function::{Function, FunctionContext};
27use crate::helper;
28use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
29
30const NAME: &str = "vec_scalar_mul";
31
32#[derive(Debug, Clone, Default)]
55pub struct ScalarMulFunction;
56
57impl Function for ScalarMulFunction {
58 fn name(&self) -> &str {
59 NAME
60 }
61
62 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
63 Ok(DataType::Binary)
64 }
65
66 fn signature(&self) -> Signature {
67 helper::one_of_sigs2(
68 vec![DataType::Float64],
69 vec![DataType::Utf8, DataType::Binary],
70 )
71 }
72
73 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
74 ensure!(
75 columns.len() == 2,
76 InvalidFuncArgsSnafu {
77 err_msg: format!(
78 "The length of the args is not correct, expect exactly two, have: {}",
79 columns.len()
80 ),
81 }
82 );
83 let arg0 = &columns[0];
84 let arg1 = &columns[1];
85
86 let len = arg0.len();
87 let mut result = BinaryVectorBuilder::with_capacity(len);
88 if len == 0 {
89 return Ok(result.to_vector());
90 }
91
92 let arg1_const = as_veclit_if_const(arg1)?;
93
94 for i in 0..len {
95 let arg0 = arg0.get(i).as_f64_lossy();
96 let Some(arg0) = arg0 else {
97 result.push_null();
98 continue;
99 };
100
101 let arg1 = match arg1_const.as_ref() {
102 Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
103 None => as_veclit(arg1.get_ref(i))?,
104 };
105 let Some(arg1) = arg1 else {
106 result.push_null();
107 continue;
108 };
109
110 let vec = DVectorView::from_slice(&arg1, arg1.len());
111 let vec_res = vec.scale(arg0 as _);
112
113 let veclit = vec_res.as_slice();
114 let binlit = veclit_to_binlit(veclit);
115 result.push(Some(&binlit));
116 }
117
118 Ok(result.to_vector())
119 }
120}
121
122impl Display for ScalarMulFunction {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 write!(f, "{}", NAME.to_ascii_uppercase())
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use std::sync::Arc;
131
132 use datatypes::vectors::{Float32Vector, StringVector};
133
134 use super::*;
135
136 #[test]
137 fn test_scalar_mul() {
138 let func = ScalarMulFunction;
139
140 let input0 = Arc::new(Float32Vector::from(vec![
141 Some(2.0),
142 Some(-0.5),
143 None,
144 Some(3.0),
145 ]));
146 let input1 = Arc::new(StringVector::from(vec![
147 Some("[1.0,2.0,3.0]".to_string()),
148 Some("[8.0,10.0,12.0]".to_string()),
149 Some("[7.0,8.0,9.0]".to_string()),
150 None,
151 ]));
152
153 let result = func
154 .eval(&FunctionContext::default(), &[input0, input1])
155 .unwrap();
156
157 let result = result.as_ref();
158 assert_eq!(result.len(), 4);
159 assert_eq!(
160 result.get_ref(0).as_binary().unwrap(),
161 Some(veclit_to_binlit(&[2.0, 4.0, 6.0]).as_slice())
162 );
163 assert_eq!(
164 result.get_ref(1).as_binary().unwrap(),
165 Some(veclit_to_binlit(&[-4.0, -5.0, -6.0]).as_slice())
166 );
167 assert!(result.get_ref(2).is_null());
168 assert!(result.get_ref(3).is_null());
169 }
170}