common_function/scalars/vector/
elem_product.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion::arrow::datatypes::DataType;
20use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
21use datafusion_expr::{Signature, TypeSignature, Volatility};
22use datatypes::scalars::ScalarVectorBuilder;
23use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
24use nalgebra::DVectorView;
25use snafu::ensure;
26
27use crate::function::{Function, FunctionContext};
28use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
29
30const NAME: &str = "vec_elem_product";
31
32/// Multiplies all elements of the vector, returns a scalar.
33///
34/// # Example
35///
36/// ```sql
37/// SELECT vec_elem_product(parse_vec('[1.0, 2.0, 3.0, 4.0]'));
38///
39// +-----------------------------------------------------------+
40// | vec_elem_product(parse_vec(Utf8("[1.0, 2.0, 3.0, 4.0]"))) |
41// +-----------------------------------------------------------+
42// | 24.0                                                      |
43// +-----------------------------------------------------------+
44/// ``````
45#[derive(Debug, Clone, Default)]
46pub struct ElemProductFunction;
47
48impl Function for ElemProductFunction {
49    fn name(&self) -> &str {
50        NAME
51    }
52
53    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
54        Ok(DataType::Float32)
55    }
56
57    fn signature(&self) -> Signature {
58        Signature::one_of(
59            vec![
60                TypeSignature::Uniform(1, STRINGS.to_vec()),
61                TypeSignature::Uniform(1, BINARYS.to_vec()),
62            ],
63            Volatility::Immutable,
64        )
65    }
66
67    fn eval(
68        &self,
69        _func_ctx: &FunctionContext,
70        columns: &[VectorRef],
71    ) -> common_query::error::Result<VectorRef> {
72        ensure!(
73            columns.len() == 1,
74            InvalidFuncArgsSnafu {
75                err_msg: format!(
76                    "The length of the args is not correct, expect exactly one, have: {}",
77                    columns.len()
78                )
79            }
80        );
81        let arg0 = &columns[0];
82
83        let len = arg0.len();
84        let mut result = Float32VectorBuilder::with_capacity(len);
85        if len == 0 {
86            return Ok(result.to_vector());
87        }
88
89        let arg0_const = as_veclit_if_const(arg0)?;
90
91        for i in 0..len {
92            let arg0 = match arg0_const.as_ref() {
93                Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
94                None => as_veclit(arg0.get_ref(i))?,
95            };
96            let Some(arg0) = arg0 else {
97                result.push_null();
98                continue;
99            };
100            result.push(Some(DVectorView::from_slice(&arg0, arg0.len()).product()));
101        }
102
103        Ok(result.to_vector())
104    }
105}
106
107impl Display for ElemProductFunction {
108    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109        write!(f, "{}", NAME.to_ascii_uppercase())
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use std::sync::Arc;
116
117    use datatypes::vectors::StringVector;
118
119    use super::*;
120    use crate::function::FunctionContext;
121
122    #[test]
123    fn test_elem_product() {
124        let func = ElemProductFunction;
125
126        let input0 = Arc::new(StringVector::from(vec![
127            Some("[1.0,2.0,3.0]".to_string()),
128            Some("[4.0,5.0,6.0]".to_string()),
129            None,
130        ]));
131
132        let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
133
134        let result = result.as_ref();
135        assert_eq!(result.len(), 3);
136        assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(6.0));
137        assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(120.0));
138        assert_eq!(result.get_ref(2).as_f32().unwrap(), None);
139    }
140}