common_function/scalars/vector/
vector_mul.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::Result;
19use datafusion::arrow::datatypes::DataType;
20use datafusion::logical_expr::ColumnarValue;
21use datafusion_common::{DataFusionError, ScalarValue};
22use datafusion_expr::{ScalarFunctionArgs, Signature};
23use nalgebra::DVectorView;
24
25use crate::function::Function;
26use crate::helper;
27use crate::scalars::vector::VectorCalculator;
28use crate::scalars::vector::impl_conv::veclit_to_binlit;
29
30const NAME: &str = "vec_mul";
31
32/// Multiplies corresponding elements of two vectors.
33///
34/// # Example
35///
36/// ```sql
37/// SELECT vec_to_string(vec_mul("[1, 2, 3]", "[1, 2, 3]")) as result;
38///
39/// +---------+
40/// | result  |
41/// +---------+
42/// | [1,4,9] |
43/// +---------+
44///
45/// ```
46#[derive(Debug, Clone, Default)]
47pub struct VectorMulFunction;
48
49impl Function for VectorMulFunction {
50    fn name(&self) -> &str {
51        NAME
52    }
53
54    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
55        Ok(DataType::BinaryView)
56    }
57
58    fn signature(&self) -> Signature {
59        helper::one_of_sigs2(
60            vec![DataType::Utf8, DataType::Binary],
61            vec![DataType::Utf8, DataType::Binary],
62        )
63    }
64
65    fn invoke_with_args(
66        &self,
67        args: ScalarFunctionArgs,
68    ) -> datafusion_common::Result<ColumnarValue> {
69        let body = |v0: &Option<Cow<[f32]>>,
70                    v1: &Option<Cow<[f32]>>|
71         -> datafusion_common::Result<ScalarValue> {
72            let result = if let (Some(v0), Some(v1)) = (v0, v1) {
73                let v0 = DVectorView::from_slice(v0, v0.len());
74                let v1 = DVectorView::from_slice(v1, v1.len());
75                if v0.len() != v1.len() {
76                    return Err(DataFusionError::Execution(format!(
77                        "vectors length not match: {}",
78                        self.name()
79                    )));
80                }
81
82                let result = veclit_to_binlit((v0.component_mul(&v1)).as_slice());
83                Some(result)
84            } else {
85                None
86            };
87            Ok(ScalarValue::BinaryView(result))
88        };
89
90        let calculator = VectorCalculator {
91            name: self.name(),
92            func: body,
93        };
94        calculator.invoke_with_vectors(args)
95    }
96}
97
98impl Display for VectorMulFunction {
99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        write!(f, "{}", NAME.to_ascii_uppercase())
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use std::sync::Arc;
107
108    use arrow_schema::Field;
109    use datafusion::arrow::array::{Array, AsArray, StringViewArray};
110    use datafusion_common::config::ConfigOptions;
111
112    use super::*;
113
114    #[test]
115    fn test_vector_mul() {
116        let func = VectorMulFunction;
117
118        let vec0 = vec![1.0, 2.0, 3.0];
119        let vec1 = vec![1.0, 1.0];
120        let input0 = Arc::new(StringViewArray::from(vec![Some(format!("{vec0:?}"))]));
121        let input1 = Arc::new(StringViewArray::from(vec![Some(format!("{vec1:?}"))]));
122
123        let args = ScalarFunctionArgs {
124            args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
125            arg_fields: vec![],
126            number_rows: 4,
127            return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
128            config_options: Arc::new(ConfigOptions::new()),
129        };
130        let e = func.invoke_with_args(args).unwrap_err();
131        assert!(
132            e.to_string()
133                .starts_with("Execution error: vectors length not match: vec_mul")
134        );
135
136        let input0 = Arc::new(StringViewArray::from(vec![
137            Some("[1.0,2.0,3.0]".to_string()),
138            Some("[8.0,10.0,12.0]".to_string()),
139            Some("[7.0,8.0,9.0]".to_string()),
140            None,
141        ]));
142
143        let input1 = Arc::new(StringViewArray::from(vec![
144            Some("[1.0,1.0,1.0]".to_string()),
145            Some("[2.0,2.0,2.0]".to_string()),
146            None,
147            Some("[3.0,3.0,3.0]".to_string()),
148        ]));
149
150        let args = ScalarFunctionArgs {
151            args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
152            arg_fields: vec![],
153            number_rows: 4,
154            return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
155            config_options: Arc::new(ConfigOptions::new()),
156        };
157        let result = func
158            .invoke_with_args(args)
159            .and_then(|x| x.to_array(4))
160            .unwrap();
161
162        let result = result.as_binary_view();
163        assert_eq!(result.len(), 4);
164        assert_eq!(
165            result.value(0),
166            veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice()
167        );
168        assert_eq!(
169            result.value(1),
170            veclit_to_binlit(&[16.0, 20.0, 24.0]).as_slice()
171        );
172        assert!(result.is_null(2));
173        assert!(result.is_null(3));
174    }
175}