common_function/scalars/
udf.rs1use std::any::Any;
16use std::fmt::{Debug, Formatter};
17use std::sync::Arc;
18
19use common_query::error::FromScalarValueSnafu;
20use common_query::prelude::ColumnarValue;
21use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl};
22use datafusion_expr::ScalarUDF;
23use datatypes::data_type::DataType;
24use datatypes::prelude::*;
25use datatypes::vectors::Helper;
26use session::context::QueryContextRef;
27use snafu::ResultExt;
28
29use crate::function::{FunctionContext, FunctionRef};
30use crate::state::FunctionState;
31
32struct ScalarUdf {
33 function: FunctionRef,
34 signature: datafusion_expr::Signature,
35 context: FunctionContext,
36}
37
38impl Debug for ScalarUdf {
39 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
40 f.debug_struct("ScalarUdf")
41 .field("function", &self.function.name())
42 .field("signature", &self.signature)
43 .finish()
44 }
45}
46
47impl ScalarUDFImpl for ScalarUdf {
48 fn as_any(&self) -> &dyn Any {
49 self
50 }
51
52 fn name(&self) -> &str {
53 self.function.name()
54 }
55
56 fn signature(&self) -> &datafusion_expr::Signature {
57 &self.signature
58 }
59
60 fn return_type(
61 &self,
62 arg_types: &[datatypes::arrow::datatypes::DataType],
63 ) -> datafusion_common::Result<datatypes::arrow::datatypes::DataType> {
64 let arg_types = arg_types
65 .iter()
66 .map(ConcreteDataType::from_arrow_type)
67 .collect::<Vec<_>>();
68 let t = self.function.return_type(&arg_types)?;
69 Ok(t.as_arrow_type())
70 }
71
72 fn invoke_with_args(
73 &self,
74 args: ScalarFunctionArgs,
75 ) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
76 let columns = args
77 .args
78 .iter()
79 .map(|x| {
80 ColumnarValue::try_from(x).and_then(|y| match y {
81 ColumnarValue::Vector(z) => Ok(z),
82 ColumnarValue::Scalar(z) => Helper::try_from_scalar_value(z, args.number_rows)
83 .context(FromScalarValueSnafu),
84 })
85 })
86 .collect::<common_query::error::Result<Vec<_>>>()?;
87 let v = self
88 .function
89 .eval(&self.context, &columns)
90 .map(ColumnarValue::Vector)?;
91 Ok(v.into())
92 }
93}
94
95pub fn create_udf(
97 func: FunctionRef,
98 query_ctx: QueryContextRef,
99 state: Arc<FunctionState>,
100) -> ScalarUDF {
101 let signature = func.signature().into();
102 let udf = ScalarUdf {
103 function: func,
104 signature,
105 context: FunctionContext { query_ctx, state },
106 };
107 ScalarUDF::new_from_impl(udf)
108}
109
110#[cfg(test)]
111mod tests {
112 use std::sync::Arc;
113
114 use common_query::prelude::ScalarValue;
115 use datafusion::arrow::array::BooleanArray;
116 use datafusion_common::config::ConfigOptions;
117 use datatypes::arrow::datatypes::Field;
118 use datatypes::data_type::ConcreteDataType;
119 use datatypes::prelude::VectorRef;
120 use datatypes::vectors::{BooleanVector, ConstantVector};
121 use session::context::QueryContextBuilder;
122
123 use super::*;
124 use crate::function::Function;
125 use crate::scalars::test::TestAndFunction;
126
127 #[test]
128 fn test_create_udf() {
129 let f = Arc::new(TestAndFunction);
130 let query_ctx = QueryContextBuilder::default().build().into();
131
132 let args: Vec<VectorRef> = vec![
133 Arc::new(ConstantVector::new(
134 Arc::new(BooleanVector::from(vec![true])),
135 3,
136 )),
137 Arc::new(BooleanVector::from(vec![true, false, true])),
138 ];
139
140 let vector = f.eval(&FunctionContext::default(), &args).unwrap();
141 assert_eq!(3, vector.len());
142
143 for i in 0..3 {
144 assert!(matches!(vector.get(i), Value::Boolean(b) if b == (i == 0 || i == 2)));
145 }
146
147 let udf = create_udf(f.clone(), query_ctx, Arc::new(FunctionState::default()));
149
150 assert_eq!("test_and", udf.name());
151 let expected_signature: datafusion_expr::Signature = f.signature().into();
152 assert_eq!(udf.signature(), &expected_signature);
153 assert_eq!(
154 ConcreteDataType::boolean_datatype(),
155 udf.return_type(&[])
156 .map(|x| ConcreteDataType::from_arrow_type(&x))
157 .unwrap()
158 );
159
160 let args = vec![
161 datafusion_expr::ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))),
162 datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
163 true, false, false, true,
164 ]))),
165 ];
166
167 let arg_fields = vec![
168 Arc::new(Field::new("a", args[0].data_type(), false)),
169 Arc::new(Field::new("b", args[1].data_type(), false)),
170 ];
171 let return_field = Arc::new(Field::new(
172 "x",
173 ConcreteDataType::boolean_datatype().as_arrow_type(),
174 false,
175 ));
176 let args = ScalarFunctionArgs {
177 args,
178 arg_fields,
179 number_rows: 4,
180 return_field,
181 config_options: Arc::new(ConfigOptions::default()),
182 };
183 match udf.invoke_with_args(args).unwrap() {
184 datafusion_expr::ColumnarValue::Array(x) => {
185 let x = x.as_any().downcast_ref::<BooleanArray>().unwrap();
186 assert_eq!(x.len(), 4);
187 assert_eq!(
188 x.iter().flatten().collect::<Vec<bool>>(),
189 vec![true, false, false, true]
190 );
191 }
192 _ => unreachable!(),
193 }
194 }
195}