common_function/aggrs/vector/
product.rs1use std::borrow::Cow;
16use std::sync::Arc;
17
18use arrow::array::{Array, ArrayRef, AsArray, BinaryArray, StringArray};
19use arrow_schema::{DataType, Field};
20use datafusion::logical_expr::{Signature, TypeSignature, Volatility};
21use datafusion_common::{Result, ScalarValue};
22use datafusion_expr::{Accumulator, AggregateUDF, SimpleAggregateUDF};
23use datafusion_functions_aggregate_common::accumulator::AccumulatorArgs;
24use nalgebra::{Const, DVectorView, Dyn, OVector};
25
26use crate::scalars::vector::impl_conv::{
27 binlit_as_veclit, parse_veclit_from_strlit, veclit_to_binlit,
28};
29
30#[derive(Debug, Default)]
32pub struct VectorProduct {
33 product: Option<OVector<f32, Dyn>>,
34 has_null: bool,
35}
36
37impl VectorProduct {
38 pub fn uadf_impl() -> AggregateUDF {
40 let signature = Signature::one_of(
41 vec![
42 TypeSignature::Exact(vec![DataType::Utf8]),
43 TypeSignature::Exact(vec![DataType::Binary]),
44 ],
45 Volatility::Immutable,
46 );
47 let udaf = SimpleAggregateUDF::new_with_signature(
48 "vec_product",
49 signature,
50 DataType::Binary,
51 Arc::new(Self::accumulator),
52 vec![Arc::new(Field::new("x", DataType::Binary, true))],
53 );
54 AggregateUDF::from(udaf)
55 }
56
57 fn accumulator(args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
58 if args.schema.fields().len() != 1 {
59 return Err(datafusion_common::DataFusionError::Internal(format!(
60 "expect creating `VEC_PRODUCT` with only one input field, actual {}",
61 args.schema.fields().len()
62 )));
63 }
64
65 let t = args.schema.field(0).data_type();
66 if !matches!(t, DataType::Utf8 | DataType::Binary) {
67 return Err(datafusion_common::DataFusionError::Internal(format!(
68 "unexpected input datatype {t} when creating `VEC_PRODUCT`"
69 )));
70 }
71
72 Ok(Box::new(VectorProduct::default()))
73 }
74
75 fn inner(&mut self, len: usize) -> &mut OVector<f32, Dyn> {
76 self.product.get_or_insert_with(|| {
77 OVector::from_iterator_generic(Dyn(len), Const::<1>, (0..len).map(|_| 1.0))
78 })
79 }
80
81 fn update(&mut self, values: &[ArrayRef], is_update: bool) -> Result<()> {
82 if values.is_empty() || self.has_null {
83 return Ok(());
84 };
85
86 let vectors = match values[0].data_type() {
87 DataType::Utf8 => {
88 let arr: &StringArray = values[0].as_string();
89 arr.iter()
90 .filter_map(|x| x.map(|s| parse_veclit_from_strlit(s).map_err(Into::into)))
91 .map(|x| x.map(Cow::Owned))
92 .collect::<Result<Vec<_>>>()?
93 }
94 DataType::Binary => {
95 let arr: &BinaryArray = values[0].as_binary();
96 arr.iter()
97 .filter_map(|x| x.map(|b| binlit_as_veclit(b).map_err(Into::into)))
98 .collect::<Result<Vec<_>>>()?
99 }
100 _ => {
101 return Err(datafusion_common::DataFusionError::NotImplemented(format!(
102 "unsupported data type {} for `VEC_PRODUCT`",
103 values[0].data_type()
104 )))
105 }
106 };
107 if vectors.len() != values[0].len() {
108 if is_update {
109 self.has_null = true;
110 self.product = None;
111 }
112 return Ok(());
113 }
114
115 vectors.iter().for_each(|v| {
116 let v = DVectorView::from_slice(v, v.len());
117 let inner = self.inner(v.len());
118 *inner = inner.component_mul(&v);
119 });
120 Ok(())
121 }
122}
123
124impl Accumulator for VectorProduct {
125 fn state(&mut self) -> Result<Vec<ScalarValue>> {
126 self.evaluate().map(|v| vec![v])
127 }
128
129 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
130 self.update(values, true)
131 }
132
133 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
134 self.update(states, false)
135 }
136
137 fn evaluate(&mut self) -> Result<ScalarValue> {
138 match &self.product {
139 None => Ok(ScalarValue::Binary(None)),
140 Some(vector) => Ok(ScalarValue::Binary(Some(veclit_to_binlit(
141 vector.as_slice(),
142 )))),
143 }
144 }
145
146 fn size(&self) -> usize {
147 size_of_val(self)
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use std::sync::Arc;
154
155 use datatypes::scalars::ScalarVector;
156 use datatypes::vectors::{ConstantVector, StringVector, Vector};
157
158 use super::*;
159
160 #[test]
161 fn test_update_batch() {
162 let mut vec_product = VectorProduct::default();
164 vec_product.update_batch(&[]).unwrap();
165 assert!(vec_product.product.is_none());
166 assert!(!vec_product.has_null);
167 assert_eq!(ScalarValue::Binary(None), vec_product.evaluate().unwrap());
168
169 let mut vec_product = VectorProduct::default();
171 let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![Some(
172 "[1.0,2.0,3.0]".to_string(),
173 )]))];
174 vec_product.update_batch(&v).unwrap();
175 assert_eq!(
176 ScalarValue::Binary(Some(veclit_to_binlit(&[1.0, 2.0, 3.0]))),
177 vec_product.evaluate().unwrap()
178 );
179
180 let mut vec_product = VectorProduct::default();
182 let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![Option::<String>::None]))];
183 vec_product.update_batch(&v).unwrap();
184 assert_eq!(ScalarValue::Binary(None), vec_product.evaluate().unwrap());
185
186 let mut vec_product = VectorProduct::default();
188 let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
189 Some("[1.0,2.0,3.0]".to_string()),
190 Some("[4.0,5.0,6.0]".to_string()),
191 Some("[7.0,8.0,9.0]".to_string()),
192 ]))];
193 vec_product.update_batch(&v).unwrap();
194 assert_eq!(
195 ScalarValue::Binary(Some(veclit_to_binlit(&[28.0, 80.0, 162.0]))),
196 vec_product.evaluate().unwrap()
197 );
198
199 let mut vec_product = VectorProduct::default();
201 let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
202 Some("[1.0,2.0,3.0]".to_string()),
203 None,
204 Some("[7.0,8.0,9.0]".to_string()),
205 ]))];
206 vec_product.update_batch(&v).unwrap();
207 assert_eq!(ScalarValue::Binary(None), vec_product.evaluate().unwrap());
208
209 let mut vec_product = VectorProduct::default();
211 let v: Vec<ArrayRef> = vec![Arc::new(ConstantVector::new(
212 Arc::new(StringVector::from_vec(vec!["[1.0,2.0,3.0]".to_string()])),
213 4,
214 ))
215 .to_arrow_array()];
216
217 vec_product.update_batch(&v).unwrap();
218
219 assert_eq!(
220 ScalarValue::Binary(Some(veclit_to_binlit(&[1.0, 16.0, 81.0]))),
221 vec_product.evaluate().unwrap()
222 );
223 }
224}