common_function/scalars/vector/
elem_sum.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16
17use common_query::error::Result;
18use datafusion::arrow::datatypes::DataType;
19use datafusion::logical_expr::ColumnarValue;
20use datafusion_common::ScalarValue;
21use datafusion_expr::type_coercion::aggregates::{BINARYS, STRINGS};
22use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
23use nalgebra::DVectorView;
24
25use crate::function::Function;
26use crate::scalars::vector::{VectorCalculator, impl_conv};
27
28const NAME: &str = "vec_elem_sum";
29
30#[derive(Debug, Clone, Default)]
31pub struct ElemSumFunction;
32
33impl Function for ElemSumFunction {
34    fn name(&self) -> &str {
35        NAME
36    }
37
38    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
39        Ok(DataType::Float32)
40    }
41
42    fn signature(&self) -> Signature {
43        Signature::one_of(
44            vec![
45                TypeSignature::Uniform(1, STRINGS.to_vec()),
46                TypeSignature::Uniform(1, BINARYS.to_vec()),
47            ],
48            Volatility::Immutable,
49        )
50    }
51
52    fn invoke_with_args(
53        &self,
54        args: ScalarFunctionArgs,
55    ) -> datafusion_common::Result<ColumnarValue> {
56        let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
57            let v0 =
58                impl_conv::as_veclit(v0)?.map(|v0| DVectorView::from_slice(&v0, v0.len()).sum());
59            Ok(ScalarValue::Float32(v0))
60        };
61
62        let calculator = VectorCalculator {
63            name: self.name(),
64            func: body,
65        };
66        calculator.invoke_with_single_argument(args)
67    }
68}
69
70impl Display for ElemSumFunction {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        write!(f, "{}", NAME.to_ascii_uppercase())
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use std::sync::Arc;
79
80    use arrow::array::StringViewArray;
81    use arrow_schema::Field;
82    use datafusion::arrow::array::{Array, AsArray};
83    use datafusion::arrow::datatypes::Float32Type;
84    use datafusion_common::config::ConfigOptions;
85
86    use super::*;
87
88    #[test]
89    fn test_elem_sum() {
90        let func = ElemSumFunction;
91
92        let input = Arc::new(StringViewArray::from(vec![
93            Some("[1.0,2.0,3.0]".to_string()),
94            Some("[4.0,5.0,6.0]".to_string()),
95            None,
96        ]));
97
98        let result = func
99            .invoke_with_args(ScalarFunctionArgs {
100                args: vec![ColumnarValue::Array(input.clone())],
101                arg_fields: vec![],
102                number_rows: input.len(),
103                return_field: Arc::new(Field::new("x", DataType::Float32, true)),
104                config_options: Arc::new(ConfigOptions::new()),
105            })
106            .and_then(|v| ColumnarValue::values_to_arrays(&[v]))
107            .map(|mut a| a.remove(0))
108            .unwrap();
109        let result = result.as_primitive::<Float32Type>();
110
111        assert_eq!(result.len(), 3);
112        assert_eq!(result.value(0), 6.0);
113        assert_eq!(result.value(1), 15.0);
114        assert!(result.is_null(2));
115    }
116}