flow/transform/
aggr.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use snafu::OptionExt;
17use substrait_proto::proto;
18use substrait_proto::proto::aggregate_function::AggregationInvocation;
19use substrait_proto::proto::aggregate_rel::{Grouping, Measure};
20use substrait_proto::proto::function_argument::ArgType;
21
22use crate::error::{Error, NotImplementedSnafu, PlanSnafu};
23use crate::expr::{
24    AggregateExpr, AggregateFunc, MapFilterProject, ScalarExpr, TypedExpr, UnaryFunc,
25};
26use crate::plan::{AccumulablePlan, AggrWithIndex, KeyValPlan, Plan, ReducePlan, TypedPlan};
27use crate::repr::{ColumnType, RelationDesc, RelationType};
28use crate::transform::{FlownodeContext, FunctionExtensions, substrait_proto};
29
30impl TypedExpr {
31    /// Allow `deprecated` due to the usage of deprecated grouping_expressions on datafusion to substrait side
32    #[allow(deprecated)]
33    async fn from_substrait_agg_grouping(
34        ctx: &mut FlownodeContext,
35        grouping_expressions: &[proto::Expression],
36        groupings: &[Grouping],
37        typ: &RelationDesc,
38        extensions: &FunctionExtensions,
39    ) -> Result<Vec<TypedExpr>, Error> {
40        let _ = ctx;
41        let mut group_expr = vec![];
42        match groupings.len() {
43            1 => {
44                // handle case when deprecated grouping_expressions is referenced by index is empty
45                let expressions: Box<dyn Iterator<Item = &proto::Expression> + Send> = if groupings
46                    [0]
47                .expression_references
48                .is_empty()
49                {
50                    Box::new(groupings[0].grouping_expressions.iter())
51                } else {
52                    if groupings[0]
53                        .expression_references
54                        .iter()
55                        .any(|idx| *idx as usize >= grouping_expressions.len())
56                    {
57                        return PlanSnafu {
58                            reason: format!("Invalid grouping expression reference: {:?} for grouping expr: {:?}", 
59                            groupings[0].expression_references,
60                            grouping_expressions
61                        ),
62                        }.fail()?;
63                    }
64                    Box::new(
65                        groupings[0]
66                            .expression_references
67                            .iter()
68                            .map(|idx| &grouping_expressions[*idx as usize]),
69                    )
70                };
71                for e in expressions {
72                    let x = TypedExpr::from_substrait_rex(e, typ, extensions).await?;
73                    group_expr.push(x);
74                }
75            }
76            _ => {
77                return not_impl_err!(
78                    "Grouping sets not support yet, use union all with group by instead."
79                );
80            }
81        };
82        Ok(group_expr)
83    }
84}
85
86impl AggregateExpr {
87    /// Convert list of `Measure` into Flow's AggregateExpr
88    ///
89    /// Return both the AggregateExpr and a MapFilterProject that is the final output of the aggregate function
90    async fn from_substrait_agg_measures(
91        ctx: &mut FlownodeContext,
92        measures: &[Measure],
93        typ: &RelationDesc,
94        extensions: &FunctionExtensions,
95    ) -> Result<Vec<AggregateExpr>, Error> {
96        let _ = ctx;
97        let mut all_aggr_exprs = vec![];
98
99        for m in measures {
100            let filter = match m
101                .filter
102                .as_ref()
103                .map(|fil| TypedExpr::from_substrait_rex(fil, typ, extensions))
104            {
105                Some(fut) => Some(fut.await),
106                None => None,
107            }
108            .transpose()?;
109
110            let aggr_expr = match &m.measure {
111                Some(f) => {
112                    let distinct = match f.invocation {
113                        _ if f.invocation == AggregationInvocation::Distinct as i32 => true,
114                        _ if f.invocation == AggregationInvocation::All as i32 => false,
115                        _ => false,
116                    };
117                    AggregateExpr::from_substrait_agg_func(
118                        f, typ, extensions, &filter, // TODO(discord9): impl order_by
119                        &None, distinct,
120                    )
121                    .await?
122                }
123                None => {
124                    return not_impl_err!("Aggregate without aggregate function is not supported");
125                }
126            };
127
128            all_aggr_exprs.extend(aggr_expr);
129        }
130
131        Ok(all_aggr_exprs)
132    }
133
134    /// Convert AggregateFunction into Flow's AggregateExpr
135    ///
136    /// the returned value is a tuple of AggregateExpr and a optional ScalarExpr that if exist is the final output of the aggregate function
137    /// since aggr functions like `avg` need to be transform to `sum(x)/cast(count(x) as x_type)`
138    pub async fn from_substrait_agg_func(
139        f: &proto::AggregateFunction,
140        input_schema: &RelationDesc,
141        extensions: &FunctionExtensions,
142        filter: &Option<TypedExpr>,
143        order_by: &Option<Vec<TypedExpr>>,
144        distinct: bool,
145    ) -> Result<Vec<AggregateExpr>, Error> {
146        // TODO(discord9): impl filter
147        let _ = filter;
148        let _ = order_by;
149        let mut args = vec![];
150        for arg in &f.arguments {
151            let arg_expr = match &arg.arg_type {
152                Some(ArgType::Value(e)) => {
153                    TypedExpr::from_substrait_rex(e, input_schema, extensions).await
154                }
155                _ => not_impl_err!("Aggregated function argument non-Value type not supported"),
156            }?;
157            args.push(arg_expr);
158        }
159
160        if args.len() != 1 {
161            let fn_name = extensions.get(&f.function_reference).cloned();
162            return not_impl_err!(
163                "Aggregated function (name={:?}) with multiple arguments is not supported",
164                fn_name
165            );
166        }
167
168        let arg = if let Some(first) = args.first() {
169            first
170        } else {
171            return not_impl_err!("Aggregated function without arguments is not supported");
172        };
173
174        let fn_name = extensions
175            .get(&f.function_reference)
176            .cloned()
177            .map(|s| s.to_lowercase());
178
179        match fn_name.as_ref().map(|s| s.as_ref()) {
180            Some(function_name) => {
181                let func = AggregateFunc::from_str_and_type(
182                    function_name,
183                    Some(arg.typ.scalar_type.clone()),
184                )?;
185                let exprs = vec![AggregateExpr {
186                    func,
187                    expr: arg.expr.clone(),
188                    distinct,
189                }];
190                Ok(exprs)
191            }
192            None => not_impl_err!(
193                "Aggregated function not found: function anchor = {:?}",
194                f.function_reference
195            ),
196        }
197    }
198}
199
200impl KeyValPlan {
201    /// Generate KeyValPlan from AggregateExpr and group_exprs
202    ///
203    /// will also change aggregate expr to use column ref if necessary
204    fn from_substrait_gen_key_val_plan(
205        aggr_exprs: &mut [AggregateExpr],
206        group_exprs: &[TypedExpr],
207        input_arity: usize,
208    ) -> Result<KeyValPlan, Error> {
209        let group_expr_val = group_exprs
210            .iter()
211            .map(|expr| expr.expr.clone())
212            .collect_vec();
213        let output_arity = group_expr_val.len();
214        let key_plan = MapFilterProject::new(input_arity)
215            .map(group_expr_val)?
216            .project(input_arity..input_arity + output_arity)?;
217
218        // val_plan is extracted from aggr_exprs to give aggr function it's necessary input
219        // and since aggr func need inputs that is column ref, we just add a prefix mfp to transform any expr that is not into a column ref
220        let val_plan = {
221            let need_mfp = aggr_exprs.iter().any(|agg| agg.expr.as_column().is_none());
222            if need_mfp {
223                // create mfp from aggr_expr, and modify aggr_expr to use the output column of mfp
224                let input_exprs = aggr_exprs
225                    .iter_mut()
226                    .enumerate()
227                    .map(|(idx, aggr)| {
228                        let ret = aggr.expr.clone();
229                        aggr.expr = ScalarExpr::Column(idx);
230                        ret
231                    })
232                    .collect_vec();
233                let aggr_arity = aggr_exprs.len();
234
235                MapFilterProject::new(input_arity)
236                    .map(input_exprs)?
237                    .project(input_arity..input_arity + aggr_arity)?
238            } else {
239                // simply take all inputs as value
240                MapFilterProject::new(input_arity)
241            }
242        };
243        Ok(KeyValPlan {
244            key_plan: key_plan.into_safe(),
245            val_plan: val_plan.into_safe(),
246        })
247    }
248}
249
250/// find out the column that should be time index in group exprs(which is all columns that should be keys)
251/// TODO(discord9): better ways to assign time index
252/// for now, it will found the first column that is timestamp or has a tumble window floor function
253fn find_time_index_in_group_exprs(group_exprs: &[TypedExpr]) -> Option<usize> {
254    group_exprs.iter().position(|expr| {
255        matches!(
256            &expr.expr,
257            ScalarExpr::CallUnary {
258                func: UnaryFunc::TumbleWindowFloor { .. },
259                expr: _
260            }
261        ) || expr.typ.scalar_type.is_timestamp()
262    })
263}
264
265impl TypedPlan {
266    /// Convert AggregateRel into Flow's TypedPlan
267    ///
268    /// The output of aggr plan is:
269    ///
270    /// <group_exprs>..<aggr_exprs>
271    #[async_recursion::async_recursion]
272    pub async fn from_substrait_agg_rel(
273        ctx: &mut FlownodeContext,
274        agg: &proto::AggregateRel,
275        extensions: &FunctionExtensions,
276    ) -> Result<TypedPlan, Error> {
277        let input = if let Some(input) = agg.input.as_ref() {
278            TypedPlan::from_substrait_rel(ctx, input, extensions).await?
279        } else {
280            return not_impl_err!("Aggregate without an input is not supported");
281        };
282
283        let group_exprs = TypedExpr::from_substrait_agg_grouping(
284            ctx,
285            &agg.grouping_expressions,
286            &agg.groupings,
287            &input.schema,
288            extensions,
289        )
290        .await?;
291
292        let time_index = find_time_index_in_group_exprs(&group_exprs);
293
294        let mut aggr_exprs = AggregateExpr::from_substrait_agg_measures(
295            ctx,
296            &agg.measures,
297            &input.schema,
298            extensions,
299        )
300        .await?;
301
302        let key_val_plan = KeyValPlan::from_substrait_gen_key_val_plan(
303            &mut aggr_exprs,
304            &group_exprs,
305            input.schema.typ.column_types.len(),
306        )?;
307
308        // output type is group_exprs + aggr_exprs
309        let output_type = {
310            let mut output_types = Vec::new();
311            // give best effort to get column name
312            let mut output_names = Vec::new();
313
314            // first append group_expr as key, then aggr_expr as value
315            for expr in group_exprs.iter() {
316                output_types.push(expr.typ.clone());
317                let col_name = match &expr.expr {
318                    ScalarExpr::Column(col) => input.schema.get_name(*col).clone(),
319                    // TODO(discord9): impl& use ScalarExpr.display_name, which recursively build expr's name
320                    _ => None,
321                };
322                output_names.push(col_name)
323            }
324
325            for aggr in &aggr_exprs {
326                output_types.push(ColumnType::new_nullable(
327                    aggr.func.signature().output.clone(),
328                ));
329                // TODO(discord9): find a clever way to name them?
330                output_names.push(None);
331            }
332            // TODO(discord9): try best to get time
333            if group_exprs.is_empty() {
334                RelationType::new(output_types)
335            } else {
336                RelationType::new(output_types).with_key((0..group_exprs.len()).collect_vec())
337            }
338            .with_time_index(time_index)
339            .into_named(output_names)
340        };
341
342        // copy aggr_exprs to full_aggrs, and split them into simple_aggrs and distinct_aggrs
343        // also set them input/output column
344        let full_aggrs = aggr_exprs;
345        let mut simple_aggrs = Vec::new();
346        let mut distinct_aggrs = Vec::new();
347        for (output_column, aggr_expr) in full_aggrs.iter().enumerate() {
348            let input_column = aggr_expr.expr.as_column().with_context(|| PlanSnafu {
349                reason: "Expect aggregate argument to be transformed into a column at this point",
350            })?;
351            if aggr_expr.distinct {
352                distinct_aggrs.push(AggrWithIndex::new(
353                    aggr_expr.clone(),
354                    input_column,
355                    output_column,
356                ));
357            } else {
358                simple_aggrs.push(AggrWithIndex::new(
359                    aggr_expr.clone(),
360                    input_column,
361                    output_column,
362                ));
363            }
364        }
365        let accum_plan = AccumulablePlan {
366            full_aggrs,
367            simple_aggrs,
368            distinct_aggrs,
369        };
370        let plan = Plan::Reduce {
371            input: Box::new(input),
372            key_val_plan,
373            reduce_plan: ReducePlan::Accumulable(accum_plan),
374        };
375        // FIX(discord9): deal with key first
376        return Ok(TypedPlan {
377            schema: output_type,
378            plan,
379        });
380    }
381}
382
383#[cfg(test)]
384mod test {
385    use std::time::Duration;
386
387    use bytes::BytesMut;
388    use common_time::{IntervalMonthDayNano, Timestamp};
389    use datatypes::data_type::ConcreteDataType as CDT;
390    use datatypes::prelude::ConcreteDataType;
391    use datatypes::value::Value;
392    use pretty_assertions::assert_eq;
393
394    use super::*;
395    use crate::expr::{BinaryFunc, DfScalarFunction, GlobalId, RawDfScalarFn};
396    use crate::plan::{Plan, TypedPlan};
397    use crate::repr::{ColumnType, RelationType};
398    use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait};
399
400    #[tokio::test]
401    async fn test_df_func_basic() {
402        let engine = create_test_query_engine();
403        let sql = "SELECT sum(abs(number)) FROM numbers_with_ts GROUP BY tumble(ts, '1 second', '2021-07-01 00:00:00');";
404        let plan = sql_to_substrait(engine.clone(), sql).await;
405
406        let mut ctx = create_test_ctx();
407        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
408            .await
409            .unwrap();
410
411        let aggr_expr = AggregateExpr {
412            func: AggregateFunc::SumUInt64,
413            expr: ScalarExpr::Column(0),
414            distinct: false,
415        };
416        let expected =
417            TypedPlan {
418                schema: RelationType::new(vec![
419                    ColumnType::new(CDT::uint64_datatype(), true), // sum(number)
420                    ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
421                    ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
422                ])
423                .with_key(vec![2])
424                .with_time_index(Some(1))
425                .into_named(vec![
426                    Some("sum(abs(numbers_with_ts.number))".to_string()),
427                    Some("window_start".to_string()),
428                    Some("window_end".to_string()),
429                ]),
430                plan: Plan::Mfp {
431                    input: Box::new(
432                        Plan::Reduce {
433                            input: Box::new(
434                                Plan::Get {
435                                    id: crate::expr::Id::Global(GlobalId::User(1)),
436                                }
437                                .with_types(
438                                    RelationType::new(vec![
439                                        ColumnType::new(ConcreteDataType::uint32_datatype(), false),
440                                        ColumnType::new(
441                                            ConcreteDataType::timestamp_millisecond_datatype(),
442                                            false,
443                                        ),
444                                    ])
445                                    .into_named(vec![
446                                        Some("number".to_string()),
447                                        Some("ts".to_string()),
448                                    ]),
449                                )
450                                .mfp(MapFilterProject::new(2).into_safe())
451                                .unwrap(),
452                            ),
453                            key_val_plan: KeyValPlan {
454                                key_plan: MapFilterProject::new(2)
455                                    .map(vec![
456                                        ScalarExpr::Column(1).call_unary(
457                                            UnaryFunc::TumbleWindowFloor {
458                                                window_size: Duration::from_nanos(1_000_000_000),
459                                                start_time: Some(Timestamp::new_millisecond(
460                                                    1625097600000,
461                                                )),
462                                            },
463                                        ),
464                                        ScalarExpr::Column(1).call_unary(
465                                            UnaryFunc::TumbleWindowCeiling {
466                                                window_size: Duration::from_nanos(1_000_000_000),
467                                                start_time: Some(Timestamp::new_millisecond(
468                                                    1625097600000,
469                                                )),
470                                            },
471                                        ),
472                                    ])
473                                    .unwrap()
474                                    .project(vec![2, 3])
475                                    .unwrap()
476                                    .into_safe(),
477                                val_plan: MapFilterProject::new(2)
478                                    .map(vec![ScalarExpr::CallDf {
479                                    df_scalar_fn: DfScalarFunction::try_from_raw_fn(
480                                        RawDfScalarFn {
481                                            f: BytesMut::from(
482                                                b"\x08\x02\"\x08\x1a\x06\x12\x04\n\x02\x12\0"
483                                                    .as_ref(),
484                                            ),
485                                            input_schema: RelationType::new(vec![ColumnType::new(
486                                                ConcreteDataType::uint32_datatype(),
487                                                false,
488                                            )])
489                                            .into_unnamed(),
490                                            extensions: FunctionExtensions::from_iter(
491                                                [
492                                                    (0, "tumble_start".to_string()),
493                                                    (1, "tumble_end".to_string()),
494                                                    (2, "abs".to_string()),
495                                                    (3, "sum".to_string()),
496                                                ]
497                                                .into_iter(),
498                                            ),
499                                        },
500                                    )
501                                    .await
502                                    .unwrap(),
503                                    exprs: vec![ScalarExpr::Column(0)],
504                                }
505                                .cast(CDT::uint64_datatype())])
506                                    .unwrap()
507                                    .project(vec![2])
508                                    .unwrap()
509                                    .into_safe(),
510                            },
511                            reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
512                                full_aggrs: vec![aggr_expr.clone()],
513                                simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
514                                distinct_aggrs: vec![],
515                            }),
516                        }
517                        .with_types(
518                            RelationType::new(vec![
519                                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
520                                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
521                                ColumnType::new(CDT::uint64_datatype(), true), //sum(number)
522                            ])
523                            .with_key(vec![1])
524                            .with_time_index(Some(0))
525                            .into_unnamed(),
526                        ),
527                    ),
528                    mfp: MapFilterProject::new(3)
529                        .map(vec![
530                            ScalarExpr::Column(2),
531                            ScalarExpr::Column(0),
532                            ScalarExpr::Column(1),
533                        ])
534                        .unwrap()
535                        .project(vec![3, 4, 5])
536                        .unwrap(),
537                },
538            };
539        assert_eq!(flow_plan, expected);
540    }
541
542    #[tokio::test]
543    async fn test_df_func_expr_tree() {
544        let engine = create_test_query_engine();
545        let sql = "SELECT abs(sum(number)) FROM numbers_with_ts GROUP BY tumble(ts, '1 second', '2021-07-01 00:00:00');";
546        let plan = sql_to_substrait(engine.clone(), sql).await;
547
548        let mut ctx = create_test_ctx();
549        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
550            .await
551            .unwrap();
552
553        let aggr_expr = AggregateExpr {
554            func: AggregateFunc::SumUInt64,
555            expr: ScalarExpr::Column(0),
556            distinct: false,
557        };
558        let expected = TypedPlan {
559            schema: RelationType::new(vec![
560                ColumnType::new(CDT::uint64_datatype(), true), // sum(number)
561                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
562                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
563            ])
564            .with_key(vec![2])
565            .with_time_index(Some(1))
566            .into_named(vec![
567                Some("abs(sum(numbers_with_ts.number))".to_string()),
568                Some("window_start".to_string()),
569                Some("window_end".to_string()),
570            ]),
571            plan: Plan::Mfp {
572                input: Box::new(
573                    Plan::Reduce {
574                        input: Box::new(
575                            Plan::Get {
576                                id: crate::expr::Id::Global(GlobalId::User(1)),
577                            }
578                            .with_types(
579                                RelationType::new(vec![
580                                    ColumnType::new(ConcreteDataType::uint32_datatype(), false),
581                                    ColumnType::new(
582                                        ConcreteDataType::timestamp_millisecond_datatype(),
583                                        false,
584                                    ),
585                                ])
586                                .into_named(vec![
587                                    Some("number".to_string()),
588                                    Some("ts".to_string()),
589                                ]),
590                            )
591                            .mfp(MapFilterProject::new(2).into_safe())
592                            .unwrap(),
593                        ),
594                        key_val_plan: KeyValPlan {
595                            key_plan: MapFilterProject::new(2)
596                                .map(vec![
597                                    ScalarExpr::Column(1).call_unary(
598                                        UnaryFunc::TumbleWindowFloor {
599                                            window_size: Duration::from_nanos(1_000_000_000),
600                                            start_time: Some(Timestamp::new_millisecond(
601                                                1625097600000,
602                                            )),
603                                        },
604                                    ),
605                                    ScalarExpr::Column(1).call_unary(
606                                        UnaryFunc::TumbleWindowCeiling {
607                                            window_size: Duration::from_nanos(1_000_000_000),
608                                            start_time: Some(Timestamp::new_millisecond(
609                                                1625097600000,
610                                            )),
611                                        },
612                                    ),
613                                ])
614                                .unwrap()
615                                .project(vec![2, 3])
616                                .unwrap()
617                                .into_safe(),
618                            val_plan: MapFilterProject::new(2)
619                                .map(vec![ScalarExpr::Column(0).cast(CDT::uint64_datatype())])
620                                .unwrap()
621                                .project(vec![2])
622                                .unwrap()
623                                .into_safe(),
624                        },
625                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
626                            full_aggrs: vec![aggr_expr.clone()],
627                            simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
628                            distinct_aggrs: vec![],
629                        }),
630                    }
631                    .with_types(
632                        RelationType::new(vec![
633                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
634                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
635                            ColumnType::new(CDT::uint64_datatype(), true), //sum(number)
636                        ])
637                        .with_key(vec![1])
638                        .with_time_index(Some(0))
639                        .into_named(vec![None, None, None]),
640                    ),
641                ),
642                mfp: MapFilterProject::new(3)
643                    .map(vec![
644                        ScalarExpr::CallDf {
645                            df_scalar_fn: DfScalarFunction::try_from_raw_fn(RawDfScalarFn {
646                                f: BytesMut::from(b"\"\x08\x1a\x06\x12\x04\n\x02\x12\0".as_ref()),
647                                input_schema: RelationType::new(vec![ColumnType::new(
648                                    ConcreteDataType::uint64_datatype(),
649                                    true,
650                                )])
651                                .into_unnamed(),
652                                extensions: FunctionExtensions::from_iter(
653                                    [
654                                        (0, "abs".to_string()),
655                                        (1, "tumble_start".to_string()),
656                                        (2, "tumble_end".to_string()),
657                                        (3, "sum".to_string()),
658                                    ]
659                                    .into_iter(),
660                                ),
661                            })
662                            .await
663                            .unwrap(),
664                            exprs: vec![ScalarExpr::Column(2)],
665                        },
666                        ScalarExpr::Column(0),
667                        ScalarExpr::Column(1),
668                    ])
669                    .unwrap()
670                    .project(vec![3, 4, 5])
671                    .unwrap(),
672            },
673        };
674        assert_eq!(flow_plan, expected);
675    }
676
677    /// TODO(discord9): add more illegal sql tests
678    #[tokio::test]
679    async fn test_tumble_composite() {
680        let engine = create_test_query_engine();
681        let sql =
682            "SELECT number, avg(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour'), number";
683        let plan = sql_to_substrait(engine.clone(), sql).await;
684
685        let mut ctx = create_test_ctx();
686        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
687            .await
688            .unwrap();
689
690        let aggr_exprs = vec![
691            AggregateExpr {
692                func: AggregateFunc::SumUInt64,
693                expr: ScalarExpr::Column(0),
694                distinct: false,
695            },
696            AggregateExpr {
697                func: AggregateFunc::Count,
698                expr: ScalarExpr::Column(1),
699                distinct: false,
700            },
701        ];
702        let avg_expr = ScalarExpr::If {
703            cond: Box::new(ScalarExpr::Column(4).call_binary(
704                ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()),
705                BinaryFunc::NotEq,
706            )),
707            then: Box::new(
708                ScalarExpr::Column(3)
709                    .cast(CDT::float64_datatype())
710                    .call_binary(
711                        ScalarExpr::Column(4).cast(CDT::float64_datatype()),
712                        BinaryFunc::DivFloat64,
713                    ),
714            ),
715            els: Box::new(ScalarExpr::Literal(Value::Null, CDT::float64_datatype())),
716        };
717        let expected = TypedPlan {
718            plan: Plan::Mfp {
719                input: Box::new(
720                    Plan::Reduce {
721                        input: Box::new(
722                            Plan::Get {
723                                id: crate::expr::Id::Global(GlobalId::User(1)),
724                            }
725                            .with_types(
726                                RelationType::new(vec![
727                                    ColumnType::new(ConcreteDataType::uint32_datatype(), false),
728                                    ColumnType::new(
729                                        ConcreteDataType::timestamp_millisecond_datatype(),
730                                        false,
731                                    ),
732                                ])
733                                .into_named(vec![
734                                    Some("number".to_string()),
735                                    Some("ts".to_string()),
736                                ]),
737                            )
738                            .mfp(MapFilterProject::new(2).into_safe())
739                            .unwrap(),
740                        ),
741                        key_val_plan: KeyValPlan {
742                            key_plan: MapFilterProject::new(2)
743                                .map(vec![
744                                    ScalarExpr::Column(1).call_unary(
745                                        UnaryFunc::TumbleWindowFloor {
746                                            window_size: Duration::from_nanos(3_600_000_000_000),
747                                            start_time: None,
748                                        },
749                                    ),
750                                    ScalarExpr::Column(1).call_unary(
751                                        UnaryFunc::TumbleWindowCeiling {
752                                            window_size: Duration::from_nanos(3_600_000_000_000),
753                                            start_time: None,
754                                        },
755                                    ),
756                                    ScalarExpr::Column(0),
757                                ])
758                                .unwrap()
759                                .project(vec![2, 3, 4])
760                                .unwrap()
761                                .into_safe(),
762                            val_plan: MapFilterProject::new(2)
763                                .map(vec![
764                                    ScalarExpr::Column(0).cast(CDT::uint64_datatype()),
765                                    ScalarExpr::Column(0),
766                                ])
767                                .unwrap()
768                                .project(vec![2, 3])
769                                .unwrap()
770                                .into_safe(),
771                        },
772                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
773                            full_aggrs: aggr_exprs.clone(),
774                            simple_aggrs: vec![
775                                AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
776                                AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1),
777                            ],
778                            distinct_aggrs: vec![],
779                        }),
780                    }
781                    .with_types(
782                        RelationType::new(vec![
783                            // keys
784                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start(time index)
785                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end(pk)
786                            ColumnType::new(CDT::uint32_datatype(), false), // number(pk)
787                            // values
788                            ColumnType::new(CDT::uint64_datatype(), true), // avg.sum(number)
789                            ColumnType::new(CDT::int64_datatype(), true),  // avg.count(number)
790                        ])
791                        .with_key(vec![1, 2])
792                        .with_time_index(Some(0))
793                        .into_named(vec![
794                            None,
795                            None,
796                            Some("number".to_string()),
797                            None,
798                            None,
799                        ]),
800                    ),
801                ),
802                mfp: MapFilterProject::new(5)
803                    .map(vec![
804                        ScalarExpr::Column(2), // number(pk)
805                        avg_expr,
806                        ScalarExpr::Column(0), // window start
807                        ScalarExpr::Column(1), // window end
808                    ])
809                    .unwrap()
810                    .project(vec![5, 6, 7, 8])
811                    .unwrap(),
812            },
813            schema: RelationType::new(vec![
814                ColumnType::new(CDT::uint32_datatype(), false), // number
815                ColumnType::new(CDT::float64_datatype(), true), // avg(number)
816                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
817                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
818            ])
819            .with_key(vec![0, 3])
820            .with_time_index(Some(2))
821            .into_named(vec![
822                Some("number".to_string()),
823                Some("avg(numbers_with_ts.number)".to_string()),
824                Some("window_start".to_string()),
825                Some("window_end".to_string()),
826            ]),
827        };
828        assert_eq!(flow_plan, expected);
829    }
830
831    #[tokio::test]
832    async fn test_tumble_parse_optional() {
833        let engine = create_test_query_engine();
834        let sql = "SELECT sum(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour')";
835        let plan = sql_to_substrait(engine.clone(), sql).await;
836
837        let mut ctx = create_test_ctx();
838        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
839            .await
840            .unwrap();
841
842        let aggr_expr = AggregateExpr {
843            func: AggregateFunc::SumUInt64,
844            expr: ScalarExpr::Column(0),
845            distinct: false,
846        };
847        let expected = TypedPlan {
848            schema: RelationType::new(vec![
849                ColumnType::new(CDT::uint64_datatype(), true), // sum(number)
850                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
851                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
852            ])
853            .with_key(vec![2])
854            .with_time_index(Some(1))
855            .into_named(vec![
856                Some("sum(numbers_with_ts.number)".to_string()),
857                Some("window_start".to_string()),
858                Some("window_end".to_string()),
859            ]),
860            plan: Plan::Mfp {
861                input: Box::new(
862                    Plan::Reduce {
863                        input: Box::new(
864                            Plan::Get {
865                                id: crate::expr::Id::Global(GlobalId::User(1)),
866                            }
867                            .with_types(
868                                RelationType::new(vec![
869                                    ColumnType::new(ConcreteDataType::uint32_datatype(), false),
870                                    ColumnType::new(
871                                        ConcreteDataType::timestamp_millisecond_datatype(),
872                                        false,
873                                    ),
874                                ])
875                                .into_named(vec![
876                                    Some("number".to_string()),
877                                    Some("ts".to_string()),
878                                ]),
879                            )
880                            .mfp(MapFilterProject::new(2).into_safe())
881                            .unwrap(),
882                        ),
883                        key_val_plan: KeyValPlan {
884                            key_plan: MapFilterProject::new(2)
885                                .map(vec![
886                                    ScalarExpr::Column(1).call_unary(
887                                        UnaryFunc::TumbleWindowFloor {
888                                            window_size: Duration::from_nanos(3_600_000_000_000),
889                                            start_time: None,
890                                        },
891                                    ),
892                                    ScalarExpr::Column(1).call_unary(
893                                        UnaryFunc::TumbleWindowCeiling {
894                                            window_size: Duration::from_nanos(3_600_000_000_000),
895                                            start_time: None,
896                                        },
897                                    ),
898                                ])
899                                .unwrap()
900                                .project(vec![2, 3])
901                                .unwrap()
902                                .into_safe(),
903                            val_plan: MapFilterProject::new(2)
904                                .map(vec![ScalarExpr::Column(0).cast(CDT::uint64_datatype())])
905                                .unwrap()
906                                .project(vec![2])
907                                .unwrap()
908                                .into_safe(),
909                        },
910                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
911                            full_aggrs: vec![aggr_expr.clone()],
912                            simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
913                            distinct_aggrs: vec![],
914                        }),
915                    }
916                    .with_types(
917                        RelationType::new(vec![
918                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
919                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
920                            ColumnType::new(CDT::uint64_datatype(), true), //sum(number)
921                        ])
922                        .with_key(vec![1])
923                        .with_time_index(Some(0))
924                        .into_named(vec![None, None, None]),
925                    ),
926                ),
927                mfp: MapFilterProject::new(3)
928                    .map(vec![
929                        ScalarExpr::Column(2),
930                        ScalarExpr::Column(0),
931                        ScalarExpr::Column(1),
932                    ])
933                    .unwrap()
934                    .project(vec![3, 4, 5])
935                    .unwrap(),
936            },
937        };
938        assert_eq!(flow_plan, expected);
939    }
940
941    #[tokio::test]
942    async fn test_tumble_parse() {
943        let engine = create_test_query_engine();
944        let sql = "SELECT sum(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour', '2021-07-01 00:00:00')";
945        let plan = sql_to_substrait(engine.clone(), sql).await;
946
947        let mut ctx = create_test_ctx();
948        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
949            .await
950            .unwrap();
951
952        let aggr_expr = AggregateExpr {
953            func: AggregateFunc::SumUInt64,
954            expr: ScalarExpr::Column(0),
955            distinct: false,
956        };
957        let expected = TypedPlan {
958            schema: RelationType::new(vec![
959                ColumnType::new(CDT::uint64_datatype(), true), // sum(number)
960                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
961                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
962            ])
963            .with_key(vec![2])
964            .with_time_index(Some(1))
965            .into_named(vec![
966                Some("sum(numbers_with_ts.number)".to_string()),
967                Some("window_start".to_string()),
968                Some("window_end".to_string()),
969            ]),
970            plan: Plan::Mfp {
971                input: Box::new(
972                    Plan::Reduce {
973                        input: Box::new(
974                            Plan::Get {
975                                id: crate::expr::Id::Global(GlobalId::User(1)),
976                            }
977                            .with_types(
978                                RelationType::new(vec![
979                                    ColumnType::new(ConcreteDataType::uint32_datatype(), false),
980                                    ColumnType::new(
981                                        ConcreteDataType::timestamp_millisecond_datatype(),
982                                        false,
983                                    ),
984                                ])
985                                .into_named(vec![
986                                    Some("number".to_string()),
987                                    Some("ts".to_string()),
988                                ]),
989                            )
990                            .mfp(MapFilterProject::new(2).into_safe())
991                            .unwrap(),
992                        ),
993                        key_val_plan: KeyValPlan {
994                            key_plan: MapFilterProject::new(2)
995                                .map(vec![
996                                    ScalarExpr::Column(1).call_unary(
997                                        UnaryFunc::TumbleWindowFloor {
998                                            window_size: Duration::from_nanos(3_600_000_000_000),
999                                            start_time: Some(Timestamp::new_millisecond(
1000                                                1625097600000,
1001                                            )),
1002                                        },
1003                                    ),
1004                                    ScalarExpr::Column(1).call_unary(
1005                                        UnaryFunc::TumbleWindowCeiling {
1006                                            window_size: Duration::from_nanos(3_600_000_000_000),
1007                                            start_time: Some(Timestamp::new_millisecond(
1008                                                1625097600000,
1009                                            )),
1010                                        },
1011                                    ),
1012                                ])
1013                                .unwrap()
1014                                .project(vec![2, 3])
1015                                .unwrap()
1016                                .into_safe(),
1017                            val_plan: MapFilterProject::new(2)
1018                                .map(vec![ScalarExpr::Column(0).cast(CDT::uint64_datatype())])
1019                                .unwrap()
1020                                .project(vec![2])
1021                                .unwrap()
1022                                .into_safe(),
1023                        },
1024                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1025                            full_aggrs: vec![aggr_expr.clone()],
1026                            simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
1027                            distinct_aggrs: vec![],
1028                        }),
1029                    }
1030                    .with_types(
1031                        RelationType::new(vec![
1032                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
1033                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
1034                            ColumnType::new(CDT::uint64_datatype(), true), //sum(number)
1035                        ])
1036                        .with_key(vec![1])
1037                        .with_time_index(Some(0))
1038                        .into_unnamed(),
1039                    ),
1040                ),
1041                mfp: MapFilterProject::new(3)
1042                    .map(vec![
1043                        ScalarExpr::Column(2),
1044                        ScalarExpr::Column(0),
1045                        ScalarExpr::Column(1),
1046                    ])
1047                    .unwrap()
1048                    .project(vec![3, 4, 5])
1049                    .unwrap(),
1050            },
1051        };
1052        assert_eq!(flow_plan, expected);
1053    }
1054
1055    #[tokio::test]
1056    async fn test_avg_group_by() {
1057        let engine = create_test_query_engine();
1058        let sql = "SELECT avg(number), number FROM numbers GROUP BY number";
1059        let plan = sql_to_substrait(engine.clone(), sql).await;
1060
1061        let mut ctx = create_test_ctx();
1062        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
1063
1064        let aggr_exprs = vec![
1065            AggregateExpr {
1066                func: AggregateFunc::SumUInt64,
1067                expr: ScalarExpr::Column(0),
1068                distinct: false,
1069            },
1070            AggregateExpr {
1071                func: AggregateFunc::Count,
1072                expr: ScalarExpr::Column(1),
1073                distinct: false,
1074            },
1075        ];
1076        let avg_expr = ScalarExpr::If {
1077            cond: Box::new(ScalarExpr::Column(2).call_binary(
1078                ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()),
1079                BinaryFunc::NotEq,
1080            )),
1081            then: Box::new(
1082                ScalarExpr::Column(1)
1083                    .cast(CDT::float64_datatype())
1084                    .call_binary(
1085                        ScalarExpr::Column(2).cast(CDT::float64_datatype()),
1086                        BinaryFunc::DivFloat64,
1087                    ),
1088            ),
1089            els: Box::new(ScalarExpr::Literal(Value::Null, CDT::float64_datatype())),
1090        };
1091        let expected = TypedPlan {
1092            schema: RelationType::new(vec![
1093                ColumnType::new(CDT::float64_datatype(), true), // avg(number: u32) -> f64
1094                ColumnType::new(CDT::uint32_datatype(), false), // number
1095            ])
1096            .with_key(vec![1])
1097            .into_named(vec![
1098                Some("avg(numbers.number)".to_string()),
1099                Some("number".to_string()),
1100            ]),
1101            plan: Plan::Mfp {
1102                input: Box::new(
1103                    Plan::Reduce {
1104                        input: Box::new(
1105                            Plan::Get {
1106                                id: crate::expr::Id::Global(GlobalId::User(0)),
1107                            }
1108                            .with_types(
1109                                RelationType::new(vec![ColumnType::new(
1110                                    ConcreteDataType::uint32_datatype(),
1111                                    false,
1112                                )])
1113                                .into_named(vec![Some("number".to_string())]),
1114                            )
1115                            .mfp(
1116                                MapFilterProject::new(1)
1117                                    .project(vec![0])
1118                                    .unwrap()
1119                                    .into_safe(),
1120                            )
1121                            .unwrap(),
1122                        ),
1123                        key_val_plan: KeyValPlan {
1124                            key_plan: MapFilterProject::new(1)
1125                                .map(vec![ScalarExpr::Column(0)])
1126                                .unwrap()
1127                                .project(vec![1])
1128                                .unwrap()
1129                                .into_safe(),
1130                            val_plan: MapFilterProject::new(1)
1131                                .map(vec![
1132                                    ScalarExpr::Column(0).cast(CDT::uint64_datatype()),
1133                                    ScalarExpr::Column(0),
1134                                ])
1135                                .unwrap()
1136                                .project(vec![1, 2])
1137                                .unwrap()
1138                                .into_safe(),
1139                        },
1140                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1141                            full_aggrs: aggr_exprs.clone(),
1142                            simple_aggrs: vec![
1143                                AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
1144                                AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1),
1145                            ],
1146                            distinct_aggrs: vec![],
1147                        }),
1148                    }
1149                    .with_types(
1150                        RelationType::new(vec![
1151                            ColumnType::new(ConcreteDataType::uint32_datatype(), false), // key: number
1152                            ColumnType::new(ConcreteDataType::uint64_datatype(), true),  // sum
1153                            ColumnType::new(ConcreteDataType::int64_datatype(), true),   // count
1154                        ])
1155                        .with_key(vec![0])
1156                        .into_named(vec![
1157                            Some("number".to_string()),
1158                            None,
1159                            None,
1160                        ]),
1161                    ),
1162                ),
1163                mfp: MapFilterProject::new(3)
1164                    .map(vec![
1165                        avg_expr, // col 3
1166                        ScalarExpr::Column(0),
1167                        // TODO(discord9): optimize mfp so to remove indirect ref
1168                    ])
1169                    .unwrap()
1170                    .project(vec![3, 4])
1171                    .unwrap(),
1172            },
1173        };
1174        assert_eq!(flow_plan.unwrap(), expected);
1175    }
1176
1177    #[tokio::test]
1178    async fn test_avg() {
1179        let engine = create_test_query_engine();
1180        let sql = "SELECT avg(number) FROM numbers";
1181        let plan = sql_to_substrait(engine.clone(), sql).await;
1182
1183        let mut ctx = create_test_ctx();
1184
1185        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
1186            .await
1187            .unwrap();
1188
1189        let aggr_exprs = vec![
1190            AggregateExpr {
1191                func: AggregateFunc::SumUInt64,
1192                expr: ScalarExpr::Column(0),
1193                distinct: false,
1194            },
1195            AggregateExpr {
1196                func: AggregateFunc::Count,
1197                expr: ScalarExpr::Column(1),
1198                distinct: false,
1199            },
1200        ];
1201        let avg_expr = ScalarExpr::If {
1202            cond: Box::new(ScalarExpr::Column(1).call_binary(
1203                ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()),
1204                BinaryFunc::NotEq,
1205            )),
1206            then: Box::new(
1207                ScalarExpr::Column(0)
1208                    .cast(CDT::float64_datatype())
1209                    .call_binary(
1210                        ScalarExpr::Column(1).cast(CDT::float64_datatype()),
1211                        BinaryFunc::DivFloat64,
1212                    ),
1213            ),
1214            els: Box::new(ScalarExpr::Literal(Value::Null, CDT::float64_datatype())),
1215        };
1216        let input = Box::new(
1217            Plan::Get {
1218                id: crate::expr::Id::Global(GlobalId::User(0)),
1219            }
1220            .with_types(
1221                RelationType::new(vec![ColumnType::new(
1222                    ConcreteDataType::uint32_datatype(),
1223                    false,
1224                )])
1225                .into_named(vec![Some("number".to_string())]),
1226            ),
1227        );
1228        let expected = TypedPlan {
1229            schema: RelationType::new(vec![ColumnType::new(CDT::float64_datatype(), true)])
1230                .into_named(vec![Some("avg(numbers.number)".to_string())]),
1231            plan: Plan::Mfp {
1232                input: Box::new(
1233                    Plan::Reduce {
1234                        input: Box::new(
1235                            Plan::Mfp {
1236                                input: input.clone(),
1237                                mfp: MapFilterProject::new(1).project(vec![0]).unwrap(),
1238                            }
1239                            .with_types(
1240                                RelationType::new(vec![ColumnType::new(
1241                                    CDT::uint32_datatype(),
1242                                    false,
1243                                )])
1244                                .into_named(vec![Some("number".to_string())]),
1245                            ),
1246                        ),
1247                        key_val_plan: KeyValPlan {
1248                            key_plan: MapFilterProject::new(1)
1249                                .project(vec![])
1250                                .unwrap()
1251                                .into_safe(),
1252                            val_plan: MapFilterProject::new(1)
1253                                .map(vec![
1254                                    ScalarExpr::Column(0).cast(CDT::uint64_datatype()),
1255                                    ScalarExpr::Column(0),
1256                                ])
1257                                .unwrap()
1258                                .project(vec![1, 2])
1259                                .unwrap()
1260                                .into_safe(),
1261                        },
1262                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1263                            full_aggrs: aggr_exprs.clone(),
1264                            simple_aggrs: vec![
1265                                AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
1266                                AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1),
1267                            ],
1268                            distinct_aggrs: vec![],
1269                        }),
1270                    }
1271                    .with_types(
1272                        RelationType::new(vec![
1273                            ColumnType::new(ConcreteDataType::uint64_datatype(), true), // sum
1274                            ColumnType::new(ConcreteDataType::int64_datatype(), true),  // count
1275                        ])
1276                        .into_named(vec![None, None]),
1277                    ),
1278                ),
1279                mfp: MapFilterProject::new(2)
1280                    .map(vec![
1281                        avg_expr,
1282                        // TODO(discord9): optimize mfp so to remove indirect ref
1283                    ])
1284                    .unwrap()
1285                    .project(vec![2])
1286                    .unwrap(),
1287            },
1288        };
1289        assert_eq!(flow_plan, expected);
1290    }
1291
1292    #[tokio::test]
1293    async fn test_sum() {
1294        let engine = create_test_query_engine();
1295        let sql = "SELECT sum(number) FROM numbers";
1296        let plan = sql_to_substrait(engine.clone(), sql).await;
1297
1298        let mut ctx = create_test_ctx();
1299        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
1300
1301        let aggr_expr = AggregateExpr {
1302            func: AggregateFunc::SumUInt64,
1303            expr: ScalarExpr::Column(0),
1304            distinct: false,
1305        };
1306        let expected = TypedPlan {
1307            schema: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)])
1308                .into_named(vec![Some("sum(numbers.number)".to_string())]),
1309            plan: Plan::Reduce {
1310                input: Box::new(
1311                    Plan::Get {
1312                        id: crate::expr::Id::Global(GlobalId::User(0)),
1313                    }
1314                    .with_types(
1315                        RelationType::new(vec![ColumnType::new(
1316                            ConcreteDataType::uint32_datatype(),
1317                            false,
1318                        )])
1319                        .into_named(vec![Some("number".to_string())]),
1320                    )
1321                    .mfp(MapFilterProject::new(1).into_safe())
1322                    .unwrap(),
1323                ),
1324                key_val_plan: KeyValPlan {
1325                    key_plan: MapFilterProject::new(1)
1326                        .project(vec![])
1327                        .unwrap()
1328                        .into_safe(),
1329                    val_plan: MapFilterProject::new(1)
1330                        .map(vec![
1331                            ScalarExpr::Column(0)
1332                                .call_unary(UnaryFunc::Cast(CDT::uint64_datatype())),
1333                        ])
1334                        .unwrap()
1335                        .project(vec![1])
1336                        .unwrap()
1337                        .into_safe(),
1338                },
1339                reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1340                    full_aggrs: vec![aggr_expr.clone()],
1341                    simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
1342                    distinct_aggrs: vec![],
1343                }),
1344            },
1345        };
1346        assert_eq!(flow_plan.unwrap(), expected);
1347    }
1348
1349    #[tokio::test]
1350    async fn test_distinct_number() {
1351        let engine = create_test_query_engine();
1352        let sql = "SELECT DISTINCT number FROM numbers";
1353        let plan = sql_to_substrait(engine.clone(), sql).await;
1354
1355        let mut ctx = create_test_ctx();
1356        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
1357            .await
1358            .unwrap();
1359
1360        let expected = TypedPlan {
1361            schema: RelationType::new(vec![
1362                ColumnType::new(CDT::uint32_datatype(), false), // col number
1363            ])
1364            .with_key(vec![0])
1365            .into_named(vec![Some("number".to_string())]),
1366            plan: Plan::Reduce {
1367                input: Box::new(
1368                    Plan::Get {
1369                        id: crate::expr::Id::Global(GlobalId::User(0)),
1370                    }
1371                    .with_types(
1372                        RelationType::new(vec![ColumnType::new(
1373                            ConcreteDataType::uint32_datatype(),
1374                            false,
1375                        )])
1376                        .into_named(vec![Some("number".to_string())]),
1377                    )
1378                    .mfp(MapFilterProject::new(1).into_safe())
1379                    .unwrap(),
1380                ),
1381                key_val_plan: KeyValPlan {
1382                    key_plan: MapFilterProject::new(1)
1383                        .map(vec![ScalarExpr::Column(0)])
1384                        .unwrap()
1385                        .project(vec![1])
1386                        .unwrap()
1387                        .into_safe(),
1388                    val_plan: MapFilterProject::new(1)
1389                        .project(vec![0])
1390                        .unwrap()
1391                        .into_safe(),
1392                },
1393                reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1394                    full_aggrs: vec![],
1395                    simple_aggrs: vec![],
1396                    distinct_aggrs: vec![],
1397                }),
1398            },
1399        };
1400
1401        assert_eq!(flow_plan, expected);
1402    }
1403
1404    #[tokio::test]
1405    async fn test_sum_group_by() {
1406        let engine = create_test_query_engine();
1407        let sql = "SELECT sum(number), number FROM numbers GROUP BY number";
1408        let plan = sql_to_substrait(engine.clone(), sql).await;
1409
1410        let mut ctx = create_test_ctx();
1411        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
1412            .await
1413            .unwrap();
1414
1415        let aggr_expr = AggregateExpr {
1416            func: AggregateFunc::SumUInt64,
1417            expr: ScalarExpr::Column(0),
1418            distinct: false,
1419        };
1420        let expected = TypedPlan {
1421            schema: RelationType::new(vec![
1422                ColumnType::new(CDT::uint64_datatype(), true), // col sum(number)
1423                ColumnType::new(CDT::uint32_datatype(), false), // col number
1424            ])
1425            .with_key(vec![1])
1426            .into_named(vec![
1427                Some("sum(numbers.number)".to_string()),
1428                Some("number".to_string()),
1429            ]),
1430            plan: Plan::Mfp {
1431                input: Box::new(
1432                    Plan::Reduce {
1433                        input: Box::new(
1434                            Plan::Get {
1435                                id: crate::expr::Id::Global(GlobalId::User(0)),
1436                            }
1437                            .with_types(
1438                                RelationType::new(vec![ColumnType::new(
1439                                    ConcreteDataType::uint32_datatype(),
1440                                    false,
1441                                )])
1442                                .into_named(vec![Some("number".to_string())]),
1443                            )
1444                            .mfp(MapFilterProject::new(1).into_safe())
1445                            .unwrap(),
1446                        ),
1447                        key_val_plan: KeyValPlan {
1448                            key_plan: MapFilterProject::new(1)
1449                                .map(vec![ScalarExpr::Column(0)])
1450                                .unwrap()
1451                                .project(vec![1])
1452                                .unwrap()
1453                                .into_safe(),
1454                            val_plan: MapFilterProject::new(1)
1455                                .map(vec![
1456                                    ScalarExpr::Column(0)
1457                                        .call_unary(UnaryFunc::Cast(CDT::uint64_datatype())),
1458                                ])
1459                                .unwrap()
1460                                .project(vec![1])
1461                                .unwrap()
1462                                .into_safe(),
1463                        },
1464                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1465                            full_aggrs: vec![aggr_expr.clone()],
1466                            simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
1467                            distinct_aggrs: vec![],
1468                        }),
1469                    }
1470                    .with_types(
1471                        RelationType::new(vec![
1472                            ColumnType::new(CDT::uint32_datatype(), false), // col number
1473                            ColumnType::new(CDT::uint64_datatype(), true),  // col sum(number)
1474                        ])
1475                        .with_key(vec![0])
1476                        .into_named(vec![Some("number".to_string()), None]),
1477                    ),
1478                ),
1479                mfp: MapFilterProject::new(2)
1480                    .map(vec![ScalarExpr::Column(1), ScalarExpr::Column(0)])
1481                    .unwrap()
1482                    .project(vec![2, 3])
1483                    .unwrap(),
1484            },
1485        };
1486
1487        assert_eq!(flow_plan, expected);
1488    }
1489
1490    #[tokio::test]
1491    async fn test_sum_add() {
1492        let engine = create_test_query_engine();
1493        let sql = "SELECT sum(number+number) FROM numbers";
1494        let plan = sql_to_substrait(engine.clone(), sql).await;
1495
1496        let mut ctx = create_test_ctx();
1497        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
1498
1499        let aggr_expr = AggregateExpr {
1500            func: AggregateFunc::SumUInt64,
1501            expr: ScalarExpr::Column(0),
1502            distinct: false,
1503        };
1504        let expected = TypedPlan {
1505            schema: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)])
1506                .into_named(vec![Some(
1507                    "sum(numbers.number + numbers.number)".to_string(),
1508                )]),
1509            plan: Plan::Reduce {
1510                input: Box::new(
1511                    Plan::Mfp {
1512                        input: Box::new(
1513                            Plan::Get {
1514                                id: crate::expr::Id::Global(GlobalId::User(0)),
1515                            }
1516                            .with_types(
1517                                RelationType::new(vec![ColumnType::new(
1518                                    ConcreteDataType::uint32_datatype(),
1519                                    false,
1520                                )])
1521                                .into_named(vec![Some("number".to_string())]),
1522                            ),
1523                        ),
1524                        mfp: MapFilterProject::new(1),
1525                    }
1526                    .with_types(
1527                        RelationType::new(vec![ColumnType::new(
1528                            ConcreteDataType::uint32_datatype(),
1529                            false,
1530                        )])
1531                        .into_named(vec![Some("number".to_string())]),
1532                    ),
1533                ),
1534                key_val_plan: KeyValPlan {
1535                    key_plan: MapFilterProject::new(1)
1536                        .project(vec![])
1537                        .unwrap()
1538                        .into_safe(),
1539                    val_plan: MapFilterProject::new(1)
1540                        .map(vec![
1541                            ScalarExpr::Column(0)
1542                                .call_binary(ScalarExpr::Column(0), BinaryFunc::AddUInt32)
1543                                .call_unary(UnaryFunc::Cast(CDT::uint64_datatype())),
1544                        ])
1545                        .unwrap()
1546                        .project(vec![1])
1547                        .unwrap()
1548                        .into_safe(),
1549                },
1550                reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1551                    full_aggrs: vec![aggr_expr.clone()],
1552                    simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
1553                    distinct_aggrs: vec![],
1554                }),
1555            },
1556        };
1557        assert_eq!(flow_plan.unwrap(), expected);
1558    }
1559
1560    #[tokio::test]
1561    async fn test_cast_max_min() {
1562        let engine = create_test_query_engine();
1563        let sql = "SELECT (max(number) - min(number))/30.0, date_bin(INTERVAL '30 second', CAST(ts AS TimestampMillisecond)) as time_window from numbers_with_ts GROUP BY time_window";
1564        let plan = sql_to_substrait(engine.clone(), sql).await;
1565
1566        let mut ctx = create_test_ctx();
1567        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
1568
1569        let aggr_exprs = vec![
1570            AggregateExpr {
1571                func: AggregateFunc::MaxUInt32,
1572                expr: ScalarExpr::Column(0),
1573                distinct: false,
1574            },
1575            AggregateExpr {
1576                func: AggregateFunc::MinUInt32,
1577                expr: ScalarExpr::Column(0),
1578                distinct: false,
1579            },
1580        ];
1581        let expected = TypedPlan {
1582            schema: RelationType::new(vec![
1583                ColumnType::new(CDT::float64_datatype(), true),
1584                ColumnType::new(CDT::timestamp_millisecond_datatype(), true),
1585            ])
1586            .with_time_index(Some(1))
1587            .into_named(vec![
1588                Some(
1589                    "max(numbers_with_ts.number) - min(numbers_with_ts.number) / Float64(30)"
1590                        .to_string(),
1591                ),
1592                Some("time_window".to_string()),
1593            ]),
1594            plan: Plan::Mfp {
1595                input: Box::new(
1596                    Plan::Reduce {
1597                        input: Box::new(
1598                            Plan::Get {
1599                                id: crate::expr::Id::Global(GlobalId::User(1)),
1600                            }
1601                            .with_types(
1602                                RelationType::new(vec![
1603                                    ColumnType::new(ConcreteDataType::uint32_datatype(), false),
1604                                    ColumnType::new(ConcreteDataType::timestamp_millisecond_datatype(), false),
1605                                ])
1606                                .into_named(vec![
1607                                    Some("number".to_string()),
1608                                    Some("ts".to_string()),
1609                                ]),
1610                            )
1611                            .mfp(MapFilterProject::new(2).into_safe())
1612                            .unwrap(),
1613                        ),
1614
1615                        key_val_plan: KeyValPlan {
1616                            key_plan: MapFilterProject::new(2)
1617                                .map(vec![ScalarExpr::CallDf {
1618                                    df_scalar_fn: DfScalarFunction::try_from_raw_fn(
1619                                        RawDfScalarFn {
1620                                            f: BytesMut::from(
1621                                                b"\x08\x02\"\x0f\x1a\r\n\x0b\xa2\x02\x08\n\0\x12\x04\x10\x1e \t\"\n\x1a\x08\x12\x06\n\x04\x12\x02\x08\x01".as_ref(),
1622                                            ),
1623                                            input_schema: RelationType::new(vec![ColumnType::new(
1624                                                ConcreteDataType::interval_month_day_nano_datatype(),
1625                                                true,
1626                                            ),ColumnType::new(
1627                                                ConcreteDataType::timestamp_millisecond_datatype(),
1628                                                false,
1629                                            )])
1630                                            .into_unnamed(),
1631                                            extensions: FunctionExtensions::from_iter([
1632                                                    (0, "subtract".to_string()),
1633                                                    (1, "divide".to_string()),
1634                                                    (2, "date_bin".to_string()),
1635                                                    (3, "max".to_string()),
1636                                                    (4, "min".to_string()),
1637                                                ]),
1638                                        },
1639                                    )
1640                                    .await
1641                                    .unwrap(),
1642                                    exprs: vec![
1643                                        ScalarExpr::Literal(
1644                                            Value::IntervalMonthDayNano(IntervalMonthDayNano::new(0, 0, 30000000000)),
1645                                            CDT::interval_month_day_nano_datatype()
1646                                        ),
1647                                        ScalarExpr::Column(1)
1648                                        ],
1649                                }])
1650                                .unwrap()
1651                                .project(vec![2])
1652                                .unwrap()
1653                                .into_safe(),
1654                            val_plan: MapFilterProject::new(2)
1655                                .into_safe(),
1656                        },
1657                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1658                            full_aggrs: aggr_exprs.clone(),
1659                            simple_aggrs: vec![AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
1660                            AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1)],
1661                            distinct_aggrs: vec![],
1662                        }),
1663                    }
1664                    .with_types(
1665                        RelationType::new(vec![
1666                            ColumnType::new(
1667                                ConcreteDataType::timestamp_millisecond_datatype(),
1668                                true,
1669                            ), // time_window
1670                            ColumnType::new(ConcreteDataType::uint32_datatype(), true), // max
1671                            ColumnType::new(ConcreteDataType::uint32_datatype(), true), // min
1672                        ])
1673                        .with_time_index(Some(0))
1674                        .into_unnamed(),
1675                    ),
1676                ),
1677                mfp: MapFilterProject::new(3)
1678                    .map(vec![
1679                        ScalarExpr::Column(1)
1680                            .call_binary(ScalarExpr::Column(2), BinaryFunc::SubUInt32)
1681                            .cast(CDT::float64_datatype())
1682                            .call_binary(
1683                                ScalarExpr::Literal(Value::from(30.0f64), CDT::float64_datatype()),
1684                                BinaryFunc::DivFloat64,
1685                            ),
1686                        ScalarExpr::Column(0),
1687                    ])
1688                    .unwrap()
1689                    .project(vec![3, 4])
1690                    .unwrap(),
1691            },
1692        };
1693
1694        assert_eq!(flow_plan.unwrap(), expected);
1695    }
1696}