common_function/scalars/vector/
product.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use common_macro::{as_aggr_func_creator, AggrFuncTypeStore};
18use common_query::error::{CreateAccumulatorSnafu, Error, InvalidFuncArgsSnafu};
19use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
20use common_query::prelude::AccumulatorCreatorFunction;
21use datatypes::prelude::{ConcreteDataType, Value, *};
22use datatypes::vectors::VectorRef;
23use nalgebra::{Const, DVectorView, Dyn, OVector};
24use snafu::ensure;
25
26use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
27
28/// Aggregates by multiplying elements across the same dimension, returns a vector.
29#[derive(Debug, Default)]
30pub struct VectorProduct {
31    product: Option<OVector<f32, Dyn>>,
32    has_null: bool,
33}
34
35#[as_aggr_func_creator]
36#[derive(Debug, Default, AggrFuncTypeStore)]
37pub struct VectorProductCreator {}
38
39impl AggregateFunctionCreator for VectorProductCreator {
40    fn creator(&self) -> AccumulatorCreatorFunction {
41        let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| {
42            ensure!(
43                types.len() == 1,
44                InvalidFuncArgsSnafu {
45                    err_msg: format!(
46                        "The length of the args is not correct, expect exactly one, have: {}",
47                        types.len()
48                    )
49                }
50            );
51            let input_type = &types[0];
52            match input_type {
53                ConcreteDataType::String(_) | ConcreteDataType::Binary(_) => {
54                    Ok(Box::new(VectorProduct::default()))
55                }
56                _ => {
57                    let err_msg = format!(
58                        "\"VEC_PRODUCT\" aggregate function not support data type {:?}",
59                        input_type.logical_type_id(),
60                    );
61                    CreateAccumulatorSnafu { err_msg }.fail()?
62                }
63            }
64        });
65        creator
66    }
67
68    fn output_type(&self) -> common_query::error::Result<ConcreteDataType> {
69        Ok(ConcreteDataType::binary_datatype())
70    }
71
72    fn state_types(&self) -> common_query::error::Result<Vec<ConcreteDataType>> {
73        Ok(vec![self.output_type()?])
74    }
75}
76
77impl VectorProduct {
78    fn inner(&mut self, len: usize) -> &mut OVector<f32, Dyn> {
79        self.product.get_or_insert_with(|| {
80            OVector::from_iterator_generic(Dyn(len), Const::<1>, (0..len).map(|_| 1.0))
81        })
82    }
83
84    fn update(&mut self, values: &[VectorRef], is_update: bool) -> Result<(), Error> {
85        if values.is_empty() || self.has_null {
86            return Ok(());
87        };
88        let column = &values[0];
89        let len = column.len();
90
91        match as_veclit_if_const(column)? {
92            Some(column) => {
93                let vec_column = DVectorView::from_slice(&column, column.len()).scale(len as f32);
94                *self.inner(vec_column.len()) =
95                    (*self.inner(vec_column.len())).component_mul(&vec_column);
96            }
97            None => {
98                for i in 0..len {
99                    let Some(arg0) = as_veclit(column.get_ref(i))? else {
100                        if is_update {
101                            self.has_null = true;
102                            self.product = None;
103                        }
104                        return Ok(());
105                    };
106                    let vec_column = DVectorView::from_slice(&arg0, arg0.len());
107                    *self.inner(vec_column.len()) =
108                        (*self.inner(vec_column.len())).component_mul(&vec_column);
109                }
110            }
111        }
112        Ok(())
113    }
114}
115
116impl Accumulator for VectorProduct {
117    fn state(&self) -> common_query::error::Result<Vec<Value>> {
118        self.evaluate().map(|v| vec![v])
119    }
120
121    fn update_batch(&mut self, values: &[VectorRef]) -> common_query::error::Result<()> {
122        self.update(values, true)
123    }
124
125    fn merge_batch(&mut self, states: &[VectorRef]) -> common_query::error::Result<()> {
126        self.update(states, false)
127    }
128
129    fn evaluate(&self) -> common_query::error::Result<Value> {
130        match &self.product {
131            None => Ok(Value::Null),
132            Some(vector) => {
133                let v = vector.as_slice();
134                Ok(Value::from(veclit_to_binlit(v)))
135            }
136        }
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use std::sync::Arc;
143
144    use datatypes::vectors::{ConstantVector, StringVector};
145
146    use super::*;
147
148    #[test]
149    fn test_update_batch() {
150        // test update empty batch, expect not updating anything
151        let mut vec_product = VectorProduct::default();
152        vec_product.update_batch(&[]).unwrap();
153        assert!(vec_product.product.is_none());
154        assert!(!vec_product.has_null);
155        assert_eq!(Value::Null, vec_product.evaluate().unwrap());
156
157        // test update one not-null value
158        let mut vec_product = VectorProduct::default();
159        let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![Some(
160            "[1.0,2.0,3.0]".to_string(),
161        )]))];
162        vec_product.update_batch(&v).unwrap();
163        assert_eq!(
164            Value::from(veclit_to_binlit(&[1.0, 2.0, 3.0])),
165            vec_product.evaluate().unwrap()
166        );
167
168        // test update one null value
169        let mut vec_product = VectorProduct::default();
170        let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![Option::<String>::None]))];
171        vec_product.update_batch(&v).unwrap();
172        assert_eq!(Value::Null, vec_product.evaluate().unwrap());
173
174        // test update no null-value batch
175        let mut vec_product = VectorProduct::default();
176        let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![
177            Some("[1.0,2.0,3.0]".to_string()),
178            Some("[4.0,5.0,6.0]".to_string()),
179            Some("[7.0,8.0,9.0]".to_string()),
180        ]))];
181        vec_product.update_batch(&v).unwrap();
182        assert_eq!(
183            Value::from(veclit_to_binlit(&[28.0, 80.0, 162.0])),
184            vec_product.evaluate().unwrap()
185        );
186
187        // test update null-value batch
188        let mut vec_product = VectorProduct::default();
189        let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![
190            Some("[1.0,2.0,3.0]".to_string()),
191            None,
192            Some("[7.0,8.0,9.0]".to_string()),
193        ]))];
194        vec_product.update_batch(&v).unwrap();
195        assert_eq!(Value::Null, vec_product.evaluate().unwrap());
196
197        // test update with constant vector
198        let mut vec_product = VectorProduct::default();
199        let v: Vec<VectorRef> = vec![Arc::new(ConstantVector::new(
200            Arc::new(StringVector::from_vec(vec!["[1.0,2.0,3.0]".to_string()])),
201            4,
202        ))];
203
204        vec_product.update_batch(&v).unwrap();
205
206        assert_eq!(
207            Value::from(veclit_to_binlit(&[4.0, 8.0, 12.0])),
208            vec_product.evaluate().unwrap()
209        );
210    }
211}