Skip to main content

common_function/aggrs/
aggr_wrapper.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Wrapper for making aggregate functions out of state/merge functions of original aggregate functions.
16//!
17//! i.e. for a aggregate function `foo`, we will have a state function `foo_state` and a merge function `foo_merge`.
18//!
19//! `foo_state`'s input args is the same as `foo`'s, and its output is a state object.
20//! Note that `foo_state` might have multiple output columns, so it's a struct array
21//! that each output column is a struct field.
22//! `foo_merge`'s input arg is the same as `foo_state`'s output, and its output is the same as `foo`'s input.
23//!
24
25use std::hash::{Hash, Hasher};
26use std::sync::Arc;
27
28use arrow::array::{ArrayRef, BooleanArray, StructArray};
29use arrow_schema::{FieldRef, Fields};
30use common_telemetry::debug;
31use datafusion::functions_aggregate::all_default_aggregate_functions;
32use datafusion::functions_aggregate::count::Count;
33use datafusion::functions_aggregate::min_max::{Max, Min};
34use datafusion::optimizer::AnalyzerRule;
35use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
36use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter;
37use datafusion_common::{Column, ScalarValue};
38use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
39use datafusion_expr::function::StateFieldsArgs;
40use datafusion_expr::{
41    Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, EmitTo, Expr, ExprSchemable,
42    GroupsAccumulator, LogicalPlan, Signature,
43};
44use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
45use datatypes::arrow::datatypes::{DataType, Field};
46
47use crate::aggrs::aggr_wrapper::fix_order::FixStateUdafOrderingAnalyzer;
48use crate::function_registry::{FUNCTION_REGISTRY, FunctionRegistry};
49
50pub mod fix_order;
51#[cfg(test)]
52mod tests;
53
54/// Returns the name of the state function for the given aggregate function name.
55/// The state function is used to compute the state of the aggregate function.
56/// The state function's name is in the format `__<aggr_name>_state
57pub fn aggr_state_func_name(aggr_name: &str) -> String {
58    format!("__{}_state", aggr_name)
59}
60
61/// Returns the name of the merge function for the given aggregate function name.
62/// The merge function is used to merge the states of the state functions.
63/// The merge function's name is in the format `__<aggr_name>_merge
64pub fn aggr_merge_func_name(aggr_name: &str) -> String {
65    format!("__{}_merge", aggr_name)
66}
67
68/// Check if the given aggregate expression is steppable.
69/// As in if it can be split into multiple steps:
70/// i.e. on datanode first call `state(input)` then
71/// on frontend call `calc(merge(state))` to get the final result.
72pub fn is_all_aggr_exprs_steppable(aggr_exprs: &[Expr]) -> bool {
73    aggr_exprs.iter().all(|expr| {
74        if let Some(aggr_func) = get_aggr_func(expr) {
75            if aggr_func.params.distinct {
76                // Distinct aggregate functions are not steppable(yet).
77                // TODO(discord9): support distinct aggregate functions.
78                return false;
79            }
80
81            // whether the corresponding state function exists in the registry
82            FUNCTION_REGISTRY.is_aggr_func_exist(&aggr_state_func_name(aggr_func.func.name()))
83        } else {
84            false
85        }
86    })
87}
88
89pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
90    let mut expr_ref = expr;
91    while let Expr::Alias(alias) = expr_ref {
92        expr_ref = &alias.expr;
93    }
94    if let Expr::AggregateFunction(aggr_func) = expr_ref {
95        Some(aggr_func)
96    } else {
97        None
98    }
99}
100
101/// A wrapper to make an aggregate function out of the state and merge functions of the original aggregate function.
102/// It contains the original aggregate function, the state functions, and the merge function.
103///
104/// Notice state functions may have multiple output columns, so it's return type is always a struct array, and the merge function is used to merge the states of the state functions.
105#[derive(Debug, Clone)]
106pub struct StateMergeHelper;
107
108/// A struct to hold the two aggregate plans, one for the state function(lower) and one for the merge function(upper).
109#[allow(unused)]
110#[derive(Debug, Clone)]
111pub struct StepAggrPlan {
112    /// Upper merge plan, which is the aggregate plan that merges the states of the state function.
113    pub upper_merge: LogicalPlan,
114    /// Lower state plan, which is the aggregate plan that computes the state of the aggregate function.
115    pub lower_state: LogicalPlan,
116}
117
118impl StateMergeHelper {
119    /// Register all the `state` function of supported aggregate functions.
120    /// Note that can't register `merge` function here, as it needs to be created from the original aggregate function with given input types.
121    pub fn register(registry: &FunctionRegistry) {
122        let all_default = all_default_aggregate_functions();
123        let greptime_custom_aggr_functions = registry.aggregate_functions();
124
125        // if our custom aggregate function have the same name as the default aggregate function, we will override it.
126        let supported = all_default
127            .into_iter()
128            .chain(greptime_custom_aggr_functions.into_iter().map(Arc::new))
129            .collect::<Vec<_>>();
130        debug!(
131            "Registering state functions for supported: {:?}",
132            supported.iter().map(|f| f.name()).collect::<Vec<_>>()
133        );
134
135        let state_func = supported.into_iter().filter_map(|f| {
136            StateWrapper::new((*f).clone())
137                .inspect_err(
138                    |e| common_telemetry::error!(e; "Failed to register state function for {:?}", f),
139                )
140                .ok()
141                .map(AggregateUDF::new_from_impl)
142        });
143
144        for func in state_func {
145            registry.register_aggr(func);
146        }
147    }
148
149    /// Split an aggregate plan into two aggregate plans, one for the state function and one for the merge function.
150    ///
151    pub fn split_aggr_node(aggr_plan: Aggregate) -> datafusion_common::Result<StepAggrPlan> {
152        let aggr = {
153            // certain aggr func need type coercion to work correctly, so we need to analyze the plan first.
154            let aggr_plan = TypeCoercion::new().analyze(
155                LogicalPlan::Aggregate(aggr_plan).clone(),
156                &Default::default(),
157            )?;
158            if let LogicalPlan::Aggregate(aggr) = aggr_plan {
159                aggr
160            } else {
161                return Err(datafusion_common::DataFusionError::Internal(format!(
162                    "Failed to coerce expressions in aggregate plan, expected Aggregate, got: {:?}",
163                    aggr_plan
164                )));
165            }
166        };
167        let mut lower_aggr_exprs = vec![];
168        let mut upper_aggr_exprs = vec![];
169
170        // group exprs for upper plan should refer to the output group expr as column from lower plan
171        // to avoid re-compute group exprs again.
172        let upper_group_exprs = aggr
173            .group_expr
174            .iter()
175            .map(|c| c.qualified_name())
176            .map(|(r, c)| Expr::Column(Column::new(r, c)))
177            .collect();
178
179        for aggr_expr in aggr.aggr_expr.iter() {
180            let Some(aggr_func) = get_aggr_func(aggr_expr) else {
181                return Err(datafusion_common::DataFusionError::NotImplemented(format!(
182                    "Unsupported aggregate expression for step aggr optimize: {:?}",
183                    aggr_expr
184                )));
185            };
186
187            let original_input_fields = aggr_func
188                .params
189                .args
190                .iter()
191                .map(|e| e.to_field(&aggr.input.schema()).map(|(_, field)| field))
192                .collect::<Result<Vec<_>, _>>()?;
193
194            // first create the state function from the original aggregate function.
195            let state_func = StateWrapper::new((*aggr_func.func).clone())?;
196
197            let expr = AggregateFunction {
198                func: Arc::new(state_func.into()),
199                params: aggr_func.params.clone(),
200            };
201            let expr = Expr::AggregateFunction(expr);
202            let lower_state_output_col_name = expr.schema_name().to_string();
203
204            lower_aggr_exprs.push(expr);
205
206            // then create the merge function using the physical expression of the original aggregate function
207            let (original_phy_expr, _filter, _ordering) = create_aggregate_expr_and_maybe_filter(
208                aggr_expr,
209                aggr.input.schema(),
210                aggr.input.schema().as_arrow(),
211                &Default::default(),
212            )?;
213
214            let merge_func = MergeWrapper::new(
215                (*aggr_func.func).clone(),
216                original_phy_expr,
217                original_input_fields,
218            )?;
219            let arg = Expr::Column(Column::new_unqualified(lower_state_output_col_name));
220            let expr = AggregateFunction {
221                func: Arc::new(merge_func.into()),
222                // notice filter/order_by is not supported in the merge function, as it's not meaningful to have them in the merge phase.
223                // do notice this order by is only removed in the outer logical plan, the physical plan still have order by and hence
224                // can create correct accumulator with order by.
225                params: AggregateFunctionParams {
226                    args: vec![arg],
227                    distinct: aggr_func.params.distinct,
228                    filter: None,
229                    order_by: vec![],
230                    null_treatment: aggr_func.params.null_treatment,
231                },
232            };
233
234            // alias to the original aggregate expr's schema name, so parent plan can refer to it
235            // correctly.
236            let expr = Expr::AggregateFunction(expr).alias(aggr_expr.schema_name().to_string());
237            upper_aggr_exprs.push(expr);
238        }
239
240        let mut lower = aggr.clone();
241        lower.aggr_expr = lower_aggr_exprs;
242        let lower_plan = LogicalPlan::Aggregate(lower);
243
244        // update aggregate's output schema
245        let lower_plan = lower_plan.recompute_schema()?;
246
247        // should only affect two udaf `first_value/last_value`
248        // which only them have meaningful order by field
249        let fixed_lower_plan =
250            FixStateUdafOrderingAnalyzer.analyze(lower_plan, &Default::default())?;
251
252        let upper = Aggregate::try_new(
253            Arc::new(fixed_lower_plan.clone()),
254            upper_group_exprs,
255            upper_aggr_exprs.clone(),
256        )?;
257        let aggr_plan = LogicalPlan::Aggregate(aggr);
258
259        // upper schema's output schema should be the same as the original aggregate plan's output schema
260        let upper_check = upper;
261        let upper_plan = LogicalPlan::Aggregate(upper_check).recompute_schema()?;
262        if *upper_plan.schema() != *aggr_plan.schema() {
263            return Err(datafusion_common::DataFusionError::Internal(format!(
264                "Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[original]:{}",
265                upper_plan.schema(),
266                aggr_plan.schema()
267            )));
268        }
269
270        Ok(StepAggrPlan {
271            lower_state: fixed_lower_plan,
272            upper_merge: upper_plan,
273        })
274    }
275}
276
277/// Wrapper to make an aggregate function out of a state function.
278#[derive(Debug, Clone, PartialEq, Eq, Hash)]
279pub struct StateWrapper {
280    inner: AggregateUDF,
281    name: String,
282    /// Default to empty, might get fixed by analyzer later
283    ordering: Vec<FieldRef>,
284    /// Default to false, might get fixed by analyzer later
285    distinct: bool,
286}
287
288impl StateWrapper {
289    /// `state_index`: The index of the state in the output of the state function.
290    pub fn new(inner: AggregateUDF) -> datafusion_common::Result<Self> {
291        let name = aggr_state_func_name(inner.name());
292        Ok(Self {
293            inner,
294            name,
295            ordering: vec![],
296            distinct: false,
297        })
298    }
299
300    pub fn inner(&self) -> &AggregateUDF {
301        &self.inner
302    }
303
304    /// Deduce the return type of the original aggregate function
305    /// based on the accumulator arguments.
306    ///
307    pub fn deduce_aggr_return_type(
308        &self,
309        acc_args: &datafusion_expr::function::AccumulatorArgs,
310    ) -> datafusion_common::Result<FieldRef> {
311        let input_fields = acc_args
312            .exprs
313            .iter()
314            .map(|e| e.return_field(acc_args.schema))
315            .collect::<Result<Vec<_>, _>>()?;
316        self.inner.return_field(&input_fields).inspect_err(|e| {
317            common_telemetry::error!(
318                "StateWrapper: {:#?}\nacc_args:{:?}\nerror:{:?}",
319                &self,
320                &acc_args,
321                e
322            );
323        })
324    }
325
326    fn fix_inner_acc_args<'b>(
327        &self,
328        mut acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
329    ) -> datafusion_common::Result<datafusion_expr::function::AccumulatorArgs<'b>> {
330        acc_args.return_field = self.deduce_aggr_return_type(&acc_args)?;
331        Ok(acc_args)
332    }
333}
334
335impl AggregateUDFImpl for StateWrapper {
336    fn accumulator<'a, 'b>(
337        &'a self,
338        acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
339    ) -> datafusion_common::Result<Box<dyn Accumulator>> {
340        // fix and recover proper acc args for the original aggregate function.
341        let state_type = acc_args.return_type().clone();
342        let inner = self.inner.accumulator(self.fix_inner_acc_args(acc_args)?)?;
343
344        Ok(Box::new(StateAccum::new(inner, state_type)?))
345    }
346
347    fn groups_accumulator_supported(
348        &self,
349        acc_args: datafusion_expr::function::AccumulatorArgs,
350    ) -> bool {
351        self.fix_inner_acc_args(acc_args)
352            .map(|args| self.inner.inner().groups_accumulator_supported(args))
353            .unwrap_or(false)
354    }
355
356    fn create_groups_accumulator(
357        &self,
358        acc_args: datafusion_expr::function::AccumulatorArgs,
359    ) -> datafusion_common::Result<Box<dyn GroupsAccumulator>> {
360        let state_type = acc_args.return_type().clone();
361        let inner = self
362            .inner
363            .inner()
364            .create_groups_accumulator(self.fix_inner_acc_args(acc_args)?)?;
365        Ok(Box::new(StateGroupsAccum::new(inner, state_type)?))
366    }
367
368    fn as_any(&self) -> &dyn std::any::Any {
369        self
370    }
371    fn name(&self) -> &str {
372        self.name.as_str()
373    }
374
375    fn is_nullable(&self) -> bool {
376        self.inner.is_nullable()
377    }
378
379    /// Return state_fields as the output struct type.
380    ///
381    fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
382        let input_fields = &arg_types
383            .iter()
384            .map(|x| Arc::new(Field::new("x", x.clone(), false)))
385            .collect::<Vec<_>>();
386
387        let state_fields_args = StateFieldsArgs {
388            name: self.inner().name(),
389            input_fields,
390            return_field: self.inner.return_field(input_fields)?,
391            // those args are also needed as they are vital to construct the state fields correctly.
392            ordering_fields: &self.ordering,
393            is_distinct: self.distinct,
394        };
395        let state_fields = self.inner.state_fields(state_fields_args)?;
396
397        let state_fields = state_fields
398            .into_iter()
399            .map(|f| {
400                let mut f = f.as_ref().clone();
401                // since state can be null when no input rows, so make all fields nullable
402                f.set_nullable(true);
403                Arc::new(f)
404            })
405            .collect::<Vec<_>>();
406
407        let struct_field = DataType::Struct(state_fields.into());
408        Ok(struct_field)
409    }
410
411    /// The state function's output fields are the same as the original aggregate function's state fields.
412    fn state_fields(
413        &self,
414        args: datafusion_expr::function::StateFieldsArgs,
415    ) -> datafusion_common::Result<Vec<FieldRef>> {
416        let state_fields_args = StateFieldsArgs {
417            name: args.name,
418            input_fields: args.input_fields,
419            return_field: self.inner.return_field(args.input_fields)?,
420            ordering_fields: args.ordering_fields,
421            is_distinct: args.is_distinct,
422        };
423        self.inner.state_fields(state_fields_args)
424    }
425
426    /// The state function's signature is the same as the original aggregate function's signature,
427    fn signature(&self) -> &Signature {
428        self.inner.signature()
429    }
430
431    /// Coerce types also do nothing, as optimizer should be able to already make struct types
432    fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
433        self.inner.coerce_types(arg_types)
434    }
435
436    fn value_from_stats(
437        &self,
438        statistics_args: &datafusion_expr::StatisticsArgs,
439    ) -> Option<ScalarValue> {
440        let inner = self.inner().inner().as_any();
441        // only count/min/max need special handling here, for getting result from statistics
442        // the result of count/min/max is also the result of count_state so can return directly
443        let can_use_stat = inner.is::<Count>() || inner.is::<Max>() || inner.is::<Min>();
444        if !can_use_stat {
445            return None;
446        }
447
448        // fix return type by extract the first field's data type from the struct type
449        let state_type = if let DataType::Struct(fields) = &statistics_args.return_type {
450            if fields.is_empty() {
451                return None;
452            }
453            fields[0].data_type().clone()
454        } else {
455            return None;
456        };
457
458        let fixed_args = datafusion_expr::StatisticsArgs {
459            statistics: statistics_args.statistics,
460            return_type: &state_type,
461            is_distinct: statistics_args.is_distinct,
462            exprs: statistics_args.exprs,
463        };
464
465        let ret = self.inner().value_from_stats(&fixed_args)?;
466
467        // wrap the result into struct scalar value
468        let fields = if let DataType::Struct(fields) = &statistics_args.return_type {
469            fields
470        } else {
471            return None;
472        };
473
474        let array = ret.to_array().ok()?;
475
476        let struct_array = StructArray::new(fields.clone(), vec![array], None);
477        let ret = ScalarValue::Struct(Arc::new(struct_array));
478        Some(ret)
479    }
480}
481
482/// The wrapper's input is the same as the original aggregate function's input,
483/// and the output is the state function's output.
484#[derive(Debug)]
485pub struct StateAccum {
486    inner: Box<dyn Accumulator>,
487    state_fields: Fields,
488}
489
490pub struct StateGroupsAccum {
491    inner: Box<dyn GroupsAccumulator>,
492    state_fields: Fields,
493}
494
495impl StateGroupsAccum {
496    fn new(
497        inner: Box<dyn GroupsAccumulator>,
498        state_type: DataType,
499    ) -> datafusion_common::Result<Self> {
500        let DataType::Struct(fields) = state_type else {
501            return Err(datafusion_common::DataFusionError::Internal(format!(
502                "Expected a struct type for state, got: {:?}",
503                state_type
504            )));
505        };
506        Ok(Self {
507            inner,
508            state_fields: fields,
509        })
510    }
511
512    fn wrap_state_arrays(&self, arrays: Vec<ArrayRef>) -> datafusion_common::Result<ArrayRef> {
513        let array_type = arrays
514            .iter()
515            .map(|array| array.data_type().clone())
516            .collect::<Vec<_>>();
517        let expected_type = self
518            .state_fields
519            .iter()
520            .map(|field| field.data_type().clone())
521            .collect::<Vec<_>>();
522        if array_type != expected_type {
523            debug!(
524                "State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}",
525                self.state_fields.len(),
526                arrays.len(),
527                self.state_fields,
528                array_type,
529            );
530            let guess_schema = arrays
531                .iter()
532                .enumerate()
533                .map(|(index, array)| {
534                    Field::new(
535                        format!("col_{index}[mismatch_state]").as_str(),
536                        array.data_type().clone(),
537                        true,
538                    )
539                })
540                .collect::<Fields>();
541            let array = StructArray::try_new(guess_schema, arrays, None)?;
542            return Ok(Arc::new(array));
543        }
544
545        Ok(Arc::new(StructArray::try_new(
546            self.state_fields.clone(),
547            arrays,
548            None,
549        )?))
550    }
551}
552
553impl GroupsAccumulator for StateGroupsAccum {
554    fn update_batch(
555        &mut self,
556        values: &[ArrayRef],
557        group_indices: &[usize],
558        opt_filter: Option<&BooleanArray>,
559        total_num_groups: usize,
560    ) -> datafusion_common::Result<()> {
561        self.inner
562            .update_batch(values, group_indices, opt_filter, total_num_groups)
563    }
564
565    fn merge_batch(
566        &mut self,
567        values: &[ArrayRef],
568        group_indices: &[usize],
569        opt_filter: Option<&BooleanArray>,
570        total_num_groups: usize,
571    ) -> datafusion_common::Result<()> {
572        self.inner
573            .merge_batch(values, group_indices, opt_filter, total_num_groups)
574    }
575
576    fn evaluate(&mut self, emit_to: EmitTo) -> datafusion_common::Result<ArrayRef> {
577        let state = self.inner.state(emit_to)?;
578        self.wrap_state_arrays(state)
579    }
580
581    fn state(&mut self, emit_to: EmitTo) -> datafusion_common::Result<Vec<ArrayRef>> {
582        self.inner.state(emit_to)
583    }
584
585    fn convert_to_state(
586        &self,
587        values: &[ArrayRef],
588        opt_filter: Option<&BooleanArray>,
589    ) -> datafusion_common::Result<Vec<ArrayRef>> {
590        self.inner.convert_to_state(values, opt_filter)
591    }
592
593    fn supports_convert_to_state(&self) -> bool {
594        self.inner.supports_convert_to_state()
595    }
596
597    fn size(&self) -> usize {
598        self.inner.size()
599    }
600}
601
602impl StateAccum {
603    pub fn new(
604        inner: Box<dyn Accumulator>,
605        state_type: DataType,
606    ) -> datafusion_common::Result<Self> {
607        let DataType::Struct(fields) = state_type else {
608            return Err(datafusion_common::DataFusionError::Internal(format!(
609                "Expected a struct type for state, got: {:?}",
610                state_type
611            )));
612        };
613        Ok(Self {
614            inner,
615            state_fields: fields,
616        })
617    }
618}
619
620impl Accumulator for StateAccum {
621    fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
622        let state = self.inner.state()?;
623
624        let array = state
625            .iter()
626            .map(|s| s.to_array())
627            .collect::<Result<Vec<_>, _>>()?;
628        let array_type = array
629            .iter()
630            .map(|a| a.data_type().clone())
631            .collect::<Vec<_>>();
632        let expected_type: Vec<_> = self
633            .state_fields
634            .iter()
635            .map(|f| f.data_type().clone())
636            .collect();
637        if array_type != expected_type {
638            debug!(
639                "State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}",
640                self.state_fields.len(),
641                array.len(),
642                self.state_fields,
643                array_type,
644            );
645            let guess_schema = array
646                .iter()
647                .enumerate()
648                .map(|(index, array)| {
649                    Field::new(
650                        format!("col_{index}[mismatch_state]").as_str(),
651                        array.data_type().clone(),
652                        true,
653                    )
654                })
655                .collect::<Fields>();
656            let arr = StructArray::try_new(guess_schema, array, None)?;
657
658            return Ok(ScalarValue::Struct(Arc::new(arr)));
659        }
660
661        let struct_array = StructArray::try_new(self.state_fields.clone(), array, None)?;
662        Ok(ScalarValue::Struct(Arc::new(struct_array)))
663    }
664
665    fn merge_batch(
666        &mut self,
667        states: &[datatypes::arrow::array::ArrayRef],
668    ) -> datafusion_common::Result<()> {
669        self.inner.merge_batch(states)
670    }
671
672    fn update_batch(
673        &mut self,
674        values: &[datatypes::arrow::array::ArrayRef],
675    ) -> datafusion_common::Result<()> {
676        self.inner.update_batch(values)
677    }
678
679    fn size(&self) -> usize {
680        self.inner.size()
681    }
682
683    fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
684        self.inner.state()
685    }
686}
687
688/// TODO(discord9): mark this function as non-ser/de able
689///
690/// This wrapper shouldn't be register as a udaf, as it contain extra data that is not serializable.
691/// and changes for different logical plans.
692#[derive(Debug, Clone)]
693pub struct MergeWrapper {
694    inner: AggregateUDF,
695    name: String,
696    merge_signature: Signature,
697    /// The original physical expression of the aggregate function, can't store the original aggregate function directly, as PhysicalExpr didn't implement Any
698    original_phy_expr: Arc<AggregateFunctionExpr>,
699    return_field: FieldRef,
700}
701impl MergeWrapper {
702    pub fn new(
703        inner: AggregateUDF,
704        original_phy_expr: Arc<AggregateFunctionExpr>,
705        original_input_fields: Vec<FieldRef>,
706    ) -> datafusion_common::Result<Self> {
707        let name = aggr_merge_func_name(inner.name());
708        // the input type is actually struct type, which is the state fields of the original aggregate function.
709        let merge_signature = Signature::user_defined(datafusion_expr::Volatility::Immutable);
710        let return_field = inner.return_field(&original_input_fields)?.clone();
711
712        Ok(Self {
713            inner,
714            name,
715            merge_signature,
716            original_phy_expr,
717            return_field,
718        })
719    }
720
721    pub fn inner(&self) -> &AggregateUDF {
722        &self.inner
723    }
724}
725
726impl AggregateUDFImpl for MergeWrapper {
727    fn accumulator<'a, 'b>(
728        &'a self,
729        acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
730    ) -> datafusion_common::Result<Box<dyn Accumulator>> {
731        if acc_args.exprs.len() != 1
732            || !matches!(
733                acc_args.exprs[0].data_type(acc_args.schema)?,
734                DataType::Struct(_)
735            )
736        {
737            return Err(datafusion_common::DataFusionError::Internal(format!(
738                "Expected one struct type as input, got: {:?}",
739                acc_args.schema
740            )));
741        }
742        let input_type = acc_args.exprs[0].data_type(acc_args.schema)?;
743        let DataType::Struct(fields) = input_type else {
744            return Err(datafusion_common::DataFusionError::Internal(format!(
745                "Expected a struct type for input, got: {:?}",
746                input_type
747            )));
748        };
749
750        let inner_accum = self.original_phy_expr.create_accumulator()?;
751        Ok(Box::new(MergeAccum::new(inner_accum, &fields)))
752    }
753
754    fn as_any(&self) -> &dyn std::any::Any {
755        self
756    }
757    fn name(&self) -> &str {
758        self.name.as_str()
759    }
760
761    fn is_nullable(&self) -> bool {
762        self.inner.is_nullable()
763    }
764
765    /// Notice here the `arg_types` is actually the `state_fields`'s data types,
766    /// so return fixed return type instead of using `arg_types` to determine the return type.
767    fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
768        // The return type is the same as the original aggregate function's return type.
769        Ok(self.return_field.data_type().clone())
770    }
771
772    /// Similar to return_type, we just return the fixed return field.
773    fn return_field(&self, _arg_fields: &[FieldRef]) -> datafusion_common::Result<FieldRef> {
774        Ok(self.return_field.clone())
775    }
776
777    fn signature(&self) -> &Signature {
778        &self.merge_signature
779    }
780
781    /// Coerce types also do nothing, as optimizer should be able to already make struct types
782    fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
783        // just check if the arg_types are only one and is struct array
784        if arg_types.len() != 1 || !matches!(arg_types.first(), Some(DataType::Struct(_))) {
785            return Err(datafusion_common::DataFusionError::Internal(format!(
786                "Expected one struct type as input, got: {:?}",
787                arg_types
788            )));
789        }
790        Ok(arg_types.to_vec())
791    }
792
793    /// Just return the original aggregate function's state fields.
794    fn state_fields(
795        &self,
796        _args: datafusion_expr::function::StateFieldsArgs,
797    ) -> datafusion_common::Result<Vec<FieldRef>> {
798        self.original_phy_expr.state_fields()
799    }
800}
801
802impl PartialEq for MergeWrapper {
803    fn eq(&self, other: &Self) -> bool {
804        self.inner == other.inner
805    }
806}
807
808impl Eq for MergeWrapper {}
809
810impl Hash for MergeWrapper {
811    fn hash<H: Hasher>(&self, state: &mut H) {
812        self.inner.hash(state);
813    }
814}
815
816/// The merge accumulator, which modify `update_batch`'s behavior to accept one struct array which
817/// include the state fields of original aggregate function, and merge said states into original accumulator
818/// the output is the same as original aggregate function
819#[derive(Debug)]
820pub struct MergeAccum {
821    inner: Box<dyn Accumulator>,
822    state_fields: Fields,
823}
824
825impl MergeAccum {
826    pub fn new(inner: Box<dyn Accumulator>, state_fields: &Fields) -> Self {
827        Self {
828            inner,
829            state_fields: state_fields.clone(),
830        }
831    }
832}
833
834impl Accumulator for MergeAccum {
835    fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
836        self.inner.evaluate()
837    }
838
839    fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
840        self.inner.merge_batch(states)
841    }
842
843    fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
844        let value = values.first().ok_or_else(|| {
845            datafusion_common::DataFusionError::Internal("No values provided for merge".to_string())
846        })?;
847        // The input values are states from other accumulators, so we merge them.
848        let struct_arr = value
849            .as_any()
850            .downcast_ref::<StructArray>()
851            .ok_or_else(|| {
852                datafusion_common::DataFusionError::Internal(format!(
853                    "Expected StructArray, got: {:?}",
854                    value.data_type()
855                ))
856            })?;
857        let fields = struct_arr.fields();
858        if fields != &self.state_fields {
859            debug!(
860                "State fields mismatch, expected: {:?}, got: {:?}",
861                self.state_fields, fields
862            );
863            // state fields mismatch might be acceptable by datafusion, continue
864        }
865
866        // now fields should be the same, so we can merge the batch
867        // by pass the columns as order should be the same
868        let state_columns = struct_arr.columns();
869        self.inner.merge_batch(state_columns)
870    }
871
872    fn size(&self) -> usize {
873        self.inner.size()
874    }
875
876    fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
877        self.inner.state()
878    }
879}