common_function/scalars/vector/
elem_product.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::{Coercion, ColumnarValue, TypeSignature, TypeSignatureClass};
19use datafusion_common::ScalarValue;
20use datafusion_common::types::{logical_binary, logical_string};
21use datafusion_expr::{ScalarFunctionArgs, Signature, Volatility};
22use nalgebra::DVectorView;
23
24use crate::function::Function;
25use crate::scalars::vector::{VectorCalculator, impl_conv};
26
27const NAME: &str = "vec_elem_product";
28
29/// Multiplies all elements of the vector, returns a scalar.
30///
31/// # Example
32///
33/// ```sql
34/// SELECT vec_elem_product(parse_vec('[1.0, 2.0, 3.0, 4.0]'));
35///
36// +-----------------------------------------------------------+
37// | vec_elem_product(parse_vec(Utf8("[1.0, 2.0, 3.0, 4.0]"))) |
38// +-----------------------------------------------------------+
39// | 24.0                                                      |
40// +-----------------------------------------------------------+
41/// ``````
42#[derive(Debug, Clone)]
43pub(crate) struct ElemProductFunction {
44    signature: Signature,
45}
46
47impl Default for ElemProductFunction {
48    fn default() -> Self {
49        Self {
50            signature: Signature::one_of(
51                vec![
52                    TypeSignature::Coercible(vec![Coercion::new_exact(
53                        TypeSignatureClass::Native(logical_binary()),
54                    )]),
55                    TypeSignature::Coercible(vec![Coercion::new_exact(
56                        TypeSignatureClass::Native(logical_string()),
57                    )]),
58                ],
59                Volatility::Immutable,
60            ),
61        }
62    }
63}
64
65impl Function for ElemProductFunction {
66    fn name(&self) -> &str {
67        NAME
68    }
69
70    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
71        Ok(DataType::Float32)
72    }
73
74    fn signature(&self) -> &Signature {
75        &self.signature
76    }
77
78    fn invoke_with_args(
79        &self,
80        args: ScalarFunctionArgs,
81    ) -> datafusion_common::Result<ColumnarValue> {
82        let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
83            let v0 = impl_conv::as_veclit(v0)?
84                .map(|v0| DVectorView::from_slice(&v0, v0.len()).product());
85            Ok(ScalarValue::Float32(v0))
86        };
87
88        let calculator = VectorCalculator {
89            name: self.name(),
90            func: body,
91        };
92        calculator.invoke_with_single_argument(args)
93    }
94}
95
96impl Display for ElemProductFunction {
97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98        write!(f, "{}", NAME.to_ascii_uppercase())
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use std::sync::Arc;
105
106    use arrow_schema::Field;
107    use datafusion::arrow::array::{Array, AsArray, StringArray};
108    use datafusion::arrow::datatypes::Float32Type;
109    use datafusion_common::config::ConfigOptions;
110
111    use super::*;
112
113    #[test]
114    fn test_elem_product() {
115        let func = ElemProductFunction::default();
116
117        let input = Arc::new(StringArray::from(vec![
118            Some("[1.0,2.0,3.0]".to_string()),
119            Some("[4.0,5.0,6.0]".to_string()),
120            None,
121        ]));
122
123        let result = func
124            .invoke_with_args(ScalarFunctionArgs {
125                args: vec![ColumnarValue::Array(input.clone())],
126                arg_fields: vec![],
127                number_rows: input.len(),
128                return_field: Arc::new(Field::new("x", DataType::Float32, true)),
129                config_options: Arc::new(ConfigOptions::new()),
130            })
131            .and_then(|v| ColumnarValue::values_to_arrays(&[v]))
132            .map(|mut a| a.remove(0))
133            .unwrap();
134        let result = result.as_primitive::<Float32Type>();
135
136        assert_eq!(result.len(), 3);
137        assert_eq!(result.value(0), 6.0);
138        assert_eq!(result.value(1), 120.0);
139        assert!(result.is_null(2));
140    }
141}