common_function/scalars/
udf.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::fmt::{Debug, Formatter};
17use std::hash::{Hash, Hasher};
18
19use datafusion::arrow::datatypes::DataType;
20use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl};
21use datafusion_expr::ScalarUDF;
22
23use crate::function::FunctionRef;
24
25struct ScalarUdf {
26    function: FunctionRef,
27}
28
29impl Debug for ScalarUdf {
30    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("ScalarUdf")
32            .field("function", &self.function.name())
33            .finish()
34    }
35}
36
37impl PartialEq for ScalarUdf {
38    fn eq(&self, other: &Self) -> bool {
39        self.function.signature() == other.function.signature()
40    }
41}
42
43impl Eq for ScalarUdf {}
44
45impl Hash for ScalarUdf {
46    fn hash<H: Hasher>(&self, state: &mut H) {
47        self.function.signature().hash(state)
48    }
49}
50
51impl ScalarUDFImpl for ScalarUdf {
52    fn as_any(&self) -> &dyn Any {
53        self
54    }
55
56    fn name(&self) -> &str {
57        self.function.name()
58    }
59
60    fn aliases(&self) -> &[String] {
61        self.function.aliases()
62    }
63
64    fn signature(&self) -> &datafusion_expr::Signature {
65        self.function.signature()
66    }
67
68    fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
69        self.function.return_type(arg_types)
70    }
71
72    fn invoke_with_args(
73        &self,
74        args: ScalarFunctionArgs,
75    ) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
76        self.function.invoke_with_args(args)
77    }
78}
79
80/// Create a ScalarUdf from function, query context and state.
81pub fn create_udf(function: FunctionRef) -> ScalarUDF {
82    ScalarUDF::new_from_impl(ScalarUdf { function })
83}
84
85#[cfg(test)]
86mod tests {
87    use std::sync::Arc;
88
89    use common_query::prelude::ScalarValue;
90    use datafusion::arrow::array::BooleanArray;
91    use datafusion_common::arrow::array::AsArray;
92    use datafusion_common::arrow::datatypes::DataType as ArrowDataType;
93    use datafusion_common::config::ConfigOptions;
94    use datatypes::arrow::datatypes::Field;
95    use datatypes::data_type::{ConcreteDataType, DataType};
96
97    use super::*;
98    use crate::function::Function;
99    use crate::scalars::test::TestAndFunction;
100
101    #[test]
102    fn test_create_udf() {
103        let f = Arc::new(TestAndFunction::default());
104
105        let args = ScalarFunctionArgs {
106            args: vec![
107                datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
108                    true, true, true,
109                ]))),
110                datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
111                    true, false, true,
112                ]))),
113            ],
114            arg_fields: vec![],
115            number_rows: 3,
116            return_field: Arc::new(Field::new("x", ArrowDataType::Boolean, true)),
117            config_options: Arc::new(Default::default()),
118        };
119
120        let result = f
121            .invoke_with_args(args)
122            .and_then(|x| x.to_array(3))
123            .unwrap();
124        let vector = result.as_boolean();
125        assert_eq!(3, vector.len());
126
127        assert!(vector.value(0));
128        assert!(!vector.value(1));
129        assert!(vector.value(2));
130
131        // create a udf and test it again
132        let udf = create_udf(f);
133
134        assert_eq!("test_and", udf.name());
135        assert_eq!(
136            ConcreteDataType::boolean_datatype(),
137            udf.return_type(&[])
138                .map(|x| ConcreteDataType::from_arrow_type(&x))
139                .unwrap()
140        );
141
142        let args = vec![
143            datafusion_expr::ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))),
144            datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
145                true, false, false, true,
146            ]))),
147        ];
148
149        let arg_fields = vec![
150            Arc::new(Field::new("a", args[0].data_type(), false)),
151            Arc::new(Field::new("b", args[1].data_type(), false)),
152        ];
153        let return_field = Arc::new(Field::new(
154            "x",
155            ConcreteDataType::boolean_datatype().as_arrow_type(),
156            false,
157        ));
158        let args = ScalarFunctionArgs {
159            args,
160            arg_fields,
161            number_rows: 4,
162            return_field,
163            config_options: Arc::new(ConfigOptions::default()),
164        };
165        match udf.invoke_with_args(args).unwrap() {
166            datafusion_expr::ColumnarValue::Array(x) => {
167                let x = x.as_any().downcast_ref::<BooleanArray>().unwrap();
168                assert_eq!(x.len(), 4);
169                assert_eq!(
170                    x.iter().flatten().collect::<Vec<bool>>(),
171                    vec![true, false, false, true]
172                );
173            }
174            _ => unreachable!(),
175        }
176    }
177}