common_function/scalars/vector/
elem_avg.rs1use std::fmt::Display;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion_common::ScalarValue;
20use datafusion_expr::type_coercion::aggregates::{BINARYS, STRINGS};
21use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
22use nalgebra::DVectorView;
23
24use crate::function::Function;
25use crate::scalars::vector::{VectorCalculator, impl_conv};
26
27const NAME: &str = "vec_elem_avg";
28
29#[derive(Debug, Clone)]
30pub(crate) struct ElemAvgFunction {
31 signature: Signature,
32}
33
34impl Default for ElemAvgFunction {
35 fn default() -> Self {
36 Self {
37 signature: Signature::one_of(
38 vec![
39 TypeSignature::Uniform(1, STRINGS.to_vec()),
40 TypeSignature::Uniform(1, BINARYS.to_vec()),
41 TypeSignature::Uniform(1, vec![DataType::BinaryView]),
42 ],
43 Volatility::Immutable,
44 ),
45 }
46 }
47}
48
49impl Function for ElemAvgFunction {
50 fn name(&self) -> &str {
51 NAME
52 }
53
54 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
55 Ok(DataType::Float32)
56 }
57
58 fn signature(&self) -> &Signature {
59 &self.signature
60 }
61
62 fn invoke_with_args(
63 &self,
64 args: ScalarFunctionArgs,
65 ) -> datafusion_common::Result<ColumnarValue> {
66 let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
67 let v0 =
68 impl_conv::as_veclit(v0)?.map(|v0| DVectorView::from_slice(&v0, v0.len()).mean());
69 Ok(ScalarValue::Float32(v0))
70 };
71
72 let calculator = VectorCalculator {
73 name: self.name(),
74 func: body,
75 };
76 calculator.invoke_with_single_argument(args)
77 }
78}
79
80impl Display for ElemAvgFunction {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 write!(f, "{}", NAME.to_ascii_uppercase())
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use std::sync::Arc;
89
90 use arrow::array::StringViewArray;
91 use arrow_schema::Field;
92 use datafusion::arrow::array::{Array, AsArray};
93 use datafusion::arrow::datatypes::Float32Type;
94 use datafusion_common::config::ConfigOptions;
95
96 use super::*;
97
98 #[test]
99 fn test_elem_avg() {
100 let func = ElemAvgFunction::default();
101
102 let input = Arc::new(StringViewArray::from(vec![
103 Some("[1.0,2.0,3.0]".to_string()),
104 Some("[4.0,5.0,6.0]".to_string()),
105 Some("[7.0,8.0,9.0]".to_string()),
106 None,
107 ]));
108
109 let result = func
110 .invoke_with_args(ScalarFunctionArgs {
111 args: vec![ColumnarValue::Array(input.clone())],
112 arg_fields: vec![],
113 number_rows: input.len(),
114 return_field: Arc::new(Field::new("x", DataType::Float32, true)),
115 config_options: Arc::new(ConfigOptions::new()),
116 })
117 .and_then(|v| ColumnarValue::values_to_arrays(&[v]))
118 .map(|mut a| a.remove(0))
119 .unwrap();
120 let result = result.as_primitive::<Float32Type>();
121
122 assert_eq!(result.len(), 4);
123 assert_eq!(result.value(0), 2.0);
124 assert_eq!(result.value(1), 5.0);
125 assert_eq!(result.value(2), 8.0);
126 assert!(result.is_null(3));
127 }
128}