common_function/scalars/
udf.rs1use std::any::Any;
16use std::fmt::{Debug, Formatter};
17use std::sync::Arc;
18
19use common_query::error::FromScalarValueSnafu;
20use common_query::prelude::ColumnarValue;
21use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl};
22use datafusion_expr::ScalarUDF;
23use datatypes::data_type::DataType;
24use datatypes::prelude::*;
25use datatypes::vectors::Helper;
26use session::context::QueryContextRef;
27use snafu::ResultExt;
28
29use crate::function::{FunctionContext, FunctionRef};
30use crate::state::FunctionState;
31
32struct ScalarUdf {
33 function: FunctionRef,
34 signature: datafusion_expr::Signature,
35 context: FunctionContext,
36}
37
38impl Debug for ScalarUdf {
39 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
40 f.debug_struct("ScalarUdf")
41 .field("function", &self.function.name())
42 .field("signature", &self.signature)
43 .finish()
44 }
45}
46
47impl ScalarUDFImpl for ScalarUdf {
48 fn as_any(&self) -> &dyn Any {
49 self
50 }
51
52 fn name(&self) -> &str {
53 self.function.name()
54 }
55
56 fn signature(&self) -> &datafusion_expr::Signature {
57 &self.signature
58 }
59
60 fn return_type(
61 &self,
62 arg_types: &[datatypes::arrow::datatypes::DataType],
63 ) -> datafusion_common::Result<datatypes::arrow::datatypes::DataType> {
64 let arg_types = arg_types
65 .iter()
66 .map(ConcreteDataType::from_arrow_type)
67 .collect::<Vec<_>>();
68 let t = self.function.return_type(&arg_types)?;
69 Ok(t.as_arrow_type())
70 }
71
72 fn invoke_with_args(
73 &self,
74 args: ScalarFunctionArgs,
75 ) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
76 let columns = args
77 .args
78 .iter()
79 .map(|x| {
80 ColumnarValue::try_from(x).and_then(|y| match y {
81 ColumnarValue::Vector(z) => Ok(z),
82 ColumnarValue::Scalar(z) => Helper::try_from_scalar_value(z, args.number_rows)
83 .context(FromScalarValueSnafu),
84 })
85 })
86 .collect::<common_query::error::Result<Vec<_>>>()?;
87 let v = self
88 .function
89 .eval(&self.context, &columns)
90 .map(ColumnarValue::Vector)?;
91 Ok(v.into())
92 }
93}
94
95pub fn create_udf(
97 func: FunctionRef,
98 query_ctx: QueryContextRef,
99 state: Arc<FunctionState>,
100) -> ScalarUDF {
101 let signature = func.signature().into();
102 let udf = ScalarUdf {
103 function: func,
104 signature,
105 context: FunctionContext { query_ctx, state },
106 };
107 ScalarUDF::new_from_impl(udf)
108}
109
110#[cfg(test)]
111mod tests {
112 use std::sync::Arc;
113
114 use common_query::prelude::ScalarValue;
115 use datafusion::arrow::array::BooleanArray;
116 use datatypes::data_type::ConcreteDataType;
117 use datatypes::prelude::VectorRef;
118 use datatypes::vectors::{BooleanVector, ConstantVector};
119 use session::context::QueryContextBuilder;
120
121 use super::*;
122 use crate::function::Function;
123 use crate::scalars::test::TestAndFunction;
124
125 #[test]
126 fn test_create_udf() {
127 let f = Arc::new(TestAndFunction);
128 let query_ctx = QueryContextBuilder::default().build().into();
129
130 let args: Vec<VectorRef> = vec![
131 Arc::new(ConstantVector::new(
132 Arc::new(BooleanVector::from(vec![true])),
133 3,
134 )),
135 Arc::new(BooleanVector::from(vec![true, false, true])),
136 ];
137
138 let vector = f.eval(&FunctionContext::default(), &args).unwrap();
139 assert_eq!(3, vector.len());
140
141 for i in 0..3 {
142 assert!(matches!(vector.get(i), Value::Boolean(b) if b == (i == 0 || i == 2)));
143 }
144
145 let udf = create_udf(f.clone(), query_ctx, Arc::new(FunctionState::default()));
147
148 assert_eq!("test_and", udf.name());
149 let expected_signature: datafusion_expr::Signature = f.signature().into();
150 assert_eq!(udf.signature(), &expected_signature);
151 assert_eq!(
152 ConcreteDataType::boolean_datatype(),
153 udf.return_type(&[])
154 .map(|x| ConcreteDataType::from_arrow_type(&x))
155 .unwrap()
156 );
157
158 let args = vec![
159 datafusion_expr::ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))),
160 datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
161 true, false, false, true,
162 ]))),
163 ];
164
165 let args = ScalarFunctionArgs {
166 args,
167 number_rows: 4,
168 return_type: &ConcreteDataType::boolean_datatype().as_arrow_type(),
169 };
170 match udf.invoke_with_args(args).unwrap() {
171 datafusion_expr::ColumnarValue::Array(x) => {
172 let x = x.as_any().downcast_ref::<BooleanArray>().unwrap();
173 assert_eq!(x.len(), 4);
174 assert_eq!(
175 x.iter().flatten().collect::<Vec<bool>>(),
176 vec![true, false, false, true]
177 );
178 }
179 _ => unreachable!(),
180 }
181 }
182}