common_function/scalars/vector/
elem_sum.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion::arrow::datatypes::DataType;
20use datafusion_expr::type_coercion::aggregates::{BINARYS, STRINGS};
21use datafusion_expr::{Signature, TypeSignature, Volatility};
22use datatypes::scalars::ScalarVectorBuilder;
23use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
24use nalgebra::DVectorView;
25use snafu::ensure;
26
27use crate::function::{Function, FunctionContext};
28use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
29
30const NAME: &str = "vec_elem_sum";
31
32#[derive(Debug, Clone, Default)]
33pub struct ElemSumFunction;
34
35impl Function for ElemSumFunction {
36    fn name(&self) -> &str {
37        NAME
38    }
39
40    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
41        Ok(DataType::Float32)
42    }
43
44    fn signature(&self) -> Signature {
45        Signature::one_of(
46            vec![
47                TypeSignature::Uniform(1, STRINGS.to_vec()),
48                TypeSignature::Uniform(1, BINARYS.to_vec()),
49            ],
50            Volatility::Immutable,
51        )
52    }
53
54    fn eval(
55        &self,
56        _func_ctx: &FunctionContext,
57        columns: &[VectorRef],
58    ) -> common_query::error::Result<VectorRef> {
59        ensure!(
60            columns.len() == 1,
61            InvalidFuncArgsSnafu {
62                err_msg: format!(
63                    "The length of the args is not correct, expect exactly one, have: {}",
64                    columns.len()
65                )
66            }
67        );
68        let arg0 = &columns[0];
69
70        let len = arg0.len();
71        let mut result = Float32VectorBuilder::with_capacity(len);
72        if len == 0 {
73            return Ok(result.to_vector());
74        }
75
76        let arg0_const = as_veclit_if_const(arg0)?;
77
78        for i in 0..len {
79            let arg0 = match arg0_const.as_ref() {
80                Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
81                None => as_veclit(arg0.get_ref(i))?,
82            };
83            let Some(arg0) = arg0 else {
84                result.push_null();
85                continue;
86            };
87            result.push(Some(DVectorView::from_slice(&arg0, arg0.len()).sum()));
88        }
89
90        Ok(result.to_vector())
91    }
92}
93
94impl Display for ElemSumFunction {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        write!(f, "{}", NAME.to_ascii_uppercase())
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use std::sync::Arc;
103
104    use datatypes::vectors::StringVector;
105
106    use super::*;
107    use crate::function::FunctionContext;
108
109    #[test]
110    fn test_elem_sum() {
111        let func = ElemSumFunction;
112
113        let input0 = Arc::new(StringVector::from(vec![
114            Some("[1.0,2.0,3.0]".to_string()),
115            Some("[4.0,5.0,6.0]".to_string()),
116            None,
117        ]));
118
119        let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
120
121        let result = result.as_ref();
122        assert_eq!(result.len(), 3);
123        assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(6.0));
124        assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(15.0));
125        assert_eq!(result.get_ref(2).as_f32().unwrap(), None);
126    }
127}