common_function/scalars/
udf.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::fmt::{Debug, Formatter};
17
18use datafusion::arrow::datatypes::DataType;
19use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl};
20use datafusion_expr::ScalarUDF;
21
22use crate::function::FunctionRef;
23
24struct ScalarUdf {
25    function: FunctionRef,
26}
27
28impl Debug for ScalarUdf {
29    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
30        f.debug_struct("ScalarUdf")
31            .field("function", &self.function.name())
32            .finish()
33    }
34}
35
36impl ScalarUDFImpl for ScalarUdf {
37    fn as_any(&self) -> &dyn Any {
38        self
39    }
40
41    fn name(&self) -> &str {
42        self.function.name()
43    }
44
45    fn aliases(&self) -> &[String] {
46        self.function.aliases()
47    }
48
49    fn signature(&self) -> &datafusion_expr::Signature {
50        self.function.signature()
51    }
52
53    fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
54        self.function.return_type(arg_types)
55    }
56
57    fn invoke_with_args(
58        &self,
59        args: ScalarFunctionArgs,
60    ) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
61        self.function.invoke_with_args(args)
62    }
63}
64
65/// Create a ScalarUdf from function, query context and state.
66pub fn create_udf(function: FunctionRef) -> ScalarUDF {
67    ScalarUDF::new_from_impl(ScalarUdf { function })
68}
69
70#[cfg(test)]
71mod tests {
72    use std::sync::Arc;
73
74    use common_query::prelude::ScalarValue;
75    use datafusion::arrow::array::BooleanArray;
76    use datafusion_common::arrow::array::AsArray;
77    use datafusion_common::arrow::datatypes::DataType as ArrowDataType;
78    use datafusion_common::config::ConfigOptions;
79    use datatypes::arrow::datatypes::Field;
80    use datatypes::data_type::{ConcreteDataType, DataType};
81
82    use super::*;
83    use crate::function::Function;
84    use crate::scalars::test::TestAndFunction;
85
86    #[test]
87    fn test_create_udf() {
88        let f = Arc::new(TestAndFunction::default());
89
90        let args = ScalarFunctionArgs {
91            args: vec![
92                datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
93                    true, true, true,
94                ]))),
95                datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
96                    true, false, true,
97                ]))),
98            ],
99            arg_fields: vec![],
100            number_rows: 3,
101            return_field: Arc::new(Field::new("x", ArrowDataType::Boolean, true)),
102            config_options: Arc::new(Default::default()),
103        };
104
105        let result = f
106            .invoke_with_args(args)
107            .and_then(|x| x.to_array(3))
108            .unwrap();
109        let vector = result.as_boolean();
110        assert_eq!(3, vector.len());
111
112        assert!(vector.value(0));
113        assert!(!vector.value(1));
114        assert!(vector.value(2));
115
116        // create a udf and test it again
117        let udf = create_udf(f);
118
119        assert_eq!("test_and", udf.name());
120        assert_eq!(
121            ConcreteDataType::boolean_datatype(),
122            udf.return_type(&[])
123                .map(|x| ConcreteDataType::from_arrow_type(&x))
124                .unwrap()
125        );
126
127        let args = vec![
128            datafusion_expr::ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))),
129            datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
130                true, false, false, true,
131            ]))),
132        ];
133
134        let arg_fields = vec![
135            Arc::new(Field::new("a", args[0].data_type(), false)),
136            Arc::new(Field::new("b", args[1].data_type(), false)),
137        ];
138        let return_field = Arc::new(Field::new(
139            "x",
140            ConcreteDataType::boolean_datatype().as_arrow_type(),
141            false,
142        ));
143        let args = ScalarFunctionArgs {
144            args,
145            arg_fields,
146            number_rows: 4,
147            return_field,
148            config_options: Arc::new(ConfigOptions::default()),
149        };
150        match udf.invoke_with_args(args).unwrap() {
151            datafusion_expr::ColumnarValue::Array(x) => {
152                let x = x.as_any().downcast_ref::<BooleanArray>().unwrap();
153                assert_eq!(x.len(), 4);
154                assert_eq!(
155                    x.iter().flatten().collect::<Vec<bool>>(),
156                    vec![true, false, false, true]
157                );
158            }
159            _ => unreachable!(),
160        }
161    }
162}