flow/expr/relation/
func.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::sync::OnceLock;
17
18use datatypes::prelude::ConcreteDataType;
19use datatypes::value::Value;
20use datatypes::vectors::VectorRef;
21use serde::{Deserialize, Serialize};
22use smallvec::smallvec;
23use snafu::OptionExt;
24use strum::{EnumIter, IntoEnumIterator};
25
26use crate::error::{Error, InvalidQuerySnafu};
27use crate::expr::error::EvalError;
28use crate::expr::relation::accum::{Accum, Accumulator};
29use crate::expr::signature::{GenericFn, Signature};
30use crate::expr::VectorDiff;
31use crate::repr::Diff;
32
33/// Aggregate functions that can be applied to a group of rows.
34///
35/// `Mean` function is deliberately not included as it can be computed from `Sum` and `Count`, whose state can be better managed.
36///
37/// type of the input and output of the aggregate function:
38///
39/// `sum(i*)->i64, sum(u*)->u64`
40///
41/// `count()->i64`
42///
43/// `min/max(T)->T`
44#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, EnumIter)]
45pub enum AggregateFunc {
46    MaxInt16,
47    MaxInt32,
48    MaxInt64,
49    MaxUInt16,
50    MaxUInt32,
51    MaxUInt64,
52    MaxFloat32,
53    MaxFloat64,
54    MaxBool,
55    MaxString,
56    MaxDate,
57    MaxDateTime,
58    MaxTimestamp,
59    MaxTime,
60    MaxDuration,
61    MaxInterval,
62
63    MinInt16,
64    MinInt32,
65    MinInt64,
66    MinUInt16,
67    MinUInt32,
68    MinUInt64,
69    MinFloat32,
70    MinFloat64,
71    MinBool,
72    MinString,
73    MinDate,
74    MinDateTime,
75    MinTimestamp,
76    MinTime,
77    MinDuration,
78    MinInterval,
79
80    SumInt16,
81    SumInt32,
82    SumInt64,
83    SumUInt16,
84    SumUInt32,
85    SumUInt64,
86    SumFloat32,
87    SumFloat64,
88
89    Count,
90    Any,
91    All,
92}
93
94impl AggregateFunc {
95    /// if this function is a `max`
96    pub fn is_max(&self) -> bool {
97        self.signature().generic_fn == GenericFn::Max
98    }
99
100    /// if this function is a `min`
101    pub fn is_min(&self) -> bool {
102        self.signature().generic_fn == GenericFn::Min
103    }
104
105    /// Eval value, diff with accumulator
106    ///
107    /// Expect self to be accumulable aggregate function, i.e. sum/count
108    ///
109    /// TODO(discord9): deal with overflow&better accumulator
110    pub fn eval_diff_accumulable<A, I>(
111        &self,
112        accum: A,
113        value_diffs: I,
114    ) -> Result<(Value, Vec<Value>), EvalError>
115    where
116        A: IntoIterator<Item = Value>,
117        I: IntoIterator<Item = (Value, Diff)>,
118    {
119        let mut accum = accum.into_iter().peekable();
120
121        let mut accum = if accum.peek().is_none() {
122            Accum::new_accum(self)?
123        } else {
124            Accum::try_from_iter(self, &mut accum)?
125        };
126        accum.update_batch(self, value_diffs)?;
127        let res = accum.eval(self)?;
128        Ok((res, accum.into_state()))
129    }
130
131    /// return output value and new accumulator state
132    pub fn eval_batch<A>(
133        &self,
134        accum: A,
135        vector: VectorRef,
136        diff: Option<VectorRef>,
137    ) -> Result<(Value, Vec<Value>), EvalError>
138    where
139        A: IntoIterator<Item = Value>,
140    {
141        let mut accum = accum.into_iter().peekable();
142
143        let mut accum = if accum.peek().is_none() {
144            Accum::new_accum(self)?
145        } else {
146            Accum::try_from_iter(self, &mut accum)?
147        };
148
149        let vector_diff = VectorDiff::try_new(vector, diff)?;
150
151        accum.update_batch(self, vector_diff)?;
152
153        let res = accum.eval(self)?;
154        Ok((res, accum.into_state()))
155    }
156}
157
158/// Generate signature for each aggregate function
159macro_rules! generate_signature {
160    ($value:ident,
161        { $($user_arm:tt)* },
162        [ $(
163            $auto_arm:ident=>($($arg:ident),*)
164            ),*
165        ]
166    ) => {
167        match $value {
168            $($user_arm)*,
169            $(
170                Self::$auto_arm => gen_one_siginature!($($arg),*),
171            )*
172        }
173    };
174}
175
176/// Generate one match arm with optional arguments
177macro_rules! gen_one_siginature {
178    (
179        $con_type:ident, $generic:ident
180    ) => {
181        Signature {
182            input: smallvec![ConcreteDataType::$con_type(), ConcreteDataType::$con_type(),],
183            output: ConcreteDataType::$con_type(),
184            generic_fn: GenericFn::$generic,
185        }
186    };
187    (
188        $in_type:ident, $out_type:ident, $generic:ident
189    ) => {
190        Signature {
191            input: smallvec![ConcreteDataType::$in_type()],
192            output: ConcreteDataType::$out_type(),
193            generic_fn: GenericFn::$generic,
194        }
195    };
196}
197
198static SPECIALIZATION: OnceLock<HashMap<(GenericFn, ConcreteDataType), AggregateFunc>> =
199    OnceLock::new();
200
201impl AggregateFunc {
202    /// Create a `AggregateFunc` from a string of the function name and given argument type(optional)
203    /// given an None type will be treated as null type,
204    /// which in turn for AggregateFunc like `Count` will be treated as any type
205    pub fn from_str_and_type(
206        name: &str,
207        arg_type: Option<ConcreteDataType>,
208    ) -> Result<Self, Error> {
209        let rule = SPECIALIZATION.get_or_init(|| {
210            let mut spec = HashMap::new();
211            for func in Self::iter() {
212                let sig = func.signature();
213                spec.insert((sig.generic_fn, sig.input[0].clone()), func);
214            }
215            spec
216        });
217
218        let generic_fn = match name {
219            "max" => GenericFn::Max,
220            "min" => GenericFn::Min,
221            "sum" => GenericFn::Sum,
222            "count" => GenericFn::Count,
223            "bool_or" => GenericFn::Any,
224            "bool_and" => GenericFn::All,
225            _ => {
226                return InvalidQuerySnafu {
227                    reason: format!("Unknown aggregate function: {}", name),
228                }
229                .fail();
230            }
231        };
232        let input_type = if matches!(generic_fn, GenericFn::Count) {
233            ConcreteDataType::null_datatype()
234        } else {
235            arg_type.unwrap_or_else(ConcreteDataType::null_datatype)
236        };
237        rule.get(&(generic_fn, input_type.clone()))
238            .cloned()
239            .with_context(|| InvalidQuerySnafu {
240                reason: format!(
241                    "No specialization found for aggregate function {:?} with input type {:?}",
242                    generic_fn, input_type
243                ),
244            })
245    }
246
247    /// all concrete datatypes with precision types will be returned with largest possible variant
248    /// as a exception, count have a signature of `null -> i64`, but it's actually `anytype -> i64`
249    ///
250    /// TODO(discorcd9): fix signature for sum unsign -> u64 sum signed -> i64
251    pub fn signature(&self) -> Signature {
252        generate_signature!(self, {
253            AggregateFunc::Count => Signature {
254                input: smallvec![ConcreteDataType::null_datatype()],
255                output: ConcreteDataType::int64_datatype(),
256                generic_fn: GenericFn::Count,
257            }
258        },[
259            MaxInt16 => (int16_datatype, Max),
260            MaxInt32 => (int32_datatype, Max),
261            MaxInt64 => (int64_datatype, Max),
262            MaxUInt16 => (uint16_datatype, Max),
263            MaxUInt32 => (uint32_datatype, Max),
264            MaxUInt64 => (uint64_datatype, Max),
265            MaxFloat32 => (float32_datatype, Max),
266            MaxFloat64 => (float64_datatype, Max),
267            MaxBool => (boolean_datatype, Max),
268            MaxString => (string_datatype, Max),
269            MaxDate => (date_datatype, Max),
270            MaxDateTime => (timestamp_microsecond_datatype, Max),
271            MaxTimestamp => (timestamp_second_datatype, Max),
272            MaxTime => (time_second_datatype, Max),
273            MaxDuration => (duration_second_datatype, Max),
274            MaxInterval => (interval_year_month_datatype, Max),
275            MinInt16 => (int16_datatype, Min),
276            MinInt32 => (int32_datatype, Min),
277            MinInt64 => (int64_datatype, Min),
278            MinUInt16 => (uint16_datatype, Min),
279            MinUInt32 => (uint32_datatype, Min),
280            MinUInt64 => (uint64_datatype, Min),
281            MinFloat32 => (float32_datatype, Min),
282            MinFloat64 => (float64_datatype, Min),
283            MinBool => (boolean_datatype, Min),
284            MinString => (string_datatype, Min),
285            MinDate => (date_datatype, Min),
286            MinDateTime => (timestamp_microsecond_datatype, Min),
287            MinTimestamp => (timestamp_second_datatype, Min),
288            MinTime => (time_second_datatype, Min),
289            MinDuration => (duration_second_datatype, Min),
290            MinInterval => (interval_year_month_datatype, Min),
291            SumInt16 => (int16_datatype, int64_datatype, Sum),
292            SumInt32 => (int32_datatype, int64_datatype, Sum),
293            SumInt64 => (int64_datatype, int64_datatype, Sum),
294            SumUInt16 => (uint16_datatype, uint64_datatype, Sum),
295            SumUInt32 => (uint32_datatype, uint64_datatype, Sum),
296            SumUInt64 => (uint64_datatype, uint64_datatype, Sum),
297            SumFloat32 => (float32_datatype, Sum),
298            SumFloat64 => (float64_datatype, Sum),
299            Any => (boolean_datatype, Any),
300            All => (boolean_datatype, All)
301        ])
302    }
303}