common_function/scalars/vector/
elem_avg.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::{Coercion, ColumnarValue, TypeSignature, TypeSignatureClass};
19use datafusion_common::ScalarValue;
20use datafusion_common::types::{logical_binary, logical_string};
21use datafusion_expr::{ScalarFunctionArgs, Signature, Volatility};
22use nalgebra::DVectorView;
23
24use crate::function::Function;
25use crate::scalars::vector::{VectorCalculator, impl_conv};
26
27const NAME: &str = "vec_elem_avg";
28
29#[derive(Debug, Clone)]
30pub(crate) struct ElemAvgFunction {
31    signature: Signature,
32}
33
34impl Default for ElemAvgFunction {
35    fn default() -> Self {
36        Self {
37            signature: Signature::one_of(
38                vec![
39                    TypeSignature::Coercible(vec![Coercion::new_exact(
40                        TypeSignatureClass::Native(logical_binary()),
41                    )]),
42                    TypeSignature::Coercible(vec![Coercion::new_exact(
43                        TypeSignatureClass::Native(logical_string()),
44                    )]),
45                ],
46                Volatility::Immutable,
47            ),
48        }
49    }
50}
51
52impl Function for ElemAvgFunction {
53    fn name(&self) -> &str {
54        NAME
55    }
56
57    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
58        Ok(DataType::Float32)
59    }
60
61    fn signature(&self) -> &Signature {
62        &self.signature
63    }
64
65    fn invoke_with_args(
66        &self,
67        args: ScalarFunctionArgs,
68    ) -> datafusion_common::Result<ColumnarValue> {
69        let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
70            let v0 =
71                impl_conv::as_veclit(v0)?.map(|v0| DVectorView::from_slice(&v0, v0.len()).mean());
72            Ok(ScalarValue::Float32(v0))
73        };
74
75        let calculator = VectorCalculator {
76            name: self.name(),
77            func: body,
78        };
79        calculator.invoke_with_single_argument(args)
80    }
81}
82
83impl Display for ElemAvgFunction {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        write!(f, "{}", NAME.to_ascii_uppercase())
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use std::sync::Arc;
92
93    use arrow::array::StringViewArray;
94    use arrow_schema::Field;
95    use datafusion::arrow::array::{Array, AsArray};
96    use datafusion::arrow::datatypes::Float32Type;
97    use datafusion_common::config::ConfigOptions;
98
99    use super::*;
100
101    #[test]
102    fn test_elem_avg() {
103        let func = ElemAvgFunction::default();
104
105        let input = Arc::new(StringViewArray::from(vec![
106            Some("[1.0,2.0,3.0]".to_string()),
107            Some("[4.0,5.0,6.0]".to_string()),
108            Some("[7.0,8.0,9.0]".to_string()),
109            None,
110        ]));
111
112        let result = func
113            .invoke_with_args(ScalarFunctionArgs {
114                args: vec![ColumnarValue::Array(input.clone())],
115                arg_fields: vec![],
116                number_rows: input.len(),
117                return_field: Arc::new(Field::new("x", DataType::Float32, true)),
118                config_options: Arc::new(ConfigOptions::new()),
119            })
120            .and_then(|v| ColumnarValue::values_to_arrays(&[v]))
121            .map(|mut a| a.remove(0))
122            .unwrap();
123        let result = result.as_primitive::<Float32Type>();
124
125        assert_eq!(result.len(), 4);
126        assert_eq!(result.value(0), 2.0);
127        assert_eq!(result.value(1), 5.0);
128        assert_eq!(result.value(2), 8.0);
129        assert!(result.is_null(3));
130    }
131}