1#![warn(unused_imports)]
16
17use std::sync::Arc;
18
19use common_error::ext::BoxedError;
20use common_telemetry::debug;
21use datafusion::execution::SessionStateBuilder;
22use datafusion::functions::all_default_functions;
23use datafusion_physical_expr::PhysicalExpr;
24use datafusion_substrait::logical_plan::consumer::DefaultSubstraitConsumer;
25use datatypes::data_type::ConcreteDataType as CDT;
26use snafu::{ensure, OptionExt, ResultExt};
27use substrait_proto::proto::expression::field_reference::ReferenceType::DirectReference;
28use substrait_proto::proto::expression::reference_segment::ReferenceType::StructField;
29use substrait_proto::proto::expression::{IfThen, RexType, ScalarFunction};
30use substrait_proto::proto::function_argument::ArgType;
31use substrait_proto::proto::Expression;
32
33use crate::error::{
34 DatafusionSnafu, DatatypesSnafu, Error, EvalSnafu, ExternalSnafu, InvalidQuerySnafu,
35 NotImplementedSnafu, PlanSnafu, UnexpectedSnafu,
36};
37use crate::expr::{
38 BinaryFunc, DfScalarFunction, RawDfScalarFn, ScalarExpr, TypedExpr, UnaryFunc,
39 UnmaterializableFunc, VariadicFunc, TUMBLE_END, TUMBLE_START,
40};
41use crate::repr::{ColumnType, RelationDesc, RelationType};
42use crate::transform::literal::{
43 from_substrait_literal, from_substrait_type, to_substrait_literal,
44};
45use crate::transform::{substrait_proto, FunctionExtensions};
46
47fn typename_to_cdt(name: &str) -> Result<CDT, Error> {
50 let ret = match name {
51 "Int8" => CDT::int8_datatype(),
52 "Int16" => CDT::int16_datatype(),
53 "Int32" => CDT::int32_datatype(),
54 "Int64" => CDT::int64_datatype(),
55 "UInt8" => CDT::uint8_datatype(),
56 "UInt16" => CDT::uint16_datatype(),
57 "UInt32" => CDT::uint32_datatype(),
58 "UInt64" => CDT::uint64_datatype(),
59 "Float32" => CDT::float32_datatype(),
60 "Float64" => CDT::float64_datatype(),
61 "Boolean" => CDT::boolean_datatype(),
62 "String" => CDT::string_datatype(),
63 "Date" | "Date32" | "Date64" => CDT::date_datatype(),
64 "Timestamp" => CDT::timestamp_second_datatype(),
65 "Timestamp(Second, None)" => CDT::timestamp_second_datatype(),
66 "Timestamp(Millisecond, None)" => CDT::timestamp_millisecond_datatype(),
67 "Timestamp(Microsecond, None)" => CDT::timestamp_microsecond_datatype(),
68 "Timestamp(Nanosecond, None)" => CDT::timestamp_nanosecond_datatype(),
69 "Time32(Second)" | "Time64(Second)" => CDT::time_second_datatype(),
70 "Time32(Millisecond)" | "Time64(Millisecond)" => CDT::time_millisecond_datatype(),
71 "Time32(Microsecond)" | "Time64(Microsecond)" => CDT::time_microsecond_datatype(),
72 "Time32(Nanosecond)" | "Time64(Nanosecond)" => CDT::time_nanosecond_datatype(),
73 _ => NotImplementedSnafu {
74 reason: format!("Unrecognized typename: {}", name),
75 }
76 .fail()?,
77 };
78 Ok(ret)
79}
80
81pub(crate) async fn from_scalar_fn_to_df_fn_impl(
83 f: &ScalarFunction,
84 input_schema: &RelationDesc,
85 extensions: &FunctionExtensions,
86) -> Result<Arc<dyn PhysicalExpr>, Error> {
87 let e = Expression {
88 rex_type: Some(RexType::ScalarFunction(f.clone())),
89 };
90 let schema = input_schema.to_df_schema()?;
91
92 let extensions = extensions.to_extensions();
93 let session_state = SessionStateBuilder::new()
94 .with_scalar_functions(all_default_functions())
95 .build();
96 let consumer = DefaultSubstraitConsumer::new(&extensions, &session_state);
97 let df_expr =
98 substrait::df_logical_plan::consumer::from_substrait_rex(&consumer, &e, &schema).await;
99 let expr = df_expr.context({
100 DatafusionSnafu {
101 context: "Failed to convert substrait scalar function to datafusion scalar function",
102 }
103 })?;
104 let phy_expr =
105 datafusion::physical_expr::create_physical_expr(&expr, &schema, &Default::default())
106 .context(DatafusionSnafu {
107 context: "Failed to create physical expression from logical expression",
108 })?;
109 Ok(phy_expr)
110}
111
112pub(crate) fn proto_col(i: usize) -> substrait_proto::proto::FunctionArgument {
114 use substrait_proto::proto::expression;
115 let expr = Expression {
116 rex_type: Some(expression::RexType::Selection(Box::new(
117 expression::FieldReference {
118 reference_type: Some(expression::field_reference::ReferenceType::DirectReference(
119 expression::ReferenceSegment {
120 reference_type: Some(
121 expression::reference_segment::ReferenceType::StructField(Box::new(
122 expression::reference_segment::StructField {
123 field: i as i32,
124 child: None,
125 },
126 )),
127 ),
128 },
129 )),
130 root_type: None,
131 },
132 ))),
133 };
134 substrait_proto::proto::FunctionArgument {
135 arg_type: Some(substrait_proto::proto::function_argument::ArgType::Value(
136 expr,
137 )),
138 }
139}
140
141fn is_proto_literal(arg: &substrait_proto::proto::FunctionArgument) -> bool {
142 use substrait_proto::proto::expression;
143 matches!(
144 arg.arg_type.as_ref().unwrap(),
145 ArgType::Value(Expression {
146 rex_type: Some(expression::RexType::Literal(_)),
147 })
148 )
149}
150
151fn build_proto_lit(
152 lit: substrait_proto::proto::expression::Literal,
153) -> substrait_proto::proto::FunctionArgument {
154 use substrait_proto::proto;
155 proto::FunctionArgument {
156 arg_type: Some(ArgType::Value(Expression {
157 rex_type: Some(proto::expression::RexType::Literal(lit)),
158 })),
159 }
160}
161
162fn rewrite_scalar_function(
166 f: &ScalarFunction,
167 arg_typed_exprs: &[TypedExpr],
168) -> Result<ScalarFunction, Error> {
169 let mut f_rewrite = f.clone();
170 ensure!(
171 f_rewrite.arguments.len() == arg_typed_exprs.len(),
172 crate::error::InternalSnafu {
173 reason: format!(
174 "Expect `f_rewrite` and `arg_typed_expr` to be same length, found {} and {}",
175 f_rewrite.arguments.len(),
176 arg_typed_exprs.len()
177 )
178 }
179 );
180 for (idx, raw_expr) in f_rewrite.arguments.iter_mut().enumerate() {
181 match (
185 is_proto_literal(raw_expr),
186 arg_typed_exprs[idx].expr.is_literal(),
187 ) {
188 (false, false) => *raw_expr = proto_col(idx),
189 (true, _) => (),
190 (false, true) => {
191 if let ScalarExpr::Literal(val, ty) = &arg_typed_exprs[idx].expr {
192 let df_val = val
193 .try_to_scalar_value(ty)
194 .map_err(BoxedError::new)
195 .context(ExternalSnafu)?;
196 let lit_sub = to_substrait_literal(&df_val)?;
197 *raw_expr = build_proto_lit(lit_sub);
199 } else {
200 UnexpectedSnafu {
201 reason: format!(
202 "Expect value to be literal, but found {:?}",
203 arg_typed_exprs[idx].expr
204 ),
205 }
206 .fail()?
207 }
208 }
209 }
210 }
211 Ok(f_rewrite)
212}
213
214impl TypedExpr {
215 pub async fn from_substrait_to_datafusion_scalar_func(
216 f: &ScalarFunction,
217 arg_typed_exprs: Vec<TypedExpr>,
218 extensions: &FunctionExtensions,
219 ) -> Result<TypedExpr, Error> {
220 let (arg_exprs, arg_types): (Vec<_>, Vec<_>) = arg_typed_exprs
221 .clone()
222 .into_iter()
223 .map(|e| (e.expr, e.typ))
224 .unzip();
225 debug!("Before rewrite: {:?}", f);
226 let f_rewrite = rewrite_scalar_function(f, &arg_typed_exprs)?;
227 debug!("After rewrite: {:?}", f_rewrite);
228 let input_schema = RelationType::new(arg_types).into_unnamed();
229 let raw_fn =
230 RawDfScalarFn::from_proto(&f_rewrite, input_schema.clone(), extensions.clone())?;
231
232 let df_func = DfScalarFunction::try_from_raw_fn(raw_fn).await?;
233 let expr = ScalarExpr::CallDf {
234 df_scalar_fn: df_func,
235 exprs: arg_exprs,
236 };
237 let ret_type = expr.typ(&[])?;
239 Ok(TypedExpr::new(expr, ret_type))
240 }
241
242 pub async fn from_substrait_scalar_func(
244 f: &ScalarFunction,
245 input_schema: &RelationDesc,
246 extensions: &FunctionExtensions,
247 ) -> Result<TypedExpr, Error> {
248 let fn_name =
249 extensions
250 .get(&f.function_reference)
251 .with_context(|| NotImplementedSnafu {
252 reason: format!(
253 "Aggregated function not found: function reference = {:?}",
254 f.function_reference
255 ),
256 })?;
257 let arg_len = f.arguments.len();
258 let arg_typed_exprs: Vec<TypedExpr> = {
259 let mut rets = Vec::new();
260 for arg in f.arguments.iter() {
261 let ret = match &arg.arg_type {
262 Some(ArgType::Value(e)) => {
263 TypedExpr::from_substrait_rex(e, input_schema, extensions).await
264 }
265 _ => not_impl_err!("Aggregated function argument non-Value type not supported"),
266 }?;
267 rets.push(ret);
268 }
269 rets
270 };
271
272 let (arg_exprs, arg_types): (Vec<_>, Vec<_>) = arg_typed_exprs
274 .clone()
275 .into_iter()
276 .map(
277 |TypedExpr {
278 expr: arg_val,
279 typ: arg_type,
280 }| {
281 if arg_val.is_literal() {
282 (arg_val, None)
283 } else {
284 (arg_val, Some(arg_type.scalar_type))
285 }
286 },
287 )
288 .unzip();
289
290 match arg_len {
291 1 if UnaryFunc::is_valid_func_name(fn_name) => {
292 let func = UnaryFunc::from_str_and_type(fn_name, None)?;
293 let arg = arg_exprs[0].clone();
294 let ret_type = ColumnType::new_nullable(func.signature().output.clone());
295
296 Ok(TypedExpr::new(arg.call_unary(func), ret_type))
297 }
298 2 if fn_name == "arrow_cast" => {
299 let cast_to = arg_exprs[1]
300 .clone()
301 .as_literal()
302 .and_then(|lit| lit.as_string())
303 .with_context(|| InvalidQuerySnafu {
304 reason: "array_cast's second argument must be a literal string",
305 })?;
306 let cast_to = typename_to_cdt(&cast_to)?;
307 let func = UnaryFunc::Cast(cast_to.clone());
308 let arg = arg_exprs[0].clone();
309 if arg.is_literal() {
311 let res = func.eval(&[], &arg).context(EvalSnafu)?;
312 Ok(TypedExpr::new(
313 ScalarExpr::Literal(res, cast_to.clone()),
314 ColumnType::new_nullable(cast_to),
315 ))
316 } else {
317 let ret_type = ColumnType::new_nullable(func.signature().output.clone());
318
319 Ok(TypedExpr::new(arg.call_unary(func), ret_type))
320 }
321 }
322 2 if BinaryFunc::is_valid_func_name(fn_name) => {
323 let (func, signature) =
324 BinaryFunc::from_str_expr_and_type(fn_name, &arg_exprs, &arg_types[0..2])?;
325
326 let is_all_literal = arg_exprs.iter().all(|arg| arg.is_literal());
328 if is_all_literal {
329 let res = func
330 .eval(&[], &arg_exprs[0], &arg_exprs[1])
331 .context(EvalSnafu)?;
332
333 let con_typ = signature.output.clone();
335 let typ = ColumnType::new_nullable(con_typ.clone());
336 return Ok(TypedExpr::new(ScalarExpr::Literal(res, con_typ), typ));
337 }
338
339 let mut arg_exprs = arg_exprs;
340 for (idx, arg_expr) in arg_exprs.iter_mut().enumerate() {
341 if let ScalarExpr::Literal(val, typ) = arg_expr {
342 let dest_type = signature.input[idx].clone();
343
344 let dest_val = if !dest_type.is_null() {
346 datatypes::types::cast(val.clone(), &dest_type)
347 .with_context(|_|
348 DatatypesSnafu{
349 extra: format!("Failed to implicitly cast literal {val:?} to type {dest_type:?}")
350 })?
351 } else {
352 val.clone()
353 };
354 *val = dest_val;
355 *typ = dest_type;
356 }
357 }
358
359 let ret_type = ColumnType::new_nullable(func.signature().output.clone());
360 let ret_expr = arg_exprs[0].clone().call_binary(arg_exprs[1].clone(), func);
361 Ok(TypedExpr::new(ret_expr, ret_type))
362 }
363 _var => {
364 if fn_name == TUMBLE_START || fn_name == TUMBLE_END {
365 let (func, arg) = UnaryFunc::from_tumble_func(fn_name, &arg_typed_exprs)?;
366
367 let ret_type = ColumnType::new_nullable(func.signature().output.clone());
368
369 Ok(TypedExpr::new(arg.expr.call_unary(func), ret_type))
370 } else if VariadicFunc::is_valid_func_name(fn_name) {
371 let func = VariadicFunc::from_str_and_types(fn_name, &arg_types)?;
372 let ret_type = ColumnType::new_nullable(func.signature().output.clone());
373 let mut expr = ScalarExpr::CallVariadic {
374 func,
375 exprs: arg_exprs,
376 };
377 expr.optimize();
378 Ok(TypedExpr::new(expr, ret_type))
379 } else if UnmaterializableFunc::is_valid_func_name(fn_name) {
380 let func = UnmaterializableFunc::from_str_args(fn_name, arg_typed_exprs)?;
381 let ret_type = ColumnType::new_nullable(func.signature().output.clone());
382 Ok(TypedExpr::new(
383 ScalarExpr::CallUnmaterializable(func),
384 ret_type,
385 ))
386 } else {
387 let try_as_df = Self::from_substrait_to_datafusion_scalar_func(
388 f,
389 arg_typed_exprs,
390 extensions,
391 )
392 .await?;
393 Ok(try_as_df)
394 }
395 }
396 }
397 }
398
399 pub async fn from_substrait_ifthen_rex(
401 if_then: &IfThen,
402 input_schema: &RelationDesc,
403 extensions: &FunctionExtensions,
404 ) -> Result<TypedExpr, Error> {
405 let ifs: Vec<_> = {
406 let mut ifs = Vec::new();
407 for if_clause in if_then.ifs.iter() {
408 let proto_if = if_clause.r#if.as_ref().with_context(|| InvalidQuerySnafu {
409 reason: "IfThen clause without if",
410 })?;
411 let proto_then = if_clause.then.as_ref().with_context(|| InvalidQuerySnafu {
412 reason: "IfThen clause without then",
413 })?;
414 let cond =
415 TypedExpr::from_substrait_rex(proto_if, input_schema, extensions).await?;
416 let then =
417 TypedExpr::from_substrait_rex(proto_then, input_schema, extensions).await?;
418 ifs.push((cond, then));
419 }
420 ifs
421 };
422 let els = match if_then
424 .r#else
425 .as_ref()
426 .map(|e| TypedExpr::from_substrait_rex(e, input_schema, extensions))
427 {
428 Some(fut) => Some(fut.await),
429 None => None,
430 }
431 .transpose()?
432 .unwrap_or_else(|| {
433 TypedExpr::new(
434 ScalarExpr::literal_null(),
435 ColumnType::new_nullable(CDT::null_datatype()),
436 )
437 });
438
439 fn build_if_then_recur(
440 mut next_if_then: impl Iterator<Item = (TypedExpr, TypedExpr)>,
441 els: TypedExpr,
442 ) -> TypedExpr {
443 if let Some((cond, then)) = next_if_then.next() {
444 TypedExpr::new(
446 ScalarExpr::If {
447 cond: Box::new(cond.expr),
448 then: Box::new(then.expr),
449 els: Box::new(build_if_then_recur(next_if_then, els).expr),
450 },
451 then.typ,
452 )
453 } else {
454 els
455 }
456 }
457 let expr_if = build_if_then_recur(ifs.into_iter(), els);
458 Ok(expr_if)
459 }
460 #[async_recursion::async_recursion]
462 pub async fn from_substrait_rex(
463 e: &Expression,
464 input_schema: &RelationDesc,
465 extensions: &FunctionExtensions,
466 ) -> Result<TypedExpr, Error> {
467 match &e.rex_type {
468 Some(RexType::Literal(lit)) => {
469 let lit = from_substrait_literal(lit)?;
470 Ok(TypedExpr::new(
471 ScalarExpr::Literal(lit.0, lit.1.clone()),
472 ColumnType::new_nullable(lit.1),
473 ))
474 }
475 Some(RexType::SingularOrList(s)) => {
476 let substrait_expr = s.value.as_ref().with_context(|| InvalidQuerySnafu {
477 reason: "SingularOrList expression without value",
478 })?;
479 let typed_expr =
480 TypedExpr::from_substrait_rex(substrait_expr, input_schema, extensions).await?;
481 if !s.options.is_empty() {
483 let mut list = Vec::with_capacity(s.options.len());
484 for opt in s.options.iter() {
485 let opt_expr =
486 TypedExpr::from_substrait_rex(opt, input_schema, extensions).await?;
487 list.push(opt_expr.expr);
488 }
489 let in_list_expr = ScalarExpr::InList {
490 expr: Box::new(typed_expr.expr),
491 list,
492 };
493 Ok(TypedExpr::new(
494 in_list_expr,
495 ColumnType::new_nullable(CDT::boolean_datatype()),
496 ))
497 } else {
498 Ok(typed_expr)
499 }
500 }
501 Some(RexType::Selection(field_ref)) => match &field_ref.reference_type {
502 Some(DirectReference(direct)) => match &direct.reference_type.as_ref() {
503 Some(StructField(x)) => match &x.child.as_ref() {
504 Some(_) => {
505 not_impl_err!(
506 "Direct reference StructField with child is not supported"
507 )
508 }
509 None => {
510 let column = x.field as usize;
511 let column_type = input_schema.typ().column_types[column].clone();
512 Ok(TypedExpr::new(ScalarExpr::Column(column), column_type))
513 }
514 },
515 _ => not_impl_err!(
516 "Direct reference with types other than StructField is not supported"
517 ),
518 },
519 _ => not_impl_err!("unsupported field ref type"),
520 },
521 Some(RexType::ScalarFunction(f)) => {
522 TypedExpr::from_substrait_scalar_func(f, input_schema, extensions).await
523 }
524 Some(RexType::IfThen(if_then)) => {
525 TypedExpr::from_substrait_ifthen_rex(if_then, input_schema, extensions).await
526 }
527 Some(RexType::Cast(cast)) => {
528 let input = cast.input.as_ref().with_context(|| InvalidQuerySnafu {
529 reason: "Cast expression without input",
530 })?;
531 let input = TypedExpr::from_substrait_rex(input, input_schema, extensions).await?;
532 let cast_type = from_substrait_type(cast.r#type.as_ref().with_context(|| {
533 InvalidQuerySnafu {
534 reason: "Cast expression without type",
535 }
536 })?)?;
537 let func = UnaryFunc::from_str_and_type("cast", Some(cast_type.clone()))?;
538 Ok(TypedExpr::new(
539 input.expr.call_unary(func),
540 ColumnType::new_nullable(cast_type),
541 ))
542 }
543 Some(RexType::WindowFunction(_)) => PlanSnafu {
544 reason:
545 "Window function is not supported yet. Please use aggregation function instead."
546 .to_string(),
547 }
548 .fail(),
549 _ => not_impl_err!("unsupported rex_type"),
550 }
551 }
552}
553
554#[cfg(test)]
555mod test {
556 use datatypes::prelude::ConcreteDataType;
557 use datatypes::value::Value;
558 use pretty_assertions::assert_eq;
559
560 use super::*;
561 use crate::expr::{GlobalId, MapFilterProject};
562 use crate::plan::{Plan, TypedPlan};
563 use crate::repr::{self, ColumnType, RelationType};
564 use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait};
565
566 #[tokio::test]
568 async fn test_where_and() {
569 let engine = create_test_query_engine();
570 let sql =
571 "SELECT number FROM numbers_with_ts WHERE number >= 1 AND number <= 3 AND number!=2";
572 let plan = sql_to_substrait(engine.clone(), sql).await;
573
574 let mut ctx = create_test_ctx();
575 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
576
577 let filter = ScalarExpr::CallVariadic {
579 func: VariadicFunc::And,
580 exprs: vec![
581 ScalarExpr::Column(2).call_binary(
582 ScalarExpr::Literal(Value::from(1i64), CDT::int64_datatype()),
583 BinaryFunc::Gte,
584 ),
585 ScalarExpr::Column(2).call_binary(
586 ScalarExpr::Literal(Value::from(3i64), CDT::int64_datatype()),
587 BinaryFunc::Lte,
588 ),
589 ScalarExpr::Column(2).call_binary(
590 ScalarExpr::Literal(Value::from(2i64), CDT::int64_datatype()),
591 BinaryFunc::NotEq,
592 ),
593 ],
594 };
595 let expected = TypedPlan {
596 schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)])
597 .into_named(vec![Some("number".to_string())]),
598 plan: Plan::Mfp {
599 input: Box::new(
600 Plan::Get {
601 id: crate::expr::Id::Global(GlobalId::User(1)),
602 }
603 .with_types(
604 RelationType::new(vec![
605 ColumnType::new(ConcreteDataType::uint32_datatype(), false),
606 ColumnType::new(
607 ConcreteDataType::timestamp_millisecond_datatype(),
608 false,
609 ),
610 ])
611 .into_named(vec![Some("number".to_string()), Some("ts".to_string())]),
612 ),
613 ),
614 mfp: MapFilterProject::new(2)
615 .map(vec![
616 ScalarExpr::CallUnary {
617 func: UnaryFunc::Cast(CDT::int64_datatype()),
618 expr: Box::new(ScalarExpr::Column(0)),
619 },
620 ScalarExpr::Column(0),
621 ScalarExpr::Column(3),
622 ])
623 .unwrap()
624 .filter(vec![filter])
625 .unwrap()
626 .project(vec![4])
627 .unwrap(),
628 },
629 };
630 assert_eq!(flow_plan.unwrap(), expected);
631 }
632
633 #[tokio::test]
635 async fn test_binary_func_and_constant_folding() {
636 let engine = create_test_query_engine();
637 let sql = "SELECT 1+1*2-1/1+1%2==3 FROM numbers";
638 let plan = sql_to_substrait(engine.clone(), sql).await;
639
640 let mut ctx = create_test_ctx();
641 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
642
643 let expected = TypedPlan {
644 schema: RelationType::new(vec![ColumnType::new(CDT::boolean_datatype(), true)])
645 .into_named(vec![Some("Int64(1) + Int64(1) * Int64(2) - Int64(1) / Int64(1) + Int64(1) % Int64(2) = Int64(3)".to_string())]),
646 plan: Plan::Constant {
647 rows: vec![(
648 repr::Row::new(vec![Value::from(true)]),
649 repr::Timestamp::MIN,
650 1,
651 )],
652 },
653 };
654
655 assert_eq!(flow_plan.unwrap(), expected);
656 }
657
658 #[tokio::test]
660 async fn test_implicitly_cast() {
661 let engine = create_test_query_engine();
662 let sql = "SELECT number+1 FROM numbers";
663 let plan = sql_to_substrait(engine.clone(), sql).await;
664
665 let mut ctx = create_test_ctx();
666 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
667
668 let expected = TypedPlan {
669 schema: RelationType::new(vec![ColumnType::new(CDT::int64_datatype(), true)])
670 .into_named(vec![Some("numbers.number + Int64(1)".to_string())]),
671 plan: Plan::Mfp {
672 input: Box::new(
673 Plan::Get {
674 id: crate::expr::Id::Global(GlobalId::User(0)),
675 }
676 .with_types(
677 RelationType::new(vec![ColumnType::new(
678 ConcreteDataType::uint32_datatype(),
679 false,
680 )])
681 .into_named(vec![Some("number".to_string())]),
682 ),
683 ),
684 mfp: MapFilterProject::new(1)
685 .map(vec![ScalarExpr::Column(0)
686 .call_unary(UnaryFunc::Cast(CDT::int64_datatype()))
687 .call_binary(
688 ScalarExpr::Literal(Value::from(1i64), CDT::int64_datatype()),
689 BinaryFunc::AddInt64,
690 )])
691 .unwrap()
692 .project(vec![1])
693 .unwrap(),
694 },
695 };
696 assert_eq!(flow_plan.unwrap(), expected);
697 }
698
699 #[tokio::test]
700 async fn test_cast() {
701 let engine = create_test_query_engine();
702 let sql = "SELECT CAST(1 AS INT16) FROM numbers";
703 let plan = sql_to_substrait(engine.clone(), sql).await;
704
705 let mut ctx = create_test_ctx();
706 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
707
708 let expected = TypedPlan {
709 schema: RelationType::new(vec![ColumnType::new(CDT::int16_datatype(), true)])
710 .into_named(vec![Some(
711 "arrow_cast(Int64(1),Utf8(\"Int16\"))".to_string(),
712 )]),
713 plan: Plan::Constant {
714 rows: vec![(repr::Row::new(vec![Value::from(1i16)]), i64::MIN, 1)],
716 },
717 };
718 assert_eq!(flow_plan.unwrap(), expected);
719 }
720
721 #[tokio::test]
722 async fn test_select_add() {
723 let engine = create_test_query_engine();
724 let sql = "SELECT number+number FROM numbers";
725 let plan = sql_to_substrait(engine.clone(), sql).await;
726
727 let mut ctx = create_test_ctx();
728 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
729
730 let expected = TypedPlan {
731 schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)])
732 .into_named(vec![Some("numbers.number + numbers.number".to_string())]),
733 plan: Plan::Mfp {
734 input: Box::new(
735 Plan::Get {
736 id: crate::expr::Id::Global(GlobalId::User(0)),
737 }
738 .with_types(
739 RelationType::new(vec![ColumnType::new(
740 ConcreteDataType::uint32_datatype(),
741 false,
742 )])
743 .into_named(vec![Some("number".to_string())]),
744 ),
745 ),
746 mfp: MapFilterProject::new(1)
747 .map(vec![ScalarExpr::Column(0)
748 .call_binary(ScalarExpr::Column(0), BinaryFunc::AddUInt32)])
749 .unwrap()
750 .project(vec![1])
751 .unwrap(),
752 },
753 };
754
755 assert_eq!(flow_plan.unwrap(), expected);
756 }
757
758 #[tokio::test]
759 async fn test_func_sig() {
760 fn lit(v: impl ToString) -> substrait_proto::proto::FunctionArgument {
761 use substrait_proto::proto::expression;
762 let expr = Expression {
763 rex_type: Some(expression::RexType::Literal(expression::Literal {
764 nullable: false,
765 type_variation_reference: 0,
766 literal_type: Some(expression::literal::LiteralType::String(v.to_string())),
767 })),
768 };
769 substrait_proto::proto::FunctionArgument {
770 arg_type: Some(substrait_proto::proto::function_argument::ArgType::Value(
771 expr,
772 )),
773 }
774 }
775
776 let f = substrait_proto::proto::expression::ScalarFunction {
777 function_reference: 0,
778 arguments: vec![proto_col(0)],
779 options: vec![],
780 output_type: None,
781 ..Default::default()
782 };
783 let input_schema =
784 RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]).into_unnamed();
785 let extensions = FunctionExtensions::from_iter([(0, "is_null".to_string())]);
786 let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions)
787 .await
788 .unwrap();
789
790 assert_eq!(
791 res,
792 TypedExpr {
793 expr: ScalarExpr::Column(0).call_unary(UnaryFunc::IsNull),
794 typ: ColumnType {
795 scalar_type: CDT::boolean_datatype(),
796 nullable: true,
797 },
798 }
799 );
800
801 let f = substrait_proto::proto::expression::ScalarFunction {
802 function_reference: 0,
803 arguments: vec![proto_col(0), proto_col(1)],
804 options: vec![],
805 output_type: None,
806 ..Default::default()
807 };
808 let input_schema = RelationType::new(vec![
809 ColumnType::new(CDT::uint32_datatype(), false),
810 ColumnType::new(CDT::uint32_datatype(), false),
811 ])
812 .into_unnamed();
813 let extensions = FunctionExtensions::from_iter([(0, "add".to_string())]);
814 let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions)
815 .await
816 .unwrap();
817
818 assert_eq!(
819 res,
820 TypedExpr {
821 expr: ScalarExpr::Column(0)
822 .call_binary(ScalarExpr::Column(1), BinaryFunc::AddUInt32,),
823 typ: ColumnType {
824 scalar_type: CDT::uint32_datatype(),
825 nullable: true,
826 },
827 }
828 );
829 }
830}