common_function/scalars/vector/
scalar_add.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion_expr::Signature;
20use datatypes::arrow::datatypes::DataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
23use nalgebra::DVectorView;
24use snafu::ensure;
25
26use crate::function::{Function, FunctionContext};
27use crate::helper;
28use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
29
30const NAME: &str = "vec_scalar_add";
31
32/// Adds a scalar to each element of a vector.
33///
34/// # Example
35///
36/// ```sql
37/// SELECT vec_to_string(vec_scalar_add(1, "[1, 2, 3]")) as result;
38///
39/// +---------+
40/// | result  |
41/// +---------+
42/// | [2,3,4] |
43/// +---------+
44///
45/// -- Negative scalar to simulate subtraction
46/// SELECT vec_to_string(vec_scalar_add(-1, "[1, 2, 3]")) as result;
47///
48/// +---------+
49/// | result  |
50/// +---------+
51/// | [0,1,2] |
52/// +---------+
53/// ```
54#[derive(Debug, Clone, Default)]
55pub struct ScalarAddFunction;
56
57impl Function for ScalarAddFunction {
58    fn name(&self) -> &str {
59        NAME
60    }
61
62    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
63        Ok(DataType::Binary)
64    }
65
66    fn signature(&self) -> Signature {
67        helper::one_of_sigs2(
68            vec![DataType::Float64],
69            vec![DataType::Utf8, DataType::Binary],
70        )
71    }
72
73    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
74        ensure!(
75            columns.len() == 2,
76            InvalidFuncArgsSnafu {
77                err_msg: format!(
78                    "The length of the args is not correct, expect exactly two, have: {}",
79                    columns.len()
80                ),
81            }
82        );
83        let arg0 = &columns[0];
84        let arg1 = &columns[1];
85
86        let len = arg0.len();
87        let mut result = BinaryVectorBuilder::with_capacity(len);
88        if len == 0 {
89            return Ok(result.to_vector());
90        }
91
92        let arg1_const = as_veclit_if_const(arg1)?;
93
94        for i in 0..len {
95            let arg0 = arg0.get(i).as_f64_lossy();
96            let Some(arg0) = arg0 else {
97                result.push_null();
98                continue;
99            };
100
101            let arg1 = match arg1_const.as_ref() {
102                Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
103                None => as_veclit(arg1.get_ref(i))?,
104            };
105            let Some(arg1) = arg1 else {
106                result.push_null();
107                continue;
108            };
109
110            let vec = DVectorView::from_slice(&arg1, arg1.len());
111            let vec_res = vec.add_scalar(arg0 as _);
112
113            let veclit = vec_res.as_slice();
114            let binlit = veclit_to_binlit(veclit);
115            result.push(Some(&binlit));
116        }
117
118        Ok(result.to_vector())
119    }
120}
121
122impl Display for ScalarAddFunction {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        write!(f, "{}", NAME.to_ascii_uppercase())
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use std::sync::Arc;
131
132    use datatypes::vectors::{Float32Vector, StringVector};
133
134    use super::*;
135
136    #[test]
137    fn test_scalar_add() {
138        let func = ScalarAddFunction;
139
140        let input0 = Arc::new(Float32Vector::from(vec![
141            Some(1.0),
142            Some(-1.0),
143            None,
144            Some(3.0),
145        ]));
146        let input1 = Arc::new(StringVector::from(vec![
147            Some("[1.0,2.0,3.0]".to_string()),
148            Some("[4.0,5.0,6.0]".to_string()),
149            Some("[7.0,8.0,9.0]".to_string()),
150            None,
151        ]));
152
153        let result = func
154            .eval(&FunctionContext::default(), &[input0, input1])
155            .unwrap();
156
157        let result = result.as_ref();
158        assert_eq!(result.len(), 4);
159        assert_eq!(
160            result.get_ref(0).as_binary().unwrap(),
161            Some(veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice())
162        );
163        assert_eq!(
164            result.get_ref(1).as_binary().unwrap(),
165            Some(veclit_to_binlit(&[3.0, 4.0, 5.0]).as_slice())
166        );
167        assert!(result.get_ref(2).is_null());
168        assert!(result.get_ref(3).is_null());
169    }
170}