common_function/scalars/vector/
scalar_add.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion_common::ScalarValue;
20use datafusion_expr::{ScalarFunctionArgs, Signature};
21use nalgebra::DVectorView;
22
23use crate::function::Function;
24use crate::helper;
25use crate::scalars::vector::VectorCalculator;
26use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
27
28const NAME: &str = "vec_scalar_add";
29
30/// Adds a scalar to each element of a vector.
31///
32/// # Example
33///
34/// ```sql
35/// SELECT vec_to_string(vec_scalar_add(1, "[1, 2, 3]")) as result;
36///
37/// +---------+
38/// | result  |
39/// +---------+
40/// | [2,3,4] |
41/// +---------+
42///
43/// -- Negative scalar to simulate subtraction
44/// SELECT vec_to_string(vec_scalar_add(-1, "[1, 2, 3]")) as result;
45///
46/// +---------+
47/// | result  |
48/// +---------+
49/// | [0,1,2] |
50/// +---------+
51/// ```
52#[derive(Debug, Clone)]
53pub(crate) struct ScalarAddFunction {
54    signature: Signature,
55}
56
57impl Default for ScalarAddFunction {
58    fn default() -> Self {
59        Self {
60            signature: helper::one_of_sigs2(
61                vec![DataType::Float64],
62                vec![
63                    DataType::Utf8,
64                    DataType::Utf8View,
65                    DataType::Binary,
66                    DataType::BinaryView,
67                ],
68            ),
69        }
70    }
71}
72
73impl Function for ScalarAddFunction {
74    fn name(&self) -> &str {
75        NAME
76    }
77
78    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
79        Ok(DataType::BinaryView)
80    }
81
82    fn signature(&self) -> &Signature {
83        &self.signature
84    }
85
86    fn invoke_with_args(
87        &self,
88        args: ScalarFunctionArgs,
89    ) -> datafusion_common::Result<ColumnarValue> {
90        let body = |v0: &ScalarValue, v1: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
91            let ScalarValue::Float64(Some(v0)) = v0 else {
92                return Ok(ScalarValue::BinaryView(None));
93            };
94
95            let v1 = as_veclit(v1)?
96                .map(|v1| DVectorView::from_slice(&v1, v1.len()).add_scalar(*v0 as f32));
97            let result = v1.map(|v1| veclit_to_binlit(v1.as_slice()));
98            Ok(ScalarValue::BinaryView(result))
99        };
100
101        let calculator = VectorCalculator {
102            name: self.name(),
103            func: body,
104        };
105        calculator.invoke_with_args(args)
106    }
107}
108
109impl Display for ScalarAddFunction {
110    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111        write!(f, "{}", NAME.to_ascii_uppercase())
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use std::sync::Arc;
118
119    use arrow_schema::Field;
120    use datafusion::arrow::array::{Array, AsArray, Float64Array, StringViewArray};
121    use datafusion_common::config::ConfigOptions;
122
123    use super::*;
124
125    #[test]
126    fn test_scalar_add() {
127        let func = ScalarAddFunction::default();
128
129        let input0 = Arc::new(Float64Array::from(vec![
130            Some(1.0),
131            Some(-1.0),
132            None,
133            Some(3.0),
134        ]));
135        let input1 = Arc::new(StringViewArray::from(vec![
136            Some("[1.0,2.0,3.0]".to_string()),
137            Some("[4.0,5.0,6.0]".to_string()),
138            Some("[7.0,8.0,9.0]".to_string()),
139            None,
140        ]));
141
142        let args = ScalarFunctionArgs {
143            args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
144            arg_fields: vec![],
145            number_rows: 4,
146            return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
147            config_options: Arc::new(ConfigOptions::new()),
148        };
149        let result = func
150            .invoke_with_args(args)
151            .and_then(|x| x.to_array(4))
152            .unwrap();
153
154        let result = result.as_binary_view();
155        assert_eq!(result.len(), 4);
156        assert_eq!(
157            result.value(0),
158            veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice()
159        );
160        assert_eq!(
161            result.value(1),
162            veclit_to_binlit(&[3.0, 4.0, 5.0]).as_slice()
163        );
164        assert!(result.is_null(2));
165        assert!(result.is_null(3));
166    }
167}