common_function/scalars/
udf.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::fmt::{Debug, Formatter};
17use std::hash::{Hash, Hasher};
18
19use datafusion::arrow::datatypes::DataType;
20use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl};
21use datafusion_expr::ScalarUDF;
22
23use crate::function::FunctionRef;
24
25struct ScalarUdf {
26    function: FunctionRef,
27}
28
29impl Debug for ScalarUdf {
30    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("ScalarUdf")
32            .field("function", &self.function.name())
33            .finish()
34    }
35}
36
37impl PartialEq for ScalarUdf {
38    fn eq(&self, other: &Self) -> bool {
39        self.function.signature() == other.function.signature()
40    }
41}
42
43impl Eq for ScalarUdf {}
44
45impl Hash for ScalarUdf {
46    fn hash<H: Hasher>(&self, state: &mut H) {
47        self.function.signature().hash(state)
48    }
49}
50
51impl ScalarUDFImpl for ScalarUdf {
52    fn as_any(&self) -> &dyn Any {
53        self
54    }
55
56    fn name(&self) -> &str {
57        self.function.name()
58    }
59
60    fn aliases(&self) -> &[String] {
61        self.function.aliases()
62    }
63
64    fn signature(&self) -> &datafusion_expr::Signature {
65        self.function.signature()
66    }
67
68    fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
69        self.function.return_type(arg_types)
70    }
71
72    fn return_field_from_args(
73        &self,
74        args: datafusion_expr::ReturnFieldArgs,
75    ) -> datafusion_common::Result<arrow_schema::FieldRef> {
76        self.function.return_field_from_args(args)
77    }
78
79    fn invoke_with_args(
80        &self,
81        args: ScalarFunctionArgs,
82    ) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
83        self.function.invoke_with_args(args)
84    }
85}
86
87/// Create a ScalarUdf from function, query context and state.
88pub fn create_udf(function: FunctionRef) -> ScalarUDF {
89    ScalarUDF::new_from_impl(ScalarUdf { function })
90}
91
92#[cfg(test)]
93mod tests {
94    use std::sync::Arc;
95
96    use common_query::prelude::ScalarValue;
97    use datafusion::arrow::array::BooleanArray;
98    use datafusion_common::arrow::array::AsArray;
99    use datafusion_common::arrow::datatypes::DataType as ArrowDataType;
100    use datafusion_common::config::ConfigOptions;
101    use datatypes::arrow::datatypes::Field;
102    use datatypes::data_type::{ConcreteDataType, DataType};
103
104    use super::*;
105    use crate::function::Function;
106    use crate::scalars::test::TestAndFunction;
107
108    #[test]
109    fn test_create_udf() {
110        let f = Arc::new(TestAndFunction::default());
111
112        let args = ScalarFunctionArgs {
113            args: vec![
114                datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
115                    true, true, true,
116                ]))),
117                datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
118                    true, false, true,
119                ]))),
120            ],
121            arg_fields: vec![],
122            number_rows: 3,
123            return_field: Arc::new(Field::new("x", ArrowDataType::Boolean, true)),
124            config_options: Arc::new(Default::default()),
125        };
126
127        let result = f
128            .invoke_with_args(args)
129            .and_then(|x| x.to_array(3))
130            .unwrap();
131        let vector = result.as_boolean();
132        assert_eq!(3, vector.len());
133
134        assert!(vector.value(0));
135        assert!(!vector.value(1));
136        assert!(vector.value(2));
137
138        // create a udf and test it again
139        let udf = create_udf(f);
140
141        assert_eq!("test_and", udf.name());
142        assert_eq!(
143            ConcreteDataType::boolean_datatype(),
144            udf.return_type(&[])
145                .map(|x| ConcreteDataType::from_arrow_type(&x))
146                .unwrap()
147        );
148
149        let args = vec![
150            datafusion_expr::ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))),
151            datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
152                true, false, false, true,
153            ]))),
154        ];
155
156        let arg_fields = vec![
157            Arc::new(Field::new("a", args[0].data_type(), false)),
158            Arc::new(Field::new("b", args[1].data_type(), false)),
159        ];
160        let return_field = Arc::new(Field::new(
161            "x",
162            ConcreteDataType::boolean_datatype().as_arrow_type(),
163            false,
164        ));
165        let args = ScalarFunctionArgs {
166            args,
167            arg_fields,
168            number_rows: 4,
169            return_field,
170            config_options: Arc::new(ConfigOptions::default()),
171        };
172        match udf.invoke_with_args(args).unwrap() {
173            datafusion_expr::ColumnarValue::Array(x) => {
174                let x = x.as_any().downcast_ref::<BooleanArray>().unwrap();
175                assert_eq!(x.len(), 4);
176                assert_eq!(
177                    x.iter().flatten().collect::<Vec<bool>>(),
178                    vec![true, false, false, true]
179                );
180            }
181            _ => unreachable!(),
182        }
183    }
184}