common_function/aggrs/vector/
product.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::sync::Arc;
17
18use arrow::array::{Array, ArrayRef, AsArray, BinaryArray, LargeStringArray, StringArray};
19use arrow_schema::{DataType, Field};
20use datafusion::logical_expr::{Signature, TypeSignature, Volatility};
21use datafusion_common::{Result, ScalarValue};
22use datafusion_expr::{Accumulator, AggregateUDF, SimpleAggregateUDF};
23use datafusion_functions_aggregate_common::accumulator::AccumulatorArgs;
24use nalgebra::{Const, DVectorView, Dyn, OVector};
25
26use crate::scalars::vector::impl_conv::{
27    binlit_as_veclit, parse_veclit_from_strlit, veclit_to_binlit,
28};
29
30/// Aggregates by multiplying elements across the same dimension, returns a vector.
31#[derive(Debug, Default)]
32pub struct VectorProduct {
33    product: Option<OVector<f32, Dyn>>,
34    has_null: bool,
35}
36
37impl VectorProduct {
38    /// Create a new `AggregateUDF` for the `vec_product` aggregate function.
39    pub fn uadf_impl() -> AggregateUDF {
40        let signature = Signature::one_of(
41            vec![
42                TypeSignature::Exact(vec![DataType::Utf8]),
43                TypeSignature::Exact(vec![DataType::Binary]),
44            ],
45            Volatility::Immutable,
46        );
47        let udaf = SimpleAggregateUDF::new_with_signature(
48            "vec_product",
49            signature,
50            DataType::Binary,
51            Arc::new(Self::accumulator),
52            vec![Arc::new(Field::new("x", DataType::Binary, true))],
53        );
54        AggregateUDF::from(udaf)
55    }
56
57    fn accumulator(args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
58        if args.schema.fields().len() != 1 {
59            return Err(datafusion_common::DataFusionError::Internal(format!(
60                "expect creating `VEC_PRODUCT` with only one input field, actual {}",
61                args.schema.fields().len()
62            )));
63        }
64
65        let t = args.schema.field(0).data_type();
66        if !matches!(t, DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary) {
67            return Err(datafusion_common::DataFusionError::Internal(format!(
68                "unexpected input datatype {t} when creating `VEC_PRODUCT`"
69            )));
70        }
71
72        Ok(Box::new(VectorProduct::default()))
73    }
74
75    fn inner(&mut self, len: usize) -> &mut OVector<f32, Dyn> {
76        self.product.get_or_insert_with(|| {
77            OVector::from_iterator_generic(Dyn(len), Const::<1>, (0..len).map(|_| 1.0))
78        })
79    }
80
81    fn update(&mut self, values: &[ArrayRef], is_update: bool) -> Result<()> {
82        if values.is_empty() || self.has_null {
83            return Ok(());
84        };
85
86        let vectors = match values[0].data_type() {
87            DataType::Utf8 => {
88                let arr: &StringArray = values[0].as_string();
89                arr.iter()
90                    .filter_map(|x| x.map(|s| parse_veclit_from_strlit(s).map_err(Into::into)))
91                    .map(|x| x.map(Cow::Owned))
92                    .collect::<Result<Vec<_>>>()?
93            }
94            DataType::LargeUtf8 => {
95                let arr: &LargeStringArray = values[0].as_string();
96                arr.iter()
97                    .filter_map(|x| x.map(|s| parse_veclit_from_strlit(s).map_err(Into::into)))
98                    .map(|x: Result<Vec<f32>>| x.map(Cow::Owned))
99                    .collect::<Result<Vec<_>>>()?
100            }
101            DataType::Binary => {
102                let arr: &BinaryArray = values[0].as_binary();
103                arr.iter()
104                    .filter_map(|x| x.map(|b| binlit_as_veclit(b).map_err(Into::into)))
105                    .collect::<Result<Vec<_>>>()?
106            }
107            _ => {
108                return Err(datafusion_common::DataFusionError::NotImplemented(format!(
109                    "unsupported data type {} for `VEC_PRODUCT`",
110                    values[0].data_type()
111                )));
112            }
113        };
114        if vectors.len() != values[0].len() {
115            if is_update {
116                self.has_null = true;
117                self.product = None;
118            }
119            return Ok(());
120        }
121
122        vectors.iter().for_each(|v| {
123            let v = DVectorView::from_slice(v, v.len());
124            let inner = self.inner(v.len());
125            *inner = inner.component_mul(&v);
126        });
127        Ok(())
128    }
129}
130
131impl Accumulator for VectorProduct {
132    fn state(&mut self) -> Result<Vec<ScalarValue>> {
133        self.evaluate().map(|v| vec![v])
134    }
135
136    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
137        self.update(values, true)
138    }
139
140    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
141        self.update(states, false)
142    }
143
144    fn evaluate(&mut self) -> Result<ScalarValue> {
145        match &self.product {
146            None => Ok(ScalarValue::Binary(None)),
147            Some(vector) => Ok(ScalarValue::Binary(Some(veclit_to_binlit(
148                vector.as_slice(),
149            )))),
150        }
151    }
152
153    fn size(&self) -> usize {
154        size_of_val(self)
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use std::sync::Arc;
161
162    use datatypes::scalars::ScalarVector;
163    use datatypes::vectors::{ConstantVector, StringVector, Vector};
164
165    use super::*;
166
167    #[test]
168    fn test_update_batch() {
169        // test update empty batch, expect not updating anything
170        let mut vec_product = VectorProduct::default();
171        vec_product.update_batch(&[]).unwrap();
172        assert!(vec_product.product.is_none());
173        assert!(!vec_product.has_null);
174        assert_eq!(ScalarValue::Binary(None), vec_product.evaluate().unwrap());
175
176        // test update one not-null value
177        let mut vec_product = VectorProduct::default();
178        let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![Some(
179            "[1.0,2.0,3.0]".to_string(),
180        )]))];
181        vec_product.update_batch(&v).unwrap();
182        assert_eq!(
183            ScalarValue::Binary(Some(veclit_to_binlit(&[1.0, 2.0, 3.0]))),
184            vec_product.evaluate().unwrap()
185        );
186
187        // test update one null value
188        let mut vec_product = VectorProduct::default();
189        let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![Option::<String>::None]))];
190        vec_product.update_batch(&v).unwrap();
191        assert_eq!(ScalarValue::Binary(None), vec_product.evaluate().unwrap());
192
193        // test update no null-value batch
194        let mut vec_product = VectorProduct::default();
195        let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
196            Some("[1.0,2.0,3.0]".to_string()),
197            Some("[4.0,5.0,6.0]".to_string()),
198            Some("[7.0,8.0,9.0]".to_string()),
199        ]))];
200        vec_product.update_batch(&v).unwrap();
201        assert_eq!(
202            ScalarValue::Binary(Some(veclit_to_binlit(&[28.0, 80.0, 162.0]))),
203            vec_product.evaluate().unwrap()
204        );
205
206        // test update null-value batch
207        let mut vec_product = VectorProduct::default();
208        let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
209            Some("[1.0,2.0,3.0]".to_string()),
210            None,
211            Some("[7.0,8.0,9.0]".to_string()),
212        ]))];
213        vec_product.update_batch(&v).unwrap();
214        assert_eq!(ScalarValue::Binary(None), vec_product.evaluate().unwrap());
215
216        // test update with constant vector
217        let mut vec_product = VectorProduct::default();
218        let v: Vec<ArrayRef> = vec![
219            Arc::new(ConstantVector::new(
220                Arc::new(StringVector::from_vec(vec!["[1.0,2.0,3.0]".to_string()])),
221                4,
222            ))
223            .to_arrow_array(),
224        ];
225
226        vec_product.update_batch(&v).unwrap();
227
228        assert_eq!(
229            ScalarValue::Binary(Some(veclit_to_binlit(&[1.0, 16.0, 81.0]))),
230            vec_product.evaluate().unwrap()
231        );
232    }
233}