common_function/scalars/vector/
vector_sub.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion_expr::Signature;
20use datatypes::arrow::datatypes::DataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
23use nalgebra::DVectorView;
24use snafu::ensure;
25
26use crate::function::{Function, FunctionContext};
27use crate::helper;
28use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
29
30const NAME: &str = "vec_sub";
31
32/// Subtracts corresponding elements of two vectors, returns a vector.
33///
34/// # Example
35///
36/// ```sql
37/// SELECT vec_to_string(vec_sub("[1.0, 1.0]", "[1.0, 2.0]")) as result;
38///
39/// +---------------------------------------------------------------+
40/// | vec_to_string(vec_sub(Utf8("[1.0, 1.0]"),Utf8("[1.0, 2.0]"))) |
41/// +---------------------------------------------------------------+
42/// | [0,-1]                                                        |
43/// +---------------------------------------------------------------+
44///
45#[derive(Debug, Clone, Default)]
46pub struct VectorSubFunction;
47
48impl Function for VectorSubFunction {
49    fn name(&self) -> &str {
50        NAME
51    }
52
53    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
54        Ok(DataType::Binary)
55    }
56
57    fn signature(&self) -> Signature {
58        helper::one_of_sigs2(
59            vec![DataType::Utf8, DataType::Binary],
60            vec![DataType::Utf8, DataType::Binary],
61        )
62    }
63
64    fn eval(
65        &self,
66        _func_ctx: &FunctionContext,
67        columns: &[VectorRef],
68    ) -> common_query::error::Result<VectorRef> {
69        ensure!(
70            columns.len() == 2,
71            InvalidFuncArgsSnafu {
72                err_msg: format!(
73                    "The length of the args is not correct, expect exactly two, have: {}",
74                    columns.len()
75                )
76            }
77        );
78        let arg0 = &columns[0];
79        let arg1 = &columns[1];
80
81        ensure!(
82            arg0.len() == arg1.len(),
83            InvalidFuncArgsSnafu {
84                err_msg: format!(
85                    "The lengths of the vector are not aligned, args 0: {}, args 1: {}",
86                    arg0.len(),
87                    arg1.len(),
88                )
89            }
90        );
91
92        let len = arg0.len();
93        let mut result = BinaryVectorBuilder::with_capacity(len);
94        if len == 0 {
95            return Ok(result.to_vector());
96        }
97
98        let arg0_const = as_veclit_if_const(arg0)?;
99        let arg1_const = as_veclit_if_const(arg1)?;
100
101        for i in 0..len {
102            let arg0 = match arg0_const.as_ref() {
103                Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
104                None => as_veclit(arg0.get_ref(i))?,
105            };
106            let arg1 = match arg1_const.as_ref() {
107                Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
108                None => as_veclit(arg1.get_ref(i))?,
109            };
110            let (Some(arg0), Some(arg1)) = (arg0, arg1) else {
111                result.push_null();
112                continue;
113            };
114            let vec0 = DVectorView::from_slice(&arg0, arg0.len());
115            let vec1 = DVectorView::from_slice(&arg1, arg1.len());
116
117            let vec_res = vec0 - vec1;
118            let veclit = vec_res.as_slice();
119            let binlit = veclit_to_binlit(veclit);
120            result.push(Some(&binlit));
121        }
122
123        Ok(result.to_vector())
124    }
125}
126
127impl Display for VectorSubFunction {
128    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129        write!(f, "{}", NAME.to_ascii_uppercase())
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use std::sync::Arc;
136
137    use common_query::error::Error;
138    use datatypes::vectors::StringVector;
139
140    use super::*;
141
142    #[test]
143    fn test_sub() {
144        let func = VectorSubFunction;
145
146        let input0 = Arc::new(StringVector::from(vec![
147            Some("[1.0,2.0,3.0]".to_string()),
148            Some("[4.0,5.0,6.0]".to_string()),
149            None,
150            Some("[2.0,3.0,3.0]".to_string()),
151        ]));
152        let input1 = Arc::new(StringVector::from(vec![
153            Some("[1.0,1.0,1.0]".to_string()),
154            Some("[6.0,5.0,4.0]".to_string()),
155            Some("[3.0,2.0,2.0]".to_string()),
156            None,
157        ]));
158
159        let result = func
160            .eval(&FunctionContext::default(), &[input0, input1])
161            .unwrap();
162
163        let result = result.as_ref();
164        assert_eq!(result.len(), 4);
165        assert_eq!(
166            result.get_ref(0).as_binary().unwrap(),
167            Some(veclit_to_binlit(&[0.0, 1.0, 2.0]).as_slice())
168        );
169        assert_eq!(
170            result.get_ref(1).as_binary().unwrap(),
171            Some(veclit_to_binlit(&[-2.0, 0.0, 2.0]).as_slice())
172        );
173        assert!(result.get_ref(2).is_null());
174        assert!(result.get_ref(3).is_null());
175    }
176
177    #[test]
178    fn test_sub_error() {
179        let func = VectorSubFunction;
180
181        let input0 = Arc::new(StringVector::from(vec![
182            Some("[1.0,2.0,3.0]".to_string()),
183            Some("[4.0,5.0,6.0]".to_string()),
184            None,
185            Some("[2.0,3.0,3.0]".to_string()),
186        ]));
187        let input1 = Arc::new(StringVector::from(vec![
188            Some("[1.0,1.0,1.0]".to_string()),
189            Some("[6.0,5.0,4.0]".to_string()),
190            Some("[3.0,2.0,2.0]".to_string()),
191        ]));
192
193        let result = func.eval(&FunctionContext::default(), &[input0, input1]);
194
195        match result {
196            Err(Error::InvalidFuncArgs { err_msg, .. }) => {
197                assert_eq!(
198                    err_msg,
199                    "The lengths of the vector are not aligned, args 0: 4, args 1: 3"
200                )
201            }
202            _ => unreachable!(),
203        }
204    }
205}