common_function/scalars/vector/
scalar_mul.rs1use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use common_query::prelude::Signature;
20use datatypes::prelude::ConcreteDataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
23use nalgebra::DVectorView;
24use snafu::ensure;
25
26use crate::function::{Function, FunctionContext};
27use crate::helper;
28use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
29
30const NAME: &str = "vec_scalar_mul";
31
32#[derive(Debug, Clone, Default)]
55pub struct ScalarMulFunction;
56
57impl Function for ScalarMulFunction {
58 fn name(&self) -> &str {
59 NAME
60 }
61
62 fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
63 Ok(ConcreteDataType::binary_datatype())
64 }
65
66 fn signature(&self) -> Signature {
67 helper::one_of_sigs2(
68 vec![ConcreteDataType::float64_datatype()],
69 vec![
70 ConcreteDataType::string_datatype(),
71 ConcreteDataType::binary_datatype(),
72 ],
73 )
74 }
75
76 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
77 ensure!(
78 columns.len() == 2,
79 InvalidFuncArgsSnafu {
80 err_msg: format!(
81 "The length of the args is not correct, expect exactly two, have: {}",
82 columns.len()
83 ),
84 }
85 );
86 let arg0 = &columns[0];
87 let arg1 = &columns[1];
88
89 let len = arg0.len();
90 let mut result = BinaryVectorBuilder::with_capacity(len);
91 if len == 0 {
92 return Ok(result.to_vector());
93 }
94
95 let arg1_const = as_veclit_if_const(arg1)?;
96
97 for i in 0..len {
98 let arg0 = arg0.get(i).as_f64_lossy();
99 let Some(arg0) = arg0 else {
100 result.push_null();
101 continue;
102 };
103
104 let arg1 = match arg1_const.as_ref() {
105 Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
106 None => as_veclit(arg1.get_ref(i))?,
107 };
108 let Some(arg1) = arg1 else {
109 result.push_null();
110 continue;
111 };
112
113 let vec = DVectorView::from_slice(&arg1, arg1.len());
114 let vec_res = vec.scale(arg0 as _);
115
116 let veclit = vec_res.as_slice();
117 let binlit = veclit_to_binlit(veclit);
118 result.push(Some(&binlit));
119 }
120
121 Ok(result.to_vector())
122 }
123}
124
125impl Display for ScalarMulFunction {
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 write!(f, "{}", NAME.to_ascii_uppercase())
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use std::sync::Arc;
134
135 use datatypes::vectors::{Float32Vector, StringVector};
136
137 use super::*;
138
139 #[test]
140 fn test_scalar_mul() {
141 let func = ScalarMulFunction;
142
143 let input0 = Arc::new(Float32Vector::from(vec![
144 Some(2.0),
145 Some(-0.5),
146 None,
147 Some(3.0),
148 ]));
149 let input1 = Arc::new(StringVector::from(vec![
150 Some("[1.0,2.0,3.0]".to_string()),
151 Some("[8.0,10.0,12.0]".to_string()),
152 Some("[7.0,8.0,9.0]".to_string()),
153 None,
154 ]));
155
156 let result = func
157 .eval(&FunctionContext::default(), &[input0, input1])
158 .unwrap();
159
160 let result = result.as_ref();
161 assert_eq!(result.len(), 4);
162 assert_eq!(
163 result.get_ref(0).as_binary().unwrap(),
164 Some(veclit_to_binlit(&[2.0, 4.0, 6.0]).as_slice())
165 );
166 assert_eq!(
167 result.get_ref(1).as_binary().unwrap(),
168 Some(veclit_to_binlit(&[-4.0, -5.0, -6.0]).as_slice())
169 );
170 assert!(result.get_ref(2).is_null());
171 assert!(result.get_ref(3).is_null());
172 }
173}