common_function/scalars/vector/
vector_div.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion_expr::Signature;
20use datatypes::arrow::datatypes::DataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
23use nalgebra::DVectorView;
24use snafu::ensure;
25
26use crate::function::{Function, FunctionContext};
27use crate::helper;
28use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
29
30const NAME: &str = "vec_div";
31
32/// Divides corresponding elements of two vectors.
33///
34/// # Example
35///
36/// ```sql
37/// SELECT vec_to_string(vec_div("[2, 4, 6]", "[2, 2, 2]")) as result;
38///
39/// +---------+
40/// | result  |
41/// +---------+
42/// | [1,2,3] |
43/// +---------+
44///
45/// ```
46#[derive(Debug, Clone, Default)]
47pub struct VectorDivFunction;
48
49impl Function for VectorDivFunction {
50    fn name(&self) -> &str {
51        NAME
52    }
53
54    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
55        Ok(DataType::Binary)
56    }
57
58    fn signature(&self) -> Signature {
59        helper::one_of_sigs2(
60            vec![DataType::Utf8, DataType::Binary],
61            vec![DataType::Utf8, DataType::Binary],
62        )
63    }
64
65    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
66        ensure!(
67            columns.len() == 2,
68            InvalidFuncArgsSnafu {
69                err_msg: format!(
70                    "The length of the args is not correct, expect exactly two, have: {}",
71                    columns.len()
72                ),
73            }
74        );
75
76        let arg0 = &columns[0];
77        let arg1 = &columns[1];
78
79        let len = arg0.len();
80        let mut result = BinaryVectorBuilder::with_capacity(len);
81        if len == 0 {
82            return Ok(result.to_vector());
83        }
84
85        let arg0_const = as_veclit_if_const(arg0)?;
86        let arg1_const = as_veclit_if_const(arg1)?;
87
88        for i in 0..len {
89            let arg0 = match arg0_const.as_ref() {
90                Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
91                None => as_veclit(arg0.get_ref(i))?,
92            };
93
94            let arg1 = match arg1_const.as_ref() {
95                Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
96                None => as_veclit(arg1.get_ref(i))?,
97            };
98
99            if let (Some(arg0), Some(arg1)) = (arg0, arg1) {
100                ensure!(
101                    arg0.len() == arg1.len(),
102                    InvalidFuncArgsSnafu {
103                        err_msg: format!(
104                            "The length of the vectors must match for division, have: {} vs {}",
105                            arg0.len(),
106                            arg1.len()
107                        ),
108                    }
109                );
110                let vec0 = DVectorView::from_slice(&arg0, arg0.len());
111                let vec1 = DVectorView::from_slice(&arg1, arg1.len());
112                let vec_res = vec0.component_div(&vec1);
113
114                let veclit = vec_res.as_slice();
115                let binlit = veclit_to_binlit(veclit);
116                result.push(Some(&binlit));
117            } else {
118                result.push_null();
119            }
120        }
121
122        Ok(result.to_vector())
123    }
124}
125
126impl Display for VectorDivFunction {
127    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128        write!(f, "{}", NAME.to_ascii_uppercase())
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use std::sync::Arc;
135
136    use common_query::error;
137    use datatypes::vectors::StringVector;
138
139    use super::*;
140
141    #[test]
142    fn test_vector_mul() {
143        let func = VectorDivFunction;
144
145        let vec0 = vec![1.0, 2.0, 3.0];
146        let vec1 = vec![1.0, 1.0];
147        let (len0, len1) = (vec0.len(), vec1.len());
148        let input0 = Arc::new(StringVector::from(vec![Some(format!("{vec0:?}"))]));
149        let input1 = Arc::new(StringVector::from(vec![Some(format!("{vec1:?}"))]));
150
151        let err = func
152            .eval(&FunctionContext::default(), &[input0, input1])
153            .unwrap_err();
154
155        match err {
156            error::Error::InvalidFuncArgs { err_msg, .. } => {
157                assert_eq!(
158                    err_msg,
159                    format!(
160                        "The length of the vectors must match for division, have: {} vs {}",
161                        len0, len1
162                    )
163                )
164            }
165            _ => unreachable!(),
166        }
167
168        let input0 = Arc::new(StringVector::from(vec![
169            Some("[1.0,2.0,3.0]".to_string()),
170            Some("[8.0,10.0,12.0]".to_string()),
171            Some("[7.0,8.0,9.0]".to_string()),
172            None,
173        ]));
174
175        let input1 = Arc::new(StringVector::from(vec![
176            Some("[1.0,1.0,1.0]".to_string()),
177            Some("[2.0,2.0,2.0]".to_string()),
178            None,
179            Some("[3.0,3.0,3.0]".to_string()),
180        ]));
181
182        let result = func
183            .eval(&FunctionContext::default(), &[input0, input1])
184            .unwrap();
185
186        let result = result.as_ref();
187        assert_eq!(result.len(), 4);
188        assert_eq!(
189            result.get_ref(0).as_binary().unwrap(),
190            Some(veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice())
191        );
192        assert_eq!(
193            result.get_ref(1).as_binary().unwrap(),
194            Some(veclit_to_binlit(&[4.0, 5.0, 6.0]).as_slice())
195        );
196        assert!(result.get_ref(2).is_null());
197        assert!(result.get_ref(3).is_null());
198
199        let input0 = Arc::new(StringVector::from(vec![Some("[1.0,-2.0]".to_string())]));
200        let input1 = Arc::new(StringVector::from(vec![Some("[0.0,0.0]".to_string())]));
201
202        let result = func
203            .eval(&FunctionContext::default(), &[input0, input1])
204            .unwrap();
205
206        let result = result.as_ref();
207        assert_eq!(
208            result.get_ref(0).as_binary().unwrap(),
209            Some(veclit_to_binlit(&[f64::INFINITY as f32, f64::NEG_INFINITY as f32]).as_slice())
210        );
211    }
212}