common_function/scalars/vector/
sum.rs1use std::sync::Arc;
16
17use common_macro::{as_aggr_func_creator, AggrFuncTypeStore};
18use common_query::error::{CreateAccumulatorSnafu, Error, InvalidFuncArgsSnafu};
19use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
20use common_query::prelude::AccumulatorCreatorFunction;
21use datatypes::prelude::{ConcreteDataType, Value, *};
22use datatypes::vectors::VectorRef;
23use nalgebra::{Const, DVectorView, Dyn, OVector};
24use snafu::ensure;
25
26use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
27
28#[derive(Debug, Default)]
29pub struct VectorSum {
30 sum: Option<OVector<f32, Dyn>>,
31 has_null: bool,
32}
33
34#[as_aggr_func_creator]
35#[derive(Debug, Default, AggrFuncTypeStore)]
36pub struct VectorSumCreator {}
37
38impl AggregateFunctionCreator for VectorSumCreator {
39 fn creator(&self) -> AccumulatorCreatorFunction {
40 let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| {
41 ensure!(
42 types.len() == 1,
43 InvalidFuncArgsSnafu {
44 err_msg: format!(
45 "The length of the args is not correct, expect exactly one, have: {}",
46 types.len()
47 )
48 }
49 );
50 let input_type = &types[0];
51 match input_type {
52 ConcreteDataType::String(_) | ConcreteDataType::Binary(_) => {
53 Ok(Box::new(VectorSum::default()))
54 }
55 _ => {
56 let err_msg = format!(
57 "\"VEC_SUM\" aggregate function not support data type {:?}",
58 input_type.logical_type_id(),
59 );
60 CreateAccumulatorSnafu { err_msg }.fail()?
61 }
62 }
63 });
64 creator
65 }
66
67 fn output_type(&self) -> common_query::error::Result<ConcreteDataType> {
68 Ok(ConcreteDataType::binary_datatype())
69 }
70
71 fn state_types(&self) -> common_query::error::Result<Vec<ConcreteDataType>> {
72 Ok(vec![self.output_type()?])
73 }
74}
75
76impl VectorSum {
77 fn inner(&mut self, len: usize) -> &mut OVector<f32, Dyn> {
78 self.sum
79 .get_or_insert_with(|| OVector::zeros_generic(Dyn(len), Const::<1>))
80 }
81
82 fn update(&mut self, values: &[VectorRef], is_update: bool) -> Result<(), Error> {
83 if values.is_empty() || self.has_null {
84 return Ok(());
85 };
86 let column = &values[0];
87 let len = column.len();
88
89 match as_veclit_if_const(column)? {
90 Some(column) => {
91 let vec_column = DVectorView::from_slice(&column, column.len()).scale(len as f32);
92 *self.inner(vec_column.len()) += vec_column;
93 }
94 None => {
95 for i in 0..len {
96 let Some(arg0) = as_veclit(column.get_ref(i))? else {
97 if is_update {
98 self.has_null = true;
99 self.sum = None;
100 }
101 return Ok(());
102 };
103 let vec_column = DVectorView::from_slice(&arg0, arg0.len());
104 *self.inner(vec_column.len()) += vec_column;
105 }
106 }
107 }
108 Ok(())
109 }
110}
111
112impl Accumulator for VectorSum {
113 fn state(&self) -> common_query::error::Result<Vec<Value>> {
114 self.evaluate().map(|v| vec![v])
115 }
116
117 fn update_batch(&mut self, values: &[VectorRef]) -> common_query::error::Result<()> {
118 self.update(values, true)
119 }
120
121 fn merge_batch(&mut self, states: &[VectorRef]) -> common_query::error::Result<()> {
122 self.update(states, false)
123 }
124
125 fn evaluate(&self) -> common_query::error::Result<Value> {
126 match &self.sum {
127 None => Ok(Value::Null),
128 Some(vector) => Ok(Value::from(veclit_to_binlit(vector.as_slice()))),
129 }
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use std::sync::Arc;
136
137 use datatypes::vectors::{ConstantVector, StringVector};
138
139 use super::*;
140
141 #[test]
142 fn test_update_batch() {
143 let mut vec_sum = VectorSum::default();
145 vec_sum.update_batch(&[]).unwrap();
146 assert!(vec_sum.sum.is_none());
147 assert!(!vec_sum.has_null);
148 assert_eq!(Value::Null, vec_sum.evaluate().unwrap());
149
150 let mut vec_sum = VectorSum::default();
152 let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![Some(
153 "[1.0,2.0,3.0]".to_string(),
154 )]))];
155 vec_sum.update_batch(&v).unwrap();
156 assert_eq!(
157 Value::from(veclit_to_binlit(&[1.0, 2.0, 3.0])),
158 vec_sum.evaluate().unwrap()
159 );
160
161 let mut vec_sum = VectorSum::default();
163 let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![Option::<String>::None]))];
164 vec_sum.update_batch(&v).unwrap();
165 assert_eq!(Value::Null, vec_sum.evaluate().unwrap());
166
167 let mut vec_sum = VectorSum::default();
169 let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![
170 Some("[1.0,2.0,3.0]".to_string()),
171 Some("[4.0,5.0,6.0]".to_string()),
172 Some("[7.0,8.0,9.0]".to_string()),
173 ]))];
174 vec_sum.update_batch(&v).unwrap();
175 assert_eq!(
176 Value::from(veclit_to_binlit(&[12.0, 15.0, 18.0])),
177 vec_sum.evaluate().unwrap()
178 );
179
180 let mut vec_sum = VectorSum::default();
182 let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![
183 Some("[1.0,2.0,3.0]".to_string()),
184 None,
185 Some("[7.0,8.0,9.0]".to_string()),
186 ]))];
187 vec_sum.update_batch(&v).unwrap();
188 assert_eq!(Value::Null, vec_sum.evaluate().unwrap());
189
190 let mut vec_sum = VectorSum::default();
192 let v: Vec<VectorRef> = vec![Arc::new(ConstantVector::new(
193 Arc::new(StringVector::from_vec(vec!["[1.0,2.0,3.0]".to_string()])),
194 4,
195 ))];
196 vec_sum.update_batch(&v).unwrap();
197 assert_eq!(
198 Value::from(veclit_to_binlit(&[4.0, 8.0, 12.0])),
199 vec_sum.evaluate().unwrap()
200 );
201 }
202}