common_function/scalars/vector/
vector_dim.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::InvalidFuncArgsSnafu;
19use common_query::prelude::{Signature, TypeSignature, Volatility};
20use datatypes::prelude::ConcreteDataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{MutableVector, UInt64VectorBuilder, VectorRef};
23use snafu::ensure;
24
25use crate::function::{Function, FunctionContext};
26use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
27
28const NAME: &str = "vec_dim";
29
30/// Returns the dimension of the vector.
31///
32/// # Example
33///
34/// ```sql
35/// SELECT vec_dim('[7.0, 8.0, 9.0, 10.0]');
36///
37/// +---------------------------------------------------------------+
38/// | vec_dim(Utf8("[7.0, 8.0, 9.0, 10.0]"))                        |
39/// +---------------------------------------------------------------+
40/// | 4                                                             |
41/// +---------------------------------------------------------------+
42///
43#[derive(Debug, Clone, Default)]
44pub struct VectorDimFunction;
45
46impl Function for VectorDimFunction {
47    fn name(&self) -> &str {
48        NAME
49    }
50
51    fn return_type(
52        &self,
53        _input_types: &[ConcreteDataType],
54    ) -> common_query::error::Result<ConcreteDataType> {
55        Ok(ConcreteDataType::uint64_datatype())
56    }
57
58    fn signature(&self) -> Signature {
59        Signature::one_of(
60            vec![
61                TypeSignature::Exact(vec![ConcreteDataType::string_datatype()]),
62                TypeSignature::Exact(vec![ConcreteDataType::binary_datatype()]),
63            ],
64            Volatility::Immutable,
65        )
66    }
67
68    fn eval(
69        &self,
70        _func_ctx: &FunctionContext,
71        columns: &[VectorRef],
72    ) -> common_query::error::Result<VectorRef> {
73        ensure!(
74            columns.len() == 1,
75            InvalidFuncArgsSnafu {
76                err_msg: format!(
77                    "The length of the args is not correct, expect exactly one, have: {}",
78                    columns.len()
79                )
80            }
81        );
82        let arg0 = &columns[0];
83
84        let len = arg0.len();
85        let mut result = UInt64VectorBuilder::with_capacity(len);
86        if len == 0 {
87            return Ok(result.to_vector());
88        }
89
90        let arg0_const = as_veclit_if_const(arg0)?;
91
92        for i in 0..len {
93            let arg0 = match arg0_const.as_ref() {
94                Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
95                None => as_veclit(arg0.get_ref(i))?,
96            };
97            let Some(arg0) = arg0 else {
98                result.push_null();
99                continue;
100            };
101            result.push(Some(arg0.len() as u64));
102        }
103
104        Ok(result.to_vector())
105    }
106}
107
108impl Display for VectorDimFunction {
109    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110        write!(f, "{}", NAME.to_ascii_uppercase())
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use std::sync::Arc;
117
118    use common_query::error::Error;
119    use datatypes::vectors::StringVector;
120
121    use super::*;
122
123    #[test]
124    fn test_vec_dim() {
125        let func = VectorDimFunction;
126
127        let input0 = Arc::new(StringVector::from(vec![
128            Some("[0.0,2.0,3.0]".to_string()),
129            Some("[1.0,2.0,3.0,4.0]".to_string()),
130            None,
131            Some("[5.0]".to_string()),
132        ]));
133
134        let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
135
136        let result = result.as_ref();
137        assert_eq!(result.len(), 4);
138        assert_eq!(result.get_ref(0).as_u64().unwrap(), Some(3));
139        assert_eq!(result.get_ref(1).as_u64().unwrap(), Some(4));
140        assert!(result.get_ref(2).is_null());
141        assert_eq!(result.get_ref(3).as_u64().unwrap(), Some(1));
142    }
143
144    #[test]
145    fn test_dim_error() {
146        let func = VectorDimFunction;
147
148        let input0 = Arc::new(StringVector::from(vec![
149            Some("[1.0,2.0,3.0]".to_string()),
150            Some("[4.0,5.0,6.0]".to_string()),
151            None,
152            Some("[2.0,3.0,3.0]".to_string()),
153        ]));
154        let input1 = Arc::new(StringVector::from(vec![
155            Some("[1.0,1.0,1.0]".to_string()),
156            Some("[6.0,5.0,4.0]".to_string()),
157            Some("[3.0,2.0,2.0]".to_string()),
158        ]));
159
160        let result = func.eval(&FunctionContext::default(), &[input0, input1]);
161
162        match result {
163            Err(Error::InvalidFuncArgs { err_msg, .. }) => {
164                assert_eq!(
165                    err_msg,
166                    "The length of the args is not correct, expect exactly one, have: 2"
167                )
168            }
169            _ => unreachable!(),
170        }
171    }
172}