common_function/aggrs/vector/
product.rs1use std::sync::Arc;
16
17use common_macro::{as_aggr_func_creator, AggrFuncTypeStore};
18use common_query::error::{CreateAccumulatorSnafu, Error, InvalidFuncArgsSnafu};
19use common_query::logical_plan::{
20 create_aggregate_function, Accumulator, AggregateFunctionCreator,
21};
22use common_query::prelude::AccumulatorCreatorFunction;
23use datafusion_expr::AggregateUDF;
24use datatypes::prelude::{ConcreteDataType, Value, *};
25use datatypes::vectors::VectorRef;
26use nalgebra::{Const, DVectorView, Dyn, OVector};
27use snafu::ensure;
28
29use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
30
31#[derive(Debug, Default)]
33pub struct VectorProduct {
34 product: Option<OVector<f32, Dyn>>,
35 has_null: bool,
36}
37
38#[as_aggr_func_creator]
39#[derive(Debug, Default, AggrFuncTypeStore)]
40pub struct VectorProductCreator {}
41
42impl AggregateFunctionCreator for VectorProductCreator {
43 fn creator(&self) -> AccumulatorCreatorFunction {
44 let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| {
45 ensure!(
46 types.len() == 1,
47 InvalidFuncArgsSnafu {
48 err_msg: format!(
49 "The length of the args is not correct, expect exactly one, have: {}",
50 types.len()
51 )
52 }
53 );
54 let input_type = &types[0];
55 match input_type {
56 ConcreteDataType::String(_) | ConcreteDataType::Binary(_) => {
57 Ok(Box::new(VectorProduct::default()))
58 }
59 _ => {
60 let err_msg = format!(
61 "\"VEC_PRODUCT\" aggregate function not support data type {:?}",
62 input_type.logical_type_id(),
63 );
64 CreateAccumulatorSnafu { err_msg }.fail()?
65 }
66 }
67 });
68 creator
69 }
70
71 fn output_type(&self) -> common_query::error::Result<ConcreteDataType> {
72 Ok(ConcreteDataType::binary_datatype())
73 }
74
75 fn state_types(&self) -> common_query::error::Result<Vec<ConcreteDataType>> {
76 Ok(vec![self.output_type()?])
77 }
78}
79
80impl VectorProduct {
81 pub fn uadf_impl() -> AggregateUDF {
83 create_aggregate_function(
84 "vec_product".to_string(),
85 1,
86 Arc::new(VectorProductCreator::default()),
87 )
88 .into()
89 }
90
91 fn inner(&mut self, len: usize) -> &mut OVector<f32, Dyn> {
92 self.product.get_or_insert_with(|| {
93 OVector::from_iterator_generic(Dyn(len), Const::<1>, (0..len).map(|_| 1.0))
94 })
95 }
96
97 fn update(&mut self, values: &[VectorRef], is_update: bool) -> Result<(), Error> {
98 if values.is_empty() || self.has_null {
99 return Ok(());
100 };
101 let column = &values[0];
102 let len = column.len();
103
104 match as_veclit_if_const(column)? {
105 Some(column) => {
106 let vec_column = DVectorView::from_slice(&column, column.len()).scale(len as f32);
107 *self.inner(vec_column.len()) =
108 (*self.inner(vec_column.len())).component_mul(&vec_column);
109 }
110 None => {
111 for i in 0..len {
112 let Some(arg0) = as_veclit(column.get_ref(i))? else {
113 if is_update {
114 self.has_null = true;
115 self.product = None;
116 }
117 return Ok(());
118 };
119 let vec_column = DVectorView::from_slice(&arg0, arg0.len());
120 *self.inner(vec_column.len()) =
121 (*self.inner(vec_column.len())).component_mul(&vec_column);
122 }
123 }
124 }
125 Ok(())
126 }
127}
128
129impl Accumulator for VectorProduct {
130 fn state(&self) -> common_query::error::Result<Vec<Value>> {
131 self.evaluate().map(|v| vec![v])
132 }
133
134 fn update_batch(&mut self, values: &[VectorRef]) -> common_query::error::Result<()> {
135 self.update(values, true)
136 }
137
138 fn merge_batch(&mut self, states: &[VectorRef]) -> common_query::error::Result<()> {
139 self.update(states, false)
140 }
141
142 fn evaluate(&self) -> common_query::error::Result<Value> {
143 match &self.product {
144 None => Ok(Value::Null),
145 Some(vector) => {
146 let v = vector.as_slice();
147 Ok(Value::from(veclit_to_binlit(v)))
148 }
149 }
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use std::sync::Arc;
156
157 use datatypes::vectors::{ConstantVector, StringVector};
158
159 use super::*;
160
161 #[test]
162 fn test_update_batch() {
163 let mut vec_product = VectorProduct::default();
165 vec_product.update_batch(&[]).unwrap();
166 assert!(vec_product.product.is_none());
167 assert!(!vec_product.has_null);
168 assert_eq!(Value::Null, vec_product.evaluate().unwrap());
169
170 let mut vec_product = VectorProduct::default();
172 let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![Some(
173 "[1.0,2.0,3.0]".to_string(),
174 )]))];
175 vec_product.update_batch(&v).unwrap();
176 assert_eq!(
177 Value::from(veclit_to_binlit(&[1.0, 2.0, 3.0])),
178 vec_product.evaluate().unwrap()
179 );
180
181 let mut vec_product = VectorProduct::default();
183 let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![Option::<String>::None]))];
184 vec_product.update_batch(&v).unwrap();
185 assert_eq!(Value::Null, vec_product.evaluate().unwrap());
186
187 let mut vec_product = VectorProduct::default();
189 let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![
190 Some("[1.0,2.0,3.0]".to_string()),
191 Some("[4.0,5.0,6.0]".to_string()),
192 Some("[7.0,8.0,9.0]".to_string()),
193 ]))];
194 vec_product.update_batch(&v).unwrap();
195 assert_eq!(
196 Value::from(veclit_to_binlit(&[28.0, 80.0, 162.0])),
197 vec_product.evaluate().unwrap()
198 );
199
200 let mut vec_product = VectorProduct::default();
202 let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![
203 Some("[1.0,2.0,3.0]".to_string()),
204 None,
205 Some("[7.0,8.0,9.0]".to_string()),
206 ]))];
207 vec_product.update_batch(&v).unwrap();
208 assert_eq!(Value::Null, vec_product.evaluate().unwrap());
209
210 let mut vec_product = VectorProduct::default();
212 let v: Vec<VectorRef> = vec![Arc::new(ConstantVector::new(
213 Arc::new(StringVector::from_vec(vec!["[1.0,2.0,3.0]".to_string()])),
214 4,
215 ))];
216
217 vec_product.update_batch(&v).unwrap();
218
219 assert_eq!(
220 Value::from(veclit_to_binlit(&[4.0, 8.0, 12.0])),
221 vec_product.evaluate().unwrap()
222 );
223 }
224}