common_function/scalars/vector/
vector_mul.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion_common::{DataFusionError, ScalarValue};
20use datafusion_expr::{ScalarFunctionArgs, Signature};
21use nalgebra::DVectorView;
22
23use crate::function::Function;
24use crate::scalars::vector::impl_conv::veclit_to_binlit;
25use crate::scalars::vector::{VectorCalculator, define_args_of_two_vector_literals_udf};
26
27const NAME: &str = "vec_mul";
28
29define_args_of_two_vector_literals_udf!(
30/// Multiplies corresponding elements of two vectors.
31///
32/// # Example
33///
34/// ```sql
35/// SELECT vec_to_string(vec_mul("[1, 2, 3]", "[1, 2, 3]")) as result;
36///
37/// +---------+
38/// | result  |
39/// +---------+
40/// | [1,4,9] |
41/// +---------+
42///
43/// ```
44VectorMulFunction);
45
46impl Function for VectorMulFunction {
47    fn name(&self) -> &str {
48        NAME
49    }
50
51    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
52        Ok(DataType::BinaryView)
53    }
54
55    fn signature(&self) -> &Signature {
56        &self.signature
57    }
58
59    fn invoke_with_args(
60        &self,
61        args: ScalarFunctionArgs,
62    ) -> datafusion_common::Result<ColumnarValue> {
63        let body = |v0: &Option<Cow<[f32]>>,
64                    v1: &Option<Cow<[f32]>>|
65         -> datafusion_common::Result<ScalarValue> {
66            let result = if let (Some(v0), Some(v1)) = (v0, v1) {
67                let v0 = DVectorView::from_slice(v0, v0.len());
68                let v1 = DVectorView::from_slice(v1, v1.len());
69                if v0.len() != v1.len() {
70                    return Err(DataFusionError::Execution(format!(
71                        "vectors length not match: {}",
72                        self.name()
73                    )));
74                }
75
76                let result = veclit_to_binlit((v0.component_mul(&v1)).as_slice());
77                Some(result)
78            } else {
79                None
80            };
81            Ok(ScalarValue::BinaryView(result))
82        };
83
84        let calculator = VectorCalculator {
85            name: self.name(),
86            func: body,
87        };
88        calculator.invoke_with_vectors(args)
89    }
90}
91
92#[cfg(test)]
93mod tests {
94    use std::sync::Arc;
95
96    use arrow_schema::Field;
97    use datafusion::arrow::array::{Array, AsArray, StringViewArray};
98    use datafusion_common::config::ConfigOptions;
99
100    use super::*;
101
102    #[test]
103    fn test_vector_mul() {
104        let func = VectorMulFunction::default();
105
106        let vec0 = vec![1.0, 2.0, 3.0];
107        let vec1 = vec![1.0, 1.0];
108        let input0 = Arc::new(StringViewArray::from(vec![Some(format!("{vec0:?}"))]));
109        let input1 = Arc::new(StringViewArray::from(vec![Some(format!("{vec1:?}"))]));
110
111        let args = ScalarFunctionArgs {
112            args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
113            arg_fields: vec![],
114            number_rows: 4,
115            return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
116            config_options: Arc::new(ConfigOptions::new()),
117        };
118        let e = func.invoke_with_args(args).unwrap_err();
119        assert!(
120            e.to_string()
121                .starts_with("Execution error: vectors length not match: vec_mul")
122        );
123
124        let input0 = Arc::new(StringViewArray::from(vec![
125            Some("[1.0,2.0,3.0]".to_string()),
126            Some("[8.0,10.0,12.0]".to_string()),
127            Some("[7.0,8.0,9.0]".to_string()),
128            None,
129        ]));
130
131        let input1 = Arc::new(StringViewArray::from(vec![
132            Some("[1.0,1.0,1.0]".to_string()),
133            Some("[2.0,2.0,2.0]".to_string()),
134            None,
135            Some("[3.0,3.0,3.0]".to_string()),
136        ]));
137
138        let args = ScalarFunctionArgs {
139            args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
140            arg_fields: vec![],
141            number_rows: 4,
142            return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
143            config_options: Arc::new(ConfigOptions::new()),
144        };
145        let result = func
146            .invoke_with_args(args)
147            .and_then(|x| x.to_array(4))
148            .unwrap();
149
150        let result = result.as_binary_view();
151        assert_eq!(result.len(), 4);
152        assert_eq!(
153            result.value(0),
154            veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice()
155        );
156        assert_eq!(
157            result.value(1),
158            veclit_to_binlit(&[16.0, 20.0, 24.0]).as_slice()
159        );
160        assert!(result.is_null(2));
161        assert!(result.is_null(3));
162    }
163}