common_function/scalars/vector/
product.rs1use std::sync::Arc;
16
17use common_macro::{as_aggr_func_creator, AggrFuncTypeStore};
18use common_query::error::{CreateAccumulatorSnafu, Error, InvalidFuncArgsSnafu};
19use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
20use common_query::prelude::AccumulatorCreatorFunction;
21use datatypes::prelude::{ConcreteDataType, Value, *};
22use datatypes::vectors::VectorRef;
23use nalgebra::{Const, DVectorView, Dyn, OVector};
24use snafu::ensure;
25
26use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
27
28#[derive(Debug, Default)]
30pub struct VectorProduct {
31 product: Option<OVector<f32, Dyn>>,
32 has_null: bool,
33}
34
35#[as_aggr_func_creator]
36#[derive(Debug, Default, AggrFuncTypeStore)]
37pub struct VectorProductCreator {}
38
39impl AggregateFunctionCreator for VectorProductCreator {
40 fn creator(&self) -> AccumulatorCreatorFunction {
41 let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| {
42 ensure!(
43 types.len() == 1,
44 InvalidFuncArgsSnafu {
45 err_msg: format!(
46 "The length of the args is not correct, expect exactly one, have: {}",
47 types.len()
48 )
49 }
50 );
51 let input_type = &types[0];
52 match input_type {
53 ConcreteDataType::String(_) | ConcreteDataType::Binary(_) => {
54 Ok(Box::new(VectorProduct::default()))
55 }
56 _ => {
57 let err_msg = format!(
58 "\"VEC_PRODUCT\" aggregate function not support data type {:?}",
59 input_type.logical_type_id(),
60 );
61 CreateAccumulatorSnafu { err_msg }.fail()?
62 }
63 }
64 });
65 creator
66 }
67
68 fn output_type(&self) -> common_query::error::Result<ConcreteDataType> {
69 Ok(ConcreteDataType::binary_datatype())
70 }
71
72 fn state_types(&self) -> common_query::error::Result<Vec<ConcreteDataType>> {
73 Ok(vec![self.output_type()?])
74 }
75}
76
77impl VectorProduct {
78 fn inner(&mut self, len: usize) -> &mut OVector<f32, Dyn> {
79 self.product.get_or_insert_with(|| {
80 OVector::from_iterator_generic(Dyn(len), Const::<1>, (0..len).map(|_| 1.0))
81 })
82 }
83
84 fn update(&mut self, values: &[VectorRef], is_update: bool) -> Result<(), Error> {
85 if values.is_empty() || self.has_null {
86 return Ok(());
87 };
88 let column = &values[0];
89 let len = column.len();
90
91 match as_veclit_if_const(column)? {
92 Some(column) => {
93 let vec_column = DVectorView::from_slice(&column, column.len()).scale(len as f32);
94 *self.inner(vec_column.len()) =
95 (*self.inner(vec_column.len())).component_mul(&vec_column);
96 }
97 None => {
98 for i in 0..len {
99 let Some(arg0) = as_veclit(column.get_ref(i))? else {
100 if is_update {
101 self.has_null = true;
102 self.product = None;
103 }
104 return Ok(());
105 };
106 let vec_column = DVectorView::from_slice(&arg0, arg0.len());
107 *self.inner(vec_column.len()) =
108 (*self.inner(vec_column.len())).component_mul(&vec_column);
109 }
110 }
111 }
112 Ok(())
113 }
114}
115
116impl Accumulator for VectorProduct {
117 fn state(&self) -> common_query::error::Result<Vec<Value>> {
118 self.evaluate().map(|v| vec![v])
119 }
120
121 fn update_batch(&mut self, values: &[VectorRef]) -> common_query::error::Result<()> {
122 self.update(values, true)
123 }
124
125 fn merge_batch(&mut self, states: &[VectorRef]) -> common_query::error::Result<()> {
126 self.update(states, false)
127 }
128
129 fn evaluate(&self) -> common_query::error::Result<Value> {
130 match &self.product {
131 None => Ok(Value::Null),
132 Some(vector) => {
133 let v = vector.as_slice();
134 Ok(Value::from(veclit_to_binlit(v)))
135 }
136 }
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use std::sync::Arc;
143
144 use datatypes::vectors::{ConstantVector, StringVector};
145
146 use super::*;
147
148 #[test]
149 fn test_update_batch() {
150 let mut vec_product = VectorProduct::default();
152 vec_product.update_batch(&[]).unwrap();
153 assert!(vec_product.product.is_none());
154 assert!(!vec_product.has_null);
155 assert_eq!(Value::Null, vec_product.evaluate().unwrap());
156
157 let mut vec_product = VectorProduct::default();
159 let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![Some(
160 "[1.0,2.0,3.0]".to_string(),
161 )]))];
162 vec_product.update_batch(&v).unwrap();
163 assert_eq!(
164 Value::from(veclit_to_binlit(&[1.0, 2.0, 3.0])),
165 vec_product.evaluate().unwrap()
166 );
167
168 let mut vec_product = VectorProduct::default();
170 let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![Option::<String>::None]))];
171 vec_product.update_batch(&v).unwrap();
172 assert_eq!(Value::Null, vec_product.evaluate().unwrap());
173
174 let mut vec_product = VectorProduct::default();
176 let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![
177 Some("[1.0,2.0,3.0]".to_string()),
178 Some("[4.0,5.0,6.0]".to_string()),
179 Some("[7.0,8.0,9.0]".to_string()),
180 ]))];
181 vec_product.update_batch(&v).unwrap();
182 assert_eq!(
183 Value::from(veclit_to_binlit(&[28.0, 80.0, 162.0])),
184 vec_product.evaluate().unwrap()
185 );
186
187 let mut vec_product = VectorProduct::default();
189 let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![
190 Some("[1.0,2.0,3.0]".to_string()),
191 None,
192 Some("[7.0,8.0,9.0]".to_string()),
193 ]))];
194 vec_product.update_batch(&v).unwrap();
195 assert_eq!(Value::Null, vec_product.evaluate().unwrap());
196
197 let mut vec_product = VectorProduct::default();
199 let v: Vec<VectorRef> = vec![Arc::new(ConstantVector::new(
200 Arc::new(StringVector::from_vec(vec!["[1.0,2.0,3.0]".to_string()])),
201 4,
202 ))];
203
204 vec_product.update_batch(&v).unwrap();
205
206 assert_eq!(
207 Value::from(veclit_to_binlit(&[4.0, 8.0, 12.0])),
208 vec_product.evaluate().unwrap()
209 );
210 }
211}