common_function/scalars/vector/
elem_sum.rs1use std::fmt::Display;
16
17use common_query::error::Result;
18use datafusion::arrow::datatypes::DataType;
19use datafusion::logical_expr::ColumnarValue;
20use datafusion_common::ScalarValue;
21use datafusion_expr::type_coercion::aggregates::{BINARYS, STRINGS};
22use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
23use nalgebra::DVectorView;
24
25use crate::function::Function;
26use crate::scalars::vector::{VectorCalculator, impl_conv};
27
28const NAME: &str = "vec_elem_sum";
29
30#[derive(Debug, Clone, Default)]
31pub struct ElemSumFunction;
32
33impl Function for ElemSumFunction {
34 fn name(&self) -> &str {
35 NAME
36 }
37
38 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
39 Ok(DataType::Float32)
40 }
41
42 fn signature(&self) -> Signature {
43 Signature::one_of(
44 vec![
45 TypeSignature::Uniform(1, STRINGS.to_vec()),
46 TypeSignature::Uniform(1, BINARYS.to_vec()),
47 ],
48 Volatility::Immutable,
49 )
50 }
51
52 fn invoke_with_args(
53 &self,
54 args: ScalarFunctionArgs,
55 ) -> datafusion_common::Result<ColumnarValue> {
56 let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
57 let v0 =
58 impl_conv::as_veclit(v0)?.map(|v0| DVectorView::from_slice(&v0, v0.len()).sum());
59 Ok(ScalarValue::Float32(v0))
60 };
61
62 let calculator = VectorCalculator {
63 name: self.name(),
64 func: body,
65 };
66 calculator.invoke_with_single_argument(args)
67 }
68}
69
70impl Display for ElemSumFunction {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 write!(f, "{}", NAME.to_ascii_uppercase())
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use std::sync::Arc;
79
80 use arrow::array::StringViewArray;
81 use arrow_schema::Field;
82 use datafusion::arrow::array::{Array, AsArray};
83 use datafusion::arrow::datatypes::Float32Type;
84 use datafusion_common::config::ConfigOptions;
85
86 use super::*;
87
88 #[test]
89 fn test_elem_sum() {
90 let func = ElemSumFunction;
91
92 let input = Arc::new(StringViewArray::from(vec![
93 Some("[1.0,2.0,3.0]".to_string()),
94 Some("[4.0,5.0,6.0]".to_string()),
95 None,
96 ]));
97
98 let result = func
99 .invoke_with_args(ScalarFunctionArgs {
100 args: vec![ColumnarValue::Array(input.clone())],
101 arg_fields: vec![],
102 number_rows: input.len(),
103 return_field: Arc::new(Field::new("x", DataType::Float32, true)),
104 config_options: Arc::new(ConfigOptions::new()),
105 })
106 .and_then(|v| ColumnarValue::values_to_arrays(&[v]))
107 .map(|mut a| a.remove(0))
108 .unwrap();
109 let result = result.as_primitive::<Float32Type>();
110
111 assert_eq!(result.len(), 3);
112 assert_eq!(result.value(0), 6.0);
113 assert_eq!(result.value(1), 15.0);
114 assert!(result.is_null(2));
115 }
116}