common_function/scalars/vector/
elem_sum.rs1use std::fmt::Display;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion_common::ScalarValue;
20use datafusion_expr::type_coercion::aggregates::{BINARYS, STRINGS};
21use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
22use nalgebra::DVectorView;
23
24use crate::function::Function;
25use crate::scalars::vector::{VectorCalculator, impl_conv};
26
27const NAME: &str = "vec_elem_sum";
28
29#[derive(Debug, Clone)]
30pub(crate) struct ElemSumFunction {
31 signature: Signature,
32}
33
34impl Default for ElemSumFunction {
35 fn default() -> Self {
36 Self {
37 signature: Signature::one_of(
38 vec![
39 TypeSignature::Uniform(1, STRINGS.to_vec()),
40 TypeSignature::Uniform(1, BINARYS.to_vec()),
41 TypeSignature::Uniform(1, vec![DataType::BinaryView]),
42 ],
43 Volatility::Immutable,
44 ),
45 }
46 }
47}
48
49impl Function for ElemSumFunction {
50 fn name(&self) -> &str {
51 NAME
52 }
53
54 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
55 Ok(DataType::Float32)
56 }
57
58 fn signature(&self) -> &Signature {
59 &self.signature
60 }
61
62 fn invoke_with_args(
63 &self,
64 args: ScalarFunctionArgs,
65 ) -> datafusion_common::Result<ColumnarValue> {
66 let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
67 let v0 =
68 impl_conv::as_veclit(v0)?.map(|v0| DVectorView::from_slice(&v0, v0.len()).sum());
69 Ok(ScalarValue::Float32(v0))
70 };
71
72 let calculator = VectorCalculator {
73 name: self.name(),
74 func: body,
75 };
76 calculator.invoke_with_single_argument(args)
77 }
78}
79
80impl Display for ElemSumFunction {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 write!(f, "{}", NAME.to_ascii_uppercase())
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use std::sync::Arc;
89
90 use arrow::array::StringViewArray;
91 use arrow_schema::Field;
92 use datafusion::arrow::array::{Array, AsArray};
93 use datafusion::arrow::datatypes::Float32Type;
94 use datafusion_common::config::ConfigOptions;
95
96 use super::*;
97
98 #[test]
99 fn test_elem_sum() {
100 let func = ElemSumFunction::default();
101
102 let input = Arc::new(StringViewArray::from(vec![
103 Some("[1.0,2.0,3.0]".to_string()),
104 Some("[4.0,5.0,6.0]".to_string()),
105 None,
106 ]));
107
108 let result = func
109 .invoke_with_args(ScalarFunctionArgs {
110 args: vec![ColumnarValue::Array(input.clone())],
111 arg_fields: vec![],
112 number_rows: input.len(),
113 return_field: Arc::new(Field::new("x", DataType::Float32, true)),
114 config_options: Arc::new(ConfigOptions::new()),
115 })
116 .and_then(|v| ColumnarValue::values_to_arrays(&[v]))
117 .map(|mut a| a.remove(0))
118 .unwrap();
119 let result = result.as_primitive::<Float32Type>();
120
121 assert_eq!(result.len(), 3);
122 assert_eq!(result.value(0), 6.0);
123 assert_eq!(result.value(1), 15.0);
124 assert!(result.is_null(2));
125 }
126}