common_function/aggrs/vector/
sum.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use arrow::array::{Array, ArrayRef, AsArray, BinaryArray, StringArray};
18use arrow_schema::{DataType, Field};
19use datafusion_common::{Result, ScalarValue};
20use datafusion_expr::{
21    Accumulator, AggregateUDF, Signature, SimpleAggregateUDF, TypeSignature, Volatility,
22};
23use datafusion_functions_aggregate_common::accumulator::AccumulatorArgs;
24use nalgebra::{Const, DVectorView, Dyn, OVector};
25
26use crate::scalars::vector::impl_conv::{
27    binlit_as_veclit, parse_veclit_from_strlit, veclit_to_binlit,
28};
29
30/// The accumulator for the `vec_sum` aggregate function.
31#[derive(Debug, Default)]
32pub struct VectorSum {
33    sum: Option<OVector<f32, Dyn>>,
34    has_null: bool,
35}
36
37impl VectorSum {
38    /// Create a new `AggregateUDF` for the `vec_sum` aggregate function.
39    pub fn uadf_impl() -> AggregateUDF {
40        let signature = Signature::one_of(
41            vec![
42                TypeSignature::Exact(vec![DataType::Utf8]),
43                TypeSignature::Exact(vec![DataType::Binary]),
44            ],
45            Volatility::Immutable,
46        );
47        let udaf = SimpleAggregateUDF::new_with_signature(
48            "vec_sum",
49            signature,
50            DataType::Binary,
51            Arc::new(Self::accumulator),
52            vec![Arc::new(Field::new("x", DataType::Binary, true))],
53        );
54        AggregateUDF::from(udaf)
55    }
56
57    fn accumulator(args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
58        if args.schema.fields().len() != 1 {
59            return Err(datafusion_common::DataFusionError::Internal(format!(
60                "expect creating `VEC_SUM` with only one input field, actual {}",
61                args.schema.fields().len()
62            )));
63        }
64
65        let t = args.schema.field(0).data_type();
66        if !matches!(t, DataType::Utf8 | DataType::Binary) {
67            return Err(datafusion_common::DataFusionError::Internal(format!(
68                "unexpected input datatype {t} when creating `VEC_SUM`"
69            )));
70        }
71
72        Ok(Box::new(VectorSum::default()))
73    }
74
75    fn inner(&mut self, len: usize) -> &mut OVector<f32, Dyn> {
76        self.sum
77            .get_or_insert_with(|| OVector::zeros_generic(Dyn(len), Const::<1>))
78    }
79
80    fn update(&mut self, values: &[ArrayRef], is_update: bool) -> Result<()> {
81        if values.is_empty() || self.has_null {
82            return Ok(());
83        };
84
85        match values[0].data_type() {
86            DataType::Utf8 => {
87                let arr: &StringArray = values[0].as_string();
88                for s in arr.iter() {
89                    let Some(s) = s else {
90                        if is_update {
91                            self.has_null = true;
92                            self.sum = None;
93                        }
94                        return Ok(());
95                    };
96                    let values = parse_veclit_from_strlit(s)?;
97                    let vec_column = DVectorView::from_slice(&values, values.len());
98                    *self.inner(vec_column.len()) += vec_column;
99                }
100            }
101            DataType::Binary => {
102                let arr: &BinaryArray = values[0].as_binary();
103                for b in arr.iter() {
104                    let Some(b) = b else {
105                        if is_update {
106                            self.has_null = true;
107                            self.sum = None;
108                        }
109                        return Ok(());
110                    };
111                    let values = binlit_as_veclit(b)?;
112                    let vec_column = DVectorView::from_slice(&values, values.len());
113                    *self.inner(vec_column.len()) += vec_column;
114                }
115            }
116            _ => {
117                return Err(datafusion_common::DataFusionError::NotImplemented(format!(
118                    "unsupported data type {} for `VEC_SUM`",
119                    values[0].data_type()
120                )))
121            }
122        }
123        Ok(())
124    }
125}
126
127impl Accumulator for VectorSum {
128    fn state(&mut self) -> Result<Vec<ScalarValue>> {
129        self.evaluate().map(|v| vec![v])
130    }
131
132    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
133        self.update(values, true)
134    }
135
136    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
137        self.update(states, false)
138    }
139
140    fn evaluate(&mut self) -> Result<ScalarValue> {
141        match &self.sum {
142            None => Ok(ScalarValue::Binary(None)),
143            Some(vector) => Ok(ScalarValue::Binary(Some(veclit_to_binlit(
144                vector.as_slice(),
145            )))),
146        }
147    }
148
149    fn size(&self) -> usize {
150        size_of_val(self)
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use std::sync::Arc;
157
158    use arrow::array::StringArray;
159    use datatypes::scalars::ScalarVector;
160    use datatypes::vectors::{ConstantVector, StringVector, Vector};
161
162    use super::*;
163
164    #[test]
165    fn test_update_batch() {
166        // test update empty batch, expect not updating anything
167        let mut vec_sum = VectorSum::default();
168        vec_sum.update_batch(&[]).unwrap();
169        assert!(vec_sum.sum.is_none());
170        assert!(!vec_sum.has_null);
171        assert_eq!(ScalarValue::Binary(None), vec_sum.evaluate().unwrap());
172
173        // test update one not-null value
174        let mut vec_sum = VectorSum::default();
175        let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![Some(
176            "[1.0,2.0,3.0]".to_string(),
177        )]))];
178        vec_sum.update_batch(&v).unwrap();
179        assert_eq!(
180            ScalarValue::Binary(Some(veclit_to_binlit(&[1.0, 2.0, 3.0]))),
181            vec_sum.evaluate().unwrap()
182        );
183
184        // test update one null value
185        let mut vec_sum = VectorSum::default();
186        let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![Option::<String>::None]))];
187        vec_sum.update_batch(&v).unwrap();
188        assert_eq!(ScalarValue::Binary(None), vec_sum.evaluate().unwrap());
189
190        // test update no null-value batch
191        let mut vec_sum = VectorSum::default();
192        let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
193            Some("[1.0,2.0,3.0]".to_string()),
194            Some("[4.0,5.0,6.0]".to_string()),
195            Some("[7.0,8.0,9.0]".to_string()),
196        ]))];
197        vec_sum.update_batch(&v).unwrap();
198        assert_eq!(
199            ScalarValue::Binary(Some(veclit_to_binlit(&[12.0, 15.0, 18.0]))),
200            vec_sum.evaluate().unwrap()
201        );
202
203        // test update null-value batch
204        let mut vec_sum = VectorSum::default();
205        let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
206            Some("[1.0,2.0,3.0]".to_string()),
207            None,
208            Some("[7.0,8.0,9.0]".to_string()),
209        ]))];
210        vec_sum.update_batch(&v).unwrap();
211        assert_eq!(ScalarValue::Binary(None), vec_sum.evaluate().unwrap());
212
213        // test update with constant vector
214        let mut vec_sum = VectorSum::default();
215        let v: Vec<ArrayRef> = vec![Arc::new(ConstantVector::new(
216            Arc::new(StringVector::from_vec(vec!["[1.0,2.0,3.0]".to_string()])),
217            4,
218        ))
219        .to_arrow_array()];
220        vec_sum.update_batch(&v).unwrap();
221        assert_eq!(
222            ScalarValue::Binary(Some(veclit_to_binlit(&[4.0, 8.0, 12.0]))),
223            vec_sum.evaluate().unwrap()
224        );
225    }
226}