common_function/scalars/
vector.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod convert;
16mod distance;
17mod elem_product;
18mod elem_sum;
19pub mod impl_conv;
20mod scalar_add;
21mod scalar_mul;
22mod vector_add;
23mod vector_dim;
24mod vector_div;
25mod vector_kth_elem;
26mod vector_mul;
27mod vector_norm;
28mod vector_sub;
29mod vector_subvector;
30
31use std::borrow::Cow;
32
33use datafusion_common::{DataFusionError, Result, ScalarValue, utils};
34use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
35
36use crate::function_registry::FunctionRegistry;
37use crate::scalars::vector::impl_conv::as_veclit;
38
39pub(crate) struct VectorFunction;
40
41impl VectorFunction {
42    pub fn register(registry: &FunctionRegistry) {
43        // conversion
44        registry.register_scalar(convert::ParseVectorFunction::default());
45        registry.register_scalar(convert::VectorToStringFunction::default());
46
47        // distance
48        registry.register_scalar(distance::CosDistanceFunction::default());
49        registry.register_scalar(distance::DotProductFunction::default());
50        registry.register_scalar(distance::L2SqDistanceFunction::default());
51
52        // scalar calculation
53        registry.register_scalar(scalar_add::ScalarAddFunction::default());
54        registry.register_scalar(scalar_mul::ScalarMulFunction::default());
55
56        // vector calculation
57        registry.register_scalar(vector_add::VectorAddFunction::default());
58        registry.register_scalar(vector_sub::VectorSubFunction::default());
59        registry.register_scalar(vector_mul::VectorMulFunction::default());
60        registry.register_scalar(vector_div::VectorDivFunction::default());
61        registry.register_scalar(vector_norm::VectorNormFunction::default());
62        registry.register_scalar(vector_dim::VectorDimFunction::default());
63        registry.register_scalar(vector_kth_elem::VectorKthElemFunction::default());
64        registry.register_scalar(vector_subvector::VectorSubvectorFunction::default());
65        registry.register_scalar(elem_sum::ElemSumFunction::default());
66        registry.register_scalar(elem_product::ElemProductFunction::default());
67    }
68}
69
70// Use macro instead of function to "return" the reference to `ScalarValue` in the
71// `ColumnarValue::Array` match arm.
72macro_rules! try_get_scalar_value {
73    ($col: ident, $i: ident) => {
74        match $col {
75            datafusion::logical_expr::ColumnarValue::Array(a) => {
76                &datafusion_common::ScalarValue::try_from_array(a.as_ref(), $i)?
77            }
78            datafusion::logical_expr::ColumnarValue::Scalar(v) => v,
79        }
80    };
81}
82
83pub(crate) fn ensure_same_length(values: &[&ColumnarValue]) -> Result<usize> {
84    if values.is_empty() {
85        return Ok(0);
86    }
87
88    let mut array_len = None;
89    for v in values {
90        array_len = match (v, array_len) {
91            (ColumnarValue::Array(a), None) => Some(a.len()),
92            (ColumnarValue::Array(a), Some(array_len)) => {
93                if array_len == a.len() {
94                    Some(array_len)
95                } else {
96                    return Err(DataFusionError::Internal(format!(
97                        "Arguments has mixed length. Expected length: {array_len}, found length: {}",
98                        a.len()
99                    )));
100                }
101            }
102            (ColumnarValue::Scalar(_), array_len) => array_len,
103        }
104    }
105
106    // If array_len is none, it means there are only scalars, treat them each as 1 element array.
107    let array_len = array_len.unwrap_or(1);
108    Ok(array_len)
109}
110
111struct VectorCalculator<'a, F> {
112    name: &'a str,
113    func: F,
114}
115
116impl<F> VectorCalculator<'_, F>
117where
118    F: Fn(&ScalarValue, &ScalarValue) -> Result<ScalarValue>,
119{
120    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
121        let [arg0, arg1] = utils::take_function_args(self.name, &args.args)?;
122
123        if let (ColumnarValue::Scalar(v0), ColumnarValue::Scalar(v1)) = (arg0, arg1) {
124            let result = (self.func)(v0, v1)?;
125            return Ok(ColumnarValue::Scalar(result));
126        }
127
128        let len = ensure_same_length(&[arg0, arg1])?;
129        let mut results = Vec::with_capacity(len);
130        for i in 0..len {
131            let v0 = try_get_scalar_value!(arg0, i);
132            let v1 = try_get_scalar_value!(arg1, i);
133            results.push((self.func)(v0, v1)?);
134        }
135
136        let results = ScalarValue::iter_to_array(results.into_iter())?;
137        Ok(ColumnarValue::Array(results))
138    }
139}
140
141impl<F> VectorCalculator<'_, F>
142where
143    F: Fn(&Option<Cow<[f32]>>, &Option<Cow<[f32]>>) -> Result<ScalarValue>,
144{
145    fn invoke_with_vectors(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
146        let [arg0, arg1] = utils::take_function_args(self.name, &args.args)?;
147
148        if let (ColumnarValue::Scalar(v0), ColumnarValue::Scalar(v1)) = (arg0, arg1) {
149            let v0 = as_veclit(v0)?;
150            let v1 = as_veclit(v1)?;
151            let result = (self.func)(&v0, &v1)?;
152            return Ok(ColumnarValue::Scalar(result));
153        }
154
155        let len = ensure_same_length(&[arg0, arg1])?;
156        let mut results = Vec::with_capacity(len);
157
158        match (arg0, arg1) {
159            (ColumnarValue::Scalar(v0), ColumnarValue::Array(a1)) => {
160                let v0 = as_veclit(v0)?;
161                for i in 0..len {
162                    let v1 = ScalarValue::try_from_array(a1, i)?;
163                    let v1 = as_veclit(&v1)?;
164                    results.push((self.func)(&v0, &v1)?);
165                }
166            }
167            (ColumnarValue::Array(a0), ColumnarValue::Scalar(v1)) => {
168                let v1 = as_veclit(v1)?;
169                for i in 0..len {
170                    let v0 = ScalarValue::try_from_array(a0, i)?;
171                    let v0 = as_veclit(&v0)?;
172                    results.push((self.func)(&v0, &v1)?);
173                }
174            }
175            (ColumnarValue::Array(a0), ColumnarValue::Array(a1)) => {
176                for i in 0..len {
177                    let v0 = ScalarValue::try_from_array(a0, i)?;
178                    let v0 = as_veclit(&v0)?;
179                    let v1 = ScalarValue::try_from_array(a1, i)?;
180                    let v1 = as_veclit(&v1)?;
181                    results.push((self.func)(&v0, &v1)?);
182                }
183            }
184            (ColumnarValue::Scalar(_), ColumnarValue::Scalar(_)) => {
185                // unreachable because this arm has been separately dealt with above
186                unreachable!()
187            }
188        }
189
190        let results = ScalarValue::iter_to_array(results.into_iter())?;
191        Ok(ColumnarValue::Array(results))
192    }
193}
194
195impl<F> VectorCalculator<'_, F>
196where
197    F: Fn(&ScalarValue) -> Result<ScalarValue>,
198{
199    fn invoke_with_single_argument(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
200        let [arg0] = utils::take_function_args(self.name, &args.args)?;
201
202        let arg0 = match arg0 {
203            ColumnarValue::Scalar(v) => {
204                let result = (self.func)(v)?;
205                return Ok(ColumnarValue::Scalar(result));
206            }
207            ColumnarValue::Array(a) => a,
208        };
209
210        let len = arg0.len();
211        let mut results = Vec::with_capacity(len);
212        for i in 0..len {
213            let v = ScalarValue::try_from_array(arg0, i)?;
214            results.push((self.func)(&v)?);
215        }
216
217        let results = ScalarValue::iter_to_array(results.into_iter())?;
218        Ok(ColumnarValue::Array(results))
219    }
220}
221
222macro_rules! define_args_of_two_vector_literals_udf {
223    ($(#[$attr:meta])* $name: ident) => {
224        $(#[$attr])*
225        #[derive(Debug, Clone)]
226        pub(crate) struct $name {
227            signature: datafusion_expr::Signature,
228        }
229
230        impl Default for $name {
231            fn default() -> Self {
232                use arrow::datatypes::DataType;
233
234                Self {
235                    signature: crate::helper::one_of_sigs2(
236                        vec![
237                            DataType::Utf8,
238                            DataType::Utf8View,
239                            DataType::Binary,
240                            DataType::BinaryView,
241                        ],
242                        vec![
243                            DataType::Utf8,
244                            DataType::Utf8View,
245                            DataType::Binary,
246                            DataType::BinaryView,
247                        ],
248                    ),
249                }
250            }
251        }
252
253        impl std::fmt::Display for $name {
254            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255                write!(f, "{}", self.name().to_ascii_uppercase())
256            }
257        }
258    };
259}
260
261pub(crate) use define_args_of_two_vector_literals_udf;