common_function/scalars/
udf.rs1use std::any::Any;
16use std::fmt::{Debug, Formatter};
17use std::hash::{Hash, Hasher};
18
19use datafusion::arrow::datatypes::DataType;
20use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl};
21use datafusion_expr::ScalarUDF;
22
23use crate::function::FunctionRef;
24
25struct ScalarUdf {
26 function: FunctionRef,
27}
28
29impl Debug for ScalarUdf {
30 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
31 f.debug_struct("ScalarUdf")
32 .field("function", &self.function.name())
33 .finish()
34 }
35}
36
37impl PartialEq for ScalarUdf {
38 fn eq(&self, other: &Self) -> bool {
39 self.function.signature() == other.function.signature()
40 }
41}
42
43impl Eq for ScalarUdf {}
44
45impl Hash for ScalarUdf {
46 fn hash<H: Hasher>(&self, state: &mut H) {
47 self.function.signature().hash(state)
48 }
49}
50
51impl ScalarUDFImpl for ScalarUdf {
52 fn as_any(&self) -> &dyn Any {
53 self
54 }
55
56 fn name(&self) -> &str {
57 self.function.name()
58 }
59
60 fn aliases(&self) -> &[String] {
61 self.function.aliases()
62 }
63
64 fn signature(&self) -> &datafusion_expr::Signature {
65 self.function.signature()
66 }
67
68 fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
69 self.function.return_type(arg_types)
70 }
71
72 fn invoke_with_args(
73 &self,
74 args: ScalarFunctionArgs,
75 ) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
76 self.function.invoke_with_args(args)
77 }
78}
79
80pub fn create_udf(function: FunctionRef) -> ScalarUDF {
82 ScalarUDF::new_from_impl(ScalarUdf { function })
83}
84
85#[cfg(test)]
86mod tests {
87 use std::sync::Arc;
88
89 use common_query::prelude::ScalarValue;
90 use datafusion::arrow::array::BooleanArray;
91 use datafusion_common::arrow::array::AsArray;
92 use datafusion_common::arrow::datatypes::DataType as ArrowDataType;
93 use datafusion_common::config::ConfigOptions;
94 use datatypes::arrow::datatypes::Field;
95 use datatypes::data_type::{ConcreteDataType, DataType};
96
97 use super::*;
98 use crate::function::Function;
99 use crate::scalars::test::TestAndFunction;
100
101 #[test]
102 fn test_create_udf() {
103 let f = Arc::new(TestAndFunction::default());
104
105 let args = ScalarFunctionArgs {
106 args: vec![
107 datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
108 true, true, true,
109 ]))),
110 datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
111 true, false, true,
112 ]))),
113 ],
114 arg_fields: vec![],
115 number_rows: 3,
116 return_field: Arc::new(Field::new("x", ArrowDataType::Boolean, true)),
117 config_options: Arc::new(Default::default()),
118 };
119
120 let result = f
121 .invoke_with_args(args)
122 .and_then(|x| x.to_array(3))
123 .unwrap();
124 let vector = result.as_boolean();
125 assert_eq!(3, vector.len());
126
127 assert!(vector.value(0));
128 assert!(!vector.value(1));
129 assert!(vector.value(2));
130
131 let udf = create_udf(f);
133
134 assert_eq!("test_and", udf.name());
135 assert_eq!(
136 ConcreteDataType::boolean_datatype(),
137 udf.return_type(&[])
138 .map(|x| ConcreteDataType::from_arrow_type(&x))
139 .unwrap()
140 );
141
142 let args = vec![
143 datafusion_expr::ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))),
144 datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
145 true, false, false, true,
146 ]))),
147 ];
148
149 let arg_fields = vec![
150 Arc::new(Field::new("a", args[0].data_type(), false)),
151 Arc::new(Field::new("b", args[1].data_type(), false)),
152 ];
153 let return_field = Arc::new(Field::new(
154 "x",
155 ConcreteDataType::boolean_datatype().as_arrow_type(),
156 false,
157 ));
158 let args = ScalarFunctionArgs {
159 args,
160 arg_fields,
161 number_rows: 4,
162 return_field,
163 config_options: Arc::new(ConfigOptions::default()),
164 };
165 match udf.invoke_with_args(args).unwrap() {
166 datafusion_expr::ColumnarValue::Array(x) => {
167 let x = x.as_any().downcast_ref::<BooleanArray>().unwrap();
168 assert_eq!(x.len(), 4);
169 assert_eq!(
170 x.iter().flatten().collect::<Vec<bool>>(),
171 vec![true, false, false, true]
172 );
173 }
174 _ => unreachable!(),
175 }
176 }
177}