flow/transform/
aggr.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use snafu::OptionExt;
17use substrait_proto::proto;
18use substrait_proto::proto::aggregate_function::AggregationInvocation;
19use substrait_proto::proto::aggregate_rel::{Grouping, Measure};
20use substrait_proto::proto::function_argument::ArgType;
21
22use crate::error::{Error, NotImplementedSnafu, PlanSnafu};
23use crate::expr::{
24    AggregateExpr, AggregateFunc, MapFilterProject, ScalarExpr, TypedExpr, UnaryFunc,
25};
26use crate::plan::{AccumulablePlan, AggrWithIndex, KeyValPlan, Plan, ReducePlan, TypedPlan};
27use crate::repr::{ColumnType, RelationDesc, RelationType};
28use crate::transform::{substrait_proto, FlownodeContext, FunctionExtensions};
29
30impl TypedExpr {
31    /// Allow `deprecated` due to the usage of deprecated grouping_expressions on datafusion to substrait side
32    #[allow(deprecated)]
33    async fn from_substrait_agg_grouping(
34        ctx: &mut FlownodeContext,
35        grouping_expressions: &[proto::Expression],
36        groupings: &[Grouping],
37        typ: &RelationDesc,
38        extensions: &FunctionExtensions,
39    ) -> Result<Vec<TypedExpr>, Error> {
40        let _ = ctx;
41        let mut group_expr = vec![];
42        match groupings.len() {
43            1 => {
44                // handle case when deprecated grouping_expressions is referenced by index is empty
45                let expressions: Box<dyn Iterator<Item = &proto::Expression> + Send> = if groupings
46                    [0]
47                .expression_references
48                .is_empty()
49                {
50                    Box::new(groupings[0].grouping_expressions.iter())
51                } else {
52                    if groupings[0]
53                        .expression_references
54                        .iter()
55                        .any(|idx| *idx as usize >= grouping_expressions.len())
56                    {
57                        return PlanSnafu {
58                            reason: format!("Invalid grouping expression reference: {:?} for grouping expr: {:?}", 
59                            groupings[0].expression_references,
60                            grouping_expressions
61                        ),
62                        }.fail()?;
63                    }
64                    Box::new(
65                        groupings[0]
66                            .expression_references
67                            .iter()
68                            .map(|idx| &grouping_expressions[*idx as usize]),
69                    )
70                };
71                for e in expressions {
72                    let x = TypedExpr::from_substrait_rex(e, typ, extensions).await?;
73                    group_expr.push(x);
74                }
75            }
76            _ => {
77                return not_impl_err!(
78                    "Grouping sets not support yet, use union all with group by instead."
79                );
80            }
81        };
82        Ok(group_expr)
83    }
84}
85
86impl AggregateExpr {
87    /// Convert list of `Measure` into Flow's AggregateExpr
88    ///
89    /// Return both the AggregateExpr and a MapFilterProject that is the final output of the aggregate function
90    async fn from_substrait_agg_measures(
91        ctx: &mut FlownodeContext,
92        measures: &[Measure],
93        typ: &RelationDesc,
94        extensions: &FunctionExtensions,
95    ) -> Result<Vec<AggregateExpr>, Error> {
96        let _ = ctx;
97        let mut all_aggr_exprs = vec![];
98
99        for m in measures {
100            let filter = match m
101                .filter
102                .as_ref()
103                .map(|fil| TypedExpr::from_substrait_rex(fil, typ, extensions))
104            {
105                Some(fut) => Some(fut.await),
106                None => None,
107            }
108            .transpose()?;
109
110            let aggr_expr = match &m.measure {
111                Some(f) => {
112                    let distinct = match f.invocation {
113                        _ if f.invocation == AggregationInvocation::Distinct as i32 => true,
114                        _ if f.invocation == AggregationInvocation::All as i32 => false,
115                        _ => false,
116                    };
117                    AggregateExpr::from_substrait_agg_func(
118                        f, typ, extensions, &filter, // TODO(discord9): impl order_by
119                        &None, distinct,
120                    )
121                    .await?
122                }
123                None => {
124                    return not_impl_err!("Aggregate without aggregate function is not supported")
125                }
126            };
127
128            all_aggr_exprs.extend(aggr_expr);
129        }
130
131        Ok(all_aggr_exprs)
132    }
133
134    /// Convert AggregateFunction into Flow's AggregateExpr
135    ///
136    /// the returned value is a tuple of AggregateExpr and a optional ScalarExpr that if exist is the final output of the aggregate function
137    /// since aggr functions like `avg` need to be transform to `sum(x)/cast(count(x) as x_type)`
138    pub async fn from_substrait_agg_func(
139        f: &proto::AggregateFunction,
140        input_schema: &RelationDesc,
141        extensions: &FunctionExtensions,
142        filter: &Option<TypedExpr>,
143        order_by: &Option<Vec<TypedExpr>>,
144        distinct: bool,
145    ) -> Result<Vec<AggregateExpr>, Error> {
146        // TODO(discord9): impl filter
147        let _ = filter;
148        let _ = order_by;
149        let mut args = vec![];
150        for arg in &f.arguments {
151            let arg_expr = match &arg.arg_type {
152                Some(ArgType::Value(e)) => {
153                    TypedExpr::from_substrait_rex(e, input_schema, extensions).await
154                }
155                _ => not_impl_err!("Aggregated function argument non-Value type not supported"),
156            }?;
157            args.push(arg_expr);
158        }
159
160        if args.len() != 1 {
161            let fn_name = extensions.get(&f.function_reference).cloned();
162            return not_impl_err!(
163                "Aggregated function (name={:?}) with multiple arguments is not supported",
164                fn_name
165            );
166        }
167
168        let arg = if let Some(first) = args.first() {
169            first
170        } else {
171            return not_impl_err!("Aggregated function without arguments is not supported");
172        };
173
174        let fn_name = extensions
175            .get(&f.function_reference)
176            .cloned()
177            .map(|s| s.to_lowercase());
178
179        match fn_name.as_ref().map(|s| s.as_ref()) {
180            Some(function_name) => {
181                let func = AggregateFunc::from_str_and_type(
182                    function_name,
183                    Some(arg.typ.scalar_type.clone()),
184                )?;
185                let exprs = vec![AggregateExpr {
186                    func,
187                    expr: arg.expr.clone(),
188                    distinct,
189                }];
190                Ok(exprs)
191            }
192            None => not_impl_err!(
193                "Aggregated function not found: function anchor = {:?}",
194                f.function_reference
195            ),
196        }
197    }
198}
199
200impl KeyValPlan {
201    /// Generate KeyValPlan from AggregateExpr and group_exprs
202    ///
203    /// will also change aggregate expr to use column ref if necessary
204    fn from_substrait_gen_key_val_plan(
205        aggr_exprs: &mut [AggregateExpr],
206        group_exprs: &[TypedExpr],
207        input_arity: usize,
208    ) -> Result<KeyValPlan, Error> {
209        let group_expr_val = group_exprs
210            .iter()
211            .cloned()
212            .map(|expr| expr.expr.clone())
213            .collect_vec();
214        let output_arity = group_expr_val.len();
215        let key_plan = MapFilterProject::new(input_arity)
216            .map(group_expr_val)?
217            .project(input_arity..input_arity + output_arity)?;
218
219        // val_plan is extracted from aggr_exprs to give aggr function it's necessary input
220        // and since aggr func need inputs that is column ref, we just add a prefix mfp to transform any expr that is not into a column ref
221        let val_plan = {
222            let need_mfp = aggr_exprs.iter().any(|agg| agg.expr.as_column().is_none());
223            if need_mfp {
224                // create mfp from aggr_expr, and modify aggr_expr to use the output column of mfp
225                let input_exprs = aggr_exprs
226                    .iter_mut()
227                    .enumerate()
228                    .map(|(idx, aggr)| {
229                        let ret = aggr.expr.clone();
230                        aggr.expr = ScalarExpr::Column(idx);
231                        ret
232                    })
233                    .collect_vec();
234                let aggr_arity = aggr_exprs.len();
235
236                MapFilterProject::new(input_arity)
237                    .map(input_exprs)?
238                    .project(input_arity..input_arity + aggr_arity)?
239            } else {
240                // simply take all inputs as value
241                MapFilterProject::new(input_arity)
242            }
243        };
244        Ok(KeyValPlan {
245            key_plan: key_plan.into_safe(),
246            val_plan: val_plan.into_safe(),
247        })
248    }
249}
250
251/// find out the column that should be time index in group exprs(which is all columns that should be keys)
252/// TODO(discord9): better ways to assign time index
253/// for now, it will found the first column that is timestamp or has a tumble window floor function
254fn find_time_index_in_group_exprs(group_exprs: &[TypedExpr]) -> Option<usize> {
255    group_exprs.iter().position(|expr| {
256        matches!(
257            &expr.expr,
258            ScalarExpr::CallUnary {
259                func: UnaryFunc::TumbleWindowFloor { .. },
260                expr: _
261            }
262        ) || expr.typ.scalar_type.is_timestamp()
263    })
264}
265
266impl TypedPlan {
267    /// Convert AggregateRel into Flow's TypedPlan
268    ///
269    /// The output of aggr plan is:
270    ///
271    /// <group_exprs>..<aggr_exprs>
272    #[async_recursion::async_recursion]
273    pub async fn from_substrait_agg_rel(
274        ctx: &mut FlownodeContext,
275        agg: &proto::AggregateRel,
276        extensions: &FunctionExtensions,
277    ) -> Result<TypedPlan, Error> {
278        let input = if let Some(input) = agg.input.as_ref() {
279            TypedPlan::from_substrait_rel(ctx, input, extensions).await?
280        } else {
281            return not_impl_err!("Aggregate without an input is not supported");
282        };
283
284        let group_exprs = TypedExpr::from_substrait_agg_grouping(
285            ctx,
286            &agg.grouping_expressions,
287            &agg.groupings,
288            &input.schema,
289            extensions,
290        )
291        .await?;
292
293        let time_index = find_time_index_in_group_exprs(&group_exprs);
294
295        let mut aggr_exprs = AggregateExpr::from_substrait_agg_measures(
296            ctx,
297            &agg.measures,
298            &input.schema,
299            extensions,
300        )
301        .await?;
302
303        let key_val_plan = KeyValPlan::from_substrait_gen_key_val_plan(
304            &mut aggr_exprs,
305            &group_exprs,
306            input.schema.typ.column_types.len(),
307        )?;
308
309        // output type is group_exprs + aggr_exprs
310        let output_type = {
311            let mut output_types = Vec::new();
312            // give best effort to get column name
313            let mut output_names = Vec::new();
314
315            // first append group_expr as key, then aggr_expr as value
316            for expr in group_exprs.iter() {
317                output_types.push(expr.typ.clone());
318                let col_name = match &expr.expr {
319                    ScalarExpr::Column(col) => input.schema.get_name(*col).clone(),
320                    // TODO(discord9): impl& use ScalarExpr.display_name, which recursively build expr's name
321                    _ => None,
322                };
323                output_names.push(col_name)
324            }
325
326            for aggr in &aggr_exprs {
327                output_types.push(ColumnType::new_nullable(
328                    aggr.func.signature().output.clone(),
329                ));
330                // TODO(discord9): find a clever way to name them?
331                output_names.push(None);
332            }
333            // TODO(discord9): try best to get time
334            if group_exprs.is_empty() {
335                RelationType::new(output_types)
336            } else {
337                RelationType::new(output_types).with_key((0..group_exprs.len()).collect_vec())
338            }
339            .with_time_index(time_index)
340            .into_named(output_names)
341        };
342
343        // copy aggr_exprs to full_aggrs, and split them into simple_aggrs and distinct_aggrs
344        // also set them input/output column
345        let full_aggrs = aggr_exprs;
346        let mut simple_aggrs = Vec::new();
347        let mut distinct_aggrs = Vec::new();
348        for (output_column, aggr_expr) in full_aggrs.iter().enumerate() {
349            let input_column = aggr_expr.expr.as_column().with_context(|| PlanSnafu {
350                reason: "Expect aggregate argument to be transformed into a column at this point",
351            })?;
352            if aggr_expr.distinct {
353                distinct_aggrs.push(AggrWithIndex::new(
354                    aggr_expr.clone(),
355                    input_column,
356                    output_column,
357                ));
358            } else {
359                simple_aggrs.push(AggrWithIndex::new(
360                    aggr_expr.clone(),
361                    input_column,
362                    output_column,
363                ));
364            }
365        }
366        let accum_plan = AccumulablePlan {
367            full_aggrs,
368            simple_aggrs,
369            distinct_aggrs,
370        };
371        let plan = Plan::Reduce {
372            input: Box::new(input),
373            key_val_plan,
374            reduce_plan: ReducePlan::Accumulable(accum_plan),
375        };
376        // FIX(discord9): deal with key first
377        return Ok(TypedPlan {
378            schema: output_type,
379            plan,
380        });
381    }
382}
383
384#[cfg(test)]
385mod test {
386    use std::time::Duration;
387
388    use bytes::BytesMut;
389    use common_time::{IntervalMonthDayNano, Timestamp};
390    use datatypes::prelude::ConcreteDataType;
391    use datatypes::value::Value;
392    use pretty_assertions::assert_eq;
393
394    use super::*;
395    use crate::expr::{BinaryFunc, DfScalarFunction, GlobalId, RawDfScalarFn};
396    use crate::plan::{Plan, TypedPlan};
397    use crate::repr::{ColumnType, RelationType};
398    use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait};
399    use crate::transform::CDT;
400
401    #[tokio::test]
402    async fn test_df_func_basic() {
403        let engine = create_test_query_engine();
404        let sql = "SELECT sum(abs(number)) FROM numbers_with_ts GROUP BY tumble(ts, '1 second', '2021-07-01 00:00:00');";
405        let plan = sql_to_substrait(engine.clone(), sql).await;
406
407        let mut ctx = create_test_ctx();
408        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
409            .await
410            .unwrap();
411
412        let aggr_expr = AggregateExpr {
413            func: AggregateFunc::SumUInt64,
414            expr: ScalarExpr::Column(0),
415            distinct: false,
416        };
417        let expected = TypedPlan {
418            schema: RelationType::new(vec![
419                ColumnType::new(CDT::uint64_datatype(), true), // sum(number)
420                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
421                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
422            ])
423            .with_key(vec![2])
424            .with_time_index(Some(1))
425            .into_named(vec![
426                Some("sum(abs(numbers_with_ts.number))".to_string()),
427                Some("window_start".to_string()),
428                Some("window_end".to_string()),
429            ]),
430            plan: Plan::Mfp {
431                input: Box::new(
432                    Plan::Reduce {
433                        input: Box::new(
434                            Plan::Get {
435                                id: crate::expr::Id::Global(GlobalId::User(1)),
436                            }
437                            .with_types(
438                                RelationType::new(vec![
439                                    ColumnType::new(ConcreteDataType::uint32_datatype(), false),
440                                    ColumnType::new(
441                                        ConcreteDataType::timestamp_millisecond_datatype(),
442                                        false,
443                                    ),
444                                ])
445                                .into_named(vec![
446                                    Some("number".to_string()),
447                                    Some("ts".to_string()),
448                                ]),
449                            )
450                            .mfp(MapFilterProject::new(2).into_safe())
451                            .unwrap(),
452                        ),
453                        key_val_plan: KeyValPlan {
454                            key_plan: MapFilterProject::new(2)
455                                .map(vec![
456                                    ScalarExpr::Column(1).call_unary(
457                                        UnaryFunc::TumbleWindowFloor {
458                                            window_size: Duration::from_nanos(1_000_000_000),
459                                            start_time: Some(Timestamp::new_millisecond(
460                                                1625097600000,
461                                            )),
462                                        },
463                                    ),
464                                    ScalarExpr::Column(1).call_unary(
465                                        UnaryFunc::TumbleWindowCeiling {
466                                            window_size: Duration::from_nanos(1_000_000_000),
467                                            start_time: Some(Timestamp::new_millisecond(
468                                                1625097600000,
469                                            )),
470                                        },
471                                    ),
472                                ])
473                                .unwrap()
474                                .project(vec![2, 3])
475                                .unwrap()
476                                .into_safe(),
477                            val_plan: MapFilterProject::new(2)
478                                .map(vec![ScalarExpr::CallDf {
479                                    df_scalar_fn: DfScalarFunction::try_from_raw_fn(
480                                        RawDfScalarFn {
481                                            f: BytesMut::from(
482                                                b"\x08\x02\"\x08\x1a\x06\x12\x04\n\x02\x12\0"
483                                                    .as_ref(),
484                                            ),
485                                            input_schema: RelationType::new(vec![ColumnType::new(
486                                                ConcreteDataType::uint32_datatype(),
487                                                false,
488                                            )])
489                                            .into_unnamed(),
490                                            extensions: FunctionExtensions::from_iter(
491                                                [
492                                                    (0, "tumble_start".to_string()),
493                                                    (1, "tumble_end".to_string()),
494                                                    (2, "abs".to_string()),
495                                                    (3, "sum".to_string()),
496                                                ]
497                                                .into_iter(),
498                                            ),
499                                        },
500                                    )
501                                    .await
502                                    .unwrap(),
503                                    exprs: vec![ScalarExpr::Column(0)],
504                                }
505                                .cast(CDT::uint64_datatype())])
506                                .unwrap()
507                                .project(vec![2])
508                                .unwrap()
509                                .into_safe(),
510                        },
511                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
512                            full_aggrs: vec![aggr_expr.clone()],
513                            simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
514                            distinct_aggrs: vec![],
515                        }),
516                    }
517                    .with_types(
518                        RelationType::new(vec![
519                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
520                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
521                            ColumnType::new(CDT::uint64_datatype(), true), //sum(number)
522                        ])
523                        .with_key(vec![1])
524                        .with_time_index(Some(0))
525                        .into_unnamed(),
526                    ),
527                ),
528                mfp: MapFilterProject::new(3)
529                    .map(vec![
530                        ScalarExpr::Column(2),
531                        ScalarExpr::Column(0),
532                        ScalarExpr::Column(1),
533                    ])
534                    .unwrap()
535                    .project(vec![3, 4, 5])
536                    .unwrap(),
537            },
538        };
539        assert_eq!(flow_plan, expected);
540    }
541
542    #[tokio::test]
543    async fn test_df_func_expr_tree() {
544        let engine = create_test_query_engine();
545        let sql = "SELECT abs(sum(number)) FROM numbers_with_ts GROUP BY tumble(ts, '1 second', '2021-07-01 00:00:00');";
546        let plan = sql_to_substrait(engine.clone(), sql).await;
547
548        let mut ctx = create_test_ctx();
549        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
550            .await
551            .unwrap();
552
553        let aggr_expr = AggregateExpr {
554            func: AggregateFunc::SumUInt64,
555            expr: ScalarExpr::Column(0),
556            distinct: false,
557        };
558        let expected = TypedPlan {
559            schema: RelationType::new(vec![
560                ColumnType::new(CDT::uint64_datatype(), true), // sum(number)
561                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
562                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
563            ])
564            .with_key(vec![2])
565            .with_time_index(Some(1))
566            .into_named(vec![
567                Some("abs(sum(numbers_with_ts.number))".to_string()),
568                Some("window_start".to_string()),
569                Some("window_end".to_string()),
570            ]),
571            plan: Plan::Mfp {
572                input: Box::new(
573                    Plan::Reduce {
574                        input: Box::new(
575                            Plan::Get {
576                                id: crate::expr::Id::Global(GlobalId::User(1)),
577                            }
578                            .with_types(
579                                RelationType::new(vec![
580                                    ColumnType::new(ConcreteDataType::uint32_datatype(), false),
581                                    ColumnType::new(
582                                        ConcreteDataType::timestamp_millisecond_datatype(),
583                                        false,
584                                    ),
585                                ])
586                                .into_named(vec![
587                                    Some("number".to_string()),
588                                    Some("ts".to_string()),
589                                ]),
590                            )
591                            .mfp(MapFilterProject::new(2).into_safe())
592                            .unwrap(),
593                        ),
594                        key_val_plan: KeyValPlan {
595                            key_plan: MapFilterProject::new(2)
596                                .map(vec![
597                                    ScalarExpr::Column(1).call_unary(
598                                        UnaryFunc::TumbleWindowFloor {
599                                            window_size: Duration::from_nanos(1_000_000_000),
600                                            start_time: Some(Timestamp::new_millisecond(
601                                                1625097600000,
602                                            )),
603                                        },
604                                    ),
605                                    ScalarExpr::Column(1).call_unary(
606                                        UnaryFunc::TumbleWindowCeiling {
607                                            window_size: Duration::from_nanos(1_000_000_000),
608                                            start_time: Some(Timestamp::new_millisecond(
609                                                1625097600000,
610                                            )),
611                                        },
612                                    ),
613                                ])
614                                .unwrap()
615                                .project(vec![2, 3])
616                                .unwrap()
617                                .into_safe(),
618                            val_plan: MapFilterProject::new(2)
619                                .map(vec![ScalarExpr::Column(0).cast(CDT::uint64_datatype())])
620                                .unwrap()
621                                .project(vec![2])
622                                .unwrap()
623                                .into_safe(),
624                        },
625                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
626                            full_aggrs: vec![aggr_expr.clone()],
627                            simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
628                            distinct_aggrs: vec![],
629                        }),
630                    }
631                    .with_types(
632                        RelationType::new(vec![
633                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
634                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
635                            ColumnType::new(CDT::uint64_datatype(), true), //sum(number)
636                        ])
637                        .with_key(vec![1])
638                        .with_time_index(Some(0))
639                        .into_named(vec![None, None, None]),
640                    ),
641                ),
642                mfp: MapFilterProject::new(3)
643                    .map(vec![
644                        ScalarExpr::CallDf {
645                            df_scalar_fn: DfScalarFunction::try_from_raw_fn(RawDfScalarFn {
646                                f: BytesMut::from(b"\"\x08\x1a\x06\x12\x04\n\x02\x12\0".as_ref()),
647                                input_schema: RelationType::new(vec![ColumnType::new(
648                                    ConcreteDataType::uint64_datatype(),
649                                    true,
650                                )])
651                                .into_unnamed(),
652                                extensions: FunctionExtensions::from_iter(
653                                    [
654                                        (0, "abs".to_string()),
655                                        (1, "tumble_start".to_string()),
656                                        (2, "tumble_end".to_string()),
657                                        (3, "sum".to_string()),
658                                    ]
659                                    .into_iter(),
660                                ),
661                            })
662                            .await
663                            .unwrap(),
664                            exprs: vec![ScalarExpr::Column(2)],
665                        },
666                        ScalarExpr::Column(0),
667                        ScalarExpr::Column(1),
668                    ])
669                    .unwrap()
670                    .project(vec![3, 4, 5])
671                    .unwrap(),
672            },
673        };
674        assert_eq!(flow_plan, expected);
675    }
676
677    /// TODO(discord9): add more illegal sql tests
678    #[tokio::test]
679    async fn test_tumble_composite() {
680        let engine = create_test_query_engine();
681        let sql =
682            "SELECT number, avg(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour'), number";
683        let plan = sql_to_substrait(engine.clone(), sql).await;
684
685        let mut ctx = create_test_ctx();
686        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
687            .await
688            .unwrap();
689
690        let aggr_exprs = vec![
691            AggregateExpr {
692                func: AggregateFunc::SumUInt64,
693                expr: ScalarExpr::Column(0),
694                distinct: false,
695            },
696            AggregateExpr {
697                func: AggregateFunc::Count,
698                expr: ScalarExpr::Column(1),
699                distinct: false,
700            },
701        ];
702        let avg_expr = ScalarExpr::If {
703            cond: Box::new(ScalarExpr::Column(4).call_binary(
704                ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()),
705                BinaryFunc::NotEq,
706            )),
707            then: Box::new(
708                ScalarExpr::Column(3)
709                    .cast(CDT::float64_datatype())
710                    .call_binary(
711                        ScalarExpr::Column(4).cast(CDT::float64_datatype()),
712                        BinaryFunc::DivFloat64,
713                    ),
714            ),
715            els: Box::new(ScalarExpr::Literal(Value::Null, CDT::float64_datatype())),
716        };
717        let expected = TypedPlan {
718            plan: Plan::Mfp {
719                input: Box::new(
720                    Plan::Reduce {
721                        input: Box::new(
722                            Plan::Get {
723                                id: crate::expr::Id::Global(GlobalId::User(1)),
724                            }
725                            .with_types(
726                                RelationType::new(vec![
727                                    ColumnType::new(ConcreteDataType::uint32_datatype(), false),
728                                    ColumnType::new(
729                                        ConcreteDataType::timestamp_millisecond_datatype(),
730                                        false,
731                                    ),
732                                ])
733                                .into_named(vec![
734                                    Some("number".to_string()),
735                                    Some("ts".to_string()),
736                                ]),
737                            )
738                            .mfp(MapFilterProject::new(2).into_safe())
739                            .unwrap(),
740                        ),
741                        key_val_plan: KeyValPlan {
742                            key_plan: MapFilterProject::new(2)
743                                .map(vec![
744                                    ScalarExpr::Column(1).call_unary(
745                                        UnaryFunc::TumbleWindowFloor {
746                                            window_size: Duration::from_nanos(3_600_000_000_000),
747                                            start_time: None,
748                                        },
749                                    ),
750                                    ScalarExpr::Column(1).call_unary(
751                                        UnaryFunc::TumbleWindowCeiling {
752                                            window_size: Duration::from_nanos(3_600_000_000_000),
753                                            start_time: None,
754                                        },
755                                    ),
756                                    ScalarExpr::Column(0),
757                                ])
758                                .unwrap()
759                                .project(vec![2, 3, 4])
760                                .unwrap()
761                                .into_safe(),
762                            val_plan: MapFilterProject::new(2)
763                                .map(vec![
764                                    ScalarExpr::Column(0).cast(CDT::uint64_datatype()),
765                                    ScalarExpr::Column(0),
766                                ])
767                                .unwrap()
768                                .project(vec![2, 3])
769                                .unwrap()
770                                .into_safe(),
771                        },
772                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
773                            full_aggrs: aggr_exprs.clone(),
774                            simple_aggrs: vec![
775                                AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
776                                AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1),
777                            ],
778                            distinct_aggrs: vec![],
779                        }),
780                    }
781                    .with_types(
782                        RelationType::new(vec![
783                            // keys
784                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start(time index)
785                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end(pk)
786                            ColumnType::new(CDT::uint32_datatype(), false), // number(pk)
787                            // values
788                            ColumnType::new(CDT::uint64_datatype(), true), // avg.sum(number)
789                            ColumnType::new(CDT::int64_datatype(), true),  // avg.count(number)
790                        ])
791                        .with_key(vec![1, 2])
792                        .with_time_index(Some(0))
793                        .into_named(vec![
794                            None,
795                            None,
796                            Some("number".to_string()),
797                            None,
798                            None,
799                        ]),
800                    ),
801                ),
802                mfp: MapFilterProject::new(5)
803                    .map(vec![
804                        ScalarExpr::Column(2), // number(pk)
805                        avg_expr,
806                        ScalarExpr::Column(0), // window start
807                        ScalarExpr::Column(1), // window end
808                    ])
809                    .unwrap()
810                    .project(vec![5, 6, 7, 8])
811                    .unwrap(),
812            },
813            schema: RelationType::new(vec![
814                ColumnType::new(CDT::uint32_datatype(), false), // number
815                ColumnType::new(CDT::float64_datatype(), true), // avg(number)
816                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
817                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
818            ])
819            .with_key(vec![0, 3])
820            .with_time_index(Some(2))
821            .into_named(vec![
822                Some("number".to_string()),
823                Some("avg(numbers_with_ts.number)".to_string()),
824                Some("window_start".to_string()),
825                Some("window_end".to_string()),
826            ]),
827        };
828        assert_eq!(flow_plan, expected);
829    }
830
831    #[tokio::test]
832    async fn test_tumble_parse_optional() {
833        let engine = create_test_query_engine();
834        let sql = "SELECT sum(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour')";
835        let plan = sql_to_substrait(engine.clone(), sql).await;
836
837        let mut ctx = create_test_ctx();
838        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
839            .await
840            .unwrap();
841
842        let aggr_expr = AggregateExpr {
843            func: AggregateFunc::SumUInt64,
844            expr: ScalarExpr::Column(0),
845            distinct: false,
846        };
847        let expected = TypedPlan {
848            schema: RelationType::new(vec![
849                ColumnType::new(CDT::uint64_datatype(), true), // sum(number)
850                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
851                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
852            ])
853            .with_key(vec![2])
854            .with_time_index(Some(1))
855            .into_named(vec![
856                Some("sum(numbers_with_ts.number)".to_string()),
857                Some("window_start".to_string()),
858                Some("window_end".to_string()),
859            ]),
860            plan: Plan::Mfp {
861                input: Box::new(
862                    Plan::Reduce {
863                        input: Box::new(
864                            Plan::Get {
865                                id: crate::expr::Id::Global(GlobalId::User(1)),
866                            }
867                            .with_types(
868                                RelationType::new(vec![
869                                    ColumnType::new(ConcreteDataType::uint32_datatype(), false),
870                                    ColumnType::new(
871                                        ConcreteDataType::timestamp_millisecond_datatype(),
872                                        false,
873                                    ),
874                                ])
875                                .into_named(vec![
876                                    Some("number".to_string()),
877                                    Some("ts".to_string()),
878                                ]),
879                            )
880                            .mfp(MapFilterProject::new(2).into_safe())
881                            .unwrap(),
882                        ),
883                        key_val_plan: KeyValPlan {
884                            key_plan: MapFilterProject::new(2)
885                                .map(vec![
886                                    ScalarExpr::Column(1).call_unary(
887                                        UnaryFunc::TumbleWindowFloor {
888                                            window_size: Duration::from_nanos(3_600_000_000_000),
889                                            start_time: None,
890                                        },
891                                    ),
892                                    ScalarExpr::Column(1).call_unary(
893                                        UnaryFunc::TumbleWindowCeiling {
894                                            window_size: Duration::from_nanos(3_600_000_000_000),
895                                            start_time: None,
896                                        },
897                                    ),
898                                ])
899                                .unwrap()
900                                .project(vec![2, 3])
901                                .unwrap()
902                                .into_safe(),
903                            val_plan: MapFilterProject::new(2)
904                                .map(vec![ScalarExpr::Column(0).cast(CDT::uint64_datatype())])
905                                .unwrap()
906                                .project(vec![2])
907                                .unwrap()
908                                .into_safe(),
909                        },
910                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
911                            full_aggrs: vec![aggr_expr.clone()],
912                            simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
913                            distinct_aggrs: vec![],
914                        }),
915                    }
916                    .with_types(
917                        RelationType::new(vec![
918                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
919                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
920                            ColumnType::new(CDT::uint64_datatype(), true), //sum(number)
921                        ])
922                        .with_key(vec![1])
923                        .with_time_index(Some(0))
924                        .into_named(vec![None, None, None]),
925                    ),
926                ),
927                mfp: MapFilterProject::new(3)
928                    .map(vec![
929                        ScalarExpr::Column(2),
930                        ScalarExpr::Column(0),
931                        ScalarExpr::Column(1),
932                    ])
933                    .unwrap()
934                    .project(vec![3, 4, 5])
935                    .unwrap(),
936            },
937        };
938        assert_eq!(flow_plan, expected);
939    }
940
941    #[tokio::test]
942    async fn test_tumble_parse() {
943        let engine = create_test_query_engine();
944        let sql = "SELECT sum(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour', '2021-07-01 00:00:00')";
945        let plan = sql_to_substrait(engine.clone(), sql).await;
946
947        let mut ctx = create_test_ctx();
948        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
949            .await
950            .unwrap();
951
952        let aggr_expr = AggregateExpr {
953            func: AggregateFunc::SumUInt64,
954            expr: ScalarExpr::Column(0),
955            distinct: false,
956        };
957        let expected = TypedPlan {
958            schema: RelationType::new(vec![
959                ColumnType::new(CDT::uint64_datatype(), true), // sum(number)
960                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
961                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
962            ])
963            .with_key(vec![2])
964            .with_time_index(Some(1))
965            .into_named(vec![
966                Some("sum(numbers_with_ts.number)".to_string()),
967                Some("window_start".to_string()),
968                Some("window_end".to_string()),
969            ]),
970            plan: Plan::Mfp {
971                input: Box::new(
972                    Plan::Reduce {
973                        input: Box::new(
974                            Plan::Get {
975                                id: crate::expr::Id::Global(GlobalId::User(1)),
976                            }
977                            .with_types(
978                                RelationType::new(vec![
979                                    ColumnType::new(ConcreteDataType::uint32_datatype(), false),
980                                    ColumnType::new(
981                                        ConcreteDataType::timestamp_millisecond_datatype(),
982                                        false,
983                                    ),
984                                ])
985                                .into_named(vec![
986                                    Some("number".to_string()),
987                                    Some("ts".to_string()),
988                                ]),
989                            )
990                            .mfp(MapFilterProject::new(2).into_safe())
991                            .unwrap(),
992                        ),
993                        key_val_plan: KeyValPlan {
994                            key_plan: MapFilterProject::new(2)
995                                .map(vec![
996                                    ScalarExpr::Column(1).call_unary(
997                                        UnaryFunc::TumbleWindowFloor {
998                                            window_size: Duration::from_nanos(3_600_000_000_000),
999                                            start_time: Some(Timestamp::new_millisecond(
1000                                                1625097600000,
1001                                            )),
1002                                        },
1003                                    ),
1004                                    ScalarExpr::Column(1).call_unary(
1005                                        UnaryFunc::TumbleWindowCeiling {
1006                                            window_size: Duration::from_nanos(3_600_000_000_000),
1007                                            start_time: Some(Timestamp::new_millisecond(
1008                                                1625097600000,
1009                                            )),
1010                                        },
1011                                    ),
1012                                ])
1013                                .unwrap()
1014                                .project(vec![2, 3])
1015                                .unwrap()
1016                                .into_safe(),
1017                            val_plan: MapFilterProject::new(2)
1018                                .map(vec![ScalarExpr::Column(0).cast(CDT::uint64_datatype())])
1019                                .unwrap()
1020                                .project(vec![2])
1021                                .unwrap()
1022                                .into_safe(),
1023                        },
1024                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1025                            full_aggrs: vec![aggr_expr.clone()],
1026                            simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
1027                            distinct_aggrs: vec![],
1028                        }),
1029                    }
1030                    .with_types(
1031                        RelationType::new(vec![
1032                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
1033                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
1034                            ColumnType::new(CDT::uint64_datatype(), true), //sum(number)
1035                        ])
1036                        .with_key(vec![1])
1037                        .with_time_index(Some(0))
1038                        .into_unnamed(),
1039                    ),
1040                ),
1041                mfp: MapFilterProject::new(3)
1042                    .map(vec![
1043                        ScalarExpr::Column(2),
1044                        ScalarExpr::Column(0),
1045                        ScalarExpr::Column(1),
1046                    ])
1047                    .unwrap()
1048                    .project(vec![3, 4, 5])
1049                    .unwrap(),
1050            },
1051        };
1052        assert_eq!(flow_plan, expected);
1053    }
1054
1055    #[tokio::test]
1056    async fn test_avg_group_by() {
1057        let engine = create_test_query_engine();
1058        let sql = "SELECT avg(number), number FROM numbers GROUP BY number";
1059        let plan = sql_to_substrait(engine.clone(), sql).await;
1060
1061        let mut ctx = create_test_ctx();
1062        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
1063
1064        let aggr_exprs = vec![
1065            AggregateExpr {
1066                func: AggregateFunc::SumUInt64,
1067                expr: ScalarExpr::Column(0),
1068                distinct: false,
1069            },
1070            AggregateExpr {
1071                func: AggregateFunc::Count,
1072                expr: ScalarExpr::Column(1),
1073                distinct: false,
1074            },
1075        ];
1076        let avg_expr = ScalarExpr::If {
1077            cond: Box::new(ScalarExpr::Column(2).call_binary(
1078                ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()),
1079                BinaryFunc::NotEq,
1080            )),
1081            then: Box::new(
1082                ScalarExpr::Column(1)
1083                    .cast(CDT::float64_datatype())
1084                    .call_binary(
1085                        ScalarExpr::Column(2).cast(CDT::float64_datatype()),
1086                        BinaryFunc::DivFloat64,
1087                    ),
1088            ),
1089            els: Box::new(ScalarExpr::Literal(Value::Null, CDT::float64_datatype())),
1090        };
1091        let expected = TypedPlan {
1092            schema: RelationType::new(vec![
1093                ColumnType::new(CDT::float64_datatype(), true), // avg(number: u32) -> f64
1094                ColumnType::new(CDT::uint32_datatype(), false), // number
1095            ])
1096            .with_key(vec![1])
1097            .into_named(vec![
1098                Some("avg(numbers.number)".to_string()),
1099                Some("number".to_string()),
1100            ]),
1101            plan: Plan::Mfp {
1102                input: Box::new(
1103                    Plan::Reduce {
1104                        input: Box::new(
1105                            Plan::Get {
1106                                id: crate::expr::Id::Global(GlobalId::User(0)),
1107                            }
1108                            .with_types(
1109                                RelationType::new(vec![ColumnType::new(
1110                                    ConcreteDataType::uint32_datatype(),
1111                                    false,
1112                                )])
1113                                .into_named(vec![Some("number".to_string())]),
1114                            )
1115                            .mfp(
1116                                MapFilterProject::new(1)
1117                                    .project(vec![0])
1118                                    .unwrap()
1119                                    .into_safe(),
1120                            )
1121                            .unwrap(),
1122                        ),
1123                        key_val_plan: KeyValPlan {
1124                            key_plan: MapFilterProject::new(1)
1125                                .map(vec![ScalarExpr::Column(0)])
1126                                .unwrap()
1127                                .project(vec![1])
1128                                .unwrap()
1129                                .into_safe(),
1130                            val_plan: MapFilterProject::new(1)
1131                                .map(vec![
1132                                    ScalarExpr::Column(0).cast(CDT::uint64_datatype()),
1133                                    ScalarExpr::Column(0),
1134                                ])
1135                                .unwrap()
1136                                .project(vec![1, 2])
1137                                .unwrap()
1138                                .into_safe(),
1139                        },
1140                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1141                            full_aggrs: aggr_exprs.clone(),
1142                            simple_aggrs: vec![
1143                                AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
1144                                AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1),
1145                            ],
1146                            distinct_aggrs: vec![],
1147                        }),
1148                    }
1149                    .with_types(
1150                        RelationType::new(vec![
1151                            ColumnType::new(ConcreteDataType::uint32_datatype(), false), // key: number
1152                            ColumnType::new(ConcreteDataType::uint64_datatype(), true),  // sum
1153                            ColumnType::new(ConcreteDataType::int64_datatype(), true),   // count
1154                        ])
1155                        .with_key(vec![0])
1156                        .into_named(vec![
1157                            Some("number".to_string()),
1158                            None,
1159                            None,
1160                        ]),
1161                    ),
1162                ),
1163                mfp: MapFilterProject::new(3)
1164                    .map(vec![
1165                        avg_expr, // col 3
1166                        ScalarExpr::Column(0),
1167                        // TODO(discord9): optimize mfp so to remove indirect ref
1168                    ])
1169                    .unwrap()
1170                    .project(vec![3, 4])
1171                    .unwrap(),
1172            },
1173        };
1174        assert_eq!(flow_plan.unwrap(), expected);
1175    }
1176
1177    #[tokio::test]
1178    async fn test_avg() {
1179        let engine = create_test_query_engine();
1180        let sql = "SELECT avg(number) FROM numbers";
1181        let plan = sql_to_substrait(engine.clone(), sql).await;
1182
1183        let mut ctx = create_test_ctx();
1184
1185        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
1186            .await
1187            .unwrap();
1188
1189        let aggr_exprs = vec![
1190            AggregateExpr {
1191                func: AggregateFunc::SumUInt64,
1192                expr: ScalarExpr::Column(0),
1193                distinct: false,
1194            },
1195            AggregateExpr {
1196                func: AggregateFunc::Count,
1197                expr: ScalarExpr::Column(1),
1198                distinct: false,
1199            },
1200        ];
1201        let avg_expr = ScalarExpr::If {
1202            cond: Box::new(ScalarExpr::Column(1).call_binary(
1203                ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()),
1204                BinaryFunc::NotEq,
1205            )),
1206            then: Box::new(
1207                ScalarExpr::Column(0)
1208                    .cast(CDT::float64_datatype())
1209                    .call_binary(
1210                        ScalarExpr::Column(1).cast(CDT::float64_datatype()),
1211                        BinaryFunc::DivFloat64,
1212                    ),
1213            ),
1214            els: Box::new(ScalarExpr::Literal(Value::Null, CDT::float64_datatype())),
1215        };
1216        let input = Box::new(
1217            Plan::Get {
1218                id: crate::expr::Id::Global(GlobalId::User(0)),
1219            }
1220            .with_types(
1221                RelationType::new(vec![ColumnType::new(
1222                    ConcreteDataType::uint32_datatype(),
1223                    false,
1224                )])
1225                .into_named(vec![Some("number".to_string())]),
1226            ),
1227        );
1228        let expected = TypedPlan {
1229            schema: RelationType::new(vec![ColumnType::new(CDT::float64_datatype(), true)])
1230                .into_named(vec![Some("avg(numbers.number)".to_string())]),
1231            plan: Plan::Mfp {
1232                input: Box::new(
1233                    Plan::Reduce {
1234                        input: Box::new(
1235                            Plan::Mfp {
1236                                input: input.clone(),
1237                                mfp: MapFilterProject::new(1).project(vec![0]).unwrap(),
1238                            }
1239                            .with_types(
1240                                RelationType::new(vec![ColumnType::new(
1241                                    CDT::uint32_datatype(),
1242                                    false,
1243                                )])
1244                                .into_named(vec![Some("number".to_string())]),
1245                            ),
1246                        ),
1247                        key_val_plan: KeyValPlan {
1248                            key_plan: MapFilterProject::new(1)
1249                                .project(vec![])
1250                                .unwrap()
1251                                .into_safe(),
1252                            val_plan: MapFilterProject::new(1)
1253                                .map(vec![
1254                                    ScalarExpr::Column(0).cast(CDT::uint64_datatype()),
1255                                    ScalarExpr::Column(0),
1256                                ])
1257                                .unwrap()
1258                                .project(vec![1, 2])
1259                                .unwrap()
1260                                .into_safe(),
1261                        },
1262                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1263                            full_aggrs: aggr_exprs.clone(),
1264                            simple_aggrs: vec![
1265                                AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
1266                                AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1),
1267                            ],
1268                            distinct_aggrs: vec![],
1269                        }),
1270                    }
1271                    .with_types(
1272                        RelationType::new(vec![
1273                            ColumnType::new(ConcreteDataType::uint64_datatype(), true), // sum
1274                            ColumnType::new(ConcreteDataType::int64_datatype(), true),  // count
1275                        ])
1276                        .into_named(vec![None, None]),
1277                    ),
1278                ),
1279                mfp: MapFilterProject::new(2)
1280                    .map(vec![
1281                        avg_expr,
1282                        // TODO(discord9): optimize mfp so to remove indirect ref
1283                    ])
1284                    .unwrap()
1285                    .project(vec![2])
1286                    .unwrap(),
1287            },
1288        };
1289        assert_eq!(flow_plan, expected);
1290    }
1291
1292    #[tokio::test]
1293    async fn test_sum() {
1294        let engine = create_test_query_engine();
1295        let sql = "SELECT sum(number) FROM numbers";
1296        let plan = sql_to_substrait(engine.clone(), sql).await;
1297
1298        let mut ctx = create_test_ctx();
1299        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
1300
1301        let aggr_expr = AggregateExpr {
1302            func: AggregateFunc::SumUInt64,
1303            expr: ScalarExpr::Column(0),
1304            distinct: false,
1305        };
1306        let expected = TypedPlan {
1307            schema: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)])
1308                .into_named(vec![Some("sum(numbers.number)".to_string())]),
1309            plan: Plan::Reduce {
1310                input: Box::new(
1311                    Plan::Get {
1312                        id: crate::expr::Id::Global(GlobalId::User(0)),
1313                    }
1314                    .with_types(
1315                        RelationType::new(vec![ColumnType::new(
1316                            ConcreteDataType::uint32_datatype(),
1317                            false,
1318                        )])
1319                        .into_named(vec![Some("number".to_string())]),
1320                    )
1321                    .mfp(MapFilterProject::new(1).into_safe())
1322                    .unwrap(),
1323                ),
1324                key_val_plan: KeyValPlan {
1325                    key_plan: MapFilterProject::new(1)
1326                        .project(vec![])
1327                        .unwrap()
1328                        .into_safe(),
1329                    val_plan: MapFilterProject::new(1)
1330                        .map(vec![ScalarExpr::Column(0)
1331                            .call_unary(UnaryFunc::Cast(CDT::uint64_datatype()))])
1332                        .unwrap()
1333                        .project(vec![1])
1334                        .unwrap()
1335                        .into_safe(),
1336                },
1337                reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1338                    full_aggrs: vec![aggr_expr.clone()],
1339                    simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
1340                    distinct_aggrs: vec![],
1341                }),
1342            },
1343        };
1344        assert_eq!(flow_plan.unwrap(), expected);
1345    }
1346
1347    #[tokio::test]
1348    async fn test_distinct_number() {
1349        let engine = create_test_query_engine();
1350        let sql = "SELECT DISTINCT number FROM numbers";
1351        let plan = sql_to_substrait(engine.clone(), sql).await;
1352
1353        let mut ctx = create_test_ctx();
1354        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
1355            .await
1356            .unwrap();
1357
1358        let expected = TypedPlan {
1359            schema: RelationType::new(vec![
1360                ColumnType::new(CDT::uint32_datatype(), false), // col number
1361            ])
1362            .with_key(vec![0])
1363            .into_named(vec![Some("number".to_string())]),
1364            plan: Plan::Reduce {
1365                input: Box::new(
1366                    Plan::Get {
1367                        id: crate::expr::Id::Global(GlobalId::User(0)),
1368                    }
1369                    .with_types(
1370                        RelationType::new(vec![ColumnType::new(
1371                            ConcreteDataType::uint32_datatype(),
1372                            false,
1373                        )])
1374                        .into_named(vec![Some("number".to_string())]),
1375                    )
1376                    .mfp(MapFilterProject::new(1).into_safe())
1377                    .unwrap(),
1378                ),
1379                key_val_plan: KeyValPlan {
1380                    key_plan: MapFilterProject::new(1)
1381                        .map(vec![ScalarExpr::Column(0)])
1382                        .unwrap()
1383                        .project(vec![1])
1384                        .unwrap()
1385                        .into_safe(),
1386                    val_plan: MapFilterProject::new(1)
1387                        .project(vec![0])
1388                        .unwrap()
1389                        .into_safe(),
1390                },
1391                reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1392                    full_aggrs: vec![],
1393                    simple_aggrs: vec![],
1394                    distinct_aggrs: vec![],
1395                }),
1396            },
1397        };
1398
1399        assert_eq!(flow_plan, expected);
1400    }
1401
1402    #[tokio::test]
1403    async fn test_sum_group_by() {
1404        let engine = create_test_query_engine();
1405        let sql = "SELECT sum(number), number FROM numbers GROUP BY number";
1406        let plan = sql_to_substrait(engine.clone(), sql).await;
1407
1408        let mut ctx = create_test_ctx();
1409        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
1410            .await
1411            .unwrap();
1412
1413        let aggr_expr = AggregateExpr {
1414            func: AggregateFunc::SumUInt64,
1415            expr: ScalarExpr::Column(0),
1416            distinct: false,
1417        };
1418        let expected = TypedPlan {
1419            schema: RelationType::new(vec![
1420                ColumnType::new(CDT::uint64_datatype(), true), // col sum(number)
1421                ColumnType::new(CDT::uint32_datatype(), false), // col number
1422            ])
1423            .with_key(vec![1])
1424            .into_named(vec![
1425                Some("sum(numbers.number)".to_string()),
1426                Some("number".to_string()),
1427            ]),
1428            plan: Plan::Mfp {
1429                input: Box::new(
1430                    Plan::Reduce {
1431                        input: Box::new(
1432                            Plan::Get {
1433                                id: crate::expr::Id::Global(GlobalId::User(0)),
1434                            }
1435                            .with_types(
1436                                RelationType::new(vec![ColumnType::new(
1437                                    ConcreteDataType::uint32_datatype(),
1438                                    false,
1439                                )])
1440                                .into_named(vec![Some("number".to_string())]),
1441                            )
1442                            .mfp(MapFilterProject::new(1).into_safe())
1443                            .unwrap(),
1444                        ),
1445                        key_val_plan: KeyValPlan {
1446                            key_plan: MapFilterProject::new(1)
1447                                .map(vec![ScalarExpr::Column(0)])
1448                                .unwrap()
1449                                .project(vec![1])
1450                                .unwrap()
1451                                .into_safe(),
1452                            val_plan: MapFilterProject::new(1)
1453                                .map(vec![ScalarExpr::Column(0)
1454                                    .call_unary(UnaryFunc::Cast(CDT::uint64_datatype()))])
1455                                .unwrap()
1456                                .project(vec![1])
1457                                .unwrap()
1458                                .into_safe(),
1459                        },
1460                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1461                            full_aggrs: vec![aggr_expr.clone()],
1462                            simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
1463                            distinct_aggrs: vec![],
1464                        }),
1465                    }
1466                    .with_types(
1467                        RelationType::new(vec![
1468                            ColumnType::new(CDT::uint32_datatype(), false), // col number
1469                            ColumnType::new(CDT::uint64_datatype(), true),  // col sum(number)
1470                        ])
1471                        .with_key(vec![0])
1472                        .into_named(vec![Some("number".to_string()), None]),
1473                    ),
1474                ),
1475                mfp: MapFilterProject::new(2)
1476                    .map(vec![ScalarExpr::Column(1), ScalarExpr::Column(0)])
1477                    .unwrap()
1478                    .project(vec![2, 3])
1479                    .unwrap(),
1480            },
1481        };
1482
1483        assert_eq!(flow_plan, expected);
1484    }
1485
1486    #[tokio::test]
1487    async fn test_sum_add() {
1488        let engine = create_test_query_engine();
1489        let sql = "SELECT sum(number+number) FROM numbers";
1490        let plan = sql_to_substrait(engine.clone(), sql).await;
1491
1492        let mut ctx = create_test_ctx();
1493        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
1494
1495        let aggr_expr = AggregateExpr {
1496            func: AggregateFunc::SumUInt64,
1497            expr: ScalarExpr::Column(0),
1498            distinct: false,
1499        };
1500        let expected = TypedPlan {
1501            schema: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)])
1502                .into_named(vec![Some(
1503                    "sum(numbers.number + numbers.number)".to_string(),
1504                )]),
1505            plan: Plan::Reduce {
1506                input: Box::new(
1507                    Plan::Mfp {
1508                        input: Box::new(
1509                            Plan::Get {
1510                                id: crate::expr::Id::Global(GlobalId::User(0)),
1511                            }
1512                            .with_types(
1513                                RelationType::new(vec![ColumnType::new(
1514                                    ConcreteDataType::uint32_datatype(),
1515                                    false,
1516                                )])
1517                                .into_named(vec![Some("number".to_string())]),
1518                            ),
1519                        ),
1520                        mfp: MapFilterProject::new(1),
1521                    }
1522                    .with_types(
1523                        RelationType::new(vec![ColumnType::new(
1524                            ConcreteDataType::uint32_datatype(),
1525                            false,
1526                        )])
1527                        .into_named(vec![Some("number".to_string())]),
1528                    ),
1529                ),
1530                key_val_plan: KeyValPlan {
1531                    key_plan: MapFilterProject::new(1)
1532                        .project(vec![])
1533                        .unwrap()
1534                        .into_safe(),
1535                    val_plan: MapFilterProject::new(1)
1536                        .map(vec![ScalarExpr::Column(0)
1537                            .call_binary(ScalarExpr::Column(0), BinaryFunc::AddUInt32)
1538                            .call_unary(UnaryFunc::Cast(CDT::uint64_datatype()))])
1539                        .unwrap()
1540                        .project(vec![1])
1541                        .unwrap()
1542                        .into_safe(),
1543                },
1544                reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1545                    full_aggrs: vec![aggr_expr.clone()],
1546                    simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
1547                    distinct_aggrs: vec![],
1548                }),
1549            },
1550        };
1551        assert_eq!(flow_plan.unwrap(), expected);
1552    }
1553
1554    #[tokio::test]
1555    async fn test_cast_max_min() {
1556        let engine = create_test_query_engine();
1557        let sql = "SELECT (max(number) - min(number))/30.0, date_bin(INTERVAL '30 second', CAST(ts AS TimestampMillisecond)) as time_window from numbers_with_ts GROUP BY time_window";
1558        let plan = sql_to_substrait(engine.clone(), sql).await;
1559
1560        let mut ctx = create_test_ctx();
1561        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
1562
1563        let aggr_exprs = vec![
1564            AggregateExpr {
1565                func: AggregateFunc::MaxUInt32,
1566                expr: ScalarExpr::Column(0),
1567                distinct: false,
1568            },
1569            AggregateExpr {
1570                func: AggregateFunc::MinUInt32,
1571                expr: ScalarExpr::Column(0),
1572                distinct: false,
1573            },
1574        ];
1575        let expected = TypedPlan {
1576            schema: RelationType::new(vec![
1577                ColumnType::new(CDT::float64_datatype(), true),
1578                ColumnType::new(CDT::timestamp_millisecond_datatype(), true),
1579            ])
1580            .with_time_index(Some(1))
1581            .into_named(vec![
1582                Some(
1583                    "max(numbers_with_ts.number) - min(numbers_with_ts.number) / Float64(30)"
1584                        .to_string(),
1585                ),
1586                Some("time_window".to_string()),
1587            ]),
1588            plan: Plan::Mfp {
1589                input: Box::new(
1590                    Plan::Reduce {
1591                        input: Box::new(
1592                            Plan::Get {
1593                                id: crate::expr::Id::Global(GlobalId::User(1)),
1594                            }
1595                            .with_types(
1596                                RelationType::new(vec![
1597                                    ColumnType::new(ConcreteDataType::uint32_datatype(), false),
1598                                    ColumnType::new(ConcreteDataType::timestamp_millisecond_datatype(), false),
1599                                ])
1600                                .into_named(vec![
1601                                    Some("number".to_string()),
1602                                    Some("ts".to_string()),
1603                                ]),
1604                            )
1605                            .mfp(MapFilterProject::new(2).into_safe())
1606                            .unwrap(),
1607                        ),
1608
1609                        key_val_plan: KeyValPlan {
1610                            key_plan: MapFilterProject::new(2)
1611                                .map(vec![ScalarExpr::CallDf {
1612                                    df_scalar_fn: DfScalarFunction::try_from_raw_fn(
1613                                        RawDfScalarFn {
1614                                            f: BytesMut::from(
1615                                                b"\x08\x02\"\x0f\x1a\r\n\x0b\xa2\x02\x08\n\0\x12\x04\x10\x1e \t\"\n\x1a\x08\x12\x06\n\x04\x12\x02\x08\x01".as_ref(),
1616                                            ),
1617                                            input_schema: RelationType::new(vec![ColumnType::new(
1618                                                ConcreteDataType::interval_month_day_nano_datatype(),
1619                                                true,
1620                                            ),ColumnType::new(
1621                                                ConcreteDataType::timestamp_millisecond_datatype(),
1622                                                false,
1623                                            )])
1624                                            .into_unnamed(),
1625                                            extensions: FunctionExtensions::from_iter([
1626                                                    (0, "subtract".to_string()),
1627                                                    (1, "divide".to_string()),
1628                                                    (2, "date_bin".to_string()),
1629                                                    (3, "max".to_string()),
1630                                                    (4, "min".to_string()),
1631                                                ]),
1632                                        },
1633                                    )
1634                                    .await
1635                                    .unwrap(),
1636                                    exprs: vec![
1637                                        ScalarExpr::Literal(
1638                                            Value::IntervalMonthDayNano(IntervalMonthDayNano::new(0, 0, 30000000000)),
1639                                            CDT::interval_month_day_nano_datatype()
1640                                        ),
1641                                        ScalarExpr::Column(1)
1642                                        ],
1643                                }])
1644                                .unwrap()
1645                                .project(vec![2])
1646                                .unwrap()
1647                                .into_safe(),
1648                            val_plan: MapFilterProject::new(2)
1649                                .into_safe(),
1650                        },
1651                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1652                            full_aggrs: aggr_exprs.clone(),
1653                            simple_aggrs: vec![AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
1654                            AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1)],
1655                            distinct_aggrs: vec![],
1656                        }),
1657                    }
1658                    .with_types(
1659                        RelationType::new(vec![
1660                            ColumnType::new(
1661                                ConcreteDataType::timestamp_millisecond_datatype(),
1662                                true,
1663                            ), // time_window
1664                            ColumnType::new(ConcreteDataType::uint32_datatype(), true), // max
1665                            ColumnType::new(ConcreteDataType::uint32_datatype(), true), // min
1666                        ])
1667                        .with_time_index(Some(0))
1668                        .into_unnamed(),
1669                    ),
1670                ),
1671                mfp: MapFilterProject::new(3)
1672                    .map(vec![
1673                        ScalarExpr::Column(1)
1674                            .call_binary(ScalarExpr::Column(2), BinaryFunc::SubUInt32)
1675                            .cast(CDT::float64_datatype())
1676                            .call_binary(
1677                                ScalarExpr::Literal(Value::from(30.0f64), CDT::float64_datatype()),
1678                                BinaryFunc::DivFloat64,
1679                            ),
1680                        ScalarExpr::Column(0),
1681                    ])
1682                    .unwrap()
1683                    .project(vec![3, 4])
1684                    .unwrap(),
1685            },
1686        };
1687
1688        assert_eq!(flow_plan.unwrap(), expected);
1689    }
1690}