common_function/scalars/vector/
vector_mul.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use common_query::prelude::Signature;
20use datatypes::prelude::ConcreteDataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
23use nalgebra::DVectorView;
24use snafu::ensure;
25
26use crate::function::{Function, FunctionContext};
27use crate::helper;
28use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
29
30const NAME: &str = "vec_mul";
31
32/// Multiplies corresponding elements of two vectors.
33///
34/// # Example
35///
36/// ```sql
37/// SELECT vec_to_string(vec_mul("[1, 2, 3]", "[1, 2, 3]")) as result;
38///
39/// +---------+
40/// | result  |
41/// +---------+
42/// | [1,4,9] |
43/// +---------+
44///
45/// ```
46#[derive(Debug, Clone, Default)]
47pub struct VectorMulFunction;
48
49impl Function for VectorMulFunction {
50    fn name(&self) -> &str {
51        NAME
52    }
53
54    fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
55        Ok(ConcreteDataType::binary_datatype())
56    }
57
58    fn signature(&self) -> Signature {
59        helper::one_of_sigs2(
60            vec![
61                ConcreteDataType::string_datatype(),
62                ConcreteDataType::binary_datatype(),
63            ],
64            vec![
65                ConcreteDataType::string_datatype(),
66                ConcreteDataType::binary_datatype(),
67            ],
68        )
69    }
70
71    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
72        ensure!(
73            columns.len() == 2,
74            InvalidFuncArgsSnafu {
75                err_msg: format!(
76                    "The length of the args is not correct, expect exactly two, have: {}",
77                    columns.len()
78                ),
79            }
80        );
81
82        let arg0 = &columns[0];
83        let arg1 = &columns[1];
84
85        let len = arg0.len();
86        let mut result = BinaryVectorBuilder::with_capacity(len);
87        if len == 0 {
88            return Ok(result.to_vector());
89        }
90
91        let arg0_const = as_veclit_if_const(arg0)?;
92        let arg1_const = as_veclit_if_const(arg1)?;
93
94        for i in 0..len {
95            let arg0 = match arg0_const.as_ref() {
96                Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
97                None => as_veclit(arg0.get_ref(i))?,
98            };
99
100            let arg1 = match arg1_const.as_ref() {
101                Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
102                None => as_veclit(arg1.get_ref(i))?,
103            };
104
105            if let (Some(arg0), Some(arg1)) = (arg0, arg1) {
106                ensure!(
107                    arg0.len() == arg1.len(),
108                    InvalidFuncArgsSnafu {
109                        err_msg: format!(
110                            "The length of the vectors must match for multiplying, have: {} vs {}",
111                            arg0.len(),
112                            arg1.len()
113                        ),
114                    }
115                );
116                let vec0 = DVectorView::from_slice(&arg0, arg0.len());
117                let vec1 = DVectorView::from_slice(&arg1, arg1.len());
118                let vec_res = vec1.component_mul(&vec0);
119
120                let veclit = vec_res.as_slice();
121                let binlit = veclit_to_binlit(veclit);
122                result.push(Some(&binlit));
123            } else {
124                result.push_null();
125            }
126        }
127
128        Ok(result.to_vector())
129    }
130}
131
132impl Display for VectorMulFunction {
133    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134        write!(f, "{}", NAME.to_ascii_uppercase())
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use std::sync::Arc;
141
142    use common_query::error;
143    use datatypes::vectors::StringVector;
144
145    use super::*;
146
147    #[test]
148    fn test_vector_mul() {
149        let func = VectorMulFunction;
150
151        let vec0 = vec![1.0, 2.0, 3.0];
152        let vec1 = vec![1.0, 1.0];
153        let (len0, len1) = (vec0.len(), vec1.len());
154        let input0 = Arc::new(StringVector::from(vec![Some(format!("{vec0:?}"))]));
155        let input1 = Arc::new(StringVector::from(vec![Some(format!("{vec1:?}"))]));
156
157        let err = func
158            .eval(&FunctionContext::default(), &[input0, input1])
159            .unwrap_err();
160
161        match err {
162            error::Error::InvalidFuncArgs { err_msg, .. } => {
163                assert_eq!(
164                    err_msg,
165                    format!(
166                        "The length of the vectors must match for multiplying, have: {} vs {}",
167                        len0, len1
168                    )
169                )
170            }
171            _ => unreachable!(),
172        }
173
174        let input0 = Arc::new(StringVector::from(vec![
175            Some("[1.0,2.0,3.0]".to_string()),
176            Some("[8.0,10.0,12.0]".to_string()),
177            Some("[7.0,8.0,9.0]".to_string()),
178            None,
179        ]));
180
181        let input1 = Arc::new(StringVector::from(vec![
182            Some("[1.0,1.0,1.0]".to_string()),
183            Some("[2.0,2.0,2.0]".to_string()),
184            None,
185            Some("[3.0,3.0,3.0]".to_string()),
186        ]));
187
188        let result = func
189            .eval(&FunctionContext::default(), &[input0, input1])
190            .unwrap();
191
192        let result = result.as_ref();
193        assert_eq!(result.len(), 4);
194        assert_eq!(
195            result.get_ref(0).as_binary().unwrap(),
196            Some(veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice())
197        );
198        assert_eq!(
199            result.get_ref(1).as_binary().unwrap(),
200            Some(veclit_to_binlit(&[16.0, 20.0, 24.0]).as_slice())
201        );
202        assert!(result.get_ref(2).is_null());
203        assert!(result.get_ref(3).is_null());
204    }
205}