common_function/aggrs/vector/
sum.rs1use std::sync::Arc;
16
17use arrow::array::{Array, ArrayRef, AsArray, BinaryArray, StringArray};
18use arrow_schema::{DataType, Field};
19use datafusion_common::{Result, ScalarValue};
20use datafusion_expr::{
21 Accumulator, AggregateUDF, Signature, SimpleAggregateUDF, TypeSignature, Volatility,
22};
23use datafusion_functions_aggregate_common::accumulator::AccumulatorArgs;
24use nalgebra::{Const, DVectorView, Dyn, OVector};
25
26use crate::scalars::vector::impl_conv::{
27 binlit_as_veclit, parse_veclit_from_strlit, veclit_to_binlit,
28};
29
30#[derive(Debug, Default)]
32pub struct VectorSum {
33 sum: Option<OVector<f32, Dyn>>,
34 has_null: bool,
35}
36
37impl VectorSum {
38 pub fn uadf_impl() -> AggregateUDF {
40 let signature = Signature::one_of(
41 vec![
42 TypeSignature::Exact(vec![DataType::Utf8]),
43 TypeSignature::Exact(vec![DataType::Binary]),
44 ],
45 Volatility::Immutable,
46 );
47 let udaf = SimpleAggregateUDF::new_with_signature(
48 "vec_sum",
49 signature,
50 DataType::Binary,
51 Arc::new(Self::accumulator),
52 vec![Arc::new(Field::new("x", DataType::Binary, true))],
53 );
54 AggregateUDF::from(udaf)
55 }
56
57 fn accumulator(args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
58 if args.schema.fields().len() != 1 {
59 return Err(datafusion_common::DataFusionError::Internal(format!(
60 "expect creating `VEC_SUM` with only one input field, actual {}",
61 args.schema.fields().len()
62 )));
63 }
64
65 let t = args.schema.field(0).data_type();
66 if !matches!(t, DataType::Utf8 | DataType::Binary) {
67 return Err(datafusion_common::DataFusionError::Internal(format!(
68 "unexpected input datatype {t} when creating `VEC_SUM`"
69 )));
70 }
71
72 Ok(Box::new(VectorSum::default()))
73 }
74
75 fn inner(&mut self, len: usize) -> &mut OVector<f32, Dyn> {
76 self.sum
77 .get_or_insert_with(|| OVector::zeros_generic(Dyn(len), Const::<1>))
78 }
79
80 fn update(&mut self, values: &[ArrayRef], is_update: bool) -> Result<()> {
81 if values.is_empty() || self.has_null {
82 return Ok(());
83 };
84
85 match values[0].data_type() {
86 DataType::Utf8 => {
87 let arr: &StringArray = values[0].as_string();
88 for s in arr.iter() {
89 let Some(s) = s else {
90 if is_update {
91 self.has_null = true;
92 self.sum = None;
93 }
94 return Ok(());
95 };
96 let values = parse_veclit_from_strlit(s)?;
97 let vec_column = DVectorView::from_slice(&values, values.len());
98 *self.inner(vec_column.len()) += vec_column;
99 }
100 }
101 DataType::Binary => {
102 let arr: &BinaryArray = values[0].as_binary();
103 for b in arr.iter() {
104 let Some(b) = b else {
105 if is_update {
106 self.has_null = true;
107 self.sum = None;
108 }
109 return Ok(());
110 };
111 let values = binlit_as_veclit(b)?;
112 let vec_column = DVectorView::from_slice(&values, values.len());
113 *self.inner(vec_column.len()) += vec_column;
114 }
115 }
116 _ => {
117 return Err(datafusion_common::DataFusionError::NotImplemented(format!(
118 "unsupported data type {} for `VEC_SUM`",
119 values[0].data_type()
120 )))
121 }
122 }
123 Ok(())
124 }
125}
126
127impl Accumulator for VectorSum {
128 fn state(&mut self) -> Result<Vec<ScalarValue>> {
129 self.evaluate().map(|v| vec![v])
130 }
131
132 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
133 self.update(values, true)
134 }
135
136 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
137 self.update(states, false)
138 }
139
140 fn evaluate(&mut self) -> Result<ScalarValue> {
141 match &self.sum {
142 None => Ok(ScalarValue::Binary(None)),
143 Some(vector) => Ok(ScalarValue::Binary(Some(veclit_to_binlit(
144 vector.as_slice(),
145 )))),
146 }
147 }
148
149 fn size(&self) -> usize {
150 size_of_val(self)
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use std::sync::Arc;
157
158 use arrow::array::StringArray;
159 use datatypes::scalars::ScalarVector;
160 use datatypes::vectors::{ConstantVector, StringVector, Vector};
161
162 use super::*;
163
164 #[test]
165 fn test_update_batch() {
166 let mut vec_sum = VectorSum::default();
168 vec_sum.update_batch(&[]).unwrap();
169 assert!(vec_sum.sum.is_none());
170 assert!(!vec_sum.has_null);
171 assert_eq!(ScalarValue::Binary(None), vec_sum.evaluate().unwrap());
172
173 let mut vec_sum = VectorSum::default();
175 let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![Some(
176 "[1.0,2.0,3.0]".to_string(),
177 )]))];
178 vec_sum.update_batch(&v).unwrap();
179 assert_eq!(
180 ScalarValue::Binary(Some(veclit_to_binlit(&[1.0, 2.0, 3.0]))),
181 vec_sum.evaluate().unwrap()
182 );
183
184 let mut vec_sum = VectorSum::default();
186 let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![Option::<String>::None]))];
187 vec_sum.update_batch(&v).unwrap();
188 assert_eq!(ScalarValue::Binary(None), vec_sum.evaluate().unwrap());
189
190 let mut vec_sum = VectorSum::default();
192 let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
193 Some("[1.0,2.0,3.0]".to_string()),
194 Some("[4.0,5.0,6.0]".to_string()),
195 Some("[7.0,8.0,9.0]".to_string()),
196 ]))];
197 vec_sum.update_batch(&v).unwrap();
198 assert_eq!(
199 ScalarValue::Binary(Some(veclit_to_binlit(&[12.0, 15.0, 18.0]))),
200 vec_sum.evaluate().unwrap()
201 );
202
203 let mut vec_sum = VectorSum::default();
205 let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
206 Some("[1.0,2.0,3.0]".to_string()),
207 None,
208 Some("[7.0,8.0,9.0]".to_string()),
209 ]))];
210 vec_sum.update_batch(&v).unwrap();
211 assert_eq!(ScalarValue::Binary(None), vec_sum.evaluate().unwrap());
212
213 let mut vec_sum = VectorSum::default();
215 let v: Vec<ArrayRef> = vec![Arc::new(ConstantVector::new(
216 Arc::new(StringVector::from_vec(vec!["[1.0,2.0,3.0]".to_string()])),
217 4,
218 ))
219 .to_arrow_array()];
220 vec_sum.update_batch(&v).unwrap();
221 assert_eq!(
222 ScalarValue::Binary(Some(veclit_to_binlit(&[4.0, 8.0, 12.0]))),
223 vec_sum.evaluate().unwrap()
224 );
225 }
226}