1#![warn(unused_imports)]
16
17use std::sync::Arc;
18
19use common_error::ext::BoxedError;
20use common_telemetry::debug;
21use datafusion::execution::SessionStateBuilder;
22use datafusion::functions::all_default_functions;
23use datafusion_physical_expr::PhysicalExpr;
24use datafusion_substrait::logical_plan::consumer::DefaultSubstraitConsumer;
25use datatypes::data_type::ConcreteDataType as CDT;
26use snafu::{ensure, OptionExt, ResultExt};
27use substrait_proto::proto::expression::field_reference::ReferenceType::DirectReference;
28use substrait_proto::proto::expression::reference_segment::ReferenceType::StructField;
29use substrait_proto::proto::expression::{IfThen, RexType, ScalarFunction};
30use substrait_proto::proto::function_argument::ArgType;
31use substrait_proto::proto::Expression;
32
33use crate::error::{
34 DatafusionSnafu, DatatypesSnafu, Error, EvalSnafu, ExternalSnafu, InvalidQuerySnafu,
35 NotImplementedSnafu, PlanSnafu, UnexpectedSnafu,
36};
37use crate::expr::{
38 BinaryFunc, DfScalarFunction, RawDfScalarFn, ScalarExpr, TypedExpr, UnaryFunc,
39 UnmaterializableFunc, VariadicFunc, TUMBLE_END, TUMBLE_START,
40};
41use crate::repr::{ColumnType, RelationDesc, RelationType};
42use crate::transform::literal::{
43 from_substrait_literal, from_substrait_type, to_substrait_literal,
44};
45use crate::transform::{substrait_proto, FunctionExtensions};
46
47fn typename_to_cdt(name: &str) -> Result<CDT, Error> {
50 let ret = match name {
51 "Int8" => CDT::int8_datatype(),
52 "Int16" => CDT::int16_datatype(),
53 "Int32" => CDT::int32_datatype(),
54 "Int64" => CDT::int64_datatype(),
55 "UInt8" => CDT::uint8_datatype(),
56 "UInt16" => CDT::uint16_datatype(),
57 "UInt32" => CDT::uint32_datatype(),
58 "UInt64" => CDT::uint64_datatype(),
59 "Float32" => CDT::float32_datatype(),
60 "Float64" => CDT::float64_datatype(),
61 "Boolean" => CDT::boolean_datatype(),
62 "String" => CDT::string_datatype(),
63 "Date" | "Date32" | "Date64" => CDT::date_datatype(),
64 "Timestamp" => CDT::timestamp_second_datatype(),
65 "Timestamp(Second, None)" => CDT::timestamp_second_datatype(),
66 "Timestamp(Millisecond, None)" => CDT::timestamp_millisecond_datatype(),
67 "Timestamp(Microsecond, None)" => CDT::timestamp_microsecond_datatype(),
68 "Timestamp(Nanosecond, None)" => CDT::timestamp_nanosecond_datatype(),
69 "Time32(Second)" | "Time64(Second)" => CDT::time_second_datatype(),
70 "Time32(Millisecond)" | "Time64(Millisecond)" => CDT::time_millisecond_datatype(),
71 "Time32(Microsecond)" | "Time64(Microsecond)" => CDT::time_microsecond_datatype(),
72 "Time32(Nanosecond)" | "Time64(Nanosecond)" => CDT::time_nanosecond_datatype(),
73 _ => NotImplementedSnafu {
74 reason: format!("Unrecognized typename: {}", name),
75 }
76 .fail()?,
77 };
78 Ok(ret)
79}
80
81pub(crate) async fn from_scalar_fn_to_df_fn_impl(
83 f: &ScalarFunction,
84 input_schema: &RelationDesc,
85 extensions: &FunctionExtensions,
86) -> Result<Arc<dyn PhysicalExpr>, Error> {
87 let e = Expression {
88 rex_type: Some(RexType::ScalarFunction(f.clone())),
89 };
90 let schema = input_schema.to_df_schema()?;
91
92 let extensions = extensions.to_extensions();
93 let session_state = SessionStateBuilder::new()
94 .with_scalar_functions(all_default_functions())
95 .build();
96 let consumer = DefaultSubstraitConsumer::new(&extensions, &session_state);
97 let df_expr =
98 substrait::df_logical_plan::consumer::from_substrait_rex(&consumer, &e, &schema).await;
99 let expr = df_expr.context({
100 DatafusionSnafu {
101 context: "Failed to convert substrait scalar function to datafusion scalar function",
102 }
103 })?;
104 let phy_expr =
105 datafusion::physical_expr::create_physical_expr(&expr, &schema, &Default::default())
106 .context(DatafusionSnafu {
107 context: "Failed to create physical expression from logical expression",
108 })?;
109 Ok(phy_expr)
110}
111
112pub(crate) fn proto_col(i: usize) -> substrait_proto::proto::FunctionArgument {
114 use substrait_proto::proto::expression;
115 let expr = Expression {
116 rex_type: Some(expression::RexType::Selection(Box::new(
117 expression::FieldReference {
118 reference_type: Some(expression::field_reference::ReferenceType::DirectReference(
119 expression::ReferenceSegment {
120 reference_type: Some(
121 expression::reference_segment::ReferenceType::StructField(Box::new(
122 expression::reference_segment::StructField {
123 field: i as i32,
124 child: None,
125 },
126 )),
127 ),
128 },
129 )),
130 root_type: None,
131 },
132 ))),
133 };
134 substrait_proto::proto::FunctionArgument {
135 arg_type: Some(substrait_proto::proto::function_argument::ArgType::Value(
136 expr,
137 )),
138 }
139}
140
141fn is_proto_literal(arg: &substrait_proto::proto::FunctionArgument) -> bool {
142 use substrait_proto::proto::expression;
143 matches!(
144 arg.arg_type.as_ref().unwrap(),
145 ArgType::Value(Expression {
146 rex_type: Some(expression::RexType::Literal(_)),
147 })
148 )
149}
150
151fn build_proto_lit(
152 lit: substrait_proto::proto::expression::Literal,
153) -> substrait_proto::proto::FunctionArgument {
154 use substrait_proto::proto;
155 proto::FunctionArgument {
156 arg_type: Some(ArgType::Value(Expression {
157 rex_type: Some(proto::expression::RexType::Literal(lit)),
158 })),
159 }
160}
161
162fn rewrite_scalar_function(
166 f: &ScalarFunction,
167 arg_typed_exprs: &[TypedExpr],
168) -> Result<ScalarFunction, Error> {
169 let mut f_rewrite = f.clone();
170 ensure!(
171 f_rewrite.arguments.len() == arg_typed_exprs.len(),
172 crate::error::InternalSnafu {
173 reason: format!(
174 "Expect `f_rewrite` and `arg_typed_expr` to be same length, found {} and {}",
175 f_rewrite.arguments.len(),
176 arg_typed_exprs.len()
177 )
178 }
179 );
180 for (idx, raw_expr) in f_rewrite.arguments.iter_mut().enumerate() {
181 match (
185 is_proto_literal(raw_expr),
186 arg_typed_exprs[idx].expr.is_literal(),
187 ) {
188 (false, false) => *raw_expr = proto_col(idx),
189 (true, _) => (),
190 (false, true) => {
191 if let ScalarExpr::Literal(val, ty) = &arg_typed_exprs[idx].expr {
192 let df_val = val
193 .try_to_scalar_value(ty)
194 .map_err(BoxedError::new)
195 .context(ExternalSnafu)?;
196 let lit_sub = to_substrait_literal(&df_val)?;
197 *raw_expr = build_proto_lit(lit_sub);
199 } else {
200 UnexpectedSnafu {
201 reason: format!(
202 "Expect value to be literal, but found {:?}",
203 arg_typed_exprs[idx].expr
204 ),
205 }
206 .fail()?
207 }
208 }
209 }
210 }
211 Ok(f_rewrite)
212}
213
214impl TypedExpr {
215 pub async fn from_substrait_to_datafusion_scalar_func(
216 f: &ScalarFunction,
217 arg_typed_exprs: Vec<TypedExpr>,
218 extensions: &FunctionExtensions,
219 ) -> Result<TypedExpr, Error> {
220 let (arg_exprs, arg_types): (Vec<_>, Vec<_>) = arg_typed_exprs
221 .clone()
222 .into_iter()
223 .map(|e| (e.expr, e.typ))
224 .unzip();
225 debug!("Before rewrite: {:?}", f);
226 let f_rewrite = rewrite_scalar_function(f, &arg_typed_exprs)?;
227 debug!("After rewrite: {:?}", f_rewrite);
228 let input_schema = RelationType::new(arg_types).into_unnamed();
229 let raw_fn =
230 RawDfScalarFn::from_proto(&f_rewrite, input_schema.clone(), extensions.clone())?;
231
232 let df_func = DfScalarFunction::try_from_raw_fn(raw_fn).await?;
233 let expr = ScalarExpr::CallDf {
234 df_scalar_fn: df_func,
235 exprs: arg_exprs,
236 };
237 let ret_type = expr.typ(&[])?;
239 Ok(TypedExpr::new(expr, ret_type))
240 }
241
242 pub async fn from_substrait_scalar_func(
244 f: &ScalarFunction,
245 input_schema: &RelationDesc,
246 extensions: &FunctionExtensions,
247 ) -> Result<TypedExpr, Error> {
248 let fn_name =
249 extensions
250 .get(&f.function_reference)
251 .with_context(|| NotImplementedSnafu {
252 reason: format!(
253 "Aggregated function not found: function reference = {:?}",
254 f.function_reference
255 ),
256 })?;
257 let arg_len = f.arguments.len();
258 let arg_typed_exprs: Vec<TypedExpr> = {
259 let mut rets = Vec::new();
260 for arg in f.arguments.iter() {
261 let ret = match &arg.arg_type {
262 Some(ArgType::Value(e)) => {
263 TypedExpr::from_substrait_rex(e, input_schema, extensions).await
264 }
265 _ => not_impl_err!("Aggregated function argument non-Value type not supported"),
266 }?;
267 rets.push(ret);
268 }
269 rets
270 };
271
272 let (arg_exprs, arg_types): (Vec<_>, Vec<_>) = arg_typed_exprs
274 .clone()
275 .into_iter()
276 .map(
277 |TypedExpr {
278 expr: arg_val,
279 typ: arg_type,
280 }| {
281 if arg_val.is_literal() {
282 (arg_val, None)
283 } else {
284 (arg_val, Some(arg_type.scalar_type))
285 }
286 },
287 )
288 .unzip();
289
290 match arg_len {
291 1 if UnaryFunc::is_valid_func_name(fn_name) => {
292 let func = UnaryFunc::from_str_and_type(fn_name, None)?;
293 let arg = arg_exprs[0].clone();
294 let ret_type = ColumnType::new_nullable(func.signature().output.clone());
295
296 Ok(TypedExpr::new(arg.call_unary(func), ret_type))
297 }
298 2 if fn_name == "arrow_cast" => {
299 let cast_to = arg_exprs[1]
300 .clone()
301 .as_literal()
302 .and_then(|lit| lit.as_string())
303 .with_context(|| InvalidQuerySnafu {
304 reason: "array_cast's second argument must be a literal string",
305 })?;
306 let cast_to = typename_to_cdt(&cast_to)?;
307 let func = UnaryFunc::Cast(cast_to.clone());
308 let arg = arg_exprs[0].clone();
309 if arg.is_literal() {
311 let res = func.eval(&[], &arg).context(EvalSnafu)?;
312 Ok(TypedExpr::new(
313 ScalarExpr::Literal(res, cast_to.clone()),
314 ColumnType::new_nullable(cast_to),
315 ))
316 } else {
317 let ret_type = ColumnType::new_nullable(func.signature().output.clone());
318
319 Ok(TypedExpr::new(arg.call_unary(func), ret_type))
320 }
321 }
322 2 if BinaryFunc::is_valid_func_name(fn_name) => {
323 let (func, signature) =
324 BinaryFunc::from_str_expr_and_type(fn_name, &arg_exprs, &arg_types[0..2])?;
325
326 let is_all_literal = arg_exprs.iter().all(|arg| arg.is_literal());
328 if is_all_literal {
329 let res = func
330 .eval(&[], &arg_exprs[0], &arg_exprs[1])
331 .context(EvalSnafu)?;
332
333 let con_typ = signature.output.clone();
335 let typ = ColumnType::new_nullable(con_typ.clone());
336 return Ok(TypedExpr::new(ScalarExpr::Literal(res, con_typ), typ));
337 }
338
339 let mut arg_exprs = arg_exprs;
340 for (idx, arg_expr) in arg_exprs.iter_mut().enumerate() {
341 if let ScalarExpr::Literal(val, typ) = arg_expr {
342 let dest_type = signature.input[idx].clone();
343
344 let dest_val = if !dest_type.is_null() {
346 datatypes::types::cast(val.clone(), &dest_type)
347 .with_context(|_|
348 DatatypesSnafu{
349 extra: format!("Failed to implicitly cast literal {val:?} to type {dest_type:?}")
350 })?
351 } else {
352 val.clone()
353 };
354 *val = dest_val;
355 *typ = dest_type;
356 }
357 }
358
359 let ret_type = ColumnType::new_nullable(func.signature().output.clone());
360 let ret_expr = arg_exprs[0].clone().call_binary(arg_exprs[1].clone(), func);
361 Ok(TypedExpr::new(ret_expr, ret_type))
362 }
363 _var => {
364 if fn_name == TUMBLE_START || fn_name == TUMBLE_END {
365 let (func, arg) = UnaryFunc::from_tumble_func(fn_name, &arg_typed_exprs)?;
366
367 let ret_type = ColumnType::new_nullable(func.signature().output.clone());
368
369 Ok(TypedExpr::new(arg.expr.call_unary(func), ret_type))
370 } else if VariadicFunc::is_valid_func_name(fn_name) {
371 let func = VariadicFunc::from_str_and_types(fn_name, &arg_types)?;
372 let ret_type = ColumnType::new_nullable(func.signature().output.clone());
373 let mut expr = ScalarExpr::CallVariadic {
374 func,
375 exprs: arg_exprs,
376 };
377 expr.optimize();
378 Ok(TypedExpr::new(expr, ret_type))
379 } else if UnmaterializableFunc::is_valid_func_name(fn_name) {
380 let func = UnmaterializableFunc::from_str_args(fn_name, arg_typed_exprs)?;
381 let ret_type = ColumnType::new_nullable(func.signature().output.clone());
382 Ok(TypedExpr::new(
383 ScalarExpr::CallUnmaterializable(func),
384 ret_type,
385 ))
386 } else {
387 let try_as_df = Self::from_substrait_to_datafusion_scalar_func(
388 f,
389 arg_typed_exprs,
390 extensions,
391 )
392 .await?;
393 Ok(try_as_df)
394 }
395 }
396 }
397 }
398
399 pub async fn from_substrait_ifthen_rex(
401 if_then: &IfThen,
402 input_schema: &RelationDesc,
403 extensions: &FunctionExtensions,
404 ) -> Result<TypedExpr, Error> {
405 let ifs: Vec<_> = {
406 let mut ifs = Vec::new();
407 for if_clause in if_then.ifs.iter() {
408 let proto_if = if_clause.r#if.as_ref().with_context(|| InvalidQuerySnafu {
409 reason: "IfThen clause without if",
410 })?;
411 let proto_then = if_clause.then.as_ref().with_context(|| InvalidQuerySnafu {
412 reason: "IfThen clause without then",
413 })?;
414 let cond =
415 TypedExpr::from_substrait_rex(proto_if, input_schema, extensions).await?;
416 let then =
417 TypedExpr::from_substrait_rex(proto_then, input_schema, extensions).await?;
418 ifs.push((cond, then));
419 }
420 ifs
421 };
422 let els = match if_then
424 .r#else
425 .as_ref()
426 .map(|e| TypedExpr::from_substrait_rex(e, input_schema, extensions))
427 {
428 Some(fut) => Some(fut.await),
429 None => None,
430 }
431 .transpose()?
432 .unwrap_or_else(|| {
433 TypedExpr::new(
434 ScalarExpr::literal_null(),
435 ColumnType::new_nullable(CDT::null_datatype()),
436 )
437 });
438
439 fn build_if_then_recur(
440 mut next_if_then: impl Iterator<Item = (TypedExpr, TypedExpr)>,
441 els: TypedExpr,
442 ) -> TypedExpr {
443 if let Some((cond, then)) = next_if_then.next() {
444 TypedExpr::new(
446 ScalarExpr::If {
447 cond: Box::new(cond.expr),
448 then: Box::new(then.expr),
449 els: Box::new(build_if_then_recur(next_if_then, els).expr),
450 },
451 then.typ,
452 )
453 } else {
454 els
455 }
456 }
457 let expr_if = build_if_then_recur(ifs.into_iter(), els);
458 Ok(expr_if)
459 }
460 #[async_recursion::async_recursion]
462 pub async fn from_substrait_rex(
463 e: &Expression,
464 input_schema: &RelationDesc,
465 extensions: &FunctionExtensions,
466 ) -> Result<TypedExpr, Error> {
467 match &e.rex_type {
468 Some(RexType::Literal(lit)) => {
469 let lit = from_substrait_literal(lit)?;
470 Ok(TypedExpr::new(
471 ScalarExpr::Literal(lit.0, lit.1.clone()),
472 ColumnType::new_nullable(lit.1),
473 ))
474 }
475 Some(RexType::SingularOrList(s)) => {
476 let substrait_expr = s.value.as_ref().with_context(|| InvalidQuerySnafu {
477 reason: "SingularOrList expression without value",
478 })?;
479 if !s.options.is_empty() {
481 return not_impl_err!("In list expression is not supported");
482 }
483 TypedExpr::from_substrait_rex(substrait_expr, input_schema, extensions).await
484 }
485 Some(RexType::Selection(field_ref)) => match &field_ref.reference_type {
486 Some(DirectReference(direct)) => match &direct.reference_type.as_ref() {
487 Some(StructField(x)) => match &x.child.as_ref() {
488 Some(_) => {
489 not_impl_err!(
490 "Direct reference StructField with child is not supported"
491 )
492 }
493 None => {
494 let column = x.field as usize;
495 let column_type = input_schema.typ().column_types[column].clone();
496 Ok(TypedExpr::new(ScalarExpr::Column(column), column_type))
497 }
498 },
499 _ => not_impl_err!(
500 "Direct reference with types other than StructField is not supported"
501 ),
502 },
503 _ => not_impl_err!("unsupported field ref type"),
504 },
505 Some(RexType::ScalarFunction(f)) => {
506 TypedExpr::from_substrait_scalar_func(f, input_schema, extensions).await
507 }
508 Some(RexType::IfThen(if_then)) => {
509 TypedExpr::from_substrait_ifthen_rex(if_then, input_schema, extensions).await
510 }
511 Some(RexType::Cast(cast)) => {
512 let input = cast.input.as_ref().with_context(|| InvalidQuerySnafu {
513 reason: "Cast expression without input",
514 })?;
515 let input = TypedExpr::from_substrait_rex(input, input_schema, extensions).await?;
516 let cast_type = from_substrait_type(cast.r#type.as_ref().with_context(|| {
517 InvalidQuerySnafu {
518 reason: "Cast expression without type",
519 }
520 })?)?;
521 let func = UnaryFunc::from_str_and_type("cast", Some(cast_type.clone()))?;
522 Ok(TypedExpr::new(
523 input.expr.call_unary(func),
524 ColumnType::new_nullable(cast_type),
525 ))
526 }
527 Some(RexType::WindowFunction(_)) => PlanSnafu {
528 reason:
529 "Window function is not supported yet. Please use aggregation function instead."
530 .to_string(),
531 }
532 .fail(),
533 _ => not_impl_err!("unsupported rex_type"),
534 }
535 }
536}
537
538#[cfg(test)]
539mod test {
540 use datatypes::prelude::ConcreteDataType;
541 use datatypes::value::Value;
542 use pretty_assertions::assert_eq;
543
544 use super::*;
545 use crate::expr::{GlobalId, MapFilterProject};
546 use crate::plan::{Plan, TypedPlan};
547 use crate::repr::{self, ColumnType, RelationType};
548 use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait};
549
550 #[tokio::test]
552 async fn test_where_and() {
553 let engine = create_test_query_engine();
554 let sql =
555 "SELECT number FROM numbers_with_ts WHERE number >= 1 AND number <= 3 AND number!=2";
556 let plan = sql_to_substrait(engine.clone(), sql).await;
557
558 let mut ctx = create_test_ctx();
559 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
560
561 let filter = ScalarExpr::CallVariadic {
563 func: VariadicFunc::And,
564 exprs: vec![
565 ScalarExpr::Column(2).call_binary(
566 ScalarExpr::Literal(Value::from(1i64), CDT::int64_datatype()),
567 BinaryFunc::Gte,
568 ),
569 ScalarExpr::Column(2).call_binary(
570 ScalarExpr::Literal(Value::from(3i64), CDT::int64_datatype()),
571 BinaryFunc::Lte,
572 ),
573 ScalarExpr::Column(2).call_binary(
574 ScalarExpr::Literal(Value::from(2i64), CDT::int64_datatype()),
575 BinaryFunc::NotEq,
576 ),
577 ],
578 };
579 let expected = TypedPlan {
580 schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)])
581 .into_named(vec![Some("number".to_string())]),
582 plan: Plan::Mfp {
583 input: Box::new(
584 Plan::Get {
585 id: crate::expr::Id::Global(GlobalId::User(1)),
586 }
587 .with_types(
588 RelationType::new(vec![
589 ColumnType::new(ConcreteDataType::uint32_datatype(), false),
590 ColumnType::new(
591 ConcreteDataType::timestamp_millisecond_datatype(),
592 false,
593 ),
594 ])
595 .into_named(vec![Some("number".to_string()), Some("ts".to_string())]),
596 ),
597 ),
598 mfp: MapFilterProject::new(2)
599 .map(vec![
600 ScalarExpr::CallUnary {
601 func: UnaryFunc::Cast(CDT::int64_datatype()),
602 expr: Box::new(ScalarExpr::Column(0)),
603 },
604 ScalarExpr::Column(0),
605 ScalarExpr::Column(3),
606 ])
607 .unwrap()
608 .filter(vec![filter])
609 .unwrap()
610 .project(vec![4])
611 .unwrap(),
612 },
613 };
614 assert_eq!(flow_plan.unwrap(), expected);
615 }
616
617 #[tokio::test]
619 async fn test_binary_func_and_constant_folding() {
620 let engine = create_test_query_engine();
621 let sql = "SELECT 1+1*2-1/1+1%2==3 FROM numbers";
622 let plan = sql_to_substrait(engine.clone(), sql).await;
623
624 let mut ctx = create_test_ctx();
625 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
626
627 let expected = TypedPlan {
628 schema: RelationType::new(vec![ColumnType::new(CDT::boolean_datatype(), true)])
629 .into_named(vec![Some("Int64(1) + Int64(1) * Int64(2) - Int64(1) / Int64(1) + Int64(1) % Int64(2) = Int64(3)".to_string())]),
630 plan: Plan::Constant {
631 rows: vec![(
632 repr::Row::new(vec![Value::from(true)]),
633 repr::Timestamp::MIN,
634 1,
635 )],
636 },
637 };
638
639 assert_eq!(flow_plan.unwrap(), expected);
640 }
641
642 #[tokio::test]
644 async fn test_implicitly_cast() {
645 let engine = create_test_query_engine();
646 let sql = "SELECT number+1 FROM numbers";
647 let plan = sql_to_substrait(engine.clone(), sql).await;
648
649 let mut ctx = create_test_ctx();
650 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
651
652 let expected = TypedPlan {
653 schema: RelationType::new(vec![ColumnType::new(CDT::int64_datatype(), true)])
654 .into_named(vec![Some("numbers.number + Int64(1)".to_string())]),
655 plan: Plan::Mfp {
656 input: Box::new(
657 Plan::Get {
658 id: crate::expr::Id::Global(GlobalId::User(0)),
659 }
660 .with_types(
661 RelationType::new(vec![ColumnType::new(
662 ConcreteDataType::uint32_datatype(),
663 false,
664 )])
665 .into_named(vec![Some("number".to_string())]),
666 ),
667 ),
668 mfp: MapFilterProject::new(1)
669 .map(vec![ScalarExpr::Column(0)
670 .call_unary(UnaryFunc::Cast(CDT::int64_datatype()))
671 .call_binary(
672 ScalarExpr::Literal(Value::from(1i64), CDT::int64_datatype()),
673 BinaryFunc::AddInt64,
674 )])
675 .unwrap()
676 .project(vec![1])
677 .unwrap(),
678 },
679 };
680 assert_eq!(flow_plan.unwrap(), expected);
681 }
682
683 #[tokio::test]
684 async fn test_cast() {
685 let engine = create_test_query_engine();
686 let sql = "SELECT CAST(1 AS INT16) FROM numbers";
687 let plan = sql_to_substrait(engine.clone(), sql).await;
688
689 let mut ctx = create_test_ctx();
690 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
691
692 let expected = TypedPlan {
693 schema: RelationType::new(vec![ColumnType::new(CDT::int16_datatype(), true)])
694 .into_named(vec![Some(
695 "arrow_cast(Int64(1),Utf8(\"Int16\"))".to_string(),
696 )]),
697 plan: Plan::Constant {
698 rows: vec![(repr::Row::new(vec![Value::from(1i16)]), i64::MIN, 1)],
700 },
701 };
702 assert_eq!(flow_plan.unwrap(), expected);
703 }
704
705 #[tokio::test]
706 async fn test_select_add() {
707 let engine = create_test_query_engine();
708 let sql = "SELECT number+number FROM numbers";
709 let plan = sql_to_substrait(engine.clone(), sql).await;
710
711 let mut ctx = create_test_ctx();
712 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
713
714 let expected = TypedPlan {
715 schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)])
716 .into_named(vec![Some("numbers.number + numbers.number".to_string())]),
717 plan: Plan::Mfp {
718 input: Box::new(
719 Plan::Get {
720 id: crate::expr::Id::Global(GlobalId::User(0)),
721 }
722 .with_types(
723 RelationType::new(vec![ColumnType::new(
724 ConcreteDataType::uint32_datatype(),
725 false,
726 )])
727 .into_named(vec![Some("number".to_string())]),
728 ),
729 ),
730 mfp: MapFilterProject::new(1)
731 .map(vec![ScalarExpr::Column(0)
732 .call_binary(ScalarExpr::Column(0), BinaryFunc::AddUInt32)])
733 .unwrap()
734 .project(vec![1])
735 .unwrap(),
736 },
737 };
738
739 assert_eq!(flow_plan.unwrap(), expected);
740 }
741
742 #[tokio::test]
743 async fn test_func_sig() {
744 fn lit(v: impl ToString) -> substrait_proto::proto::FunctionArgument {
745 use substrait_proto::proto::expression;
746 let expr = Expression {
747 rex_type: Some(expression::RexType::Literal(expression::Literal {
748 nullable: false,
749 type_variation_reference: 0,
750 literal_type: Some(expression::literal::LiteralType::String(v.to_string())),
751 })),
752 };
753 substrait_proto::proto::FunctionArgument {
754 arg_type: Some(substrait_proto::proto::function_argument::ArgType::Value(
755 expr,
756 )),
757 }
758 }
759
760 let f = substrait_proto::proto::expression::ScalarFunction {
761 function_reference: 0,
762 arguments: vec![proto_col(0)],
763 options: vec![],
764 output_type: None,
765 ..Default::default()
766 };
767 let input_schema =
768 RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]).into_unnamed();
769 let extensions = FunctionExtensions::from_iter([(0, "is_null".to_string())]);
770 let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions)
771 .await
772 .unwrap();
773
774 assert_eq!(
775 res,
776 TypedExpr {
777 expr: ScalarExpr::Column(0).call_unary(UnaryFunc::IsNull),
778 typ: ColumnType {
779 scalar_type: CDT::boolean_datatype(),
780 nullable: true,
781 },
782 }
783 );
784
785 let f = substrait_proto::proto::expression::ScalarFunction {
786 function_reference: 0,
787 arguments: vec![proto_col(0), proto_col(1)],
788 options: vec![],
789 output_type: None,
790 ..Default::default()
791 };
792 let input_schema = RelationType::new(vec![
793 ColumnType::new(CDT::uint32_datatype(), false),
794 ColumnType::new(CDT::uint32_datatype(), false),
795 ])
796 .into_unnamed();
797 let extensions = FunctionExtensions::from_iter([(0, "add".to_string())]);
798 let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions)
799 .await
800 .unwrap();
801
802 assert_eq!(
803 res,
804 TypedExpr {
805 expr: ScalarExpr::Column(0)
806 .call_binary(ScalarExpr::Column(1), BinaryFunc::AddUInt32,),
807 typ: ColumnType {
808 scalar_type: CDT::uint32_datatype(),
809 nullable: true,
810 },
811 }
812 );
813 }
814}