common_function/scalars/
vector.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod convert;
16pub mod distance;
17mod elem_avg;
18mod elem_product;
19mod elem_sum;
20pub mod impl_conv;
21mod scalar_add;
22mod scalar_mul;
23mod vector_add;
24mod vector_dim;
25mod vector_div;
26mod vector_kth_elem;
27mod vector_mul;
28mod vector_norm;
29mod vector_sub;
30mod vector_subvector;
31
32use std::borrow::Cow;
33
34use datafusion_common::{DataFusionError, Result, ScalarValue, utils};
35use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
36use datatypes::arrow::array::new_empty_array;
37
38use crate::function_registry::FunctionRegistry;
39use crate::scalars::vector::impl_conv::as_veclit;
40
41pub(crate) struct VectorFunction;
42
43impl VectorFunction {
44    pub fn register(registry: &FunctionRegistry) {
45        // conversion
46        registry.register_scalar(convert::ParseVectorFunction::default());
47        registry.register_scalar(convert::VectorToStringFunction::default());
48
49        // distance
50        registry.register_scalar(distance::CosDistanceFunction::default());
51        registry.register_scalar(distance::DotProductFunction::default());
52        registry.register_scalar(distance::L2SqDistanceFunction::default());
53
54        // scalar calculation
55        registry.register_scalar(scalar_add::ScalarAddFunction::default());
56        registry.register_scalar(scalar_mul::ScalarMulFunction::default());
57
58        // vector calculation
59        registry.register_scalar(vector_add::VectorAddFunction::default());
60        registry.register_scalar(vector_sub::VectorSubFunction::default());
61        registry.register_scalar(vector_mul::VectorMulFunction::default());
62        registry.register_scalar(vector_div::VectorDivFunction::default());
63        registry.register_scalar(vector_norm::VectorNormFunction::default());
64        registry.register_scalar(vector_dim::VectorDimFunction::default());
65        registry.register_scalar(vector_kth_elem::VectorKthElemFunction::default());
66        registry.register_scalar(vector_subvector::VectorSubvectorFunction::default());
67        registry.register_scalar(elem_sum::ElemSumFunction::default());
68        registry.register_scalar(elem_product::ElemProductFunction::default());
69        registry.register_scalar(elem_avg::ElemAvgFunction::default());
70    }
71}
72
73// Use macro instead of function to "return" the reference to `ScalarValue` in the
74// `ColumnarValue::Array` match arm.
75macro_rules! try_get_scalar_value {
76    ($col: ident, $i: ident) => {
77        match $col {
78            datafusion::logical_expr::ColumnarValue::Array(a) => {
79                &datafusion_common::ScalarValue::try_from_array(a.as_ref(), $i)?
80            }
81            datafusion::logical_expr::ColumnarValue::Scalar(v) => v,
82        }
83    };
84}
85
86pub(crate) fn ensure_same_length(values: &[&ColumnarValue]) -> Result<usize> {
87    if values.is_empty() {
88        return Ok(0);
89    }
90
91    let mut array_len = None;
92    for v in values {
93        array_len = match (v, array_len) {
94            (ColumnarValue::Array(a), None) => Some(a.len()),
95            (ColumnarValue::Array(a), Some(array_len)) => {
96                if array_len == a.len() {
97                    Some(array_len)
98                } else {
99                    return Err(DataFusionError::Internal(format!(
100                        "Arguments has mixed length. Expected length: {array_len}, found length: {}",
101                        a.len()
102                    )));
103                }
104            }
105            (ColumnarValue::Scalar(_), array_len) => array_len,
106        }
107    }
108
109    // If array_len is none, it means there are only scalars, treat them each as 1 element array.
110    let array_len = array_len.unwrap_or(1);
111    Ok(array_len)
112}
113
114struct VectorCalculator<'a, F> {
115    name: &'a str,
116    func: F,
117}
118
119impl<F> VectorCalculator<'_, F>
120where
121    F: Fn(&ScalarValue, &ScalarValue) -> Result<ScalarValue>,
122{
123    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
124        let [arg0, arg1] = utils::take_function_args(self.name, &args.args)?;
125
126        if let (ColumnarValue::Scalar(v0), ColumnarValue::Scalar(v1)) = (arg0, arg1) {
127            let result = (self.func)(v0, v1)?;
128            return Ok(ColumnarValue::Scalar(result));
129        }
130
131        let len = ensure_same_length(&[arg0, arg1])?;
132        if len == 0 {
133            return Ok(ColumnarValue::Array(new_empty_array(
134                args.return_field.data_type(),
135            )));
136        }
137        let mut results = Vec::with_capacity(len);
138        for i in 0..len {
139            let v0 = try_get_scalar_value!(arg0, i);
140            let v1 = try_get_scalar_value!(arg1, i);
141            results.push((self.func)(v0, v1)?);
142        }
143
144        let results = ScalarValue::iter_to_array(results.into_iter())?;
145        Ok(ColumnarValue::Array(results))
146    }
147}
148
149impl<F> VectorCalculator<'_, F>
150where
151    F: Fn(&Option<Cow<[f32]>>, &Option<Cow<[f32]>>) -> Result<ScalarValue>,
152{
153    fn invoke_with_vectors(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
154        let [arg0, arg1] = utils::take_function_args(self.name, &args.args)?;
155
156        if let (ColumnarValue::Scalar(v0), ColumnarValue::Scalar(v1)) = (arg0, arg1) {
157            let v0 = as_veclit(v0)?;
158            let v1 = as_veclit(v1)?;
159            let result = (self.func)(&v0, &v1)?;
160            return Ok(ColumnarValue::Scalar(result));
161        }
162
163        let len = ensure_same_length(&[arg0, arg1])?;
164        if len == 0 {
165            return Ok(ColumnarValue::Array(new_empty_array(
166                args.return_field.data_type(),
167            )));
168        }
169        let mut results = Vec::with_capacity(len);
170
171        match (arg0, arg1) {
172            (ColumnarValue::Scalar(v0), ColumnarValue::Array(a1)) => {
173                let v0 = as_veclit(v0)?;
174                for i in 0..len {
175                    let v1 = ScalarValue::try_from_array(a1, i)?;
176                    let v1 = as_veclit(&v1)?;
177                    results.push((self.func)(&v0, &v1)?);
178                }
179            }
180            (ColumnarValue::Array(a0), ColumnarValue::Scalar(v1)) => {
181                let v1 = as_veclit(v1)?;
182                for i in 0..len {
183                    let v0 = ScalarValue::try_from_array(a0, i)?;
184                    let v0 = as_veclit(&v0)?;
185                    results.push((self.func)(&v0, &v1)?);
186                }
187            }
188            (ColumnarValue::Array(a0), ColumnarValue::Array(a1)) => {
189                for i in 0..len {
190                    let v0 = ScalarValue::try_from_array(a0, i)?;
191                    let v0 = as_veclit(&v0)?;
192                    let v1 = ScalarValue::try_from_array(a1, i)?;
193                    let v1 = as_veclit(&v1)?;
194                    results.push((self.func)(&v0, &v1)?);
195                }
196            }
197            (ColumnarValue::Scalar(_), ColumnarValue::Scalar(_)) => {
198                // unreachable because this arm has been separately dealt with above
199                unreachable!()
200            }
201        }
202
203        let results = ScalarValue::iter_to_array(results.into_iter())?;
204        Ok(ColumnarValue::Array(results))
205    }
206}
207
208impl<F> VectorCalculator<'_, F>
209where
210    F: Fn(&ScalarValue) -> Result<ScalarValue>,
211{
212    fn invoke_with_single_argument(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
213        let [arg0] = utils::take_function_args(self.name, &args.args)?;
214
215        let arg0 = match arg0 {
216            ColumnarValue::Scalar(v) => {
217                let result = (self.func)(v)?;
218                return Ok(ColumnarValue::Scalar(result));
219            }
220            ColumnarValue::Array(a) => a,
221        };
222
223        let len = arg0.len();
224        if len == 0 {
225            return Ok(ColumnarValue::Array(new_empty_array(
226                args.return_field.data_type(),
227            )));
228        }
229        let mut results = Vec::with_capacity(len);
230        for i in 0..len {
231            let v = ScalarValue::try_from_array(arg0, i)?;
232            results.push((self.func)(&v)?);
233        }
234
235        let results = ScalarValue::iter_to_array(results.into_iter())?;
236        Ok(ColumnarValue::Array(results))
237    }
238}
239
240macro_rules! define_args_of_two_vector_literals_udf {
241    ($(#[$attr:meta])* $name: ident) => {
242        $(#[$attr])*
243        #[derive(Debug, Clone)]
244        pub(crate) struct $name {
245            signature: datafusion_expr::Signature,
246        }
247
248        impl Default for $name {
249            fn default() -> Self {
250                use arrow::datatypes::DataType;
251
252                Self {
253                    signature: crate::helper::one_of_sigs2(
254                        vec![
255                            DataType::Utf8,
256                            DataType::Utf8View,
257                            DataType::Binary,
258                            DataType::BinaryView,
259                        ],
260                        vec![
261                            DataType::Utf8,
262                            DataType::Utf8View,
263                            DataType::Binary,
264                            DataType::BinaryView,
265                        ],
266                    ),
267                }
268            }
269        }
270
271        impl std::fmt::Display for $name {
272            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273                write!(f, "{}", self.name().to_ascii_uppercase())
274            }
275        }
276    };
277}
278
279pub(crate) use define_args_of_two_vector_literals_udf;