1use std::collections::HashMap;
16use std::sync::OnceLock;
17
18use datatypes::prelude::ConcreteDataType;
19use datatypes::value::Value;
20use datatypes::vectors::VectorRef;
21use serde::{Deserialize, Serialize};
22use smallvec::smallvec;
23use snafu::OptionExt;
24use strum::{EnumIter, IntoEnumIterator};
25
26use crate::error::{Error, InvalidQuerySnafu};
27use crate::expr::error::EvalError;
28use crate::expr::relation::accum::{Accum, Accumulator};
29use crate::expr::signature::{GenericFn, Signature};
30use crate::expr::VectorDiff;
31use crate::repr::Diff;
32
33#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, EnumIter)]
45pub enum AggregateFunc {
46 MaxInt16,
47 MaxInt32,
48 MaxInt64,
49 MaxUInt16,
50 MaxUInt32,
51 MaxUInt64,
52 MaxFloat32,
53 MaxFloat64,
54 MaxBool,
55 MaxString,
56 MaxDate,
57 MaxDateTime,
58 MaxTimestamp,
59 MaxTime,
60 MaxDuration,
61 MaxInterval,
62
63 MinInt16,
64 MinInt32,
65 MinInt64,
66 MinUInt16,
67 MinUInt32,
68 MinUInt64,
69 MinFloat32,
70 MinFloat64,
71 MinBool,
72 MinString,
73 MinDate,
74 MinDateTime,
75 MinTimestamp,
76 MinTime,
77 MinDuration,
78 MinInterval,
79
80 SumInt16,
81 SumInt32,
82 SumInt64,
83 SumUInt16,
84 SumUInt32,
85 SumUInt64,
86 SumFloat32,
87 SumFloat64,
88
89 Count,
90 Any,
91 All,
92}
93
94impl AggregateFunc {
95 pub fn is_max(&self) -> bool {
97 self.signature().generic_fn == GenericFn::Max
98 }
99
100 pub fn is_min(&self) -> bool {
102 self.signature().generic_fn == GenericFn::Min
103 }
104
105 pub fn eval_diff_accumulable<A, I>(
111 &self,
112 accum: A,
113 value_diffs: I,
114 ) -> Result<(Value, Vec<Value>), EvalError>
115 where
116 A: IntoIterator<Item = Value>,
117 I: IntoIterator<Item = (Value, Diff)>,
118 {
119 let mut accum = accum.into_iter().peekable();
120
121 let mut accum = if accum.peek().is_none() {
122 Accum::new_accum(self)?
123 } else {
124 Accum::try_from_iter(self, &mut accum)?
125 };
126 accum.update_batch(self, value_diffs)?;
127 let res = accum.eval(self)?;
128 Ok((res, accum.into_state()))
129 }
130
131 pub fn eval_batch<A>(
133 &self,
134 accum: A,
135 vector: VectorRef,
136 diff: Option<VectorRef>,
137 ) -> Result<(Value, Vec<Value>), EvalError>
138 where
139 A: IntoIterator<Item = Value>,
140 {
141 let mut accum = accum.into_iter().peekable();
142
143 let mut accum = if accum.peek().is_none() {
144 Accum::new_accum(self)?
145 } else {
146 Accum::try_from_iter(self, &mut accum)?
147 };
148
149 let vector_diff = VectorDiff::try_new(vector, diff)?;
150
151 accum.update_batch(self, vector_diff)?;
152
153 let res = accum.eval(self)?;
154 Ok((res, accum.into_state()))
155 }
156}
157
158macro_rules! generate_signature {
160 ($value:ident,
161 { $($user_arm:tt)* },
162 [ $(
163 $auto_arm:ident=>($($arg:ident),*)
164 ),*
165 ]
166 ) => {
167 match $value {
168 $($user_arm)*,
169 $(
170 Self::$auto_arm => gen_one_siginature!($($arg),*),
171 )*
172 }
173 };
174}
175
176macro_rules! gen_one_siginature {
178 (
179 $con_type:ident, $generic:ident
180 ) => {
181 Signature {
182 input: smallvec![ConcreteDataType::$con_type(), ConcreteDataType::$con_type(),],
183 output: ConcreteDataType::$con_type(),
184 generic_fn: GenericFn::$generic,
185 }
186 };
187 (
188 $in_type:ident, $out_type:ident, $generic:ident
189 ) => {
190 Signature {
191 input: smallvec![ConcreteDataType::$in_type()],
192 output: ConcreteDataType::$out_type(),
193 generic_fn: GenericFn::$generic,
194 }
195 };
196}
197
198static SPECIALIZATION: OnceLock<HashMap<(GenericFn, ConcreteDataType), AggregateFunc>> =
199 OnceLock::new();
200
201impl AggregateFunc {
202 pub fn from_str_and_type(
206 name: &str,
207 arg_type: Option<ConcreteDataType>,
208 ) -> Result<Self, Error> {
209 let rule = SPECIALIZATION.get_or_init(|| {
210 let mut spec = HashMap::new();
211 for func in Self::iter() {
212 let sig = func.signature();
213 spec.insert((sig.generic_fn, sig.input[0].clone()), func);
214 }
215 spec
216 });
217
218 let generic_fn = match name {
219 "max" => GenericFn::Max,
220 "min" => GenericFn::Min,
221 "sum" => GenericFn::Sum,
222 "count" => GenericFn::Count,
223 "bool_or" => GenericFn::Any,
224 "bool_and" => GenericFn::All,
225 _ => {
226 return InvalidQuerySnafu {
227 reason: format!("Unknown aggregate function: {}", name),
228 }
229 .fail();
230 }
231 };
232 let input_type = if matches!(generic_fn, GenericFn::Count) {
233 ConcreteDataType::null_datatype()
234 } else {
235 arg_type.unwrap_or_else(ConcreteDataType::null_datatype)
236 };
237 rule.get(&(generic_fn, input_type.clone()))
238 .cloned()
239 .with_context(|| InvalidQuerySnafu {
240 reason: format!(
241 "No specialization found for aggregate function {:?} with input type {:?}",
242 generic_fn, input_type
243 ),
244 })
245 }
246
247 pub fn signature(&self) -> Signature {
252 generate_signature!(self, {
253 AggregateFunc::Count => Signature {
254 input: smallvec![ConcreteDataType::null_datatype()],
255 output: ConcreteDataType::int64_datatype(),
256 generic_fn: GenericFn::Count,
257 }
258 },[
259 MaxInt16 => (int16_datatype, Max),
260 MaxInt32 => (int32_datatype, Max),
261 MaxInt64 => (int64_datatype, Max),
262 MaxUInt16 => (uint16_datatype, Max),
263 MaxUInt32 => (uint32_datatype, Max),
264 MaxUInt64 => (uint64_datatype, Max),
265 MaxFloat32 => (float32_datatype, Max),
266 MaxFloat64 => (float64_datatype, Max),
267 MaxBool => (boolean_datatype, Max),
268 MaxString => (string_datatype, Max),
269 MaxDate => (date_datatype, Max),
270 MaxDateTime => (timestamp_microsecond_datatype, Max),
271 MaxTimestamp => (timestamp_second_datatype, Max),
272 MaxTime => (time_second_datatype, Max),
273 MaxDuration => (duration_second_datatype, Max),
274 MaxInterval => (interval_year_month_datatype, Max),
275 MinInt16 => (int16_datatype, Min),
276 MinInt32 => (int32_datatype, Min),
277 MinInt64 => (int64_datatype, Min),
278 MinUInt16 => (uint16_datatype, Min),
279 MinUInt32 => (uint32_datatype, Min),
280 MinUInt64 => (uint64_datatype, Min),
281 MinFloat32 => (float32_datatype, Min),
282 MinFloat64 => (float64_datatype, Min),
283 MinBool => (boolean_datatype, Min),
284 MinString => (string_datatype, Min),
285 MinDate => (date_datatype, Min),
286 MinDateTime => (timestamp_microsecond_datatype, Min),
287 MinTimestamp => (timestamp_second_datatype, Min),
288 MinTime => (time_second_datatype, Min),
289 MinDuration => (duration_second_datatype, Min),
290 MinInterval => (interval_year_month_datatype, Min),
291 SumInt16 => (int16_datatype, int64_datatype, Sum),
292 SumInt32 => (int32_datatype, int64_datatype, Sum),
293 SumInt64 => (int64_datatype, int64_datatype, Sum),
294 SumUInt16 => (uint16_datatype, uint64_datatype, Sum),
295 SumUInt32 => (uint32_datatype, uint64_datatype, Sum),
296 SumUInt64 => (uint64_datatype, uint64_datatype, Sum),
297 SumFloat32 => (float32_datatype, Sum),
298 SumFloat64 => (float64_datatype, Sum),
299 Any => (boolean_datatype, Any),
300 All => (boolean_datatype, All)
301 ])
302 }
303}