common_function/scalars/
udf.rs1use std::any::Any;
16use std::fmt::{Debug, Formatter};
17use std::sync::Arc;
18
19use common_query::prelude::ColumnarValue;
20use datafusion::arrow::datatypes::DataType;
21use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl};
22use datafusion_expr::ScalarUDF;
23use session::context::QueryContextRef;
24
25use crate::function::{FunctionContext, FunctionRef};
26use crate::state::FunctionState;
27
28struct ScalarUdf {
29 function: FunctionRef,
30 signature: datafusion_expr::Signature,
31 context: FunctionContext,
32}
33
34impl Debug for ScalarUdf {
35 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
36 f.debug_struct("ScalarUdf")
37 .field("function", &self.function.name())
38 .field("signature", &self.signature)
39 .finish()
40 }
41}
42
43impl ScalarUDFImpl for ScalarUdf {
44 fn as_any(&self) -> &dyn Any {
45 self
46 }
47
48 fn name(&self) -> &str {
49 self.function.name()
50 }
51
52 fn aliases(&self) -> &[String] {
53 self.function.aliases()
54 }
55
56 fn signature(&self) -> &datafusion_expr::Signature {
57 &self.signature
58 }
59
60 fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
61 self.function.return_type(arg_types).map_err(Into::into)
62 }
63
64 fn invoke_with_args(
65 &self,
66 args: ScalarFunctionArgs,
67 ) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
68 let result = self.function.invoke_with_args(args.clone());
69 if !matches!(
70 result,
71 Err(datafusion_common::DataFusionError::NotImplemented(_))
72 ) {
73 return result;
74 }
75
76 let columns = args
77 .args
78 .iter()
79 .map(|x| ColumnarValue::try_from(x).and_then(|y| y.try_into_vector(args.number_rows)))
80 .collect::<common_query::error::Result<Vec<_>>>()?;
81 let v = self
82 .function
83 .eval(&self.context, &columns)
84 .map(ColumnarValue::Vector)?;
85 Ok(v.into())
86 }
87}
88
89pub fn create_udf(
91 func: FunctionRef,
92 query_ctx: QueryContextRef,
93 state: Arc<FunctionState>,
94) -> ScalarUDF {
95 let signature = func.signature();
96 let udf = ScalarUdf {
97 function: func,
98 signature,
99 context: FunctionContext { query_ctx, state },
100 };
101 ScalarUDF::new_from_impl(udf)
102}
103
104#[cfg(test)]
105mod tests {
106 use std::sync::Arc;
107
108 use common_query::prelude::ScalarValue;
109 use datafusion::arrow::array::BooleanArray;
110 use datafusion_common::config::ConfigOptions;
111 use datatypes::arrow::datatypes::Field;
112 use datatypes::data_type::{ConcreteDataType, DataType};
113 use datatypes::prelude::VectorRef;
114 use datatypes::value::Value;
115 use datatypes::vectors::{BooleanVector, ConstantVector};
116 use session::context::QueryContextBuilder;
117
118 use super::*;
119 use crate::function::Function;
120 use crate::scalars::test::TestAndFunction;
121
122 #[test]
123 fn test_create_udf() {
124 let f = Arc::new(TestAndFunction);
125 let query_ctx = QueryContextBuilder::default().build().into();
126
127 let args: Vec<VectorRef> = vec![
128 Arc::new(ConstantVector::new(
129 Arc::new(BooleanVector::from(vec![true])),
130 3,
131 )),
132 Arc::new(BooleanVector::from(vec![true, false, true])),
133 ];
134
135 let vector = f.eval(&FunctionContext::default(), &args).unwrap();
136 assert_eq!(3, vector.len());
137
138 for i in 0..3 {
139 assert!(matches!(vector.get(i), Value::Boolean(b) if b == (i == 0 || i == 2)));
140 }
141
142 let udf = create_udf(f.clone(), query_ctx, Arc::new(FunctionState::default()));
144
145 assert_eq!("test_and", udf.name());
146 let expected_signature: datafusion_expr::Signature = f.signature();
147 assert_eq!(udf.signature(), &expected_signature);
148 assert_eq!(
149 ConcreteDataType::boolean_datatype(),
150 udf.return_type(&[])
151 .map(|x| ConcreteDataType::from_arrow_type(&x))
152 .unwrap()
153 );
154
155 let args = vec![
156 datafusion_expr::ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))),
157 datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
158 true, false, false, true,
159 ]))),
160 ];
161
162 let arg_fields = vec![
163 Arc::new(Field::new("a", args[0].data_type(), false)),
164 Arc::new(Field::new("b", args[1].data_type(), false)),
165 ];
166 let return_field = Arc::new(Field::new(
167 "x",
168 ConcreteDataType::boolean_datatype().as_arrow_type(),
169 false,
170 ));
171 let args = ScalarFunctionArgs {
172 args,
173 arg_fields,
174 number_rows: 4,
175 return_field,
176 config_options: Arc::new(ConfigOptions::default()),
177 };
178 match udf.invoke_with_args(args).unwrap() {
179 datafusion_expr::ColumnarValue::Array(x) => {
180 let x = x.as_any().downcast_ref::<BooleanArray>().unwrap();
181 assert_eq!(x.len(), 4);
182 assert_eq!(
183 x.iter().flatten().collect::<Vec<bool>>(),
184 vec![true, false, false, true]
185 );
186 }
187 _ => unreachable!(),
188 }
189 }
190}