common_function/scalars/vector/
vector_dim.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16
17use common_query::error::Result;
18use datafusion::arrow::datatypes::DataType;
19use datafusion::logical_expr::ColumnarValue;
20use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
21use datafusion_common::ScalarValue;
22use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
23
24use crate::function::Function;
25use crate::scalars::vector::VectorCalculator;
26use crate::scalars::vector::impl_conv::as_veclit;
27
28const NAME: &str = "vec_dim";
29
30/// Returns the dimension of the vector.
31///
32/// # Example
33///
34/// ```sql
35/// SELECT vec_dim('[7.0, 8.0, 9.0, 10.0]');
36///
37/// +---------------------------------------------------------------+
38/// | vec_dim(Utf8("[7.0, 8.0, 9.0, 10.0]"))                        |
39/// +---------------------------------------------------------------+
40/// | 4                                                             |
41/// +---------------------------------------------------------------+
42///
43#[derive(Debug, Clone, Default)]
44pub struct VectorDimFunction;
45
46impl Function for VectorDimFunction {
47    fn name(&self) -> &str {
48        NAME
49    }
50
51    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
52        Ok(DataType::UInt64)
53    }
54
55    fn signature(&self) -> Signature {
56        Signature::one_of(
57            vec![
58                TypeSignature::Uniform(1, STRINGS.to_vec()),
59                TypeSignature::Uniform(1, BINARYS.to_vec()),
60            ],
61            Volatility::Immutable,
62        )
63    }
64
65    fn invoke_with_args(
66        &self,
67        args: ScalarFunctionArgs,
68    ) -> datafusion_common::Result<ColumnarValue> {
69        let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
70            let v = as_veclit(v0)?.map(|v0| v0.len() as u64);
71            Ok(ScalarValue::UInt64(v))
72        };
73
74        let calculator = VectorCalculator {
75            name: self.name(),
76            func: body,
77        };
78        calculator.invoke_with_single_argument(args)
79    }
80}
81
82impl Display for VectorDimFunction {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        write!(f, "{}", NAME.to_ascii_uppercase())
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use std::sync::Arc;
91
92    use arrow_schema::Field;
93    use datafusion::arrow::array::{Array, AsArray, StringViewArray};
94    use datafusion::arrow::datatypes::UInt64Type;
95    use datafusion_common::config::ConfigOptions;
96
97    use super::*;
98
99    #[test]
100    fn test_vec_dim() {
101        let func = VectorDimFunction;
102
103        let input0 = Arc::new(StringViewArray::from(vec![
104            Some("[0.0,2.0,3.0]".to_string()),
105            Some("[1.0,2.0,3.0,4.0]".to_string()),
106            None,
107            Some("[5.0]".to_string()),
108        ]));
109
110        let args = ScalarFunctionArgs {
111            args: vec![ColumnarValue::Array(input0)],
112            arg_fields: vec![],
113            number_rows: 4,
114            return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
115            config_options: Arc::new(ConfigOptions::new()),
116        };
117        let result = func
118            .invoke_with_args(args)
119            .and_then(|x| x.to_array(4))
120            .unwrap();
121
122        let result = result.as_primitive::<UInt64Type>();
123        assert_eq!(result.len(), 4);
124        assert_eq!(result.value(0), 3);
125        assert_eq!(result.value(1), 4);
126        assert!(result.is_null(2));
127        assert_eq!(result.value(3), 1);
128    }
129
130    #[test]
131    fn test_dim_error() {
132        let func = VectorDimFunction;
133
134        let input0 = Arc::new(StringViewArray::from(vec![
135            Some("[1.0,2.0,3.0]".to_string()),
136            Some("[4.0,5.0,6.0]".to_string()),
137            None,
138            Some("[2.0,3.0,3.0]".to_string()),
139        ]));
140        let input1 = Arc::new(StringViewArray::from(vec![
141            Some("[1.0,1.0,1.0]".to_string()),
142            Some("[6.0,5.0,4.0]".to_string()),
143            Some("[3.0,2.0,2.0]".to_string()),
144        ]));
145
146        let args = ScalarFunctionArgs {
147            args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
148            arg_fields: vec![],
149            number_rows: 4,
150            return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
151            config_options: Arc::new(ConfigOptions::new()),
152        };
153        let e = func.invoke_with_args(args).unwrap_err();
154        assert!(
155            e.to_string()
156                .starts_with("Execution error: vec_dim function requires 1 argument, got 2")
157        )
158    }
159}