common_function/scalars/
vector.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod convert;
16mod distance;
17mod elem_avg;
18mod elem_product;
19mod elem_sum;
20pub mod impl_conv;
21mod scalar_add;
22mod scalar_mul;
23mod vector_add;
24mod vector_dim;
25mod vector_div;
26mod vector_kth_elem;
27mod vector_mul;
28mod vector_norm;
29mod vector_sub;
30mod vector_subvector;
31
32use std::borrow::Cow;
33
34use datafusion_common::{DataFusionError, Result, ScalarValue, utils};
35use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
36
37use crate::function_registry::FunctionRegistry;
38use crate::scalars::vector::impl_conv::as_veclit;
39
40pub(crate) struct VectorFunction;
41
42impl VectorFunction {
43    pub fn register(registry: &FunctionRegistry) {
44        // conversion
45        registry.register_scalar(convert::ParseVectorFunction::default());
46        registry.register_scalar(convert::VectorToStringFunction::default());
47
48        // distance
49        registry.register_scalar(distance::CosDistanceFunction::default());
50        registry.register_scalar(distance::DotProductFunction::default());
51        registry.register_scalar(distance::L2SqDistanceFunction::default());
52
53        // scalar calculation
54        registry.register_scalar(scalar_add::ScalarAddFunction::default());
55        registry.register_scalar(scalar_mul::ScalarMulFunction::default());
56
57        // vector calculation
58        registry.register_scalar(vector_add::VectorAddFunction::default());
59        registry.register_scalar(vector_sub::VectorSubFunction::default());
60        registry.register_scalar(vector_mul::VectorMulFunction::default());
61        registry.register_scalar(vector_div::VectorDivFunction::default());
62        registry.register_scalar(vector_norm::VectorNormFunction::default());
63        registry.register_scalar(vector_dim::VectorDimFunction::default());
64        registry.register_scalar(vector_kth_elem::VectorKthElemFunction::default());
65        registry.register_scalar(vector_subvector::VectorSubvectorFunction::default());
66        registry.register_scalar(elem_sum::ElemSumFunction::default());
67        registry.register_scalar(elem_product::ElemProductFunction::default());
68        registry.register_scalar(elem_avg::ElemAvgFunction::default());
69    }
70}
71
72// Use macro instead of function to "return" the reference to `ScalarValue` in the
73// `ColumnarValue::Array` match arm.
74macro_rules! try_get_scalar_value {
75    ($col: ident, $i: ident) => {
76        match $col {
77            datafusion::logical_expr::ColumnarValue::Array(a) => {
78                &datafusion_common::ScalarValue::try_from_array(a.as_ref(), $i)?
79            }
80            datafusion::logical_expr::ColumnarValue::Scalar(v) => v,
81        }
82    };
83}
84
85pub(crate) fn ensure_same_length(values: &[&ColumnarValue]) -> Result<usize> {
86    if values.is_empty() {
87        return Ok(0);
88    }
89
90    let mut array_len = None;
91    for v in values {
92        array_len = match (v, array_len) {
93            (ColumnarValue::Array(a), None) => Some(a.len()),
94            (ColumnarValue::Array(a), Some(array_len)) => {
95                if array_len == a.len() {
96                    Some(array_len)
97                } else {
98                    return Err(DataFusionError::Internal(format!(
99                        "Arguments has mixed length. Expected length: {array_len}, found length: {}",
100                        a.len()
101                    )));
102                }
103            }
104            (ColumnarValue::Scalar(_), array_len) => array_len,
105        }
106    }
107
108    // If array_len is none, it means there are only scalars, treat them each as 1 element array.
109    let array_len = array_len.unwrap_or(1);
110    Ok(array_len)
111}
112
113struct VectorCalculator<'a, F> {
114    name: &'a str,
115    func: F,
116}
117
118impl<F> VectorCalculator<'_, F>
119where
120    F: Fn(&ScalarValue, &ScalarValue) -> Result<ScalarValue>,
121{
122    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
123        let [arg0, arg1] = utils::take_function_args(self.name, &args.args)?;
124
125        if let (ColumnarValue::Scalar(v0), ColumnarValue::Scalar(v1)) = (arg0, arg1) {
126            let result = (self.func)(v0, v1)?;
127            return Ok(ColumnarValue::Scalar(result));
128        }
129
130        let len = ensure_same_length(&[arg0, arg1])?;
131        let mut results = Vec::with_capacity(len);
132        for i in 0..len {
133            let v0 = try_get_scalar_value!(arg0, i);
134            let v1 = try_get_scalar_value!(arg1, i);
135            results.push((self.func)(v0, v1)?);
136        }
137
138        let results = ScalarValue::iter_to_array(results.into_iter())?;
139        Ok(ColumnarValue::Array(results))
140    }
141}
142
143impl<F> VectorCalculator<'_, F>
144where
145    F: Fn(&Option<Cow<[f32]>>, &Option<Cow<[f32]>>) -> Result<ScalarValue>,
146{
147    fn invoke_with_vectors(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
148        let [arg0, arg1] = utils::take_function_args(self.name, &args.args)?;
149
150        if let (ColumnarValue::Scalar(v0), ColumnarValue::Scalar(v1)) = (arg0, arg1) {
151            let v0 = as_veclit(v0)?;
152            let v1 = as_veclit(v1)?;
153            let result = (self.func)(&v0, &v1)?;
154            return Ok(ColumnarValue::Scalar(result));
155        }
156
157        let len = ensure_same_length(&[arg0, arg1])?;
158        let mut results = Vec::with_capacity(len);
159
160        match (arg0, arg1) {
161            (ColumnarValue::Scalar(v0), ColumnarValue::Array(a1)) => {
162                let v0 = as_veclit(v0)?;
163                for i in 0..len {
164                    let v1 = ScalarValue::try_from_array(a1, i)?;
165                    let v1 = as_veclit(&v1)?;
166                    results.push((self.func)(&v0, &v1)?);
167                }
168            }
169            (ColumnarValue::Array(a0), ColumnarValue::Scalar(v1)) => {
170                let v1 = as_veclit(v1)?;
171                for i in 0..len {
172                    let v0 = ScalarValue::try_from_array(a0, i)?;
173                    let v0 = as_veclit(&v0)?;
174                    results.push((self.func)(&v0, &v1)?);
175                }
176            }
177            (ColumnarValue::Array(a0), ColumnarValue::Array(a1)) => {
178                for i in 0..len {
179                    let v0 = ScalarValue::try_from_array(a0, i)?;
180                    let v0 = as_veclit(&v0)?;
181                    let v1 = ScalarValue::try_from_array(a1, i)?;
182                    let v1 = as_veclit(&v1)?;
183                    results.push((self.func)(&v0, &v1)?);
184                }
185            }
186            (ColumnarValue::Scalar(_), ColumnarValue::Scalar(_)) => {
187                // unreachable because this arm has been separately dealt with above
188                unreachable!()
189            }
190        }
191
192        let results = ScalarValue::iter_to_array(results.into_iter())?;
193        Ok(ColumnarValue::Array(results))
194    }
195}
196
197impl<F> VectorCalculator<'_, F>
198where
199    F: Fn(&ScalarValue) -> Result<ScalarValue>,
200{
201    fn invoke_with_single_argument(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
202        let [arg0] = utils::take_function_args(self.name, &args.args)?;
203
204        let arg0 = match arg0 {
205            ColumnarValue::Scalar(v) => {
206                let result = (self.func)(v)?;
207                return Ok(ColumnarValue::Scalar(result));
208            }
209            ColumnarValue::Array(a) => a,
210        };
211
212        let len = arg0.len();
213        let mut results = Vec::with_capacity(len);
214        for i in 0..len {
215            let v = ScalarValue::try_from_array(arg0, i)?;
216            results.push((self.func)(&v)?);
217        }
218
219        let results = ScalarValue::iter_to_array(results.into_iter())?;
220        Ok(ColumnarValue::Array(results))
221    }
222}
223
224macro_rules! define_args_of_two_vector_literals_udf {
225    ($(#[$attr:meta])* $name: ident) => {
226        $(#[$attr])*
227        #[derive(Debug, Clone)]
228        pub(crate) struct $name {
229            signature: datafusion_expr::Signature,
230        }
231
232        impl Default for $name {
233            fn default() -> Self {
234                use arrow::datatypes::DataType;
235
236                Self {
237                    signature: crate::helper::one_of_sigs2(
238                        vec![
239                            DataType::Utf8,
240                            DataType::Utf8View,
241                            DataType::Binary,
242                            DataType::BinaryView,
243                        ],
244                        vec![
245                            DataType::Utf8,
246                            DataType::Utf8View,
247                            DataType::Binary,
248                            DataType::BinaryView,
249                        ],
250                    ),
251                }
252            }
253        }
254
255        impl std::fmt::Display for $name {
256            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257                write!(f, "{}", self.name().to_ascii_uppercase())
258            }
259        }
260    };
261}
262
263pub(crate) use define_args_of_two_vector_literals_udf;