flow/expr/relation/
accum.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Accumulators for aggregate functions that's is accumulatable. i.e. sum/count
16//!
17//! Accumulator will only be restore from row and being updated every time dataflow need process a new batch of rows.
18//! So the overhead is acceptable.
19//!
20//! Currently support sum, count, any, all and min/max(with one caveat that min/max can't support delete with aggregate).
21//! TODO: think of better ways to not ser/de every time a accum needed to be updated, since it's in a tight loop
22
23use std::any::type_name;
24use std::fmt::Display;
25
26use common_decimal::Decimal128;
27use datatypes::data_type::ConcreteDataType;
28use datatypes::value::{OrderedF32, OrderedF64, OrderedFloat, Value};
29use enum_dispatch::enum_dispatch;
30use serde::{Deserialize, Serialize};
31use snafu::ensure;
32
33use crate::expr::error::{InternalSnafu, OverflowSnafu, TryFromValueSnafu, TypeMismatchSnafu};
34use crate::expr::signature::GenericFn;
35use crate::expr::{AggregateFunc, EvalError};
36use crate::repr::Diff;
37
38/// Accumulates values for the various types of accumulable aggregations.
39#[enum_dispatch]
40pub trait Accumulator: Sized {
41    fn into_state(self) -> Vec<Value>;
42
43    fn update(
44        &mut self,
45        aggr_fn: &AggregateFunc,
46        value: Value,
47        diff: Diff,
48    ) -> Result<(), EvalError>;
49
50    fn update_batch<I>(&mut self, aggr_fn: &AggregateFunc, value_diffs: I) -> Result<(), EvalError>
51    where
52        I: IntoIterator<Item = (Value, Diff)>,
53    {
54        for (v, d) in value_diffs {
55            self.update(aggr_fn, v, d)?;
56        }
57        Ok(())
58    }
59
60    fn eval(&self, aggr_fn: &AggregateFunc) -> Result<Value, EvalError>;
61}
62
63/// Bool accumulator, used for `Any` `All` `Max/MinBool`
64#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
65pub struct Bool {
66    /// The number of `true` values observed.
67    trues: Diff,
68    /// The number of `false` values observed.
69    falses: Diff,
70}
71
72impl Bool {
73    /// Expect two `Diff` type values, one for `true` and one for `false`.
74    pub fn try_from_iter<I>(iter: &mut I) -> Result<Self, EvalError>
75    where
76        I: Iterator<Item = Value>,
77    {
78        Ok(Self {
79            trues: Diff::try_from(iter.next().ok_or_else(fail_accum::<Self>)?)
80                .map_err(err_try_from_val)?,
81            falses: Diff::try_from(iter.next().ok_or_else(fail_accum::<Self>)?)
82                .map_err(err_try_from_val)?,
83        })
84    }
85}
86
87impl TryFrom<Vec<Value>> for Bool {
88    type Error = EvalError;
89
90    fn try_from(state: Vec<Value>) -> Result<Self, Self::Error> {
91        ensure!(
92            state.len() == 2,
93            InternalSnafu {
94                reason: "Bool Accumulator state should have 2 values",
95            }
96        );
97        let mut iter = state.into_iter();
98
99        Self::try_from_iter(&mut iter)
100    }
101}
102
103impl Accumulator for Bool {
104    fn into_state(self) -> Vec<Value> {
105        vec![self.trues.into(), self.falses.into()]
106    }
107
108    /// Null values are ignored
109    fn update(
110        &mut self,
111        aggr_fn: &AggregateFunc,
112        value: Value,
113        diff: Diff,
114    ) -> Result<(), EvalError> {
115        ensure!(
116            matches!(
117                aggr_fn,
118                AggregateFunc::Any
119                    | AggregateFunc::All
120                    | AggregateFunc::MaxBool
121                    | AggregateFunc::MinBool
122            ),
123            InternalSnafu {
124                reason: format!(
125                    "Bool Accumulator does not support this aggregation function: {:?}",
126                    aggr_fn
127                ),
128            }
129        );
130
131        match value {
132            Value::Boolean(true) => self.trues += diff,
133            Value::Boolean(false) => self.falses += diff,
134            Value::Null => (), // ignore nulls
135            x => {
136                return Err(TypeMismatchSnafu {
137                    expected: ConcreteDataType::boolean_datatype(),
138                    actual: x.data_type(),
139                }
140                .build());
141            }
142        };
143        Ok(())
144    }
145
146    fn eval(&self, aggr_fn: &AggregateFunc) -> Result<Value, EvalError> {
147        match aggr_fn {
148            AggregateFunc::Any => Ok(Value::from(self.trues > 0)),
149            AggregateFunc::All => Ok(Value::from(self.falses == 0)),
150            AggregateFunc::MaxBool => Ok(Value::from(self.trues > 0)),
151            AggregateFunc::MinBool => Ok(Value::from(self.falses == 0)),
152            _ => Err(InternalSnafu {
153                reason: format!(
154                    "Bool Accumulator does not support this aggregation function: {:?}",
155                    aggr_fn
156                ),
157            }
158            .build()),
159        }
160    }
161}
162
163/// Accumulates simple numeric values for sum over integer.
164#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
165pub struct SimpleNumber {
166    /// The accumulation of all non-NULL values observed.
167    accum: i128,
168    /// The number of non-NULL values observed.
169    non_nulls: Diff,
170}
171
172impl SimpleNumber {
173    /// Expect one `Decimal128` and one `Diff` type values.
174    /// The `Decimal128` type is used to store the sum of all non-NULL values.
175    /// The `Diff` type is used to count the number of non-NULL values.
176    pub fn try_from_iter<I>(iter: &mut I) -> Result<Self, EvalError>
177    where
178        I: Iterator<Item = Value>,
179    {
180        Ok(Self {
181            accum: Decimal128::try_from(iter.next().ok_or_else(fail_accum::<Self>)?)
182                .map_err(err_try_from_val)?
183                .val(),
184            non_nulls: Diff::try_from(iter.next().ok_or_else(fail_accum::<Self>)?)
185                .map_err(err_try_from_val)?,
186        })
187    }
188}
189
190impl TryFrom<Vec<Value>> for SimpleNumber {
191    type Error = EvalError;
192
193    fn try_from(state: Vec<Value>) -> Result<Self, Self::Error> {
194        ensure!(
195            state.len() == 2,
196            InternalSnafu {
197                reason: "Number Accumulator state should have 2 values",
198            }
199        );
200        let mut iter = state.into_iter();
201        Self::try_from_iter(&mut iter)
202    }
203}
204
205impl Accumulator for SimpleNumber {
206    fn into_state(self) -> Vec<Value> {
207        vec![
208            Value::Decimal128(Decimal128::new(self.accum, 38, 0)),
209            self.non_nulls.into(),
210        ]
211    }
212
213    fn update(
214        &mut self,
215        aggr_fn: &AggregateFunc,
216        value: Value,
217        diff: Diff,
218    ) -> Result<(), EvalError> {
219        ensure!(
220            matches!(
221                aggr_fn,
222                AggregateFunc::SumInt16
223                    | AggregateFunc::SumInt32
224                    | AggregateFunc::SumInt64
225                    | AggregateFunc::SumUInt16
226                    | AggregateFunc::SumUInt32
227                    | AggregateFunc::SumUInt64
228            ),
229            InternalSnafu {
230                reason: format!(
231                    "SimpleNumber Accumulator does not support this aggregation function: {:?}",
232                    aggr_fn
233                ),
234            }
235        );
236
237        let v = match (aggr_fn, value) {
238            (AggregateFunc::SumInt16, Value::Int16(x)) => i128::from(x),
239            (AggregateFunc::SumInt32, Value::Int32(x)) => i128::from(x),
240            (AggregateFunc::SumInt64, Value::Int64(x)) => i128::from(x),
241            (AggregateFunc::SumUInt16, Value::UInt16(x)) => i128::from(x),
242            (AggregateFunc::SumUInt32, Value::UInt32(x)) => i128::from(x),
243            (AggregateFunc::SumUInt64, Value::UInt64(x)) => i128::from(x),
244            (_f, Value::Null) => return Ok(()), // ignore null
245            (f, v) => {
246                let expected_datatype = f.signature().input;
247                return Err(TypeMismatchSnafu {
248                    expected: expected_datatype[0].clone(),
249                    actual: v.data_type(),
250                }
251                .build())?;
252            }
253        };
254
255        self.accum += v * i128::from(diff);
256
257        self.non_nulls += diff;
258        Ok(())
259    }
260
261    fn eval(&self, aggr_fn: &AggregateFunc) -> Result<Value, EvalError> {
262        match aggr_fn {
263            AggregateFunc::SumInt16 | AggregateFunc::SumInt32 | AggregateFunc::SumInt64 => {
264                i64::try_from(self.accum)
265                    .map_err(|_e| OverflowSnafu {}.build())
266                    .map(Value::from)
267            }
268            AggregateFunc::SumUInt16 | AggregateFunc::SumUInt32 | AggregateFunc::SumUInt64 => {
269                u64::try_from(self.accum)
270                    .map_err(|_e| OverflowSnafu {}.build())
271                    .map(Value::from)
272            }
273            _ => Err(InternalSnafu {
274                reason: format!(
275                    "SimpleNumber Accumulator does not support this aggregation function: {:?}",
276                    aggr_fn
277                ),
278            }
279            .build()),
280        }
281    }
282}
283/// Accumulates float values for sum over floating numbers.
284#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
285pub struct Float {
286    /// Accumulates non-special float values, i.e. not NaN, +inf, -inf.
287    /// accum will be set to zero if `non_nulls` is zero.
288    accum: OrderedF64,
289    /// Counts +inf
290    pos_infs: Diff,
291    /// Counts -inf
292    neg_infs: Diff,
293    /// Counts NaNs
294    nans: Diff,
295    /// Counts non-NULL values
296    non_nulls: Diff,
297}
298
299impl Float {
300    /// Expect first value to be `OrderedF64` and the rest four values to be `Diff` type values.
301    pub fn try_from_iter<I>(iter: &mut I) -> Result<Self, EvalError>
302    where
303        I: Iterator<Item = Value>,
304    {
305        let mut ret = Self {
306            accum: OrderedF64::try_from(iter.next().ok_or_else(fail_accum::<Self>)?)
307                .map_err(err_try_from_val)?,
308            pos_infs: Diff::try_from(iter.next().ok_or_else(fail_accum::<Self>)?)
309                .map_err(err_try_from_val)?,
310            neg_infs: Diff::try_from(iter.next().ok_or_else(fail_accum::<Self>)?)
311                .map_err(err_try_from_val)?,
312            nans: Diff::try_from(iter.next().ok_or_else(fail_accum::<Self>)?)
313                .map_err(err_try_from_val)?,
314            non_nulls: Diff::try_from(iter.next().ok_or_else(fail_accum::<Self>)?)
315                .map_err(err_try_from_val)?,
316        };
317
318        // This prevent counter-intuitive behavior of summing over no values having non-zero results
319        if ret.non_nulls == 0 {
320            ret.accum = OrderedFloat::from(0.0);
321        }
322
323        Ok(ret)
324    }
325}
326
327impl TryFrom<Vec<Value>> for Float {
328    type Error = EvalError;
329
330    fn try_from(state: Vec<Value>) -> Result<Self, Self::Error> {
331        ensure!(
332            state.len() == 5,
333            InternalSnafu {
334                reason: "Float Accumulator state should have 5 values",
335            }
336        );
337
338        let mut iter = state.into_iter();
339
340        let mut ret = Self {
341            accum: OrderedF64::try_from(iter.next().unwrap()).map_err(err_try_from_val)?,
342            pos_infs: Diff::try_from(iter.next().unwrap()).map_err(err_try_from_val)?,
343            neg_infs: Diff::try_from(iter.next().unwrap()).map_err(err_try_from_val)?,
344            nans: Diff::try_from(iter.next().unwrap()).map_err(err_try_from_val)?,
345            non_nulls: Diff::try_from(iter.next().unwrap()).map_err(err_try_from_val)?,
346        };
347
348        // This prevent counter-intuitive behavior of summing over no values
349        if ret.non_nulls == 0 {
350            ret.accum = OrderedFloat::from(0.0);
351        }
352
353        Ok(ret)
354    }
355}
356
357impl Accumulator for Float {
358    fn into_state(self) -> Vec<Value> {
359        vec![
360            self.accum.into(),
361            self.pos_infs.into(),
362            self.neg_infs.into(),
363            self.nans.into(),
364            self.non_nulls.into(),
365        ]
366    }
367
368    /// sum ignore null
369    fn update(
370        &mut self,
371        aggr_fn: &AggregateFunc,
372        value: Value,
373        diff: Diff,
374    ) -> Result<(), EvalError> {
375        ensure!(
376            matches!(
377                aggr_fn,
378                AggregateFunc::SumFloat32 | AggregateFunc::SumFloat64
379            ),
380            InternalSnafu {
381                reason: format!(
382                    "Float Accumulator does not support this aggregation function: {:?}",
383                    aggr_fn
384                ),
385            }
386        );
387
388        let x = match (aggr_fn, value) {
389            (AggregateFunc::SumFloat32, Value::Float32(x)) => OrderedF64::from(*x as f64),
390            (AggregateFunc::SumFloat64, Value::Float64(x)) => OrderedF64::from(x),
391            (_f, Value::Null) => return Ok(()), // ignore null
392            (f, v) => {
393                let expected_datatype = f.signature().input;
394                return Err(TypeMismatchSnafu {
395                    expected: expected_datatype[0].clone(),
396                    actual: v.data_type(),
397                }
398                .build())?;
399            }
400        };
401
402        if x.is_nan() {
403            self.nans += diff;
404        } else if x.is_infinite() {
405            if x.is_sign_positive() {
406                self.pos_infs += diff;
407            } else {
408                self.neg_infs += diff;
409            }
410        } else {
411            self.accum += *(x * OrderedF64::from(diff as f64));
412        }
413
414        self.non_nulls += diff;
415        Ok(())
416    }
417
418    fn eval(&self, aggr_fn: &AggregateFunc) -> Result<Value, EvalError> {
419        match aggr_fn {
420            AggregateFunc::SumFloat32 => Ok(Value::Float32(OrderedF32::from(self.accum.0 as f32))),
421            AggregateFunc::SumFloat64 => Ok(Value::Float64(self.accum)),
422            _ => Err(InternalSnafu {
423                reason: format!(
424                    "Float Accumulator does not support this aggregation function: {:?}",
425                    aggr_fn
426                ),
427            }
428            .build()),
429        }
430    }
431}
432
433/// Accumulates a single `Ord`ed `Value`, useful for min/max aggregations.
434#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
435pub struct OrdValue {
436    val: Option<Value>,
437    non_nulls: Diff,
438}
439
440impl OrdValue {
441    pub fn try_from_iter<I>(iter: &mut I) -> Result<Self, EvalError>
442    where
443        I: Iterator<Item = Value>,
444    {
445        Ok(Self {
446            val: {
447                let v = iter.next().ok_or_else(fail_accum::<Self>)?;
448                if v == Value::Null {
449                    None
450                } else {
451                    Some(v)
452                }
453            },
454            non_nulls: Diff::try_from(iter.next().ok_or_else(fail_accum::<Self>)?)
455                .map_err(err_try_from_val)?,
456        })
457    }
458}
459
460impl TryFrom<Vec<Value>> for OrdValue {
461    type Error = EvalError;
462
463    fn try_from(state: Vec<Value>) -> Result<Self, Self::Error> {
464        ensure!(
465            state.len() == 2,
466            InternalSnafu {
467                reason: "OrdValue Accumulator state should have 2 values",
468            }
469        );
470
471        let mut iter = state.into_iter();
472
473        Ok(Self {
474            val: {
475                let v = iter.next().unwrap();
476                if v == Value::Null {
477                    None
478                } else {
479                    Some(v)
480                }
481            },
482            non_nulls: Diff::try_from(iter.next().unwrap()).map_err(err_try_from_val)?,
483        })
484    }
485}
486
487impl Accumulator for OrdValue {
488    fn into_state(self) -> Vec<Value> {
489        vec![self.val.unwrap_or(Value::Null), self.non_nulls.into()]
490    }
491
492    /// min/max try to find results in all non-null values, if all values are null, the result is null.
493    /// count(col_name) gives the number of non-null values, count(*) gives the number of rows including nulls.
494    /// TODO(discord9): add count(*) as a aggr function
495    fn update(
496        &mut self,
497        aggr_fn: &AggregateFunc,
498        value: Value,
499        diff: Diff,
500    ) -> Result<(), EvalError> {
501        ensure!(
502            aggr_fn.is_max() || aggr_fn.is_min() || matches!(aggr_fn, AggregateFunc::Count),
503            InternalSnafu {
504                reason: format!(
505                    "OrdValue Accumulator does not support this aggregation function: {:?}",
506                    aggr_fn
507                ),
508            }
509        );
510        if diff <= 0 && (aggr_fn.is_max() || aggr_fn.is_min()) {
511            return Err(InternalSnafu {
512                reason: "OrdValue Accumulator does not support non-monotonic input for min/max aggregation".to_string(),
513            }.build());
514        }
515
516        // if aggr_fn is count, the incoming value type doesn't matter in type checking
517        // otherwise, type need to be the same or value can be null
518        let check_type_aggr_fn_and_arg_value =
519            ty_eq_without_precision(value.data_type(), aggr_fn.signature().input[0].clone())
520                || matches!(aggr_fn, AggregateFunc::Count)
521                || value.is_null();
522        let check_type_aggr_fn_and_self_val = self
523            .val
524            .as_ref()
525            .map(|zelf| {
526                ty_eq_without_precision(zelf.data_type(), aggr_fn.signature().input[0].clone())
527            })
528            .unwrap_or(true)
529            || matches!(aggr_fn, AggregateFunc::Count);
530
531        if !check_type_aggr_fn_and_arg_value {
532            return Err(TypeMismatchSnafu {
533                expected: aggr_fn.signature().input[0].clone(),
534                actual: value.data_type(),
535            }
536            .build());
537        } else if !check_type_aggr_fn_and_self_val {
538            return Err(TypeMismatchSnafu {
539                expected: aggr_fn.signature().input[0].clone(),
540                actual: self
541                    .val
542                    .as_ref()
543                    .map(|v| v.data_type())
544                    .unwrap_or(ConcreteDataType::null_datatype()),
545            }
546            .build());
547        }
548
549        let is_null = value.is_null();
550        if is_null {
551            return Ok(());
552        }
553
554        if !is_null {
555            // compile count(*) to count(true) to include null/non-nulls
556            // And the counts of non-null values are updated here
557            self.non_nulls += diff;
558
559            match aggr_fn.signature().generic_fn {
560                GenericFn::Max => {
561                    self.val = self
562                        .val
563                        .clone()
564                        .map(|v| v.max(value.clone()))
565                        .or_else(|| Some(value))
566                }
567                GenericFn::Min => {
568                    self.val = self
569                        .val
570                        .clone()
571                        .map(|v| v.min(value.clone()))
572                        .or_else(|| Some(value))
573                }
574
575                GenericFn::Count => (),
576                _ => unreachable!("already checked by ensure!"),
577            }
578        };
579        // min/max ignore nulls
580
581        Ok(())
582    }
583
584    fn eval(&self, aggr_fn: &AggregateFunc) -> Result<Value, EvalError> {
585        if aggr_fn.is_max() || aggr_fn.is_min() {
586            Ok(self.val.clone().unwrap_or(Value::Null))
587        } else if matches!(aggr_fn, AggregateFunc::Count) {
588            Ok(self.non_nulls.into())
589        } else {
590            Err(InternalSnafu {
591                reason: format!(
592                    "OrdValue Accumulator does not support this aggregation function: {:?}",
593                    aggr_fn
594                ),
595            }
596            .build())
597        }
598    }
599}
600
601/// Accumulates values for the various types of accumulable aggregations.
602///
603/// We assume that there are not more than 2^32 elements for the aggregation.
604/// Thus we can perform a summation over i32 in an i64 accumulator
605/// and not worry about exceeding its bounds.
606///
607/// The float accumulator performs accumulation with tolerance for floating point error.
608///
609/// TODO(discord9): check for overflowing
610#[enum_dispatch(Accumulator)]
611#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
612pub enum Accum {
613    /// Accumulates boolean values.
614    Bool(Bool),
615    /// Accumulates simple numeric values.
616    SimpleNumber(SimpleNumber),
617    /// Accumulates float values.
618    Float(Float),
619    /// Accumulate Values that impl `Ord`
620    OrdValue(OrdValue),
621}
622
623impl Accum {
624    /// create a new accumulator from given aggregate function
625    pub fn new_accum(aggr_fn: &AggregateFunc) -> Result<Self, EvalError> {
626        Ok(match aggr_fn {
627            AggregateFunc::Any
628            | AggregateFunc::All
629            | AggregateFunc::MaxBool
630            | AggregateFunc::MinBool => Self::from(Bool {
631                trues: 0,
632                falses: 0,
633            }),
634            AggregateFunc::SumInt16
635            | AggregateFunc::SumInt32
636            | AggregateFunc::SumInt64
637            | AggregateFunc::SumUInt16
638            | AggregateFunc::SumUInt32
639            | AggregateFunc::SumUInt64 => Self::from(SimpleNumber {
640                accum: 0,
641                non_nulls: 0,
642            }),
643            AggregateFunc::SumFloat32 | AggregateFunc::SumFloat64 => Self::from(Float {
644                accum: OrderedF64::from(0.0),
645                pos_infs: 0,
646                neg_infs: 0,
647                nans: 0,
648                non_nulls: 0,
649            }),
650            f if f.is_max() || f.is_min() || matches!(f, AggregateFunc::Count) => {
651                Self::from(OrdValue {
652                    val: None,
653                    non_nulls: 0,
654                })
655            }
656            f => {
657                return Err(InternalSnafu {
658                    reason: format!(
659                        "Accumulator does not support this aggregation function: {:?}",
660                        f
661                    ),
662                }
663                .build());
664            }
665        })
666    }
667
668    pub fn try_from_iter(
669        aggr_fn: &AggregateFunc,
670        iter: &mut impl Iterator<Item = Value>,
671    ) -> Result<Self, EvalError> {
672        match aggr_fn {
673            AggregateFunc::Any
674            | AggregateFunc::All
675            | AggregateFunc::MaxBool
676            | AggregateFunc::MinBool => Ok(Self::from(Bool::try_from_iter(iter)?)),
677            AggregateFunc::SumInt16
678            | AggregateFunc::SumInt32
679            | AggregateFunc::SumInt64
680            | AggregateFunc::SumUInt16
681            | AggregateFunc::SumUInt32
682            | AggregateFunc::SumUInt64 => Ok(Self::from(SimpleNumber::try_from_iter(iter)?)),
683            AggregateFunc::SumFloat32 | AggregateFunc::SumFloat64 => {
684                Ok(Self::from(Float::try_from_iter(iter)?))
685            }
686            f if f.is_max() || f.is_min() || matches!(f, AggregateFunc::Count) => {
687                Ok(Self::from(OrdValue::try_from_iter(iter)?))
688            }
689            f => Err(InternalSnafu {
690                reason: format!(
691                    "Accumulator does not support this aggregation function: {:?}",
692                    f
693                ),
694            }
695            .build()),
696        }
697    }
698
699    /// try to convert a vector of value into given aggregate function's accumulator
700    pub fn try_into_accum(aggr_fn: &AggregateFunc, state: Vec<Value>) -> Result<Self, EvalError> {
701        match aggr_fn {
702            AggregateFunc::Any
703            | AggregateFunc::All
704            | AggregateFunc::MaxBool
705            | AggregateFunc::MinBool => Ok(Self::from(Bool::try_from(state)?)),
706            AggregateFunc::SumInt16
707            | AggregateFunc::SumInt32
708            | AggregateFunc::SumInt64
709            | AggregateFunc::SumUInt16
710            | AggregateFunc::SumUInt32
711            | AggregateFunc::SumUInt64 => Ok(Self::from(SimpleNumber::try_from(state)?)),
712            AggregateFunc::SumFloat32 | AggregateFunc::SumFloat64 => {
713                Ok(Self::from(Float::try_from(state)?))
714            }
715            f if f.is_max() || f.is_min() || matches!(f, AggregateFunc::Count) => {
716                Ok(Self::from(OrdValue::try_from(state)?))
717            }
718            f => Err(InternalSnafu {
719                reason: format!(
720                    "Accumulator does not support this aggregation function: {:?}",
721                    f
722                ),
723            }
724            .build()),
725        }
726    }
727}
728
729fn fail_accum<T>() -> EvalError {
730    InternalSnafu {
731        reason: format!(
732            "list of values exhausted before a accum of type {} can be build from it",
733            type_name::<T>()
734        ),
735    }
736    .build()
737}
738
739fn err_try_from_val<T: Display>(reason: T) -> EvalError {
740    TryFromValueSnafu {
741        msg: reason.to_string(),
742    }
743    .build()
744}
745
746/// compare type while ignore their precision, including `TimeStamp`, `Time`,
747/// `Duration`, `Interval`
748fn ty_eq_without_precision(left: ConcreteDataType, right: ConcreteDataType) -> bool {
749    left == right
750        || matches!(left, ConcreteDataType::Timestamp(..))
751            && matches!(right, ConcreteDataType::Timestamp(..))
752        || matches!(left, ConcreteDataType::Time(..)) && matches!(right, ConcreteDataType::Time(..))
753        || matches!(left, ConcreteDataType::Duration(..))
754            && matches!(right, ConcreteDataType::Duration(..))
755        || matches!(left, ConcreteDataType::Interval(..))
756            && matches!(right, ConcreteDataType::Interval(..))
757}
758
759#[allow(clippy::too_many_lines)]
760#[cfg(test)]
761mod test {
762    use common_time::Timestamp;
763
764    use super::*;
765
766    #[test]
767    fn test_accum() {
768        let testcases = vec![
769            (
770                AggregateFunc::SumInt32,
771                vec![(Value::Int32(1), 1), (Value::Null, 1)],
772                (
773                    Value::Int64(1),
774                    vec![Value::Decimal128(Decimal128::new(1, 38, 0)), 1i64.into()],
775                ),
776            ),
777            (
778                AggregateFunc::SumFloat32,
779                vec![(Value::Float32(OrderedF32::from(1.0)), 1), (Value::Null, 1)],
780                (
781                    Value::Float32(OrderedF32::from(1.0)),
782                    vec![
783                        Value::Float64(OrderedF64::from(1.0)),
784                        0i64.into(),
785                        0i64.into(),
786                        0i64.into(),
787                        1i64.into(),
788                    ],
789                ),
790            ),
791            (
792                AggregateFunc::MaxInt32,
793                vec![(Value::Int32(1), 1), (Value::Int32(2), 1), (Value::Null, 1)],
794                (Value::Int32(2), vec![Value::Int32(2), 2i64.into()]),
795            ),
796            (
797                AggregateFunc::MinInt32,
798                vec![(Value::Int32(2), 1), (Value::Int32(1), 1), (Value::Null, 1)],
799                (Value::Int32(1), vec![Value::Int32(1), 2i64.into()]),
800            ),
801            (
802                AggregateFunc::MaxFloat32,
803                vec![
804                    (Value::Float32(OrderedF32::from(1.0)), 1),
805                    (Value::Float32(OrderedF32::from(2.0)), 1),
806                    (Value::Null, 1),
807                ],
808                (
809                    Value::Float32(OrderedF32::from(2.0)),
810                    vec![Value::Float32(OrderedF32::from(2.0)), 2i64.into()],
811                ),
812            ),
813            (
814                AggregateFunc::MaxDateTime,
815                vec![
816                    (Value::Timestamp(Timestamp::from(0)), 1),
817                    (Value::Timestamp(Timestamp::from(1)), 1),
818                    (Value::Null, 1),
819                ],
820                (
821                    Value::Timestamp(Timestamp::from(1)),
822                    vec![Value::Timestamp(Timestamp::from(1)), 2i64.into()],
823                ),
824            ),
825            (
826                AggregateFunc::Count,
827                vec![
828                    (Value::Int32(1), 1),
829                    (Value::Int32(2), 1),
830                    (Value::Null, 1),
831                    (Value::Null, 1),
832                ],
833                (2i64.into(), vec![Value::Null, 2i64.into()]),
834            ),
835            (
836                AggregateFunc::Any,
837                vec![
838                    (Value::Boolean(false), 1),
839                    (Value::Boolean(false), 1),
840                    (Value::Boolean(true), 1),
841                    (Value::Null, 1),
842                ],
843                (
844                    Value::Boolean(true),
845                    vec![Value::from(1i64), Value::from(2i64)],
846                ),
847            ),
848            (
849                AggregateFunc::All,
850                vec![
851                    (Value::Boolean(false), 1),
852                    (Value::Boolean(false), 1),
853                    (Value::Boolean(true), 1),
854                    (Value::Null, 1),
855                ],
856                (
857                    Value::Boolean(false),
858                    vec![Value::from(1i64), Value::from(2i64)],
859                ),
860            ),
861            (
862                AggregateFunc::MaxBool,
863                vec![
864                    (Value::Boolean(false), 1),
865                    (Value::Boolean(false), 1),
866                    (Value::Boolean(true), 1),
867                    (Value::Null, 1),
868                ],
869                (
870                    Value::Boolean(true),
871                    vec![Value::from(1i64), Value::from(2i64)],
872                ),
873            ),
874            (
875                AggregateFunc::MinBool,
876                vec![
877                    (Value::Boolean(false), 1),
878                    (Value::Boolean(false), 1),
879                    (Value::Boolean(true), 1),
880                    (Value::Null, 1),
881                ],
882                (
883                    Value::Boolean(false),
884                    vec![Value::from(1i64), Value::from(2i64)],
885                ),
886            ),
887        ];
888
889        for (aggr_fn, input, (eval_res, state)) in testcases {
890            let create_and_insert = || -> Result<Accum, EvalError> {
891                let mut acc = Accum::new_accum(&aggr_fn)?;
892                acc.update_batch(&aggr_fn, input.clone())?;
893                let row = acc.into_state();
894                let acc = Accum::try_into_accum(&aggr_fn, row.clone())?;
895                let alter_acc = Accum::try_from_iter(&aggr_fn, &mut row.into_iter())?;
896                assert_eq!(acc, alter_acc);
897                Ok(acc)
898            };
899            let acc = match create_and_insert() {
900                Ok(acc) => acc,
901                Err(err) => panic!(
902                    "Failed to create accum for {:?} with input {:?} with error: {:?}",
903                    aggr_fn, input, err
904                ),
905            };
906
907            if acc.eval(&aggr_fn).unwrap() != eval_res {
908                panic!(
909                    "Failed to eval accum for {:?} with input {:?}, expect {:?}, got {:?}",
910                    aggr_fn,
911                    input,
912                    eval_res,
913                    acc.eval(&aggr_fn).unwrap()
914                );
915            }
916            let actual_state = acc.into_state();
917            if actual_state != state {
918                panic!(
919                    "Failed to cast into state from accum for {:?} with input {:?}, expect state {:?}, got state {:?}",
920                    aggr_fn,
921                    input,
922                    state,
923                    actual_state
924                );
925            }
926        }
927    }
928    #[test]
929    fn test_fail_path_accum() {
930        {
931            let bool_accum = Bool::try_from(vec![Value::Null]);
932            assert!(matches!(bool_accum, Err(EvalError::Internal { .. })));
933        }
934
935        {
936            let mut bool_accum = Bool::try_from(vec![1i64.into(), 1i64.into()]).unwrap();
937            // serde
938            let bool_accum_serde = serde_json::to_string(&bool_accum).unwrap();
939            let bool_accum_de = serde_json::from_str::<Bool>(&bool_accum_serde).unwrap();
940            assert_eq!(bool_accum, bool_accum_de);
941            assert!(matches!(
942                bool_accum.update(&AggregateFunc::MaxDate, 1.into(), 1),
943                Err(EvalError::Internal { .. })
944            ));
945            assert!(matches!(
946                bool_accum.update(&AggregateFunc::Any, 1.into(), 1),
947                Err(EvalError::TypeMismatch { .. })
948            ));
949            assert!(matches!(
950                bool_accum.eval(&AggregateFunc::MaxDate),
951                Err(EvalError::Internal { .. })
952            ));
953        }
954
955        {
956            let ret = SimpleNumber::try_from(vec![Value::Null]);
957            assert!(matches!(ret, Err(EvalError::Internal { .. })));
958            let mut accum =
959                SimpleNumber::try_from(vec![Decimal128::new(0, 38, 0).into(), 0i64.into()])
960                    .unwrap();
961
962            assert!(matches!(
963                accum.update(&AggregateFunc::All, 0.into(), 1),
964                Err(EvalError::Internal { .. })
965            ));
966            assert!(matches!(
967                accum.update(&AggregateFunc::SumInt64, 0i32.into(), 1),
968                Err(EvalError::TypeMismatch { .. })
969            ));
970            assert!(matches!(
971                accum.eval(&AggregateFunc::All),
972                Err(EvalError::Internal { .. })
973            ));
974            accum
975                .update(&AggregateFunc::SumInt64, 1i64.into(), 1)
976                .unwrap();
977            accum
978                .update(&AggregateFunc::SumInt64, i64::MAX.into(), 1)
979                .unwrap();
980            assert!(matches!(
981                accum.eval(&AggregateFunc::SumInt64),
982                Err(EvalError::Overflow { .. })
983            ));
984        }
985
986        {
987            let ret = Float::try_from(vec![2f64.into(), 0i64.into(), 0i64.into(), 0i64.into()]);
988            assert!(matches!(ret, Err(EvalError::Internal { .. })));
989            let mut accum = Float::try_from(vec![
990                2f64.into(),
991                0i64.into(),
992                0i64.into(),
993                0i64.into(),
994                1i64.into(),
995            ])
996            .unwrap();
997            accum
998                .update(&AggregateFunc::SumFloat64, 2f64.into(), -1)
999                .unwrap();
1000            assert!(matches!(
1001                accum.update(&AggregateFunc::All, 0.into(), 1),
1002                Err(EvalError::Internal { .. })
1003            ));
1004            assert!(matches!(
1005                accum.update(&AggregateFunc::SumFloat64, 0.0f32.into(), 1),
1006                Err(EvalError::TypeMismatch { .. })
1007            ));
1008            // no record, no accum
1009            assert_eq!(
1010                accum.eval(&AggregateFunc::SumFloat64).unwrap(),
1011                0.0f64.into()
1012            );
1013
1014            assert!(matches!(
1015                accum.eval(&AggregateFunc::All),
1016                Err(EvalError::Internal { .. })
1017            ));
1018
1019            accum
1020                .update(&AggregateFunc::SumFloat64, f64::INFINITY.into(), 1)
1021                .unwrap();
1022            accum
1023                .update(&AggregateFunc::SumFloat64, (-f64::INFINITY).into(), 1)
1024                .unwrap();
1025            accum
1026                .update(&AggregateFunc::SumFloat64, f64::NAN.into(), 1)
1027                .unwrap();
1028        }
1029
1030        {
1031            let ret = OrdValue::try_from(vec![Value::Null]);
1032            assert!(matches!(ret, Err(EvalError::Internal { .. })));
1033            let mut accum = OrdValue::try_from(vec![Value::Null, 0i64.into()]).unwrap();
1034            assert!(matches!(
1035                accum.update(&AggregateFunc::All, 0.into(), 1),
1036                Err(EvalError::Internal { .. })
1037            ));
1038            accum
1039                .update(&AggregateFunc::MaxInt16, 1i16.into(), 1)
1040                .unwrap();
1041            assert!(matches!(
1042                accum.update(&AggregateFunc::MaxInt16, 0i32.into(), 1),
1043                Err(EvalError::TypeMismatch { .. })
1044            ));
1045            assert!(matches!(
1046                accum.update(&AggregateFunc::MaxInt16, 0i16.into(), -1),
1047                Err(EvalError::Internal { .. })
1048            ));
1049            accum
1050                .update(&AggregateFunc::MaxInt16, Value::Null, 1)
1051                .unwrap();
1052        }
1053
1054        // insert uint64 into max_int64 should fail
1055        {
1056            let mut accum = OrdValue::try_from(vec![Value::Null, 0i64.into()]).unwrap();
1057            assert!(matches!(
1058                accum.update(&AggregateFunc::MaxInt64, 0u64.into(), 1),
1059                Err(EvalError::TypeMismatch { .. })
1060            ));
1061        }
1062    }
1063}