common_function/scalars/vector/
elem_product.rs1use std::fmt::Display;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
20use datafusion_common::ScalarValue;
21use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
22use nalgebra::DVectorView;
23
24use crate::function::Function;
25use crate::scalars::vector::{VectorCalculator, impl_conv};
26
27const NAME: &str = "vec_elem_product";
28
29#[derive(Debug, Clone)]
43pub(crate) struct ElemProductFunction {
44 signature: Signature,
45}
46
47impl Default for ElemProductFunction {
48 fn default() -> Self {
49 Self {
50 signature: Signature::one_of(
51 vec![
52 TypeSignature::Uniform(1, STRINGS.to_vec()),
53 TypeSignature::Uniform(1, BINARYS.to_vec()),
54 TypeSignature::Uniform(1, vec![DataType::BinaryView]),
55 ],
56 Volatility::Immutable,
57 ),
58 }
59 }
60}
61
62impl Function for ElemProductFunction {
63 fn name(&self) -> &str {
64 NAME
65 }
66
67 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
68 Ok(DataType::Float32)
69 }
70
71 fn signature(&self) -> &Signature {
72 &self.signature
73 }
74
75 fn invoke_with_args(
76 &self,
77 args: ScalarFunctionArgs,
78 ) -> datafusion_common::Result<ColumnarValue> {
79 let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
80 let v0 = impl_conv::as_veclit(v0)?
81 .map(|v0| DVectorView::from_slice(&v0, v0.len()).product());
82 Ok(ScalarValue::Float32(v0))
83 };
84
85 let calculator = VectorCalculator {
86 name: self.name(),
87 func: body,
88 };
89 calculator.invoke_with_single_argument(args)
90 }
91}
92
93impl Display for ElemProductFunction {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 write!(f, "{}", NAME.to_ascii_uppercase())
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use std::sync::Arc;
102
103 use arrow_schema::Field;
104 use datafusion::arrow::array::{Array, AsArray, StringArray};
105 use datafusion::arrow::datatypes::Float32Type;
106 use datafusion_common::config::ConfigOptions;
107
108 use super::*;
109
110 #[test]
111 fn test_elem_product() {
112 let func = ElemProductFunction::default();
113
114 let input = Arc::new(StringArray::from(vec![
115 Some("[1.0,2.0,3.0]".to_string()),
116 Some("[4.0,5.0,6.0]".to_string()),
117 None,
118 ]));
119
120 let result = func
121 .invoke_with_args(ScalarFunctionArgs {
122 args: vec![ColumnarValue::Array(input.clone())],
123 arg_fields: vec![],
124 number_rows: input.len(),
125 return_field: Arc::new(Field::new("x", DataType::Float32, true)),
126 config_options: Arc::new(ConfigOptions::new()),
127 })
128 .and_then(|v| ColumnarValue::values_to_arrays(&[v]))
129 .map(|mut a| a.remove(0))
130 .unwrap();
131 let result = result.as_primitive::<Float32Type>();
132
133 assert_eq!(result.len(), 3);
134 assert_eq!(result.value(0), 6.0);
135 assert_eq!(result.value(1), 120.0);
136 assert!(result.is_null(2));
137 }
138}