common_function/scalars/vector/
vector_dim.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion::arrow::datatypes::DataType;
20use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
21use datafusion_expr::{Signature, TypeSignature, Volatility};
22use datatypes::scalars::ScalarVectorBuilder;
23use datatypes::vectors::{MutableVector, UInt64VectorBuilder, VectorRef};
24use snafu::ensure;
25
26use crate::function::{Function, FunctionContext};
27use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
28
29const NAME: &str = "vec_dim";
30
31/// Returns the dimension of the vector.
32///
33/// # Example
34///
35/// ```sql
36/// SELECT vec_dim('[7.0, 8.0, 9.0, 10.0]');
37///
38/// +---------------------------------------------------------------+
39/// | vec_dim(Utf8("[7.0, 8.0, 9.0, 10.0]"))                        |
40/// +---------------------------------------------------------------+
41/// | 4                                                             |
42/// +---------------------------------------------------------------+
43///
44#[derive(Debug, Clone, Default)]
45pub struct VectorDimFunction;
46
47impl Function for VectorDimFunction {
48    fn name(&self) -> &str {
49        NAME
50    }
51
52    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
53        Ok(DataType::UInt64)
54    }
55
56    fn signature(&self) -> Signature {
57        Signature::one_of(
58            vec![
59                TypeSignature::Uniform(1, STRINGS.to_vec()),
60                TypeSignature::Uniform(1, BINARYS.to_vec()),
61            ],
62            Volatility::Immutable,
63        )
64    }
65
66    fn eval(
67        &self,
68        _func_ctx: &FunctionContext,
69        columns: &[VectorRef],
70    ) -> common_query::error::Result<VectorRef> {
71        ensure!(
72            columns.len() == 1,
73            InvalidFuncArgsSnafu {
74                err_msg: format!(
75                    "The length of the args is not correct, expect exactly one, have: {}",
76                    columns.len()
77                )
78            }
79        );
80        let arg0 = &columns[0];
81
82        let len = arg0.len();
83        let mut result = UInt64VectorBuilder::with_capacity(len);
84        if len == 0 {
85            return Ok(result.to_vector());
86        }
87
88        let arg0_const = as_veclit_if_const(arg0)?;
89
90        for i in 0..len {
91            let arg0 = match arg0_const.as_ref() {
92                Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
93                None => as_veclit(arg0.get_ref(i))?,
94            };
95            let Some(arg0) = arg0 else {
96                result.push_null();
97                continue;
98            };
99            result.push(Some(arg0.len() as u64));
100        }
101
102        Ok(result.to_vector())
103    }
104}
105
106impl Display for VectorDimFunction {
107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108        write!(f, "{}", NAME.to_ascii_uppercase())
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use std::sync::Arc;
115
116    use common_query::error::Error;
117    use datatypes::vectors::StringVector;
118
119    use super::*;
120
121    #[test]
122    fn test_vec_dim() {
123        let func = VectorDimFunction;
124
125        let input0 = Arc::new(StringVector::from(vec![
126            Some("[0.0,2.0,3.0]".to_string()),
127            Some("[1.0,2.0,3.0,4.0]".to_string()),
128            None,
129            Some("[5.0]".to_string()),
130        ]));
131
132        let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
133
134        let result = result.as_ref();
135        assert_eq!(result.len(), 4);
136        assert_eq!(result.get_ref(0).as_u64().unwrap(), Some(3));
137        assert_eq!(result.get_ref(1).as_u64().unwrap(), Some(4));
138        assert!(result.get_ref(2).is_null());
139        assert_eq!(result.get_ref(3).as_u64().unwrap(), Some(1));
140    }
141
142    #[test]
143    fn test_dim_error() {
144        let func = VectorDimFunction;
145
146        let input0 = Arc::new(StringVector::from(vec![
147            Some("[1.0,2.0,3.0]".to_string()),
148            Some("[4.0,5.0,6.0]".to_string()),
149            None,
150            Some("[2.0,3.0,3.0]".to_string()),
151        ]));
152        let input1 = Arc::new(StringVector::from(vec![
153            Some("[1.0,1.0,1.0]".to_string()),
154            Some("[6.0,5.0,4.0]".to_string()),
155            Some("[3.0,2.0,2.0]".to_string()),
156        ]));
157
158        let result = func.eval(&FunctionContext::default(), &[input0, input1]);
159
160        match result {
161            Err(Error::InvalidFuncArgs { err_msg, .. }) => {
162                assert_eq!(
163                    err_msg,
164                    "The length of the args is not correct, expect exactly one, have: 2"
165                )
166            }
167            _ => unreachable!(),
168        }
169    }
170}