flow/transform/
expr.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![warn(unused_imports)]
16
17use std::sync::Arc;
18
19use common_error::ext::BoxedError;
20use common_telemetry::debug;
21use datafusion::execution::SessionStateBuilder;
22use datafusion::functions::all_default_functions;
23use datafusion_physical_expr::PhysicalExpr;
24use datafusion_substrait::logical_plan::consumer::DefaultSubstraitConsumer;
25use datatypes::data_type::ConcreteDataType as CDT;
26use snafu::{ensure, OptionExt, ResultExt};
27use substrait_proto::proto::expression::field_reference::ReferenceType::DirectReference;
28use substrait_proto::proto::expression::reference_segment::ReferenceType::StructField;
29use substrait_proto::proto::expression::{IfThen, RexType, ScalarFunction};
30use substrait_proto::proto::function_argument::ArgType;
31use substrait_proto::proto::Expression;
32
33use crate::error::{
34    DatafusionSnafu, DatatypesSnafu, Error, EvalSnafu, ExternalSnafu, InvalidQuerySnafu,
35    NotImplementedSnafu, PlanSnafu, UnexpectedSnafu,
36};
37use crate::expr::{
38    BinaryFunc, DfScalarFunction, RawDfScalarFn, ScalarExpr, TypedExpr, UnaryFunc,
39    UnmaterializableFunc, VariadicFunc, TUMBLE_END, TUMBLE_START,
40};
41use crate::repr::{ColumnType, RelationDesc, RelationType};
42use crate::transform::literal::{
43    from_substrait_literal, from_substrait_type, to_substrait_literal,
44};
45use crate::transform::{substrait_proto, FunctionExtensions};
46
47// TODO(discord9): refactor plan to substrait convert of `arrow_cast` function thus remove this function
48/// ref to `arrow_schema::datatype` for type name
49fn typename_to_cdt(name: &str) -> Result<CDT, Error> {
50    let ret = match name {
51        "Int8" => CDT::int8_datatype(),
52        "Int16" => CDT::int16_datatype(),
53        "Int32" => CDT::int32_datatype(),
54        "Int64" => CDT::int64_datatype(),
55        "UInt8" => CDT::uint8_datatype(),
56        "UInt16" => CDT::uint16_datatype(),
57        "UInt32" => CDT::uint32_datatype(),
58        "UInt64" => CDT::uint64_datatype(),
59        "Float32" => CDT::float32_datatype(),
60        "Float64" => CDT::float64_datatype(),
61        "Boolean" => CDT::boolean_datatype(),
62        "String" => CDT::string_datatype(),
63        "Date" | "Date32" | "Date64" => CDT::date_datatype(),
64        "Timestamp" => CDT::timestamp_second_datatype(),
65        "Timestamp(Second, None)" => CDT::timestamp_second_datatype(),
66        "Timestamp(Millisecond, None)" => CDT::timestamp_millisecond_datatype(),
67        "Timestamp(Microsecond, None)" => CDT::timestamp_microsecond_datatype(),
68        "Timestamp(Nanosecond, None)" => CDT::timestamp_nanosecond_datatype(),
69        "Time32(Second)" | "Time64(Second)" => CDT::time_second_datatype(),
70        "Time32(Millisecond)" | "Time64(Millisecond)" => CDT::time_millisecond_datatype(),
71        "Time32(Microsecond)" | "Time64(Microsecond)" => CDT::time_microsecond_datatype(),
72        "Time32(Nanosecond)" | "Time64(Nanosecond)" => CDT::time_nanosecond_datatype(),
73        _ => NotImplementedSnafu {
74            reason: format!("Unrecognized typename: {}", name),
75        }
76        .fail()?,
77    };
78    Ok(ret)
79}
80
81/// Convert [`ScalarFunction`] to corresponding Datafusion's [`PhysicalExpr`]
82pub(crate) async fn from_scalar_fn_to_df_fn_impl(
83    f: &ScalarFunction,
84    input_schema: &RelationDesc,
85    extensions: &FunctionExtensions,
86) -> Result<Arc<dyn PhysicalExpr>, Error> {
87    let e = Expression {
88        rex_type: Some(RexType::ScalarFunction(f.clone())),
89    };
90    let schema = input_schema.to_df_schema()?;
91
92    let extensions = extensions.to_extensions();
93    let session_state = SessionStateBuilder::new()
94        .with_scalar_functions(all_default_functions())
95        .build();
96    let consumer = DefaultSubstraitConsumer::new(&extensions, &session_state);
97    let df_expr =
98        substrait::df_logical_plan::consumer::from_substrait_rex(&consumer, &e, &schema).await;
99    let expr = df_expr.context({
100        DatafusionSnafu {
101            context: "Failed to convert substrait scalar function to datafusion scalar function",
102        }
103    })?;
104    let phy_expr =
105        datafusion::physical_expr::create_physical_expr(&expr, &schema, &Default::default())
106            .context(DatafusionSnafu {
107                context: "Failed to create physical expression from logical expression",
108            })?;
109    Ok(phy_expr)
110}
111
112/// Return an [`Expression`](wrapped in a [`FunctionArgument`]) that references the i-th column of the input relation
113pub(crate) fn proto_col(i: usize) -> substrait_proto::proto::FunctionArgument {
114    use substrait_proto::proto::expression;
115    let expr = Expression {
116        rex_type: Some(expression::RexType::Selection(Box::new(
117            expression::FieldReference {
118                reference_type: Some(expression::field_reference::ReferenceType::DirectReference(
119                    expression::ReferenceSegment {
120                        reference_type: Some(
121                            expression::reference_segment::ReferenceType::StructField(Box::new(
122                                expression::reference_segment::StructField {
123                                    field: i as i32,
124                                    child: None,
125                                },
126                            )),
127                        ),
128                    },
129                )),
130                root_type: None,
131            },
132        ))),
133    };
134    substrait_proto::proto::FunctionArgument {
135        arg_type: Some(substrait_proto::proto::function_argument::ArgType::Value(
136            expr,
137        )),
138    }
139}
140
141fn is_proto_literal(arg: &substrait_proto::proto::FunctionArgument) -> bool {
142    use substrait_proto::proto::expression;
143    matches!(
144        arg.arg_type.as_ref().unwrap(),
145        ArgType::Value(Expression {
146            rex_type: Some(expression::RexType::Literal(_)),
147        })
148    )
149}
150
151fn build_proto_lit(
152    lit: substrait_proto::proto::expression::Literal,
153) -> substrait_proto::proto::FunctionArgument {
154    use substrait_proto::proto;
155    proto::FunctionArgument {
156        arg_type: Some(ArgType::Value(Expression {
157            rex_type: Some(proto::expression::RexType::Literal(lit)),
158        })),
159    }
160}
161
162/// rewrite ScalarFunction's arguments to Columns 0..n so nested exprs are still handled by us instead of datafusion
163///
164/// specially, if a argument is a literal, the replacement will not happen
165fn rewrite_scalar_function(
166    f: &ScalarFunction,
167    arg_typed_exprs: &[TypedExpr],
168) -> Result<ScalarFunction, Error> {
169    let mut f_rewrite = f.clone();
170    ensure!(
171        f_rewrite.arguments.len() == arg_typed_exprs.len(),
172        crate::error::InternalSnafu {
173            reason: format!(
174                "Expect `f_rewrite` and `arg_typed_expr` to be same length, found {} and {}",
175                f_rewrite.arguments.len(),
176                arg_typed_exprs.len()
177            )
178        }
179    );
180    for (idx, raw_expr) in f_rewrite.arguments.iter_mut().enumerate() {
181        // only replace it with col(idx) if it is not literal
182        // will try best to determine if it is literal, i.e. for function like `cast(<literal>)` will try
183        // in both world to understand if it results in a literal
184        match (
185            is_proto_literal(raw_expr),
186            arg_typed_exprs[idx].expr.is_literal(),
187        ) {
188            (false, false) => *raw_expr = proto_col(idx),
189            (true, _) => (),
190            (false, true) => {
191                if let ScalarExpr::Literal(val, ty) = &arg_typed_exprs[idx].expr {
192                    let df_val = val
193                        .try_to_scalar_value(ty)
194                        .map_err(BoxedError::new)
195                        .context(ExternalSnafu)?;
196                    let lit_sub = to_substrait_literal(&df_val)?;
197                    // put const-folded literal back to df to simplify stuff
198                    *raw_expr = build_proto_lit(lit_sub);
199                } else {
200                    UnexpectedSnafu {
201                        reason: format!(
202                            "Expect value to be literal, but found {:?}",
203                            arg_typed_exprs[idx].expr
204                        ),
205                    }
206                    .fail()?
207                }
208            }
209        }
210    }
211    Ok(f_rewrite)
212}
213
214impl TypedExpr {
215    pub async fn from_substrait_to_datafusion_scalar_func(
216        f: &ScalarFunction,
217        arg_typed_exprs: Vec<TypedExpr>,
218        extensions: &FunctionExtensions,
219    ) -> Result<TypedExpr, Error> {
220        let (arg_exprs, arg_types): (Vec<_>, Vec<_>) = arg_typed_exprs
221            .clone()
222            .into_iter()
223            .map(|e| (e.expr, e.typ))
224            .unzip();
225        debug!("Before rewrite: {:?}", f);
226        let f_rewrite = rewrite_scalar_function(f, &arg_typed_exprs)?;
227        debug!("After rewrite: {:?}", f_rewrite);
228        let input_schema = RelationType::new(arg_types).into_unnamed();
229        let raw_fn =
230            RawDfScalarFn::from_proto(&f_rewrite, input_schema.clone(), extensions.clone())?;
231
232        let df_func = DfScalarFunction::try_from_raw_fn(raw_fn).await?;
233        let expr = ScalarExpr::CallDf {
234            df_scalar_fn: df_func,
235            exprs: arg_exprs,
236        };
237        // df already know it's own schema, so not providing here
238        let ret_type = expr.typ(&[])?;
239        Ok(TypedExpr::new(expr, ret_type))
240    }
241
242    /// Convert ScalarFunction into Flow's ScalarExpr
243    pub async fn from_substrait_scalar_func(
244        f: &ScalarFunction,
245        input_schema: &RelationDesc,
246        extensions: &FunctionExtensions,
247    ) -> Result<TypedExpr, Error> {
248        let fn_name =
249            extensions
250                .get(&f.function_reference)
251                .with_context(|| NotImplementedSnafu {
252                    reason: format!(
253                        "Aggregated function not found: function reference = {:?}",
254                        f.function_reference
255                    ),
256                })?;
257        let arg_len = f.arguments.len();
258        let arg_typed_exprs: Vec<TypedExpr> = {
259            let mut rets = Vec::new();
260            for arg in f.arguments.iter() {
261                let ret = match &arg.arg_type {
262                    Some(ArgType::Value(e)) => {
263                        TypedExpr::from_substrait_rex(e, input_schema, extensions).await
264                    }
265                    _ => not_impl_err!("Aggregated function argument non-Value type not supported"),
266                }?;
267                rets.push(ret);
268            }
269            rets
270        };
271
272        // literal's type is determined by the function and type of other args
273        let (arg_exprs, arg_types): (Vec<_>, Vec<_>) = arg_typed_exprs
274            .clone()
275            .into_iter()
276            .map(
277                |TypedExpr {
278                     expr: arg_val,
279                     typ: arg_type,
280                 }| {
281                    if arg_val.is_literal() {
282                        (arg_val, None)
283                    } else {
284                        (arg_val, Some(arg_type.scalar_type))
285                    }
286                },
287            )
288            .unzip();
289
290        match arg_len {
291            1 if UnaryFunc::is_valid_func_name(fn_name) => {
292                let func = UnaryFunc::from_str_and_type(fn_name, None)?;
293                let arg = arg_exprs[0].clone();
294                let ret_type = ColumnType::new_nullable(func.signature().output.clone());
295
296                Ok(TypedExpr::new(arg.call_unary(func), ret_type))
297            }
298            2 if fn_name == "arrow_cast" => {
299                let cast_to = arg_exprs[1]
300                    .clone()
301                    .as_literal()
302                    .and_then(|lit| lit.as_string())
303                    .with_context(|| InvalidQuerySnafu {
304                        reason: "array_cast's second argument must be a literal string",
305                    })?;
306                let cast_to = typename_to_cdt(&cast_to)?;
307                let func = UnaryFunc::Cast(cast_to.clone());
308                let arg = arg_exprs[0].clone();
309                // constant folding here since some datafusion function require it for constant arg(i.e. `DATE_BIN`)
310                if arg.is_literal() {
311                    let res = func.eval(&[], &arg).context(EvalSnafu)?;
312                    Ok(TypedExpr::new(
313                        ScalarExpr::Literal(res, cast_to.clone()),
314                        ColumnType::new_nullable(cast_to),
315                    ))
316                } else {
317                    let ret_type = ColumnType::new_nullable(func.signature().output.clone());
318
319                    Ok(TypedExpr::new(arg.call_unary(func), ret_type))
320                }
321            }
322            2 if BinaryFunc::is_valid_func_name(fn_name) => {
323                let (func, signature) =
324                    BinaryFunc::from_str_expr_and_type(fn_name, &arg_exprs, &arg_types[0..2])?;
325
326                // constant folding here
327                let is_all_literal = arg_exprs.iter().all(|arg| arg.is_literal());
328                if is_all_literal {
329                    let res = func
330                        .eval(&[], &arg_exprs[0], &arg_exprs[1])
331                        .context(EvalSnafu)?;
332
333                    // if output type is null, it should be inferred from the input types
334                    let con_typ = signature.output.clone();
335                    let typ = ColumnType::new_nullable(con_typ.clone());
336                    return Ok(TypedExpr::new(ScalarExpr::Literal(res, con_typ), typ));
337                }
338
339                let mut arg_exprs = arg_exprs;
340                for (idx, arg_expr) in arg_exprs.iter_mut().enumerate() {
341                    if let ScalarExpr::Literal(val, typ) = arg_expr {
342                        let dest_type = signature.input[idx].clone();
343
344                        // cast val to target_type
345                        let dest_val = if !dest_type.is_null() {
346                            datatypes::types::cast(val.clone(), &dest_type)
347                        .with_context(|_|
348                            DatatypesSnafu{
349                                extra: format!("Failed to implicitly cast literal {val:?} to type {dest_type:?}")
350                            })?
351                        } else {
352                            val.clone()
353                        };
354                        *val = dest_val;
355                        *typ = dest_type;
356                    }
357                }
358
359                let ret_type = ColumnType::new_nullable(func.signature().output.clone());
360                let ret_expr = arg_exprs[0].clone().call_binary(arg_exprs[1].clone(), func);
361                Ok(TypedExpr::new(ret_expr, ret_type))
362            }
363            _var => {
364                if fn_name == TUMBLE_START || fn_name == TUMBLE_END {
365                    let (func, arg) = UnaryFunc::from_tumble_func(fn_name, &arg_typed_exprs)?;
366
367                    let ret_type = ColumnType::new_nullable(func.signature().output.clone());
368
369                    Ok(TypedExpr::new(arg.expr.call_unary(func), ret_type))
370                } else if VariadicFunc::is_valid_func_name(fn_name) {
371                    let func = VariadicFunc::from_str_and_types(fn_name, &arg_types)?;
372                    let ret_type = ColumnType::new_nullable(func.signature().output.clone());
373                    let mut expr = ScalarExpr::CallVariadic {
374                        func,
375                        exprs: arg_exprs,
376                    };
377                    expr.optimize();
378                    Ok(TypedExpr::new(expr, ret_type))
379                } else if UnmaterializableFunc::is_valid_func_name(fn_name) {
380                    let func = UnmaterializableFunc::from_str_args(fn_name, arg_typed_exprs)?;
381                    let ret_type = ColumnType::new_nullable(func.signature().output.clone());
382                    Ok(TypedExpr::new(
383                        ScalarExpr::CallUnmaterializable(func),
384                        ret_type,
385                    ))
386                } else {
387                    let try_as_df = Self::from_substrait_to_datafusion_scalar_func(
388                        f,
389                        arg_typed_exprs,
390                        extensions,
391                    )
392                    .await?;
393                    Ok(try_as_df)
394                }
395            }
396        }
397    }
398
399    /// Convert IfThen into Flow's ScalarExpr
400    pub async fn from_substrait_ifthen_rex(
401        if_then: &IfThen,
402        input_schema: &RelationDesc,
403        extensions: &FunctionExtensions,
404    ) -> Result<TypedExpr, Error> {
405        let ifs: Vec<_> = {
406            let mut ifs = Vec::new();
407            for if_clause in if_then.ifs.iter() {
408                let proto_if = if_clause.r#if.as_ref().with_context(|| InvalidQuerySnafu {
409                    reason: "IfThen clause without if",
410                })?;
411                let proto_then = if_clause.then.as_ref().with_context(|| InvalidQuerySnafu {
412                    reason: "IfThen clause without then",
413                })?;
414                let cond =
415                    TypedExpr::from_substrait_rex(proto_if, input_schema, extensions).await?;
416                let then =
417                    TypedExpr::from_substrait_rex(proto_then, input_schema, extensions).await?;
418                ifs.push((cond, then));
419            }
420            ifs
421        };
422        // if no else is presented
423        let els = match if_then
424            .r#else
425            .as_ref()
426            .map(|e| TypedExpr::from_substrait_rex(e, input_schema, extensions))
427        {
428            Some(fut) => Some(fut.await),
429            None => None,
430        }
431        .transpose()?
432        .unwrap_or_else(|| {
433            TypedExpr::new(
434                ScalarExpr::literal_null(),
435                ColumnType::new_nullable(CDT::null_datatype()),
436            )
437        });
438
439        fn build_if_then_recur(
440            mut next_if_then: impl Iterator<Item = (TypedExpr, TypedExpr)>,
441            els: TypedExpr,
442        ) -> TypedExpr {
443            if let Some((cond, then)) = next_if_then.next() {
444                // always assume the type of `if`` expr is the same with the `then`` expr
445                TypedExpr::new(
446                    ScalarExpr::If {
447                        cond: Box::new(cond.expr),
448                        then: Box::new(then.expr),
449                        els: Box::new(build_if_then_recur(next_if_then, els).expr),
450                    },
451                    then.typ,
452                )
453            } else {
454                els
455            }
456        }
457        let expr_if = build_if_then_recur(ifs.into_iter(), els);
458        Ok(expr_if)
459    }
460    /// Convert Substrait Rex into Flow's ScalarExpr
461    #[async_recursion::async_recursion]
462    pub async fn from_substrait_rex(
463        e: &Expression,
464        input_schema: &RelationDesc,
465        extensions: &FunctionExtensions,
466    ) -> Result<TypedExpr, Error> {
467        match &e.rex_type {
468            Some(RexType::Literal(lit)) => {
469                let lit = from_substrait_literal(lit)?;
470                Ok(TypedExpr::new(
471                    ScalarExpr::Literal(lit.0, lit.1.clone()),
472                    ColumnType::new_nullable(lit.1),
473                ))
474            }
475            Some(RexType::SingularOrList(s)) => {
476                let substrait_expr = s.value.as_ref().with_context(|| InvalidQuerySnafu {
477                    reason: "SingularOrList expression without value",
478                })?;
479                // Note that we didn't impl support to in list expr
480                if !s.options.is_empty() {
481                    return not_impl_err!("In list expression is not supported");
482                }
483                TypedExpr::from_substrait_rex(substrait_expr, input_schema, extensions).await
484            }
485            Some(RexType::Selection(field_ref)) => match &field_ref.reference_type {
486                Some(DirectReference(direct)) => match &direct.reference_type.as_ref() {
487                    Some(StructField(x)) => match &x.child.as_ref() {
488                        Some(_) => {
489                            not_impl_err!(
490                                "Direct reference StructField with child is not supported"
491                            )
492                        }
493                        None => {
494                            let column = x.field as usize;
495                            let column_type = input_schema.typ().column_types[column].clone();
496                            Ok(TypedExpr::new(ScalarExpr::Column(column), column_type))
497                        }
498                    },
499                    _ => not_impl_err!(
500                        "Direct reference with types other than StructField is not supported"
501                    ),
502                },
503                _ => not_impl_err!("unsupported field ref type"),
504            },
505            Some(RexType::ScalarFunction(f)) => {
506                TypedExpr::from_substrait_scalar_func(f, input_schema, extensions).await
507            }
508            Some(RexType::IfThen(if_then)) => {
509                TypedExpr::from_substrait_ifthen_rex(if_then, input_schema, extensions).await
510            }
511            Some(RexType::Cast(cast)) => {
512                let input = cast.input.as_ref().with_context(|| InvalidQuerySnafu {
513                    reason: "Cast expression without input",
514                })?;
515                let input = TypedExpr::from_substrait_rex(input, input_schema, extensions).await?;
516                let cast_type = from_substrait_type(cast.r#type.as_ref().with_context(|| {
517                    InvalidQuerySnafu {
518                        reason: "Cast expression without type",
519                    }
520                })?)?;
521                let func = UnaryFunc::from_str_and_type("cast", Some(cast_type.clone()))?;
522                Ok(TypedExpr::new(
523                    input.expr.call_unary(func),
524                    ColumnType::new_nullable(cast_type),
525                ))
526            }
527            Some(RexType::WindowFunction(_)) => PlanSnafu {
528                reason:
529                    "Window function is not supported yet. Please use aggregation function instead."
530                        .to_string(),
531            }
532            .fail(),
533            _ => not_impl_err!("unsupported rex_type"),
534        }
535    }
536}
537
538#[cfg(test)]
539mod test {
540    use datatypes::prelude::ConcreteDataType;
541    use datatypes::value::Value;
542    use pretty_assertions::assert_eq;
543
544    use super::*;
545    use crate::expr::{GlobalId, MapFilterProject};
546    use crate::plan::{Plan, TypedPlan};
547    use crate::repr::{self, ColumnType, RelationType};
548    use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait};
549
550    /// test if `WHERE` condition can be converted to Flow's ScalarExpr in mfp's filter
551    #[tokio::test]
552    async fn test_where_and() {
553        let engine = create_test_query_engine();
554        let sql =
555            "SELECT number FROM numbers_with_ts WHERE number >= 1 AND number <= 3 AND number!=2";
556        let plan = sql_to_substrait(engine.clone(), sql).await;
557
558        let mut ctx = create_test_ctx();
559        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
560
561        // optimize binary and to variadic and
562        let filter = ScalarExpr::CallVariadic {
563            func: VariadicFunc::And,
564            exprs: vec![
565                ScalarExpr::Column(2).call_binary(
566                    ScalarExpr::Literal(Value::from(1i64), CDT::int64_datatype()),
567                    BinaryFunc::Gte,
568                ),
569                ScalarExpr::Column(2).call_binary(
570                    ScalarExpr::Literal(Value::from(3i64), CDT::int64_datatype()),
571                    BinaryFunc::Lte,
572                ),
573                ScalarExpr::Column(2).call_binary(
574                    ScalarExpr::Literal(Value::from(2i64), CDT::int64_datatype()),
575                    BinaryFunc::NotEq,
576                ),
577            ],
578        };
579        let expected = TypedPlan {
580            schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)])
581                .into_named(vec![Some("number".to_string())]),
582            plan: Plan::Mfp {
583                input: Box::new(
584                    Plan::Get {
585                        id: crate::expr::Id::Global(GlobalId::User(1)),
586                    }
587                    .with_types(
588                        RelationType::new(vec![
589                            ColumnType::new(ConcreteDataType::uint32_datatype(), false),
590                            ColumnType::new(
591                                ConcreteDataType::timestamp_millisecond_datatype(),
592                                false,
593                            ),
594                        ])
595                        .into_named(vec![Some("number".to_string()), Some("ts".to_string())]),
596                    ),
597                ),
598                mfp: MapFilterProject::new(2)
599                    .map(vec![
600                        ScalarExpr::CallUnary {
601                            func: UnaryFunc::Cast(CDT::int64_datatype()),
602                            expr: Box::new(ScalarExpr::Column(0)),
603                        },
604                        ScalarExpr::Column(0),
605                        ScalarExpr::Column(3),
606                    ])
607                    .unwrap()
608                    .filter(vec![filter])
609                    .unwrap()
610                    .project(vec![4])
611                    .unwrap(),
612            },
613        };
614        assert_eq!(flow_plan.unwrap(), expected);
615    }
616
617    /// case: binary functions&constant folding can happen in converting substrait plan
618    #[tokio::test]
619    async fn test_binary_func_and_constant_folding() {
620        let engine = create_test_query_engine();
621        let sql = "SELECT 1+1*2-1/1+1%2==3 FROM numbers";
622        let plan = sql_to_substrait(engine.clone(), sql).await;
623
624        let mut ctx = create_test_ctx();
625        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
626
627        let expected = TypedPlan {
628            schema: RelationType::new(vec![ColumnType::new(CDT::boolean_datatype(), true)])
629                .into_named(vec![Some("Int64(1) + Int64(1) * Int64(2) - Int64(1) / Int64(1) + Int64(1) % Int64(2) = Int64(3)".to_string())]),
630            plan: Plan::Constant {
631                rows: vec![(
632                    repr::Row::new(vec![Value::from(true)]),
633                    repr::Timestamp::MIN,
634                    1,
635                )],
636            },
637        };
638
639        assert_eq!(flow_plan.unwrap(), expected);
640    }
641
642    /// test if the type of the literal is correctly inferred, i.e. in here literal is decoded to be int64, but need to be uint32,
643    #[tokio::test]
644    async fn test_implicitly_cast() {
645        let engine = create_test_query_engine();
646        let sql = "SELECT number+1 FROM numbers";
647        let plan = sql_to_substrait(engine.clone(), sql).await;
648
649        let mut ctx = create_test_ctx();
650        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
651
652        let expected = TypedPlan {
653            schema: RelationType::new(vec![ColumnType::new(CDT::int64_datatype(), true)])
654                .into_named(vec![Some("numbers.number + Int64(1)".to_string())]),
655            plan: Plan::Mfp {
656                input: Box::new(
657                    Plan::Get {
658                        id: crate::expr::Id::Global(GlobalId::User(0)),
659                    }
660                    .with_types(
661                        RelationType::new(vec![ColumnType::new(
662                            ConcreteDataType::uint32_datatype(),
663                            false,
664                        )])
665                        .into_named(vec![Some("number".to_string())]),
666                    ),
667                ),
668                mfp: MapFilterProject::new(1)
669                    .map(vec![ScalarExpr::Column(0)
670                        .call_unary(UnaryFunc::Cast(CDT::int64_datatype()))
671                        .call_binary(
672                            ScalarExpr::Literal(Value::from(1i64), CDT::int64_datatype()),
673                            BinaryFunc::AddInt64,
674                        )])
675                    .unwrap()
676                    .project(vec![1])
677                    .unwrap(),
678            },
679        };
680        assert_eq!(flow_plan.unwrap(), expected);
681    }
682
683    #[tokio::test]
684    async fn test_cast() {
685        let engine = create_test_query_engine();
686        let sql = "SELECT CAST(1 AS INT16) FROM numbers";
687        let plan = sql_to_substrait(engine.clone(), sql).await;
688
689        let mut ctx = create_test_ctx();
690        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
691
692        let expected = TypedPlan {
693            schema: RelationType::new(vec![ColumnType::new(CDT::int16_datatype(), true)])
694                .into_named(vec![Some(
695                    "arrow_cast(Int64(1),Utf8(\"Int16\"))".to_string(),
696                )]),
697            plan: Plan::Constant {
698                // cast of literal is constant folded
699                rows: vec![(repr::Row::new(vec![Value::from(1i16)]), i64::MIN, 1)],
700            },
701        };
702        assert_eq!(flow_plan.unwrap(), expected);
703    }
704
705    #[tokio::test]
706    async fn test_select_add() {
707        let engine = create_test_query_engine();
708        let sql = "SELECT number+number FROM numbers";
709        let plan = sql_to_substrait(engine.clone(), sql).await;
710
711        let mut ctx = create_test_ctx();
712        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
713
714        let expected = TypedPlan {
715            schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)])
716                .into_named(vec![Some("numbers.number + numbers.number".to_string())]),
717            plan: Plan::Mfp {
718                input: Box::new(
719                    Plan::Get {
720                        id: crate::expr::Id::Global(GlobalId::User(0)),
721                    }
722                    .with_types(
723                        RelationType::new(vec![ColumnType::new(
724                            ConcreteDataType::uint32_datatype(),
725                            false,
726                        )])
727                        .into_named(vec![Some("number".to_string())]),
728                    ),
729                ),
730                mfp: MapFilterProject::new(1)
731                    .map(vec![ScalarExpr::Column(0)
732                        .call_binary(ScalarExpr::Column(0), BinaryFunc::AddUInt32)])
733                    .unwrap()
734                    .project(vec![1])
735                    .unwrap(),
736            },
737        };
738
739        assert_eq!(flow_plan.unwrap(), expected);
740    }
741
742    #[tokio::test]
743    async fn test_func_sig() {
744        fn lit(v: impl ToString) -> substrait_proto::proto::FunctionArgument {
745            use substrait_proto::proto::expression;
746            let expr = Expression {
747                rex_type: Some(expression::RexType::Literal(expression::Literal {
748                    nullable: false,
749                    type_variation_reference: 0,
750                    literal_type: Some(expression::literal::LiteralType::String(v.to_string())),
751                })),
752            };
753            substrait_proto::proto::FunctionArgument {
754                arg_type: Some(substrait_proto::proto::function_argument::ArgType::Value(
755                    expr,
756                )),
757            }
758        }
759
760        let f = substrait_proto::proto::expression::ScalarFunction {
761            function_reference: 0,
762            arguments: vec![proto_col(0)],
763            options: vec![],
764            output_type: None,
765            ..Default::default()
766        };
767        let input_schema =
768            RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]).into_unnamed();
769        let extensions = FunctionExtensions::from_iter([(0, "is_null".to_string())]);
770        let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions)
771            .await
772            .unwrap();
773
774        assert_eq!(
775            res,
776            TypedExpr {
777                expr: ScalarExpr::Column(0).call_unary(UnaryFunc::IsNull),
778                typ: ColumnType {
779                    scalar_type: CDT::boolean_datatype(),
780                    nullable: true,
781                },
782            }
783        );
784
785        let f = substrait_proto::proto::expression::ScalarFunction {
786            function_reference: 0,
787            arguments: vec![proto_col(0), proto_col(1)],
788            options: vec![],
789            output_type: None,
790            ..Default::default()
791        };
792        let input_schema = RelationType::new(vec![
793            ColumnType::new(CDT::uint32_datatype(), false),
794            ColumnType::new(CDT::uint32_datatype(), false),
795        ])
796        .into_unnamed();
797        let extensions = FunctionExtensions::from_iter([(0, "add".to_string())]);
798        let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions)
799            .await
800            .unwrap();
801
802        assert_eq!(
803            res,
804            TypedExpr {
805                expr: ScalarExpr::Column(0)
806                    .call_binary(ScalarExpr::Column(1), BinaryFunc::AddUInt32,),
807                typ: ColumnType {
808                    scalar_type: CDT::uint32_datatype(),
809                    nullable: true,
810                },
811            }
812        );
813    }
814}