common_function/scalars/vector/
elem_sum.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion_common::ScalarValue;
20use datafusion_expr::type_coercion::aggregates::{BINARYS, STRINGS};
21use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
22use nalgebra::DVectorView;
23
24use crate::function::Function;
25use crate::scalars::vector::{VectorCalculator, impl_conv};
26
27const NAME: &str = "vec_elem_sum";
28
29#[derive(Debug, Clone)]
30pub(crate) struct ElemSumFunction {
31    signature: Signature,
32}
33
34impl Default for ElemSumFunction {
35    fn default() -> Self {
36        Self {
37            signature: Signature::one_of(
38                vec![
39                    TypeSignature::Uniform(1, STRINGS.to_vec()),
40                    TypeSignature::Uniform(1, BINARYS.to_vec()),
41                    TypeSignature::Uniform(1, vec![DataType::BinaryView]),
42                ],
43                Volatility::Immutable,
44            ),
45        }
46    }
47}
48
49impl Function for ElemSumFunction {
50    fn name(&self) -> &str {
51        NAME
52    }
53
54    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
55        Ok(DataType::Float32)
56    }
57
58    fn signature(&self) -> &Signature {
59        &self.signature
60    }
61
62    fn invoke_with_args(
63        &self,
64        args: ScalarFunctionArgs,
65    ) -> datafusion_common::Result<ColumnarValue> {
66        let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
67            let v0 =
68                impl_conv::as_veclit(v0)?.map(|v0| DVectorView::from_slice(&v0, v0.len()).sum());
69            Ok(ScalarValue::Float32(v0))
70        };
71
72        let calculator = VectorCalculator {
73            name: self.name(),
74            func: body,
75        };
76        calculator.invoke_with_single_argument(args)
77    }
78}
79
80impl Display for ElemSumFunction {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        write!(f, "{}", NAME.to_ascii_uppercase())
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use std::sync::Arc;
89
90    use arrow::array::StringViewArray;
91    use arrow_schema::Field;
92    use datafusion::arrow::array::{Array, AsArray};
93    use datafusion::arrow::datatypes::Float32Type;
94    use datafusion_common::config::ConfigOptions;
95
96    use super::*;
97
98    #[test]
99    fn test_elem_sum() {
100        let func = ElemSumFunction::default();
101
102        let input = Arc::new(StringViewArray::from(vec![
103            Some("[1.0,2.0,3.0]".to_string()),
104            Some("[4.0,5.0,6.0]".to_string()),
105            None,
106        ]));
107
108        let result = func
109            .invoke_with_args(ScalarFunctionArgs {
110                args: vec![ColumnarValue::Array(input.clone())],
111                arg_fields: vec![],
112                number_rows: input.len(),
113                return_field: Arc::new(Field::new("x", DataType::Float32, true)),
114                config_options: Arc::new(ConfigOptions::new()),
115            })
116            .and_then(|v| ColumnarValue::values_to_arrays(&[v]))
117            .map(|mut a| a.remove(0))
118            .unwrap();
119        let result = result.as_primitive::<Float32Type>();
120
121        assert_eq!(result.len(), 3);
122        assert_eq!(result.value(0), 6.0);
123        assert_eq!(result.value(1), 15.0);
124        assert!(result.is_null(2));
125    }
126}