common_function/aggrs/vector/
avg.rs1use std::borrow::Cow;
16use std::sync::Arc;
17
18use arrow::array::{Array, ArrayRef, AsArray, BinaryArray, LargeStringArray, StringArray};
19use arrow::compute::sum;
20use arrow::datatypes::UInt64Type;
21use arrow_schema::{DataType, Field};
22use datafusion_common::{Result, ScalarValue};
23use datafusion_expr::{
24 Accumulator, AggregateUDF, Signature, SimpleAggregateUDF, TypeSignature, Volatility,
25};
26use datafusion_functions_aggregate_common::accumulator::AccumulatorArgs;
27use nalgebra::{Const, DVector, DVectorView, Dyn, OVector};
28
29use crate::scalars::vector::impl_conv::{
30 binlit_as_veclit, parse_veclit_from_strlit, veclit_to_binlit,
31};
32
33#[derive(Debug, Default)]
35pub struct VectorAvg {
36 sum: Option<OVector<f32, Dyn>>,
37 count: u64,
38}
39
40impl VectorAvg {
41 pub fn uadf_impl() -> AggregateUDF {
43 let signature = Signature::one_of(
44 vec![
45 TypeSignature::Exact(vec![DataType::Utf8]),
46 TypeSignature::Exact(vec![DataType::LargeUtf8]),
47 TypeSignature::Exact(vec![DataType::Binary]),
48 ],
49 Volatility::Immutable,
50 );
51 let udaf = SimpleAggregateUDF::new_with_signature(
52 "vec_avg",
53 signature,
54 DataType::Binary,
55 Arc::new(Self::accumulator),
56 vec![
57 Arc::new(Field::new("sum", DataType::Binary, true)),
58 Arc::new(Field::new("count", DataType::UInt64, true)),
59 ],
60 );
61 AggregateUDF::from(udaf)
62 }
63
64 fn accumulator(args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
65 if args.schema.fields().len() != 1 {
66 return Err(datafusion_common::DataFusionError::Internal(format!(
67 "expect creating `VEC_AVG` with only one input field, actual {}",
68 args.schema.fields().len()
69 )));
70 }
71
72 let t = args.schema.field(0).data_type();
73 if !matches!(t, DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary) {
74 return Err(datafusion_common::DataFusionError::Internal(format!(
75 "unexpected input datatype {t} when creating `VEC_AVG`"
76 )));
77 }
78
79 Ok(Box::new(VectorAvg::default()))
80 }
81
82 fn inner(&mut self, len: usize) -> &mut OVector<f32, Dyn> {
83 self.sum
84 .get_or_insert_with(|| OVector::zeros_generic(Dyn(len), Const::<1>))
85 }
86
87 fn update(&mut self, values: &[ArrayRef], is_update: bool) -> Result<()> {
88 if values.is_empty() {
89 return Ok(());
90 };
91
92 let vectors = match values[0].data_type() {
93 DataType::Utf8 => {
94 let arr: &StringArray = values[0].as_string();
95 arr.iter()
96 .filter_map(|x| x.map(|s| parse_veclit_from_strlit(s).map_err(Into::into)))
97 .map(|x| x.map(Cow::Owned))
98 .collect::<Result<Vec<_>>>()?
99 }
100 DataType::LargeUtf8 => {
101 let arr: &LargeStringArray = values[0].as_string();
102 arr.iter()
103 .filter_map(|x| x.map(|s| parse_veclit_from_strlit(s).map_err(Into::into)))
104 .map(|x: Result<Vec<f32>>| x.map(Cow::Owned))
105 .collect::<Result<Vec<_>>>()?
106 }
107 DataType::Binary => {
108 let arr: &BinaryArray = values[0].as_binary();
109 arr.iter()
110 .filter_map(|x| x.map(|b| binlit_as_veclit(b).map_err(Into::into)))
111 .collect::<Result<Vec<_>>>()?
112 }
113 _ => {
114 return Err(datafusion_common::DataFusionError::NotImplemented(format!(
115 "unsupported data type {} for `VEC_AVG`",
116 values[0].data_type()
117 )));
118 }
119 };
120
121 if vectors.is_empty() {
122 return Ok(());
123 }
124
125 let len = if is_update {
126 vectors.len() as u64
127 } else {
128 sum(values[1].as_primitive::<UInt64Type>()).unwrap_or_default()
129 };
130
131 let dims = vectors[0].len();
132 let mut sum = DVector::zeros(dims);
133 for v in vectors {
134 if v.len() != dims {
135 return Err(datafusion_common::DataFusionError::Execution(
136 "vectors length not match: VEC_AVG".to_string(),
137 ));
138 }
139 let v_view = DVectorView::from_slice(&v, dims);
140 sum += &v_view;
141 }
142
143 *self.inner(dims) += sum;
144 self.count += len;
145
146 Ok(())
147 }
148}
149
150impl Accumulator for VectorAvg {
151 fn state(&mut self) -> Result<Vec<ScalarValue>> {
152 let vector = match &self.sum {
153 None => ScalarValue::Binary(None),
154 Some(sum) => ScalarValue::Binary(Some(veclit_to_binlit(sum.as_slice()))),
155 };
156 Ok(vec![vector, ScalarValue::from(self.count)])
157 }
158
159 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
160 self.update(values, true)
161 }
162
163 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
164 self.update(states, false)
165 }
166
167 fn evaluate(&mut self) -> Result<ScalarValue> {
168 match &self.sum {
169 None => Ok(ScalarValue::Binary(None)),
170 Some(sum) => Ok(ScalarValue::Binary(Some(veclit_to_binlit(
171 (sum / self.count as f32).as_slice(),
172 )))),
173 }
174 }
175
176 fn size(&self) -> usize {
177 size_of_val(self)
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use std::sync::Arc;
184
185 use arrow::array::StringArray;
186 use datatypes::scalars::ScalarVector;
187 use datatypes::vectors::{ConstantVector, StringVector, Vector};
188
189 use super::*;
190
191 #[test]
192 fn test_update_batch() {
193 let mut vec_avg = VectorAvg::default();
195 vec_avg.update_batch(&[]).unwrap();
196 assert!(vec_avg.sum.is_none());
197 assert_eq!(ScalarValue::Binary(None), vec_avg.evaluate().unwrap());
198
199 let mut vec_avg = VectorAvg::default();
201 let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
202 Some("[1.0,2.0,3.0]".to_string()),
203 Some("[4.0,5.0,6.0]".to_string()),
204 ]))];
205 vec_avg.update_batch(&v).unwrap();
206 assert_eq!(
207 ScalarValue::Binary(Some(veclit_to_binlit(&[2.5, 3.5, 4.5]))),
208 vec_avg.evaluate().unwrap()
209 );
210
211 let mut vec_avg = VectorAvg::default();
213 let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![Option::<String>::None]))];
214 vec_avg.update_batch(&v).unwrap();
215 assert_eq!(ScalarValue::Binary(None), vec_avg.evaluate().unwrap());
216
217 let mut vec_avg = VectorAvg::default();
219 let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
220 Some("[1.0,2.0,3.0]".to_string()),
221 Some("[4.0,5.0,6.0]".to_string()),
222 Some("[7.0,8.0,9.0]".to_string()),
223 ]))];
224 vec_avg.update_batch(&v).unwrap();
225 assert_eq!(
226 ScalarValue::Binary(Some(veclit_to_binlit(&[4.0, 5.0, 6.0]))),
227 vec_avg.evaluate().unwrap()
228 );
229
230 let mut vec_avg = VectorAvg::default();
232 let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
233 Some("[1.0,2.0,3.0]".to_string()),
234 None,
235 Some("[7.0,8.0,9.0]".to_string()),
236 ]))];
237 vec_avg.update_batch(&v).unwrap();
238 assert_eq!(
239 ScalarValue::Binary(Some(veclit_to_binlit(&[4.0, 5.0, 6.0]))),
240 vec_avg.evaluate().unwrap()
241 );
242
243 let mut vec_avg = VectorAvg::default();
244 let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
245 None,
246 Some("[4.0,5.0,6.0]".to_string()),
247 Some("[7.0,8.0,9.0]".to_string()),
248 ]))];
249 vec_avg.update_batch(&v).unwrap();
250 assert_eq!(
251 ScalarValue::Binary(Some(veclit_to_binlit(&[5.5, 6.5, 7.5]))),
252 vec_avg.evaluate().unwrap()
253 );
254
255 let mut vec_avg = VectorAvg::default();
257 let v: Vec<ArrayRef> = vec![
258 Arc::new(ConstantVector::new(
259 Arc::new(StringVector::from_vec(vec!["[1.0,2.0,3.0]".to_string()])),
260 4,
261 ))
262 .to_arrow_array(),
263 ];
264 vec_avg.update_batch(&v).unwrap();
265 assert_eq!(
266 ScalarValue::Binary(Some(veclit_to_binlit(&[1.0, 2.0, 3.0]))),
267 vec_avg.evaluate().unwrap()
268 );
269 }
270}