common_function/scalars/vector/
elem_avg.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion_common::ScalarValue;
20use datafusion_expr::type_coercion::aggregates::{BINARYS, STRINGS};
21use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
22use nalgebra::DVectorView;
23
24use crate::function::Function;
25use crate::scalars::vector::{VectorCalculator, impl_conv};
26
27const NAME: &str = "vec_elem_avg";
28
29#[derive(Debug, Clone)]
30pub(crate) struct ElemAvgFunction {
31    signature: Signature,
32}
33
34impl Default for ElemAvgFunction {
35    fn default() -> Self {
36        Self {
37            signature: Signature::one_of(
38                vec![
39                    TypeSignature::Uniform(1, STRINGS.to_vec()),
40                    TypeSignature::Uniform(1, BINARYS.to_vec()),
41                    TypeSignature::Uniform(1, vec![DataType::BinaryView]),
42                ],
43                Volatility::Immutable,
44            ),
45        }
46    }
47}
48
49impl Function for ElemAvgFunction {
50    fn name(&self) -> &str {
51        NAME
52    }
53
54    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
55        Ok(DataType::Float32)
56    }
57
58    fn signature(&self) -> &Signature {
59        &self.signature
60    }
61
62    fn invoke_with_args(
63        &self,
64        args: ScalarFunctionArgs,
65    ) -> datafusion_common::Result<ColumnarValue> {
66        let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
67            let v0 =
68                impl_conv::as_veclit(v0)?.map(|v0| DVectorView::from_slice(&v0, v0.len()).mean());
69            Ok(ScalarValue::Float32(v0))
70        };
71
72        let calculator = VectorCalculator {
73            name: self.name(),
74            func: body,
75        };
76        calculator.invoke_with_single_argument(args)
77    }
78}
79
80impl Display for ElemAvgFunction {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        write!(f, "{}", NAME.to_ascii_uppercase())
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use std::sync::Arc;
89
90    use arrow::array::StringViewArray;
91    use arrow_schema::Field;
92    use datafusion::arrow::array::{Array, AsArray};
93    use datafusion::arrow::datatypes::Float32Type;
94    use datafusion_common::config::ConfigOptions;
95
96    use super::*;
97
98    #[test]
99    fn test_elem_avg() {
100        let func = ElemAvgFunction::default();
101
102        let input = Arc::new(StringViewArray::from(vec![
103            Some("[1.0,2.0,3.0]".to_string()),
104            Some("[4.0,5.0,6.0]".to_string()),
105            Some("[7.0,8.0,9.0]".to_string()),
106            None,
107        ]));
108
109        let result = func
110            .invoke_with_args(ScalarFunctionArgs {
111                args: vec![ColumnarValue::Array(input.clone())],
112                arg_fields: vec![],
113                number_rows: input.len(),
114                return_field: Arc::new(Field::new("x", DataType::Float32, true)),
115                config_options: Arc::new(ConfigOptions::new()),
116            })
117            .and_then(|v| ColumnarValue::values_to_arrays(&[v]))
118            .map(|mut a| a.remove(0))
119            .unwrap();
120        let result = result.as_primitive::<Float32Type>();
121
122        assert_eq!(result.len(), 4);
123        assert_eq!(result.value(0), 2.0);
124        assert_eq!(result.value(1), 5.0);
125        assert_eq!(result.value(2), 8.0);
126        assert!(result.is_null(3));
127    }
128}