1use itertools::Itertools;
16use snafu::OptionExt;
17use substrait_proto::proto;
18use substrait_proto::proto::aggregate_function::AggregationInvocation;
19use substrait_proto::proto::aggregate_rel::{Grouping, Measure};
20use substrait_proto::proto::function_argument::ArgType;
21
22use crate::error::{Error, NotImplementedSnafu, PlanSnafu};
23use crate::expr::{
24 AggregateExpr, AggregateFunc, MapFilterProject, ScalarExpr, TypedExpr, UnaryFunc,
25};
26use crate::plan::{AccumulablePlan, AggrWithIndex, KeyValPlan, Plan, ReducePlan, TypedPlan};
27use crate::repr::{ColumnType, RelationDesc, RelationType};
28use crate::transform::{FlownodeContext, FunctionExtensions, substrait_proto};
29
30impl TypedExpr {
31 #[allow(deprecated)]
33 async fn from_substrait_agg_grouping(
34 ctx: &mut FlownodeContext,
35 grouping_expressions: &[proto::Expression],
36 groupings: &[Grouping],
37 typ: &RelationDesc,
38 extensions: &FunctionExtensions,
39 ) -> Result<Vec<TypedExpr>, Error> {
40 let _ = ctx;
41 let mut group_expr = vec![];
42 match groupings.len() {
43 1 => {
44 let expressions: Box<dyn Iterator<Item = &proto::Expression> + Send> = if groupings
46 [0]
47 .expression_references
48 .is_empty()
49 {
50 Box::new(groupings[0].grouping_expressions.iter())
51 } else {
52 if groupings[0]
53 .expression_references
54 .iter()
55 .any(|idx| *idx as usize >= grouping_expressions.len())
56 {
57 return PlanSnafu {
58 reason: format!("Invalid grouping expression reference: {:?} for grouping expr: {:?}",
59 groupings[0].expression_references,
60 grouping_expressions
61 ),
62 }.fail()?;
63 }
64 Box::new(
65 groupings[0]
66 .expression_references
67 .iter()
68 .map(|idx| &grouping_expressions[*idx as usize]),
69 )
70 };
71 for e in expressions {
72 let x = TypedExpr::from_substrait_rex(e, typ, extensions).await?;
73 group_expr.push(x);
74 }
75 }
76 _ => {
77 return not_impl_err!(
78 "Grouping sets not support yet, use union all with group by instead."
79 );
80 }
81 };
82 Ok(group_expr)
83 }
84}
85
86impl AggregateExpr {
87 async fn from_substrait_agg_measures(
91 ctx: &mut FlownodeContext,
92 measures: &[Measure],
93 typ: &RelationDesc,
94 extensions: &FunctionExtensions,
95 ) -> Result<Vec<AggregateExpr>, Error> {
96 let _ = ctx;
97 let mut all_aggr_exprs = vec![];
98
99 for m in measures {
100 let filter = match m
101 .filter
102 .as_ref()
103 .map(|fil| TypedExpr::from_substrait_rex(fil, typ, extensions))
104 {
105 Some(fut) => Some(fut.await),
106 None => None,
107 }
108 .transpose()?;
109
110 let aggr_expr = match &m.measure {
111 Some(f) => {
112 let distinct = match f.invocation {
113 _ if f.invocation == AggregationInvocation::Distinct as i32 => true,
114 _ if f.invocation == AggregationInvocation::All as i32 => false,
115 _ => false,
116 };
117 AggregateExpr::from_substrait_agg_func(
118 f, typ, extensions, &filter, &None, distinct,
120 )
121 .await?
122 }
123 None => {
124 return not_impl_err!("Aggregate without aggregate function is not supported");
125 }
126 };
127
128 all_aggr_exprs.extend(aggr_expr);
129 }
130
131 Ok(all_aggr_exprs)
132 }
133
134 pub async fn from_substrait_agg_func(
139 f: &proto::AggregateFunction,
140 input_schema: &RelationDesc,
141 extensions: &FunctionExtensions,
142 filter: &Option<TypedExpr>,
143 order_by: &Option<Vec<TypedExpr>>,
144 distinct: bool,
145 ) -> Result<Vec<AggregateExpr>, Error> {
146 let _ = filter;
148 let _ = order_by;
149 let mut args = vec![];
150 for arg in &f.arguments {
151 let arg_expr = match &arg.arg_type {
152 Some(ArgType::Value(e)) => {
153 TypedExpr::from_substrait_rex(e, input_schema, extensions).await
154 }
155 _ => not_impl_err!("Aggregated function argument non-Value type not supported"),
156 }?;
157 args.push(arg_expr);
158 }
159
160 if args.len() != 1 {
161 let fn_name = extensions.get(&f.function_reference).cloned();
162 return not_impl_err!(
163 "Aggregated function (name={:?}) with multiple arguments is not supported",
164 fn_name
165 );
166 }
167
168 let arg = if let Some(first) = args.first() {
169 first
170 } else {
171 return not_impl_err!("Aggregated function without arguments is not supported");
172 };
173
174 let fn_name = extensions
175 .get(&f.function_reference)
176 .cloned()
177 .map(|s| s.to_lowercase());
178
179 match fn_name.as_ref().map(|s| s.as_ref()) {
180 Some(function_name) => {
181 let func = AggregateFunc::from_str_and_type(
182 function_name,
183 Some(arg.typ.scalar_type.clone()),
184 )?;
185 let exprs = vec![AggregateExpr {
186 func,
187 expr: arg.expr.clone(),
188 distinct,
189 }];
190 Ok(exprs)
191 }
192 None => not_impl_err!(
193 "Aggregated function not found: function anchor = {:?}",
194 f.function_reference
195 ),
196 }
197 }
198}
199
200impl KeyValPlan {
201 fn from_substrait_gen_key_val_plan(
205 aggr_exprs: &mut [AggregateExpr],
206 group_exprs: &[TypedExpr],
207 input_arity: usize,
208 ) -> Result<KeyValPlan, Error> {
209 let group_expr_val = group_exprs
210 .iter()
211 .cloned()
212 .map(|expr| expr.expr.clone())
213 .collect_vec();
214 let output_arity = group_expr_val.len();
215 let key_plan = MapFilterProject::new(input_arity)
216 .map(group_expr_val)?
217 .project(input_arity..input_arity + output_arity)?;
218
219 let val_plan = {
222 let need_mfp = aggr_exprs.iter().any(|agg| agg.expr.as_column().is_none());
223 if need_mfp {
224 let input_exprs = aggr_exprs
226 .iter_mut()
227 .enumerate()
228 .map(|(idx, aggr)| {
229 let ret = aggr.expr.clone();
230 aggr.expr = ScalarExpr::Column(idx);
231 ret
232 })
233 .collect_vec();
234 let aggr_arity = aggr_exprs.len();
235
236 MapFilterProject::new(input_arity)
237 .map(input_exprs)?
238 .project(input_arity..input_arity + aggr_arity)?
239 } else {
240 MapFilterProject::new(input_arity)
242 }
243 };
244 Ok(KeyValPlan {
245 key_plan: key_plan.into_safe(),
246 val_plan: val_plan.into_safe(),
247 })
248 }
249}
250
251fn find_time_index_in_group_exprs(group_exprs: &[TypedExpr]) -> Option<usize> {
255 group_exprs.iter().position(|expr| {
256 matches!(
257 &expr.expr,
258 ScalarExpr::CallUnary {
259 func: UnaryFunc::TumbleWindowFloor { .. },
260 expr: _
261 }
262 ) || expr.typ.scalar_type.is_timestamp()
263 })
264}
265
266impl TypedPlan {
267 #[async_recursion::async_recursion]
273 pub async fn from_substrait_agg_rel(
274 ctx: &mut FlownodeContext,
275 agg: &proto::AggregateRel,
276 extensions: &FunctionExtensions,
277 ) -> Result<TypedPlan, Error> {
278 let input = if let Some(input) = agg.input.as_ref() {
279 TypedPlan::from_substrait_rel(ctx, input, extensions).await?
280 } else {
281 return not_impl_err!("Aggregate without an input is not supported");
282 };
283
284 let group_exprs = TypedExpr::from_substrait_agg_grouping(
285 ctx,
286 &agg.grouping_expressions,
287 &agg.groupings,
288 &input.schema,
289 extensions,
290 )
291 .await?;
292
293 let time_index = find_time_index_in_group_exprs(&group_exprs);
294
295 let mut aggr_exprs = AggregateExpr::from_substrait_agg_measures(
296 ctx,
297 &agg.measures,
298 &input.schema,
299 extensions,
300 )
301 .await?;
302
303 let key_val_plan = KeyValPlan::from_substrait_gen_key_val_plan(
304 &mut aggr_exprs,
305 &group_exprs,
306 input.schema.typ.column_types.len(),
307 )?;
308
309 let output_type = {
311 let mut output_types = Vec::new();
312 let mut output_names = Vec::new();
314
315 for expr in group_exprs.iter() {
317 output_types.push(expr.typ.clone());
318 let col_name = match &expr.expr {
319 ScalarExpr::Column(col) => input.schema.get_name(*col).clone(),
320 _ => None,
322 };
323 output_names.push(col_name)
324 }
325
326 for aggr in &aggr_exprs {
327 output_types.push(ColumnType::new_nullable(
328 aggr.func.signature().output.clone(),
329 ));
330 output_names.push(None);
332 }
333 if group_exprs.is_empty() {
335 RelationType::new(output_types)
336 } else {
337 RelationType::new(output_types).with_key((0..group_exprs.len()).collect_vec())
338 }
339 .with_time_index(time_index)
340 .into_named(output_names)
341 };
342
343 let full_aggrs = aggr_exprs;
346 let mut simple_aggrs = Vec::new();
347 let mut distinct_aggrs = Vec::new();
348 for (output_column, aggr_expr) in full_aggrs.iter().enumerate() {
349 let input_column = aggr_expr.expr.as_column().with_context(|| PlanSnafu {
350 reason: "Expect aggregate argument to be transformed into a column at this point",
351 })?;
352 if aggr_expr.distinct {
353 distinct_aggrs.push(AggrWithIndex::new(
354 aggr_expr.clone(),
355 input_column,
356 output_column,
357 ));
358 } else {
359 simple_aggrs.push(AggrWithIndex::new(
360 aggr_expr.clone(),
361 input_column,
362 output_column,
363 ));
364 }
365 }
366 let accum_plan = AccumulablePlan {
367 full_aggrs,
368 simple_aggrs,
369 distinct_aggrs,
370 };
371 let plan = Plan::Reduce {
372 input: Box::new(input),
373 key_val_plan,
374 reduce_plan: ReducePlan::Accumulable(accum_plan),
375 };
376 return Ok(TypedPlan {
378 schema: output_type,
379 plan,
380 });
381 }
382}
383
384#[cfg(test)]
385mod test {
386 use std::time::Duration;
387
388 use bytes::BytesMut;
389 use common_time::{IntervalMonthDayNano, Timestamp};
390 use datatypes::data_type::ConcreteDataType as CDT;
391 use datatypes::prelude::ConcreteDataType;
392 use datatypes::value::Value;
393 use pretty_assertions::assert_eq;
394
395 use super::*;
396 use crate::expr::{BinaryFunc, DfScalarFunction, GlobalId, RawDfScalarFn};
397 use crate::plan::{Plan, TypedPlan};
398 use crate::repr::{ColumnType, RelationType};
399 use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait};
400
401 #[tokio::test]
402 async fn test_df_func_basic() {
403 let engine = create_test_query_engine();
404 let sql = "SELECT sum(abs(number)) FROM numbers_with_ts GROUP BY tumble(ts, '1 second', '2021-07-01 00:00:00');";
405 let plan = sql_to_substrait(engine.clone(), sql).await;
406
407 let mut ctx = create_test_ctx();
408 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
409 .await
410 .unwrap();
411
412 let aggr_expr = AggregateExpr {
413 func: AggregateFunc::SumUInt64,
414 expr: ScalarExpr::Column(0),
415 distinct: false,
416 };
417 let expected =
418 TypedPlan {
419 schema: RelationType::new(vec![
420 ColumnType::new(CDT::uint64_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ])
424 .with_key(vec![2])
425 .with_time_index(Some(1))
426 .into_named(vec![
427 Some("sum(abs(numbers_with_ts.number))".to_string()),
428 Some("window_start".to_string()),
429 Some("window_end".to_string()),
430 ]),
431 plan: Plan::Mfp {
432 input: Box::new(
433 Plan::Reduce {
434 input: Box::new(
435 Plan::Get {
436 id: crate::expr::Id::Global(GlobalId::User(1)),
437 }
438 .with_types(
439 RelationType::new(vec![
440 ColumnType::new(ConcreteDataType::uint32_datatype(), false),
441 ColumnType::new(
442 ConcreteDataType::timestamp_millisecond_datatype(),
443 false,
444 ),
445 ])
446 .into_named(vec![
447 Some("number".to_string()),
448 Some("ts".to_string()),
449 ]),
450 )
451 .mfp(MapFilterProject::new(2).into_safe())
452 .unwrap(),
453 ),
454 key_val_plan: KeyValPlan {
455 key_plan: MapFilterProject::new(2)
456 .map(vec![
457 ScalarExpr::Column(1).call_unary(
458 UnaryFunc::TumbleWindowFloor {
459 window_size: Duration::from_nanos(1_000_000_000),
460 start_time: Some(Timestamp::new_millisecond(
461 1625097600000,
462 )),
463 },
464 ),
465 ScalarExpr::Column(1).call_unary(
466 UnaryFunc::TumbleWindowCeiling {
467 window_size: Duration::from_nanos(1_000_000_000),
468 start_time: Some(Timestamp::new_millisecond(
469 1625097600000,
470 )),
471 },
472 ),
473 ])
474 .unwrap()
475 .project(vec![2, 3])
476 .unwrap()
477 .into_safe(),
478 val_plan: MapFilterProject::new(2)
479 .map(vec![ScalarExpr::CallDf {
480 df_scalar_fn: DfScalarFunction::try_from_raw_fn(
481 RawDfScalarFn {
482 f: BytesMut::from(
483 b"\x08\x02\"\x08\x1a\x06\x12\x04\n\x02\x12\0"
484 .as_ref(),
485 ),
486 input_schema: RelationType::new(vec![ColumnType::new(
487 ConcreteDataType::uint32_datatype(),
488 false,
489 )])
490 .into_unnamed(),
491 extensions: FunctionExtensions::from_iter(
492 [
493 (0, "tumble_start".to_string()),
494 (1, "tumble_end".to_string()),
495 (2, "abs".to_string()),
496 (3, "sum".to_string()),
497 ]
498 .into_iter(),
499 ),
500 },
501 )
502 .await
503 .unwrap(),
504 exprs: vec![ScalarExpr::Column(0)],
505 }
506 .cast(CDT::uint64_datatype())])
507 .unwrap()
508 .project(vec![2])
509 .unwrap()
510 .into_safe(),
511 },
512 reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
513 full_aggrs: vec![aggr_expr.clone()],
514 simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
515 distinct_aggrs: vec![],
516 }),
517 }
518 .with_types(
519 RelationType::new(vec![
520 ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::uint64_datatype(), true), ])
524 .with_key(vec![1])
525 .with_time_index(Some(0))
526 .into_unnamed(),
527 ),
528 ),
529 mfp: MapFilterProject::new(3)
530 .map(vec![
531 ScalarExpr::Column(2),
532 ScalarExpr::Column(0),
533 ScalarExpr::Column(1),
534 ])
535 .unwrap()
536 .project(vec![3, 4, 5])
537 .unwrap(),
538 },
539 };
540 assert_eq!(flow_plan, expected);
541 }
542
543 #[tokio::test]
544 async fn test_df_func_expr_tree() {
545 let engine = create_test_query_engine();
546 let sql = "SELECT abs(sum(number)) FROM numbers_with_ts GROUP BY tumble(ts, '1 second', '2021-07-01 00:00:00');";
547 let plan = sql_to_substrait(engine.clone(), sql).await;
548
549 let mut ctx = create_test_ctx();
550 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
551 .await
552 .unwrap();
553
554 let aggr_expr = AggregateExpr {
555 func: AggregateFunc::SumUInt64,
556 expr: ScalarExpr::Column(0),
557 distinct: false,
558 };
559 let expected = TypedPlan {
560 schema: RelationType::new(vec![
561 ColumnType::new(CDT::uint64_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ])
565 .with_key(vec![2])
566 .with_time_index(Some(1))
567 .into_named(vec![
568 Some("abs(sum(numbers_with_ts.number))".to_string()),
569 Some("window_start".to_string()),
570 Some("window_end".to_string()),
571 ]),
572 plan: Plan::Mfp {
573 input: Box::new(
574 Plan::Reduce {
575 input: Box::new(
576 Plan::Get {
577 id: crate::expr::Id::Global(GlobalId::User(1)),
578 }
579 .with_types(
580 RelationType::new(vec![
581 ColumnType::new(ConcreteDataType::uint32_datatype(), false),
582 ColumnType::new(
583 ConcreteDataType::timestamp_millisecond_datatype(),
584 false,
585 ),
586 ])
587 .into_named(vec![
588 Some("number".to_string()),
589 Some("ts".to_string()),
590 ]),
591 )
592 .mfp(MapFilterProject::new(2).into_safe())
593 .unwrap(),
594 ),
595 key_val_plan: KeyValPlan {
596 key_plan: MapFilterProject::new(2)
597 .map(vec![
598 ScalarExpr::Column(1).call_unary(
599 UnaryFunc::TumbleWindowFloor {
600 window_size: Duration::from_nanos(1_000_000_000),
601 start_time: Some(Timestamp::new_millisecond(
602 1625097600000,
603 )),
604 },
605 ),
606 ScalarExpr::Column(1).call_unary(
607 UnaryFunc::TumbleWindowCeiling {
608 window_size: Duration::from_nanos(1_000_000_000),
609 start_time: Some(Timestamp::new_millisecond(
610 1625097600000,
611 )),
612 },
613 ),
614 ])
615 .unwrap()
616 .project(vec![2, 3])
617 .unwrap()
618 .into_safe(),
619 val_plan: MapFilterProject::new(2)
620 .map(vec![ScalarExpr::Column(0).cast(CDT::uint64_datatype())])
621 .unwrap()
622 .project(vec![2])
623 .unwrap()
624 .into_safe(),
625 },
626 reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
627 full_aggrs: vec![aggr_expr.clone()],
628 simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
629 distinct_aggrs: vec![],
630 }),
631 }
632 .with_types(
633 RelationType::new(vec![
634 ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::uint64_datatype(), true), ])
638 .with_key(vec![1])
639 .with_time_index(Some(0))
640 .into_named(vec![None, None, None]),
641 ),
642 ),
643 mfp: MapFilterProject::new(3)
644 .map(vec![
645 ScalarExpr::CallDf {
646 df_scalar_fn: DfScalarFunction::try_from_raw_fn(RawDfScalarFn {
647 f: BytesMut::from(b"\"\x08\x1a\x06\x12\x04\n\x02\x12\0".as_ref()),
648 input_schema: RelationType::new(vec![ColumnType::new(
649 ConcreteDataType::uint64_datatype(),
650 true,
651 )])
652 .into_unnamed(),
653 extensions: FunctionExtensions::from_iter(
654 [
655 (0, "abs".to_string()),
656 (1, "tumble_start".to_string()),
657 (2, "tumble_end".to_string()),
658 (3, "sum".to_string()),
659 ]
660 .into_iter(),
661 ),
662 })
663 .await
664 .unwrap(),
665 exprs: vec![ScalarExpr::Column(2)],
666 },
667 ScalarExpr::Column(0),
668 ScalarExpr::Column(1),
669 ])
670 .unwrap()
671 .project(vec![3, 4, 5])
672 .unwrap(),
673 },
674 };
675 assert_eq!(flow_plan, expected);
676 }
677
678 #[tokio::test]
680 async fn test_tumble_composite() {
681 let engine = create_test_query_engine();
682 let sql =
683 "SELECT number, avg(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour'), number";
684 let plan = sql_to_substrait(engine.clone(), sql).await;
685
686 let mut ctx = create_test_ctx();
687 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
688 .await
689 .unwrap();
690
691 let aggr_exprs = vec![
692 AggregateExpr {
693 func: AggregateFunc::SumUInt64,
694 expr: ScalarExpr::Column(0),
695 distinct: false,
696 },
697 AggregateExpr {
698 func: AggregateFunc::Count,
699 expr: ScalarExpr::Column(1),
700 distinct: false,
701 },
702 ];
703 let avg_expr = ScalarExpr::If {
704 cond: Box::new(ScalarExpr::Column(4).call_binary(
705 ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()),
706 BinaryFunc::NotEq,
707 )),
708 then: Box::new(
709 ScalarExpr::Column(3)
710 .cast(CDT::float64_datatype())
711 .call_binary(
712 ScalarExpr::Column(4).cast(CDT::float64_datatype()),
713 BinaryFunc::DivFloat64,
714 ),
715 ),
716 els: Box::new(ScalarExpr::Literal(Value::Null, CDT::float64_datatype())),
717 };
718 let expected = TypedPlan {
719 plan: Plan::Mfp {
720 input: Box::new(
721 Plan::Reduce {
722 input: Box::new(
723 Plan::Get {
724 id: crate::expr::Id::Global(GlobalId::User(1)),
725 }
726 .with_types(
727 RelationType::new(vec![
728 ColumnType::new(ConcreteDataType::uint32_datatype(), false),
729 ColumnType::new(
730 ConcreteDataType::timestamp_millisecond_datatype(),
731 false,
732 ),
733 ])
734 .into_named(vec![
735 Some("number".to_string()),
736 Some("ts".to_string()),
737 ]),
738 )
739 .mfp(MapFilterProject::new(2).into_safe())
740 .unwrap(),
741 ),
742 key_val_plan: KeyValPlan {
743 key_plan: MapFilterProject::new(2)
744 .map(vec![
745 ScalarExpr::Column(1).call_unary(
746 UnaryFunc::TumbleWindowFloor {
747 window_size: Duration::from_nanos(3_600_000_000_000),
748 start_time: None,
749 },
750 ),
751 ScalarExpr::Column(1).call_unary(
752 UnaryFunc::TumbleWindowCeiling {
753 window_size: Duration::from_nanos(3_600_000_000_000),
754 start_time: None,
755 },
756 ),
757 ScalarExpr::Column(0),
758 ])
759 .unwrap()
760 .project(vec![2, 3, 4])
761 .unwrap()
762 .into_safe(),
763 val_plan: MapFilterProject::new(2)
764 .map(vec![
765 ScalarExpr::Column(0).cast(CDT::uint64_datatype()),
766 ScalarExpr::Column(0),
767 ])
768 .unwrap()
769 .project(vec![2, 3])
770 .unwrap()
771 .into_safe(),
772 },
773 reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
774 full_aggrs: aggr_exprs.clone(),
775 simple_aggrs: vec![
776 AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
777 AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1),
778 ],
779 distinct_aggrs: vec![],
780 }),
781 }
782 .with_types(
783 RelationType::new(vec![
784 ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::uint32_datatype(), false), ColumnType::new(CDT::uint64_datatype(), true), ColumnType::new(CDT::int64_datatype(), true), ])
792 .with_key(vec![1, 2])
793 .with_time_index(Some(0))
794 .into_named(vec![
795 None,
796 None,
797 Some("number".to_string()),
798 None,
799 None,
800 ]),
801 ),
802 ),
803 mfp: MapFilterProject::new(5)
804 .map(vec![
805 ScalarExpr::Column(2), avg_expr,
807 ScalarExpr::Column(0), ScalarExpr::Column(1), ])
810 .unwrap()
811 .project(vec![5, 6, 7, 8])
812 .unwrap(),
813 },
814 schema: RelationType::new(vec![
815 ColumnType::new(CDT::uint32_datatype(), false), ColumnType::new(CDT::float64_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ])
820 .with_key(vec![0, 3])
821 .with_time_index(Some(2))
822 .into_named(vec![
823 Some("number".to_string()),
824 Some("avg(numbers_with_ts.number)".to_string()),
825 Some("window_start".to_string()),
826 Some("window_end".to_string()),
827 ]),
828 };
829 assert_eq!(flow_plan, expected);
830 }
831
832 #[tokio::test]
833 async fn test_tumble_parse_optional() {
834 let engine = create_test_query_engine();
835 let sql = "SELECT sum(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour')";
836 let plan = sql_to_substrait(engine.clone(), sql).await;
837
838 let mut ctx = create_test_ctx();
839 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
840 .await
841 .unwrap();
842
843 let aggr_expr = AggregateExpr {
844 func: AggregateFunc::SumUInt64,
845 expr: ScalarExpr::Column(0),
846 distinct: false,
847 };
848 let expected = TypedPlan {
849 schema: RelationType::new(vec![
850 ColumnType::new(CDT::uint64_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ])
854 .with_key(vec![2])
855 .with_time_index(Some(1))
856 .into_named(vec![
857 Some("sum(numbers_with_ts.number)".to_string()),
858 Some("window_start".to_string()),
859 Some("window_end".to_string()),
860 ]),
861 plan: Plan::Mfp {
862 input: Box::new(
863 Plan::Reduce {
864 input: Box::new(
865 Plan::Get {
866 id: crate::expr::Id::Global(GlobalId::User(1)),
867 }
868 .with_types(
869 RelationType::new(vec![
870 ColumnType::new(ConcreteDataType::uint32_datatype(), false),
871 ColumnType::new(
872 ConcreteDataType::timestamp_millisecond_datatype(),
873 false,
874 ),
875 ])
876 .into_named(vec![
877 Some("number".to_string()),
878 Some("ts".to_string()),
879 ]),
880 )
881 .mfp(MapFilterProject::new(2).into_safe())
882 .unwrap(),
883 ),
884 key_val_plan: KeyValPlan {
885 key_plan: MapFilterProject::new(2)
886 .map(vec![
887 ScalarExpr::Column(1).call_unary(
888 UnaryFunc::TumbleWindowFloor {
889 window_size: Duration::from_nanos(3_600_000_000_000),
890 start_time: None,
891 },
892 ),
893 ScalarExpr::Column(1).call_unary(
894 UnaryFunc::TumbleWindowCeiling {
895 window_size: Duration::from_nanos(3_600_000_000_000),
896 start_time: None,
897 },
898 ),
899 ])
900 .unwrap()
901 .project(vec![2, 3])
902 .unwrap()
903 .into_safe(),
904 val_plan: MapFilterProject::new(2)
905 .map(vec![ScalarExpr::Column(0).cast(CDT::uint64_datatype())])
906 .unwrap()
907 .project(vec![2])
908 .unwrap()
909 .into_safe(),
910 },
911 reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
912 full_aggrs: vec![aggr_expr.clone()],
913 simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
914 distinct_aggrs: vec![],
915 }),
916 }
917 .with_types(
918 RelationType::new(vec![
919 ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::uint64_datatype(), true), ])
923 .with_key(vec![1])
924 .with_time_index(Some(0))
925 .into_named(vec![None, None, None]),
926 ),
927 ),
928 mfp: MapFilterProject::new(3)
929 .map(vec![
930 ScalarExpr::Column(2),
931 ScalarExpr::Column(0),
932 ScalarExpr::Column(1),
933 ])
934 .unwrap()
935 .project(vec![3, 4, 5])
936 .unwrap(),
937 },
938 };
939 assert_eq!(flow_plan, expected);
940 }
941
942 #[tokio::test]
943 async fn test_tumble_parse() {
944 let engine = create_test_query_engine();
945 let sql = "SELECT sum(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour', '2021-07-01 00:00:00')";
946 let plan = sql_to_substrait(engine.clone(), sql).await;
947
948 let mut ctx = create_test_ctx();
949 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
950 .await
951 .unwrap();
952
953 let aggr_expr = AggregateExpr {
954 func: AggregateFunc::SumUInt64,
955 expr: ScalarExpr::Column(0),
956 distinct: false,
957 };
958 let expected = TypedPlan {
959 schema: RelationType::new(vec![
960 ColumnType::new(CDT::uint64_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ])
964 .with_key(vec![2])
965 .with_time_index(Some(1))
966 .into_named(vec![
967 Some("sum(numbers_with_ts.number)".to_string()),
968 Some("window_start".to_string()),
969 Some("window_end".to_string()),
970 ]),
971 plan: Plan::Mfp {
972 input: Box::new(
973 Plan::Reduce {
974 input: Box::new(
975 Plan::Get {
976 id: crate::expr::Id::Global(GlobalId::User(1)),
977 }
978 .with_types(
979 RelationType::new(vec![
980 ColumnType::new(ConcreteDataType::uint32_datatype(), false),
981 ColumnType::new(
982 ConcreteDataType::timestamp_millisecond_datatype(),
983 false,
984 ),
985 ])
986 .into_named(vec![
987 Some("number".to_string()),
988 Some("ts".to_string()),
989 ]),
990 )
991 .mfp(MapFilterProject::new(2).into_safe())
992 .unwrap(),
993 ),
994 key_val_plan: KeyValPlan {
995 key_plan: MapFilterProject::new(2)
996 .map(vec![
997 ScalarExpr::Column(1).call_unary(
998 UnaryFunc::TumbleWindowFloor {
999 window_size: Duration::from_nanos(3_600_000_000_000),
1000 start_time: Some(Timestamp::new_millisecond(
1001 1625097600000,
1002 )),
1003 },
1004 ),
1005 ScalarExpr::Column(1).call_unary(
1006 UnaryFunc::TumbleWindowCeiling {
1007 window_size: Duration::from_nanos(3_600_000_000_000),
1008 start_time: Some(Timestamp::new_millisecond(
1009 1625097600000,
1010 )),
1011 },
1012 ),
1013 ])
1014 .unwrap()
1015 .project(vec![2, 3])
1016 .unwrap()
1017 .into_safe(),
1018 val_plan: MapFilterProject::new(2)
1019 .map(vec![ScalarExpr::Column(0).cast(CDT::uint64_datatype())])
1020 .unwrap()
1021 .project(vec![2])
1022 .unwrap()
1023 .into_safe(),
1024 },
1025 reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1026 full_aggrs: vec![aggr_expr.clone()],
1027 simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
1028 distinct_aggrs: vec![],
1029 }),
1030 }
1031 .with_types(
1032 RelationType::new(vec![
1033 ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::timestamp_millisecond_datatype(), true), ColumnType::new(CDT::uint64_datatype(), true), ])
1037 .with_key(vec![1])
1038 .with_time_index(Some(0))
1039 .into_unnamed(),
1040 ),
1041 ),
1042 mfp: MapFilterProject::new(3)
1043 .map(vec![
1044 ScalarExpr::Column(2),
1045 ScalarExpr::Column(0),
1046 ScalarExpr::Column(1),
1047 ])
1048 .unwrap()
1049 .project(vec![3, 4, 5])
1050 .unwrap(),
1051 },
1052 };
1053 assert_eq!(flow_plan, expected);
1054 }
1055
1056 #[tokio::test]
1057 async fn test_avg_group_by() {
1058 let engine = create_test_query_engine();
1059 let sql = "SELECT avg(number), number FROM numbers GROUP BY number";
1060 let plan = sql_to_substrait(engine.clone(), sql).await;
1061
1062 let mut ctx = create_test_ctx();
1063 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
1064
1065 let aggr_exprs = vec![
1066 AggregateExpr {
1067 func: AggregateFunc::SumUInt64,
1068 expr: ScalarExpr::Column(0),
1069 distinct: false,
1070 },
1071 AggregateExpr {
1072 func: AggregateFunc::Count,
1073 expr: ScalarExpr::Column(1),
1074 distinct: false,
1075 },
1076 ];
1077 let avg_expr = ScalarExpr::If {
1078 cond: Box::new(ScalarExpr::Column(2).call_binary(
1079 ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()),
1080 BinaryFunc::NotEq,
1081 )),
1082 then: Box::new(
1083 ScalarExpr::Column(1)
1084 .cast(CDT::float64_datatype())
1085 .call_binary(
1086 ScalarExpr::Column(2).cast(CDT::float64_datatype()),
1087 BinaryFunc::DivFloat64,
1088 ),
1089 ),
1090 els: Box::new(ScalarExpr::Literal(Value::Null, CDT::float64_datatype())),
1091 };
1092 let expected = TypedPlan {
1093 schema: RelationType::new(vec![
1094 ColumnType::new(CDT::float64_datatype(), true), ColumnType::new(CDT::uint32_datatype(), false), ])
1097 .with_key(vec![1])
1098 .into_named(vec![
1099 Some("avg(numbers.number)".to_string()),
1100 Some("number".to_string()),
1101 ]),
1102 plan: Plan::Mfp {
1103 input: Box::new(
1104 Plan::Reduce {
1105 input: Box::new(
1106 Plan::Get {
1107 id: crate::expr::Id::Global(GlobalId::User(0)),
1108 }
1109 .with_types(
1110 RelationType::new(vec![ColumnType::new(
1111 ConcreteDataType::uint32_datatype(),
1112 false,
1113 )])
1114 .into_named(vec![Some("number".to_string())]),
1115 )
1116 .mfp(
1117 MapFilterProject::new(1)
1118 .project(vec![0])
1119 .unwrap()
1120 .into_safe(),
1121 )
1122 .unwrap(),
1123 ),
1124 key_val_plan: KeyValPlan {
1125 key_plan: MapFilterProject::new(1)
1126 .map(vec![ScalarExpr::Column(0)])
1127 .unwrap()
1128 .project(vec![1])
1129 .unwrap()
1130 .into_safe(),
1131 val_plan: MapFilterProject::new(1)
1132 .map(vec![
1133 ScalarExpr::Column(0).cast(CDT::uint64_datatype()),
1134 ScalarExpr::Column(0),
1135 ])
1136 .unwrap()
1137 .project(vec![1, 2])
1138 .unwrap()
1139 .into_safe(),
1140 },
1141 reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1142 full_aggrs: aggr_exprs.clone(),
1143 simple_aggrs: vec![
1144 AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
1145 AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1),
1146 ],
1147 distinct_aggrs: vec![],
1148 }),
1149 }
1150 .with_types(
1151 RelationType::new(vec![
1152 ColumnType::new(ConcreteDataType::uint32_datatype(), false), ColumnType::new(ConcreteDataType::uint64_datatype(), true), ColumnType::new(ConcreteDataType::int64_datatype(), true), ])
1156 .with_key(vec![0])
1157 .into_named(vec![
1158 Some("number".to_string()),
1159 None,
1160 None,
1161 ]),
1162 ),
1163 ),
1164 mfp: MapFilterProject::new(3)
1165 .map(vec![
1166 avg_expr, ScalarExpr::Column(0),
1168 ])
1170 .unwrap()
1171 .project(vec![3, 4])
1172 .unwrap(),
1173 },
1174 };
1175 assert_eq!(flow_plan.unwrap(), expected);
1176 }
1177
1178 #[tokio::test]
1179 async fn test_avg() {
1180 let engine = create_test_query_engine();
1181 let sql = "SELECT avg(number) FROM numbers";
1182 let plan = sql_to_substrait(engine.clone(), sql).await;
1183
1184 let mut ctx = create_test_ctx();
1185
1186 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
1187 .await
1188 .unwrap();
1189
1190 let aggr_exprs = vec![
1191 AggregateExpr {
1192 func: AggregateFunc::SumUInt64,
1193 expr: ScalarExpr::Column(0),
1194 distinct: false,
1195 },
1196 AggregateExpr {
1197 func: AggregateFunc::Count,
1198 expr: ScalarExpr::Column(1),
1199 distinct: false,
1200 },
1201 ];
1202 let avg_expr = ScalarExpr::If {
1203 cond: Box::new(ScalarExpr::Column(1).call_binary(
1204 ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()),
1205 BinaryFunc::NotEq,
1206 )),
1207 then: Box::new(
1208 ScalarExpr::Column(0)
1209 .cast(CDT::float64_datatype())
1210 .call_binary(
1211 ScalarExpr::Column(1).cast(CDT::float64_datatype()),
1212 BinaryFunc::DivFloat64,
1213 ),
1214 ),
1215 els: Box::new(ScalarExpr::Literal(Value::Null, CDT::float64_datatype())),
1216 };
1217 let input = Box::new(
1218 Plan::Get {
1219 id: crate::expr::Id::Global(GlobalId::User(0)),
1220 }
1221 .with_types(
1222 RelationType::new(vec![ColumnType::new(
1223 ConcreteDataType::uint32_datatype(),
1224 false,
1225 )])
1226 .into_named(vec![Some("number".to_string())]),
1227 ),
1228 );
1229 let expected = TypedPlan {
1230 schema: RelationType::new(vec![ColumnType::new(CDT::float64_datatype(), true)])
1231 .into_named(vec![Some("avg(numbers.number)".to_string())]),
1232 plan: Plan::Mfp {
1233 input: Box::new(
1234 Plan::Reduce {
1235 input: Box::new(
1236 Plan::Mfp {
1237 input: input.clone(),
1238 mfp: MapFilterProject::new(1).project(vec![0]).unwrap(),
1239 }
1240 .with_types(
1241 RelationType::new(vec![ColumnType::new(
1242 CDT::uint32_datatype(),
1243 false,
1244 )])
1245 .into_named(vec![Some("number".to_string())]),
1246 ),
1247 ),
1248 key_val_plan: KeyValPlan {
1249 key_plan: MapFilterProject::new(1)
1250 .project(vec![])
1251 .unwrap()
1252 .into_safe(),
1253 val_plan: MapFilterProject::new(1)
1254 .map(vec![
1255 ScalarExpr::Column(0).cast(CDT::uint64_datatype()),
1256 ScalarExpr::Column(0),
1257 ])
1258 .unwrap()
1259 .project(vec![1, 2])
1260 .unwrap()
1261 .into_safe(),
1262 },
1263 reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1264 full_aggrs: aggr_exprs.clone(),
1265 simple_aggrs: vec![
1266 AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
1267 AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1),
1268 ],
1269 distinct_aggrs: vec![],
1270 }),
1271 }
1272 .with_types(
1273 RelationType::new(vec![
1274 ColumnType::new(ConcreteDataType::uint64_datatype(), true), ColumnType::new(ConcreteDataType::int64_datatype(), true), ])
1277 .into_named(vec![None, None]),
1278 ),
1279 ),
1280 mfp: MapFilterProject::new(2)
1281 .map(vec![
1282 avg_expr,
1283 ])
1285 .unwrap()
1286 .project(vec![2])
1287 .unwrap(),
1288 },
1289 };
1290 assert_eq!(flow_plan, expected);
1291 }
1292
1293 #[tokio::test]
1294 async fn test_sum() {
1295 let engine = create_test_query_engine();
1296 let sql = "SELECT sum(number) FROM numbers";
1297 let plan = sql_to_substrait(engine.clone(), sql).await;
1298
1299 let mut ctx = create_test_ctx();
1300 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
1301
1302 let aggr_expr = AggregateExpr {
1303 func: AggregateFunc::SumUInt64,
1304 expr: ScalarExpr::Column(0),
1305 distinct: false,
1306 };
1307 let expected = TypedPlan {
1308 schema: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)])
1309 .into_named(vec![Some("sum(numbers.number)".to_string())]),
1310 plan: Plan::Reduce {
1311 input: Box::new(
1312 Plan::Get {
1313 id: crate::expr::Id::Global(GlobalId::User(0)),
1314 }
1315 .with_types(
1316 RelationType::new(vec![ColumnType::new(
1317 ConcreteDataType::uint32_datatype(),
1318 false,
1319 )])
1320 .into_named(vec![Some("number".to_string())]),
1321 )
1322 .mfp(MapFilterProject::new(1).into_safe())
1323 .unwrap(),
1324 ),
1325 key_val_plan: KeyValPlan {
1326 key_plan: MapFilterProject::new(1)
1327 .project(vec![])
1328 .unwrap()
1329 .into_safe(),
1330 val_plan: MapFilterProject::new(1)
1331 .map(vec![
1332 ScalarExpr::Column(0)
1333 .call_unary(UnaryFunc::Cast(CDT::uint64_datatype())),
1334 ])
1335 .unwrap()
1336 .project(vec![1])
1337 .unwrap()
1338 .into_safe(),
1339 },
1340 reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1341 full_aggrs: vec![aggr_expr.clone()],
1342 simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
1343 distinct_aggrs: vec![],
1344 }),
1345 },
1346 };
1347 assert_eq!(flow_plan.unwrap(), expected);
1348 }
1349
1350 #[tokio::test]
1351 async fn test_distinct_number() {
1352 let engine = create_test_query_engine();
1353 let sql = "SELECT DISTINCT number FROM numbers";
1354 let plan = sql_to_substrait(engine.clone(), sql).await;
1355
1356 let mut ctx = create_test_ctx();
1357 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
1358 .await
1359 .unwrap();
1360
1361 let expected = TypedPlan {
1362 schema: RelationType::new(vec![
1363 ColumnType::new(CDT::uint32_datatype(), false), ])
1365 .with_key(vec![0])
1366 .into_named(vec![Some("number".to_string())]),
1367 plan: Plan::Reduce {
1368 input: Box::new(
1369 Plan::Get {
1370 id: crate::expr::Id::Global(GlobalId::User(0)),
1371 }
1372 .with_types(
1373 RelationType::new(vec![ColumnType::new(
1374 ConcreteDataType::uint32_datatype(),
1375 false,
1376 )])
1377 .into_named(vec![Some("number".to_string())]),
1378 )
1379 .mfp(MapFilterProject::new(1).into_safe())
1380 .unwrap(),
1381 ),
1382 key_val_plan: KeyValPlan {
1383 key_plan: MapFilterProject::new(1)
1384 .map(vec![ScalarExpr::Column(0)])
1385 .unwrap()
1386 .project(vec![1])
1387 .unwrap()
1388 .into_safe(),
1389 val_plan: MapFilterProject::new(1)
1390 .project(vec![0])
1391 .unwrap()
1392 .into_safe(),
1393 },
1394 reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1395 full_aggrs: vec![],
1396 simple_aggrs: vec![],
1397 distinct_aggrs: vec![],
1398 }),
1399 },
1400 };
1401
1402 assert_eq!(flow_plan, expected);
1403 }
1404
1405 #[tokio::test]
1406 async fn test_sum_group_by() {
1407 let engine = create_test_query_engine();
1408 let sql = "SELECT sum(number), number FROM numbers GROUP BY number";
1409 let plan = sql_to_substrait(engine.clone(), sql).await;
1410
1411 let mut ctx = create_test_ctx();
1412 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
1413 .await
1414 .unwrap();
1415
1416 let aggr_expr = AggregateExpr {
1417 func: AggregateFunc::SumUInt64,
1418 expr: ScalarExpr::Column(0),
1419 distinct: false,
1420 };
1421 let expected = TypedPlan {
1422 schema: RelationType::new(vec![
1423 ColumnType::new(CDT::uint64_datatype(), true), ColumnType::new(CDT::uint32_datatype(), false), ])
1426 .with_key(vec![1])
1427 .into_named(vec![
1428 Some("sum(numbers.number)".to_string()),
1429 Some("number".to_string()),
1430 ]),
1431 plan: Plan::Mfp {
1432 input: Box::new(
1433 Plan::Reduce {
1434 input: Box::new(
1435 Plan::Get {
1436 id: crate::expr::Id::Global(GlobalId::User(0)),
1437 }
1438 .with_types(
1439 RelationType::new(vec![ColumnType::new(
1440 ConcreteDataType::uint32_datatype(),
1441 false,
1442 )])
1443 .into_named(vec![Some("number".to_string())]),
1444 )
1445 .mfp(MapFilterProject::new(1).into_safe())
1446 .unwrap(),
1447 ),
1448 key_val_plan: KeyValPlan {
1449 key_plan: MapFilterProject::new(1)
1450 .map(vec![ScalarExpr::Column(0)])
1451 .unwrap()
1452 .project(vec![1])
1453 .unwrap()
1454 .into_safe(),
1455 val_plan: MapFilterProject::new(1)
1456 .map(vec![
1457 ScalarExpr::Column(0)
1458 .call_unary(UnaryFunc::Cast(CDT::uint64_datatype())),
1459 ])
1460 .unwrap()
1461 .project(vec![1])
1462 .unwrap()
1463 .into_safe(),
1464 },
1465 reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1466 full_aggrs: vec![aggr_expr.clone()],
1467 simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
1468 distinct_aggrs: vec![],
1469 }),
1470 }
1471 .with_types(
1472 RelationType::new(vec![
1473 ColumnType::new(CDT::uint32_datatype(), false), ColumnType::new(CDT::uint64_datatype(), true), ])
1476 .with_key(vec![0])
1477 .into_named(vec![Some("number".to_string()), None]),
1478 ),
1479 ),
1480 mfp: MapFilterProject::new(2)
1481 .map(vec![ScalarExpr::Column(1), ScalarExpr::Column(0)])
1482 .unwrap()
1483 .project(vec![2, 3])
1484 .unwrap(),
1485 },
1486 };
1487
1488 assert_eq!(flow_plan, expected);
1489 }
1490
1491 #[tokio::test]
1492 async fn test_sum_add() {
1493 let engine = create_test_query_engine();
1494 let sql = "SELECT sum(number+number) FROM numbers";
1495 let plan = sql_to_substrait(engine.clone(), sql).await;
1496
1497 let mut ctx = create_test_ctx();
1498 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
1499
1500 let aggr_expr = AggregateExpr {
1501 func: AggregateFunc::SumUInt64,
1502 expr: ScalarExpr::Column(0),
1503 distinct: false,
1504 };
1505 let expected = TypedPlan {
1506 schema: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)])
1507 .into_named(vec![Some(
1508 "sum(numbers.number + numbers.number)".to_string(),
1509 )]),
1510 plan: Plan::Reduce {
1511 input: Box::new(
1512 Plan::Mfp {
1513 input: Box::new(
1514 Plan::Get {
1515 id: crate::expr::Id::Global(GlobalId::User(0)),
1516 }
1517 .with_types(
1518 RelationType::new(vec![ColumnType::new(
1519 ConcreteDataType::uint32_datatype(),
1520 false,
1521 )])
1522 .into_named(vec![Some("number".to_string())]),
1523 ),
1524 ),
1525 mfp: MapFilterProject::new(1),
1526 }
1527 .with_types(
1528 RelationType::new(vec![ColumnType::new(
1529 ConcreteDataType::uint32_datatype(),
1530 false,
1531 )])
1532 .into_named(vec![Some("number".to_string())]),
1533 ),
1534 ),
1535 key_val_plan: KeyValPlan {
1536 key_plan: MapFilterProject::new(1)
1537 .project(vec![])
1538 .unwrap()
1539 .into_safe(),
1540 val_plan: MapFilterProject::new(1)
1541 .map(vec![
1542 ScalarExpr::Column(0)
1543 .call_binary(ScalarExpr::Column(0), BinaryFunc::AddUInt32)
1544 .call_unary(UnaryFunc::Cast(CDT::uint64_datatype())),
1545 ])
1546 .unwrap()
1547 .project(vec![1])
1548 .unwrap()
1549 .into_safe(),
1550 },
1551 reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1552 full_aggrs: vec![aggr_expr.clone()],
1553 simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
1554 distinct_aggrs: vec![],
1555 }),
1556 },
1557 };
1558 assert_eq!(flow_plan.unwrap(), expected);
1559 }
1560
1561 #[tokio::test]
1562 async fn test_cast_max_min() {
1563 let engine = create_test_query_engine();
1564 let sql = "SELECT (max(number) - min(number))/30.0, date_bin(INTERVAL '30 second', CAST(ts AS TimestampMillisecond)) as time_window from numbers_with_ts GROUP BY time_window";
1565 let plan = sql_to_substrait(engine.clone(), sql).await;
1566
1567 let mut ctx = create_test_ctx();
1568 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
1569
1570 let aggr_exprs = vec![
1571 AggregateExpr {
1572 func: AggregateFunc::MaxUInt32,
1573 expr: ScalarExpr::Column(0),
1574 distinct: false,
1575 },
1576 AggregateExpr {
1577 func: AggregateFunc::MinUInt32,
1578 expr: ScalarExpr::Column(0),
1579 distinct: false,
1580 },
1581 ];
1582 let expected = TypedPlan {
1583 schema: RelationType::new(vec![
1584 ColumnType::new(CDT::float64_datatype(), true),
1585 ColumnType::new(CDT::timestamp_millisecond_datatype(), true),
1586 ])
1587 .with_time_index(Some(1))
1588 .into_named(vec![
1589 Some(
1590 "max(numbers_with_ts.number) - min(numbers_with_ts.number) / Float64(30)"
1591 .to_string(),
1592 ),
1593 Some("time_window".to_string()),
1594 ]),
1595 plan: Plan::Mfp {
1596 input: Box::new(
1597 Plan::Reduce {
1598 input: Box::new(
1599 Plan::Get {
1600 id: crate::expr::Id::Global(GlobalId::User(1)),
1601 }
1602 .with_types(
1603 RelationType::new(vec![
1604 ColumnType::new(ConcreteDataType::uint32_datatype(), false),
1605 ColumnType::new(ConcreteDataType::timestamp_millisecond_datatype(), false),
1606 ])
1607 .into_named(vec![
1608 Some("number".to_string()),
1609 Some("ts".to_string()),
1610 ]),
1611 )
1612 .mfp(MapFilterProject::new(2).into_safe())
1613 .unwrap(),
1614 ),
1615
1616 key_val_plan: KeyValPlan {
1617 key_plan: MapFilterProject::new(2)
1618 .map(vec![ScalarExpr::CallDf {
1619 df_scalar_fn: DfScalarFunction::try_from_raw_fn(
1620 RawDfScalarFn {
1621 f: BytesMut::from(
1622 b"\x08\x02\"\x0f\x1a\r\n\x0b\xa2\x02\x08\n\0\x12\x04\x10\x1e \t\"\n\x1a\x08\x12\x06\n\x04\x12\x02\x08\x01".as_ref(),
1623 ),
1624 input_schema: RelationType::new(vec![ColumnType::new(
1625 ConcreteDataType::interval_month_day_nano_datatype(),
1626 true,
1627 ),ColumnType::new(
1628 ConcreteDataType::timestamp_millisecond_datatype(),
1629 false,
1630 )])
1631 .into_unnamed(),
1632 extensions: FunctionExtensions::from_iter([
1633 (0, "subtract".to_string()),
1634 (1, "divide".to_string()),
1635 (2, "date_bin".to_string()),
1636 (3, "max".to_string()),
1637 (4, "min".to_string()),
1638 ]),
1639 },
1640 )
1641 .await
1642 .unwrap(),
1643 exprs: vec![
1644 ScalarExpr::Literal(
1645 Value::IntervalMonthDayNano(IntervalMonthDayNano::new(0, 0, 30000000000)),
1646 CDT::interval_month_day_nano_datatype()
1647 ),
1648 ScalarExpr::Column(1)
1649 ],
1650 }])
1651 .unwrap()
1652 .project(vec![2])
1653 .unwrap()
1654 .into_safe(),
1655 val_plan: MapFilterProject::new(2)
1656 .into_safe(),
1657 },
1658 reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1659 full_aggrs: aggr_exprs.clone(),
1660 simple_aggrs: vec![AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
1661 AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1)],
1662 distinct_aggrs: vec![],
1663 }),
1664 }
1665 .with_types(
1666 RelationType::new(vec![
1667 ColumnType::new(
1668 ConcreteDataType::timestamp_millisecond_datatype(),
1669 true,
1670 ), ColumnType::new(ConcreteDataType::uint32_datatype(), true), ColumnType::new(ConcreteDataType::uint32_datatype(), true), ])
1674 .with_time_index(Some(0))
1675 .into_unnamed(),
1676 ),
1677 ),
1678 mfp: MapFilterProject::new(3)
1679 .map(vec![
1680 ScalarExpr::Column(1)
1681 .call_binary(ScalarExpr::Column(2), BinaryFunc::SubUInt32)
1682 .cast(CDT::float64_datatype())
1683 .call_binary(
1684 ScalarExpr::Literal(Value::from(30.0f64), CDT::float64_datatype()),
1685 BinaryFunc::DivFloat64,
1686 ),
1687 ScalarExpr::Column(0),
1688 ])
1689 .unwrap()
1690 .project(vec![3, 4])
1691 .unwrap(),
1692 },
1693 };
1694
1695 assert_eq!(flow_plan.unwrap(), expected);
1696 }
1697}