common_function/scalars/vector/
vector_dim.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::{Coercion, ColumnarValue, TypeSignatureClass};
19use datafusion_common::ScalarValue;
20use datafusion_common::types::{logical_binary, logical_string};
21use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
22
23use crate::function::Function;
24use crate::scalars::vector::VectorCalculator;
25use crate::scalars::vector::impl_conv::as_veclit;
26
27const NAME: &str = "vec_dim";
28
29/// Returns the dimension of the vector.
30///
31/// # Example
32///
33/// ```sql
34/// SELECT vec_dim('[7.0, 8.0, 9.0, 10.0]');
35///
36/// +---------------------------------------------------------------+
37/// | vec_dim(Utf8("[7.0, 8.0, 9.0, 10.0]"))                        |
38/// +---------------------------------------------------------------+
39/// | 4                                                             |
40/// +---------------------------------------------------------------+
41///
42#[derive(Debug, Clone)]
43pub(crate) struct VectorDimFunction {
44    signature: Signature,
45}
46
47impl Default for VectorDimFunction {
48    fn default() -> Self {
49        Self {
50            signature: Signature::one_of(
51                vec![
52                    TypeSignature::Coercible(vec![Coercion::new_exact(
53                        TypeSignatureClass::Native(logical_binary()),
54                    )]),
55                    TypeSignature::Coercible(vec![Coercion::new_exact(
56                        TypeSignatureClass::Native(logical_string()),
57                    )]),
58                ],
59                Volatility::Immutable,
60            ),
61        }
62    }
63}
64
65impl Function for VectorDimFunction {
66    fn name(&self) -> &str {
67        NAME
68    }
69
70    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
71        Ok(DataType::UInt64)
72    }
73
74    fn signature(&self) -> &Signature {
75        &self.signature
76    }
77
78    fn invoke_with_args(
79        &self,
80        args: ScalarFunctionArgs,
81    ) -> datafusion_common::Result<ColumnarValue> {
82        let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
83            let v = as_veclit(v0)?.map(|v0| v0.len() as u64);
84            Ok(ScalarValue::UInt64(v))
85        };
86
87        let calculator = VectorCalculator {
88            name: self.name(),
89            func: body,
90        };
91        calculator.invoke_with_single_argument(args)
92    }
93}
94
95impl Display for VectorDimFunction {
96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        write!(f, "{}", NAME.to_ascii_uppercase())
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use std::sync::Arc;
104
105    use arrow_schema::Field;
106    use datafusion::arrow::array::{Array, AsArray, StringViewArray};
107    use datafusion::arrow::datatypes::UInt64Type;
108    use datafusion_common::config::ConfigOptions;
109
110    use super::*;
111
112    #[test]
113    fn test_vec_dim() {
114        let func = VectorDimFunction::default();
115
116        let input0 = Arc::new(StringViewArray::from(vec![
117            Some("[0.0,2.0,3.0]".to_string()),
118            Some("[1.0,2.0,3.0,4.0]".to_string()),
119            None,
120            Some("[5.0]".to_string()),
121        ]));
122
123        let args = ScalarFunctionArgs {
124            args: vec![ColumnarValue::Array(input0)],
125            arg_fields: vec![],
126            number_rows: 4,
127            return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
128            config_options: Arc::new(ConfigOptions::new()),
129        };
130        let result = func
131            .invoke_with_args(args)
132            .and_then(|x| x.to_array(4))
133            .unwrap();
134
135        let result = result.as_primitive::<UInt64Type>();
136        assert_eq!(result.len(), 4);
137        assert_eq!(result.value(0), 3);
138        assert_eq!(result.value(1), 4);
139        assert!(result.is_null(2));
140        assert_eq!(result.value(3), 1);
141    }
142
143    #[test]
144    fn test_dim_error() {
145        let func = VectorDimFunction::default();
146
147        let input0 = Arc::new(StringViewArray::from(vec![
148            Some("[1.0,2.0,3.0]".to_string()),
149            Some("[4.0,5.0,6.0]".to_string()),
150            None,
151            Some("[2.0,3.0,3.0]".to_string()),
152        ]));
153        let input1 = Arc::new(StringViewArray::from(vec![
154            Some("[1.0,1.0,1.0]".to_string()),
155            Some("[6.0,5.0,4.0]".to_string()),
156            Some("[3.0,2.0,2.0]".to_string()),
157        ]));
158
159        let args = ScalarFunctionArgs {
160            args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
161            arg_fields: vec![],
162            number_rows: 4,
163            return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
164            config_options: Arc::new(ConfigOptions::new()),
165        };
166        let e = func.invoke_with_args(args).unwrap_err();
167        assert!(
168            e.to_string()
169                .starts_with("Execution error: vec_dim function requires 1 argument, got 2")
170        )
171    }
172}