common_function/scalars/
udf.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::fmt::{Debug, Formatter};
17use std::sync::Arc;
18
19use common_query::error::FromScalarValueSnafu;
20use common_query::prelude::ColumnarValue;
21use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl};
22use datafusion_expr::ScalarUDF;
23use datatypes::data_type::DataType;
24use datatypes::prelude::*;
25use datatypes::vectors::Helper;
26use session::context::QueryContextRef;
27use snafu::ResultExt;
28
29use crate::function::{FunctionContext, FunctionRef};
30use crate::state::FunctionState;
31
32struct ScalarUdf {
33    function: FunctionRef,
34    signature: datafusion_expr::Signature,
35    context: FunctionContext,
36}
37
38impl Debug for ScalarUdf {
39    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
40        f.debug_struct("ScalarUdf")
41            .field("function", &self.function.name())
42            .field("signature", &self.signature)
43            .finish()
44    }
45}
46
47impl ScalarUDFImpl for ScalarUdf {
48    fn as_any(&self) -> &dyn Any {
49        self
50    }
51
52    fn name(&self) -> &str {
53        self.function.name()
54    }
55
56    fn signature(&self) -> &datafusion_expr::Signature {
57        &self.signature
58    }
59
60    fn return_type(
61        &self,
62        arg_types: &[datatypes::arrow::datatypes::DataType],
63    ) -> datafusion_common::Result<datatypes::arrow::datatypes::DataType> {
64        let arg_types = arg_types
65            .iter()
66            .map(ConcreteDataType::from_arrow_type)
67            .collect::<Vec<_>>();
68        let t = self.function.return_type(&arg_types)?;
69        Ok(t.as_arrow_type())
70    }
71
72    fn invoke_with_args(
73        &self,
74        args: ScalarFunctionArgs,
75    ) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
76        let columns = args
77            .args
78            .iter()
79            .map(|x| {
80                ColumnarValue::try_from(x).and_then(|y| match y {
81                    ColumnarValue::Vector(z) => Ok(z),
82                    ColumnarValue::Scalar(z) => Helper::try_from_scalar_value(z, args.number_rows)
83                        .context(FromScalarValueSnafu),
84                })
85            })
86            .collect::<common_query::error::Result<Vec<_>>>()?;
87        let v = self
88            .function
89            .eval(&self.context, &columns)
90            .map(ColumnarValue::Vector)?;
91        Ok(v.into())
92    }
93}
94
95/// Create a ScalarUdf from function, query context and state.
96pub fn create_udf(
97    func: FunctionRef,
98    query_ctx: QueryContextRef,
99    state: Arc<FunctionState>,
100) -> ScalarUDF {
101    let signature = func.signature().into();
102    let udf = ScalarUdf {
103        function: func,
104        signature,
105        context: FunctionContext { query_ctx, state },
106    };
107    ScalarUDF::new_from_impl(udf)
108}
109
110#[cfg(test)]
111mod tests {
112    use std::sync::Arc;
113
114    use common_query::prelude::ScalarValue;
115    use datafusion::arrow::array::BooleanArray;
116    use datatypes::data_type::ConcreteDataType;
117    use datatypes::prelude::VectorRef;
118    use datatypes::vectors::{BooleanVector, ConstantVector};
119    use session::context::QueryContextBuilder;
120
121    use super::*;
122    use crate::function::Function;
123    use crate::scalars::test::TestAndFunction;
124
125    #[test]
126    fn test_create_udf() {
127        let f = Arc::new(TestAndFunction);
128        let query_ctx = QueryContextBuilder::default().build().into();
129
130        let args: Vec<VectorRef> = vec![
131            Arc::new(ConstantVector::new(
132                Arc::new(BooleanVector::from(vec![true])),
133                3,
134            )),
135            Arc::new(BooleanVector::from(vec![true, false, true])),
136        ];
137
138        let vector = f.eval(&FunctionContext::default(), &args).unwrap();
139        assert_eq!(3, vector.len());
140
141        for i in 0..3 {
142            assert!(matches!(vector.get(i), Value::Boolean(b) if b == (i == 0 || i == 2)));
143        }
144
145        // create a udf and test it again
146        let udf = create_udf(f.clone(), query_ctx, Arc::new(FunctionState::default()));
147
148        assert_eq!("test_and", udf.name());
149        let expected_signature: datafusion_expr::Signature = f.signature().into();
150        assert_eq!(udf.signature(), &expected_signature);
151        assert_eq!(
152            ConcreteDataType::boolean_datatype(),
153            udf.return_type(&[])
154                .map(|x| ConcreteDataType::from_arrow_type(&x))
155                .unwrap()
156        );
157
158        let args = vec![
159            datafusion_expr::ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))),
160            datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
161                true, false, false, true,
162            ]))),
163        ];
164
165        let args = ScalarFunctionArgs {
166            args,
167            number_rows: 4,
168            return_type: &ConcreteDataType::boolean_datatype().as_arrow_type(),
169        };
170        match udf.invoke_with_args(args).unwrap() {
171            datafusion_expr::ColumnarValue::Array(x) => {
172                let x = x.as_any().downcast_ref::<BooleanArray>().unwrap();
173                assert_eq!(x.len(), 4);
174                assert_eq!(
175                    x.iter().flatten().collect::<Vec<bool>>(),
176                    vec![true, false, false, true]
177                );
178            }
179            _ => unreachable!(),
180        }
181    }
182}