common_function/scalars/vector/
vector_dim.rs1use std::fmt::Display;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
20use datafusion_common::ScalarValue;
21use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
22
23use crate::function::Function;
24use crate::scalars::vector::VectorCalculator;
25use crate::scalars::vector::impl_conv::as_veclit;
26
27const NAME: &str = "vec_dim";
28
29#[derive(Debug, Clone)]
43pub(crate) struct VectorDimFunction {
44 signature: Signature,
45}
46
47impl Default for VectorDimFunction {
48 fn default() -> Self {
49 Self {
50 signature: Signature::one_of(
51 vec![
52 TypeSignature::Uniform(1, STRINGS.to_vec()),
53 TypeSignature::Uniform(1, BINARYS.to_vec()),
54 ],
55 Volatility::Immutable,
56 ),
57 }
58 }
59}
60
61impl Function for VectorDimFunction {
62 fn name(&self) -> &str {
63 NAME
64 }
65
66 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
67 Ok(DataType::UInt64)
68 }
69
70 fn signature(&self) -> &Signature {
71 &self.signature
72 }
73
74 fn invoke_with_args(
75 &self,
76 args: ScalarFunctionArgs,
77 ) -> datafusion_common::Result<ColumnarValue> {
78 let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
79 let v = as_veclit(v0)?.map(|v0| v0.len() as u64);
80 Ok(ScalarValue::UInt64(v))
81 };
82
83 let calculator = VectorCalculator {
84 name: self.name(),
85 func: body,
86 };
87 calculator.invoke_with_single_argument(args)
88 }
89}
90
91impl Display for VectorDimFunction {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 write!(f, "{}", NAME.to_ascii_uppercase())
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use std::sync::Arc;
100
101 use arrow_schema::Field;
102 use datafusion::arrow::array::{Array, AsArray, StringViewArray};
103 use datafusion::arrow::datatypes::UInt64Type;
104 use datafusion_common::config::ConfigOptions;
105
106 use super::*;
107
108 #[test]
109 fn test_vec_dim() {
110 let func = VectorDimFunction::default();
111
112 let input0 = Arc::new(StringViewArray::from(vec![
113 Some("[0.0,2.0,3.0]".to_string()),
114 Some("[1.0,2.0,3.0,4.0]".to_string()),
115 None,
116 Some("[5.0]".to_string()),
117 ]));
118
119 let args = ScalarFunctionArgs {
120 args: vec![ColumnarValue::Array(input0)],
121 arg_fields: vec![],
122 number_rows: 4,
123 return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
124 config_options: Arc::new(ConfigOptions::new()),
125 };
126 let result = func
127 .invoke_with_args(args)
128 .and_then(|x| x.to_array(4))
129 .unwrap();
130
131 let result = result.as_primitive::<UInt64Type>();
132 assert_eq!(result.len(), 4);
133 assert_eq!(result.value(0), 3);
134 assert_eq!(result.value(1), 4);
135 assert!(result.is_null(2));
136 assert_eq!(result.value(3), 1);
137 }
138
139 #[test]
140 fn test_dim_error() {
141 let func = VectorDimFunction::default();
142
143 let input0 = Arc::new(StringViewArray::from(vec![
144 Some("[1.0,2.0,3.0]".to_string()),
145 Some("[4.0,5.0,6.0]".to_string()),
146 None,
147 Some("[2.0,3.0,3.0]".to_string()),
148 ]));
149 let input1 = Arc::new(StringViewArray::from(vec![
150 Some("[1.0,1.0,1.0]".to_string()),
151 Some("[6.0,5.0,4.0]".to_string()),
152 Some("[3.0,2.0,2.0]".to_string()),
153 ]));
154
155 let args = ScalarFunctionArgs {
156 args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
157 arg_fields: vec![],
158 number_rows: 4,
159 return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
160 config_options: Arc::new(ConfigOptions::new()),
161 };
162 let e = func.invoke_with_args(args).unwrap_err();
163 assert!(
164 e.to_string()
165 .starts_with("Execution error: vec_dim function requires 1 argument, got 2")
166 )
167 }
168}