common_function/scalars/vector/
elem_avg.rs1use std::fmt::Display;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::{Coercion, ColumnarValue, TypeSignature, TypeSignatureClass};
19use datafusion_common::ScalarValue;
20use datafusion_common::types::{logical_binary, logical_string};
21use datafusion_expr::{ScalarFunctionArgs, Signature, Volatility};
22use nalgebra::DVectorView;
23
24use crate::function::Function;
25use crate::scalars::vector::{VectorCalculator, impl_conv};
26
27const NAME: &str = "vec_elem_avg";
28
29#[derive(Debug, Clone)]
30pub(crate) struct ElemAvgFunction {
31 signature: Signature,
32}
33
34impl Default for ElemAvgFunction {
35 fn default() -> Self {
36 Self {
37 signature: Signature::one_of(
38 vec![
39 TypeSignature::Coercible(vec![Coercion::new_exact(
40 TypeSignatureClass::Native(logical_binary()),
41 )]),
42 TypeSignature::Coercible(vec![Coercion::new_exact(
43 TypeSignatureClass::Native(logical_string()),
44 )]),
45 ],
46 Volatility::Immutable,
47 ),
48 }
49 }
50}
51
52impl Function for ElemAvgFunction {
53 fn name(&self) -> &str {
54 NAME
55 }
56
57 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
58 Ok(DataType::Float32)
59 }
60
61 fn signature(&self) -> &Signature {
62 &self.signature
63 }
64
65 fn invoke_with_args(
66 &self,
67 args: ScalarFunctionArgs,
68 ) -> datafusion_common::Result<ColumnarValue> {
69 let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
70 let v0 =
71 impl_conv::as_veclit(v0)?.map(|v0| DVectorView::from_slice(&v0, v0.len()).mean());
72 Ok(ScalarValue::Float32(v0))
73 };
74
75 let calculator = VectorCalculator {
76 name: self.name(),
77 func: body,
78 };
79 calculator.invoke_with_single_argument(args)
80 }
81}
82
83impl Display for ElemAvgFunction {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 write!(f, "{}", NAME.to_ascii_uppercase())
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use std::sync::Arc;
92
93 use arrow::array::StringViewArray;
94 use arrow_schema::Field;
95 use datafusion::arrow::array::{Array, AsArray};
96 use datafusion::arrow::datatypes::Float32Type;
97 use datafusion_common::config::ConfigOptions;
98
99 use super::*;
100
101 #[test]
102 fn test_elem_avg() {
103 let func = ElemAvgFunction::default();
104
105 let input = Arc::new(StringViewArray::from(vec![
106 Some("[1.0,2.0,3.0]".to_string()),
107 Some("[4.0,5.0,6.0]".to_string()),
108 Some("[7.0,8.0,9.0]".to_string()),
109 None,
110 ]));
111
112 let result = func
113 .invoke_with_args(ScalarFunctionArgs {
114 args: vec![ColumnarValue::Array(input.clone())],
115 arg_fields: vec![],
116 number_rows: input.len(),
117 return_field: Arc::new(Field::new("x", DataType::Float32, true)),
118 config_options: Arc::new(ConfigOptions::new()),
119 })
120 .and_then(|v| ColumnarValue::values_to_arrays(&[v]))
121 .map(|mut a| a.remove(0))
122 .unwrap();
123 let result = result.as_primitive::<Float32Type>();
124
125 assert_eq!(result.len(), 4);
126 assert_eq!(result.value(0), 2.0);
127 assert_eq!(result.value(1), 5.0);
128 assert_eq!(result.value(2), 8.0);
129 assert!(result.is_null(3));
130 }
131}