common_function/aggrs/vector/
sum.rs1use std::sync::Arc;
16
17use arrow::array::{Array, ArrayRef, AsArray, BinaryArray, LargeStringArray, StringArray};
18use arrow_schema::{DataType, Field};
19use datafusion_common::{Result, ScalarValue};
20use datafusion_expr::{
21 Accumulator, AggregateUDF, Signature, SimpleAggregateUDF, TypeSignature, Volatility,
22};
23use datafusion_functions_aggregate_common::accumulator::AccumulatorArgs;
24use nalgebra::{Const, DVectorView, Dyn, OVector};
25
26use crate::scalars::vector::impl_conv::{
27 binlit_as_veclit, parse_veclit_from_strlit, veclit_to_binlit,
28};
29
30#[derive(Debug, Default)]
32pub struct VectorSum {
33 sum: Option<OVector<f32, Dyn>>,
34 has_null: bool,
35}
36
37impl VectorSum {
38 pub fn uadf_impl() -> AggregateUDF {
40 let signature = Signature::one_of(
41 vec![
42 TypeSignature::Exact(vec![DataType::Utf8]),
43 TypeSignature::Exact(vec![DataType::Binary]),
44 ],
45 Volatility::Immutable,
46 );
47 let udaf = SimpleAggregateUDF::new_with_signature(
48 "vec_sum",
49 signature,
50 DataType::Binary,
51 Arc::new(Self::accumulator),
52 vec![Arc::new(Field::new("x", DataType::Binary, true))],
53 );
54 AggregateUDF::from(udaf)
55 }
56
57 fn accumulator(args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
58 if args.schema.fields().len() != 1 {
59 return Err(datafusion_common::DataFusionError::Internal(format!(
60 "expect creating `VEC_SUM` with only one input field, actual {}",
61 args.schema.fields().len()
62 )));
63 }
64
65 let t = args.schema.field(0).data_type();
66 if !matches!(t, DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary) {
67 return Err(datafusion_common::DataFusionError::Internal(format!(
68 "unexpected input datatype {t} when creating `VEC_SUM`"
69 )));
70 }
71
72 Ok(Box::new(VectorSum::default()))
73 }
74
75 fn inner(&mut self, len: usize) -> &mut OVector<f32, Dyn> {
76 self.sum
77 .get_or_insert_with(|| OVector::zeros_generic(Dyn(len), Const::<1>))
78 }
79
80 fn update(&mut self, values: &[ArrayRef], is_update: bool) -> Result<()> {
81 if values.is_empty() || self.has_null {
82 return Ok(());
83 };
84
85 match values[0].data_type() {
86 DataType::Utf8 => {
87 let arr: &StringArray = values[0].as_string();
88 for s in arr.iter() {
89 let Some(s) = s else {
90 if is_update {
91 self.has_null = true;
92 self.sum = None;
93 }
94 return Ok(());
95 };
96 let values = parse_veclit_from_strlit(s)?;
97 let vec_column = DVectorView::from_slice(&values, values.len());
98 *self.inner(vec_column.len()) += vec_column;
99 }
100 }
101 DataType::LargeUtf8 => {
102 let arr: &LargeStringArray = values[0].as_string();
103 for s in arr.iter() {
104 let Some(s) = s else {
105 if is_update {
106 self.has_null = true;
107 self.sum = None;
108 }
109 return Ok(());
110 };
111 let values = parse_veclit_from_strlit(s)?;
112 let vec_column = DVectorView::from_slice(&values, values.len());
113 *self.inner(vec_column.len()) += vec_column;
114 }
115 }
116 DataType::Binary => {
117 let arr: &BinaryArray = values[0].as_binary();
118 for b in arr.iter() {
119 let Some(b) = b else {
120 if is_update {
121 self.has_null = true;
122 self.sum = None;
123 }
124 return Ok(());
125 };
126 let values = binlit_as_veclit(b)?;
127 let vec_column = DVectorView::from_slice(&values, values.len());
128 *self.inner(vec_column.len()) += vec_column;
129 }
130 }
131 _ => {
132 return Err(datafusion_common::DataFusionError::NotImplemented(format!(
133 "unsupported data type {} for `VEC_SUM`",
134 values[0].data_type()
135 )));
136 }
137 }
138 Ok(())
139 }
140}
141
142impl Accumulator for VectorSum {
143 fn state(&mut self) -> Result<Vec<ScalarValue>> {
144 self.evaluate().map(|v| vec![v])
145 }
146
147 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
148 self.update(values, true)
149 }
150
151 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
152 self.update(states, false)
153 }
154
155 fn evaluate(&mut self) -> Result<ScalarValue> {
156 match &self.sum {
157 None => Ok(ScalarValue::Binary(None)),
158 Some(vector) => Ok(ScalarValue::Binary(Some(veclit_to_binlit(
159 vector.as_slice(),
160 )))),
161 }
162 }
163
164 fn size(&self) -> usize {
165 size_of_val(self)
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use std::sync::Arc;
172
173 use arrow::array::StringArray;
174 use datatypes::scalars::ScalarVector;
175 use datatypes::vectors::{ConstantVector, StringVector, Vector};
176
177 use super::*;
178
179 #[test]
180 fn test_update_batch() {
181 let mut vec_sum = VectorSum::default();
183 vec_sum.update_batch(&[]).unwrap();
184 assert!(vec_sum.sum.is_none());
185 assert!(!vec_sum.has_null);
186 assert_eq!(ScalarValue::Binary(None), vec_sum.evaluate().unwrap());
187
188 let mut vec_sum = VectorSum::default();
190 let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![Some(
191 "[1.0,2.0,3.0]".to_string(),
192 )]))];
193 vec_sum.update_batch(&v).unwrap();
194 assert_eq!(
195 ScalarValue::Binary(Some(veclit_to_binlit(&[1.0, 2.0, 3.0]))),
196 vec_sum.evaluate().unwrap()
197 );
198
199 let mut vec_sum = VectorSum::default();
201 let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![Option::<String>::None]))];
202 vec_sum.update_batch(&v).unwrap();
203 assert_eq!(ScalarValue::Binary(None), vec_sum.evaluate().unwrap());
204
205 let mut vec_sum = VectorSum::default();
207 let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
208 Some("[1.0,2.0,3.0]".to_string()),
209 Some("[4.0,5.0,6.0]".to_string()),
210 Some("[7.0,8.0,9.0]".to_string()),
211 ]))];
212 vec_sum.update_batch(&v).unwrap();
213 assert_eq!(
214 ScalarValue::Binary(Some(veclit_to_binlit(&[12.0, 15.0, 18.0]))),
215 vec_sum.evaluate().unwrap()
216 );
217
218 let mut vec_sum = VectorSum::default();
220 let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
221 Some("[1.0,2.0,3.0]".to_string()),
222 None,
223 Some("[7.0,8.0,9.0]".to_string()),
224 ]))];
225 vec_sum.update_batch(&v).unwrap();
226 assert_eq!(ScalarValue::Binary(None), vec_sum.evaluate().unwrap());
227
228 let mut vec_sum = VectorSum::default();
230 let v: Vec<ArrayRef> = vec![
231 Arc::new(ConstantVector::new(
232 Arc::new(StringVector::from_vec(vec!["[1.0,2.0,3.0]".to_string()])),
233 4,
234 ))
235 .to_arrow_array(),
236 ];
237 vec_sum.update_batch(&v).unwrap();
238 assert_eq!(
239 ScalarValue::Binary(Some(veclit_to_binlit(&[4.0, 8.0, 12.0]))),
240 vec_sum.evaluate().unwrap()
241 );
242 }
243}