common_function/aggrs/vector/
product.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use common_macro::{as_aggr_func_creator, AggrFuncTypeStore};
18use common_query::error::{CreateAccumulatorSnafu, Error, InvalidFuncArgsSnafu};
19use common_query::logical_plan::{
20    create_aggregate_function, Accumulator, AggregateFunctionCreator,
21};
22use common_query::prelude::AccumulatorCreatorFunction;
23use datafusion_expr::AggregateUDF;
24use datatypes::prelude::{ConcreteDataType, Value, *};
25use datatypes::vectors::VectorRef;
26use nalgebra::{Const, DVectorView, Dyn, OVector};
27use snafu::ensure;
28
29use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
30
31/// Aggregates by multiplying elements across the same dimension, returns a vector.
32#[derive(Debug, Default)]
33pub struct VectorProduct {
34    product: Option<OVector<f32, Dyn>>,
35    has_null: bool,
36}
37
38#[as_aggr_func_creator]
39#[derive(Debug, Default, AggrFuncTypeStore)]
40pub struct VectorProductCreator {}
41
42impl AggregateFunctionCreator for VectorProductCreator {
43    fn creator(&self) -> AccumulatorCreatorFunction {
44        let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| {
45            ensure!(
46                types.len() == 1,
47                InvalidFuncArgsSnafu {
48                    err_msg: format!(
49                        "The length of the args is not correct, expect exactly one, have: {}",
50                        types.len()
51                    )
52                }
53            );
54            let input_type = &types[0];
55            match input_type {
56                ConcreteDataType::String(_) | ConcreteDataType::Binary(_) => {
57                    Ok(Box::new(VectorProduct::default()))
58                }
59                _ => {
60                    let err_msg = format!(
61                        "\"VEC_PRODUCT\" aggregate function not support data type {:?}",
62                        input_type.logical_type_id(),
63                    );
64                    CreateAccumulatorSnafu { err_msg }.fail()?
65                }
66            }
67        });
68        creator
69    }
70
71    fn output_type(&self) -> common_query::error::Result<ConcreteDataType> {
72        Ok(ConcreteDataType::binary_datatype())
73    }
74
75    fn state_types(&self) -> common_query::error::Result<Vec<ConcreteDataType>> {
76        Ok(vec![self.output_type()?])
77    }
78}
79
80impl VectorProduct {
81    /// Create a new `AggregateUDF` for the `vec_product` aggregate function.
82    pub fn uadf_impl() -> AggregateUDF {
83        create_aggregate_function(
84            "vec_product".to_string(),
85            1,
86            Arc::new(VectorProductCreator::default()),
87        )
88        .into()
89    }
90
91    fn inner(&mut self, len: usize) -> &mut OVector<f32, Dyn> {
92        self.product.get_or_insert_with(|| {
93            OVector::from_iterator_generic(Dyn(len), Const::<1>, (0..len).map(|_| 1.0))
94        })
95    }
96
97    fn update(&mut self, values: &[VectorRef], is_update: bool) -> Result<(), Error> {
98        if values.is_empty() || self.has_null {
99            return Ok(());
100        };
101        let column = &values[0];
102        let len = column.len();
103
104        match as_veclit_if_const(column)? {
105            Some(column) => {
106                let vec_column = DVectorView::from_slice(&column, column.len()).scale(len as f32);
107                *self.inner(vec_column.len()) =
108                    (*self.inner(vec_column.len())).component_mul(&vec_column);
109            }
110            None => {
111                for i in 0..len {
112                    let Some(arg0) = as_veclit(column.get_ref(i))? else {
113                        if is_update {
114                            self.has_null = true;
115                            self.product = None;
116                        }
117                        return Ok(());
118                    };
119                    let vec_column = DVectorView::from_slice(&arg0, arg0.len());
120                    *self.inner(vec_column.len()) =
121                        (*self.inner(vec_column.len())).component_mul(&vec_column);
122                }
123            }
124        }
125        Ok(())
126    }
127}
128
129impl Accumulator for VectorProduct {
130    fn state(&self) -> common_query::error::Result<Vec<Value>> {
131        self.evaluate().map(|v| vec![v])
132    }
133
134    fn update_batch(&mut self, values: &[VectorRef]) -> common_query::error::Result<()> {
135        self.update(values, true)
136    }
137
138    fn merge_batch(&mut self, states: &[VectorRef]) -> common_query::error::Result<()> {
139        self.update(states, false)
140    }
141
142    fn evaluate(&self) -> common_query::error::Result<Value> {
143        match &self.product {
144            None => Ok(Value::Null),
145            Some(vector) => {
146                let v = vector.as_slice();
147                Ok(Value::from(veclit_to_binlit(v)))
148            }
149        }
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use std::sync::Arc;
156
157    use datatypes::vectors::{ConstantVector, StringVector};
158
159    use super::*;
160
161    #[test]
162    fn test_update_batch() {
163        // test update empty batch, expect not updating anything
164        let mut vec_product = VectorProduct::default();
165        vec_product.update_batch(&[]).unwrap();
166        assert!(vec_product.product.is_none());
167        assert!(!vec_product.has_null);
168        assert_eq!(Value::Null, vec_product.evaluate().unwrap());
169
170        // test update one not-null value
171        let mut vec_product = VectorProduct::default();
172        let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![Some(
173            "[1.0,2.0,3.0]".to_string(),
174        )]))];
175        vec_product.update_batch(&v).unwrap();
176        assert_eq!(
177            Value::from(veclit_to_binlit(&[1.0, 2.0, 3.0])),
178            vec_product.evaluate().unwrap()
179        );
180
181        // test update one null value
182        let mut vec_product = VectorProduct::default();
183        let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![Option::<String>::None]))];
184        vec_product.update_batch(&v).unwrap();
185        assert_eq!(Value::Null, vec_product.evaluate().unwrap());
186
187        // test update no null-value batch
188        let mut vec_product = VectorProduct::default();
189        let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![
190            Some("[1.0,2.0,3.0]".to_string()),
191            Some("[4.0,5.0,6.0]".to_string()),
192            Some("[7.0,8.0,9.0]".to_string()),
193        ]))];
194        vec_product.update_batch(&v).unwrap();
195        assert_eq!(
196            Value::from(veclit_to_binlit(&[28.0, 80.0, 162.0])),
197            vec_product.evaluate().unwrap()
198        );
199
200        // test update null-value batch
201        let mut vec_product = VectorProduct::default();
202        let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![
203            Some("[1.0,2.0,3.0]".to_string()),
204            None,
205            Some("[7.0,8.0,9.0]".to_string()),
206        ]))];
207        vec_product.update_batch(&v).unwrap();
208        assert_eq!(Value::Null, vec_product.evaluate().unwrap());
209
210        // test update with constant vector
211        let mut vec_product = VectorProduct::default();
212        let v: Vec<VectorRef> = vec![Arc::new(ConstantVector::new(
213            Arc::new(StringVector::from_vec(vec!["[1.0,2.0,3.0]".to_string()])),
214            4,
215        ))];
216
217        vec_product.update_batch(&v).unwrap();
218
219        assert_eq!(
220            Value::from(veclit_to_binlit(&[4.0, 8.0, 12.0])),
221            vec_product.evaluate().unwrap()
222        );
223    }
224}