common_function/scalars/vector/
vector_sub.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion_common::{DataFusionError, ScalarValue};
20use datafusion_expr::{ScalarFunctionArgs, Signature};
21use nalgebra::DVectorView;
22
23use crate::function::Function;
24use crate::scalars::vector::impl_conv::veclit_to_binlit;
25use crate::scalars::vector::{VectorCalculator, define_args_of_two_vector_literals_udf};
26
27const NAME: &str = "vec_sub";
28
29define_args_of_two_vector_literals_udf!(
30/// Subtracts corresponding elements of two vectors, returns a vector.
31///
32/// # Example
33///
34/// ```sql
35/// SELECT vec_to_string(vec_sub("[1.0, 1.0]", "[1.0, 2.0]")) as result;
36///
37/// +---------------------------------------------------------------+
38/// | vec_to_string(vec_sub(Utf8("[1.0, 1.0]"),Utf8("[1.0, 2.0]"))) |
39/// +---------------------------------------------------------------+
40/// | [0,-1]                                                        |
41/// +---------------------------------------------------------------+
42///
43VectorSubFunction);
44
45impl Function for VectorSubFunction {
46    fn name(&self) -> &str {
47        NAME
48    }
49
50    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
51        Ok(DataType::BinaryView)
52    }
53
54    fn signature(&self) -> &Signature {
55        &self.signature
56    }
57
58    fn invoke_with_args(
59        &self,
60        args: ScalarFunctionArgs,
61    ) -> datafusion_common::Result<ColumnarValue> {
62        let body = |v0: &Option<Cow<[f32]>>,
63                    v1: &Option<Cow<[f32]>>|
64         -> datafusion_common::Result<ScalarValue> {
65            let result = if let (Some(v0), Some(v1)) = (v0, v1) {
66                let v0 = DVectorView::from_slice(v0, v0.len());
67                let v1 = DVectorView::from_slice(v1, v1.len());
68                if v0.len() != v1.len() {
69                    return Err(DataFusionError::Execution(format!(
70                        "vectors length not match: {}",
71                        self.name()
72                    )));
73                }
74
75                let result = veclit_to_binlit((v0 - v1).as_slice());
76                Some(result)
77            } else {
78                None
79            };
80            Ok(ScalarValue::BinaryView(result))
81        };
82
83        let calculator = VectorCalculator {
84            name: self.name(),
85            func: body,
86        };
87        calculator.invoke_with_vectors(args)
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use std::sync::Arc;
94
95    use arrow_schema::Field;
96    use datafusion::arrow::array::{Array, ArrayRef, AsArray, StringViewArray};
97    use datafusion_common::config::ConfigOptions;
98
99    use super::*;
100
101    #[test]
102    fn test_sub() {
103        let func = VectorSubFunction::default();
104
105        let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
106            Some("[1.0,2.0,3.0]".to_string()),
107            Some("[4.0,5.0,6.0]".to_string()),
108            None,
109            Some("[2.0,3.0,3.0]".to_string()),
110        ]));
111        let input1: ArrayRef = Arc::new(StringViewArray::from(vec![
112            Some("[1.0,1.0,1.0]".to_string()),
113            Some("[6.0,5.0,4.0]".to_string()),
114            Some("[3.0,2.0,2.0]".to_string()),
115            None,
116        ]));
117
118        let args = ScalarFunctionArgs {
119            args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
120            arg_fields: vec![],
121            number_rows: 4,
122            return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
123            config_options: Arc::new(ConfigOptions::new()),
124        };
125        let result = func
126            .invoke_with_args(args)
127            .and_then(|x| x.to_array(4))
128            .unwrap();
129
130        let result = result.as_binary_view();
131        assert_eq!(result.len(), 4);
132        assert_eq!(
133            result.value(0),
134            veclit_to_binlit(&[0.0, 1.0, 2.0]).as_slice()
135        );
136        assert_eq!(
137            result.value(1),
138            veclit_to_binlit(&[-2.0, 0.0, 2.0]).as_slice()
139        );
140        assert!(result.is_null(2));
141        assert!(result.is_null(3));
142    }
143
144    #[test]
145    fn test_sub_error() {
146        let func = VectorSubFunction::default();
147
148        let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
149            Some("[1.0,2.0,3.0]".to_string()),
150            Some("[4.0,5.0,6.0]".to_string()),
151            None,
152            Some("[2.0,3.0,3.0]".to_string()),
153        ]));
154        let input1: ArrayRef = Arc::new(StringViewArray::from(vec![
155            Some("[1.0,1.0,1.0]".to_string()),
156            Some("[6.0,5.0,4.0]".to_string()),
157            Some("[3.0,2.0,2.0]".to_string()),
158        ]));
159
160        let args = ScalarFunctionArgs {
161            args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
162            arg_fields: vec![],
163            number_rows: 4,
164            return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
165            config_options: Arc::new(ConfigOptions::new()),
166        };
167        let e = func.invoke_with_args(args).unwrap_err();
168        assert!(e.to_string().starts_with(
169            "Internal error: Arguments has mixed length. Expected length: 4, found length: 3."
170        ));
171    }
172}