common_function/scalars/vector/
vector_dim.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
20use datafusion_common::ScalarValue;
21use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
22
23use crate::function::Function;
24use crate::scalars::vector::VectorCalculator;
25use crate::scalars::vector::impl_conv::as_veclit;
26
27const NAME: &str = "vec_dim";
28
29/// Returns the dimension of the vector.
30///
31/// # Example
32///
33/// ```sql
34/// SELECT vec_dim('[7.0, 8.0, 9.0, 10.0]');
35///
36/// +---------------------------------------------------------------+
37/// | vec_dim(Utf8("[7.0, 8.0, 9.0, 10.0]"))                        |
38/// +---------------------------------------------------------------+
39/// | 4                                                             |
40/// +---------------------------------------------------------------+
41///
42#[derive(Debug, Clone)]
43pub(crate) struct VectorDimFunction {
44    signature: Signature,
45}
46
47impl Default for VectorDimFunction {
48    fn default() -> Self {
49        Self {
50            signature: Signature::one_of(
51                vec![
52                    TypeSignature::Uniform(1, STRINGS.to_vec()),
53                    TypeSignature::Uniform(1, BINARYS.to_vec()),
54                ],
55                Volatility::Immutable,
56            ),
57        }
58    }
59}
60
61impl Function for VectorDimFunction {
62    fn name(&self) -> &str {
63        NAME
64    }
65
66    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
67        Ok(DataType::UInt64)
68    }
69
70    fn signature(&self) -> &Signature {
71        &self.signature
72    }
73
74    fn invoke_with_args(
75        &self,
76        args: ScalarFunctionArgs,
77    ) -> datafusion_common::Result<ColumnarValue> {
78        let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
79            let v = as_veclit(v0)?.map(|v0| v0.len() as u64);
80            Ok(ScalarValue::UInt64(v))
81        };
82
83        let calculator = VectorCalculator {
84            name: self.name(),
85            func: body,
86        };
87        calculator.invoke_with_single_argument(args)
88    }
89}
90
91impl Display for VectorDimFunction {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        write!(f, "{}", NAME.to_ascii_uppercase())
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use std::sync::Arc;
100
101    use arrow_schema::Field;
102    use datafusion::arrow::array::{Array, AsArray, StringViewArray};
103    use datafusion::arrow::datatypes::UInt64Type;
104    use datafusion_common::config::ConfigOptions;
105
106    use super::*;
107
108    #[test]
109    fn test_vec_dim() {
110        let func = VectorDimFunction::default();
111
112        let input0 = Arc::new(StringViewArray::from(vec![
113            Some("[0.0,2.0,3.0]".to_string()),
114            Some("[1.0,2.0,3.0,4.0]".to_string()),
115            None,
116            Some("[5.0]".to_string()),
117        ]));
118
119        let args = ScalarFunctionArgs {
120            args: vec![ColumnarValue::Array(input0)],
121            arg_fields: vec![],
122            number_rows: 4,
123            return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
124            config_options: Arc::new(ConfigOptions::new()),
125        };
126        let result = func
127            .invoke_with_args(args)
128            .and_then(|x| x.to_array(4))
129            .unwrap();
130
131        let result = result.as_primitive::<UInt64Type>();
132        assert_eq!(result.len(), 4);
133        assert_eq!(result.value(0), 3);
134        assert_eq!(result.value(1), 4);
135        assert!(result.is_null(2));
136        assert_eq!(result.value(3), 1);
137    }
138
139    #[test]
140    fn test_dim_error() {
141        let func = VectorDimFunction::default();
142
143        let input0 = Arc::new(StringViewArray::from(vec![
144            Some("[1.0,2.0,3.0]".to_string()),
145            Some("[4.0,5.0,6.0]".to_string()),
146            None,
147            Some("[2.0,3.0,3.0]".to_string()),
148        ]));
149        let input1 = Arc::new(StringViewArray::from(vec![
150            Some("[1.0,1.0,1.0]".to_string()),
151            Some("[6.0,5.0,4.0]".to_string()),
152            Some("[3.0,2.0,2.0]".to_string()),
153        ]));
154
155        let args = ScalarFunctionArgs {
156            args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
157            arg_fields: vec![],
158            number_rows: 4,
159            return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
160            config_options: Arc::new(ConfigOptions::new()),
161        };
162        let e = func.invoke_with_args(args).unwrap_err();
163        assert!(
164            e.to_string()
165                .starts_with("Execution error: vec_dim function requires 1 argument, got 2")
166        )
167    }
168}