common_function/aggrs/vector/
product.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::sync::Arc;
17
18use arrow::array::{Array, ArrayRef, AsArray, BinaryArray, StringArray};
19use arrow_schema::{DataType, Field};
20use datafusion::logical_expr::{Signature, TypeSignature, Volatility};
21use datafusion_common::{Result, ScalarValue};
22use datafusion_expr::{Accumulator, AggregateUDF, SimpleAggregateUDF};
23use datafusion_functions_aggregate_common::accumulator::AccumulatorArgs;
24use nalgebra::{Const, DVectorView, Dyn, OVector};
25
26use crate::scalars::vector::impl_conv::{
27    binlit_as_veclit, parse_veclit_from_strlit, veclit_to_binlit,
28};
29
30/// Aggregates by multiplying elements across the same dimension, returns a vector.
31#[derive(Debug, Default)]
32pub struct VectorProduct {
33    product: Option<OVector<f32, Dyn>>,
34    has_null: bool,
35}
36
37impl VectorProduct {
38    /// Create a new `AggregateUDF` for the `vec_product` aggregate function.
39    pub fn uadf_impl() -> AggregateUDF {
40        let signature = Signature::one_of(
41            vec![
42                TypeSignature::Exact(vec![DataType::Utf8]),
43                TypeSignature::Exact(vec![DataType::Binary]),
44            ],
45            Volatility::Immutable,
46        );
47        let udaf = SimpleAggregateUDF::new_with_signature(
48            "vec_product",
49            signature,
50            DataType::Binary,
51            Arc::new(Self::accumulator),
52            vec![Arc::new(Field::new("x", DataType::Binary, true))],
53        );
54        AggregateUDF::from(udaf)
55    }
56
57    fn accumulator(args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
58        if args.schema.fields().len() != 1 {
59            return Err(datafusion_common::DataFusionError::Internal(format!(
60                "expect creating `VEC_PRODUCT` with only one input field, actual {}",
61                args.schema.fields().len()
62            )));
63        }
64
65        let t = args.schema.field(0).data_type();
66        if !matches!(t, DataType::Utf8 | DataType::Binary) {
67            return Err(datafusion_common::DataFusionError::Internal(format!(
68                "unexpected input datatype {t} when creating `VEC_PRODUCT`"
69            )));
70        }
71
72        Ok(Box::new(VectorProduct::default()))
73    }
74
75    fn inner(&mut self, len: usize) -> &mut OVector<f32, Dyn> {
76        self.product.get_or_insert_with(|| {
77            OVector::from_iterator_generic(Dyn(len), Const::<1>, (0..len).map(|_| 1.0))
78        })
79    }
80
81    fn update(&mut self, values: &[ArrayRef], is_update: bool) -> Result<()> {
82        if values.is_empty() || self.has_null {
83            return Ok(());
84        };
85
86        let vectors = match values[0].data_type() {
87            DataType::Utf8 => {
88                let arr: &StringArray = values[0].as_string();
89                arr.iter()
90                    .filter_map(|x| x.map(|s| parse_veclit_from_strlit(s).map_err(Into::into)))
91                    .map(|x| x.map(Cow::Owned))
92                    .collect::<Result<Vec<_>>>()?
93            }
94            DataType::Binary => {
95                let arr: &BinaryArray = values[0].as_binary();
96                arr.iter()
97                    .filter_map(|x| x.map(|b| binlit_as_veclit(b).map_err(Into::into)))
98                    .collect::<Result<Vec<_>>>()?
99            }
100            _ => {
101                return Err(datafusion_common::DataFusionError::NotImplemented(format!(
102                    "unsupported data type {} for `VEC_PRODUCT`",
103                    values[0].data_type()
104                )))
105            }
106        };
107        if vectors.len() != values[0].len() {
108            if is_update {
109                self.has_null = true;
110                self.product = None;
111            }
112            return Ok(());
113        }
114
115        vectors.iter().for_each(|v| {
116            let v = DVectorView::from_slice(v, v.len());
117            let inner = self.inner(v.len());
118            *inner = inner.component_mul(&v);
119        });
120        Ok(())
121    }
122}
123
124impl Accumulator for VectorProduct {
125    fn state(&mut self) -> Result<Vec<ScalarValue>> {
126        self.evaluate().map(|v| vec![v])
127    }
128
129    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
130        self.update(values, true)
131    }
132
133    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
134        self.update(states, false)
135    }
136
137    fn evaluate(&mut self) -> Result<ScalarValue> {
138        match &self.product {
139            None => Ok(ScalarValue::Binary(None)),
140            Some(vector) => Ok(ScalarValue::Binary(Some(veclit_to_binlit(
141                vector.as_slice(),
142            )))),
143        }
144    }
145
146    fn size(&self) -> usize {
147        size_of_val(self)
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use std::sync::Arc;
154
155    use datatypes::scalars::ScalarVector;
156    use datatypes::vectors::{ConstantVector, StringVector, Vector};
157
158    use super::*;
159
160    #[test]
161    fn test_update_batch() {
162        // test update empty batch, expect not updating anything
163        let mut vec_product = VectorProduct::default();
164        vec_product.update_batch(&[]).unwrap();
165        assert!(vec_product.product.is_none());
166        assert!(!vec_product.has_null);
167        assert_eq!(ScalarValue::Binary(None), vec_product.evaluate().unwrap());
168
169        // test update one not-null value
170        let mut vec_product = VectorProduct::default();
171        let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![Some(
172            "[1.0,2.0,3.0]".to_string(),
173        )]))];
174        vec_product.update_batch(&v).unwrap();
175        assert_eq!(
176            ScalarValue::Binary(Some(veclit_to_binlit(&[1.0, 2.0, 3.0]))),
177            vec_product.evaluate().unwrap()
178        );
179
180        // test update one null value
181        let mut vec_product = VectorProduct::default();
182        let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![Option::<String>::None]))];
183        vec_product.update_batch(&v).unwrap();
184        assert_eq!(ScalarValue::Binary(None), vec_product.evaluate().unwrap());
185
186        // test update no null-value batch
187        let mut vec_product = VectorProduct::default();
188        let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
189            Some("[1.0,2.0,3.0]".to_string()),
190            Some("[4.0,5.0,6.0]".to_string()),
191            Some("[7.0,8.0,9.0]".to_string()),
192        ]))];
193        vec_product.update_batch(&v).unwrap();
194        assert_eq!(
195            ScalarValue::Binary(Some(veclit_to_binlit(&[28.0, 80.0, 162.0]))),
196            vec_product.evaluate().unwrap()
197        );
198
199        // test update null-value batch
200        let mut vec_product = VectorProduct::default();
201        let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
202            Some("[1.0,2.0,3.0]".to_string()),
203            None,
204            Some("[7.0,8.0,9.0]".to_string()),
205        ]))];
206        vec_product.update_batch(&v).unwrap();
207        assert_eq!(ScalarValue::Binary(None), vec_product.evaluate().unwrap());
208
209        // test update with constant vector
210        let mut vec_product = VectorProduct::default();
211        let v: Vec<ArrayRef> = vec![Arc::new(ConstantVector::new(
212            Arc::new(StringVector::from_vec(vec!["[1.0,2.0,3.0]".to_string()])),
213            4,
214        ))
215        .to_arrow_array()];
216
217        vec_product.update_batch(&v).unwrap();
218
219        assert_eq!(
220            ScalarValue::Binary(Some(veclit_to_binlit(&[1.0, 16.0, 81.0]))),
221            vec_product.evaluate().unwrap()
222        );
223    }
224}