common_function/scalars/vector/
vector_add.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::InvalidFuncArgsSnafu;
19use common_query::prelude::Signature;
20use datatypes::prelude::ConcreteDataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
23use nalgebra::DVectorView;
24use snafu::ensure;
25
26use crate::function::{Function, FunctionContext};
27use crate::helper;
28use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
29
30const NAME: &str = "vec_add";
31
32/// Adds corresponding elements of two vectors, returns a vector.
33///
34/// # Example
35///
36/// ```sql
37/// SELECT vec_to_string(vec_add("[1.0, 1.0]", "[1.0, 2.0]")) as result;
38///
39/// +---------------------------------------------------------------+
40/// | vec_to_string(vec_add(Utf8("[1.0, 1.0]"),Utf8("[1.0, 2.0]"))) |
41/// +---------------------------------------------------------------+
42/// | [2,3]                                                         |
43/// +---------------------------------------------------------------+
44///
45#[derive(Debug, Clone, Default)]
46pub struct VectorAddFunction;
47
48impl Function for VectorAddFunction {
49    fn name(&self) -> &str {
50        NAME
51    }
52
53    fn return_type(
54        &self,
55        _input_types: &[ConcreteDataType],
56    ) -> common_query::error::Result<ConcreteDataType> {
57        Ok(ConcreteDataType::binary_datatype())
58    }
59
60    fn signature(&self) -> Signature {
61        helper::one_of_sigs2(
62            vec![
63                ConcreteDataType::string_datatype(),
64                ConcreteDataType::binary_datatype(),
65            ],
66            vec![
67                ConcreteDataType::string_datatype(),
68                ConcreteDataType::binary_datatype(),
69            ],
70        )
71    }
72
73    fn eval(
74        &self,
75        _func_ctx: &FunctionContext,
76        columns: &[VectorRef],
77    ) -> common_query::error::Result<VectorRef> {
78        ensure!(
79            columns.len() == 2,
80            InvalidFuncArgsSnafu {
81                err_msg: format!(
82                    "The length of the args is not correct, expect exactly two, have: {}",
83                    columns.len()
84                )
85            }
86        );
87        let arg0 = &columns[0];
88        let arg1 = &columns[1];
89
90        ensure!(
91            arg0.len() == arg1.len(),
92            InvalidFuncArgsSnafu {
93                err_msg: format!(
94                    "The lengths of the vector are not aligned, args 0: {}, args 1: {}",
95                    arg0.len(),
96                    arg1.len(),
97                )
98            }
99        );
100
101        let len = arg0.len();
102        let mut result = BinaryVectorBuilder::with_capacity(len);
103        if len == 0 {
104            return Ok(result.to_vector());
105        }
106
107        let arg0_const = as_veclit_if_const(arg0)?;
108        let arg1_const = as_veclit_if_const(arg1)?;
109
110        for i in 0..len {
111            let arg0 = match arg0_const.as_ref() {
112                Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
113                None => as_veclit(arg0.get_ref(i))?,
114            };
115            let arg1 = match arg1_const.as_ref() {
116                Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
117                None => as_veclit(arg1.get_ref(i))?,
118            };
119            let (Some(arg0), Some(arg1)) = (arg0, arg1) else {
120                result.push_null();
121                continue;
122            };
123            let vec0 = DVectorView::from_slice(&arg0, arg0.len());
124            let vec1 = DVectorView::from_slice(&arg1, arg1.len());
125
126            let vec_res = vec0 + vec1;
127            let veclit = vec_res.as_slice();
128            let binlit = veclit_to_binlit(veclit);
129            result.push(Some(&binlit));
130        }
131
132        Ok(result.to_vector())
133    }
134}
135
136impl Display for VectorAddFunction {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        write!(f, "{}", NAME.to_ascii_uppercase())
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use std::sync::Arc;
145
146    use common_query::error::Error;
147    use datatypes::vectors::StringVector;
148
149    use super::*;
150
151    #[test]
152    fn test_sub() {
153        let func = VectorAddFunction;
154
155        let input0 = Arc::new(StringVector::from(vec![
156            Some("[1.0,2.0,3.0]".to_string()),
157            Some("[4.0,5.0,6.0]".to_string()),
158            None,
159            Some("[2.0,3.0,3.0]".to_string()),
160        ]));
161        let input1 = Arc::new(StringVector::from(vec![
162            Some("[1.0,1.0,1.0]".to_string()),
163            Some("[6.0,5.0,4.0]".to_string()),
164            Some("[3.0,2.0,2.0]".to_string()),
165            None,
166        ]));
167
168        let result = func
169            .eval(&FunctionContext::default(), &[input0, input1])
170            .unwrap();
171
172        let result = result.as_ref();
173        assert_eq!(result.len(), 4);
174        assert_eq!(
175            result.get_ref(0).as_binary().unwrap(),
176            Some(veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice())
177        );
178        assert_eq!(
179            result.get_ref(1).as_binary().unwrap(),
180            Some(veclit_to_binlit(&[10.0, 10.0, 10.0]).as_slice())
181        );
182        assert!(result.get_ref(2).is_null());
183        assert!(result.get_ref(3).is_null());
184    }
185
186    #[test]
187    fn test_sub_error() {
188        let func = VectorAddFunction;
189
190        let input0 = Arc::new(StringVector::from(vec![
191            Some("[1.0,2.0,3.0]".to_string()),
192            Some("[4.0,5.0,6.0]".to_string()),
193            None,
194            Some("[2.0,3.0,3.0]".to_string()),
195        ]));
196        let input1 = Arc::new(StringVector::from(vec![
197            Some("[1.0,1.0,1.0]".to_string()),
198            Some("[6.0,5.0,4.0]".to_string()),
199            Some("[3.0,2.0,2.0]".to_string()),
200        ]));
201
202        let result = func.eval(&FunctionContext::default(), &[input0, input1]);
203
204        match result {
205            Err(Error::InvalidFuncArgs { err_msg, .. }) => {
206                assert_eq!(
207                    err_msg,
208                    "The lengths of the vector are not aligned, args 0: 4, args 1: 3"
209                )
210            }
211            _ => unreachable!(),
212        }
213    }
214}