common_function/aggrs/
aggr_wrapper.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Wrapper for making aggregate functions out of state/merge functions of original aggregate functions.
16//!
17//! i.e. for a aggregate function `foo`, we will have a state function `foo_state` and a merge function `foo_merge`.
18//!
19//! `foo_state`'s input args is the same as `foo`'s, and its output is a state object.
20//! Note that `foo_state` might have multiple output columns, so it's a struct array
21//! that each output column is a struct field.
22//! `foo_merge`'s input arg is the same as `foo_state`'s output, and its output is the same as `foo`'s input.
23//!
24
25use std::sync::Arc;
26
27use arrow::array::StructArray;
28use arrow_schema::Fields;
29use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
30use datafusion::optimizer::AnalyzerRule;
31use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter;
32use datafusion_common::{Column, ScalarValue};
33use datafusion_expr::expr::AggregateFunction;
34use datafusion_expr::function::StateFieldsArgs;
35use datafusion_expr::{
36    Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, Expr, ExprSchemable, LogicalPlan,
37    Signature,
38};
39use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
40use datatypes::arrow::datatypes::{DataType, Field};
41
42/// Returns the name of the state function for the given aggregate function name.
43/// The state function is used to compute the state of the aggregate function.
44/// The state function's name is in the format `__<aggr_name>_state
45pub fn aggr_state_func_name(aggr_name: &str) -> String {
46    format!("__{}_state", aggr_name)
47}
48
49/// Returns the name of the merge function for the given aggregate function name.
50/// The merge function is used to merge the states of the state functions.
51/// The merge function's name is in the format `__<aggr_name>_merge
52pub fn aggr_merge_func_name(aggr_name: &str) -> String {
53    format!("__{}_merge", aggr_name)
54}
55
56/// A wrapper to make an aggregate function out of the state and merge functions of the original aggregate function.
57/// It contains the original aggregate function, the state functions, and the merge function.
58///
59/// Notice state functions may have multiple output columns, so it's return type is always a struct array, and the merge function is used to merge the states of the state functions.
60#[derive(Debug, Clone)]
61pub struct StateMergeHelper;
62
63/// A struct to hold the two aggregate plans, one for the state function(lower) and one for the merge function(upper).
64#[allow(unused)]
65#[derive(Debug, Clone)]
66pub struct StepAggrPlan {
67    /// Upper merge plan, which is the aggregate plan that merges the states of the state function.
68    pub upper_merge: Arc<LogicalPlan>,
69    /// Lower state plan, which is the aggregate plan that computes the state of the aggregate function.
70    pub lower_state: Arc<LogicalPlan>,
71}
72
73pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
74    let mut expr_ref = expr;
75    while let Expr::Alias(alias) = expr_ref {
76        expr_ref = &alias.expr;
77    }
78    if let Expr::AggregateFunction(aggr_func) = expr_ref {
79        Some(aggr_func)
80    } else {
81        None
82    }
83}
84
85impl StateMergeHelper {
86    /// Split an aggregate plan into two aggregate plans, one for the state function and one for the merge function.
87    pub fn split_aggr_node(aggr_plan: Aggregate) -> datafusion_common::Result<StepAggrPlan> {
88        let aggr = {
89            // certain aggr func need type coercion to work correctly, so we need to analyze the plan first.
90            let aggr_plan = TypeCoercion::new().analyze(
91                LogicalPlan::Aggregate(aggr_plan).clone(),
92                &Default::default(),
93            )?;
94            if let LogicalPlan::Aggregate(aggr) = aggr_plan {
95                aggr
96            } else {
97                return Err(datafusion_common::DataFusionError::Internal(format!(
98                    "Failed to coerce expressions in aggregate plan, expected Aggregate, got: {:?}",
99                    aggr_plan
100                )));
101            }
102        };
103        let mut lower_aggr_exprs = vec![];
104        let mut upper_aggr_exprs = vec![];
105
106        for aggr_expr in aggr.aggr_expr.iter() {
107            let Some(aggr_func) = get_aggr_func(aggr_expr) else {
108                return Err(datafusion_common::DataFusionError::NotImplemented(format!(
109                    "Unsupported aggregate expression for step aggr optimize: {:?}",
110                    aggr_expr
111                )));
112            };
113
114            let original_input_types = aggr_func
115                .args
116                .iter()
117                .map(|e| e.get_type(&aggr.input.schema()))
118                .collect::<Result<Vec<_>, _>>()?;
119
120            // first create the state function from the original aggregate function.
121            let state_func = StateWrapper::new((*aggr_func.func).clone())?;
122
123            let expr = AggregateFunction {
124                func: Arc::new(state_func.into()),
125                args: aggr_func.args.clone(),
126                distinct: aggr_func.distinct,
127                filter: aggr_func.filter.clone(),
128                order_by: aggr_func.order_by.clone(),
129                null_treatment: aggr_func.null_treatment,
130            };
131            let expr = Expr::AggregateFunction(expr);
132            let lower_state_output_col_name = expr.schema_name().to_string();
133
134            lower_aggr_exprs.push(expr);
135
136            let (original_phy_expr, _filter, _ordering) = create_aggregate_expr_and_maybe_filter(
137                aggr_expr,
138                aggr.input.schema(),
139                aggr.input.schema().as_arrow(),
140                &Default::default(),
141            )?;
142
143            let merge_func = MergeWrapper::new(
144                (*aggr_func.func).clone(),
145                original_phy_expr,
146                original_input_types,
147            )?;
148            let arg = Expr::Column(Column::new_unqualified(lower_state_output_col_name));
149            let expr = AggregateFunction {
150                func: Arc::new(merge_func.into()),
151                args: vec![arg],
152                distinct: aggr_func.distinct,
153                filter: aggr_func.filter.clone(),
154                order_by: aggr_func.order_by.clone(),
155                null_treatment: aggr_func.null_treatment,
156            };
157
158            // alias to the original aggregate expr's schema name, so parent plan can refer to it
159            // correctly.
160            let expr = Expr::AggregateFunction(expr).alias(aggr_expr.schema_name().to_string());
161            upper_aggr_exprs.push(expr);
162        }
163
164        let mut lower = aggr.clone();
165        lower.aggr_expr = lower_aggr_exprs;
166        let lower_plan = LogicalPlan::Aggregate(lower);
167
168        // update aggregate's output schema
169        let lower_plan = Arc::new(lower_plan.recompute_schema()?);
170
171        let mut upper = aggr.clone();
172        let aggr_plan = LogicalPlan::Aggregate(aggr);
173        upper.aggr_expr = upper_aggr_exprs;
174        upper.input = lower_plan.clone();
175        // upper schema's output schema should be the same as the original aggregate plan's output schema
176        let upper_check = upper.clone();
177        let upper_plan = Arc::new(LogicalPlan::Aggregate(upper_check).recompute_schema()?);
178        if *upper_plan.schema() != *aggr_plan.schema() {
179            return Err(datafusion_common::DataFusionError::Internal(format!(
180                 "Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[   original]{}",
181                upper_plan.schema(), aggr_plan.schema()
182            )));
183        }
184
185        Ok(StepAggrPlan {
186            lower_state: lower_plan,
187            upper_merge: upper_plan,
188        })
189    }
190}
191
192/// Wrapper to make an aggregate function out of a state function.
193#[derive(Debug, Clone, PartialEq, Eq)]
194pub struct StateWrapper {
195    inner: AggregateUDF,
196    name: String,
197}
198
199impl StateWrapper {
200    /// `state_index`: The index of the state in the output of the state function.
201    pub fn new(inner: AggregateUDF) -> datafusion_common::Result<Self> {
202        let name = aggr_state_func_name(inner.name());
203        Ok(Self { inner, name })
204    }
205
206    pub fn inner(&self) -> &AggregateUDF {
207        &self.inner
208    }
209
210    /// Deduce the return type of the original aggregate function
211    /// based on the accumulator arguments.
212    ///
213    pub fn deduce_aggr_return_type(
214        &self,
215        acc_args: &datafusion_expr::function::AccumulatorArgs,
216    ) -> datafusion_common::Result<DataType> {
217        let input_exprs = acc_args.exprs;
218        let input_schema = acc_args.schema;
219        let input_types = input_exprs
220            .iter()
221            .map(|e| e.data_type(input_schema))
222            .collect::<Result<Vec<_>, _>>()?;
223        let return_type = self.inner.return_type(&input_types)?;
224        Ok(return_type)
225    }
226}
227
228impl AggregateUDFImpl for StateWrapper {
229    fn accumulator<'a, 'b>(
230        &'a self,
231        acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
232    ) -> datafusion_common::Result<Box<dyn Accumulator>> {
233        // fix and recover proper acc args for the original aggregate function.
234        let state_type = acc_args.return_type.clone();
235        let inner = {
236            let old_return_type = self.deduce_aggr_return_type(&acc_args)?;
237            let acc_args = datafusion_expr::function::AccumulatorArgs {
238                return_type: &old_return_type,
239                schema: acc_args.schema,
240                ignore_nulls: acc_args.ignore_nulls,
241                ordering_req: acc_args.ordering_req,
242                is_reversed: acc_args.is_reversed,
243                name: acc_args.name,
244                is_distinct: acc_args.is_distinct,
245                exprs: acc_args.exprs,
246            };
247            self.inner.accumulator(acc_args)?
248        };
249        Ok(Box::new(StateAccum::new(inner, state_type)?))
250    }
251
252    fn as_any(&self) -> &dyn std::any::Any {
253        self
254    }
255    fn name(&self) -> &str {
256        self.name.as_str()
257    }
258
259    fn is_nullable(&self) -> bool {
260        self.inner.is_nullable()
261    }
262
263    /// Return state_fields as the output struct type.
264    ///
265    fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
266        let old_return_type = self.inner.return_type(arg_types)?;
267        let state_fields_args = StateFieldsArgs {
268            name: self.inner().name(),
269            input_types: arg_types,
270            return_type: &old_return_type,
271            // TODO(discord9): how to get this?, probably ok?
272            ordering_fields: &[],
273            is_distinct: false,
274        };
275        let state_fields = self.inner.state_fields(state_fields_args)?;
276        let struct_field = DataType::Struct(state_fields.into());
277        Ok(struct_field)
278    }
279
280    /// The state function's output fields are the same as the original aggregate function's state fields.
281    fn state_fields(
282        &self,
283        args: datafusion_expr::function::StateFieldsArgs,
284    ) -> datafusion_common::Result<Vec<Field>> {
285        let old_return_type = self.inner.return_type(args.input_types)?;
286        let state_fields_args = StateFieldsArgs {
287            name: args.name,
288            input_types: args.input_types,
289            return_type: &old_return_type,
290            ordering_fields: args.ordering_fields,
291            is_distinct: args.is_distinct,
292        };
293        self.inner.state_fields(state_fields_args)
294    }
295
296    /// The state function's signature is the same as the original aggregate function's signature,
297    fn signature(&self) -> &Signature {
298        self.inner.signature()
299    }
300
301    /// Coerce types also do nothing, as optimzer should be able to already make struct types
302    fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
303        self.inner.coerce_types(arg_types)
304    }
305}
306
307/// The wrapper's input is the same as the original aggregate function's input,
308/// and the output is the state function's output.
309#[derive(Debug)]
310pub struct StateAccum {
311    inner: Box<dyn Accumulator>,
312    state_fields: Fields,
313}
314
315impl StateAccum {
316    pub fn new(
317        inner: Box<dyn Accumulator>,
318        state_type: DataType,
319    ) -> datafusion_common::Result<Self> {
320        let DataType::Struct(fields) = state_type else {
321            return Err(datafusion_common::DataFusionError::Internal(format!(
322                "Expected a struct type for state, got: {:?}",
323                state_type
324            )));
325        };
326        Ok(Self {
327            inner,
328            state_fields: fields,
329        })
330    }
331}
332
333impl Accumulator for StateAccum {
334    fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
335        let state = self.inner.state()?;
336
337        let array = state
338            .iter()
339            .map(|s| s.to_array())
340            .collect::<Result<Vec<_>, _>>()?;
341        let struct_array = StructArray::try_new(self.state_fields.clone(), array, None)?;
342        Ok(ScalarValue::Struct(Arc::new(struct_array)))
343    }
344
345    fn merge_batch(
346        &mut self,
347        states: &[datatypes::arrow::array::ArrayRef],
348    ) -> datafusion_common::Result<()> {
349        self.inner.merge_batch(states)
350    }
351
352    fn update_batch(
353        &mut self,
354        values: &[datatypes::arrow::array::ArrayRef],
355    ) -> datafusion_common::Result<()> {
356        self.inner.update_batch(values)
357    }
358
359    fn size(&self) -> usize {
360        self.inner.size()
361    }
362
363    fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
364        self.inner.state()
365    }
366}
367
368/// TODO(discord9): mark this function as non-ser/de able
369///
370/// This wrapper shouldn't be register as a udaf, as it contain extra data that is not serializable.
371/// and changes for different logical plans.
372#[derive(Debug, Clone)]
373pub struct MergeWrapper {
374    inner: AggregateUDF,
375    name: String,
376    merge_signature: Signature,
377    /// The original physical expression of the aggregate function, can't store the original aggregate function directly, as PhysicalExpr didn't implement Any
378    original_phy_expr: Arc<AggregateFunctionExpr>,
379    original_input_types: Vec<DataType>,
380}
381impl MergeWrapper {
382    pub fn new(
383        inner: AggregateUDF,
384        original_phy_expr: Arc<AggregateFunctionExpr>,
385        original_input_types: Vec<DataType>,
386    ) -> datafusion_common::Result<Self> {
387        let name = aggr_merge_func_name(inner.name());
388        // the input type is actually struct type, which is the state fields of the original aggregate function.
389        let merge_signature = Signature::user_defined(datafusion_expr::Volatility::Immutable);
390
391        Ok(Self {
392            inner,
393            name,
394            merge_signature,
395            original_phy_expr,
396            original_input_types,
397        })
398    }
399
400    pub fn inner(&self) -> &AggregateUDF {
401        &self.inner
402    }
403}
404
405impl AggregateUDFImpl for MergeWrapper {
406    fn accumulator<'a, 'b>(
407        &'a self,
408        acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
409    ) -> datafusion_common::Result<Box<dyn Accumulator>> {
410        if acc_args.schema.fields().len() != 1
411            || !matches!(acc_args.schema.field(0).data_type(), DataType::Struct(_))
412        {
413            return Err(datafusion_common::DataFusionError::Internal(format!(
414                "Expected one struct type as input, got: {:?}",
415                acc_args.schema
416            )));
417        }
418        let input_type = acc_args.schema.field(0).data_type();
419        let DataType::Struct(fields) = input_type else {
420            return Err(datafusion_common::DataFusionError::Internal(format!(
421                "Expected a struct type for input, got: {:?}",
422                input_type
423            )));
424        };
425
426        let inner_accum = self.original_phy_expr.create_accumulator()?;
427        Ok(Box::new(MergeAccum::new(inner_accum, fields)))
428    }
429
430    fn as_any(&self) -> &dyn std::any::Any {
431        self
432    }
433    fn name(&self) -> &str {
434        self.name.as_str()
435    }
436
437    fn is_nullable(&self) -> bool {
438        self.inner.is_nullable()
439    }
440
441    /// Notice here the `arg_types` is actually the `state_fields`'s data types,
442    /// so return fixed return type instead of using `arg_types` to determine the return type.
443    fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
444        // The return type is the same as the original aggregate function's return type.
445        let ret_type = self.inner.return_type(&self.original_input_types)?;
446        Ok(ret_type)
447    }
448    fn signature(&self) -> &Signature {
449        &self.merge_signature
450    }
451
452    /// Coerce types also do nothing, as optimzer should be able to already make struct types
453    fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
454        // just check if the arg_types are only one and is struct array
455        if arg_types.len() != 1 || !matches!(arg_types.first(), Some(DataType::Struct(_))) {
456            return Err(datafusion_common::DataFusionError::Internal(format!(
457                "Expected one struct type as input, got: {:?}",
458                arg_types
459            )));
460        }
461        Ok(arg_types.to_vec())
462    }
463
464    /// Just return the original aggregate function's state fields.
465    fn state_fields(
466        &self,
467        _args: datafusion_expr::function::StateFieldsArgs,
468    ) -> datafusion_common::Result<Vec<Field>> {
469        self.original_phy_expr.state_fields()
470    }
471}
472
473/// The merge accumulator, which modify `update_batch`'s behavior to accept one struct array which
474/// include the state fields of original aggregate function, and merge said states into original accumulator
475/// the output is the same as original aggregate function
476#[derive(Debug)]
477pub struct MergeAccum {
478    inner: Box<dyn Accumulator>,
479    state_fields: Fields,
480}
481
482impl MergeAccum {
483    pub fn new(inner: Box<dyn Accumulator>, state_fields: &Fields) -> Self {
484        Self {
485            inner,
486            state_fields: state_fields.clone(),
487        }
488    }
489}
490
491impl Accumulator for MergeAccum {
492    fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
493        self.inner.evaluate()
494    }
495
496    fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
497        self.inner.merge_batch(states)
498    }
499
500    fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
501        let value = values.first().ok_or_else(|| {
502            datafusion_common::DataFusionError::Internal("No values provided for merge".to_string())
503        })?;
504        // The input values are states from other accumulators, so we merge them.
505        let struct_arr = value
506            .as_any()
507            .downcast_ref::<StructArray>()
508            .ok_or_else(|| {
509                datafusion_common::DataFusionError::Internal(format!(
510                    "Expected StructArray, got: {:?}",
511                    value.data_type()
512                ))
513            })?;
514        let fields = struct_arr.fields();
515        if fields != &self.state_fields {
516            return Err(datafusion_common::DataFusionError::Internal(format!(
517                "Expected state fields: {:?}, got: {:?}",
518                self.state_fields, fields
519            )));
520        }
521
522        // now fields should be the same, so we can merge the batch
523        // by pass the columns as order should be the same
524        let state_columns = struct_arr.columns();
525        self.inner.merge_batch(state_columns)
526    }
527
528    fn size(&self) -> usize {
529        self.inner.size()
530    }
531
532    fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
533        self.inner.state()
534    }
535}
536
537#[cfg(test)]
538mod tests;