1use std::sync::Arc;
18
19use arrow::array::RecordBatchOptions;
20use bytes::BytesMut;
21use common_error::ext::BoxedError;
22use common_recordbatch::DfRecordBatch;
23use common_telemetry::debug;
24use datafusion_physical_expr::PhysicalExpr;
25use datatypes::data_type::DataType;
26use datatypes::value::Value;
27use datatypes::vectors::VectorRef;
28use prost::Message;
29use snafu::{IntoError, ResultExt};
30use substrait::error::{DecodeRelSnafu, EncodeRelSnafu};
31use substrait::substrait_proto_df::proto::expression::ScalarFunction;
32
33use crate::error::Error;
34use crate::expr::error::{
35 ArrowSnafu, DatafusionSnafu as EvalDatafusionSnafu, EvalError, ExternalSnafu,
36 InvalidArgumentSnafu,
37};
38use crate::expr::{Batch, ScalarExpr};
39use crate::repr::RelationDesc;
40use crate::transform::{from_scalar_fn_to_df_fn_impl, FunctionExtensions};
41
42#[derive(Debug, Clone)]
44pub struct DfScalarFunction {
45 pub(crate) raw_fn: RawDfScalarFn,
47 pub(crate) fn_impl: Arc<dyn PhysicalExpr>,
50 pub(crate) df_schema: Arc<datafusion_common::DFSchema>,
52}
53
54impl DfScalarFunction {
55 pub fn new(raw_fn: RawDfScalarFn, fn_impl: Arc<dyn PhysicalExpr>) -> Result<Self, Error> {
56 Ok(Self {
57 df_schema: Arc::new(raw_fn.input_schema.to_df_schema()?),
58 raw_fn,
59 fn_impl,
60 })
61 }
62
63 pub async fn try_from_raw_fn(raw_fn: RawDfScalarFn) -> Result<Self, Error> {
64 Ok(Self {
65 fn_impl: raw_fn.get_fn_impl().await?,
66 df_schema: Arc::new(raw_fn.input_schema.to_df_schema()?),
67 raw_fn,
68 })
69 }
70
71 pub fn eval_batch(&self, batch: &Batch, exprs: &[ScalarExpr]) -> Result<VectorRef, EvalError> {
73 let row_count = batch.row_count();
74 let batch: Vec<_> = exprs
75 .iter()
76 .map(|expr| expr.eval_batch(batch))
77 .collect::<Result<_, _>>()?;
78
79 let schema = self.df_schema.inner().clone();
80
81 let arrays = batch
82 .iter()
83 .map(|array| array.to_arrow_array())
84 .collect::<Vec<_>>();
85 let rb = DfRecordBatch::try_new_with_options(schema, arrays, &RecordBatchOptions::new().with_row_count(Some(row_count))).map_err(|err| {
86 ArrowSnafu {
87 context:
88 "Failed to create RecordBatch from values when eval_batch datafusion scalar function",
89 }
90 .into_error(err)
91 })?;
92
93 let len = rb.num_rows();
94
95 let res = self.fn_impl.evaluate(&rb).context(EvalDatafusionSnafu {
96 context: "Failed to evaluate datafusion scalar function",
97 })?;
98 let res = common_query::columnar_value::ColumnarValue::try_from(&res)
99 .map_err(BoxedError::new)
100 .context(ExternalSnafu)?;
101 let res_vec = res
102 .try_into_vector(len)
103 .map_err(BoxedError::new)
104 .context(ExternalSnafu)?;
105
106 Ok(res_vec)
107 }
108
109 fn eval_args(values: &[Value], exprs: &[ScalarExpr]) -> Result<Vec<Value>, EvalError> {
111 exprs
112 .iter()
113 .map(|expr| expr.eval(values))
114 .collect::<Result<_, _>>()
115 }
116
117 pub fn eval(&self, values: &[Value], exprs: &[ScalarExpr]) -> Result<Value, EvalError> {
119 let values: Vec<_> = Self::eval_args(values, exprs)?;
121 if values.is_empty() {
122 return InvalidArgumentSnafu {
123 reason: "values is empty".to_string(),
124 }
125 .fail();
126 }
127 let mut cols = vec![];
129 for (idx, typ) in self
130 .raw_fn
131 .input_schema
132 .typ()
133 .column_types
134 .iter()
135 .enumerate()
136 {
137 let typ = typ.scalar_type();
138 let mut array = typ.create_mutable_vector(1);
139 array.push_value_ref(values[idx].as_value_ref());
140 cols.push(array.to_vector().to_arrow_array());
141 }
142 let schema = self.df_schema.inner().clone();
143 let rb = DfRecordBatch::try_new_with_options(
144 schema,
145 cols,
146 &RecordBatchOptions::new().with_row_count(Some(1)),
147 )
148 .map_err(|err| {
149 ArrowSnafu {
150 context:
151 "Failed to create RecordBatch from values when eval datafusion scalar function",
152 }
153 .into_error(err)
154 })?;
155
156 let res = self.fn_impl.evaluate(&rb).context(EvalDatafusionSnafu {
157 context: "Failed to evaluate datafusion scalar function",
158 })?;
159 let res = common_query::columnar_value::ColumnarValue::try_from(&res)
160 .map_err(BoxedError::new)
161 .context(ExternalSnafu)?;
162 let res_vec = res
163 .try_into_vector(1)
164 .map_err(BoxedError::new)
165 .context(ExternalSnafu)?;
166 let res_val = res_vec
167 .try_get(0)
168 .map_err(BoxedError::new)
169 .context(ExternalSnafu)?;
170 Ok(res_val)
171 }
172}
173
174#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
175pub struct RawDfScalarFn {
176 pub(crate) f: bytes::BytesMut,
182 pub(crate) input_schema: RelationDesc,
184 pub(crate) extensions: FunctionExtensions,
186}
187
188impl std::fmt::Debug for RawDfScalarFn {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 f.debug_struct("RawDfScalarFn")
191 .field("f", &self.f)
192 .field("f_decoded", &ScalarFunction::decode(&mut self.f.as_ref()))
193 .field("df_schema", &self.input_schema)
194 .field("extensions", &self.extensions)
195 .finish()
196 }
197}
198
199impl RawDfScalarFn {
200 pub fn from_proto(
201 f: &substrait::substrait_proto_df::proto::expression::ScalarFunction,
202 input_schema: RelationDesc,
203 extensions: FunctionExtensions,
204 ) -> Result<Self, Error> {
205 let mut buf = BytesMut::new();
206 f.encode(&mut buf)
207 .context(EncodeRelSnafu)
208 .map_err(BoxedError::new)
209 .context(crate::error::ExternalSnafu)?;
210 Ok(Self {
211 f: buf,
212 input_schema,
213 extensions,
214 })
215 }
216 async fn get_fn_impl(&self) -> Result<Arc<dyn PhysicalExpr>, Error> {
217 let f = ScalarFunction::decode(&mut self.f.as_ref())
218 .context(DecodeRelSnafu)
219 .map_err(BoxedError::new)
220 .context(crate::error::ExternalSnafu)?;
221 debug!("Decoded scalar function: {:?}", f);
222
223 let input_schema = &self.input_schema;
224 let extensions = &self.extensions;
225
226 from_scalar_fn_to_df_fn_impl(&f, input_schema, extensions).await
227 }
228}
229
230impl std::cmp::PartialEq for DfScalarFunction {
231 fn eq(&self, other: &Self) -> bool {
232 self.raw_fn.eq(&other.raw_fn)
233 }
234}
235
236impl std::cmp::Eq for DfScalarFunction {}
238
239impl std::cmp::PartialOrd for DfScalarFunction {
240 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
241 Some(self.cmp(other))
242 }
243}
244impl std::cmp::Ord for DfScalarFunction {
245 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
246 self.raw_fn.cmp(&other.raw_fn)
247 }
248}
249impl std::hash::Hash for DfScalarFunction {
250 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
251 self.raw_fn.hash(state);
252 }
253}
254
255#[cfg(test)]
256mod test {
257
258 use datatypes::prelude::ConcreteDataType;
259 use substrait::substrait_proto_df::proto::expression::literal::LiteralType;
260 use substrait::substrait_proto_df::proto::expression::{Literal, RexType};
261 use substrait::substrait_proto_df::proto::function_argument::ArgType;
262 use substrait::substrait_proto_df::proto::{Expression, FunctionArgument};
263
264 use super::*;
265 use crate::repr::{ColumnType, RelationType};
266
267 #[tokio::test]
268 async fn test_df_scalar_function() {
269 let raw_scalar_func = ScalarFunction {
270 function_reference: 0,
271 arguments: vec![FunctionArgument {
272 arg_type: Some(ArgType::Value(Expression {
273 rex_type: Some(RexType::Literal(Literal {
274 nullable: false,
275 type_variation_reference: 0,
276 literal_type: Some(LiteralType::I64(-1)),
277 })),
278 })),
279 }],
280 output_type: None,
281 ..Default::default()
282 };
283 let input_schema = RelationDesc::try_new(
284 RelationType::new(vec![ColumnType::new_nullable(
285 ConcreteDataType::null_datatype(),
286 )]),
287 vec!["null_column".to_string()],
288 )
289 .unwrap();
290 let extensions = FunctionExtensions::from_iter(vec![(0, "abs")]);
291 let raw_fn = RawDfScalarFn::from_proto(&raw_scalar_func, input_schema, extensions).unwrap();
292 let df_func = DfScalarFunction::try_from_raw_fn(raw_fn).await.unwrap();
293 assert_eq!(
294 df_func
295 .eval(&[Value::Null], &[ScalarExpr::Column(0)])
296 .unwrap(),
297 Value::Int64(1)
298 );
299 }
300}