flow/expr/
df_func.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Porting Datafusion scalar function to our scalar function to be used in dataflow
16
17use std::sync::Arc;
18
19use arrow::array::RecordBatchOptions;
20use bytes::BytesMut;
21use common_error::ext::BoxedError;
22use common_recordbatch::DfRecordBatch;
23use common_telemetry::debug;
24use datafusion_physical_expr::PhysicalExpr;
25use datatypes::data_type::DataType;
26use datatypes::value::Value;
27use datatypes::vectors::VectorRef;
28use prost::Message;
29use snafu::{IntoError, ResultExt};
30use substrait::error::{DecodeRelSnafu, EncodeRelSnafu};
31use substrait::substrait_proto_df::proto::expression::ScalarFunction;
32
33use crate::error::Error;
34use crate::expr::error::{
35    ArrowSnafu, DatafusionSnafu as EvalDatafusionSnafu, EvalError, ExternalSnafu,
36    InvalidArgumentSnafu,
37};
38use crate::expr::{Batch, ScalarExpr};
39use crate::repr::RelationDesc;
40use crate::transform::{from_scalar_fn_to_df_fn_impl, FunctionExtensions};
41
42/// A way to represent a scalar function that is implemented in Datafusion
43#[derive(Debug, Clone)]
44pub struct DfScalarFunction {
45    /// The raw bytes encoded datafusion scalar function
46    pub(crate) raw_fn: RawDfScalarFn,
47    // TODO(discord9): directly from datafusion expr
48    /// The implementation of the function
49    pub(crate) fn_impl: Arc<dyn PhysicalExpr>,
50    /// The input schema of the function
51    pub(crate) df_schema: Arc<datafusion_common::DFSchema>,
52}
53
54impl DfScalarFunction {
55    pub fn new(raw_fn: RawDfScalarFn, fn_impl: Arc<dyn PhysicalExpr>) -> Result<Self, Error> {
56        Ok(Self {
57            df_schema: Arc::new(raw_fn.input_schema.to_df_schema()?),
58            raw_fn,
59            fn_impl,
60        })
61    }
62
63    pub async fn try_from_raw_fn(raw_fn: RawDfScalarFn) -> Result<Self, Error> {
64        Ok(Self {
65            fn_impl: raw_fn.get_fn_impl().await?,
66            df_schema: Arc::new(raw_fn.input_schema.to_df_schema()?),
67            raw_fn,
68        })
69    }
70
71    /// Evaluate a batch of expressions using input values
72    pub fn eval_batch(&self, batch: &Batch, exprs: &[ScalarExpr]) -> Result<VectorRef, EvalError> {
73        let row_count = batch.row_count();
74        let batch: Vec<_> = exprs
75            .iter()
76            .map(|expr| expr.eval_batch(batch))
77            .collect::<Result<_, _>>()?;
78
79        let schema = self.df_schema.inner().clone();
80
81        let arrays = batch
82            .iter()
83            .map(|array| array.to_arrow_array())
84            .collect::<Vec<_>>();
85        let rb = DfRecordBatch::try_new_with_options(schema, arrays, &RecordBatchOptions::new().with_row_count(Some(row_count))).map_err(|err| {
86            ArrowSnafu {
87                context:
88                    "Failed to create RecordBatch from values when eval_batch datafusion scalar function",
89            }
90            .into_error(err)
91        })?;
92
93        let len = rb.num_rows();
94
95        let res = self.fn_impl.evaluate(&rb).context(EvalDatafusionSnafu {
96            context: "Failed to evaluate datafusion scalar function",
97        })?;
98        let res = common_query::columnar_value::ColumnarValue::try_from(&res)
99            .map_err(BoxedError::new)
100            .context(ExternalSnafu)?;
101        let res_vec = res
102            .try_into_vector(len)
103            .map_err(BoxedError::new)
104            .context(ExternalSnafu)?;
105
106        Ok(res_vec)
107    }
108
109    /// eval a list of expressions using input values
110    fn eval_args(values: &[Value], exprs: &[ScalarExpr]) -> Result<Vec<Value>, EvalError> {
111        exprs
112            .iter()
113            .map(|expr| expr.eval(values))
114            .collect::<Result<_, _>>()
115    }
116
117    // TODO(discord9): add RecordBatch support
118    pub fn eval(&self, values: &[Value], exprs: &[ScalarExpr]) -> Result<Value, EvalError> {
119        // first eval exprs to construct values to feed to datafusion
120        let values: Vec<_> = Self::eval_args(values, exprs)?;
121        if values.is_empty() {
122            return InvalidArgumentSnafu {
123                reason: "values is empty".to_string(),
124            }
125            .fail();
126        }
127        // TODO(discord9): make cols all array length of one
128        let mut cols = vec![];
129        for (idx, typ) in self
130            .raw_fn
131            .input_schema
132            .typ()
133            .column_types
134            .iter()
135            .enumerate()
136        {
137            let typ = typ.scalar_type();
138            let mut array = typ.create_mutable_vector(1);
139            array.push_value_ref(values[idx].as_value_ref());
140            cols.push(array.to_vector().to_arrow_array());
141        }
142        let schema = self.df_schema.inner().clone();
143        let rb = DfRecordBatch::try_new_with_options(
144            schema,
145            cols,
146            &RecordBatchOptions::new().with_row_count(Some(1)),
147        )
148        .map_err(|err| {
149            ArrowSnafu {
150                context:
151                    "Failed to create RecordBatch from values when eval datafusion scalar function",
152            }
153            .into_error(err)
154        })?;
155
156        let res = self.fn_impl.evaluate(&rb).context(EvalDatafusionSnafu {
157            context: "Failed to evaluate datafusion scalar function",
158        })?;
159        let res = common_query::columnar_value::ColumnarValue::try_from(&res)
160            .map_err(BoxedError::new)
161            .context(ExternalSnafu)?;
162        let res_vec = res
163            .try_into_vector(1)
164            .map_err(BoxedError::new)
165            .context(ExternalSnafu)?;
166        let res_val = res_vec
167            .try_get(0)
168            .map_err(BoxedError::new)
169            .context(ExternalSnafu)?;
170        Ok(res_val)
171    }
172}
173
174#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
175pub struct RawDfScalarFn {
176    /// The raw bytes encoded datafusion scalar function,
177    /// due to substrait have too many layers of nested struct and `ScalarFunction` 's derive is different
178    /// for simplicity's sake
179    /// so we store bytes instead of `ScalarFunction` here
180    /// but in unit test we will still compare decoded struct(using `f_decoded` field in Debug impl)
181    pub(crate) f: bytes::BytesMut,
182    /// The input schema of the function
183    pub(crate) input_schema: RelationDesc,
184    /// Extension contains mapping from function reference to function name
185    pub(crate) extensions: FunctionExtensions,
186}
187
188impl std::fmt::Debug for RawDfScalarFn {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        f.debug_struct("RawDfScalarFn")
191            .field("f", &self.f)
192            .field("f_decoded", &ScalarFunction::decode(&mut self.f.as_ref()))
193            .field("df_schema", &self.input_schema)
194            .field("extensions", &self.extensions)
195            .finish()
196    }
197}
198
199impl RawDfScalarFn {
200    pub fn from_proto(
201        f: &substrait::substrait_proto_df::proto::expression::ScalarFunction,
202        input_schema: RelationDesc,
203        extensions: FunctionExtensions,
204    ) -> Result<Self, Error> {
205        let mut buf = BytesMut::new();
206        f.encode(&mut buf)
207            .context(EncodeRelSnafu)
208            .map_err(BoxedError::new)
209            .context(crate::error::ExternalSnafu)?;
210        Ok(Self {
211            f: buf,
212            input_schema,
213            extensions,
214        })
215    }
216    async fn get_fn_impl(&self) -> Result<Arc<dyn PhysicalExpr>, Error> {
217        let f = ScalarFunction::decode(&mut self.f.as_ref())
218            .context(DecodeRelSnafu)
219            .map_err(BoxedError::new)
220            .context(crate::error::ExternalSnafu)?;
221        debug!("Decoded scalar function: {:?}", f);
222
223        let input_schema = &self.input_schema;
224        let extensions = &self.extensions;
225
226        from_scalar_fn_to_df_fn_impl(&f, input_schema, extensions).await
227    }
228}
229
230impl std::cmp::PartialEq for DfScalarFunction {
231    fn eq(&self, other: &Self) -> bool {
232        self.raw_fn.eq(&other.raw_fn)
233    }
234}
235
236// can't derive Eq because of Arc<dyn PhysicalExpr> not eq, so implement it manually
237impl std::cmp::Eq for DfScalarFunction {}
238
239impl std::cmp::PartialOrd for DfScalarFunction {
240    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
241        Some(self.cmp(other))
242    }
243}
244impl std::cmp::Ord for DfScalarFunction {
245    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
246        self.raw_fn.cmp(&other.raw_fn)
247    }
248}
249impl std::hash::Hash for DfScalarFunction {
250    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
251        self.raw_fn.hash(state);
252    }
253}
254
255#[cfg(test)]
256mod test {
257
258    use datatypes::prelude::ConcreteDataType;
259    use substrait::substrait_proto_df::proto::expression::literal::LiteralType;
260    use substrait::substrait_proto_df::proto::expression::{Literal, RexType};
261    use substrait::substrait_proto_df::proto::function_argument::ArgType;
262    use substrait::substrait_proto_df::proto::{Expression, FunctionArgument};
263
264    use super::*;
265    use crate::repr::{ColumnType, RelationType};
266
267    #[tokio::test]
268    async fn test_df_scalar_function() {
269        let raw_scalar_func = ScalarFunction {
270            function_reference: 0,
271            arguments: vec![FunctionArgument {
272                arg_type: Some(ArgType::Value(Expression {
273                    rex_type: Some(RexType::Literal(Literal {
274                        nullable: false,
275                        type_variation_reference: 0,
276                        literal_type: Some(LiteralType::I64(-1)),
277                    })),
278                })),
279            }],
280            output_type: None,
281            ..Default::default()
282        };
283        let input_schema = RelationDesc::try_new(
284            RelationType::new(vec![ColumnType::new_nullable(
285                ConcreteDataType::null_datatype(),
286            )]),
287            vec!["null_column".to_string()],
288        )
289        .unwrap();
290        let extensions = FunctionExtensions::from_iter(vec![(0, "abs")]);
291        let raw_fn = RawDfScalarFn::from_proto(&raw_scalar_func, input_schema, extensions).unwrap();
292        let df_func = DfScalarFunction::try_from_raw_fn(raw_fn).await.unwrap();
293        assert_eq!(
294            df_func
295                .eval(&[Value::Null], &[ScalarExpr::Column(0)])
296                .unwrap(),
297            Value::Int64(1)
298        );
299    }
300}