common_function/scalars/vector/
elem_product.rs1use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion::arrow::datatypes::DataType;
20use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
21use datafusion_expr::{Signature, TypeSignature, Volatility};
22use datatypes::scalars::ScalarVectorBuilder;
23use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
24use nalgebra::DVectorView;
25use snafu::ensure;
26
27use crate::function::{Function, FunctionContext};
28use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
29
30const NAME: &str = "vec_elem_product";
31
32#[derive(Debug, Clone, Default)]
46pub struct ElemProductFunction;
47
48impl Function for ElemProductFunction {
49 fn name(&self) -> &str {
50 NAME
51 }
52
53 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
54 Ok(DataType::Float32)
55 }
56
57 fn signature(&self) -> Signature {
58 Signature::one_of(
59 vec![
60 TypeSignature::Uniform(1, STRINGS.to_vec()),
61 TypeSignature::Uniform(1, BINARYS.to_vec()),
62 ],
63 Volatility::Immutable,
64 )
65 }
66
67 fn eval(
68 &self,
69 _func_ctx: &FunctionContext,
70 columns: &[VectorRef],
71 ) -> common_query::error::Result<VectorRef> {
72 ensure!(
73 columns.len() == 1,
74 InvalidFuncArgsSnafu {
75 err_msg: format!(
76 "The length of the args is not correct, expect exactly one, have: {}",
77 columns.len()
78 )
79 }
80 );
81 let arg0 = &columns[0];
82
83 let len = arg0.len();
84 let mut result = Float32VectorBuilder::with_capacity(len);
85 if len == 0 {
86 return Ok(result.to_vector());
87 }
88
89 let arg0_const = as_veclit_if_const(arg0)?;
90
91 for i in 0..len {
92 let arg0 = match arg0_const.as_ref() {
93 Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
94 None => as_veclit(arg0.get_ref(i))?,
95 };
96 let Some(arg0) = arg0 else {
97 result.push_null();
98 continue;
99 };
100 result.push(Some(DVectorView::from_slice(&arg0, arg0.len()).product()));
101 }
102
103 Ok(result.to_vector())
104 }
105}
106
107impl Display for ElemProductFunction {
108 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109 write!(f, "{}", NAME.to_ascii_uppercase())
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use std::sync::Arc;
116
117 use datatypes::vectors::StringVector;
118
119 use super::*;
120 use crate::function::FunctionContext;
121
122 #[test]
123 fn test_elem_product() {
124 let func = ElemProductFunction;
125
126 let input0 = Arc::new(StringVector::from(vec![
127 Some("[1.0,2.0,3.0]".to_string()),
128 Some("[4.0,5.0,6.0]".to_string()),
129 None,
130 ]));
131
132 let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
133
134 let result = result.as_ref();
135 assert_eq!(result.len(), 3);
136 assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(6.0));
137 assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(120.0));
138 assert_eq!(result.get_ref(2).as_f32().unwrap(), None);
139 }
140}