common_function/scalars/
udf.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::fmt::{Debug, Formatter};
17use std::sync::Arc;
18
19use common_query::prelude::ColumnarValue;
20use datafusion::arrow::datatypes::DataType;
21use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl};
22use datafusion_expr::ScalarUDF;
23use session::context::QueryContextRef;
24
25use crate::function::{FunctionContext, FunctionRef};
26use crate::state::FunctionState;
27
28struct ScalarUdf {
29    function: FunctionRef,
30    signature: datafusion_expr::Signature,
31    context: FunctionContext,
32}
33
34impl Debug for ScalarUdf {
35    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
36        f.debug_struct("ScalarUdf")
37            .field("function", &self.function.name())
38            .field("signature", &self.signature)
39            .finish()
40    }
41}
42
43impl ScalarUDFImpl for ScalarUdf {
44    fn as_any(&self) -> &dyn Any {
45        self
46    }
47
48    fn name(&self) -> &str {
49        self.function.name()
50    }
51
52    fn aliases(&self) -> &[String] {
53        self.function.aliases()
54    }
55
56    fn signature(&self) -> &datafusion_expr::Signature {
57        &self.signature
58    }
59
60    fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
61        self.function.return_type(arg_types).map_err(Into::into)
62    }
63
64    fn invoke_with_args(
65        &self,
66        args: ScalarFunctionArgs,
67    ) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
68        let result = self.function.invoke_with_args(args.clone());
69        if !matches!(
70            result,
71            Err(datafusion_common::DataFusionError::NotImplemented(_))
72        ) {
73            return result;
74        }
75
76        let columns = args
77            .args
78            .iter()
79            .map(|x| ColumnarValue::try_from(x).and_then(|y| y.try_into_vector(args.number_rows)))
80            .collect::<common_query::error::Result<Vec<_>>>()?;
81        let v = self
82            .function
83            .eval(&self.context, &columns)
84            .map(ColumnarValue::Vector)?;
85        Ok(v.into())
86    }
87}
88
89/// Create a ScalarUdf from function, query context and state.
90pub fn create_udf(
91    func: FunctionRef,
92    query_ctx: QueryContextRef,
93    state: Arc<FunctionState>,
94) -> ScalarUDF {
95    let signature = func.signature();
96    let udf = ScalarUdf {
97        function: func,
98        signature,
99        context: FunctionContext { query_ctx, state },
100    };
101    ScalarUDF::new_from_impl(udf)
102}
103
104#[cfg(test)]
105mod tests {
106    use std::sync::Arc;
107
108    use common_query::prelude::ScalarValue;
109    use datafusion::arrow::array::BooleanArray;
110    use datafusion_common::config::ConfigOptions;
111    use datatypes::arrow::datatypes::Field;
112    use datatypes::data_type::{ConcreteDataType, DataType};
113    use datatypes::prelude::VectorRef;
114    use datatypes::value::Value;
115    use datatypes::vectors::{BooleanVector, ConstantVector};
116    use session::context::QueryContextBuilder;
117
118    use super::*;
119    use crate::function::Function;
120    use crate::scalars::test::TestAndFunction;
121
122    #[test]
123    fn test_create_udf() {
124        let f = Arc::new(TestAndFunction);
125        let query_ctx = QueryContextBuilder::default().build().into();
126
127        let args: Vec<VectorRef> = vec![
128            Arc::new(ConstantVector::new(
129                Arc::new(BooleanVector::from(vec![true])),
130                3,
131            )),
132            Arc::new(BooleanVector::from(vec![true, false, true])),
133        ];
134
135        let vector = f.eval(&FunctionContext::default(), &args).unwrap();
136        assert_eq!(3, vector.len());
137
138        for i in 0..3 {
139            assert!(matches!(vector.get(i), Value::Boolean(b) if b == (i == 0 || i == 2)));
140        }
141
142        // create a udf and test it again
143        let udf = create_udf(f.clone(), query_ctx, Arc::new(FunctionState::default()));
144
145        assert_eq!("test_and", udf.name());
146        let expected_signature: datafusion_expr::Signature = f.signature();
147        assert_eq!(udf.signature(), &expected_signature);
148        assert_eq!(
149            ConcreteDataType::boolean_datatype(),
150            udf.return_type(&[])
151                .map(|x| ConcreteDataType::from_arrow_type(&x))
152                .unwrap()
153        );
154
155        let args = vec![
156            datafusion_expr::ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))),
157            datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
158                true, false, false, true,
159            ]))),
160        ];
161
162        let arg_fields = vec![
163            Arc::new(Field::new("a", args[0].data_type(), false)),
164            Arc::new(Field::new("b", args[1].data_type(), false)),
165        ];
166        let return_field = Arc::new(Field::new(
167            "x",
168            ConcreteDataType::boolean_datatype().as_arrow_type(),
169            false,
170        ));
171        let args = ScalarFunctionArgs {
172            args,
173            arg_fields,
174            number_rows: 4,
175            return_field,
176            config_options: Arc::new(ConfigOptions::default()),
177        };
178        match udf.invoke_with_args(args).unwrap() {
179            datafusion_expr::ColumnarValue::Array(x) => {
180                let x = x.as_any().downcast_ref::<BooleanArray>().unwrap();
181                assert_eq!(x.len(), 4);
182                assert_eq!(
183                    x.iter().flatten().collect::<Vec<bool>>(),
184                    vec![true, false, false, true]
185                );
186            }
187            _ => unreachable!(),
188        }
189    }
190}