common_function/scalars/vector/
vector_dim.rs1use std::fmt::Display;
16
17use common_query::error::Result;
18use datafusion::arrow::datatypes::DataType;
19use datafusion::logical_expr::ColumnarValue;
20use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
21use datafusion_common::ScalarValue;
22use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
23
24use crate::function::Function;
25use crate::scalars::vector::VectorCalculator;
26use crate::scalars::vector::impl_conv::as_veclit;
27
28const NAME: &str = "vec_dim";
29
30#[derive(Debug, Clone, Default)]
44pub struct VectorDimFunction;
45
46impl Function for VectorDimFunction {
47 fn name(&self) -> &str {
48 NAME
49 }
50
51 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
52 Ok(DataType::UInt64)
53 }
54
55 fn signature(&self) -> Signature {
56 Signature::one_of(
57 vec![
58 TypeSignature::Uniform(1, STRINGS.to_vec()),
59 TypeSignature::Uniform(1, BINARYS.to_vec()),
60 ],
61 Volatility::Immutable,
62 )
63 }
64
65 fn invoke_with_args(
66 &self,
67 args: ScalarFunctionArgs,
68 ) -> datafusion_common::Result<ColumnarValue> {
69 let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
70 let v = as_veclit(v0)?.map(|v0| v0.len() as u64);
71 Ok(ScalarValue::UInt64(v))
72 };
73
74 let calculator = VectorCalculator {
75 name: self.name(),
76 func: body,
77 };
78 calculator.invoke_with_single_argument(args)
79 }
80}
81
82impl Display for VectorDimFunction {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 write!(f, "{}", NAME.to_ascii_uppercase())
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use std::sync::Arc;
91
92 use arrow_schema::Field;
93 use datafusion::arrow::array::{Array, AsArray, StringViewArray};
94 use datafusion::arrow::datatypes::UInt64Type;
95 use datafusion_common::config::ConfigOptions;
96
97 use super::*;
98
99 #[test]
100 fn test_vec_dim() {
101 let func = VectorDimFunction;
102
103 let input0 = Arc::new(StringViewArray::from(vec![
104 Some("[0.0,2.0,3.0]".to_string()),
105 Some("[1.0,2.0,3.0,4.0]".to_string()),
106 None,
107 Some("[5.0]".to_string()),
108 ]));
109
110 let args = ScalarFunctionArgs {
111 args: vec![ColumnarValue::Array(input0)],
112 arg_fields: vec![],
113 number_rows: 4,
114 return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
115 config_options: Arc::new(ConfigOptions::new()),
116 };
117 let result = func
118 .invoke_with_args(args)
119 .and_then(|x| x.to_array(4))
120 .unwrap();
121
122 let result = result.as_primitive::<UInt64Type>();
123 assert_eq!(result.len(), 4);
124 assert_eq!(result.value(0), 3);
125 assert_eq!(result.value(1), 4);
126 assert!(result.is_null(2));
127 assert_eq!(result.value(3), 1);
128 }
129
130 #[test]
131 fn test_dim_error() {
132 let func = VectorDimFunction;
133
134 let input0 = Arc::new(StringViewArray::from(vec![
135 Some("[1.0,2.0,3.0]".to_string()),
136 Some("[4.0,5.0,6.0]".to_string()),
137 None,
138 Some("[2.0,3.0,3.0]".to_string()),
139 ]));
140 let input1 = Arc::new(StringViewArray::from(vec![
141 Some("[1.0,1.0,1.0]".to_string()),
142 Some("[6.0,5.0,4.0]".to_string()),
143 Some("[3.0,2.0,2.0]".to_string()),
144 ]));
145
146 let args = ScalarFunctionArgs {
147 args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
148 arg_fields: vec![],
149 number_rows: 4,
150 return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
151 config_options: Arc::new(ConfigOptions::new()),
152 };
153 let e = func.invoke_with_args(args).unwrap_err();
154 assert!(
155 e.to_string()
156 .starts_with("Execution error: vec_dim function requires 1 argument, got 2")
157 )
158 }
159}