flow/expr/relation/
accum.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Accumulators for aggregate functions that's is accumulatable. i.e. sum/count
16//!
17//! Accumulator will only be restore from row and being updated every time dataflow need process a new batch of rows.
18//! So the overhead is acceptable.
19//!
20//! Currently support sum, count, any, all and min/max(with one caveat that min/max can't support delete with aggregate).
21//! TODO: think of better ways to not ser/de every time a accum needed to be updated, since it's in a tight loop
22
23use std::any::type_name;
24use std::fmt::Display;
25
26use common_decimal::Decimal128;
27use datatypes::data_type::ConcreteDataType;
28use datatypes::value::{OrderedF32, OrderedF64, OrderedFloat, Value};
29use enum_dispatch::enum_dispatch;
30use serde::{Deserialize, Serialize};
31use snafu::ensure;
32
33use crate::expr::error::{InternalSnafu, OverflowSnafu, TryFromValueSnafu, TypeMismatchSnafu};
34use crate::expr::signature::GenericFn;
35use crate::expr::{AggregateFunc, EvalError};
36use crate::repr::Diff;
37
38/// Accumulates values for the various types of accumulable aggregations.
39#[enum_dispatch]
40pub trait Accumulator: Sized {
41    fn into_state(self) -> Vec<Value>;
42
43    fn update(
44        &mut self,
45        aggr_fn: &AggregateFunc,
46        value: Value,
47        diff: Diff,
48    ) -> Result<(), EvalError>;
49
50    fn update_batch<I>(&mut self, aggr_fn: &AggregateFunc, value_diffs: I) -> Result<(), EvalError>
51    where
52        I: IntoIterator<Item = (Value, Diff)>,
53    {
54        for (v, d) in value_diffs {
55            self.update(aggr_fn, v, d)?;
56        }
57        Ok(())
58    }
59
60    fn eval(&self, aggr_fn: &AggregateFunc) -> Result<Value, EvalError>;
61}
62
63/// Bool accumulator, used for `Any` `All` `Max/MinBool`
64#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
65pub struct Bool {
66    /// The number of `true` values observed.
67    trues: Diff,
68    /// The number of `false` values observed.
69    falses: Diff,
70}
71
72impl Bool {
73    /// Expect two `Diff` type values, one for `true` and one for `false`.
74    pub fn try_from_iter<I>(iter: &mut I) -> Result<Self, EvalError>
75    where
76        I: Iterator<Item = Value>,
77    {
78        Ok(Self {
79            trues: Diff::try_from(iter.next().ok_or_else(fail_accum::<Self>)?)
80                .map_err(err_try_from_val)?,
81            falses: Diff::try_from(iter.next().ok_or_else(fail_accum::<Self>)?)
82                .map_err(err_try_from_val)?,
83        })
84    }
85}
86
87impl TryFrom<Vec<Value>> for Bool {
88    type Error = EvalError;
89
90    fn try_from(state: Vec<Value>) -> Result<Self, Self::Error> {
91        ensure!(
92            state.len() == 2,
93            InternalSnafu {
94                reason: "Bool Accumulator state should have 2 values",
95            }
96        );
97        let mut iter = state.into_iter();
98
99        Self::try_from_iter(&mut iter)
100    }
101}
102
103impl Accumulator for Bool {
104    fn into_state(self) -> Vec<Value> {
105        vec![self.trues.into(), self.falses.into()]
106    }
107
108    /// Null values are ignored
109    fn update(
110        &mut self,
111        aggr_fn: &AggregateFunc,
112        value: Value,
113        diff: Diff,
114    ) -> Result<(), EvalError> {
115        ensure!(
116            matches!(
117                aggr_fn,
118                AggregateFunc::Any
119                    | AggregateFunc::All
120                    | AggregateFunc::MaxBool
121                    | AggregateFunc::MinBool
122            ),
123            InternalSnafu {
124                reason: format!(
125                    "Bool Accumulator does not support this aggregation function: {:?}",
126                    aggr_fn
127                ),
128            }
129        );
130
131        match value {
132            Value::Boolean(true) => self.trues += diff,
133            Value::Boolean(false) => self.falses += diff,
134            Value::Null => (), // ignore nulls
135            x => {
136                return Err(TypeMismatchSnafu {
137                    expected: ConcreteDataType::boolean_datatype(),
138                    actual: x.data_type(),
139                }
140                .build());
141            }
142        };
143        Ok(())
144    }
145
146    fn eval(&self, aggr_fn: &AggregateFunc) -> Result<Value, EvalError> {
147        match aggr_fn {
148            AggregateFunc::Any => Ok(Value::from(self.trues > 0)),
149            AggregateFunc::All => Ok(Value::from(self.falses == 0)),
150            AggregateFunc::MaxBool => Ok(Value::from(self.trues > 0)),
151            AggregateFunc::MinBool => Ok(Value::from(self.falses == 0)),
152            _ => Err(InternalSnafu {
153                reason: format!(
154                    "Bool Accumulator does not support this aggregation function: {:?}",
155                    aggr_fn
156                ),
157            }
158            .build()),
159        }
160    }
161}
162
163/// Accumulates simple numeric values for sum over integer.
164#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
165pub struct SimpleNumber {
166    /// The accumulation of all non-NULL values observed.
167    accum: i128,
168    /// The number of non-NULL values observed.
169    non_nulls: Diff,
170}
171
172impl SimpleNumber {
173    /// Expect one `Decimal128` and one `Diff` type values.
174    /// The `Decimal128` type is used to store the sum of all non-NULL values.
175    /// The `Diff` type is used to count the number of non-NULL values.
176    pub fn try_from_iter<I>(iter: &mut I) -> Result<Self, EvalError>
177    where
178        I: Iterator<Item = Value>,
179    {
180        Ok(Self {
181            accum: Decimal128::try_from(iter.next().ok_or_else(fail_accum::<Self>)?)
182                .map_err(err_try_from_val)?
183                .val(),
184            non_nulls: Diff::try_from(iter.next().ok_or_else(fail_accum::<Self>)?)
185                .map_err(err_try_from_val)?,
186        })
187    }
188}
189
190impl TryFrom<Vec<Value>> for SimpleNumber {
191    type Error = EvalError;
192
193    fn try_from(state: Vec<Value>) -> Result<Self, Self::Error> {
194        ensure!(
195            state.len() == 2,
196            InternalSnafu {
197                reason: "Number Accumulator state should have 2 values",
198            }
199        );
200        let mut iter = state.into_iter();
201        Self::try_from_iter(&mut iter)
202    }
203}
204
205impl Accumulator for SimpleNumber {
206    fn into_state(self) -> Vec<Value> {
207        vec![
208            Value::Decimal128(Decimal128::new(self.accum, 38, 0)),
209            self.non_nulls.into(),
210        ]
211    }
212
213    fn update(
214        &mut self,
215        aggr_fn: &AggregateFunc,
216        value: Value,
217        diff: Diff,
218    ) -> Result<(), EvalError> {
219        ensure!(
220            matches!(
221                aggr_fn,
222                AggregateFunc::SumInt16
223                    | AggregateFunc::SumInt32
224                    | AggregateFunc::SumInt64
225                    | AggregateFunc::SumUInt16
226                    | AggregateFunc::SumUInt32
227                    | AggregateFunc::SumUInt64
228            ),
229            InternalSnafu {
230                reason: format!(
231                    "SimpleNumber Accumulator does not support this aggregation function: {:?}",
232                    aggr_fn
233                ),
234            }
235        );
236
237        let v = match (aggr_fn, value) {
238            (AggregateFunc::SumInt16, Value::Int16(x)) => i128::from(x),
239            (AggregateFunc::SumInt32, Value::Int32(x)) => i128::from(x),
240            (AggregateFunc::SumInt64, Value::Int64(x)) => i128::from(x),
241            (AggregateFunc::SumUInt16, Value::UInt16(x)) => i128::from(x),
242            (AggregateFunc::SumUInt32, Value::UInt32(x)) => i128::from(x),
243            (AggregateFunc::SumUInt64, Value::UInt64(x)) => i128::from(x),
244            (_f, Value::Null) => return Ok(()), // ignore null
245            (f, v) => {
246                let expected_datatype = f.signature().input;
247                return Err(TypeMismatchSnafu {
248                    expected: expected_datatype[0].clone(),
249                    actual: v.data_type(),
250                }
251                .build())?;
252            }
253        };
254
255        self.accum += v * i128::from(diff);
256
257        self.non_nulls += diff;
258        Ok(())
259    }
260
261    fn eval(&self, aggr_fn: &AggregateFunc) -> Result<Value, EvalError> {
262        match aggr_fn {
263            AggregateFunc::SumInt16 | AggregateFunc::SumInt32 | AggregateFunc::SumInt64 => {
264                i64::try_from(self.accum)
265                    .map_err(|_e| OverflowSnafu {}.build())
266                    .map(Value::from)
267            }
268            AggregateFunc::SumUInt16 | AggregateFunc::SumUInt32 | AggregateFunc::SumUInt64 => {
269                u64::try_from(self.accum)
270                    .map_err(|_e| OverflowSnafu {}.build())
271                    .map(Value::from)
272            }
273            _ => Err(InternalSnafu {
274                reason: format!(
275                    "SimpleNumber Accumulator does not support this aggregation function: {:?}",
276                    aggr_fn
277                ),
278            }
279            .build()),
280        }
281    }
282}
283/// Accumulates float values for sum over floating numbers.
284#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
285pub struct Float {
286    /// Accumulates non-special float values, i.e. not NaN, +inf, -inf.
287    /// accum will be set to zero if `non_nulls` is zero.
288    accum: OrderedF64,
289    /// Counts +inf
290    pos_infs: Diff,
291    /// Counts -inf
292    neg_infs: Diff,
293    /// Counts NaNs
294    nans: Diff,
295    /// Counts non-NULL values
296    non_nulls: Diff,
297}
298
299impl Float {
300    /// Expect first value to be `OrderedF64` and the rest four values to be `Diff` type values.
301    pub fn try_from_iter<I>(iter: &mut I) -> Result<Self, EvalError>
302    where
303        I: Iterator<Item = Value>,
304    {
305        let mut ret = Self {
306            accum: OrderedF64::try_from(iter.next().ok_or_else(fail_accum::<Self>)?)
307                .map_err(err_try_from_val)?,
308            pos_infs: Diff::try_from(iter.next().ok_or_else(fail_accum::<Self>)?)
309                .map_err(err_try_from_val)?,
310            neg_infs: Diff::try_from(iter.next().ok_or_else(fail_accum::<Self>)?)
311                .map_err(err_try_from_val)?,
312            nans: Diff::try_from(iter.next().ok_or_else(fail_accum::<Self>)?)
313                .map_err(err_try_from_val)?,
314            non_nulls: Diff::try_from(iter.next().ok_or_else(fail_accum::<Self>)?)
315                .map_err(err_try_from_val)?,
316        };
317
318        // This prevent counter-intuitive behavior of summing over no values having non-zero results
319        if ret.non_nulls == 0 {
320            ret.accum = OrderedFloat::from(0.0);
321        }
322
323        Ok(ret)
324    }
325}
326
327impl TryFrom<Vec<Value>> for Float {
328    type Error = EvalError;
329
330    fn try_from(state: Vec<Value>) -> Result<Self, Self::Error> {
331        ensure!(
332            state.len() == 5,
333            InternalSnafu {
334                reason: "Float Accumulator state should have 5 values",
335            }
336        );
337
338        let mut iter = state.into_iter();
339
340        let mut ret = Self {
341            accum: OrderedF64::try_from(iter.next().unwrap()).map_err(err_try_from_val)?,
342            pos_infs: Diff::try_from(iter.next().unwrap()).map_err(err_try_from_val)?,
343            neg_infs: Diff::try_from(iter.next().unwrap()).map_err(err_try_from_val)?,
344            nans: Diff::try_from(iter.next().unwrap()).map_err(err_try_from_val)?,
345            non_nulls: Diff::try_from(iter.next().unwrap()).map_err(err_try_from_val)?,
346        };
347
348        // This prevent counter-intuitive behavior of summing over no values
349        if ret.non_nulls == 0 {
350            ret.accum = OrderedFloat::from(0.0);
351        }
352
353        Ok(ret)
354    }
355}
356
357impl Accumulator for Float {
358    fn into_state(self) -> Vec<Value> {
359        vec![
360            self.accum.into(),
361            self.pos_infs.into(),
362            self.neg_infs.into(),
363            self.nans.into(),
364            self.non_nulls.into(),
365        ]
366    }
367
368    /// sum ignore null
369    fn update(
370        &mut self,
371        aggr_fn: &AggregateFunc,
372        value: Value,
373        diff: Diff,
374    ) -> Result<(), EvalError> {
375        ensure!(
376            matches!(
377                aggr_fn,
378                AggregateFunc::SumFloat32 | AggregateFunc::SumFloat64
379            ),
380            InternalSnafu {
381                reason: format!(
382                    "Float Accumulator does not support this aggregation function: {:?}",
383                    aggr_fn
384                ),
385            }
386        );
387
388        let x = match (aggr_fn, value) {
389            (AggregateFunc::SumFloat32, Value::Float32(x)) => OrderedF64::from(*x as f64),
390            (AggregateFunc::SumFloat64, Value::Float64(x)) => OrderedF64::from(x),
391            (_f, Value::Null) => return Ok(()), // ignore null
392            (f, v) => {
393                let expected_datatype = f.signature().input;
394                return Err(TypeMismatchSnafu {
395                    expected: expected_datatype[0].clone(),
396                    actual: v.data_type(),
397                }
398                .build())?;
399            }
400        };
401
402        if x.is_nan() {
403            self.nans += diff;
404        } else if x.is_infinite() {
405            if x.is_sign_positive() {
406                self.pos_infs += diff;
407            } else {
408                self.neg_infs += diff;
409            }
410        } else {
411            self.accum += *(x * OrderedF64::from(diff as f64));
412        }
413
414        self.non_nulls += diff;
415        Ok(())
416    }
417
418    fn eval(&self, aggr_fn: &AggregateFunc) -> Result<Value, EvalError> {
419        match aggr_fn {
420            AggregateFunc::SumFloat32 => Ok(Value::Float32(OrderedF32::from(self.accum.0 as f32))),
421            AggregateFunc::SumFloat64 => Ok(Value::Float64(self.accum)),
422            _ => Err(InternalSnafu {
423                reason: format!(
424                    "Float Accumulator does not support this aggregation function: {:?}",
425                    aggr_fn
426                ),
427            }
428            .build()),
429        }
430    }
431}
432
433/// Accumulates a single `Ord`ed `Value`, useful for min/max aggregations.
434#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
435pub struct OrdValue {
436    val: Option<Value>,
437    non_nulls: Diff,
438}
439
440impl OrdValue {
441    pub fn try_from_iter<I>(iter: &mut I) -> Result<Self, EvalError>
442    where
443        I: Iterator<Item = Value>,
444    {
445        Ok(Self {
446            val: {
447                let v = iter.next().ok_or_else(fail_accum::<Self>)?;
448                if v == Value::Null { None } else { Some(v) }
449            },
450            non_nulls: Diff::try_from(iter.next().ok_or_else(fail_accum::<Self>)?)
451                .map_err(err_try_from_val)?,
452        })
453    }
454}
455
456impl TryFrom<Vec<Value>> for OrdValue {
457    type Error = EvalError;
458
459    fn try_from(state: Vec<Value>) -> Result<Self, Self::Error> {
460        ensure!(
461            state.len() == 2,
462            InternalSnafu {
463                reason: "OrdValue Accumulator state should have 2 values",
464            }
465        );
466
467        let mut iter = state.into_iter();
468
469        Ok(Self {
470            val: {
471                let v = iter.next().unwrap();
472                if v == Value::Null { None } else { Some(v) }
473            },
474            non_nulls: Diff::try_from(iter.next().unwrap()).map_err(err_try_from_val)?,
475        })
476    }
477}
478
479impl Accumulator for OrdValue {
480    fn into_state(self) -> Vec<Value> {
481        vec![self.val.unwrap_or(Value::Null), self.non_nulls.into()]
482    }
483
484    /// min/max try to find results in all non-null values, if all values are null, the result is null.
485    /// count(col_name) gives the number of non-null values, count(*) gives the number of rows including nulls.
486    /// TODO(discord9): add count(*) as a aggr function
487    fn update(
488        &mut self,
489        aggr_fn: &AggregateFunc,
490        value: Value,
491        diff: Diff,
492    ) -> Result<(), EvalError> {
493        ensure!(
494            aggr_fn.is_max() || aggr_fn.is_min() || matches!(aggr_fn, AggregateFunc::Count),
495            InternalSnafu {
496                reason: format!(
497                    "OrdValue Accumulator does not support this aggregation function: {:?}",
498                    aggr_fn
499                ),
500            }
501        );
502        if diff <= 0 && (aggr_fn.is_max() || aggr_fn.is_min()) {
503            return Err(InternalSnafu {
504                reason: "OrdValue Accumulator does not support non-monotonic input for min/max aggregation".to_string(),
505            }.build());
506        }
507
508        // if aggr_fn is count, the incoming value type doesn't matter in type checking
509        // otherwise, type need to be the same or value can be null
510        let check_type_aggr_fn_and_arg_value =
511            ty_eq_without_precision(value.data_type(), aggr_fn.signature().input[0].clone())
512                || matches!(aggr_fn, AggregateFunc::Count)
513                || value.is_null();
514        let check_type_aggr_fn_and_self_val = self
515            .val
516            .as_ref()
517            .map(|zelf| {
518                ty_eq_without_precision(zelf.data_type(), aggr_fn.signature().input[0].clone())
519            })
520            .unwrap_or(true)
521            || matches!(aggr_fn, AggregateFunc::Count);
522
523        if !check_type_aggr_fn_and_arg_value {
524            return Err(TypeMismatchSnafu {
525                expected: aggr_fn.signature().input[0].clone(),
526                actual: value.data_type(),
527            }
528            .build());
529        } else if !check_type_aggr_fn_and_self_val {
530            return Err(TypeMismatchSnafu {
531                expected: aggr_fn.signature().input[0].clone(),
532                actual: self
533                    .val
534                    .as_ref()
535                    .map(|v| v.data_type())
536                    .unwrap_or(ConcreteDataType::null_datatype()),
537            }
538            .build());
539        }
540
541        let is_null = value.is_null();
542        if is_null {
543            return Ok(());
544        }
545
546        if !is_null {
547            // compile count(*) to count(true) to include null/non-nulls
548            // And the counts of non-null values are updated here
549            self.non_nulls += diff;
550
551            match aggr_fn.signature().generic_fn {
552                GenericFn::Max => {
553                    self.val = self
554                        .val
555                        .clone()
556                        .map(|v| v.max(value.clone()))
557                        .or_else(|| Some(value))
558                }
559                GenericFn::Min => {
560                    self.val = self
561                        .val
562                        .clone()
563                        .map(|v| v.min(value.clone()))
564                        .or_else(|| Some(value))
565                }
566
567                GenericFn::Count => (),
568                _ => unreachable!("already checked by ensure!"),
569            }
570        };
571        // min/max ignore nulls
572
573        Ok(())
574    }
575
576    fn eval(&self, aggr_fn: &AggregateFunc) -> Result<Value, EvalError> {
577        if aggr_fn.is_max() || aggr_fn.is_min() {
578            Ok(self.val.clone().unwrap_or(Value::Null))
579        } else if matches!(aggr_fn, AggregateFunc::Count) {
580            Ok(self.non_nulls.into())
581        } else {
582            Err(InternalSnafu {
583                reason: format!(
584                    "OrdValue Accumulator does not support this aggregation function: {:?}",
585                    aggr_fn
586                ),
587            }
588            .build())
589        }
590    }
591}
592
593/// Accumulates values for the various types of accumulable aggregations.
594///
595/// We assume that there are not more than 2^32 elements for the aggregation.
596/// Thus we can perform a summation over i32 in an i64 accumulator
597/// and not worry about exceeding its bounds.
598///
599/// The float accumulator performs accumulation with tolerance for floating point error.
600///
601/// TODO(discord9): check for overflowing
602#[enum_dispatch(Accumulator)]
603#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
604pub enum Accum {
605    /// Accumulates boolean values.
606    Bool(Bool),
607    /// Accumulates simple numeric values.
608    SimpleNumber(SimpleNumber),
609    /// Accumulates float values.
610    Float(Float),
611    /// Accumulate Values that impl `Ord`
612    OrdValue(OrdValue),
613}
614
615impl Accum {
616    /// create a new accumulator from given aggregate function
617    pub fn new_accum(aggr_fn: &AggregateFunc) -> Result<Self, EvalError> {
618        Ok(match aggr_fn {
619            AggregateFunc::Any
620            | AggregateFunc::All
621            | AggregateFunc::MaxBool
622            | AggregateFunc::MinBool => Self::from(Bool {
623                trues: 0,
624                falses: 0,
625            }),
626            AggregateFunc::SumInt16
627            | AggregateFunc::SumInt32
628            | AggregateFunc::SumInt64
629            | AggregateFunc::SumUInt16
630            | AggregateFunc::SumUInt32
631            | AggregateFunc::SumUInt64 => Self::from(SimpleNumber {
632                accum: 0,
633                non_nulls: 0,
634            }),
635            AggregateFunc::SumFloat32 | AggregateFunc::SumFloat64 => Self::from(Float {
636                accum: OrderedF64::from(0.0),
637                pos_infs: 0,
638                neg_infs: 0,
639                nans: 0,
640                non_nulls: 0,
641            }),
642            f if f.is_max() || f.is_min() || matches!(f, AggregateFunc::Count) => {
643                Self::from(OrdValue {
644                    val: None,
645                    non_nulls: 0,
646                })
647            }
648            f => {
649                return Err(InternalSnafu {
650                    reason: format!(
651                        "Accumulator does not support this aggregation function: {:?}",
652                        f
653                    ),
654                }
655                .build());
656            }
657        })
658    }
659
660    pub fn try_from_iter(
661        aggr_fn: &AggregateFunc,
662        iter: &mut impl Iterator<Item = Value>,
663    ) -> Result<Self, EvalError> {
664        match aggr_fn {
665            AggregateFunc::Any
666            | AggregateFunc::All
667            | AggregateFunc::MaxBool
668            | AggregateFunc::MinBool => Ok(Self::from(Bool::try_from_iter(iter)?)),
669            AggregateFunc::SumInt16
670            | AggregateFunc::SumInt32
671            | AggregateFunc::SumInt64
672            | AggregateFunc::SumUInt16
673            | AggregateFunc::SumUInt32
674            | AggregateFunc::SumUInt64 => Ok(Self::from(SimpleNumber::try_from_iter(iter)?)),
675            AggregateFunc::SumFloat32 | AggregateFunc::SumFloat64 => {
676                Ok(Self::from(Float::try_from_iter(iter)?))
677            }
678            f if f.is_max() || f.is_min() || matches!(f, AggregateFunc::Count) => {
679                Ok(Self::from(OrdValue::try_from_iter(iter)?))
680            }
681            f => Err(InternalSnafu {
682                reason: format!(
683                    "Accumulator does not support this aggregation function: {:?}",
684                    f
685                ),
686            }
687            .build()),
688        }
689    }
690
691    /// try to convert a vector of value into given aggregate function's accumulator
692    pub fn try_into_accum(aggr_fn: &AggregateFunc, state: Vec<Value>) -> Result<Self, EvalError> {
693        match aggr_fn {
694            AggregateFunc::Any
695            | AggregateFunc::All
696            | AggregateFunc::MaxBool
697            | AggregateFunc::MinBool => Ok(Self::from(Bool::try_from(state)?)),
698            AggregateFunc::SumInt16
699            | AggregateFunc::SumInt32
700            | AggregateFunc::SumInt64
701            | AggregateFunc::SumUInt16
702            | AggregateFunc::SumUInt32
703            | AggregateFunc::SumUInt64 => Ok(Self::from(SimpleNumber::try_from(state)?)),
704            AggregateFunc::SumFloat32 | AggregateFunc::SumFloat64 => {
705                Ok(Self::from(Float::try_from(state)?))
706            }
707            f if f.is_max() || f.is_min() || matches!(f, AggregateFunc::Count) => {
708                Ok(Self::from(OrdValue::try_from(state)?))
709            }
710            f => Err(InternalSnafu {
711                reason: format!(
712                    "Accumulator does not support this aggregation function: {:?}",
713                    f
714                ),
715            }
716            .build()),
717        }
718    }
719}
720
721fn fail_accum<T>() -> EvalError {
722    InternalSnafu {
723        reason: format!(
724            "list of values exhausted before a accum of type {} can be build from it",
725            type_name::<T>()
726        ),
727    }
728    .build()
729}
730
731fn err_try_from_val<T: Display>(reason: T) -> EvalError {
732    TryFromValueSnafu {
733        msg: reason.to_string(),
734    }
735    .build()
736}
737
738/// compare type while ignore their precision, including `TimeStamp`, `Time`,
739/// `Duration`, `Interval`
740fn ty_eq_without_precision(left: ConcreteDataType, right: ConcreteDataType) -> bool {
741    left == right
742        || matches!(left, ConcreteDataType::Timestamp(..))
743            && matches!(right, ConcreteDataType::Timestamp(..))
744        || matches!(left, ConcreteDataType::Time(..)) && matches!(right, ConcreteDataType::Time(..))
745        || matches!(left, ConcreteDataType::Duration(..))
746            && matches!(right, ConcreteDataType::Duration(..))
747        || matches!(left, ConcreteDataType::Interval(..))
748            && matches!(right, ConcreteDataType::Interval(..))
749}
750
751#[allow(clippy::too_many_lines)]
752#[cfg(test)]
753mod test {
754    use common_time::Timestamp;
755
756    use super::*;
757
758    #[test]
759    fn test_accum() {
760        let testcases = vec![
761            (
762                AggregateFunc::SumInt32,
763                vec![(Value::Int32(1), 1), (Value::Null, 1)],
764                (
765                    Value::Int64(1),
766                    vec![Value::Decimal128(Decimal128::new(1, 38, 0)), 1i64.into()],
767                ),
768            ),
769            (
770                AggregateFunc::SumFloat32,
771                vec![(Value::Float32(OrderedF32::from(1.0)), 1), (Value::Null, 1)],
772                (
773                    Value::Float32(OrderedF32::from(1.0)),
774                    vec![
775                        Value::Float64(OrderedF64::from(1.0)),
776                        0i64.into(),
777                        0i64.into(),
778                        0i64.into(),
779                        1i64.into(),
780                    ],
781                ),
782            ),
783            (
784                AggregateFunc::MaxInt32,
785                vec![(Value::Int32(1), 1), (Value::Int32(2), 1), (Value::Null, 1)],
786                (Value::Int32(2), vec![Value::Int32(2), 2i64.into()]),
787            ),
788            (
789                AggregateFunc::MinInt32,
790                vec![(Value::Int32(2), 1), (Value::Int32(1), 1), (Value::Null, 1)],
791                (Value::Int32(1), vec![Value::Int32(1), 2i64.into()]),
792            ),
793            (
794                AggregateFunc::MaxFloat32,
795                vec![
796                    (Value::Float32(OrderedF32::from(1.0)), 1),
797                    (Value::Float32(OrderedF32::from(2.0)), 1),
798                    (Value::Null, 1),
799                ],
800                (
801                    Value::Float32(OrderedF32::from(2.0)),
802                    vec![Value::Float32(OrderedF32::from(2.0)), 2i64.into()],
803                ),
804            ),
805            (
806                AggregateFunc::MaxDateTime,
807                vec![
808                    (Value::Timestamp(Timestamp::from(0)), 1),
809                    (Value::Timestamp(Timestamp::from(1)), 1),
810                    (Value::Null, 1),
811                ],
812                (
813                    Value::Timestamp(Timestamp::from(1)),
814                    vec![Value::Timestamp(Timestamp::from(1)), 2i64.into()],
815                ),
816            ),
817            (
818                AggregateFunc::Count,
819                vec![
820                    (Value::Int32(1), 1),
821                    (Value::Int32(2), 1),
822                    (Value::Null, 1),
823                    (Value::Null, 1),
824                ],
825                (2i64.into(), vec![Value::Null, 2i64.into()]),
826            ),
827            (
828                AggregateFunc::Any,
829                vec![
830                    (Value::Boolean(false), 1),
831                    (Value::Boolean(false), 1),
832                    (Value::Boolean(true), 1),
833                    (Value::Null, 1),
834                ],
835                (
836                    Value::Boolean(true),
837                    vec![Value::from(1i64), Value::from(2i64)],
838                ),
839            ),
840            (
841                AggregateFunc::All,
842                vec![
843                    (Value::Boolean(false), 1),
844                    (Value::Boolean(false), 1),
845                    (Value::Boolean(true), 1),
846                    (Value::Null, 1),
847                ],
848                (
849                    Value::Boolean(false),
850                    vec![Value::from(1i64), Value::from(2i64)],
851                ),
852            ),
853            (
854                AggregateFunc::MaxBool,
855                vec![
856                    (Value::Boolean(false), 1),
857                    (Value::Boolean(false), 1),
858                    (Value::Boolean(true), 1),
859                    (Value::Null, 1),
860                ],
861                (
862                    Value::Boolean(true),
863                    vec![Value::from(1i64), Value::from(2i64)],
864                ),
865            ),
866            (
867                AggregateFunc::MinBool,
868                vec![
869                    (Value::Boolean(false), 1),
870                    (Value::Boolean(false), 1),
871                    (Value::Boolean(true), 1),
872                    (Value::Null, 1),
873                ],
874                (
875                    Value::Boolean(false),
876                    vec![Value::from(1i64), Value::from(2i64)],
877                ),
878            ),
879        ];
880
881        for (aggr_fn, input, (eval_res, state)) in testcases {
882            let create_and_insert = || -> Result<Accum, EvalError> {
883                let mut acc = Accum::new_accum(&aggr_fn)?;
884                acc.update_batch(&aggr_fn, input.clone())?;
885                let row = acc.into_state();
886                let acc = Accum::try_into_accum(&aggr_fn, row.clone())?;
887                let alter_acc = Accum::try_from_iter(&aggr_fn, &mut row.into_iter())?;
888                assert_eq!(acc, alter_acc);
889                Ok(acc)
890            };
891            let acc = match create_and_insert() {
892                Ok(acc) => acc,
893                Err(err) => panic!(
894                    "Failed to create accum for {:?} with input {:?} with error: {:?}",
895                    aggr_fn, input, err
896                ),
897            };
898
899            if acc.eval(&aggr_fn).unwrap() != eval_res {
900                panic!(
901                    "Failed to eval accum for {:?} with input {:?}, expect {:?}, got {:?}",
902                    aggr_fn,
903                    input,
904                    eval_res,
905                    acc.eval(&aggr_fn).unwrap()
906                );
907            }
908            let actual_state = acc.into_state();
909            if actual_state != state {
910                panic!(
911                    "Failed to cast into state from accum for {:?} with input {:?}, expect state {:?}, got state {:?}",
912                    aggr_fn, input, state, actual_state
913                );
914            }
915        }
916    }
917    #[test]
918    fn test_fail_path_accum() {
919        {
920            let bool_accum = Bool::try_from(vec![Value::Null]);
921            assert!(matches!(bool_accum, Err(EvalError::Internal { .. })));
922        }
923
924        {
925            let mut bool_accum = Bool::try_from(vec![1i64.into(), 1i64.into()]).unwrap();
926            // serde
927            let bool_accum_serde = serde_json::to_string(&bool_accum).unwrap();
928            let bool_accum_de = serde_json::from_str::<Bool>(&bool_accum_serde).unwrap();
929            assert_eq!(bool_accum, bool_accum_de);
930            assert!(matches!(
931                bool_accum.update(&AggregateFunc::MaxDate, 1.into(), 1),
932                Err(EvalError::Internal { .. })
933            ));
934            assert!(matches!(
935                bool_accum.update(&AggregateFunc::Any, 1.into(), 1),
936                Err(EvalError::TypeMismatch { .. })
937            ));
938            assert!(matches!(
939                bool_accum.eval(&AggregateFunc::MaxDate),
940                Err(EvalError::Internal { .. })
941            ));
942        }
943
944        {
945            let ret = SimpleNumber::try_from(vec![Value::Null]);
946            assert!(matches!(ret, Err(EvalError::Internal { .. })));
947            let mut accum =
948                SimpleNumber::try_from(vec![Decimal128::new(0, 38, 0).into(), 0i64.into()])
949                    .unwrap();
950
951            assert!(matches!(
952                accum.update(&AggregateFunc::All, 0.into(), 1),
953                Err(EvalError::Internal { .. })
954            ));
955            assert!(matches!(
956                accum.update(&AggregateFunc::SumInt64, 0i32.into(), 1),
957                Err(EvalError::TypeMismatch { .. })
958            ));
959            assert!(matches!(
960                accum.eval(&AggregateFunc::All),
961                Err(EvalError::Internal { .. })
962            ));
963            accum
964                .update(&AggregateFunc::SumInt64, 1i64.into(), 1)
965                .unwrap();
966            accum
967                .update(&AggregateFunc::SumInt64, i64::MAX.into(), 1)
968                .unwrap();
969            assert!(matches!(
970                accum.eval(&AggregateFunc::SumInt64),
971                Err(EvalError::Overflow { .. })
972            ));
973        }
974
975        {
976            let ret = Float::try_from(vec![2f64.into(), 0i64.into(), 0i64.into(), 0i64.into()]);
977            assert!(matches!(ret, Err(EvalError::Internal { .. })));
978            let mut accum = Float::try_from(vec![
979                2f64.into(),
980                0i64.into(),
981                0i64.into(),
982                0i64.into(),
983                1i64.into(),
984            ])
985            .unwrap();
986            accum
987                .update(&AggregateFunc::SumFloat64, 2f64.into(), -1)
988                .unwrap();
989            assert!(matches!(
990                accum.update(&AggregateFunc::All, 0.into(), 1),
991                Err(EvalError::Internal { .. })
992            ));
993            assert!(matches!(
994                accum.update(&AggregateFunc::SumFloat64, 0.0f32.into(), 1),
995                Err(EvalError::TypeMismatch { .. })
996            ));
997            // no record, no accum
998            assert_eq!(
999                accum.eval(&AggregateFunc::SumFloat64).unwrap(),
1000                0.0f64.into()
1001            );
1002
1003            assert!(matches!(
1004                accum.eval(&AggregateFunc::All),
1005                Err(EvalError::Internal { .. })
1006            ));
1007
1008            accum
1009                .update(&AggregateFunc::SumFloat64, f64::INFINITY.into(), 1)
1010                .unwrap();
1011            accum
1012                .update(&AggregateFunc::SumFloat64, (-f64::INFINITY).into(), 1)
1013                .unwrap();
1014            accum
1015                .update(&AggregateFunc::SumFloat64, f64::NAN.into(), 1)
1016                .unwrap();
1017        }
1018
1019        {
1020            let ret = OrdValue::try_from(vec![Value::Null]);
1021            assert!(matches!(ret, Err(EvalError::Internal { .. })));
1022            let mut accum = OrdValue::try_from(vec![Value::Null, 0i64.into()]).unwrap();
1023            assert!(matches!(
1024                accum.update(&AggregateFunc::All, 0.into(), 1),
1025                Err(EvalError::Internal { .. })
1026            ));
1027            accum
1028                .update(&AggregateFunc::MaxInt16, 1i16.into(), 1)
1029                .unwrap();
1030            assert!(matches!(
1031                accum.update(&AggregateFunc::MaxInt16, 0i32.into(), 1),
1032                Err(EvalError::TypeMismatch { .. })
1033            ));
1034            assert!(matches!(
1035                accum.update(&AggregateFunc::MaxInt16, 0i16.into(), -1),
1036                Err(EvalError::Internal { .. })
1037            ));
1038            accum
1039                .update(&AggregateFunc::MaxInt16, Value::Null, 1)
1040                .unwrap();
1041        }
1042
1043        // insert uint64 into max_int64 should fail
1044        {
1045            let mut accum = OrdValue::try_from(vec![Value::Null, 0i64.into()]).unwrap();
1046            assert!(matches!(
1047                accum.update(&AggregateFunc::MaxInt64, 0u64.into(), 1),
1048                Err(EvalError::TypeMismatch { .. })
1049            ));
1050        }
1051    }
1052}