common_function/scalars/vector/
elem_sum.rs1use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::InvalidFuncArgsSnafu;
19use common_query::prelude::{Signature, TypeSignature, Volatility};
20use datatypes::prelude::ConcreteDataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
23use nalgebra::DVectorView;
24use snafu::ensure;
25
26use crate::function::{Function, FunctionContext};
27use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
28
29const NAME: &str = "vec_elem_sum";
30
31#[derive(Debug, Clone, Default)]
32pub struct ElemSumFunction;
33
34impl Function for ElemSumFunction {
35 fn name(&self) -> &str {
36 NAME
37 }
38
39 fn return_type(
40 &self,
41 _input_types: &[ConcreteDataType],
42 ) -> common_query::error::Result<ConcreteDataType> {
43 Ok(ConcreteDataType::float32_datatype())
44 }
45
46 fn signature(&self) -> Signature {
47 Signature::one_of(
48 vec![
49 TypeSignature::Exact(vec![ConcreteDataType::string_datatype()]),
50 TypeSignature::Exact(vec![ConcreteDataType::binary_datatype()]),
51 ],
52 Volatility::Immutable,
53 )
54 }
55
56 fn eval(
57 &self,
58 _func_ctx: &FunctionContext,
59 columns: &[VectorRef],
60 ) -> common_query::error::Result<VectorRef> {
61 ensure!(
62 columns.len() == 1,
63 InvalidFuncArgsSnafu {
64 err_msg: format!(
65 "The length of the args is not correct, expect exactly one, have: {}",
66 columns.len()
67 )
68 }
69 );
70 let arg0 = &columns[0];
71
72 let len = arg0.len();
73 let mut result = Float32VectorBuilder::with_capacity(len);
74 if len == 0 {
75 return Ok(result.to_vector());
76 }
77
78 let arg0_const = as_veclit_if_const(arg0)?;
79
80 for i in 0..len {
81 let arg0 = match arg0_const.as_ref() {
82 Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
83 None => as_veclit(arg0.get_ref(i))?,
84 };
85 let Some(arg0) = arg0 else {
86 result.push_null();
87 continue;
88 };
89 result.push(Some(DVectorView::from_slice(&arg0, arg0.len()).sum()));
90 }
91
92 Ok(result.to_vector())
93 }
94}
95
96impl Display for ElemSumFunction {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 write!(f, "{}", NAME.to_ascii_uppercase())
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use std::sync::Arc;
105
106 use datatypes::vectors::StringVector;
107
108 use super::*;
109 use crate::function::FunctionContext;
110
111 #[test]
112 fn test_elem_sum() {
113 let func = ElemSumFunction;
114
115 let input0 = Arc::new(StringVector::from(vec![
116 Some("[1.0,2.0,3.0]".to_string()),
117 Some("[4.0,5.0,6.0]".to_string()),
118 None,
119 ]));
120
121 let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
122
123 let result = result.as_ref();
124 assert_eq!(result.len(), 3);
125 assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(6.0));
126 assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(15.0));
127 assert_eq!(result.get_ref(2).as_f32().unwrap(), None);
128 }
129}