flow/transform/
aggr.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use snafu::OptionExt;
17use substrait_proto::proto;
18use substrait_proto::proto::aggregate_function::AggregationInvocation;
19use substrait_proto::proto::aggregate_rel::{Grouping, Measure};
20use substrait_proto::proto::function_argument::ArgType;
21
22use crate::error::{Error, NotImplementedSnafu, PlanSnafu};
23use crate::expr::{
24    AggregateExpr, AggregateFunc, MapFilterProject, ScalarExpr, TypedExpr, UnaryFunc,
25};
26use crate::plan::{AccumulablePlan, AggrWithIndex, KeyValPlan, Plan, ReducePlan, TypedPlan};
27use crate::repr::{ColumnType, RelationDesc, RelationType};
28use crate::transform::{FlownodeContext, FunctionExtensions, substrait_proto};
29
30impl TypedExpr {
31    /// Allow `deprecated` due to the usage of deprecated grouping_expressions on datafusion to substrait side
32    #[allow(deprecated)]
33    async fn from_substrait_agg_grouping(
34        ctx: &mut FlownodeContext,
35        grouping_expressions: &[proto::Expression],
36        groupings: &[Grouping],
37        typ: &RelationDesc,
38        extensions: &FunctionExtensions,
39    ) -> Result<Vec<TypedExpr>, Error> {
40        let _ = ctx;
41        let mut group_expr = vec![];
42        match groupings.len() {
43            1 => {
44                // handle case when deprecated grouping_expressions is referenced by index is empty
45                let expressions: Box<dyn Iterator<Item = &proto::Expression> + Send> = if groupings
46                    [0]
47                .expression_references
48                .is_empty()
49                {
50                    Box::new(groupings[0].grouping_expressions.iter())
51                } else {
52                    if groupings[0]
53                        .expression_references
54                        .iter()
55                        .any(|idx| *idx as usize >= grouping_expressions.len())
56                    {
57                        return PlanSnafu {
58                            reason: format!("Invalid grouping expression reference: {:?} for grouping expr: {:?}", 
59                            groupings[0].expression_references,
60                            grouping_expressions
61                        ),
62                        }.fail()?;
63                    }
64                    Box::new(
65                        groupings[0]
66                            .expression_references
67                            .iter()
68                            .map(|idx| &grouping_expressions[*idx as usize]),
69                    )
70                };
71                for e in expressions {
72                    let x = TypedExpr::from_substrait_rex(e, typ, extensions).await?;
73                    group_expr.push(x);
74                }
75            }
76            _ => {
77                return not_impl_err!(
78                    "Grouping sets not support yet, use union all with group by instead."
79                );
80            }
81        };
82        Ok(group_expr)
83    }
84}
85
86impl AggregateExpr {
87    /// Convert list of `Measure` into Flow's AggregateExpr
88    ///
89    /// Return both the AggregateExpr and a MapFilterProject that is the final output of the aggregate function
90    async fn from_substrait_agg_measures(
91        ctx: &mut FlownodeContext,
92        measures: &[Measure],
93        typ: &RelationDesc,
94        extensions: &FunctionExtensions,
95    ) -> Result<Vec<AggregateExpr>, Error> {
96        let _ = ctx;
97        let mut all_aggr_exprs = vec![];
98
99        for m in measures {
100            let filter = match m
101                .filter
102                .as_ref()
103                .map(|fil| TypedExpr::from_substrait_rex(fil, typ, extensions))
104            {
105                Some(fut) => Some(fut.await),
106                None => None,
107            }
108            .transpose()?;
109
110            let aggr_expr = match &m.measure {
111                Some(f) => {
112                    let distinct = match f.invocation {
113                        _ if f.invocation == AggregationInvocation::Distinct as i32 => true,
114                        _ if f.invocation == AggregationInvocation::All as i32 => false,
115                        _ => false,
116                    };
117                    AggregateExpr::from_substrait_agg_func(
118                        f, typ, extensions, &filter, // TODO(discord9): impl order_by
119                        &None, distinct,
120                    )
121                    .await?
122                }
123                None => {
124                    return not_impl_err!("Aggregate without aggregate function is not supported");
125                }
126            };
127
128            all_aggr_exprs.extend(aggr_expr);
129        }
130
131        Ok(all_aggr_exprs)
132    }
133
134    /// Convert AggregateFunction into Flow's AggregateExpr
135    ///
136    /// the returned value is a tuple of AggregateExpr and a optional ScalarExpr that if exist is the final output of the aggregate function
137    /// since aggr functions like `avg` need to be transform to `sum(x)/cast(count(x) as x_type)`
138    pub async fn from_substrait_agg_func(
139        f: &proto::AggregateFunction,
140        input_schema: &RelationDesc,
141        extensions: &FunctionExtensions,
142        filter: &Option<TypedExpr>,
143        order_by: &Option<Vec<TypedExpr>>,
144        distinct: bool,
145    ) -> Result<Vec<AggregateExpr>, Error> {
146        // TODO(discord9): impl filter
147        let _ = filter;
148        let _ = order_by;
149        let mut args = vec![];
150        for arg in &f.arguments {
151            let arg_expr = match &arg.arg_type {
152                Some(ArgType::Value(e)) => {
153                    TypedExpr::from_substrait_rex(e, input_schema, extensions).await
154                }
155                _ => not_impl_err!("Aggregated function argument non-Value type not supported"),
156            }?;
157            args.push(arg_expr);
158        }
159
160        if args.len() != 1 {
161            let fn_name = extensions.get(&f.function_reference).cloned();
162            return not_impl_err!(
163                "Aggregated function (name={:?}) with multiple arguments is not supported",
164                fn_name
165            );
166        }
167
168        let arg = if let Some(first) = args.first() {
169            first
170        } else {
171            return not_impl_err!("Aggregated function without arguments is not supported");
172        };
173
174        let fn_name = extensions
175            .get(&f.function_reference)
176            .cloned()
177            .map(|s| s.to_lowercase());
178
179        match fn_name.as_ref().map(|s| s.as_ref()) {
180            Some(function_name) => {
181                let func = AggregateFunc::from_str_and_type(
182                    function_name,
183                    Some(arg.typ.scalar_type.clone()),
184                )?;
185                let exprs = vec![AggregateExpr {
186                    func,
187                    expr: arg.expr.clone(),
188                    distinct,
189                }];
190                Ok(exprs)
191            }
192            None => not_impl_err!(
193                "Aggregated function not found: function anchor = {:?}",
194                f.function_reference
195            ),
196        }
197    }
198}
199
200impl KeyValPlan {
201    /// Generate KeyValPlan from AggregateExpr and group_exprs
202    ///
203    /// will also change aggregate expr to use column ref if necessary
204    fn from_substrait_gen_key_val_plan(
205        aggr_exprs: &mut [AggregateExpr],
206        group_exprs: &[TypedExpr],
207        input_arity: usize,
208    ) -> Result<KeyValPlan, Error> {
209        let group_expr_val = group_exprs
210            .iter()
211            .cloned()
212            .map(|expr| expr.expr.clone())
213            .collect_vec();
214        let output_arity = group_expr_val.len();
215        let key_plan = MapFilterProject::new(input_arity)
216            .map(group_expr_val)?
217            .project(input_arity..input_arity + output_arity)?;
218
219        // val_plan is extracted from aggr_exprs to give aggr function it's necessary input
220        // and since aggr func need inputs that is column ref, we just add a prefix mfp to transform any expr that is not into a column ref
221        let val_plan = {
222            let need_mfp = aggr_exprs.iter().any(|agg| agg.expr.as_column().is_none());
223            if need_mfp {
224                // create mfp from aggr_expr, and modify aggr_expr to use the output column of mfp
225                let input_exprs = aggr_exprs
226                    .iter_mut()
227                    .enumerate()
228                    .map(|(idx, aggr)| {
229                        let ret = aggr.expr.clone();
230                        aggr.expr = ScalarExpr::Column(idx);
231                        ret
232                    })
233                    .collect_vec();
234                let aggr_arity = aggr_exprs.len();
235
236                MapFilterProject::new(input_arity)
237                    .map(input_exprs)?
238                    .project(input_arity..input_arity + aggr_arity)?
239            } else {
240                // simply take all inputs as value
241                MapFilterProject::new(input_arity)
242            }
243        };
244        Ok(KeyValPlan {
245            key_plan: key_plan.into_safe(),
246            val_plan: val_plan.into_safe(),
247        })
248    }
249}
250
251/// find out the column that should be time index in group exprs(which is all columns that should be keys)
252/// TODO(discord9): better ways to assign time index
253/// for now, it will found the first column that is timestamp or has a tumble window floor function
254fn find_time_index_in_group_exprs(group_exprs: &[TypedExpr]) -> Option<usize> {
255    group_exprs.iter().position(|expr| {
256        matches!(
257            &expr.expr,
258            ScalarExpr::CallUnary {
259                func: UnaryFunc::TumbleWindowFloor { .. },
260                expr: _
261            }
262        ) || expr.typ.scalar_type.is_timestamp()
263    })
264}
265
266impl TypedPlan {
267    /// Convert AggregateRel into Flow's TypedPlan
268    ///
269    /// The output of aggr plan is:
270    ///
271    /// <group_exprs>..<aggr_exprs>
272    #[async_recursion::async_recursion]
273    pub async fn from_substrait_agg_rel(
274        ctx: &mut FlownodeContext,
275        agg: &proto::AggregateRel,
276        extensions: &FunctionExtensions,
277    ) -> Result<TypedPlan, Error> {
278        let input = if let Some(input) = agg.input.as_ref() {
279            TypedPlan::from_substrait_rel(ctx, input, extensions).await?
280        } else {
281            return not_impl_err!("Aggregate without an input is not supported");
282        };
283
284        let group_exprs = TypedExpr::from_substrait_agg_grouping(
285            ctx,
286            &agg.grouping_expressions,
287            &agg.groupings,
288            &input.schema,
289            extensions,
290        )
291        .await?;
292
293        let time_index = find_time_index_in_group_exprs(&group_exprs);
294
295        let mut aggr_exprs = AggregateExpr::from_substrait_agg_measures(
296            ctx,
297            &agg.measures,
298            &input.schema,
299            extensions,
300        )
301        .await?;
302
303        let key_val_plan = KeyValPlan::from_substrait_gen_key_val_plan(
304            &mut aggr_exprs,
305            &group_exprs,
306            input.schema.typ.column_types.len(),
307        )?;
308
309        // output type is group_exprs + aggr_exprs
310        let output_type = {
311            let mut output_types = Vec::new();
312            // give best effort to get column name
313            let mut output_names = Vec::new();
314
315            // first append group_expr as key, then aggr_expr as value
316            for expr in group_exprs.iter() {
317                output_types.push(expr.typ.clone());
318                let col_name = match &expr.expr {
319                    ScalarExpr::Column(col) => input.schema.get_name(*col).clone(),
320                    // TODO(discord9): impl& use ScalarExpr.display_name, which recursively build expr's name
321                    _ => None,
322                };
323                output_names.push(col_name)
324            }
325
326            for aggr in &aggr_exprs {
327                output_types.push(ColumnType::new_nullable(
328                    aggr.func.signature().output.clone(),
329                ));
330                // TODO(discord9): find a clever way to name them?
331                output_names.push(None);
332            }
333            // TODO(discord9): try best to get time
334            if group_exprs.is_empty() {
335                RelationType::new(output_types)
336            } else {
337                RelationType::new(output_types).with_key((0..group_exprs.len()).collect_vec())
338            }
339            .with_time_index(time_index)
340            .into_named(output_names)
341        };
342
343        // copy aggr_exprs to full_aggrs, and split them into simple_aggrs and distinct_aggrs
344        // also set them input/output column
345        let full_aggrs = aggr_exprs;
346        let mut simple_aggrs = Vec::new();
347        let mut distinct_aggrs = Vec::new();
348        for (output_column, aggr_expr) in full_aggrs.iter().enumerate() {
349            let input_column = aggr_expr.expr.as_column().with_context(|| PlanSnafu {
350                reason: "Expect aggregate argument to be transformed into a column at this point",
351            })?;
352            if aggr_expr.distinct {
353                distinct_aggrs.push(AggrWithIndex::new(
354                    aggr_expr.clone(),
355                    input_column,
356                    output_column,
357                ));
358            } else {
359                simple_aggrs.push(AggrWithIndex::new(
360                    aggr_expr.clone(),
361                    input_column,
362                    output_column,
363                ));
364            }
365        }
366        let accum_plan = AccumulablePlan {
367            full_aggrs,
368            simple_aggrs,
369            distinct_aggrs,
370        };
371        let plan = Plan::Reduce {
372            input: Box::new(input),
373            key_val_plan,
374            reduce_plan: ReducePlan::Accumulable(accum_plan),
375        };
376        // FIX(discord9): deal with key first
377        return Ok(TypedPlan {
378            schema: output_type,
379            plan,
380        });
381    }
382}
383
384#[cfg(test)]
385mod test {
386    use std::time::Duration;
387
388    use bytes::BytesMut;
389    use common_time::{IntervalMonthDayNano, Timestamp};
390    use datatypes::data_type::ConcreteDataType as CDT;
391    use datatypes::prelude::ConcreteDataType;
392    use datatypes::value::Value;
393    use pretty_assertions::assert_eq;
394
395    use super::*;
396    use crate::expr::{BinaryFunc, DfScalarFunction, GlobalId, RawDfScalarFn};
397    use crate::plan::{Plan, TypedPlan};
398    use crate::repr::{ColumnType, RelationType};
399    use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait};
400
401    #[tokio::test]
402    async fn test_df_func_basic() {
403        let engine = create_test_query_engine();
404        let sql = "SELECT sum(abs(number)) FROM numbers_with_ts GROUP BY tumble(ts, '1 second', '2021-07-01 00:00:00');";
405        let plan = sql_to_substrait(engine.clone(), sql).await;
406
407        let mut ctx = create_test_ctx();
408        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
409            .await
410            .unwrap();
411
412        let aggr_expr = AggregateExpr {
413            func: AggregateFunc::SumUInt64,
414            expr: ScalarExpr::Column(0),
415            distinct: false,
416        };
417        let expected =
418            TypedPlan {
419                schema: RelationType::new(vec![
420                    ColumnType::new(CDT::uint64_datatype(), true), // sum(number)
421                    ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
422                    ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
423                ])
424                .with_key(vec![2])
425                .with_time_index(Some(1))
426                .into_named(vec![
427                    Some("sum(abs(numbers_with_ts.number))".to_string()),
428                    Some("window_start".to_string()),
429                    Some("window_end".to_string()),
430                ]),
431                plan: Plan::Mfp {
432                    input: Box::new(
433                        Plan::Reduce {
434                            input: Box::new(
435                                Plan::Get {
436                                    id: crate::expr::Id::Global(GlobalId::User(1)),
437                                }
438                                .with_types(
439                                    RelationType::new(vec![
440                                        ColumnType::new(ConcreteDataType::uint32_datatype(), false),
441                                        ColumnType::new(
442                                            ConcreteDataType::timestamp_millisecond_datatype(),
443                                            false,
444                                        ),
445                                    ])
446                                    .into_named(vec![
447                                        Some("number".to_string()),
448                                        Some("ts".to_string()),
449                                    ]),
450                                )
451                                .mfp(MapFilterProject::new(2).into_safe())
452                                .unwrap(),
453                            ),
454                            key_val_plan: KeyValPlan {
455                                key_plan: MapFilterProject::new(2)
456                                    .map(vec![
457                                        ScalarExpr::Column(1).call_unary(
458                                            UnaryFunc::TumbleWindowFloor {
459                                                window_size: Duration::from_nanos(1_000_000_000),
460                                                start_time: Some(Timestamp::new_millisecond(
461                                                    1625097600000,
462                                                )),
463                                            },
464                                        ),
465                                        ScalarExpr::Column(1).call_unary(
466                                            UnaryFunc::TumbleWindowCeiling {
467                                                window_size: Duration::from_nanos(1_000_000_000),
468                                                start_time: Some(Timestamp::new_millisecond(
469                                                    1625097600000,
470                                                )),
471                                            },
472                                        ),
473                                    ])
474                                    .unwrap()
475                                    .project(vec![2, 3])
476                                    .unwrap()
477                                    .into_safe(),
478                                val_plan: MapFilterProject::new(2)
479                                    .map(vec![ScalarExpr::CallDf {
480                                    df_scalar_fn: DfScalarFunction::try_from_raw_fn(
481                                        RawDfScalarFn {
482                                            f: BytesMut::from(
483                                                b"\x08\x02\"\x08\x1a\x06\x12\x04\n\x02\x12\0"
484                                                    .as_ref(),
485                                            ),
486                                            input_schema: RelationType::new(vec![ColumnType::new(
487                                                ConcreteDataType::uint32_datatype(),
488                                                false,
489                                            )])
490                                            .into_unnamed(),
491                                            extensions: FunctionExtensions::from_iter(
492                                                [
493                                                    (0, "tumble_start".to_string()),
494                                                    (1, "tumble_end".to_string()),
495                                                    (2, "abs".to_string()),
496                                                    (3, "sum".to_string()),
497                                                ]
498                                                .into_iter(),
499                                            ),
500                                        },
501                                    )
502                                    .await
503                                    .unwrap(),
504                                    exprs: vec![ScalarExpr::Column(0)],
505                                }
506                                .cast(CDT::uint64_datatype())])
507                                    .unwrap()
508                                    .project(vec![2])
509                                    .unwrap()
510                                    .into_safe(),
511                            },
512                            reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
513                                full_aggrs: vec![aggr_expr.clone()],
514                                simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
515                                distinct_aggrs: vec![],
516                            }),
517                        }
518                        .with_types(
519                            RelationType::new(vec![
520                                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
521                                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
522                                ColumnType::new(CDT::uint64_datatype(), true), //sum(number)
523                            ])
524                            .with_key(vec![1])
525                            .with_time_index(Some(0))
526                            .into_unnamed(),
527                        ),
528                    ),
529                    mfp: MapFilterProject::new(3)
530                        .map(vec![
531                            ScalarExpr::Column(2),
532                            ScalarExpr::Column(0),
533                            ScalarExpr::Column(1),
534                        ])
535                        .unwrap()
536                        .project(vec![3, 4, 5])
537                        .unwrap(),
538                },
539            };
540        assert_eq!(flow_plan, expected);
541    }
542
543    #[tokio::test]
544    async fn test_df_func_expr_tree() {
545        let engine = create_test_query_engine();
546        let sql = "SELECT abs(sum(number)) FROM numbers_with_ts GROUP BY tumble(ts, '1 second', '2021-07-01 00:00:00');";
547        let plan = sql_to_substrait(engine.clone(), sql).await;
548
549        let mut ctx = create_test_ctx();
550        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
551            .await
552            .unwrap();
553
554        let aggr_expr = AggregateExpr {
555            func: AggregateFunc::SumUInt64,
556            expr: ScalarExpr::Column(0),
557            distinct: false,
558        };
559        let expected = TypedPlan {
560            schema: RelationType::new(vec![
561                ColumnType::new(CDT::uint64_datatype(), true), // sum(number)
562                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
563                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
564            ])
565            .with_key(vec![2])
566            .with_time_index(Some(1))
567            .into_named(vec![
568                Some("abs(sum(numbers_with_ts.number))".to_string()),
569                Some("window_start".to_string()),
570                Some("window_end".to_string()),
571            ]),
572            plan: Plan::Mfp {
573                input: Box::new(
574                    Plan::Reduce {
575                        input: Box::new(
576                            Plan::Get {
577                                id: crate::expr::Id::Global(GlobalId::User(1)),
578                            }
579                            .with_types(
580                                RelationType::new(vec![
581                                    ColumnType::new(ConcreteDataType::uint32_datatype(), false),
582                                    ColumnType::new(
583                                        ConcreteDataType::timestamp_millisecond_datatype(),
584                                        false,
585                                    ),
586                                ])
587                                .into_named(vec![
588                                    Some("number".to_string()),
589                                    Some("ts".to_string()),
590                                ]),
591                            )
592                            .mfp(MapFilterProject::new(2).into_safe())
593                            .unwrap(),
594                        ),
595                        key_val_plan: KeyValPlan {
596                            key_plan: MapFilterProject::new(2)
597                                .map(vec![
598                                    ScalarExpr::Column(1).call_unary(
599                                        UnaryFunc::TumbleWindowFloor {
600                                            window_size: Duration::from_nanos(1_000_000_000),
601                                            start_time: Some(Timestamp::new_millisecond(
602                                                1625097600000,
603                                            )),
604                                        },
605                                    ),
606                                    ScalarExpr::Column(1).call_unary(
607                                        UnaryFunc::TumbleWindowCeiling {
608                                            window_size: Duration::from_nanos(1_000_000_000),
609                                            start_time: Some(Timestamp::new_millisecond(
610                                                1625097600000,
611                                            )),
612                                        },
613                                    ),
614                                ])
615                                .unwrap()
616                                .project(vec![2, 3])
617                                .unwrap()
618                                .into_safe(),
619                            val_plan: MapFilterProject::new(2)
620                                .map(vec![ScalarExpr::Column(0).cast(CDT::uint64_datatype())])
621                                .unwrap()
622                                .project(vec![2])
623                                .unwrap()
624                                .into_safe(),
625                        },
626                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
627                            full_aggrs: vec![aggr_expr.clone()],
628                            simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
629                            distinct_aggrs: vec![],
630                        }),
631                    }
632                    .with_types(
633                        RelationType::new(vec![
634                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
635                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
636                            ColumnType::new(CDT::uint64_datatype(), true), //sum(number)
637                        ])
638                        .with_key(vec![1])
639                        .with_time_index(Some(0))
640                        .into_named(vec![None, None, None]),
641                    ),
642                ),
643                mfp: MapFilterProject::new(3)
644                    .map(vec![
645                        ScalarExpr::CallDf {
646                            df_scalar_fn: DfScalarFunction::try_from_raw_fn(RawDfScalarFn {
647                                f: BytesMut::from(b"\"\x08\x1a\x06\x12\x04\n\x02\x12\0".as_ref()),
648                                input_schema: RelationType::new(vec![ColumnType::new(
649                                    ConcreteDataType::uint64_datatype(),
650                                    true,
651                                )])
652                                .into_unnamed(),
653                                extensions: FunctionExtensions::from_iter(
654                                    [
655                                        (0, "abs".to_string()),
656                                        (1, "tumble_start".to_string()),
657                                        (2, "tumble_end".to_string()),
658                                        (3, "sum".to_string()),
659                                    ]
660                                    .into_iter(),
661                                ),
662                            })
663                            .await
664                            .unwrap(),
665                            exprs: vec![ScalarExpr::Column(2)],
666                        },
667                        ScalarExpr::Column(0),
668                        ScalarExpr::Column(1),
669                    ])
670                    .unwrap()
671                    .project(vec![3, 4, 5])
672                    .unwrap(),
673            },
674        };
675        assert_eq!(flow_plan, expected);
676    }
677
678    /// TODO(discord9): add more illegal sql tests
679    #[tokio::test]
680    async fn test_tumble_composite() {
681        let engine = create_test_query_engine();
682        let sql =
683            "SELECT number, avg(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour'), number";
684        let plan = sql_to_substrait(engine.clone(), sql).await;
685
686        let mut ctx = create_test_ctx();
687        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
688            .await
689            .unwrap();
690
691        let aggr_exprs = vec![
692            AggregateExpr {
693                func: AggregateFunc::SumUInt64,
694                expr: ScalarExpr::Column(0),
695                distinct: false,
696            },
697            AggregateExpr {
698                func: AggregateFunc::Count,
699                expr: ScalarExpr::Column(1),
700                distinct: false,
701            },
702        ];
703        let avg_expr = ScalarExpr::If {
704            cond: Box::new(ScalarExpr::Column(4).call_binary(
705                ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()),
706                BinaryFunc::NotEq,
707            )),
708            then: Box::new(
709                ScalarExpr::Column(3)
710                    .cast(CDT::float64_datatype())
711                    .call_binary(
712                        ScalarExpr::Column(4).cast(CDT::float64_datatype()),
713                        BinaryFunc::DivFloat64,
714                    ),
715            ),
716            els: Box::new(ScalarExpr::Literal(Value::Null, CDT::float64_datatype())),
717        };
718        let expected = TypedPlan {
719            plan: Plan::Mfp {
720                input: Box::new(
721                    Plan::Reduce {
722                        input: Box::new(
723                            Plan::Get {
724                                id: crate::expr::Id::Global(GlobalId::User(1)),
725                            }
726                            .with_types(
727                                RelationType::new(vec![
728                                    ColumnType::new(ConcreteDataType::uint32_datatype(), false),
729                                    ColumnType::new(
730                                        ConcreteDataType::timestamp_millisecond_datatype(),
731                                        false,
732                                    ),
733                                ])
734                                .into_named(vec![
735                                    Some("number".to_string()),
736                                    Some("ts".to_string()),
737                                ]),
738                            )
739                            .mfp(MapFilterProject::new(2).into_safe())
740                            .unwrap(),
741                        ),
742                        key_val_plan: KeyValPlan {
743                            key_plan: MapFilterProject::new(2)
744                                .map(vec![
745                                    ScalarExpr::Column(1).call_unary(
746                                        UnaryFunc::TumbleWindowFloor {
747                                            window_size: Duration::from_nanos(3_600_000_000_000),
748                                            start_time: None,
749                                        },
750                                    ),
751                                    ScalarExpr::Column(1).call_unary(
752                                        UnaryFunc::TumbleWindowCeiling {
753                                            window_size: Duration::from_nanos(3_600_000_000_000),
754                                            start_time: None,
755                                        },
756                                    ),
757                                    ScalarExpr::Column(0),
758                                ])
759                                .unwrap()
760                                .project(vec![2, 3, 4])
761                                .unwrap()
762                                .into_safe(),
763                            val_plan: MapFilterProject::new(2)
764                                .map(vec![
765                                    ScalarExpr::Column(0).cast(CDT::uint64_datatype()),
766                                    ScalarExpr::Column(0),
767                                ])
768                                .unwrap()
769                                .project(vec![2, 3])
770                                .unwrap()
771                                .into_safe(),
772                        },
773                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
774                            full_aggrs: aggr_exprs.clone(),
775                            simple_aggrs: vec![
776                                AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
777                                AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1),
778                            ],
779                            distinct_aggrs: vec![],
780                        }),
781                    }
782                    .with_types(
783                        RelationType::new(vec![
784                            // keys
785                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start(time index)
786                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end(pk)
787                            ColumnType::new(CDT::uint32_datatype(), false), // number(pk)
788                            // values
789                            ColumnType::new(CDT::uint64_datatype(), true), // avg.sum(number)
790                            ColumnType::new(CDT::int64_datatype(), true),  // avg.count(number)
791                        ])
792                        .with_key(vec![1, 2])
793                        .with_time_index(Some(0))
794                        .into_named(vec![
795                            None,
796                            None,
797                            Some("number".to_string()),
798                            None,
799                            None,
800                        ]),
801                    ),
802                ),
803                mfp: MapFilterProject::new(5)
804                    .map(vec![
805                        ScalarExpr::Column(2), // number(pk)
806                        avg_expr,
807                        ScalarExpr::Column(0), // window start
808                        ScalarExpr::Column(1), // window end
809                    ])
810                    .unwrap()
811                    .project(vec![5, 6, 7, 8])
812                    .unwrap(),
813            },
814            schema: RelationType::new(vec![
815                ColumnType::new(CDT::uint32_datatype(), false), // number
816                ColumnType::new(CDT::float64_datatype(), true), // avg(number)
817                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
818                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
819            ])
820            .with_key(vec![0, 3])
821            .with_time_index(Some(2))
822            .into_named(vec![
823                Some("number".to_string()),
824                Some("avg(numbers_with_ts.number)".to_string()),
825                Some("window_start".to_string()),
826                Some("window_end".to_string()),
827            ]),
828        };
829        assert_eq!(flow_plan, expected);
830    }
831
832    #[tokio::test]
833    async fn test_tumble_parse_optional() {
834        let engine = create_test_query_engine();
835        let sql = "SELECT sum(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour')";
836        let plan = sql_to_substrait(engine.clone(), sql).await;
837
838        let mut ctx = create_test_ctx();
839        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
840            .await
841            .unwrap();
842
843        let aggr_expr = AggregateExpr {
844            func: AggregateFunc::SumUInt64,
845            expr: ScalarExpr::Column(0),
846            distinct: false,
847        };
848        let expected = TypedPlan {
849            schema: RelationType::new(vec![
850                ColumnType::new(CDT::uint64_datatype(), true), // sum(number)
851                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
852                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
853            ])
854            .with_key(vec![2])
855            .with_time_index(Some(1))
856            .into_named(vec![
857                Some("sum(numbers_with_ts.number)".to_string()),
858                Some("window_start".to_string()),
859                Some("window_end".to_string()),
860            ]),
861            plan: Plan::Mfp {
862                input: Box::new(
863                    Plan::Reduce {
864                        input: Box::new(
865                            Plan::Get {
866                                id: crate::expr::Id::Global(GlobalId::User(1)),
867                            }
868                            .with_types(
869                                RelationType::new(vec![
870                                    ColumnType::new(ConcreteDataType::uint32_datatype(), false),
871                                    ColumnType::new(
872                                        ConcreteDataType::timestamp_millisecond_datatype(),
873                                        false,
874                                    ),
875                                ])
876                                .into_named(vec![
877                                    Some("number".to_string()),
878                                    Some("ts".to_string()),
879                                ]),
880                            )
881                            .mfp(MapFilterProject::new(2).into_safe())
882                            .unwrap(),
883                        ),
884                        key_val_plan: KeyValPlan {
885                            key_plan: MapFilterProject::new(2)
886                                .map(vec![
887                                    ScalarExpr::Column(1).call_unary(
888                                        UnaryFunc::TumbleWindowFloor {
889                                            window_size: Duration::from_nanos(3_600_000_000_000),
890                                            start_time: None,
891                                        },
892                                    ),
893                                    ScalarExpr::Column(1).call_unary(
894                                        UnaryFunc::TumbleWindowCeiling {
895                                            window_size: Duration::from_nanos(3_600_000_000_000),
896                                            start_time: None,
897                                        },
898                                    ),
899                                ])
900                                .unwrap()
901                                .project(vec![2, 3])
902                                .unwrap()
903                                .into_safe(),
904                            val_plan: MapFilterProject::new(2)
905                                .map(vec![ScalarExpr::Column(0).cast(CDT::uint64_datatype())])
906                                .unwrap()
907                                .project(vec![2])
908                                .unwrap()
909                                .into_safe(),
910                        },
911                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
912                            full_aggrs: vec![aggr_expr.clone()],
913                            simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
914                            distinct_aggrs: vec![],
915                        }),
916                    }
917                    .with_types(
918                        RelationType::new(vec![
919                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
920                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
921                            ColumnType::new(CDT::uint64_datatype(), true), //sum(number)
922                        ])
923                        .with_key(vec![1])
924                        .with_time_index(Some(0))
925                        .into_named(vec![None, None, None]),
926                    ),
927                ),
928                mfp: MapFilterProject::new(3)
929                    .map(vec![
930                        ScalarExpr::Column(2),
931                        ScalarExpr::Column(0),
932                        ScalarExpr::Column(1),
933                    ])
934                    .unwrap()
935                    .project(vec![3, 4, 5])
936                    .unwrap(),
937            },
938        };
939        assert_eq!(flow_plan, expected);
940    }
941
942    #[tokio::test]
943    async fn test_tumble_parse() {
944        let engine = create_test_query_engine();
945        let sql = "SELECT sum(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour', '2021-07-01 00:00:00')";
946        let plan = sql_to_substrait(engine.clone(), sql).await;
947
948        let mut ctx = create_test_ctx();
949        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
950            .await
951            .unwrap();
952
953        let aggr_expr = AggregateExpr {
954            func: AggregateFunc::SumUInt64,
955            expr: ScalarExpr::Column(0),
956            distinct: false,
957        };
958        let expected = TypedPlan {
959            schema: RelationType::new(vec![
960                ColumnType::new(CDT::uint64_datatype(), true), // sum(number)
961                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
962                ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
963            ])
964            .with_key(vec![2])
965            .with_time_index(Some(1))
966            .into_named(vec![
967                Some("sum(numbers_with_ts.number)".to_string()),
968                Some("window_start".to_string()),
969                Some("window_end".to_string()),
970            ]),
971            plan: Plan::Mfp {
972                input: Box::new(
973                    Plan::Reduce {
974                        input: Box::new(
975                            Plan::Get {
976                                id: crate::expr::Id::Global(GlobalId::User(1)),
977                            }
978                            .with_types(
979                                RelationType::new(vec![
980                                    ColumnType::new(ConcreteDataType::uint32_datatype(), false),
981                                    ColumnType::new(
982                                        ConcreteDataType::timestamp_millisecond_datatype(),
983                                        false,
984                                    ),
985                                ])
986                                .into_named(vec![
987                                    Some("number".to_string()),
988                                    Some("ts".to_string()),
989                                ]),
990                            )
991                            .mfp(MapFilterProject::new(2).into_safe())
992                            .unwrap(),
993                        ),
994                        key_val_plan: KeyValPlan {
995                            key_plan: MapFilterProject::new(2)
996                                .map(vec![
997                                    ScalarExpr::Column(1).call_unary(
998                                        UnaryFunc::TumbleWindowFloor {
999                                            window_size: Duration::from_nanos(3_600_000_000_000),
1000                                            start_time: Some(Timestamp::new_millisecond(
1001                                                1625097600000,
1002                                            )),
1003                                        },
1004                                    ),
1005                                    ScalarExpr::Column(1).call_unary(
1006                                        UnaryFunc::TumbleWindowCeiling {
1007                                            window_size: Duration::from_nanos(3_600_000_000_000),
1008                                            start_time: Some(Timestamp::new_millisecond(
1009                                                1625097600000,
1010                                            )),
1011                                        },
1012                                    ),
1013                                ])
1014                                .unwrap()
1015                                .project(vec![2, 3])
1016                                .unwrap()
1017                                .into_safe(),
1018                            val_plan: MapFilterProject::new(2)
1019                                .map(vec![ScalarExpr::Column(0).cast(CDT::uint64_datatype())])
1020                                .unwrap()
1021                                .project(vec![2])
1022                                .unwrap()
1023                                .into_safe(),
1024                        },
1025                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1026                            full_aggrs: vec![aggr_expr.clone()],
1027                            simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
1028                            distinct_aggrs: vec![],
1029                        }),
1030                    }
1031                    .with_types(
1032                        RelationType::new(vec![
1033                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
1034                            ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
1035                            ColumnType::new(CDT::uint64_datatype(), true), //sum(number)
1036                        ])
1037                        .with_key(vec![1])
1038                        .with_time_index(Some(0))
1039                        .into_unnamed(),
1040                    ),
1041                ),
1042                mfp: MapFilterProject::new(3)
1043                    .map(vec![
1044                        ScalarExpr::Column(2),
1045                        ScalarExpr::Column(0),
1046                        ScalarExpr::Column(1),
1047                    ])
1048                    .unwrap()
1049                    .project(vec![3, 4, 5])
1050                    .unwrap(),
1051            },
1052        };
1053        assert_eq!(flow_plan, expected);
1054    }
1055
1056    #[tokio::test]
1057    async fn test_avg_group_by() {
1058        let engine = create_test_query_engine();
1059        let sql = "SELECT avg(number), number FROM numbers GROUP BY number";
1060        let plan = sql_to_substrait(engine.clone(), sql).await;
1061
1062        let mut ctx = create_test_ctx();
1063        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
1064
1065        let aggr_exprs = vec![
1066            AggregateExpr {
1067                func: AggregateFunc::SumUInt64,
1068                expr: ScalarExpr::Column(0),
1069                distinct: false,
1070            },
1071            AggregateExpr {
1072                func: AggregateFunc::Count,
1073                expr: ScalarExpr::Column(1),
1074                distinct: false,
1075            },
1076        ];
1077        let avg_expr = ScalarExpr::If {
1078            cond: Box::new(ScalarExpr::Column(2).call_binary(
1079                ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()),
1080                BinaryFunc::NotEq,
1081            )),
1082            then: Box::new(
1083                ScalarExpr::Column(1)
1084                    .cast(CDT::float64_datatype())
1085                    .call_binary(
1086                        ScalarExpr::Column(2).cast(CDT::float64_datatype()),
1087                        BinaryFunc::DivFloat64,
1088                    ),
1089            ),
1090            els: Box::new(ScalarExpr::Literal(Value::Null, CDT::float64_datatype())),
1091        };
1092        let expected = TypedPlan {
1093            schema: RelationType::new(vec![
1094                ColumnType::new(CDT::float64_datatype(), true), // avg(number: u32) -> f64
1095                ColumnType::new(CDT::uint32_datatype(), false), // number
1096            ])
1097            .with_key(vec![1])
1098            .into_named(vec![
1099                Some("avg(numbers.number)".to_string()),
1100                Some("number".to_string()),
1101            ]),
1102            plan: Plan::Mfp {
1103                input: Box::new(
1104                    Plan::Reduce {
1105                        input: Box::new(
1106                            Plan::Get {
1107                                id: crate::expr::Id::Global(GlobalId::User(0)),
1108                            }
1109                            .with_types(
1110                                RelationType::new(vec![ColumnType::new(
1111                                    ConcreteDataType::uint32_datatype(),
1112                                    false,
1113                                )])
1114                                .into_named(vec![Some("number".to_string())]),
1115                            )
1116                            .mfp(
1117                                MapFilterProject::new(1)
1118                                    .project(vec![0])
1119                                    .unwrap()
1120                                    .into_safe(),
1121                            )
1122                            .unwrap(),
1123                        ),
1124                        key_val_plan: KeyValPlan {
1125                            key_plan: MapFilterProject::new(1)
1126                                .map(vec![ScalarExpr::Column(0)])
1127                                .unwrap()
1128                                .project(vec![1])
1129                                .unwrap()
1130                                .into_safe(),
1131                            val_plan: MapFilterProject::new(1)
1132                                .map(vec![
1133                                    ScalarExpr::Column(0).cast(CDT::uint64_datatype()),
1134                                    ScalarExpr::Column(0),
1135                                ])
1136                                .unwrap()
1137                                .project(vec![1, 2])
1138                                .unwrap()
1139                                .into_safe(),
1140                        },
1141                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1142                            full_aggrs: aggr_exprs.clone(),
1143                            simple_aggrs: vec![
1144                                AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
1145                                AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1),
1146                            ],
1147                            distinct_aggrs: vec![],
1148                        }),
1149                    }
1150                    .with_types(
1151                        RelationType::new(vec![
1152                            ColumnType::new(ConcreteDataType::uint32_datatype(), false), // key: number
1153                            ColumnType::new(ConcreteDataType::uint64_datatype(), true),  // sum
1154                            ColumnType::new(ConcreteDataType::int64_datatype(), true),   // count
1155                        ])
1156                        .with_key(vec![0])
1157                        .into_named(vec![
1158                            Some("number".to_string()),
1159                            None,
1160                            None,
1161                        ]),
1162                    ),
1163                ),
1164                mfp: MapFilterProject::new(3)
1165                    .map(vec![
1166                        avg_expr, // col 3
1167                        ScalarExpr::Column(0),
1168                        // TODO(discord9): optimize mfp so to remove indirect ref
1169                    ])
1170                    .unwrap()
1171                    .project(vec![3, 4])
1172                    .unwrap(),
1173            },
1174        };
1175        assert_eq!(flow_plan.unwrap(), expected);
1176    }
1177
1178    #[tokio::test]
1179    async fn test_avg() {
1180        let engine = create_test_query_engine();
1181        let sql = "SELECT avg(number) FROM numbers";
1182        let plan = sql_to_substrait(engine.clone(), sql).await;
1183
1184        let mut ctx = create_test_ctx();
1185
1186        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
1187            .await
1188            .unwrap();
1189
1190        let aggr_exprs = vec![
1191            AggregateExpr {
1192                func: AggregateFunc::SumUInt64,
1193                expr: ScalarExpr::Column(0),
1194                distinct: false,
1195            },
1196            AggregateExpr {
1197                func: AggregateFunc::Count,
1198                expr: ScalarExpr::Column(1),
1199                distinct: false,
1200            },
1201        ];
1202        let avg_expr = ScalarExpr::If {
1203            cond: Box::new(ScalarExpr::Column(1).call_binary(
1204                ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()),
1205                BinaryFunc::NotEq,
1206            )),
1207            then: Box::new(
1208                ScalarExpr::Column(0)
1209                    .cast(CDT::float64_datatype())
1210                    .call_binary(
1211                        ScalarExpr::Column(1).cast(CDT::float64_datatype()),
1212                        BinaryFunc::DivFloat64,
1213                    ),
1214            ),
1215            els: Box::new(ScalarExpr::Literal(Value::Null, CDT::float64_datatype())),
1216        };
1217        let input = Box::new(
1218            Plan::Get {
1219                id: crate::expr::Id::Global(GlobalId::User(0)),
1220            }
1221            .with_types(
1222                RelationType::new(vec![ColumnType::new(
1223                    ConcreteDataType::uint32_datatype(),
1224                    false,
1225                )])
1226                .into_named(vec![Some("number".to_string())]),
1227            ),
1228        );
1229        let expected = TypedPlan {
1230            schema: RelationType::new(vec![ColumnType::new(CDT::float64_datatype(), true)])
1231                .into_named(vec![Some("avg(numbers.number)".to_string())]),
1232            plan: Plan::Mfp {
1233                input: Box::new(
1234                    Plan::Reduce {
1235                        input: Box::new(
1236                            Plan::Mfp {
1237                                input: input.clone(),
1238                                mfp: MapFilterProject::new(1).project(vec![0]).unwrap(),
1239                            }
1240                            .with_types(
1241                                RelationType::new(vec![ColumnType::new(
1242                                    CDT::uint32_datatype(),
1243                                    false,
1244                                )])
1245                                .into_named(vec![Some("number".to_string())]),
1246                            ),
1247                        ),
1248                        key_val_plan: KeyValPlan {
1249                            key_plan: MapFilterProject::new(1)
1250                                .project(vec![])
1251                                .unwrap()
1252                                .into_safe(),
1253                            val_plan: MapFilterProject::new(1)
1254                                .map(vec![
1255                                    ScalarExpr::Column(0).cast(CDT::uint64_datatype()),
1256                                    ScalarExpr::Column(0),
1257                                ])
1258                                .unwrap()
1259                                .project(vec![1, 2])
1260                                .unwrap()
1261                                .into_safe(),
1262                        },
1263                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1264                            full_aggrs: aggr_exprs.clone(),
1265                            simple_aggrs: vec![
1266                                AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
1267                                AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1),
1268                            ],
1269                            distinct_aggrs: vec![],
1270                        }),
1271                    }
1272                    .with_types(
1273                        RelationType::new(vec![
1274                            ColumnType::new(ConcreteDataType::uint64_datatype(), true), // sum
1275                            ColumnType::new(ConcreteDataType::int64_datatype(), true),  // count
1276                        ])
1277                        .into_named(vec![None, None]),
1278                    ),
1279                ),
1280                mfp: MapFilterProject::new(2)
1281                    .map(vec![
1282                        avg_expr,
1283                        // TODO(discord9): optimize mfp so to remove indirect ref
1284                    ])
1285                    .unwrap()
1286                    .project(vec![2])
1287                    .unwrap(),
1288            },
1289        };
1290        assert_eq!(flow_plan, expected);
1291    }
1292
1293    #[tokio::test]
1294    async fn test_sum() {
1295        let engine = create_test_query_engine();
1296        let sql = "SELECT sum(number) FROM numbers";
1297        let plan = sql_to_substrait(engine.clone(), sql).await;
1298
1299        let mut ctx = create_test_ctx();
1300        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
1301
1302        let aggr_expr = AggregateExpr {
1303            func: AggregateFunc::SumUInt64,
1304            expr: ScalarExpr::Column(0),
1305            distinct: false,
1306        };
1307        let expected = TypedPlan {
1308            schema: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)])
1309                .into_named(vec![Some("sum(numbers.number)".to_string())]),
1310            plan: Plan::Reduce {
1311                input: Box::new(
1312                    Plan::Get {
1313                        id: crate::expr::Id::Global(GlobalId::User(0)),
1314                    }
1315                    .with_types(
1316                        RelationType::new(vec![ColumnType::new(
1317                            ConcreteDataType::uint32_datatype(),
1318                            false,
1319                        )])
1320                        .into_named(vec![Some("number".to_string())]),
1321                    )
1322                    .mfp(MapFilterProject::new(1).into_safe())
1323                    .unwrap(),
1324                ),
1325                key_val_plan: KeyValPlan {
1326                    key_plan: MapFilterProject::new(1)
1327                        .project(vec![])
1328                        .unwrap()
1329                        .into_safe(),
1330                    val_plan: MapFilterProject::new(1)
1331                        .map(vec![
1332                            ScalarExpr::Column(0)
1333                                .call_unary(UnaryFunc::Cast(CDT::uint64_datatype())),
1334                        ])
1335                        .unwrap()
1336                        .project(vec![1])
1337                        .unwrap()
1338                        .into_safe(),
1339                },
1340                reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1341                    full_aggrs: vec![aggr_expr.clone()],
1342                    simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
1343                    distinct_aggrs: vec![],
1344                }),
1345            },
1346        };
1347        assert_eq!(flow_plan.unwrap(), expected);
1348    }
1349
1350    #[tokio::test]
1351    async fn test_distinct_number() {
1352        let engine = create_test_query_engine();
1353        let sql = "SELECT DISTINCT number FROM numbers";
1354        let plan = sql_to_substrait(engine.clone(), sql).await;
1355
1356        let mut ctx = create_test_ctx();
1357        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
1358            .await
1359            .unwrap();
1360
1361        let expected = TypedPlan {
1362            schema: RelationType::new(vec![
1363                ColumnType::new(CDT::uint32_datatype(), false), // col number
1364            ])
1365            .with_key(vec![0])
1366            .into_named(vec![Some("number".to_string())]),
1367            plan: Plan::Reduce {
1368                input: Box::new(
1369                    Plan::Get {
1370                        id: crate::expr::Id::Global(GlobalId::User(0)),
1371                    }
1372                    .with_types(
1373                        RelationType::new(vec![ColumnType::new(
1374                            ConcreteDataType::uint32_datatype(),
1375                            false,
1376                        )])
1377                        .into_named(vec![Some("number".to_string())]),
1378                    )
1379                    .mfp(MapFilterProject::new(1).into_safe())
1380                    .unwrap(),
1381                ),
1382                key_val_plan: KeyValPlan {
1383                    key_plan: MapFilterProject::new(1)
1384                        .map(vec![ScalarExpr::Column(0)])
1385                        .unwrap()
1386                        .project(vec![1])
1387                        .unwrap()
1388                        .into_safe(),
1389                    val_plan: MapFilterProject::new(1)
1390                        .project(vec![0])
1391                        .unwrap()
1392                        .into_safe(),
1393                },
1394                reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1395                    full_aggrs: vec![],
1396                    simple_aggrs: vec![],
1397                    distinct_aggrs: vec![],
1398                }),
1399            },
1400        };
1401
1402        assert_eq!(flow_plan, expected);
1403    }
1404
1405    #[tokio::test]
1406    async fn test_sum_group_by() {
1407        let engine = create_test_query_engine();
1408        let sql = "SELECT sum(number), number FROM numbers GROUP BY number";
1409        let plan = sql_to_substrait(engine.clone(), sql).await;
1410
1411        let mut ctx = create_test_ctx();
1412        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
1413            .await
1414            .unwrap();
1415
1416        let aggr_expr = AggregateExpr {
1417            func: AggregateFunc::SumUInt64,
1418            expr: ScalarExpr::Column(0),
1419            distinct: false,
1420        };
1421        let expected = TypedPlan {
1422            schema: RelationType::new(vec![
1423                ColumnType::new(CDT::uint64_datatype(), true), // col sum(number)
1424                ColumnType::new(CDT::uint32_datatype(), false), // col number
1425            ])
1426            .with_key(vec![1])
1427            .into_named(vec![
1428                Some("sum(numbers.number)".to_string()),
1429                Some("number".to_string()),
1430            ]),
1431            plan: Plan::Mfp {
1432                input: Box::new(
1433                    Plan::Reduce {
1434                        input: Box::new(
1435                            Plan::Get {
1436                                id: crate::expr::Id::Global(GlobalId::User(0)),
1437                            }
1438                            .with_types(
1439                                RelationType::new(vec![ColumnType::new(
1440                                    ConcreteDataType::uint32_datatype(),
1441                                    false,
1442                                )])
1443                                .into_named(vec![Some("number".to_string())]),
1444                            )
1445                            .mfp(MapFilterProject::new(1).into_safe())
1446                            .unwrap(),
1447                        ),
1448                        key_val_plan: KeyValPlan {
1449                            key_plan: MapFilterProject::new(1)
1450                                .map(vec![ScalarExpr::Column(0)])
1451                                .unwrap()
1452                                .project(vec![1])
1453                                .unwrap()
1454                                .into_safe(),
1455                            val_plan: MapFilterProject::new(1)
1456                                .map(vec![
1457                                    ScalarExpr::Column(0)
1458                                        .call_unary(UnaryFunc::Cast(CDT::uint64_datatype())),
1459                                ])
1460                                .unwrap()
1461                                .project(vec![1])
1462                                .unwrap()
1463                                .into_safe(),
1464                        },
1465                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1466                            full_aggrs: vec![aggr_expr.clone()],
1467                            simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
1468                            distinct_aggrs: vec![],
1469                        }),
1470                    }
1471                    .with_types(
1472                        RelationType::new(vec![
1473                            ColumnType::new(CDT::uint32_datatype(), false), // col number
1474                            ColumnType::new(CDT::uint64_datatype(), true),  // col sum(number)
1475                        ])
1476                        .with_key(vec![0])
1477                        .into_named(vec![Some("number".to_string()), None]),
1478                    ),
1479                ),
1480                mfp: MapFilterProject::new(2)
1481                    .map(vec![ScalarExpr::Column(1), ScalarExpr::Column(0)])
1482                    .unwrap()
1483                    .project(vec![2, 3])
1484                    .unwrap(),
1485            },
1486        };
1487
1488        assert_eq!(flow_plan, expected);
1489    }
1490
1491    #[tokio::test]
1492    async fn test_sum_add() {
1493        let engine = create_test_query_engine();
1494        let sql = "SELECT sum(number+number) FROM numbers";
1495        let plan = sql_to_substrait(engine.clone(), sql).await;
1496
1497        let mut ctx = create_test_ctx();
1498        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
1499
1500        let aggr_expr = AggregateExpr {
1501            func: AggregateFunc::SumUInt64,
1502            expr: ScalarExpr::Column(0),
1503            distinct: false,
1504        };
1505        let expected = TypedPlan {
1506            schema: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)])
1507                .into_named(vec![Some(
1508                    "sum(numbers.number + numbers.number)".to_string(),
1509                )]),
1510            plan: Plan::Reduce {
1511                input: Box::new(
1512                    Plan::Mfp {
1513                        input: Box::new(
1514                            Plan::Get {
1515                                id: crate::expr::Id::Global(GlobalId::User(0)),
1516                            }
1517                            .with_types(
1518                                RelationType::new(vec![ColumnType::new(
1519                                    ConcreteDataType::uint32_datatype(),
1520                                    false,
1521                                )])
1522                                .into_named(vec![Some("number".to_string())]),
1523                            ),
1524                        ),
1525                        mfp: MapFilterProject::new(1),
1526                    }
1527                    .with_types(
1528                        RelationType::new(vec![ColumnType::new(
1529                            ConcreteDataType::uint32_datatype(),
1530                            false,
1531                        )])
1532                        .into_named(vec![Some("number".to_string())]),
1533                    ),
1534                ),
1535                key_val_plan: KeyValPlan {
1536                    key_plan: MapFilterProject::new(1)
1537                        .project(vec![])
1538                        .unwrap()
1539                        .into_safe(),
1540                    val_plan: MapFilterProject::new(1)
1541                        .map(vec![
1542                            ScalarExpr::Column(0)
1543                                .call_binary(ScalarExpr::Column(0), BinaryFunc::AddUInt32)
1544                                .call_unary(UnaryFunc::Cast(CDT::uint64_datatype())),
1545                        ])
1546                        .unwrap()
1547                        .project(vec![1])
1548                        .unwrap()
1549                        .into_safe(),
1550                },
1551                reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1552                    full_aggrs: vec![aggr_expr.clone()],
1553                    simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
1554                    distinct_aggrs: vec![],
1555                }),
1556            },
1557        };
1558        assert_eq!(flow_plan.unwrap(), expected);
1559    }
1560
1561    #[tokio::test]
1562    async fn test_cast_max_min() {
1563        let engine = create_test_query_engine();
1564        let sql = "SELECT (max(number) - min(number))/30.0, date_bin(INTERVAL '30 second', CAST(ts AS TimestampMillisecond)) as time_window from numbers_with_ts GROUP BY time_window";
1565        let plan = sql_to_substrait(engine.clone(), sql).await;
1566
1567        let mut ctx = create_test_ctx();
1568        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
1569
1570        let aggr_exprs = vec![
1571            AggregateExpr {
1572                func: AggregateFunc::MaxUInt32,
1573                expr: ScalarExpr::Column(0),
1574                distinct: false,
1575            },
1576            AggregateExpr {
1577                func: AggregateFunc::MinUInt32,
1578                expr: ScalarExpr::Column(0),
1579                distinct: false,
1580            },
1581        ];
1582        let expected = TypedPlan {
1583            schema: RelationType::new(vec![
1584                ColumnType::new(CDT::float64_datatype(), true),
1585                ColumnType::new(CDT::timestamp_millisecond_datatype(), true),
1586            ])
1587            .with_time_index(Some(1))
1588            .into_named(vec![
1589                Some(
1590                    "max(numbers_with_ts.number) - min(numbers_with_ts.number) / Float64(30)"
1591                        .to_string(),
1592                ),
1593                Some("time_window".to_string()),
1594            ]),
1595            plan: Plan::Mfp {
1596                input: Box::new(
1597                    Plan::Reduce {
1598                        input: Box::new(
1599                            Plan::Get {
1600                                id: crate::expr::Id::Global(GlobalId::User(1)),
1601                            }
1602                            .with_types(
1603                                RelationType::new(vec![
1604                                    ColumnType::new(ConcreteDataType::uint32_datatype(), false),
1605                                    ColumnType::new(ConcreteDataType::timestamp_millisecond_datatype(), false),
1606                                ])
1607                                .into_named(vec![
1608                                    Some("number".to_string()),
1609                                    Some("ts".to_string()),
1610                                ]),
1611                            )
1612                            .mfp(MapFilterProject::new(2).into_safe())
1613                            .unwrap(),
1614                        ),
1615
1616                        key_val_plan: KeyValPlan {
1617                            key_plan: MapFilterProject::new(2)
1618                                .map(vec![ScalarExpr::CallDf {
1619                                    df_scalar_fn: DfScalarFunction::try_from_raw_fn(
1620                                        RawDfScalarFn {
1621                                            f: BytesMut::from(
1622                                                b"\x08\x02\"\x0f\x1a\r\n\x0b\xa2\x02\x08\n\0\x12\x04\x10\x1e \t\"\n\x1a\x08\x12\x06\n\x04\x12\x02\x08\x01".as_ref(),
1623                                            ),
1624                                            input_schema: RelationType::new(vec![ColumnType::new(
1625                                                ConcreteDataType::interval_month_day_nano_datatype(),
1626                                                true,
1627                                            ),ColumnType::new(
1628                                                ConcreteDataType::timestamp_millisecond_datatype(),
1629                                                false,
1630                                            )])
1631                                            .into_unnamed(),
1632                                            extensions: FunctionExtensions::from_iter([
1633                                                    (0, "subtract".to_string()),
1634                                                    (1, "divide".to_string()),
1635                                                    (2, "date_bin".to_string()),
1636                                                    (3, "max".to_string()),
1637                                                    (4, "min".to_string()),
1638                                                ]),
1639                                        },
1640                                    )
1641                                    .await
1642                                    .unwrap(),
1643                                    exprs: vec![
1644                                        ScalarExpr::Literal(
1645                                            Value::IntervalMonthDayNano(IntervalMonthDayNano::new(0, 0, 30000000000)),
1646                                            CDT::interval_month_day_nano_datatype()
1647                                        ),
1648                                        ScalarExpr::Column(1)
1649                                        ],
1650                                }])
1651                                .unwrap()
1652                                .project(vec![2])
1653                                .unwrap()
1654                                .into_safe(),
1655                            val_plan: MapFilterProject::new(2)
1656                                .into_safe(),
1657                        },
1658                        reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
1659                            full_aggrs: aggr_exprs.clone(),
1660                            simple_aggrs: vec![AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
1661                            AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1)],
1662                            distinct_aggrs: vec![],
1663                        }),
1664                    }
1665                    .with_types(
1666                        RelationType::new(vec![
1667                            ColumnType::new(
1668                                ConcreteDataType::timestamp_millisecond_datatype(),
1669                                true,
1670                            ), // time_window
1671                            ColumnType::new(ConcreteDataType::uint32_datatype(), true), // max
1672                            ColumnType::new(ConcreteDataType::uint32_datatype(), true), // min
1673                        ])
1674                        .with_time_index(Some(0))
1675                        .into_unnamed(),
1676                    ),
1677                ),
1678                mfp: MapFilterProject::new(3)
1679                    .map(vec![
1680                        ScalarExpr::Column(1)
1681                            .call_binary(ScalarExpr::Column(2), BinaryFunc::SubUInt32)
1682                            .cast(CDT::float64_datatype())
1683                            .call_binary(
1684                                ScalarExpr::Literal(Value::from(30.0f64), CDT::float64_datatype()),
1685                                BinaryFunc::DivFloat64,
1686                            ),
1687                        ScalarExpr::Column(0),
1688                    ])
1689                    .unwrap()
1690                    .project(vec![3, 4])
1691                    .unwrap(),
1692            },
1693        };
1694
1695        assert_eq!(flow_plan.unwrap(), expected);
1696    }
1697}