common_function/scalars/vector/
vector_dim.rs1use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion::arrow::datatypes::DataType;
20use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
21use datafusion_expr::{Signature, TypeSignature, Volatility};
22use datatypes::scalars::ScalarVectorBuilder;
23use datatypes::vectors::{MutableVector, UInt64VectorBuilder, VectorRef};
24use snafu::ensure;
25
26use crate::function::{Function, FunctionContext};
27use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
28
29const NAME: &str = "vec_dim";
30
31#[derive(Debug, Clone, Default)]
45pub struct VectorDimFunction;
46
47impl Function for VectorDimFunction {
48 fn name(&self) -> &str {
49 NAME
50 }
51
52 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
53 Ok(DataType::UInt64)
54 }
55
56 fn signature(&self) -> Signature {
57 Signature::one_of(
58 vec![
59 TypeSignature::Uniform(1, STRINGS.to_vec()),
60 TypeSignature::Uniform(1, BINARYS.to_vec()),
61 ],
62 Volatility::Immutable,
63 )
64 }
65
66 fn eval(
67 &self,
68 _func_ctx: &FunctionContext,
69 columns: &[VectorRef],
70 ) -> common_query::error::Result<VectorRef> {
71 ensure!(
72 columns.len() == 1,
73 InvalidFuncArgsSnafu {
74 err_msg: format!(
75 "The length of the args is not correct, expect exactly one, have: {}",
76 columns.len()
77 )
78 }
79 );
80 let arg0 = &columns[0];
81
82 let len = arg0.len();
83 let mut result = UInt64VectorBuilder::with_capacity(len);
84 if len == 0 {
85 return Ok(result.to_vector());
86 }
87
88 let arg0_const = as_veclit_if_const(arg0)?;
89
90 for i in 0..len {
91 let arg0 = match arg0_const.as_ref() {
92 Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
93 None => as_veclit(arg0.get_ref(i))?,
94 };
95 let Some(arg0) = arg0 else {
96 result.push_null();
97 continue;
98 };
99 result.push(Some(arg0.len() as u64));
100 }
101
102 Ok(result.to_vector())
103 }
104}
105
106impl Display for VectorDimFunction {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 write!(f, "{}", NAME.to_ascii_uppercase())
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use std::sync::Arc;
115
116 use common_query::error::Error;
117 use datatypes::vectors::StringVector;
118
119 use super::*;
120
121 #[test]
122 fn test_vec_dim() {
123 let func = VectorDimFunction;
124
125 let input0 = Arc::new(StringVector::from(vec![
126 Some("[0.0,2.0,3.0]".to_string()),
127 Some("[1.0,2.0,3.0,4.0]".to_string()),
128 None,
129 Some("[5.0]".to_string()),
130 ]));
131
132 let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
133
134 let result = result.as_ref();
135 assert_eq!(result.len(), 4);
136 assert_eq!(result.get_ref(0).as_u64().unwrap(), Some(3));
137 assert_eq!(result.get_ref(1).as_u64().unwrap(), Some(4));
138 assert!(result.get_ref(2).is_null());
139 assert_eq!(result.get_ref(3).as_u64().unwrap(), Some(1));
140 }
141
142 #[test]
143 fn test_dim_error() {
144 let func = VectorDimFunction;
145
146 let input0 = Arc::new(StringVector::from(vec![
147 Some("[1.0,2.0,3.0]".to_string()),
148 Some("[4.0,5.0,6.0]".to_string()),
149 None,
150 Some("[2.0,3.0,3.0]".to_string()),
151 ]));
152 let input1 = Arc::new(StringVector::from(vec![
153 Some("[1.0,1.0,1.0]".to_string()),
154 Some("[6.0,5.0,4.0]".to_string()),
155 Some("[3.0,2.0,2.0]".to_string()),
156 ]));
157
158 let result = func.eval(&FunctionContext::default(), &[input0, input1]);
159
160 match result {
161 Err(Error::InvalidFuncArgs { err_msg, .. }) => {
162 assert_eq!(
163 err_msg,
164 "The length of the args is not correct, expect exactly one, have: 2"
165 )
166 }
167 _ => unreachable!(),
168 }
169 }
170}