common_function/scalars/
udf.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::fmt::{Debug, Formatter};
17use std::sync::Arc;
18
19use common_query::error::FromScalarValueSnafu;
20use common_query::prelude::ColumnarValue;
21use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl};
22use datafusion_expr::ScalarUDF;
23use datatypes::data_type::DataType;
24use datatypes::prelude::*;
25use datatypes::vectors::Helper;
26use session::context::QueryContextRef;
27use snafu::ResultExt;
28
29use crate::function::{FunctionContext, FunctionRef};
30use crate::state::FunctionState;
31
32struct ScalarUdf {
33    function: FunctionRef,
34    signature: datafusion_expr::Signature,
35    context: FunctionContext,
36}
37
38impl Debug for ScalarUdf {
39    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
40        f.debug_struct("ScalarUdf")
41            .field("function", &self.function.name())
42            .field("signature", &self.signature)
43            .finish()
44    }
45}
46
47impl ScalarUDFImpl for ScalarUdf {
48    fn as_any(&self) -> &dyn Any {
49        self
50    }
51
52    fn name(&self) -> &str {
53        self.function.name()
54    }
55
56    fn signature(&self) -> &datafusion_expr::Signature {
57        &self.signature
58    }
59
60    fn return_type(
61        &self,
62        arg_types: &[datatypes::arrow::datatypes::DataType],
63    ) -> datafusion_common::Result<datatypes::arrow::datatypes::DataType> {
64        let arg_types = arg_types
65            .iter()
66            .map(ConcreteDataType::from_arrow_type)
67            .collect::<Vec<_>>();
68        let t = self.function.return_type(&arg_types)?;
69        Ok(t.as_arrow_type())
70    }
71
72    fn invoke_with_args(
73        &self,
74        args: ScalarFunctionArgs,
75    ) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
76        let columns = args
77            .args
78            .iter()
79            .map(|x| {
80                ColumnarValue::try_from(x).and_then(|y| match y {
81                    ColumnarValue::Vector(z) => Ok(z),
82                    ColumnarValue::Scalar(z) => Helper::try_from_scalar_value(z, args.number_rows)
83                        .context(FromScalarValueSnafu),
84                })
85            })
86            .collect::<common_query::error::Result<Vec<_>>>()?;
87        let v = self
88            .function
89            .eval(&self.context, &columns)
90            .map(ColumnarValue::Vector)?;
91        Ok(v.into())
92    }
93}
94
95/// Create a ScalarUdf from function, query context and state.
96pub fn create_udf(
97    func: FunctionRef,
98    query_ctx: QueryContextRef,
99    state: Arc<FunctionState>,
100) -> ScalarUDF {
101    let signature = func.signature().into();
102    let udf = ScalarUdf {
103        function: func,
104        signature,
105        context: FunctionContext { query_ctx, state },
106    };
107    ScalarUDF::new_from_impl(udf)
108}
109
110#[cfg(test)]
111mod tests {
112    use std::sync::Arc;
113
114    use common_query::prelude::ScalarValue;
115    use datafusion::arrow::array::BooleanArray;
116    use datafusion_common::config::ConfigOptions;
117    use datatypes::arrow::datatypes::Field;
118    use datatypes::data_type::ConcreteDataType;
119    use datatypes::prelude::VectorRef;
120    use datatypes::vectors::{BooleanVector, ConstantVector};
121    use session::context::QueryContextBuilder;
122
123    use super::*;
124    use crate::function::Function;
125    use crate::scalars::test::TestAndFunction;
126
127    #[test]
128    fn test_create_udf() {
129        let f = Arc::new(TestAndFunction);
130        let query_ctx = QueryContextBuilder::default().build().into();
131
132        let args: Vec<VectorRef> = vec![
133            Arc::new(ConstantVector::new(
134                Arc::new(BooleanVector::from(vec![true])),
135                3,
136            )),
137            Arc::new(BooleanVector::from(vec![true, false, true])),
138        ];
139
140        let vector = f.eval(&FunctionContext::default(), &args).unwrap();
141        assert_eq!(3, vector.len());
142
143        for i in 0..3 {
144            assert!(matches!(vector.get(i), Value::Boolean(b) if b == (i == 0 || i == 2)));
145        }
146
147        // create a udf and test it again
148        let udf = create_udf(f.clone(), query_ctx, Arc::new(FunctionState::default()));
149
150        assert_eq!("test_and", udf.name());
151        let expected_signature: datafusion_expr::Signature = f.signature().into();
152        assert_eq!(udf.signature(), &expected_signature);
153        assert_eq!(
154            ConcreteDataType::boolean_datatype(),
155            udf.return_type(&[])
156                .map(|x| ConcreteDataType::from_arrow_type(&x))
157                .unwrap()
158        );
159
160        let args = vec![
161            datafusion_expr::ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))),
162            datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
163                true, false, false, true,
164            ]))),
165        ];
166
167        let arg_fields = vec![
168            Arc::new(Field::new("a", args[0].data_type(), false)),
169            Arc::new(Field::new("b", args[1].data_type(), false)),
170        ];
171        let return_field = Arc::new(Field::new(
172            "x",
173            ConcreteDataType::boolean_datatype().as_arrow_type(),
174            false,
175        ));
176        let args = ScalarFunctionArgs {
177            args,
178            arg_fields,
179            number_rows: 4,
180            return_field,
181            config_options: Arc::new(ConfigOptions::default()),
182        };
183        match udf.invoke_with_args(args).unwrap() {
184            datafusion_expr::ColumnarValue::Array(x) => {
185                let x = x.as_any().downcast_ref::<BooleanArray>().unwrap();
186                assert_eq!(x.len(), 4);
187                assert_eq!(
188                    x.iter().flatten().collect::<Vec<bool>>(),
189                    vec![true, false, false, true]
190                );
191            }
192            _ => unreachable!(),
193        }
194    }
195}