common_function/aggrs/vector/
avg.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::sync::Arc;
17
18use arrow::array::{Array, ArrayRef, AsArray, BinaryArray, LargeStringArray, StringArray};
19use arrow::compute::sum;
20use arrow::datatypes::UInt64Type;
21use arrow_schema::{DataType, Field};
22use datafusion_common::{Result, ScalarValue};
23use datafusion_expr::{
24    Accumulator, AggregateUDF, Signature, SimpleAggregateUDF, TypeSignature, Volatility,
25};
26use datafusion_functions_aggregate_common::accumulator::AccumulatorArgs;
27use nalgebra::{Const, DVector, DVectorView, Dyn, OVector};
28
29use crate::scalars::vector::impl_conv::{
30    binlit_as_veclit, parse_veclit_from_strlit, veclit_to_binlit,
31};
32
33/// The accumulator for the `vec_avg` aggregate function.
34#[derive(Debug, Default)]
35pub struct VectorAvg {
36    sum: Option<OVector<f32, Dyn>>,
37    count: u64,
38}
39
40impl VectorAvg {
41    /// Create a new `AggregateUDF` for the `vec_avg` aggregate function.
42    pub fn uadf_impl() -> AggregateUDF {
43        let signature = Signature::one_of(
44            vec![
45                TypeSignature::Exact(vec![DataType::Utf8]),
46                TypeSignature::Exact(vec![DataType::LargeUtf8]),
47                TypeSignature::Exact(vec![DataType::Binary]),
48            ],
49            Volatility::Immutable,
50        );
51        let udaf = SimpleAggregateUDF::new_with_signature(
52            "vec_avg",
53            signature,
54            DataType::Binary,
55            Arc::new(Self::accumulator),
56            vec![
57                Arc::new(Field::new("sum", DataType::Binary, true)),
58                Arc::new(Field::new("count", DataType::UInt64, true)),
59            ],
60        );
61        AggregateUDF::from(udaf)
62    }
63
64    fn accumulator(args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
65        if args.schema.fields().len() != 1 {
66            return Err(datafusion_common::DataFusionError::Internal(format!(
67                "expect creating `VEC_AVG` with only one input field, actual {}",
68                args.schema.fields().len()
69            )));
70        }
71
72        let t = args.schema.field(0).data_type();
73        if !matches!(t, DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary) {
74            return Err(datafusion_common::DataFusionError::Internal(format!(
75                "unexpected input datatype {t} when creating `VEC_AVG`"
76            )));
77        }
78
79        Ok(Box::new(VectorAvg::default()))
80    }
81
82    fn inner(&mut self, len: usize) -> &mut OVector<f32, Dyn> {
83        self.sum
84            .get_or_insert_with(|| OVector::zeros_generic(Dyn(len), Const::<1>))
85    }
86
87    fn update(&mut self, values: &[ArrayRef], is_update: bool) -> Result<()> {
88        if values.is_empty() {
89            return Ok(());
90        };
91
92        let vectors = match values[0].data_type() {
93            DataType::Utf8 => {
94                let arr: &StringArray = values[0].as_string();
95                arr.iter()
96                    .filter_map(|x| x.map(|s| parse_veclit_from_strlit(s).map_err(Into::into)))
97                    .map(|x| x.map(Cow::Owned))
98                    .collect::<Result<Vec<_>>>()?
99            }
100            DataType::LargeUtf8 => {
101                let arr: &LargeStringArray = values[0].as_string();
102                arr.iter()
103                    .filter_map(|x| x.map(|s| parse_veclit_from_strlit(s).map_err(Into::into)))
104                    .map(|x: Result<Vec<f32>>| x.map(Cow::Owned))
105                    .collect::<Result<Vec<_>>>()?
106            }
107            DataType::Binary => {
108                let arr: &BinaryArray = values[0].as_binary();
109                arr.iter()
110                    .filter_map(|x| x.map(|b| binlit_as_veclit(b).map_err(Into::into)))
111                    .collect::<Result<Vec<_>>>()?
112            }
113            _ => {
114                return Err(datafusion_common::DataFusionError::NotImplemented(format!(
115                    "unsupported data type {} for `VEC_AVG`",
116                    values[0].data_type()
117                )));
118            }
119        };
120
121        if vectors.is_empty() {
122            return Ok(());
123        }
124
125        let len = if is_update {
126            vectors.len() as u64
127        } else {
128            sum(values[1].as_primitive::<UInt64Type>()).unwrap_or_default()
129        };
130
131        let dims = vectors[0].len();
132        let mut sum = DVector::zeros(dims);
133        for v in vectors {
134            if v.len() != dims {
135                return Err(datafusion_common::DataFusionError::Execution(
136                    "vectors length not match: VEC_AVG".to_string(),
137                ));
138            }
139            let v_view = DVectorView::from_slice(&v, dims);
140            sum += &v_view;
141        }
142
143        *self.inner(dims) += sum;
144        self.count += len;
145
146        Ok(())
147    }
148}
149
150impl Accumulator for VectorAvg {
151    fn state(&mut self) -> Result<Vec<ScalarValue>> {
152        let vector = match &self.sum {
153            None => ScalarValue::Binary(None),
154            Some(sum) => ScalarValue::Binary(Some(veclit_to_binlit(sum.as_slice()))),
155        };
156        Ok(vec![vector, ScalarValue::from(self.count)])
157    }
158
159    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
160        self.update(values, true)
161    }
162
163    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
164        self.update(states, false)
165    }
166
167    fn evaluate(&mut self) -> Result<ScalarValue> {
168        match &self.sum {
169            None => Ok(ScalarValue::Binary(None)),
170            Some(sum) => Ok(ScalarValue::Binary(Some(veclit_to_binlit(
171                (sum / self.count as f32).as_slice(),
172            )))),
173        }
174    }
175
176    fn size(&self) -> usize {
177        size_of_val(self)
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use std::sync::Arc;
184
185    use arrow::array::StringArray;
186    use datatypes::scalars::ScalarVector;
187    use datatypes::vectors::{ConstantVector, StringVector, Vector};
188
189    use super::*;
190
191    #[test]
192    fn test_update_batch() {
193        // test update empty batch, expect not updating anything
194        let mut vec_avg = VectorAvg::default();
195        vec_avg.update_batch(&[]).unwrap();
196        assert!(vec_avg.sum.is_none());
197        assert_eq!(ScalarValue::Binary(None), vec_avg.evaluate().unwrap());
198
199        // test update one not-null value
200        let mut vec_avg = VectorAvg::default();
201        let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
202            Some("[1.0,2.0,3.0]".to_string()),
203            Some("[4.0,5.0,6.0]".to_string()),
204        ]))];
205        vec_avg.update_batch(&v).unwrap();
206        assert_eq!(
207            ScalarValue::Binary(Some(veclit_to_binlit(&[2.5, 3.5, 4.5]))),
208            vec_avg.evaluate().unwrap()
209        );
210
211        // test update one null value
212        let mut vec_avg = VectorAvg::default();
213        let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![Option::<String>::None]))];
214        vec_avg.update_batch(&v).unwrap();
215        assert_eq!(ScalarValue::Binary(None), vec_avg.evaluate().unwrap());
216
217        // test update no null-value batch
218        let mut vec_avg = VectorAvg::default();
219        let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
220            Some("[1.0,2.0,3.0]".to_string()),
221            Some("[4.0,5.0,6.0]".to_string()),
222            Some("[7.0,8.0,9.0]".to_string()),
223        ]))];
224        vec_avg.update_batch(&v).unwrap();
225        assert_eq!(
226            ScalarValue::Binary(Some(veclit_to_binlit(&[4.0, 5.0, 6.0]))),
227            vec_avg.evaluate().unwrap()
228        );
229
230        // test update null-value batch
231        let mut vec_avg = VectorAvg::default();
232        let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
233            Some("[1.0,2.0,3.0]".to_string()),
234            None,
235            Some("[7.0,8.0,9.0]".to_string()),
236        ]))];
237        vec_avg.update_batch(&v).unwrap();
238        assert_eq!(
239            ScalarValue::Binary(Some(veclit_to_binlit(&[4.0, 5.0, 6.0]))),
240            vec_avg.evaluate().unwrap()
241        );
242
243        let mut vec_avg = VectorAvg::default();
244        let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
245            None,
246            Some("[4.0,5.0,6.0]".to_string()),
247            Some("[7.0,8.0,9.0]".to_string()),
248        ]))];
249        vec_avg.update_batch(&v).unwrap();
250        assert_eq!(
251            ScalarValue::Binary(Some(veclit_to_binlit(&[5.5, 6.5, 7.5]))),
252            vec_avg.evaluate().unwrap()
253        );
254
255        // test update with constant vector
256        let mut vec_avg = VectorAvg::default();
257        let v: Vec<ArrayRef> = vec![
258            Arc::new(ConstantVector::new(
259                Arc::new(StringVector::from_vec(vec!["[1.0,2.0,3.0]".to_string()])),
260                4,
261            ))
262            .to_arrow_array(),
263        ];
264        vec_avg.update_batch(&v).unwrap();
265        assert_eq!(
266            ScalarValue::Binary(Some(veclit_to_binlit(&[1.0, 2.0, 3.0]))),
267            vec_avg.evaluate().unwrap()
268        );
269    }
270}