common_function/scalars/vector/
elem_product.rs1use std::fmt::Display;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::{Coercion, ColumnarValue, TypeSignature, TypeSignatureClass};
19use datafusion_common::ScalarValue;
20use datafusion_common::types::{logical_binary, logical_string};
21use datafusion_expr::{ScalarFunctionArgs, Signature, Volatility};
22use nalgebra::DVectorView;
23
24use crate::function::Function;
25use crate::scalars::vector::{VectorCalculator, impl_conv};
26
27const NAME: &str = "vec_elem_product";
28
29#[derive(Debug, Clone)]
43pub(crate) struct ElemProductFunction {
44 signature: Signature,
45}
46
47impl Default for ElemProductFunction {
48 fn default() -> Self {
49 Self {
50 signature: Signature::one_of(
51 vec![
52 TypeSignature::Coercible(vec![Coercion::new_exact(
53 TypeSignatureClass::Native(logical_binary()),
54 )]),
55 TypeSignature::Coercible(vec![Coercion::new_exact(
56 TypeSignatureClass::Native(logical_string()),
57 )]),
58 ],
59 Volatility::Immutable,
60 ),
61 }
62 }
63}
64
65impl Function for ElemProductFunction {
66 fn name(&self) -> &str {
67 NAME
68 }
69
70 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
71 Ok(DataType::Float32)
72 }
73
74 fn signature(&self) -> &Signature {
75 &self.signature
76 }
77
78 fn invoke_with_args(
79 &self,
80 args: ScalarFunctionArgs,
81 ) -> datafusion_common::Result<ColumnarValue> {
82 let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
83 let v0 = impl_conv::as_veclit(v0)?
84 .map(|v0| DVectorView::from_slice(&v0, v0.len()).product());
85 Ok(ScalarValue::Float32(v0))
86 };
87
88 let calculator = VectorCalculator {
89 name: self.name(),
90 func: body,
91 };
92 calculator.invoke_with_single_argument(args)
93 }
94}
95
96impl Display for ElemProductFunction {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 write!(f, "{}", NAME.to_ascii_uppercase())
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use std::sync::Arc;
105
106 use arrow_schema::Field;
107 use datafusion::arrow::array::{Array, AsArray, StringArray};
108 use datafusion::arrow::datatypes::Float32Type;
109 use datafusion_common::config::ConfigOptions;
110
111 use super::*;
112
113 #[test]
114 fn test_elem_product() {
115 let func = ElemProductFunction::default();
116
117 let input = Arc::new(StringArray::from(vec![
118 Some("[1.0,2.0,3.0]".to_string()),
119 Some("[4.0,5.0,6.0]".to_string()),
120 None,
121 ]));
122
123 let result = func
124 .invoke_with_args(ScalarFunctionArgs {
125 args: vec![ColumnarValue::Array(input.clone())],
126 arg_fields: vec![],
127 number_rows: input.len(),
128 return_field: Arc::new(Field::new("x", DataType::Float32, true)),
129 config_options: Arc::new(ConfigOptions::new()),
130 })
131 .and_then(|v| ColumnarValue::values_to_arrays(&[v]))
132 .map(|mut a| a.remove(0))
133 .unwrap();
134 let result = result.as_primitive::<Float32Type>();
135
136 assert_eq!(result.len(), 3);
137 assert_eq!(result.value(0), 6.0);
138 assert_eq!(result.value(1), 120.0);
139 assert!(result.is_null(2));
140 }
141}