common_function/scalars/vector/
elem_product.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::InvalidFuncArgsSnafu;
19use common_query::prelude::{Signature, TypeSignature, Volatility};
20use datatypes::prelude::ConcreteDataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
23use nalgebra::DVectorView;
24use snafu::ensure;
25
26use crate::function::{Function, FunctionContext};
27use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
28
29const NAME: &str = "vec_elem_product";
30
31/// Multiplies all elements of the vector, returns a scalar.
32///
33/// # Example
34///
35/// ```sql
36/// SELECT vec_elem_product(parse_vec('[1.0, 2.0, 3.0, 4.0]'));
37///
38// +-----------------------------------------------------------+
39// | vec_elem_product(parse_vec(Utf8("[1.0, 2.0, 3.0, 4.0]"))) |
40// +-----------------------------------------------------------+
41// | 24.0                                                      |
42// +-----------------------------------------------------------+
43/// ``````
44#[derive(Debug, Clone, Default)]
45pub struct ElemProductFunction;
46
47impl Function for ElemProductFunction {
48    fn name(&self) -> &str {
49        NAME
50    }
51
52    fn return_type(
53        &self,
54        _input_types: &[ConcreteDataType],
55    ) -> common_query::error::Result<ConcreteDataType> {
56        Ok(ConcreteDataType::float32_datatype())
57    }
58
59    fn signature(&self) -> Signature {
60        Signature::one_of(
61            vec![
62                TypeSignature::Exact(vec![ConcreteDataType::string_datatype()]),
63                TypeSignature::Exact(vec![ConcreteDataType::binary_datatype()]),
64            ],
65            Volatility::Immutable,
66        )
67    }
68
69    fn eval(
70        &self,
71        _func_ctx: &FunctionContext,
72        columns: &[VectorRef],
73    ) -> common_query::error::Result<VectorRef> {
74        ensure!(
75            columns.len() == 1,
76            InvalidFuncArgsSnafu {
77                err_msg: format!(
78                    "The length of the args is not correct, expect exactly one, have: {}",
79                    columns.len()
80                )
81            }
82        );
83        let arg0 = &columns[0];
84
85        let len = arg0.len();
86        let mut result = Float32VectorBuilder::with_capacity(len);
87        if len == 0 {
88            return Ok(result.to_vector());
89        }
90
91        let arg0_const = as_veclit_if_const(arg0)?;
92
93        for i in 0..len {
94            let arg0 = match arg0_const.as_ref() {
95                Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
96                None => as_veclit(arg0.get_ref(i))?,
97            };
98            let Some(arg0) = arg0 else {
99                result.push_null();
100                continue;
101            };
102            result.push(Some(DVectorView::from_slice(&arg0, arg0.len()).product()));
103        }
104
105        Ok(result.to_vector())
106    }
107}
108
109impl Display for ElemProductFunction {
110    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111        write!(f, "{}", NAME.to_ascii_uppercase())
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use std::sync::Arc;
118
119    use datatypes::vectors::StringVector;
120
121    use super::*;
122    use crate::function::FunctionContext;
123
124    #[test]
125    fn test_elem_product() {
126        let func = ElemProductFunction;
127
128        let input0 = Arc::new(StringVector::from(vec![
129            Some("[1.0,2.0,3.0]".to_string()),
130            Some("[4.0,5.0,6.0]".to_string()),
131            None,
132        ]));
133
134        let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
135
136        let result = result.as_ref();
137        assert_eq!(result.len(), 3);
138        assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(6.0));
139        assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(120.0));
140        assert_eq!(result.get_ref(2).as_f32().unwrap(), None);
141    }
142}