common_function/scalars/vector/
vector_dim.rs1use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::InvalidFuncArgsSnafu;
19use common_query::prelude::{Signature, TypeSignature, Volatility};
20use datatypes::prelude::ConcreteDataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{MutableVector, UInt64VectorBuilder, VectorRef};
23use snafu::ensure;
24
25use crate::function::{Function, FunctionContext};
26use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
27
28const NAME: &str = "vec_dim";
29
30#[derive(Debug, Clone, Default)]
44pub struct VectorDimFunction;
45
46impl Function for VectorDimFunction {
47 fn name(&self) -> &str {
48 NAME
49 }
50
51 fn return_type(
52 &self,
53 _input_types: &[ConcreteDataType],
54 ) -> common_query::error::Result<ConcreteDataType> {
55 Ok(ConcreteDataType::uint64_datatype())
56 }
57
58 fn signature(&self) -> Signature {
59 Signature::one_of(
60 vec![
61 TypeSignature::Exact(vec![ConcreteDataType::string_datatype()]),
62 TypeSignature::Exact(vec![ConcreteDataType::binary_datatype()]),
63 ],
64 Volatility::Immutable,
65 )
66 }
67
68 fn eval(
69 &self,
70 _func_ctx: &FunctionContext,
71 columns: &[VectorRef],
72 ) -> common_query::error::Result<VectorRef> {
73 ensure!(
74 columns.len() == 1,
75 InvalidFuncArgsSnafu {
76 err_msg: format!(
77 "The length of the args is not correct, expect exactly one, have: {}",
78 columns.len()
79 )
80 }
81 );
82 let arg0 = &columns[0];
83
84 let len = arg0.len();
85 let mut result = UInt64VectorBuilder::with_capacity(len);
86 if len == 0 {
87 return Ok(result.to_vector());
88 }
89
90 let arg0_const = as_veclit_if_const(arg0)?;
91
92 for i in 0..len {
93 let arg0 = match arg0_const.as_ref() {
94 Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
95 None => as_veclit(arg0.get_ref(i))?,
96 };
97 let Some(arg0) = arg0 else {
98 result.push_null();
99 continue;
100 };
101 result.push(Some(arg0.len() as u64));
102 }
103
104 Ok(result.to_vector())
105 }
106}
107
108impl Display for VectorDimFunction {
109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 write!(f, "{}", NAME.to_ascii_uppercase())
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use std::sync::Arc;
117
118 use common_query::error::Error;
119 use datatypes::vectors::StringVector;
120
121 use super::*;
122
123 #[test]
124 fn test_vec_dim() {
125 let func = VectorDimFunction;
126
127 let input0 = Arc::new(StringVector::from(vec![
128 Some("[0.0,2.0,3.0]".to_string()),
129 Some("[1.0,2.0,3.0,4.0]".to_string()),
130 None,
131 Some("[5.0]".to_string()),
132 ]));
133
134 let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
135
136 let result = result.as_ref();
137 assert_eq!(result.len(), 4);
138 assert_eq!(result.get_ref(0).as_u64().unwrap(), Some(3));
139 assert_eq!(result.get_ref(1).as_u64().unwrap(), Some(4));
140 assert!(result.get_ref(2).is_null());
141 assert_eq!(result.get_ref(3).as_u64().unwrap(), Some(1));
142 }
143
144 #[test]
145 fn test_dim_error() {
146 let func = VectorDimFunction;
147
148 let input0 = Arc::new(StringVector::from(vec![
149 Some("[1.0,2.0,3.0]".to_string()),
150 Some("[4.0,5.0,6.0]".to_string()),
151 None,
152 Some("[2.0,3.0,3.0]".to_string()),
153 ]));
154 let input1 = Arc::new(StringVector::from(vec![
155 Some("[1.0,1.0,1.0]".to_string()),
156 Some("[6.0,5.0,4.0]".to_string()),
157 Some("[3.0,2.0,2.0]".to_string()),
158 ]));
159
160 let result = func.eval(&FunctionContext::default(), &[input0, input1]);
161
162 match result {
163 Err(Error::InvalidFuncArgs { err_msg, .. }) => {
164 assert_eq!(
165 err_msg,
166 "The length of the args is not correct, expect exactly one, have: 2"
167 )
168 }
169 _ => unreachable!(),
170 }
171 }
172}