common_function/scalars/vector/
elem_sum.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::InvalidFuncArgsSnafu;
19use common_query::prelude::{Signature, TypeSignature, Volatility};
20use datatypes::prelude::ConcreteDataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
23use nalgebra::DVectorView;
24use snafu::ensure;
25
26use crate::function::{Function, FunctionContext};
27use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
28
29const NAME: &str = "vec_elem_sum";
30
31#[derive(Debug, Clone, Default)]
32pub struct ElemSumFunction;
33
34impl Function for ElemSumFunction {
35    fn name(&self) -> &str {
36        NAME
37    }
38
39    fn return_type(
40        &self,
41        _input_types: &[ConcreteDataType],
42    ) -> common_query::error::Result<ConcreteDataType> {
43        Ok(ConcreteDataType::float32_datatype())
44    }
45
46    fn signature(&self) -> Signature {
47        Signature::one_of(
48            vec![
49                TypeSignature::Exact(vec![ConcreteDataType::string_datatype()]),
50                TypeSignature::Exact(vec![ConcreteDataType::binary_datatype()]),
51            ],
52            Volatility::Immutable,
53        )
54    }
55
56    fn eval(
57        &self,
58        _func_ctx: &FunctionContext,
59        columns: &[VectorRef],
60    ) -> common_query::error::Result<VectorRef> {
61        ensure!(
62            columns.len() == 1,
63            InvalidFuncArgsSnafu {
64                err_msg: format!(
65                    "The length of the args is not correct, expect exactly one, have: {}",
66                    columns.len()
67                )
68            }
69        );
70        let arg0 = &columns[0];
71
72        let len = arg0.len();
73        let mut result = Float32VectorBuilder::with_capacity(len);
74        if len == 0 {
75            return Ok(result.to_vector());
76        }
77
78        let arg0_const = as_veclit_if_const(arg0)?;
79
80        for i in 0..len {
81            let arg0 = match arg0_const.as_ref() {
82                Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
83                None => as_veclit(arg0.get_ref(i))?,
84            };
85            let Some(arg0) = arg0 else {
86                result.push_null();
87                continue;
88            };
89            result.push(Some(DVectorView::from_slice(&arg0, arg0.len()).sum()));
90        }
91
92        Ok(result.to_vector())
93    }
94}
95
96impl Display for ElemSumFunction {
97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98        write!(f, "{}", NAME.to_ascii_uppercase())
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use std::sync::Arc;
105
106    use datatypes::vectors::StringVector;
107
108    use super::*;
109    use crate::function::FunctionContext;
110
111    #[test]
112    fn test_elem_sum() {
113        let func = ElemSumFunction;
114
115        let input0 = Arc::new(StringVector::from(vec![
116            Some("[1.0,2.0,3.0]".to_string()),
117            Some("[4.0,5.0,6.0]".to_string()),
118            None,
119        ]));
120
121        let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
122
123        let result = result.as_ref();
124        assert_eq!(result.len(), 3);
125        assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(6.0));
126        assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(15.0));
127        assert_eq!(result.get_ref(2).as_f32().unwrap(), None);
128    }
129}