common_function/aggrs/vector/
sum.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use arrow::array::{Array, ArrayRef, AsArray, BinaryArray, LargeStringArray, StringArray};
18use arrow_schema::{DataType, Field};
19use datafusion_common::{Result, ScalarValue};
20use datafusion_expr::{
21    Accumulator, AggregateUDF, Signature, SimpleAggregateUDF, TypeSignature, Volatility,
22};
23use datafusion_functions_aggregate_common::accumulator::AccumulatorArgs;
24use nalgebra::{Const, DVectorView, Dyn, OVector};
25
26use crate::scalars::vector::impl_conv::{
27    binlit_as_veclit, parse_veclit_from_strlit, veclit_to_binlit,
28};
29
30/// The accumulator for the `vec_sum` aggregate function.
31#[derive(Debug, Default)]
32pub struct VectorSum {
33    sum: Option<OVector<f32, Dyn>>,
34    has_null: bool,
35}
36
37impl VectorSum {
38    /// Create a new `AggregateUDF` for the `vec_sum` aggregate function.
39    pub fn uadf_impl() -> AggregateUDF {
40        let signature = Signature::one_of(
41            vec![
42                TypeSignature::Exact(vec![DataType::Utf8]),
43                TypeSignature::Exact(vec![DataType::Binary]),
44            ],
45            Volatility::Immutable,
46        );
47        let udaf = SimpleAggregateUDF::new_with_signature(
48            "vec_sum",
49            signature,
50            DataType::Binary,
51            Arc::new(Self::accumulator),
52            vec![Arc::new(Field::new("x", DataType::Binary, true))],
53        );
54        AggregateUDF::from(udaf)
55    }
56
57    fn accumulator(args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
58        if args.schema.fields().len() != 1 {
59            return Err(datafusion_common::DataFusionError::Internal(format!(
60                "expect creating `VEC_SUM` with only one input field, actual {}",
61                args.schema.fields().len()
62            )));
63        }
64
65        let t = args.schema.field(0).data_type();
66        if !matches!(t, DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary) {
67            return Err(datafusion_common::DataFusionError::Internal(format!(
68                "unexpected input datatype {t} when creating `VEC_SUM`"
69            )));
70        }
71
72        Ok(Box::new(VectorSum::default()))
73    }
74
75    fn inner(&mut self, len: usize) -> &mut OVector<f32, Dyn> {
76        self.sum
77            .get_or_insert_with(|| OVector::zeros_generic(Dyn(len), Const::<1>))
78    }
79
80    fn update(&mut self, values: &[ArrayRef], is_update: bool) -> Result<()> {
81        if values.is_empty() || self.has_null {
82            return Ok(());
83        };
84
85        match values[0].data_type() {
86            DataType::Utf8 => {
87                let arr: &StringArray = values[0].as_string();
88                for s in arr.iter() {
89                    let Some(s) = s else {
90                        if is_update {
91                            self.has_null = true;
92                            self.sum = None;
93                        }
94                        return Ok(());
95                    };
96                    let values = parse_veclit_from_strlit(s)?;
97                    let vec_column = DVectorView::from_slice(&values, values.len());
98                    *self.inner(vec_column.len()) += vec_column;
99                }
100            }
101            DataType::LargeUtf8 => {
102                let arr: &LargeStringArray = values[0].as_string();
103                for s in arr.iter() {
104                    let Some(s) = s else {
105                        if is_update {
106                            self.has_null = true;
107                            self.sum = None;
108                        }
109                        return Ok(());
110                    };
111                    let values = parse_veclit_from_strlit(s)?;
112                    let vec_column = DVectorView::from_slice(&values, values.len());
113                    *self.inner(vec_column.len()) += vec_column;
114                }
115            }
116            DataType::Binary => {
117                let arr: &BinaryArray = values[0].as_binary();
118                for b in arr.iter() {
119                    let Some(b) = b else {
120                        if is_update {
121                            self.has_null = true;
122                            self.sum = None;
123                        }
124                        return Ok(());
125                    };
126                    let values = binlit_as_veclit(b)?;
127                    let vec_column = DVectorView::from_slice(&values, values.len());
128                    *self.inner(vec_column.len()) += vec_column;
129                }
130            }
131            _ => {
132                return Err(datafusion_common::DataFusionError::NotImplemented(format!(
133                    "unsupported data type {} for `VEC_SUM`",
134                    values[0].data_type()
135                )));
136            }
137        }
138        Ok(())
139    }
140}
141
142impl Accumulator for VectorSum {
143    fn state(&mut self) -> Result<Vec<ScalarValue>> {
144        self.evaluate().map(|v| vec![v])
145    }
146
147    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
148        self.update(values, true)
149    }
150
151    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
152        self.update(states, false)
153    }
154
155    fn evaluate(&mut self) -> Result<ScalarValue> {
156        match &self.sum {
157            None => Ok(ScalarValue::Binary(None)),
158            Some(vector) => Ok(ScalarValue::Binary(Some(veclit_to_binlit(
159                vector.as_slice(),
160            )))),
161        }
162    }
163
164    fn size(&self) -> usize {
165        size_of_val(self)
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use std::sync::Arc;
172
173    use arrow::array::StringArray;
174    use datatypes::scalars::ScalarVector;
175    use datatypes::vectors::{ConstantVector, StringVector, Vector};
176
177    use super::*;
178
179    #[test]
180    fn test_update_batch() {
181        // test update empty batch, expect not updating anything
182        let mut vec_sum = VectorSum::default();
183        vec_sum.update_batch(&[]).unwrap();
184        assert!(vec_sum.sum.is_none());
185        assert!(!vec_sum.has_null);
186        assert_eq!(ScalarValue::Binary(None), vec_sum.evaluate().unwrap());
187
188        // test update one not-null value
189        let mut vec_sum = VectorSum::default();
190        let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![Some(
191            "[1.0,2.0,3.0]".to_string(),
192        )]))];
193        vec_sum.update_batch(&v).unwrap();
194        assert_eq!(
195            ScalarValue::Binary(Some(veclit_to_binlit(&[1.0, 2.0, 3.0]))),
196            vec_sum.evaluate().unwrap()
197        );
198
199        // test update one null value
200        let mut vec_sum = VectorSum::default();
201        let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![Option::<String>::None]))];
202        vec_sum.update_batch(&v).unwrap();
203        assert_eq!(ScalarValue::Binary(None), vec_sum.evaluate().unwrap());
204
205        // test update no null-value batch
206        let mut vec_sum = VectorSum::default();
207        let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
208            Some("[1.0,2.0,3.0]".to_string()),
209            Some("[4.0,5.0,6.0]".to_string()),
210            Some("[7.0,8.0,9.0]".to_string()),
211        ]))];
212        vec_sum.update_batch(&v).unwrap();
213        assert_eq!(
214            ScalarValue::Binary(Some(veclit_to_binlit(&[12.0, 15.0, 18.0]))),
215            vec_sum.evaluate().unwrap()
216        );
217
218        // test update null-value batch
219        let mut vec_sum = VectorSum::default();
220        let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
221            Some("[1.0,2.0,3.0]".to_string()),
222            None,
223            Some("[7.0,8.0,9.0]".to_string()),
224        ]))];
225        vec_sum.update_batch(&v).unwrap();
226        assert_eq!(ScalarValue::Binary(None), vec_sum.evaluate().unwrap());
227
228        // test update with constant vector
229        let mut vec_sum = VectorSum::default();
230        let v: Vec<ArrayRef> = vec![
231            Arc::new(ConstantVector::new(
232                Arc::new(StringVector::from_vec(vec!["[1.0,2.0,3.0]".to_string()])),
233                4,
234            ))
235            .to_arrow_array(),
236        ];
237        vec_sum.update_batch(&v).unwrap();
238        assert_eq!(
239            ScalarValue::Binary(Some(veclit_to_binlit(&[4.0, 8.0, 12.0]))),
240            vec_sum.evaluate().unwrap()
241        );
242    }
243}