common_function/scalars/vector/
vector_add.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion_common::{DataFusionError, ScalarValue};
20use datafusion_expr::{ScalarFunctionArgs, Signature};
21use nalgebra::DVectorView;
22
23use crate::function::Function;
24use crate::scalars::vector::impl_conv::veclit_to_binlit;
25use crate::scalars::vector::{VectorCalculator, define_args_of_two_vector_literals_udf};
26
27const NAME: &str = "vec_add";
28
29define_args_of_two_vector_literals_udf!(
30/// Adds corresponding elements of two vectors, returns a vector.
31///
32/// # Example
33///
34/// ```sql
35/// SELECT vec_to_string(vec_add("[1.0, 1.0]", "[1.0, 2.0]")) as result;
36///
37/// +---------------------------------------------------------------+
38/// | vec_to_string(vec_add(Utf8("[1.0, 1.0]"),Utf8("[1.0, 2.0]"))) |
39/// +---------------------------------------------------------------+
40/// | [2,3]                                                         |
41/// +---------------------------------------------------------------+
42///
43
44VectorAddFunction);
45
46impl Function for VectorAddFunction {
47    fn name(&self) -> &str {
48        NAME
49    }
50
51    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
52        Ok(DataType::BinaryView)
53    }
54
55    fn signature(&self) -> &Signature {
56        &self.signature
57    }
58
59    fn invoke_with_args(
60        &self,
61        args: ScalarFunctionArgs,
62    ) -> datafusion_common::Result<ColumnarValue> {
63        let body = |v0: &Option<Cow<[f32]>>,
64                    v1: &Option<Cow<[f32]>>|
65         -> datafusion_common::Result<ScalarValue> {
66            let result = if let (Some(v0), Some(v1)) = (v0, v1) {
67                let v0 = DVectorView::from_slice(v0, v0.len());
68                let v1 = DVectorView::from_slice(v1, v1.len());
69                if v0.len() != v1.len() {
70                    return Err(DataFusionError::Execution(format!(
71                        "vectors length not match: {}",
72                        self.name()
73                    )));
74                }
75
76                let result = veclit_to_binlit((v0 + v1).as_slice());
77                Some(result)
78            } else {
79                None
80            };
81            Ok(ScalarValue::BinaryView(result))
82        };
83
84        let calculator = VectorCalculator {
85            name: self.name(),
86            func: body,
87        };
88        calculator.invoke_with_vectors(args)
89    }
90}
91
92#[cfg(test)]
93mod tests {
94    use std::sync::Arc;
95
96    use arrow_schema::Field;
97    use datafusion::arrow::array::{Array, AsArray, StringViewArray};
98    use datafusion_common::config::ConfigOptions;
99
100    use super::*;
101
102    #[test]
103    fn test_sub() {
104        let func = VectorAddFunction::default();
105
106        let input0 = Arc::new(StringViewArray::from(vec![
107            Some("[1.0,2.0,3.0]".to_string()),
108            Some("[4.0,5.0,6.0]".to_string()),
109            None,
110            Some("[2.0,3.0,3.0]".to_string()),
111        ]));
112        let input1 = Arc::new(StringViewArray::from(vec![
113            Some("[1.0,1.0,1.0]".to_string()),
114            Some("[6.0,5.0,4.0]".to_string()),
115            Some("[3.0,2.0,2.0]".to_string()),
116            None,
117        ]));
118
119        let args = ScalarFunctionArgs {
120            args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
121            arg_fields: vec![],
122            number_rows: 4,
123            return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
124            config_options: Arc::new(ConfigOptions::new()),
125        };
126        let result = func
127            .invoke_with_args(args)
128            .and_then(|x| x.to_array(4))
129            .unwrap();
130
131        let result = result.as_binary_view();
132        assert_eq!(result.len(), 4);
133        assert_eq!(
134            result.value(0),
135            veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice()
136        );
137        assert_eq!(
138            result.value(1),
139            veclit_to_binlit(&[10.0, 10.0, 10.0]).as_slice()
140        );
141        assert!(result.is_null(2));
142        assert!(result.is_null(3));
143    }
144
145    #[test]
146    fn test_sub_error() {
147        let func = VectorAddFunction::default();
148
149        let input0 = Arc::new(StringViewArray::from(vec![
150            Some("[1.0,2.0,3.0]".to_string()),
151            Some("[4.0,5.0,6.0]".to_string()),
152            None,
153            Some("[2.0,3.0,3.0]".to_string()),
154        ]));
155        let input1 = Arc::new(StringViewArray::from(vec![
156            Some("[1.0,1.0,1.0]".to_string()),
157            Some("[6.0,5.0,4.0]".to_string()),
158            Some("[3.0,2.0,2.0]".to_string()),
159        ]));
160
161        let args = ScalarFunctionArgs {
162            args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
163            arg_fields: vec![],
164            number_rows: 4,
165            return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
166            config_options: Arc::new(ConfigOptions::new()),
167        };
168        let e = func.invoke_with_args(args).unwrap_err();
169        assert!(e.to_string().starts_with(
170            "Internal error: Arguments has mixed length. Expected length: 4, found length: 3."
171        ));
172    }
173}