common_function/scalars/vector/
vector_div.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion_common::{DataFusionError, ScalarValue};
20use datafusion_expr::{ScalarFunctionArgs, Signature};
21use nalgebra::DVectorView;
22
23use crate::function::Function;
24use crate::scalars::vector::impl_conv::veclit_to_binlit;
25use crate::scalars::vector::{VectorCalculator, define_args_of_two_vector_literals_udf};
26
27const NAME: &str = "vec_div";
28
29define_args_of_two_vector_literals_udf!(
30/// Divides corresponding elements of two vectors.
31///
32/// # Example
33///
34/// ```sql
35/// SELECT vec_to_string(vec_div("[2, 4, 6]", "[2, 2, 2]")) as result;
36///
37/// +---------+
38/// | result  |
39/// +---------+
40/// | [1,2,3] |
41/// +---------+
42///
43/// ```
44
45VectorDivFunction);
46
47impl Function for VectorDivFunction {
48    fn name(&self) -> &str {
49        NAME
50    }
51
52    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
53        Ok(DataType::BinaryView)
54    }
55
56    fn signature(&self) -> &Signature {
57        &self.signature
58    }
59
60    fn invoke_with_args(
61        &self,
62        args: ScalarFunctionArgs,
63    ) -> datafusion_common::Result<ColumnarValue> {
64        let body = |v0: &Option<Cow<[f32]>>,
65                    v1: &Option<Cow<[f32]>>|
66         -> datafusion_common::Result<ScalarValue> {
67            let result = if let (Some(v0), Some(v1)) = (v0, v1) {
68                let v0 = DVectorView::from_slice(v0, v0.len());
69                let v1 = DVectorView::from_slice(v1, v1.len());
70                if v0.len() != v1.len() {
71                    return Err(DataFusionError::Execution(format!(
72                        "vectors length not match: {}",
73                        self.name()
74                    )));
75                }
76
77                let result = veclit_to_binlit((v0.component_div(&v1)).as_slice());
78                Some(result)
79            } else {
80                None
81            };
82            Ok(ScalarValue::BinaryView(result))
83        };
84
85        let calculator = VectorCalculator {
86            name: self.name(),
87            func: body,
88        };
89        calculator.invoke_with_vectors(args)
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use std::sync::Arc;
96
97    use arrow_schema::Field;
98    use datafusion::arrow::array::{Array, AsArray, StringViewArray};
99    use datafusion_common::config::ConfigOptions;
100
101    use super::*;
102
103    #[test]
104    fn test_vector_mul() {
105        let func = VectorDivFunction::default();
106
107        let vec0 = vec![1.0, 2.0, 3.0];
108        let vec1 = vec![1.0, 1.0];
109        let input0 = Arc::new(StringViewArray::from(vec![Some(format!("{vec0:?}"))]));
110        let input1 = Arc::new(StringViewArray::from(vec![Some(format!("{vec1:?}"))]));
111
112        let args = ScalarFunctionArgs {
113            args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
114            arg_fields: vec![],
115            number_rows: 3,
116            return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
117            config_options: Arc::new(ConfigOptions::new()),
118        };
119        let e = func.invoke_with_args(args).unwrap_err();
120        assert_eq!(
121            e.to_string(),
122            "Execution error: vectors length not match: vec_div"
123        );
124
125        let input0 = Arc::new(StringViewArray::from(vec![
126            Some("[1.0,2.0,3.0]".to_string()),
127            Some("[8.0,10.0,12.0]".to_string()),
128            Some("[7.0,8.0,9.0]".to_string()),
129            None,
130        ]));
131
132        let input1 = Arc::new(StringViewArray::from(vec![
133            Some("[1.0,1.0,1.0]".to_string()),
134            Some("[2.0,2.0,2.0]".to_string()),
135            None,
136            Some("[3.0,3.0,3.0]".to_string()),
137        ]));
138
139        let args = ScalarFunctionArgs {
140            args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
141            arg_fields: vec![],
142            number_rows: 4,
143            return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
144            config_options: Arc::new(ConfigOptions::new()),
145        };
146        let result = func
147            .invoke_with_args(args)
148            .and_then(|x| x.to_array(4))
149            .unwrap();
150
151        let result = result.as_binary_view();
152        assert_eq!(result.len(), 4);
153        assert_eq!(
154            result.value(0),
155            veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice()
156        );
157        assert_eq!(
158            result.value(1),
159            veclit_to_binlit(&[4.0, 5.0, 6.0]).as_slice()
160        );
161        assert!(result.is_null(2));
162        assert!(result.is_null(3));
163
164        let input0 = Arc::new(StringViewArray::from(vec![Some("[1.0,-2.0]".to_string())]));
165        let input1 = Arc::new(StringViewArray::from(vec![Some("[0.0,0.0]".to_string())]));
166
167        let args = ScalarFunctionArgs {
168            args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
169            arg_fields: vec![],
170            number_rows: 2,
171            return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
172            config_options: Arc::new(ConfigOptions::new()),
173        };
174        let result = func
175            .invoke_with_args(args)
176            .and_then(|x| x.to_array(2))
177            .unwrap();
178
179        let result = result.as_binary_view();
180        assert_eq!(
181            result.value(0),
182            veclit_to_binlit(&[f64::INFINITY as f32, f64::NEG_INFINITY as f32]).as_slice()
183        );
184    }
185}