flow/transform/
expr.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![warn(unused_imports)]
16
17use std::sync::Arc;
18
19use common_error::ext::BoxedError;
20use common_telemetry::debug;
21use datafusion::execution::SessionStateBuilder;
22use datafusion::functions::all_default_functions;
23use datafusion_physical_expr::PhysicalExpr;
24use datafusion_substrait::logical_plan::consumer::DefaultSubstraitConsumer;
25use datatypes::data_type::ConcreteDataType as CDT;
26use snafu::{ensure, OptionExt, ResultExt};
27use substrait_proto::proto::expression::field_reference::ReferenceType::DirectReference;
28use substrait_proto::proto::expression::reference_segment::ReferenceType::StructField;
29use substrait_proto::proto::expression::{IfThen, RexType, ScalarFunction};
30use substrait_proto::proto::function_argument::ArgType;
31use substrait_proto::proto::Expression;
32
33use crate::error::{
34    DatafusionSnafu, DatatypesSnafu, Error, EvalSnafu, ExternalSnafu, InvalidQuerySnafu,
35    NotImplementedSnafu, PlanSnafu, UnexpectedSnafu,
36};
37use crate::expr::{
38    BinaryFunc, DfScalarFunction, RawDfScalarFn, ScalarExpr, TypedExpr, UnaryFunc,
39    UnmaterializableFunc, VariadicFunc, TUMBLE_END, TUMBLE_START,
40};
41use crate::repr::{ColumnType, RelationDesc, RelationType};
42use crate::transform::literal::{
43    from_substrait_literal, from_substrait_type, to_substrait_literal,
44};
45use crate::transform::{substrait_proto, FunctionExtensions};
46
47// TODO(discord9): refactor plan to substrait convert of `arrow_cast` function thus remove this function
48/// ref to `arrow_schema::datatype` for type name
49fn typename_to_cdt(name: &str) -> Result<CDT, Error> {
50    let ret = match name {
51        "Int8" => CDT::int8_datatype(),
52        "Int16" => CDT::int16_datatype(),
53        "Int32" => CDT::int32_datatype(),
54        "Int64" => CDT::int64_datatype(),
55        "UInt8" => CDT::uint8_datatype(),
56        "UInt16" => CDT::uint16_datatype(),
57        "UInt32" => CDT::uint32_datatype(),
58        "UInt64" => CDT::uint64_datatype(),
59        "Float32" => CDT::float32_datatype(),
60        "Float64" => CDT::float64_datatype(),
61        "Boolean" => CDT::boolean_datatype(),
62        "String" => CDT::string_datatype(),
63        "Date" | "Date32" | "Date64" => CDT::date_datatype(),
64        "Timestamp" => CDT::timestamp_second_datatype(),
65        "Timestamp(Second, None)" => CDT::timestamp_second_datatype(),
66        "Timestamp(Millisecond, None)" => CDT::timestamp_millisecond_datatype(),
67        "Timestamp(Microsecond, None)" => CDT::timestamp_microsecond_datatype(),
68        "Timestamp(Nanosecond, None)" => CDT::timestamp_nanosecond_datatype(),
69        "Time32(Second)" | "Time64(Second)" => CDT::time_second_datatype(),
70        "Time32(Millisecond)" | "Time64(Millisecond)" => CDT::time_millisecond_datatype(),
71        "Time32(Microsecond)" | "Time64(Microsecond)" => CDT::time_microsecond_datatype(),
72        "Time32(Nanosecond)" | "Time64(Nanosecond)" => CDT::time_nanosecond_datatype(),
73        _ => NotImplementedSnafu {
74            reason: format!("Unrecognized typename: {}", name),
75        }
76        .fail()?,
77    };
78    Ok(ret)
79}
80
81/// Convert [`ScalarFunction`] to corresponding Datafusion's [`PhysicalExpr`]
82pub(crate) async fn from_scalar_fn_to_df_fn_impl(
83    f: &ScalarFunction,
84    input_schema: &RelationDesc,
85    extensions: &FunctionExtensions,
86) -> Result<Arc<dyn PhysicalExpr>, Error> {
87    let e = Expression {
88        rex_type: Some(RexType::ScalarFunction(f.clone())),
89    };
90    let schema = input_schema.to_df_schema()?;
91
92    let extensions = extensions.to_extensions();
93    let session_state = SessionStateBuilder::new()
94        .with_scalar_functions(all_default_functions())
95        .build();
96    let consumer = DefaultSubstraitConsumer::new(&extensions, &session_state);
97    let df_expr =
98        substrait::df_logical_plan::consumer::from_substrait_rex(&consumer, &e, &schema).await;
99    let expr = df_expr.context({
100        DatafusionSnafu {
101            context: "Failed to convert substrait scalar function to datafusion scalar function",
102        }
103    })?;
104    let phy_expr =
105        datafusion::physical_expr::create_physical_expr(&expr, &schema, &Default::default())
106            .context(DatafusionSnafu {
107                context: "Failed to create physical expression from logical expression",
108            })?;
109    Ok(phy_expr)
110}
111
112/// Return an [`Expression`](wrapped in a [`FunctionArgument`]) that references the i-th column of the input relation
113pub(crate) fn proto_col(i: usize) -> substrait_proto::proto::FunctionArgument {
114    use substrait_proto::proto::expression;
115    let expr = Expression {
116        rex_type: Some(expression::RexType::Selection(Box::new(
117            expression::FieldReference {
118                reference_type: Some(expression::field_reference::ReferenceType::DirectReference(
119                    expression::ReferenceSegment {
120                        reference_type: Some(
121                            expression::reference_segment::ReferenceType::StructField(Box::new(
122                                expression::reference_segment::StructField {
123                                    field: i as i32,
124                                    child: None,
125                                },
126                            )),
127                        ),
128                    },
129                )),
130                root_type: None,
131            },
132        ))),
133    };
134    substrait_proto::proto::FunctionArgument {
135        arg_type: Some(substrait_proto::proto::function_argument::ArgType::Value(
136            expr,
137        )),
138    }
139}
140
141fn is_proto_literal(arg: &substrait_proto::proto::FunctionArgument) -> bool {
142    use substrait_proto::proto::expression;
143    matches!(
144        arg.arg_type.as_ref().unwrap(),
145        ArgType::Value(Expression {
146            rex_type: Some(expression::RexType::Literal(_)),
147        })
148    )
149}
150
151fn build_proto_lit(
152    lit: substrait_proto::proto::expression::Literal,
153) -> substrait_proto::proto::FunctionArgument {
154    use substrait_proto::proto;
155    proto::FunctionArgument {
156        arg_type: Some(ArgType::Value(Expression {
157            rex_type: Some(proto::expression::RexType::Literal(lit)),
158        })),
159    }
160}
161
162/// rewrite ScalarFunction's arguments to Columns 0..n so nested exprs are still handled by us instead of datafusion
163///
164/// specially, if a argument is a literal, the replacement will not happen
165fn rewrite_scalar_function(
166    f: &ScalarFunction,
167    arg_typed_exprs: &[TypedExpr],
168) -> Result<ScalarFunction, Error> {
169    let mut f_rewrite = f.clone();
170    ensure!(
171        f_rewrite.arguments.len() == arg_typed_exprs.len(),
172        crate::error::InternalSnafu {
173            reason: format!(
174                "Expect `f_rewrite` and `arg_typed_expr` to be same length, found {} and {}",
175                f_rewrite.arguments.len(),
176                arg_typed_exprs.len()
177            )
178        }
179    );
180    for (idx, raw_expr) in f_rewrite.arguments.iter_mut().enumerate() {
181        // only replace it with col(idx) if it is not literal
182        // will try best to determine if it is literal, i.e. for function like `cast(<literal>)` will try
183        // in both world to understand if it results in a literal
184        match (
185            is_proto_literal(raw_expr),
186            arg_typed_exprs[idx].expr.is_literal(),
187        ) {
188            (false, false) => *raw_expr = proto_col(idx),
189            (true, _) => (),
190            (false, true) => {
191                if let ScalarExpr::Literal(val, ty) = &arg_typed_exprs[idx].expr {
192                    let df_val = val
193                        .try_to_scalar_value(ty)
194                        .map_err(BoxedError::new)
195                        .context(ExternalSnafu)?;
196                    let lit_sub = to_substrait_literal(&df_val)?;
197                    // put const-folded literal back to df to simplify stuff
198                    *raw_expr = build_proto_lit(lit_sub);
199                } else {
200                    UnexpectedSnafu {
201                        reason: format!(
202                            "Expect value to be literal, but found {:?}",
203                            arg_typed_exprs[idx].expr
204                        ),
205                    }
206                    .fail()?
207                }
208            }
209        }
210    }
211    Ok(f_rewrite)
212}
213
214impl TypedExpr {
215    pub async fn from_substrait_to_datafusion_scalar_func(
216        f: &ScalarFunction,
217        arg_typed_exprs: Vec<TypedExpr>,
218        extensions: &FunctionExtensions,
219    ) -> Result<TypedExpr, Error> {
220        let (arg_exprs, arg_types): (Vec<_>, Vec<_>) = arg_typed_exprs
221            .clone()
222            .into_iter()
223            .map(|e| (e.expr, e.typ))
224            .unzip();
225        debug!("Before rewrite: {:?}", f);
226        let f_rewrite = rewrite_scalar_function(f, &arg_typed_exprs)?;
227        debug!("After rewrite: {:?}", f_rewrite);
228        let input_schema = RelationType::new(arg_types).into_unnamed();
229        let raw_fn =
230            RawDfScalarFn::from_proto(&f_rewrite, input_schema.clone(), extensions.clone())?;
231
232        let df_func = DfScalarFunction::try_from_raw_fn(raw_fn).await?;
233        let expr = ScalarExpr::CallDf {
234            df_scalar_fn: df_func,
235            exprs: arg_exprs,
236        };
237        // df already know it's own schema, so not providing here
238        let ret_type = expr.typ(&[])?;
239        Ok(TypedExpr::new(expr, ret_type))
240    }
241
242    /// Convert ScalarFunction into Flow's ScalarExpr
243    pub async fn from_substrait_scalar_func(
244        f: &ScalarFunction,
245        input_schema: &RelationDesc,
246        extensions: &FunctionExtensions,
247    ) -> Result<TypedExpr, Error> {
248        let fn_name =
249            extensions
250                .get(&f.function_reference)
251                .with_context(|| NotImplementedSnafu {
252                    reason: format!(
253                        "Aggregated function not found: function reference = {:?}",
254                        f.function_reference
255                    ),
256                })?;
257        let arg_len = f.arguments.len();
258        let arg_typed_exprs: Vec<TypedExpr> = {
259            let mut rets = Vec::new();
260            for arg in f.arguments.iter() {
261                let ret = match &arg.arg_type {
262                    Some(ArgType::Value(e)) => {
263                        TypedExpr::from_substrait_rex(e, input_schema, extensions).await
264                    }
265                    _ => not_impl_err!("Aggregated function argument non-Value type not supported"),
266                }?;
267                rets.push(ret);
268            }
269            rets
270        };
271
272        // literal's type is determined by the function and type of other args
273        let (arg_exprs, arg_types): (Vec<_>, Vec<_>) = arg_typed_exprs
274            .clone()
275            .into_iter()
276            .map(
277                |TypedExpr {
278                     expr: arg_val,
279                     typ: arg_type,
280                 }| {
281                    if arg_val.is_literal() {
282                        (arg_val, None)
283                    } else {
284                        (arg_val, Some(arg_type.scalar_type))
285                    }
286                },
287            )
288            .unzip();
289
290        match arg_len {
291            1 if UnaryFunc::is_valid_func_name(fn_name) => {
292                let func = UnaryFunc::from_str_and_type(fn_name, None)?;
293                let arg = arg_exprs[0].clone();
294                let ret_type = ColumnType::new_nullable(func.signature().output.clone());
295
296                Ok(TypedExpr::new(arg.call_unary(func), ret_type))
297            }
298            2 if fn_name == "arrow_cast" => {
299                let cast_to = arg_exprs[1]
300                    .clone()
301                    .as_literal()
302                    .and_then(|lit| lit.as_string())
303                    .with_context(|| InvalidQuerySnafu {
304                        reason: "array_cast's second argument must be a literal string",
305                    })?;
306                let cast_to = typename_to_cdt(&cast_to)?;
307                let func = UnaryFunc::Cast(cast_to.clone());
308                let arg = arg_exprs[0].clone();
309                // constant folding here since some datafusion function require it for constant arg(i.e. `DATE_BIN`)
310                if arg.is_literal() {
311                    let res = func.eval(&[], &arg).context(EvalSnafu)?;
312                    Ok(TypedExpr::new(
313                        ScalarExpr::Literal(res, cast_to.clone()),
314                        ColumnType::new_nullable(cast_to),
315                    ))
316                } else {
317                    let ret_type = ColumnType::new_nullable(func.signature().output.clone());
318
319                    Ok(TypedExpr::new(arg.call_unary(func), ret_type))
320                }
321            }
322            2 if BinaryFunc::is_valid_func_name(fn_name) => {
323                let (func, signature) =
324                    BinaryFunc::from_str_expr_and_type(fn_name, &arg_exprs, &arg_types[0..2])?;
325
326                // constant folding here
327                let is_all_literal = arg_exprs.iter().all(|arg| arg.is_literal());
328                if is_all_literal {
329                    let res = func
330                        .eval(&[], &arg_exprs[0], &arg_exprs[1])
331                        .context(EvalSnafu)?;
332
333                    // if output type is null, it should be inferred from the input types
334                    let con_typ = signature.output.clone();
335                    let typ = ColumnType::new_nullable(con_typ.clone());
336                    return Ok(TypedExpr::new(ScalarExpr::Literal(res, con_typ), typ));
337                }
338
339                let mut arg_exprs = arg_exprs;
340                for (idx, arg_expr) in arg_exprs.iter_mut().enumerate() {
341                    if let ScalarExpr::Literal(val, typ) = arg_expr {
342                        let dest_type = signature.input[idx].clone();
343
344                        // cast val to target_type
345                        let dest_val = if !dest_type.is_null() {
346                            datatypes::types::cast(val.clone(), &dest_type)
347                        .with_context(|_|
348                            DatatypesSnafu{
349                                extra: format!("Failed to implicitly cast literal {val:?} to type {dest_type:?}")
350                            })?
351                        } else {
352                            val.clone()
353                        };
354                        *val = dest_val;
355                        *typ = dest_type;
356                    }
357                }
358
359                let ret_type = ColumnType::new_nullable(func.signature().output.clone());
360                let ret_expr = arg_exprs[0].clone().call_binary(arg_exprs[1].clone(), func);
361                Ok(TypedExpr::new(ret_expr, ret_type))
362            }
363            _var => {
364                if fn_name == TUMBLE_START || fn_name == TUMBLE_END {
365                    let (func, arg) = UnaryFunc::from_tumble_func(fn_name, &arg_typed_exprs)?;
366
367                    let ret_type = ColumnType::new_nullable(func.signature().output.clone());
368
369                    Ok(TypedExpr::new(arg.expr.call_unary(func), ret_type))
370                } else if VariadicFunc::is_valid_func_name(fn_name) {
371                    let func = VariadicFunc::from_str_and_types(fn_name, &arg_types)?;
372                    let ret_type = ColumnType::new_nullable(func.signature().output.clone());
373                    let mut expr = ScalarExpr::CallVariadic {
374                        func,
375                        exprs: arg_exprs,
376                    };
377                    expr.optimize();
378                    Ok(TypedExpr::new(expr, ret_type))
379                } else if UnmaterializableFunc::is_valid_func_name(fn_name) {
380                    let func = UnmaterializableFunc::from_str_args(fn_name, arg_typed_exprs)?;
381                    let ret_type = ColumnType::new_nullable(func.signature().output.clone());
382                    Ok(TypedExpr::new(
383                        ScalarExpr::CallUnmaterializable(func),
384                        ret_type,
385                    ))
386                } else {
387                    let try_as_df = Self::from_substrait_to_datafusion_scalar_func(
388                        f,
389                        arg_typed_exprs,
390                        extensions,
391                    )
392                    .await?;
393                    Ok(try_as_df)
394                }
395            }
396        }
397    }
398
399    /// Convert IfThen into Flow's ScalarExpr
400    pub async fn from_substrait_ifthen_rex(
401        if_then: &IfThen,
402        input_schema: &RelationDesc,
403        extensions: &FunctionExtensions,
404    ) -> Result<TypedExpr, Error> {
405        let ifs: Vec<_> = {
406            let mut ifs = Vec::new();
407            for if_clause in if_then.ifs.iter() {
408                let proto_if = if_clause.r#if.as_ref().with_context(|| InvalidQuerySnafu {
409                    reason: "IfThen clause without if",
410                })?;
411                let proto_then = if_clause.then.as_ref().with_context(|| InvalidQuerySnafu {
412                    reason: "IfThen clause without then",
413                })?;
414                let cond =
415                    TypedExpr::from_substrait_rex(proto_if, input_schema, extensions).await?;
416                let then =
417                    TypedExpr::from_substrait_rex(proto_then, input_schema, extensions).await?;
418                ifs.push((cond, then));
419            }
420            ifs
421        };
422        // if no else is presented
423        let els = match if_then
424            .r#else
425            .as_ref()
426            .map(|e| TypedExpr::from_substrait_rex(e, input_schema, extensions))
427        {
428            Some(fut) => Some(fut.await),
429            None => None,
430        }
431        .transpose()?
432        .unwrap_or_else(|| {
433            TypedExpr::new(
434                ScalarExpr::literal_null(),
435                ColumnType::new_nullable(CDT::null_datatype()),
436            )
437        });
438
439        fn build_if_then_recur(
440            mut next_if_then: impl Iterator<Item = (TypedExpr, TypedExpr)>,
441            els: TypedExpr,
442        ) -> TypedExpr {
443            if let Some((cond, then)) = next_if_then.next() {
444                // always assume the type of `if`` expr is the same with the `then`` expr
445                TypedExpr::new(
446                    ScalarExpr::If {
447                        cond: Box::new(cond.expr),
448                        then: Box::new(then.expr),
449                        els: Box::new(build_if_then_recur(next_if_then, els).expr),
450                    },
451                    then.typ,
452                )
453            } else {
454                els
455            }
456        }
457        let expr_if = build_if_then_recur(ifs.into_iter(), els);
458        Ok(expr_if)
459    }
460    /// Convert Substrait Rex into Flow's ScalarExpr
461    #[async_recursion::async_recursion]
462    pub async fn from_substrait_rex(
463        e: &Expression,
464        input_schema: &RelationDesc,
465        extensions: &FunctionExtensions,
466    ) -> Result<TypedExpr, Error> {
467        match &e.rex_type {
468            Some(RexType::Literal(lit)) => {
469                let lit = from_substrait_literal(lit)?;
470                Ok(TypedExpr::new(
471                    ScalarExpr::Literal(lit.0, lit.1.clone()),
472                    ColumnType::new_nullable(lit.1),
473                ))
474            }
475            Some(RexType::SingularOrList(s)) => {
476                let substrait_expr = s.value.as_ref().with_context(|| InvalidQuerySnafu {
477                    reason: "SingularOrList expression without value",
478                })?;
479                let typed_expr =
480                    TypedExpr::from_substrait_rex(substrait_expr, input_schema, extensions).await?;
481                // Note that we didn't impl support to in list expr
482                if !s.options.is_empty() {
483                    let mut list = Vec::with_capacity(s.options.len());
484                    for opt in s.options.iter() {
485                        let opt_expr =
486                            TypedExpr::from_substrait_rex(opt, input_schema, extensions).await?;
487                        list.push(opt_expr.expr);
488                    }
489                    let in_list_expr = ScalarExpr::InList {
490                        expr: Box::new(typed_expr.expr),
491                        list,
492                    };
493                    Ok(TypedExpr::new(
494                        in_list_expr,
495                        ColumnType::new_nullable(CDT::boolean_datatype()),
496                    ))
497                } else {
498                    Ok(typed_expr)
499                }
500            }
501            Some(RexType::Selection(field_ref)) => match &field_ref.reference_type {
502                Some(DirectReference(direct)) => match &direct.reference_type.as_ref() {
503                    Some(StructField(x)) => match &x.child.as_ref() {
504                        Some(_) => {
505                            not_impl_err!(
506                                "Direct reference StructField with child is not supported"
507                            )
508                        }
509                        None => {
510                            let column = x.field as usize;
511                            let column_type = input_schema.typ().column_types[column].clone();
512                            Ok(TypedExpr::new(ScalarExpr::Column(column), column_type))
513                        }
514                    },
515                    _ => not_impl_err!(
516                        "Direct reference with types other than StructField is not supported"
517                    ),
518                },
519                _ => not_impl_err!("unsupported field ref type"),
520            },
521            Some(RexType::ScalarFunction(f)) => {
522                TypedExpr::from_substrait_scalar_func(f, input_schema, extensions).await
523            }
524            Some(RexType::IfThen(if_then)) => {
525                TypedExpr::from_substrait_ifthen_rex(if_then, input_schema, extensions).await
526            }
527            Some(RexType::Cast(cast)) => {
528                let input = cast.input.as_ref().with_context(|| InvalidQuerySnafu {
529                    reason: "Cast expression without input",
530                })?;
531                let input = TypedExpr::from_substrait_rex(input, input_schema, extensions).await?;
532                let cast_type = from_substrait_type(cast.r#type.as_ref().with_context(|| {
533                    InvalidQuerySnafu {
534                        reason: "Cast expression without type",
535                    }
536                })?)?;
537                let func = UnaryFunc::from_str_and_type("cast", Some(cast_type.clone()))?;
538                Ok(TypedExpr::new(
539                    input.expr.call_unary(func),
540                    ColumnType::new_nullable(cast_type),
541                ))
542            }
543            Some(RexType::WindowFunction(_)) => PlanSnafu {
544                reason:
545                    "Window function is not supported yet. Please use aggregation function instead."
546                        .to_string(),
547            }
548            .fail(),
549            _ => not_impl_err!("unsupported rex_type"),
550        }
551    }
552}
553
554#[cfg(test)]
555mod test {
556    use datatypes::prelude::ConcreteDataType;
557    use datatypes::value::Value;
558    use pretty_assertions::assert_eq;
559
560    use super::*;
561    use crate::expr::{GlobalId, MapFilterProject};
562    use crate::plan::{Plan, TypedPlan};
563    use crate::repr::{self, ColumnType, RelationType};
564    use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait};
565
566    /// test if `WHERE` condition can be converted to Flow's ScalarExpr in mfp's filter
567    #[tokio::test]
568    async fn test_where_and() {
569        let engine = create_test_query_engine();
570        let sql =
571            "SELECT number FROM numbers_with_ts WHERE number >= 1 AND number <= 3 AND number!=2";
572        let plan = sql_to_substrait(engine.clone(), sql).await;
573
574        let mut ctx = create_test_ctx();
575        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
576
577        // optimize binary and to variadic and
578        let filter = ScalarExpr::CallVariadic {
579            func: VariadicFunc::And,
580            exprs: vec![
581                ScalarExpr::Column(2).call_binary(
582                    ScalarExpr::Literal(Value::from(1i64), CDT::int64_datatype()),
583                    BinaryFunc::Gte,
584                ),
585                ScalarExpr::Column(2).call_binary(
586                    ScalarExpr::Literal(Value::from(3i64), CDT::int64_datatype()),
587                    BinaryFunc::Lte,
588                ),
589                ScalarExpr::Column(2).call_binary(
590                    ScalarExpr::Literal(Value::from(2i64), CDT::int64_datatype()),
591                    BinaryFunc::NotEq,
592                ),
593            ],
594        };
595        let expected = TypedPlan {
596            schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)])
597                .into_named(vec![Some("number".to_string())]),
598            plan: Plan::Mfp {
599                input: Box::new(
600                    Plan::Get {
601                        id: crate::expr::Id::Global(GlobalId::User(1)),
602                    }
603                    .with_types(
604                        RelationType::new(vec![
605                            ColumnType::new(ConcreteDataType::uint32_datatype(), false),
606                            ColumnType::new(
607                                ConcreteDataType::timestamp_millisecond_datatype(),
608                                false,
609                            ),
610                        ])
611                        .into_named(vec![Some("number".to_string()), Some("ts".to_string())]),
612                    ),
613                ),
614                mfp: MapFilterProject::new(2)
615                    .map(vec![
616                        ScalarExpr::CallUnary {
617                            func: UnaryFunc::Cast(CDT::int64_datatype()),
618                            expr: Box::new(ScalarExpr::Column(0)),
619                        },
620                        ScalarExpr::Column(0),
621                        ScalarExpr::Column(3),
622                    ])
623                    .unwrap()
624                    .filter(vec![filter])
625                    .unwrap()
626                    .project(vec![4])
627                    .unwrap(),
628            },
629        };
630        assert_eq!(flow_plan.unwrap(), expected);
631    }
632
633    /// case: binary functions&constant folding can happen in converting substrait plan
634    #[tokio::test]
635    async fn test_binary_func_and_constant_folding() {
636        let engine = create_test_query_engine();
637        let sql = "SELECT 1+1*2-1/1+1%2==3 FROM numbers";
638        let plan = sql_to_substrait(engine.clone(), sql).await;
639
640        let mut ctx = create_test_ctx();
641        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
642
643        let expected = TypedPlan {
644            schema: RelationType::new(vec![ColumnType::new(CDT::boolean_datatype(), true)])
645                .into_named(vec![Some("Int64(1) + Int64(1) * Int64(2) - Int64(1) / Int64(1) + Int64(1) % Int64(2) = Int64(3)".to_string())]),
646            plan: Plan::Constant {
647                rows: vec![(
648                    repr::Row::new(vec![Value::from(true)]),
649                    repr::Timestamp::MIN,
650                    1,
651                )],
652            },
653        };
654
655        assert_eq!(flow_plan.unwrap(), expected);
656    }
657
658    /// test if the type of the literal is correctly inferred, i.e. in here literal is decoded to be int64, but need to be uint32,
659    #[tokio::test]
660    async fn test_implicitly_cast() {
661        let engine = create_test_query_engine();
662        let sql = "SELECT number+1 FROM numbers";
663        let plan = sql_to_substrait(engine.clone(), sql).await;
664
665        let mut ctx = create_test_ctx();
666        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
667
668        let expected = TypedPlan {
669            schema: RelationType::new(vec![ColumnType::new(CDT::int64_datatype(), true)])
670                .into_named(vec![Some("numbers.number + Int64(1)".to_string())]),
671            plan: Plan::Mfp {
672                input: Box::new(
673                    Plan::Get {
674                        id: crate::expr::Id::Global(GlobalId::User(0)),
675                    }
676                    .with_types(
677                        RelationType::new(vec![ColumnType::new(
678                            ConcreteDataType::uint32_datatype(),
679                            false,
680                        )])
681                        .into_named(vec![Some("number".to_string())]),
682                    ),
683                ),
684                mfp: MapFilterProject::new(1)
685                    .map(vec![ScalarExpr::Column(0)
686                        .call_unary(UnaryFunc::Cast(CDT::int64_datatype()))
687                        .call_binary(
688                            ScalarExpr::Literal(Value::from(1i64), CDT::int64_datatype()),
689                            BinaryFunc::AddInt64,
690                        )])
691                    .unwrap()
692                    .project(vec![1])
693                    .unwrap(),
694            },
695        };
696        assert_eq!(flow_plan.unwrap(), expected);
697    }
698
699    #[tokio::test]
700    async fn test_cast() {
701        let engine = create_test_query_engine();
702        let sql = "SELECT CAST(1 AS INT16) FROM numbers";
703        let plan = sql_to_substrait(engine.clone(), sql).await;
704
705        let mut ctx = create_test_ctx();
706        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
707
708        let expected = TypedPlan {
709            schema: RelationType::new(vec![ColumnType::new(CDT::int16_datatype(), true)])
710                .into_named(vec![Some(
711                    "arrow_cast(Int64(1),Utf8(\"Int16\"))".to_string(),
712                )]),
713            plan: Plan::Constant {
714                // cast of literal is constant folded
715                rows: vec![(repr::Row::new(vec![Value::from(1i16)]), i64::MIN, 1)],
716            },
717        };
718        assert_eq!(flow_plan.unwrap(), expected);
719    }
720
721    #[tokio::test]
722    async fn test_select_add() {
723        let engine = create_test_query_engine();
724        let sql = "SELECT number+number FROM numbers";
725        let plan = sql_to_substrait(engine.clone(), sql).await;
726
727        let mut ctx = create_test_ctx();
728        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
729
730        let expected = TypedPlan {
731            schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)])
732                .into_named(vec![Some("numbers.number + numbers.number".to_string())]),
733            plan: Plan::Mfp {
734                input: Box::new(
735                    Plan::Get {
736                        id: crate::expr::Id::Global(GlobalId::User(0)),
737                    }
738                    .with_types(
739                        RelationType::new(vec![ColumnType::new(
740                            ConcreteDataType::uint32_datatype(),
741                            false,
742                        )])
743                        .into_named(vec![Some("number".to_string())]),
744                    ),
745                ),
746                mfp: MapFilterProject::new(1)
747                    .map(vec![ScalarExpr::Column(0)
748                        .call_binary(ScalarExpr::Column(0), BinaryFunc::AddUInt32)])
749                    .unwrap()
750                    .project(vec![1])
751                    .unwrap(),
752            },
753        };
754
755        assert_eq!(flow_plan.unwrap(), expected);
756    }
757
758    #[tokio::test]
759    async fn test_func_sig() {
760        fn lit(v: impl ToString) -> substrait_proto::proto::FunctionArgument {
761            use substrait_proto::proto::expression;
762            let expr = Expression {
763                rex_type: Some(expression::RexType::Literal(expression::Literal {
764                    nullable: false,
765                    type_variation_reference: 0,
766                    literal_type: Some(expression::literal::LiteralType::String(v.to_string())),
767                })),
768            };
769            substrait_proto::proto::FunctionArgument {
770                arg_type: Some(substrait_proto::proto::function_argument::ArgType::Value(
771                    expr,
772                )),
773            }
774        }
775
776        let f = substrait_proto::proto::expression::ScalarFunction {
777            function_reference: 0,
778            arguments: vec![proto_col(0)],
779            options: vec![],
780            output_type: None,
781            ..Default::default()
782        };
783        let input_schema =
784            RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]).into_unnamed();
785        let extensions = FunctionExtensions::from_iter([(0, "is_null".to_string())]);
786        let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions)
787            .await
788            .unwrap();
789
790        assert_eq!(
791            res,
792            TypedExpr {
793                expr: ScalarExpr::Column(0).call_unary(UnaryFunc::IsNull),
794                typ: ColumnType {
795                    scalar_type: CDT::boolean_datatype(),
796                    nullable: true,
797                },
798            }
799        );
800
801        let f = substrait_proto::proto::expression::ScalarFunction {
802            function_reference: 0,
803            arguments: vec![proto_col(0), proto_col(1)],
804            options: vec![],
805            output_type: None,
806            ..Default::default()
807        };
808        let input_schema = RelationType::new(vec![
809            ColumnType::new(CDT::uint32_datatype(), false),
810            ColumnType::new(CDT::uint32_datatype(), false),
811        ])
812        .into_unnamed();
813        let extensions = FunctionExtensions::from_iter([(0, "add".to_string())]);
814        let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions)
815            .await
816            .unwrap();
817
818        assert_eq!(
819            res,
820            TypedExpr {
821                expr: ScalarExpr::Column(0)
822                    .call_binary(ScalarExpr::Column(1), BinaryFunc::AddUInt32,),
823                typ: ColumnType {
824                    scalar_type: CDT::uint32_datatype(),
825                    nullable: true,
826                },
827            }
828        );
829    }
830}