common_function/scalars/
udf.rs1use std::any::Any;
16use std::fmt::{Debug, Formatter};
17use std::hash::{Hash, Hasher};
18
19use datafusion::arrow::datatypes::DataType;
20use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl};
21use datafusion_expr::ScalarUDF;
22
23use crate::function::FunctionRef;
24
25struct ScalarUdf {
26 function: FunctionRef,
27}
28
29impl Debug for ScalarUdf {
30 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
31 f.debug_struct("ScalarUdf")
32 .field("function", &self.function.name())
33 .finish()
34 }
35}
36
37impl PartialEq for ScalarUdf {
38 fn eq(&self, other: &Self) -> bool {
39 self.function.signature() == other.function.signature()
40 }
41}
42
43impl Eq for ScalarUdf {}
44
45impl Hash for ScalarUdf {
46 fn hash<H: Hasher>(&self, state: &mut H) {
47 self.function.signature().hash(state)
48 }
49}
50
51impl ScalarUDFImpl for ScalarUdf {
52 fn as_any(&self) -> &dyn Any {
53 self
54 }
55
56 fn name(&self) -> &str {
57 self.function.name()
58 }
59
60 fn aliases(&self) -> &[String] {
61 self.function.aliases()
62 }
63
64 fn signature(&self) -> &datafusion_expr::Signature {
65 self.function.signature()
66 }
67
68 fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
69 self.function.return_type(arg_types)
70 }
71
72 fn return_field_from_args(
73 &self,
74 args: datafusion_expr::ReturnFieldArgs,
75 ) -> datafusion_common::Result<arrow_schema::FieldRef> {
76 self.function.return_field_from_args(args)
77 }
78
79 fn invoke_with_args(
80 &self,
81 args: ScalarFunctionArgs,
82 ) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
83 self.function.invoke_with_args(args)
84 }
85}
86
87pub fn create_udf(function: FunctionRef) -> ScalarUDF {
89 ScalarUDF::new_from_impl(ScalarUdf { function })
90}
91
92#[cfg(test)]
93mod tests {
94 use std::sync::Arc;
95
96 use common_query::prelude::ScalarValue;
97 use datafusion::arrow::array::BooleanArray;
98 use datafusion_common::arrow::array::AsArray;
99 use datafusion_common::arrow::datatypes::DataType as ArrowDataType;
100 use datafusion_common::config::ConfigOptions;
101 use datatypes::arrow::datatypes::Field;
102 use datatypes::data_type::{ConcreteDataType, DataType};
103
104 use super::*;
105 use crate::function::Function;
106 use crate::scalars::test::TestAndFunction;
107
108 #[test]
109 fn test_create_udf() {
110 let f = Arc::new(TestAndFunction::default());
111
112 let args = ScalarFunctionArgs {
113 args: vec![
114 datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
115 true, true, true,
116 ]))),
117 datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
118 true, false, true,
119 ]))),
120 ],
121 arg_fields: vec![],
122 number_rows: 3,
123 return_field: Arc::new(Field::new("x", ArrowDataType::Boolean, true)),
124 config_options: Arc::new(Default::default()),
125 };
126
127 let result = f
128 .invoke_with_args(args)
129 .and_then(|x| x.to_array(3))
130 .unwrap();
131 let vector = result.as_boolean();
132 assert_eq!(3, vector.len());
133
134 assert!(vector.value(0));
135 assert!(!vector.value(1));
136 assert!(vector.value(2));
137
138 let udf = create_udf(f);
140
141 assert_eq!("test_and", udf.name());
142 assert_eq!(
143 ConcreteDataType::boolean_datatype(),
144 udf.return_type(&[])
145 .map(|x| ConcreteDataType::from_arrow_type(&x))
146 .unwrap()
147 );
148
149 let args = vec![
150 datafusion_expr::ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))),
151 datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
152 true, false, false, true,
153 ]))),
154 ];
155
156 let arg_fields = vec![
157 Arc::new(Field::new("a", args[0].data_type(), false)),
158 Arc::new(Field::new("b", args[1].data_type(), false)),
159 ];
160 let return_field = Arc::new(Field::new(
161 "x",
162 ConcreteDataType::boolean_datatype().as_arrow_type(),
163 false,
164 ));
165 let args = ScalarFunctionArgs {
166 args,
167 arg_fields,
168 number_rows: 4,
169 return_field,
170 config_options: Arc::new(ConfigOptions::default()),
171 };
172 match udf.invoke_with_args(args).unwrap() {
173 datafusion_expr::ColumnarValue::Array(x) => {
174 let x = x.as_any().downcast_ref::<BooleanArray>().unwrap();
175 assert_eq!(x.len(), 4);
176 assert_eq!(
177 x.iter().flatten().collect::<Vec<bool>>(),
178 vec![true, false, false, true]
179 );
180 }
181 _ => unreachable!(),
182 }
183 }
184}