common_function/scalars/vector/
elem_product.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
20use datafusion_common::ScalarValue;
21use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
22use nalgebra::DVectorView;
23
24use crate::function::Function;
25use crate::scalars::vector::{VectorCalculator, impl_conv};
26
27const NAME: &str = "vec_elem_product";
28
29/// Multiplies all elements of the vector, returns a scalar.
30///
31/// # Example
32///
33/// ```sql
34/// SELECT vec_elem_product(parse_vec('[1.0, 2.0, 3.0, 4.0]'));
35///
36// +-----------------------------------------------------------+
37// | vec_elem_product(parse_vec(Utf8("[1.0, 2.0, 3.0, 4.0]"))) |
38// +-----------------------------------------------------------+
39// | 24.0                                                      |
40// +-----------------------------------------------------------+
41/// ``````
42#[derive(Debug, Clone)]
43pub(crate) struct ElemProductFunction {
44    signature: Signature,
45}
46
47impl Default for ElemProductFunction {
48    fn default() -> Self {
49        Self {
50            signature: Signature::one_of(
51                vec![
52                    TypeSignature::Uniform(1, STRINGS.to_vec()),
53                    TypeSignature::Uniform(1, BINARYS.to_vec()),
54                    TypeSignature::Uniform(1, vec![DataType::BinaryView]),
55                ],
56                Volatility::Immutable,
57            ),
58        }
59    }
60}
61
62impl Function for ElemProductFunction {
63    fn name(&self) -> &str {
64        NAME
65    }
66
67    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
68        Ok(DataType::Float32)
69    }
70
71    fn signature(&self) -> &Signature {
72        &self.signature
73    }
74
75    fn invoke_with_args(
76        &self,
77        args: ScalarFunctionArgs,
78    ) -> datafusion_common::Result<ColumnarValue> {
79        let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
80            let v0 = impl_conv::as_veclit(v0)?
81                .map(|v0| DVectorView::from_slice(&v0, v0.len()).product());
82            Ok(ScalarValue::Float32(v0))
83        };
84
85        let calculator = VectorCalculator {
86            name: self.name(),
87            func: body,
88        };
89        calculator.invoke_with_single_argument(args)
90    }
91}
92
93impl Display for ElemProductFunction {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        write!(f, "{}", NAME.to_ascii_uppercase())
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use std::sync::Arc;
102
103    use arrow_schema::Field;
104    use datafusion::arrow::array::{Array, AsArray, StringArray};
105    use datafusion::arrow::datatypes::Float32Type;
106    use datafusion_common::config::ConfigOptions;
107
108    use super::*;
109
110    #[test]
111    fn test_elem_product() {
112        let func = ElemProductFunction::default();
113
114        let input = Arc::new(StringArray::from(vec![
115            Some("[1.0,2.0,3.0]".to_string()),
116            Some("[4.0,5.0,6.0]".to_string()),
117            None,
118        ]));
119
120        let result = func
121            .invoke_with_args(ScalarFunctionArgs {
122                args: vec![ColumnarValue::Array(input.clone())],
123                arg_fields: vec![],
124                number_rows: input.len(),
125                return_field: Arc::new(Field::new("x", DataType::Float32, true)),
126                config_options: Arc::new(ConfigOptions::new()),
127            })
128            .and_then(|v| ColumnarValue::values_to_arrays(&[v]))
129            .map(|mut a| a.remove(0))
130            .unwrap();
131        let result = result.as_primitive::<Float32Type>();
132
133        assert_eq!(result.len(), 3);
134        assert_eq!(result.value(0), 6.0);
135        assert_eq!(result.value(1), 120.0);
136        assert!(result.is_null(2));
137    }
138}