common_function/scalars/vector/
sum.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use common_macro::{as_aggr_func_creator, AggrFuncTypeStore};
18use common_query::error::{CreateAccumulatorSnafu, Error, InvalidFuncArgsSnafu};
19use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
20use common_query::prelude::AccumulatorCreatorFunction;
21use datatypes::prelude::{ConcreteDataType, Value, *};
22use datatypes::vectors::VectorRef;
23use nalgebra::{Const, DVectorView, Dyn, OVector};
24use snafu::ensure;
25
26use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
27
28#[derive(Debug, Default)]
29pub struct VectorSum {
30    sum: Option<OVector<f32, Dyn>>,
31    has_null: bool,
32}
33
34#[as_aggr_func_creator]
35#[derive(Debug, Default, AggrFuncTypeStore)]
36pub struct VectorSumCreator {}
37
38impl AggregateFunctionCreator for VectorSumCreator {
39    fn creator(&self) -> AccumulatorCreatorFunction {
40        let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| {
41            ensure!(
42                types.len() == 1,
43                InvalidFuncArgsSnafu {
44                    err_msg: format!(
45                        "The length of the args is not correct, expect exactly one, have: {}",
46                        types.len()
47                    )
48                }
49            );
50            let input_type = &types[0];
51            match input_type {
52                ConcreteDataType::String(_) | ConcreteDataType::Binary(_) => {
53                    Ok(Box::new(VectorSum::default()))
54                }
55                _ => {
56                    let err_msg = format!(
57                        "\"VEC_SUM\" aggregate function not support data type {:?}",
58                        input_type.logical_type_id(),
59                    );
60                    CreateAccumulatorSnafu { err_msg }.fail()?
61                }
62            }
63        });
64        creator
65    }
66
67    fn output_type(&self) -> common_query::error::Result<ConcreteDataType> {
68        Ok(ConcreteDataType::binary_datatype())
69    }
70
71    fn state_types(&self) -> common_query::error::Result<Vec<ConcreteDataType>> {
72        Ok(vec![self.output_type()?])
73    }
74}
75
76impl VectorSum {
77    fn inner(&mut self, len: usize) -> &mut OVector<f32, Dyn> {
78        self.sum
79            .get_or_insert_with(|| OVector::zeros_generic(Dyn(len), Const::<1>))
80    }
81
82    fn update(&mut self, values: &[VectorRef], is_update: bool) -> Result<(), Error> {
83        if values.is_empty() || self.has_null {
84            return Ok(());
85        };
86        let column = &values[0];
87        let len = column.len();
88
89        match as_veclit_if_const(column)? {
90            Some(column) => {
91                let vec_column = DVectorView::from_slice(&column, column.len()).scale(len as f32);
92                *self.inner(vec_column.len()) += vec_column;
93            }
94            None => {
95                for i in 0..len {
96                    let Some(arg0) = as_veclit(column.get_ref(i))? else {
97                        if is_update {
98                            self.has_null = true;
99                            self.sum = None;
100                        }
101                        return Ok(());
102                    };
103                    let vec_column = DVectorView::from_slice(&arg0, arg0.len());
104                    *self.inner(vec_column.len()) += vec_column;
105                }
106            }
107        }
108        Ok(())
109    }
110}
111
112impl Accumulator for VectorSum {
113    fn state(&self) -> common_query::error::Result<Vec<Value>> {
114        self.evaluate().map(|v| vec![v])
115    }
116
117    fn update_batch(&mut self, values: &[VectorRef]) -> common_query::error::Result<()> {
118        self.update(values, true)
119    }
120
121    fn merge_batch(&mut self, states: &[VectorRef]) -> common_query::error::Result<()> {
122        self.update(states, false)
123    }
124
125    fn evaluate(&self) -> common_query::error::Result<Value> {
126        match &self.sum {
127            None => Ok(Value::Null),
128            Some(vector) => Ok(Value::from(veclit_to_binlit(vector.as_slice()))),
129        }
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use std::sync::Arc;
136
137    use datatypes::vectors::{ConstantVector, StringVector};
138
139    use super::*;
140
141    #[test]
142    fn test_update_batch() {
143        // test update empty batch, expect not updating anything
144        let mut vec_sum = VectorSum::default();
145        vec_sum.update_batch(&[]).unwrap();
146        assert!(vec_sum.sum.is_none());
147        assert!(!vec_sum.has_null);
148        assert_eq!(Value::Null, vec_sum.evaluate().unwrap());
149
150        // test update one not-null value
151        let mut vec_sum = VectorSum::default();
152        let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![Some(
153            "[1.0,2.0,3.0]".to_string(),
154        )]))];
155        vec_sum.update_batch(&v).unwrap();
156        assert_eq!(
157            Value::from(veclit_to_binlit(&[1.0, 2.0, 3.0])),
158            vec_sum.evaluate().unwrap()
159        );
160
161        // test update one null value
162        let mut vec_sum = VectorSum::default();
163        let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![Option::<String>::None]))];
164        vec_sum.update_batch(&v).unwrap();
165        assert_eq!(Value::Null, vec_sum.evaluate().unwrap());
166
167        // test update no null-value batch
168        let mut vec_sum = VectorSum::default();
169        let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![
170            Some("[1.0,2.0,3.0]".to_string()),
171            Some("[4.0,5.0,6.0]".to_string()),
172            Some("[7.0,8.0,9.0]".to_string()),
173        ]))];
174        vec_sum.update_batch(&v).unwrap();
175        assert_eq!(
176            Value::from(veclit_to_binlit(&[12.0, 15.0, 18.0])),
177            vec_sum.evaluate().unwrap()
178        );
179
180        // test update null-value batch
181        let mut vec_sum = VectorSum::default();
182        let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![
183            Some("[1.0,2.0,3.0]".to_string()),
184            None,
185            Some("[7.0,8.0,9.0]".to_string()),
186        ]))];
187        vec_sum.update_batch(&v).unwrap();
188        assert_eq!(Value::Null, vec_sum.evaluate().unwrap());
189
190        // test update with constant vector
191        let mut vec_sum = VectorSum::default();
192        let v: Vec<VectorRef> = vec![Arc::new(ConstantVector::new(
193            Arc::new(StringVector::from_vec(vec!["[1.0,2.0,3.0]".to_string()])),
194            4,
195        ))];
196        vec_sum.update_batch(&v).unwrap();
197        assert_eq!(
198            Value::from(veclit_to_binlit(&[4.0, 8.0, 12.0])),
199            vec_sum.evaluate().unwrap()
200        );
201    }
202}