common_function/scalars/vector/
scalar_add.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use common_query::prelude::Signature;
20use datatypes::prelude::ConcreteDataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
23use nalgebra::DVectorView;
24use snafu::ensure;
25
26use crate::function::{Function, FunctionContext};
27use crate::helper;
28use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
29
30const NAME: &str = "vec_scalar_add";
31
32/// Adds a scalar to each element of a vector.
33///
34/// # Example
35///
36/// ```sql
37/// SELECT vec_to_string(vec_scalar_add(1, "[1, 2, 3]")) as result;
38///
39/// +---------+
40/// | result  |
41/// +---------+
42/// | [2,3,4] |
43/// +---------+
44///
45/// -- Negative scalar to simulate subtraction
46/// SELECT vec_to_string(vec_scalar_add(-1, "[1, 2, 3]")) as result;
47///
48/// +---------+
49/// | result  |
50/// +---------+
51/// | [0,1,2] |
52/// +---------+
53/// ```
54#[derive(Debug, Clone, Default)]
55pub struct ScalarAddFunction;
56
57impl Function for ScalarAddFunction {
58    fn name(&self) -> &str {
59        NAME
60    }
61
62    fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
63        Ok(ConcreteDataType::binary_datatype())
64    }
65
66    fn signature(&self) -> Signature {
67        helper::one_of_sigs2(
68            vec![ConcreteDataType::float64_datatype()],
69            vec![
70                ConcreteDataType::string_datatype(),
71                ConcreteDataType::binary_datatype(),
72            ],
73        )
74    }
75
76    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
77        ensure!(
78            columns.len() == 2,
79            InvalidFuncArgsSnafu {
80                err_msg: format!(
81                    "The length of the args is not correct, expect exactly two, have: {}",
82                    columns.len()
83                ),
84            }
85        );
86        let arg0 = &columns[0];
87        let arg1 = &columns[1];
88
89        let len = arg0.len();
90        let mut result = BinaryVectorBuilder::with_capacity(len);
91        if len == 0 {
92            return Ok(result.to_vector());
93        }
94
95        let arg1_const = as_veclit_if_const(arg1)?;
96
97        for i in 0..len {
98            let arg0 = arg0.get(i).as_f64_lossy();
99            let Some(arg0) = arg0 else {
100                result.push_null();
101                continue;
102            };
103
104            let arg1 = match arg1_const.as_ref() {
105                Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
106                None => as_veclit(arg1.get_ref(i))?,
107            };
108            let Some(arg1) = arg1 else {
109                result.push_null();
110                continue;
111            };
112
113            let vec = DVectorView::from_slice(&arg1, arg1.len());
114            let vec_res = vec.add_scalar(arg0 as _);
115
116            let veclit = vec_res.as_slice();
117            let binlit = veclit_to_binlit(veclit);
118            result.push(Some(&binlit));
119        }
120
121        Ok(result.to_vector())
122    }
123}
124
125impl Display for ScalarAddFunction {
126    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127        write!(f, "{}", NAME.to_ascii_uppercase())
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use std::sync::Arc;
134
135    use datatypes::vectors::{Float32Vector, StringVector};
136
137    use super::*;
138
139    #[test]
140    fn test_scalar_add() {
141        let func = ScalarAddFunction;
142
143        let input0 = Arc::new(Float32Vector::from(vec![
144            Some(1.0),
145            Some(-1.0),
146            None,
147            Some(3.0),
148        ]));
149        let input1 = Arc::new(StringVector::from(vec![
150            Some("[1.0,2.0,3.0]".to_string()),
151            Some("[4.0,5.0,6.0]".to_string()),
152            Some("[7.0,8.0,9.0]".to_string()),
153            None,
154        ]));
155
156        let result = func
157            .eval(&FunctionContext::default(), &[input0, input1])
158            .unwrap();
159
160        let result = result.as_ref();
161        assert_eq!(result.len(), 4);
162        assert_eq!(
163            result.get_ref(0).as_binary().unwrap(),
164            Some(veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice())
165        );
166        assert_eq!(
167            result.get_ref(1).as_binary().unwrap(),
168            Some(veclit_to_binlit(&[3.0, 4.0, 5.0]).as_slice())
169        );
170        assert!(result.get_ref(2).is_null());
171        assert!(result.get_ref(3).is_null());
172    }
173}