1use itertools::Itertools;
16use snafu::OptionExt;
17use substrait_proto::proto;
18use substrait_proto::proto::aggregate_function::AggregationInvocation;
19use substrait_proto::proto::aggregate_rel::{Grouping, Measure};
20use substrait_proto::proto::function_argument::ArgType;
21
22use crate::error::{Error, NotImplementedSnafu, PlanSnafu};
23use crate::expr::{
24 AggregateExpr, AggregateFunc, MapFilterProject, ScalarExpr, TypedExpr, UnaryFunc,
25};
26use crate::plan::{AccumulablePlan, AggrWithIndex, KeyValPlan, Plan, ReducePlan, TypedPlan};
27use crate::repr::{ColumnType, RelationDesc, RelationType};
28use crate::transform::{substrait_proto, FlownodeContext, FunctionExtensions};
29
30impl TypedExpr {
31 #[allow(deprecated)]
33 async fn from_substrait_agg_grouping(
34 ctx: &mut FlownodeContext,
35 grouping_expressions: &[proto::Expression],
36 groupings: &[Grouping],
37 typ: &RelationDesc,
38 extensions: &FunctionExtensions,
39 ) -> Result<Vec<TypedExpr>, Error> {
40 let _ = ctx;
41 let mut group_expr = vec![];
42 match groupings.len() {
43 1 => {
44 let expressions: Box<dyn Iterator<Item = &proto::Expression> + Send> = if groupings
46 [0]
47 .expression_references
48 .is_empty()
49 {
50 Box::new(groupings[0].grouping_expressions.iter())
51 } else {
52 if groupings[0]
53 .expression_references
54 .iter()
55 .any(|idx| *idx as usize >= grouping_expressions.len())
56 {
57 return PlanSnafu {
58 reason: format!("Invalid grouping expression reference: {:?} for grouping expr: {:?}",
59 groupings[0].expression_references,
60 grouping_expressions
61 ),
62 }.fail()?;
63 }
64 Box::new(
65 groupings[0]
66 .expression_references
67 .iter()
68 .map(|idx| &grouping_expressions[*idx as usize]),
69 )
70 };
71 for e in expressions {
72 let x = TypedExpr::from_substrait_rex(e, typ, extensions).await?;
73 group_expr.push(x);
74 }
75 }
76 _ => {
77 return not_impl_err!(
78 "Grouping sets not support yet, use union all with group by instead."
79 );
80 }
81 };
82 Ok(group_expr)
83 }
84}
85
86impl AggregateExpr {
87 async fn from_substrait_agg_measures(
91 ctx: &mut FlownodeContext,
92 measures: &[Measure],
93 typ: &RelationDesc,
94 extensions: &FunctionExtensions,
95 ) -> Result<Vec<AggregateExpr>, Error> {
96 let _ = ctx;
97 let mut all_aggr_exprs = vec![];
98
99 for m in measures {
100 let filter = match m
101 .filter
102 .as_ref()
103 .map(|fil| TypedExpr::from_substrait_rex(fil, typ, extensions))
104 {
105 Some(fut) => Some(fut.await),
106 None => None,
107 }
108 .transpose()?;
109
110 let aggr_expr = match &m.measure {
111 Some(f) => {
112 let distinct = match f.invocation {
113 _ if f.invocation == AggregationInvocation::Distinct as i32 => true,
114 _ if f.invocation == AggregationInvocation::All as i32 => false,
115 _ => false,
116 };
117 AggregateExpr::from_substrait_agg_func(
118 f, typ, extensions, &filter, &None, distinct,
120 )
121 .await?
122 }
123 None => {
124 return not_impl_err!("Aggregate without aggregate function is not supported")
125 }
126 };
127
128 all_aggr_exprs.extend(aggr_expr);
129 }
130
131 Ok(all_aggr_exprs)
132 }
133
134 pub async fn from_substrait_agg_func(
139 f: &proto::AggregateFunction,
140 input_schema: &RelationDesc,
141 extensions: &FunctionExtensions,
142 filter: &Option<TypedExpr>,
143 order_by: &Option<Vec<TypedExpr>>,
144 distinct: bool,
145 ) -> Result<Vec<AggregateExpr>, Error> {
146 let _ = filter;
148 let _ = order_by;
149 let mut args = vec![];
150 for arg in &f.arguments {
151 let arg_expr = match &arg.arg_type {
152 Some(ArgType::Value(e)) => {
153 TypedExpr::from_substrait_rex(e, input_schema, extensions).await
154 }
155 _ => not_impl_err!("Aggregated function argument non-Value type not supported"),
156 }?;
157 args.push(arg_expr);
158 }
159
160 if args.len() != 1 {
161 let fn_name = extensions.get(&f.function_reference).cloned();
162 return not_impl_err!(
163 "Aggregated function (name={:?}) with multiple arguments is not supported",
164 fn_name
165 );
166 }
167
168 let arg = if let Some(first) = args.first() {
169 first
170 } else {
171 return not_impl_err!("Aggregated function without arguments is not supported");
172 };
173
174 let fn_name = extensions
175 .get(&f.function_reference)
176 .cloned()
177 .map(|s| s.to_lowercase());
178
179 match fn_name.as_ref().map(|s| s.as_ref()) {
180 Some(function_name) => {
181 let func = AggregateFunc::from_str_and_type(
182 function_name,
183 Some(arg.typ.scalar_type.clone()),
184 )?;
185 let exprs = vec![AggregateExpr {
186 func,
187 expr: arg.expr.clone(),
188 distinct,
189 }];
190 Ok(exprs)
191 }
192 None => not_impl_err!(
193 "Aggregated function not found: function anchor = {:?}",
194 f.function_reference
195 ),
196 }
197 }
198}
199
200impl KeyValPlan {
201 fn from_substrait_gen_key_val_plan(
205 aggr_exprs: &mut [AggregateExpr],
206 group_exprs: &[TypedExpr],
207 input_arity: usize,
208 ) -> Result<KeyValPlan, Error> {
209 let group_expr_val = group_exprs
210 .iter()
211 .cloned()
212 .map(|expr| expr.expr.clone())
213 .collect_vec();
214 let output_arity = group_expr_val.len();
215 let key_plan = MapFilterProject::new(input_arity)
216 .map(group_expr_val)?
217 .project(input_arity..input_arity + output_arity)?;
218
219 let val_plan = {
222 let need_mfp = aggr_exprs.iter().any(|agg| agg.expr.as_column().is_none());
223 if need_mfp {
224 let input_exprs = aggr_exprs
226 .iter_mut()
227 .enumerate()
228 .map(|(idx, aggr)| {
229 let ret = aggr.expr.clone();
230 aggr.expr = ScalarExpr::Column(idx);
231 ret
232 })
233 .collect_vec();
234 let aggr_arity = aggr_exprs.len();
235
236 MapFilterProject::new(input_arity)
237 .map(input_exprs)?
238 .project(input_arity..input_arity + aggr_arity)?
239 } else {
240 MapFilterProject::new(input_arity)
242 }
243 };
244 Ok(KeyValPlan {
245 key_plan: key_plan.into_safe(),
246 val_plan: val_plan.into_safe(),
247 })
248 }
249}
250
251fn find_time_index_in_group_exprs(group_exprs: &[TypedExpr]) -> Option<usize> {
255 group_exprs.iter().position(|expr| {
256 matches!(
257 &expr.expr,
258 ScalarExpr::CallUnary {
259 func: UnaryFunc::TumbleWindowFloor { .. },
260 expr: _
261 }
262 ) || expr.typ.scalar_type.is_timestamp()
263 })
264}
265
266impl TypedPlan {
267 #[async_recursion::async_recursion]
273 pub async fn from_substrait_agg_rel(
274 ctx: &mut FlownodeContext,
275 agg: &proto::AggregateRel,
276 extensions: &FunctionExtensions,
277 ) -> Result<TypedPlan, Error> {
278 let input = if let Some(input) = agg.input.as_ref() {
279 TypedPlan::from_substrait_rel(ctx, input, extensions).await?
280 } else {
281 return not_impl_err!("Aggregate without an input is not supported");
282 };
283
284 let group_exprs = TypedExpr::from_substrait_agg_grouping(
285 ctx,
286 &agg.grouping_expressions,
287 &agg.groupings,
288 &input.schema,
289 extensions,
290 )
291 .await?;
292
293 let time_index = find_time_index_in_group_exprs(&group_exprs);
294
295 let mut aggr_exprs = AggregateExpr::from_substrait_agg_measures(
296 ctx,
297 &agg.measures,
298 &input.schema,
299 extensions,
300 )
301 .await?;
302
303 let key_val_plan = KeyValPlan::from_substrait_gen_key_val_plan(
304 &mut aggr_exprs,
305 &group_exprs,
306 input.schema.typ.column_types.len(),
307 )?;
308
309 let output_type = {
311 let mut output_types = Vec::new();
312 let mut output_names = Vec::new();
314
315 for expr in group_exprs.iter() {
317 output_types.push(expr.typ.clone());
318 let col_name = match &expr.expr {
319 ScalarExpr::Column(col) => input.schema.get_name(*col).clone(),
320 _ => None,
322 };
323 output_names.push(col_name)
324 }
325
326 for aggr in &aggr_exprs {
327 output_types.push(ColumnType::new_nullable(
328 aggr.func.signature().output.clone(),
329 ));
330 output_names.push(None);
332 }
333 if group_exprs.is_empty() {
335 RelationType::new(output_types)
336 } else {
337 RelationType::new(output_types).with_key((0..group_exprs.len()).collect_vec())
338 }
339 .with_time_index(time_index)
340 .into_named(output_names)
341 };
342
343 let full_aggrs = aggr_exprs;
346 let mut simple_aggrs = Vec::new();
347 let mut distinct_aggrs = Vec::new();
348 for (output_column, aggr_expr) in full_aggrs.iter().enumerate() {
349 let input_column = aggr_expr.expr.as_column().with_context(|| PlanSnafu {
350 reason: "Expect aggregate argument to be transformed into a column at this point",
351 })?;
352 if aggr_expr.distinct {
353 distinct_aggrs.push(AggrWithIndex::new(
354 aggr_expr.clone(),
355 input_column,
356 output_column,
357 ));
358 } else {
359 simple_aggrs.push(AggrWithIndex::new(
360 aggr_expr.clone(),
361 input_column,
362 output_column,
363 ));
364 }
365 }
366 let accum_plan = AccumulablePlan {
367 full_aggrs,
368 simple_aggrs,
369 distinct_aggrs,
370 };
371 let plan = Plan::Reduce {
372 input: Box::new(input),
373 key_val_plan,
374 reduce_plan: ReducePlan::Accumulable(accum_plan),
375 };
376 return Ok(TypedPlan {
378 schema: output_type,
379 plan,
380 });
381 }
382}
383
384#[cfg(test)]
385mod test {
386 use std::time::Duration;
387
388 use bytes::BytesMut;
389 use common_time::{IntervalMonthDayNano, Timestamp};
390 use datatypes::prelude::ConcreteDataType;
391 use datatypes::value::Value;
392 use pretty_assertions::assert_eq;
393
394 use super::*;
395 use crate::expr::{BinaryFunc, DfScalarFunction, GlobalId, RawDfScalarFn};
396 use crate::plan::{Plan, TypedPlan};
397 use crate::repr::{ColumnType, RelationType};
398 use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait};
399 use crate::transform::CDT;
400
401 #[tokio::test]
402 async fn test_df_func_basic() {
403 let engine = create_test_query_engine();
404 let sql = "SELECT sum(abs(number)) FROM numbers_with_ts GROUP BY tumble(ts, '1 second', '2021-07-01 00:00:00');";
405 let plan = sql_to_substrait(engine.clone(), sql).await;
406
407 let mut ctx = create_test_ctx();
408 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
409 .await
410 .unwrap();
411
412 let aggr_expr = AggregateExpr {
413 func: AggregateFunc::SumUInt64,
414 expr: ScalarExpr::Column(0),
415 distinct: false,
416 };
417 let expected = TypedPlan {
418 schema: RelationType::new(vec![
419 ColumnType::new(CDT::uint64_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ])
423 .with_key(vec![2])
424 .with_time_index(Some(1))
425 .into_named(vec![
426 Some("sum(abs(numbers_with_ts.number))".to_string()),
427 Some("window_start".to_string()),
428 Some("window_end".to_string()),
429 ]),
430 plan: Plan::Mfp {
431 input: Box::new(
432 Plan::Reduce {
433 input: Box::new(
434 Plan::Get {
435 id: crate::expr::Id::Global(GlobalId::User(1)),
436 }
437 .with_types(
438 RelationType::new(vec![
439 ColumnType::new(ConcreteDataType::uint32_datatype(), false),
440 ColumnType::new(
441 ConcreteDataType::timestamp_millisecond_datatype(),
442 false,
443 ),
444 ])
445 .into_named(vec![
446 Some("number".to_string()),
447 Some("ts".to_string()),
448 ]),
449 )
450 .mfp(MapFilterProject::new(2).into_safe())
451 .unwrap(),
452 ),
453 key_val_plan: KeyValPlan {
454 key_plan: MapFilterProject::new(2)
455 .map(vec![
456 ScalarExpr::Column(1).call_unary(
457 UnaryFunc::TumbleWindowFloor {
458 window_size: Duration::from_nanos(1_000_000_000),
459 start_time: Some(Timestamp::new_millisecond(
460 1625097600000,
461 )),
462 },
463 ),
464 ScalarExpr::Column(1).call_unary(
465 UnaryFunc::TumbleWindowCeiling {
466 window_size: Duration::from_nanos(1_000_000_000),
467 start_time: Some(Timestamp::new_millisecond(
468 1625097600000,
469 )),
470 },
471 ),
472 ])
473 .unwrap()
474 .project(vec![2, 3])
475 .unwrap()
476 .into_safe(),
477 val_plan: MapFilterProject::new(2)
478 .map(vec![ScalarExpr::CallDf {
479 df_scalar_fn: DfScalarFunction::try_from_raw_fn(
480 RawDfScalarFn {
481 f: BytesMut::from(
482 b"\x08\x02\"\x08\x1a\x06\x12\x04\n\x02\x12\0"
483 .as_ref(),
484 ),
485 input_schema: RelationType::new(vec![ColumnType::new(
486 ConcreteDataType::uint32_datatype(),
487 false,
488 )])
489 .into_unnamed(),
490 extensions: FunctionExtensions::from_iter(
491 [
492 (0, "tumble_start".to_string()),
493 (1, "tumble_end".to_string()),
494 (2, "abs".to_string()),
495 (3, "sum".to_string()),
496 ]
497 .into_iter(),
498 ),
499 },
500 )
501 .await
502 .unwrap(),
503 exprs: vec![ScalarExpr::Column(0)],
504 }
505 .cast(CDT::uint64_datatype())])
506 .unwrap()
507 .project(vec![2])
508 .unwrap()
509 .into_safe(),
510 },
511 reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
512 full_aggrs: vec![aggr_expr.clone()],
513 simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
514 distinct_aggrs: vec![],
515 }),
516 }
517 .with_types(
518 RelationType::new(vec![
519 ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::uint64_datatype(), true), ])
523 .with_key(vec![1])
524 .with_time_index(Some(0))
525 .into_unnamed(),
526 ),
527 ),
528 mfp: MapFilterProject::new(3)
529 .map(vec![
530 ScalarExpr::Column(2),
531 ScalarExpr::Column(0),
532 ScalarExpr::Column(1),
533 ])
534 .unwrap()
535 .project(vec![3, 4, 5])
536 .unwrap(),
537 },
538 };
539 assert_eq!(flow_plan, expected);
540 }
541
542 #[tokio::test]
543 async fn test_df_func_expr_tree() {
544 let engine = create_test_query_engine();
545 let sql = "SELECT abs(sum(number)) FROM numbers_with_ts GROUP BY tumble(ts, '1 second', '2021-07-01 00:00:00');";
546 let plan = sql_to_substrait(engine.clone(), sql).await;
547
548 let mut ctx = create_test_ctx();
549 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
550 .await
551 .unwrap();
552
553 let aggr_expr = AggregateExpr {
554 func: AggregateFunc::SumUInt64,
555 expr: ScalarExpr::Column(0),
556 distinct: false,
557 };
558 let expected = TypedPlan {
559 schema: RelationType::new(vec![
560 ColumnType::new(CDT::uint64_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ])
564 .with_key(vec![2])
565 .with_time_index(Some(1))
566 .into_named(vec![
567 Some("abs(sum(numbers_with_ts.number))".to_string()),
568 Some("window_start".to_string()),
569 Some("window_end".to_string()),
570 ]),
571 plan: Plan::Mfp {
572 input: Box::new(
573 Plan::Reduce {
574 input: Box::new(
575 Plan::Get {
576 id: crate::expr::Id::Global(GlobalId::User(1)),
577 }
578 .with_types(
579 RelationType::new(vec![
580 ColumnType::new(ConcreteDataType::uint32_datatype(), false),
581 ColumnType::new(
582 ConcreteDataType::timestamp_millisecond_datatype(),
583 false,
584 ),
585 ])
586 .into_named(vec![
587 Some("number".to_string()),
588 Some("ts".to_string()),
589 ]),
590 )
591 .mfp(MapFilterProject::new(2).into_safe())
592 .unwrap(),
593 ),
594 key_val_plan: KeyValPlan {
595 key_plan: MapFilterProject::new(2)
596 .map(vec![
597 ScalarExpr::Column(1).call_unary(
598 UnaryFunc::TumbleWindowFloor {
599 window_size: Duration::from_nanos(1_000_000_000),
600 start_time: Some(Timestamp::new_millisecond(
601 1625097600000,
602 )),
603 },
604 ),
605 ScalarExpr::Column(1).call_unary(
606 UnaryFunc::TumbleWindowCeiling {
607 window_size: Duration::from_nanos(1_000_000_000),
608 start_time: Some(Timestamp::new_millisecond(
609 1625097600000,
610 )),
611 },
612 ),
613 ])
614 .unwrap()
615 .project(vec![2, 3])
616 .unwrap()
617 .into_safe(),
618 val_plan: MapFilterProject::new(2)
619 .map(vec![ScalarExpr::Column(0).cast(CDT::uint64_datatype())])
620 .unwrap()
621 .project(vec![2])
622 .unwrap()
623 .into_safe(),
624 },
625 reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
626 full_aggrs: vec![aggr_expr.clone()],
627 simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
628 distinct_aggrs: vec![],
629 }),
630 }
631 .with_types(
632 RelationType::new(vec![
633 ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::uint64_datatype(), true), ])
637 .with_key(vec![1])
638 .with_time_index(Some(0))
639 .into_named(vec![None, None, None]),
640 ),
641 ),
642 mfp: MapFilterProject::new(3)
643 .map(vec![
644 ScalarExpr::CallDf {
645 df_scalar_fn: DfScalarFunction::try_from_raw_fn(RawDfScalarFn {
646 f: BytesMut::from(b"\"\x08\x1a\x06\x12\x04\n\x02\x12\0".as_ref()),
647 input_schema: RelationType::new(vec![ColumnType::new(
648 ConcreteDataType::uint64_datatype(),
649 true,
650 )])
651 .into_unnamed(),
652 extensions: FunctionExtensions::from_iter(
653 [
654 (0, "abs".to_string()),
655 (1, "tumble_start".to_string()),
656 (2, "tumble_end".to_string()),
657 (3, "sum".to_string()),
658 ]
659 .into_iter(),
660 ),
661 })
662 .await
663 .unwrap(),
664 exprs: vec![ScalarExpr::Column(2)],
665 },
666 ScalarExpr::Column(0),
667 ScalarExpr::Column(1),
668 ])
669 .unwrap()
670 .project(vec![3, 4, 5])
671 .unwrap(),
672 },
673 };
674 assert_eq!(flow_plan, expected);
675 }
676
677 #[tokio::test]
679 async fn test_tumble_composite() {
680 let engine = create_test_query_engine();
681 let sql =
682 "SELECT number, avg(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour'), number";
683 let plan = sql_to_substrait(engine.clone(), sql).await;
684
685 let mut ctx = create_test_ctx();
686 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
687 .await
688 .unwrap();
689
690 let aggr_exprs = vec![
691 AggregateExpr {
692 func: AggregateFunc::SumUInt64,
693 expr: ScalarExpr::Column(0),
694 distinct: false,
695 },
696 AggregateExpr {
697 func: AggregateFunc::Count,
698 expr: ScalarExpr::Column(1),
699 distinct: false,
700 },
701 ];
702 let avg_expr = ScalarExpr::If {
703 cond: Box::new(ScalarExpr::Column(4).call_binary(
704 ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()),
705 BinaryFunc::NotEq,
706 )),
707 then: Box::new(
708 ScalarExpr::Column(3)
709 .cast(CDT::float64_datatype())
710 .call_binary(
711 ScalarExpr::Column(4).cast(CDT::float64_datatype()),
712 BinaryFunc::DivFloat64,
713 ),
714 ),
715 els: Box::new(ScalarExpr::Literal(Value::Null, CDT::float64_datatype())),
716 };
717 let expected = TypedPlan {
718 plan: Plan::Mfp {
719 input: Box::new(
720 Plan::Reduce {
721 input: Box::new(
722 Plan::Get {
723 id: crate::expr::Id::Global(GlobalId::User(1)),
724 }
725 .with_types(
726 RelationType::new(vec![
727 ColumnType::new(ConcreteDataType::uint32_datatype(), false),
728 ColumnType::new(
729 ConcreteDataType::timestamp_millisecond_datatype(),
730 false,
731 ),
732 ])
733 .into_named(vec![
734 Some("number".to_string()),
735 Some("ts".to_string()),
736 ]),
737 )
738 .mfp(MapFilterProject::new(2).into_safe())
739 .unwrap(),
740 ),
741 key_val_plan: KeyValPlan {
742 key_plan: MapFilterProject::new(2)
743 .map(vec![
744 ScalarExpr::Column(1).call_unary(
745 UnaryFunc::TumbleWindowFloor {
746 window_size: Duration::from_nanos(3_600_000_000_000),
747 start_time: None,
748 },
749 ),
750 ScalarExpr::Column(1).call_unary(
751 UnaryFunc::TumbleWindowCeiling {
752 window_size: Duration::from_nanos(3_600_000_000_000),
753 start_time: None,
754 },
755 ),
756 ScalarExpr::Column(0),
757 ])
758 .unwrap()
759 .project(vec![2, 3, 4])
760 .unwrap()
761 .into_safe(),
762 val_plan: MapFilterProject::new(2)
763 .map(vec![
764 ScalarExpr::Column(0).cast(CDT::uint64_datatype()),
765 ScalarExpr::Column(0),
766 ])
767 .unwrap()
768 .project(vec![2, 3])
769 .unwrap()
770 .into_safe(),
771 },
772 reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
773 full_aggrs: aggr_exprs.clone(),
774 simple_aggrs: vec![
775 AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
776 AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1),
777 ],
778 distinct_aggrs: vec![],
779 }),
780 }
781 .with_types(
782 RelationType::new(vec![
783 ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::uint32_datatype(), false), ColumnType::new(CDT::uint64_datatype(), true), ColumnType::new(CDT::int64_datatype(), true), ])
791 .with_key(vec![1, 2])
792 .with_time_index(Some(0))
793 .into_named(vec![
794 None,
795 None,
796 Some("number".to_string()),
797 None,
798 None,
799 ]),
800 ),
801 ),
802 mfp: MapFilterProject::new(5)
803 .map(vec![
804 ScalarExpr::Column(2), avg_expr,
806 ScalarExpr::Column(0), ScalarExpr::Column(1), ])
809 .unwrap()
810 .project(vec![5, 6, 7, 8])
811 .unwrap(),
812 },
813 schema: RelationType::new(vec![
814 ColumnType::new(CDT::uint32_datatype(), false), ColumnType::new(CDT::float64_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ])
819 .with_key(vec![0, 3])
820 .with_time_index(Some(2))
821 .into_named(vec![
822 Some("number".to_string()),
823 Some("avg(numbers_with_ts.number)".to_string()),
824 Some("window_start".to_string()),
825 Some("window_end".to_string()),
826 ]),
827 };
828 assert_eq!(flow_plan, expected);
829 }
830
831 #[tokio::test]
832 async fn test_tumble_parse_optional() {
833 let engine = create_test_query_engine();
834 let sql = "SELECT sum(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour')";
835 let plan = sql_to_substrait(engine.clone(), sql).await;
836
837 let mut ctx = create_test_ctx();
838 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
839 .await
840 .unwrap();
841
842 let aggr_expr = AggregateExpr {
843 func: AggregateFunc::SumUInt64,
844 expr: ScalarExpr::Column(0),
845 distinct: false,
846 };
847 let expected = TypedPlan {
848 schema: RelationType::new(vec![
849 ColumnType::new(CDT::uint64_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ])
853 .with_key(vec![2])
854 .with_time_index(Some(1))
855 .into_named(vec![
856 Some("sum(numbers_with_ts.number)".to_string()),
857 Some("window_start".to_string()),
858 Some("window_end".to_string()),
859 ]),
860 plan: Plan::Mfp {
861 input: Box::new(
862 Plan::Reduce {
863 input: Box::new(
864 Plan::Get {
865 id: crate::expr::Id::Global(GlobalId::User(1)),
866 }
867 .with_types(
868 RelationType::new(vec![
869 ColumnType::new(ConcreteDataType::uint32_datatype(), false),
870 ColumnType::new(
871 ConcreteDataType::timestamp_millisecond_datatype(),
872 false,
873 ),
874 ])
875 .into_named(vec![
876 Some("number".to_string()),
877 Some("ts".to_string()),
878 ]),
879 )
880 .mfp(MapFilterProject::new(2).into_safe())
881 .unwrap(),
882 ),
883 key_val_plan: KeyValPlan {
884 key_plan: MapFilterProject::new(2)
885 .map(vec![
886 ScalarExpr::Column(1).call_unary(
887 UnaryFunc::TumbleWindowFloor {
888 window_size: Duration::from_nanos(3_600_000_000_000),
889 start_time: None,
890 },
891 ),
892 ScalarExpr::Column(1).call_unary(
893 UnaryFunc::TumbleWindowCeiling {
894 window_size: Duration::from_nanos(3_600_000_000_000),
895 start_time: None,
896 },
897 ),
898 ])
899 .unwrap()
900 .project(vec![2, 3])
901 .unwrap()
902 .into_safe(),
903 val_plan: MapFilterProject::new(2)
904 .map(vec![ScalarExpr::Column(0).cast(CDT::uint64_datatype())])
905 .unwrap()
906 .project(vec![2])
907 .unwrap()
908 .into_safe(),
909 },
910 reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
911 full_aggrs: vec![aggr_expr.clone()],
912 simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
913 distinct_aggrs: vec![],
914 }),
915 }
916 .with_types(
917 RelationType::new(vec![
918 ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::uint64_datatype(), true), ])
922 .with_key(vec![1])
923 .with_time_index(Some(0))
924 .into_named(vec![None, None, None]),
925 ),
926 ),
927 mfp: MapFilterProject::new(3)
928 .map(vec![
929 ScalarExpr::Column(2),
930 ScalarExpr::Column(0),
931 ScalarExpr::Column(1),
932 ])
933 .unwrap()
934 .project(vec![3, 4, 5])
935 .unwrap(),
936 },
937 };
938 assert_eq!(flow_plan, expected);
939 }
940
941 #[tokio::test]
942 async fn test_tumble_parse() {
943 let engine = create_test_query_engine();
944 let sql = "SELECT sum(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour', '2021-07-01 00:00:00')";
945 let plan = sql_to_substrait(engine.clone(), sql).await;
946
947 let mut ctx = create_test_ctx();
948 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
949 .await
950 .unwrap();
951
952 let aggr_expr = AggregateExpr {
953 func: AggregateFunc::SumUInt64,
954 expr: ScalarExpr::Column(0),
955 distinct: false,
956 };
957 let expected = TypedPlan {
958 schema: RelationType::new(vec![
959 ColumnType::new(CDT::uint64_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ])
963 .with_key(vec![2])
964 .with_time_index(Some(1))
965 .into_named(vec![
966 Some("sum(numbers_with_ts.number)".to_string()),
967 Some("window_start".to_string()),
968 Some("window_end".to_string()),
969 ]),
970 plan: Plan::Mfp {
971 input: Box::new(
972 Plan::Reduce {
973 input: Box::new(
974 Plan::Get {
975 id: crate::expr::Id::Global(GlobalId::User(1)),
976 }
977 .with_types(
978 RelationType::new(vec![
979 ColumnType::new(ConcreteDataType::uint32_datatype(), false),
980 ColumnType::new(
981 ConcreteDataType::timestamp_millisecond_datatype(),
982 false,
983 ),
984 ])
985 .into_named(vec![
986 Some("number".to_string()),
987 Some("ts".to_string()),
988 ]),
989 )
990 .mfp(MapFilterProject::new(2).into_safe())
991 .unwrap(),
992 ),
993 key_val_plan: KeyValPlan {
994 key_plan: MapFilterProject::new(2)
995 .map(vec![
996 ScalarExpr::Column(1).call_unary(
997 UnaryFunc::TumbleWindowFloor {
998 window_size: Duration::from_nanos(3_600_000_000_000),
999 start_time: Some(Timestamp::new_millisecond(
1000 1625097600000,
1001 )),
1002 },
1003 ),
1004 ScalarExpr::Column(1).call_unary(
1005 UnaryFunc::TumbleWindowCeiling {
1006 window_size: Duration::from_nanos(3_600_000_000_000),
1007 start_time: Some(Timestamp::new_millisecond(
1008 1625097600000,
1009 )),
1010 },
1011 ),
1012 ])
1013 .unwrap()
1014 .project(vec![2, 3])
1015 .unwrap()
1016 .into_safe(),
1017 val_plan: MapFilterProject::new(2)
1018 .map(vec![ScalarExpr::Column(0).cast(CDT::uint64_datatype())])
1019 .unwrap()
1020 .project(vec![2])
1021 .unwrap()
1022 .into_safe(),
1023 },
1024 reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1025 full_aggrs: vec![aggr_expr.clone()],
1026 simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
1027 distinct_aggrs: vec![],
1028 }),
1029 }
1030 .with_types(
1031 RelationType::new(vec![
1032 ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::uint64_datatype(), true), ])
1036 .with_key(vec![1])
1037 .with_time_index(Some(0))
1038 .into_unnamed(),
1039 ),
1040 ),
1041 mfp: MapFilterProject::new(3)
1042 .map(vec![
1043 ScalarExpr::Column(2),
1044 ScalarExpr::Column(0),
1045 ScalarExpr::Column(1),
1046 ])
1047 .unwrap()
1048 .project(vec![3, 4, 5])
1049 .unwrap(),
1050 },
1051 };
1052 assert_eq!(flow_plan, expected);
1053 }
1054
1055 #[tokio::test]
1056 async fn test_avg_group_by() {
1057 let engine = create_test_query_engine();
1058 let sql = "SELECT avg(number), number FROM numbers GROUP BY number";
1059 let plan = sql_to_substrait(engine.clone(), sql).await;
1060
1061 let mut ctx = create_test_ctx();
1062 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
1063
1064 let aggr_exprs = vec![
1065 AggregateExpr {
1066 func: AggregateFunc::SumUInt64,
1067 expr: ScalarExpr::Column(0),
1068 distinct: false,
1069 },
1070 AggregateExpr {
1071 func: AggregateFunc::Count,
1072 expr: ScalarExpr::Column(1),
1073 distinct: false,
1074 },
1075 ];
1076 let avg_expr = ScalarExpr::If {
1077 cond: Box::new(ScalarExpr::Column(2).call_binary(
1078 ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()),
1079 BinaryFunc::NotEq,
1080 )),
1081 then: Box::new(
1082 ScalarExpr::Column(1)
1083 .cast(CDT::float64_datatype())
1084 .call_binary(
1085 ScalarExpr::Column(2).cast(CDT::float64_datatype()),
1086 BinaryFunc::DivFloat64,
1087 ),
1088 ),
1089 els: Box::new(ScalarExpr::Literal(Value::Null, CDT::float64_datatype())),
1090 };
1091 let expected = TypedPlan {
1092 schema: RelationType::new(vec![
1093 ColumnType::new(CDT::float64_datatype(), true), ColumnType::new(CDT::uint32_datatype(), false), ])
1096 .with_key(vec![1])
1097 .into_named(vec![
1098 Some("avg(numbers.number)".to_string()),
1099 Some("number".to_string()),
1100 ]),
1101 plan: Plan::Mfp {
1102 input: Box::new(
1103 Plan::Reduce {
1104 input: Box::new(
1105 Plan::Get {
1106 id: crate::expr::Id::Global(GlobalId::User(0)),
1107 }
1108 .with_types(
1109 RelationType::new(vec![ColumnType::new(
1110 ConcreteDataType::uint32_datatype(),
1111 false,
1112 )])
1113 .into_named(vec![Some("number".to_string())]),
1114 )
1115 .mfp(
1116 MapFilterProject::new(1)
1117 .project(vec![0])
1118 .unwrap()
1119 .into_safe(),
1120 )
1121 .unwrap(),
1122 ),
1123 key_val_plan: KeyValPlan {
1124 key_plan: MapFilterProject::new(1)
1125 .map(vec![ScalarExpr::Column(0)])
1126 .unwrap()
1127 .project(vec![1])
1128 .unwrap()
1129 .into_safe(),
1130 val_plan: MapFilterProject::new(1)
1131 .map(vec![
1132 ScalarExpr::Column(0).cast(CDT::uint64_datatype()),
1133 ScalarExpr::Column(0),
1134 ])
1135 .unwrap()
1136 .project(vec![1, 2])
1137 .unwrap()
1138 .into_safe(),
1139 },
1140 reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1141 full_aggrs: aggr_exprs.clone(),
1142 simple_aggrs: vec![
1143 AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
1144 AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1),
1145 ],
1146 distinct_aggrs: vec![],
1147 }),
1148 }
1149 .with_types(
1150 RelationType::new(vec![
1151 ColumnType::new(ConcreteDataType::uint32_datatype(), false), ColumnType::new(ConcreteDataType::uint64_datatype(), true), ColumnType::new(ConcreteDataType::int64_datatype(), true), ])
1155 .with_key(vec![0])
1156 .into_named(vec![
1157 Some("number".to_string()),
1158 None,
1159 None,
1160 ]),
1161 ),
1162 ),
1163 mfp: MapFilterProject::new(3)
1164 .map(vec![
1165 avg_expr, ScalarExpr::Column(0),
1167 ])
1169 .unwrap()
1170 .project(vec![3, 4])
1171 .unwrap(),
1172 },
1173 };
1174 assert_eq!(flow_plan.unwrap(), expected);
1175 }
1176
1177 #[tokio::test]
1178 async fn test_avg() {
1179 let engine = create_test_query_engine();
1180 let sql = "SELECT avg(number) FROM numbers";
1181 let plan = sql_to_substrait(engine.clone(), sql).await;
1182
1183 let mut ctx = create_test_ctx();
1184
1185 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
1186 .await
1187 .unwrap();
1188
1189 let aggr_exprs = vec![
1190 AggregateExpr {
1191 func: AggregateFunc::SumUInt64,
1192 expr: ScalarExpr::Column(0),
1193 distinct: false,
1194 },
1195 AggregateExpr {
1196 func: AggregateFunc::Count,
1197 expr: ScalarExpr::Column(1),
1198 distinct: false,
1199 },
1200 ];
1201 let avg_expr = ScalarExpr::If {
1202 cond: Box::new(ScalarExpr::Column(1).call_binary(
1203 ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()),
1204 BinaryFunc::NotEq,
1205 )),
1206 then: Box::new(
1207 ScalarExpr::Column(0)
1208 .cast(CDT::float64_datatype())
1209 .call_binary(
1210 ScalarExpr::Column(1).cast(CDT::float64_datatype()),
1211 BinaryFunc::DivFloat64,
1212 ),
1213 ),
1214 els: Box::new(ScalarExpr::Literal(Value::Null, CDT::float64_datatype())),
1215 };
1216 let input = Box::new(
1217 Plan::Get {
1218 id: crate::expr::Id::Global(GlobalId::User(0)),
1219 }
1220 .with_types(
1221 RelationType::new(vec![ColumnType::new(
1222 ConcreteDataType::uint32_datatype(),
1223 false,
1224 )])
1225 .into_named(vec![Some("number".to_string())]),
1226 ),
1227 );
1228 let expected = TypedPlan {
1229 schema: RelationType::new(vec![ColumnType::new(CDT::float64_datatype(), true)])
1230 .into_named(vec![Some("avg(numbers.number)".to_string())]),
1231 plan: Plan::Mfp {
1232 input: Box::new(
1233 Plan::Reduce {
1234 input: Box::new(
1235 Plan::Mfp {
1236 input: input.clone(),
1237 mfp: MapFilterProject::new(1).project(vec![0]).unwrap(),
1238 }
1239 .with_types(
1240 RelationType::new(vec![ColumnType::new(
1241 CDT::uint32_datatype(),
1242 false,
1243 )])
1244 .into_named(vec![Some("number".to_string())]),
1245 ),
1246 ),
1247 key_val_plan: KeyValPlan {
1248 key_plan: MapFilterProject::new(1)
1249 .project(vec![])
1250 .unwrap()
1251 .into_safe(),
1252 val_plan: MapFilterProject::new(1)
1253 .map(vec![
1254 ScalarExpr::Column(0).cast(CDT::uint64_datatype()),
1255 ScalarExpr::Column(0),
1256 ])
1257 .unwrap()
1258 .project(vec![1, 2])
1259 .unwrap()
1260 .into_safe(),
1261 },
1262 reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1263 full_aggrs: aggr_exprs.clone(),
1264 simple_aggrs: vec![
1265 AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
1266 AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1),
1267 ],
1268 distinct_aggrs: vec![],
1269 }),
1270 }
1271 .with_types(
1272 RelationType::new(vec![
1273 ColumnType::new(ConcreteDataType::uint64_datatype(), true), ColumnType::new(ConcreteDataType::int64_datatype(), true), ])
1276 .into_named(vec![None, None]),
1277 ),
1278 ),
1279 mfp: MapFilterProject::new(2)
1280 .map(vec![
1281 avg_expr,
1282 ])
1284 .unwrap()
1285 .project(vec![2])
1286 .unwrap(),
1287 },
1288 };
1289 assert_eq!(flow_plan, expected);
1290 }
1291
1292 #[tokio::test]
1293 async fn test_sum() {
1294 let engine = create_test_query_engine();
1295 let sql = "SELECT sum(number) FROM numbers";
1296 let plan = sql_to_substrait(engine.clone(), sql).await;
1297
1298 let mut ctx = create_test_ctx();
1299 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
1300
1301 let aggr_expr = AggregateExpr {
1302 func: AggregateFunc::SumUInt64,
1303 expr: ScalarExpr::Column(0),
1304 distinct: false,
1305 };
1306 let expected = TypedPlan {
1307 schema: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)])
1308 .into_named(vec![Some("sum(numbers.number)".to_string())]),
1309 plan: Plan::Reduce {
1310 input: Box::new(
1311 Plan::Get {
1312 id: crate::expr::Id::Global(GlobalId::User(0)),
1313 }
1314 .with_types(
1315 RelationType::new(vec![ColumnType::new(
1316 ConcreteDataType::uint32_datatype(),
1317 false,
1318 )])
1319 .into_named(vec![Some("number".to_string())]),
1320 )
1321 .mfp(MapFilterProject::new(1).into_safe())
1322 .unwrap(),
1323 ),
1324 key_val_plan: KeyValPlan {
1325 key_plan: MapFilterProject::new(1)
1326 .project(vec![])
1327 .unwrap()
1328 .into_safe(),
1329 val_plan: MapFilterProject::new(1)
1330 .map(vec![ScalarExpr::Column(0)
1331 .call_unary(UnaryFunc::Cast(CDT::uint64_datatype()))])
1332 .unwrap()
1333 .project(vec![1])
1334 .unwrap()
1335 .into_safe(),
1336 },
1337 reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1338 full_aggrs: vec![aggr_expr.clone()],
1339 simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
1340 distinct_aggrs: vec![],
1341 }),
1342 },
1343 };
1344 assert_eq!(flow_plan.unwrap(), expected);
1345 }
1346
1347 #[tokio::test]
1348 async fn test_distinct_number() {
1349 let engine = create_test_query_engine();
1350 let sql = "SELECT DISTINCT number FROM numbers";
1351 let plan = sql_to_substrait(engine.clone(), sql).await;
1352
1353 let mut ctx = create_test_ctx();
1354 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
1355 .await
1356 .unwrap();
1357
1358 let expected = TypedPlan {
1359 schema: RelationType::new(vec![
1360 ColumnType::new(CDT::uint32_datatype(), false), ])
1362 .with_key(vec![0])
1363 .into_named(vec![Some("number".to_string())]),
1364 plan: Plan::Reduce {
1365 input: Box::new(
1366 Plan::Get {
1367 id: crate::expr::Id::Global(GlobalId::User(0)),
1368 }
1369 .with_types(
1370 RelationType::new(vec![ColumnType::new(
1371 ConcreteDataType::uint32_datatype(),
1372 false,
1373 )])
1374 .into_named(vec![Some("number".to_string())]),
1375 )
1376 .mfp(MapFilterProject::new(1).into_safe())
1377 .unwrap(),
1378 ),
1379 key_val_plan: KeyValPlan {
1380 key_plan: MapFilterProject::new(1)
1381 .map(vec![ScalarExpr::Column(0)])
1382 .unwrap()
1383 .project(vec![1])
1384 .unwrap()
1385 .into_safe(),
1386 val_plan: MapFilterProject::new(1)
1387 .project(vec![0])
1388 .unwrap()
1389 .into_safe(),
1390 },
1391 reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1392 full_aggrs: vec![],
1393 simple_aggrs: vec![],
1394 distinct_aggrs: vec![],
1395 }),
1396 },
1397 };
1398
1399 assert_eq!(flow_plan, expected);
1400 }
1401
1402 #[tokio::test]
1403 async fn test_sum_group_by() {
1404 let engine = create_test_query_engine();
1405 let sql = "SELECT sum(number), number FROM numbers GROUP BY number";
1406 let plan = sql_to_substrait(engine.clone(), sql).await;
1407
1408 let mut ctx = create_test_ctx();
1409 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
1410 .await
1411 .unwrap();
1412
1413 let aggr_expr = AggregateExpr {
1414 func: AggregateFunc::SumUInt64,
1415 expr: ScalarExpr::Column(0),
1416 distinct: false,
1417 };
1418 let expected = TypedPlan {
1419 schema: RelationType::new(vec![
1420 ColumnType::new(CDT::uint64_datatype(), true), ColumnType::new(CDT::uint32_datatype(), false), ])
1423 .with_key(vec![1])
1424 .into_named(vec![
1425 Some("sum(numbers.number)".to_string()),
1426 Some("number".to_string()),
1427 ]),
1428 plan: Plan::Mfp {
1429 input: Box::new(
1430 Plan::Reduce {
1431 input: Box::new(
1432 Plan::Get {
1433 id: crate::expr::Id::Global(GlobalId::User(0)),
1434 }
1435 .with_types(
1436 RelationType::new(vec![ColumnType::new(
1437 ConcreteDataType::uint32_datatype(),
1438 false,
1439 )])
1440 .into_named(vec![Some("number".to_string())]),
1441 )
1442 .mfp(MapFilterProject::new(1).into_safe())
1443 .unwrap(),
1444 ),
1445 key_val_plan: KeyValPlan {
1446 key_plan: MapFilterProject::new(1)
1447 .map(vec![ScalarExpr::Column(0)])
1448 .unwrap()
1449 .project(vec![1])
1450 .unwrap()
1451 .into_safe(),
1452 val_plan: MapFilterProject::new(1)
1453 .map(vec![ScalarExpr::Column(0)
1454 .call_unary(UnaryFunc::Cast(CDT::uint64_datatype()))])
1455 .unwrap()
1456 .project(vec![1])
1457 .unwrap()
1458 .into_safe(),
1459 },
1460 reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1461 full_aggrs: vec![aggr_expr.clone()],
1462 simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
1463 distinct_aggrs: vec![],
1464 }),
1465 }
1466 .with_types(
1467 RelationType::new(vec![
1468 ColumnType::new(CDT::uint32_datatype(), false), ColumnType::new(CDT::uint64_datatype(), true), ])
1471 .with_key(vec![0])
1472 .into_named(vec![Some("number".to_string()), None]),
1473 ),
1474 ),
1475 mfp: MapFilterProject::new(2)
1476 .map(vec![ScalarExpr::Column(1), ScalarExpr::Column(0)])
1477 .unwrap()
1478 .project(vec![2, 3])
1479 .unwrap(),
1480 },
1481 };
1482
1483 assert_eq!(flow_plan, expected);
1484 }
1485
1486 #[tokio::test]
1487 async fn test_sum_add() {
1488 let engine = create_test_query_engine();
1489 let sql = "SELECT sum(number+number) FROM numbers";
1490 let plan = sql_to_substrait(engine.clone(), sql).await;
1491
1492 let mut ctx = create_test_ctx();
1493 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
1494
1495 let aggr_expr = AggregateExpr {
1496 func: AggregateFunc::SumUInt64,
1497 expr: ScalarExpr::Column(0),
1498 distinct: false,
1499 };
1500 let expected = TypedPlan {
1501 schema: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)])
1502 .into_named(vec![Some(
1503 "sum(numbers.number + numbers.number)".to_string(),
1504 )]),
1505 plan: Plan::Reduce {
1506 input: Box::new(
1507 Plan::Mfp {
1508 input: Box::new(
1509 Plan::Get {
1510 id: crate::expr::Id::Global(GlobalId::User(0)),
1511 }
1512 .with_types(
1513 RelationType::new(vec![ColumnType::new(
1514 ConcreteDataType::uint32_datatype(),
1515 false,
1516 )])
1517 .into_named(vec![Some("number".to_string())]),
1518 ),
1519 ),
1520 mfp: MapFilterProject::new(1),
1521 }
1522 .with_types(
1523 RelationType::new(vec![ColumnType::new(
1524 ConcreteDataType::uint32_datatype(),
1525 false,
1526 )])
1527 .into_named(vec![Some("number".to_string())]),
1528 ),
1529 ),
1530 key_val_plan: KeyValPlan {
1531 key_plan: MapFilterProject::new(1)
1532 .project(vec![])
1533 .unwrap()
1534 .into_safe(),
1535 val_plan: MapFilterProject::new(1)
1536 .map(vec![ScalarExpr::Column(0)
1537 .call_binary(ScalarExpr::Column(0), BinaryFunc::AddUInt32)
1538 .call_unary(UnaryFunc::Cast(CDT::uint64_datatype()))])
1539 .unwrap()
1540 .project(vec![1])
1541 .unwrap()
1542 .into_safe(),
1543 },
1544 reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1545 full_aggrs: vec![aggr_expr.clone()],
1546 simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
1547 distinct_aggrs: vec![],
1548 }),
1549 },
1550 };
1551 assert_eq!(flow_plan.unwrap(), expected);
1552 }
1553
1554 #[tokio::test]
1555 async fn test_cast_max_min() {
1556 let engine = create_test_query_engine();
1557 let sql = "SELECT (max(number) - min(number))/30.0, date_bin(INTERVAL '30 second', CAST(ts AS TimestampMillisecond)) as time_window from numbers_with_ts GROUP BY time_window";
1558 let plan = sql_to_substrait(engine.clone(), sql).await;
1559
1560 let mut ctx = create_test_ctx();
1561 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
1562
1563 let aggr_exprs = vec![
1564 AggregateExpr {
1565 func: AggregateFunc::MaxUInt32,
1566 expr: ScalarExpr::Column(0),
1567 distinct: false,
1568 },
1569 AggregateExpr {
1570 func: AggregateFunc::MinUInt32,
1571 expr: ScalarExpr::Column(0),
1572 distinct: false,
1573 },
1574 ];
1575 let expected = TypedPlan {
1576 schema: RelationType::new(vec![
1577 ColumnType::new(CDT::float64_datatype(), true),
1578 ColumnType::new(CDT::timestamp_millisecond_datatype(), true),
1579 ])
1580 .with_time_index(Some(1))
1581 .into_named(vec![
1582 Some(
1583 "max(numbers_with_ts.number) - min(numbers_with_ts.number) / Float64(30)"
1584 .to_string(),
1585 ),
1586 Some("time_window".to_string()),
1587 ]),
1588 plan: Plan::Mfp {
1589 input: Box::new(
1590 Plan::Reduce {
1591 input: Box::new(
1592 Plan::Get {
1593 id: crate::expr::Id::Global(GlobalId::User(1)),
1594 }
1595 .with_types(
1596 RelationType::new(vec![
1597 ColumnType::new(ConcreteDataType::uint32_datatype(), false),
1598 ColumnType::new(ConcreteDataType::timestamp_millisecond_datatype(), false),
1599 ])
1600 .into_named(vec![
1601 Some("number".to_string()),
1602 Some("ts".to_string()),
1603 ]),
1604 )
1605 .mfp(MapFilterProject::new(2).into_safe())
1606 .unwrap(),
1607 ),
1608
1609 key_val_plan: KeyValPlan {
1610 key_plan: MapFilterProject::new(2)
1611 .map(vec![ScalarExpr::CallDf {
1612 df_scalar_fn: DfScalarFunction::try_from_raw_fn(
1613 RawDfScalarFn {
1614 f: BytesMut::from(
1615 b"\x08\x02\"\x0f\x1a\r\n\x0b\xa2\x02\x08\n\0\x12\x04\x10\x1e \t\"\n\x1a\x08\x12\x06\n\x04\x12\x02\x08\x01".as_ref(),
1616 ),
1617 input_schema: RelationType::new(vec![ColumnType::new(
1618 ConcreteDataType::interval_month_day_nano_datatype(),
1619 true,
1620 ),ColumnType::new(
1621 ConcreteDataType::timestamp_millisecond_datatype(),
1622 false,
1623 )])
1624 .into_unnamed(),
1625 extensions: FunctionExtensions::from_iter([
1626 (0, "subtract".to_string()),
1627 (1, "divide".to_string()),
1628 (2, "date_bin".to_string()),
1629 (3, "max".to_string()),
1630 (4, "min".to_string()),
1631 ]),
1632 },
1633 )
1634 .await
1635 .unwrap(),
1636 exprs: vec![
1637 ScalarExpr::Literal(
1638 Value::IntervalMonthDayNano(IntervalMonthDayNano::new(0, 0, 30000000000)),
1639 CDT::interval_month_day_nano_datatype()
1640 ),
1641 ScalarExpr::Column(1)
1642 ],
1643 }])
1644 .unwrap()
1645 .project(vec![2])
1646 .unwrap()
1647 .into_safe(),
1648 val_plan: MapFilterProject::new(2)
1649 .into_safe(),
1650 },
1651 reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1652 full_aggrs: aggr_exprs.clone(),
1653 simple_aggrs: vec![AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
1654 AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1)],
1655 distinct_aggrs: vec![],
1656 }),
1657 }
1658 .with_types(
1659 RelationType::new(vec![
1660 ColumnType::new(
1661 ConcreteDataType::timestamp_millisecond_datatype(),
1662 true,
1663 ), ColumnType::new(ConcreteDataType::uint32_datatype(), true), ColumnType::new(ConcreteDataType::uint32_datatype(), true), ])
1667 .with_time_index(Some(0))
1668 .into_unnamed(),
1669 ),
1670 ),
1671 mfp: MapFilterProject::new(3)
1672 .map(vec![
1673 ScalarExpr::Column(1)
1674 .call_binary(ScalarExpr::Column(2), BinaryFunc::SubUInt32)
1675 .cast(CDT::float64_datatype())
1676 .call_binary(
1677 ScalarExpr::Literal(Value::from(30.0f64), CDT::float64_datatype()),
1678 BinaryFunc::DivFloat64,
1679 ),
1680 ScalarExpr::Column(0),
1681 ])
1682 .unwrap()
1683 .project(vec![3, 4])
1684 .unwrap(),
1685 },
1686 };
1687
1688 assert_eq!(flow_plan.unwrap(), expected);
1689 }
1690}