common_function/scalars/vector/
elem_product.rs1use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::InvalidFuncArgsSnafu;
19use common_query::prelude::{Signature, TypeSignature, Volatility};
20use datatypes::prelude::ConcreteDataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
23use nalgebra::DVectorView;
24use snafu::ensure;
25
26use crate::function::{Function, FunctionContext};
27use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
28
29const NAME: &str = "vec_elem_product";
30
31#[derive(Debug, Clone, Default)]
45pub struct ElemProductFunction;
46
47impl Function for ElemProductFunction {
48 fn name(&self) -> &str {
49 NAME
50 }
51
52 fn return_type(
53 &self,
54 _input_types: &[ConcreteDataType],
55 ) -> common_query::error::Result<ConcreteDataType> {
56 Ok(ConcreteDataType::float32_datatype())
57 }
58
59 fn signature(&self) -> Signature {
60 Signature::one_of(
61 vec![
62 TypeSignature::Exact(vec![ConcreteDataType::string_datatype()]),
63 TypeSignature::Exact(vec![ConcreteDataType::binary_datatype()]),
64 ],
65 Volatility::Immutable,
66 )
67 }
68
69 fn eval(
70 &self,
71 _func_ctx: &FunctionContext,
72 columns: &[VectorRef],
73 ) -> common_query::error::Result<VectorRef> {
74 ensure!(
75 columns.len() == 1,
76 InvalidFuncArgsSnafu {
77 err_msg: format!(
78 "The length of the args is not correct, expect exactly one, have: {}",
79 columns.len()
80 )
81 }
82 );
83 let arg0 = &columns[0];
84
85 let len = arg0.len();
86 let mut result = Float32VectorBuilder::with_capacity(len);
87 if len == 0 {
88 return Ok(result.to_vector());
89 }
90
91 let arg0_const = as_veclit_if_const(arg0)?;
92
93 for i in 0..len {
94 let arg0 = match arg0_const.as_ref() {
95 Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
96 None => as_veclit(arg0.get_ref(i))?,
97 };
98 let Some(arg0) = arg0 else {
99 result.push_null();
100 continue;
101 };
102 result.push(Some(DVectorView::from_slice(&arg0, arg0.len()).product()));
103 }
104
105 Ok(result.to_vector())
106 }
107}
108
109impl Display for ElemProductFunction {
110 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111 write!(f, "{}", NAME.to_ascii_uppercase())
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use std::sync::Arc;
118
119 use datatypes::vectors::StringVector;
120
121 use super::*;
122 use crate::function::FunctionContext;
123
124 #[test]
125 fn test_elem_product() {
126 let func = ElemProductFunction;
127
128 let input0 = Arc::new(StringVector::from(vec![
129 Some("[1.0,2.0,3.0]".to_string()),
130 Some("[4.0,5.0,6.0]".to_string()),
131 None,
132 ]));
133
134 let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
135
136 let result = result.as_ref();
137 assert_eq!(result.len(), 3);
138 assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(6.0));
139 assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(120.0));
140 assert_eq!(result.get_ref(2).as_f32().unwrap(), None);
141 }
142}