common_function/aggrs/
aggr_wrapper.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Wrapper for making aggregate functions out of state/merge functions of original aggregate functions.
16//!
17//! i.e. for a aggregate function `foo`, we will have a state function `foo_state` and a merge function `foo_merge`.
18//!
19//! `foo_state`'s input args is the same as `foo`'s, and its output is a state object.
20//! Note that `foo_state` might have multiple output columns, so it's a struct array
21//! that each output column is a struct field.
22//! `foo_merge`'s input arg is the same as `foo_state`'s output, and its output is the same as `foo`'s input.
23//!
24
25use std::hash::{Hash, Hasher};
26use std::sync::Arc;
27
28use arrow::array::StructArray;
29use arrow_schema::{FieldRef, Fields};
30use common_telemetry::debug;
31use datafusion::functions_aggregate::all_default_aggregate_functions;
32use datafusion::functions_aggregate::count::Count;
33use datafusion::functions_aggregate::min_max::{Max, Min};
34use datafusion::optimizer::AnalyzerRule;
35use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
36use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter;
37use datafusion_common::{Column, ScalarValue};
38use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
39use datafusion_expr::function::StateFieldsArgs;
40use datafusion_expr::{
41    Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, Expr, ExprSchemable, LogicalPlan,
42    Signature,
43};
44use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
45use datatypes::arrow::datatypes::{DataType, Field};
46
47use crate::aggrs::aggr_wrapper::fix_order::FixStateUdafOrderingAnalyzer;
48use crate::function_registry::{FUNCTION_REGISTRY, FunctionRegistry};
49
50pub mod fix_order;
51#[cfg(test)]
52mod tests;
53
54/// Returns the name of the state function for the given aggregate function name.
55/// The state function is used to compute the state of the aggregate function.
56/// The state function's name is in the format `__<aggr_name>_state
57pub fn aggr_state_func_name(aggr_name: &str) -> String {
58    format!("__{}_state", aggr_name)
59}
60
61/// Returns the name of the merge function for the given aggregate function name.
62/// The merge function is used to merge the states of the state functions.
63/// The merge function's name is in the format `__<aggr_name>_merge
64pub fn aggr_merge_func_name(aggr_name: &str) -> String {
65    format!("__{}_merge", aggr_name)
66}
67
68/// Check if the given aggregate expression is steppable.
69/// As in if it can be split into multiple steps:
70/// i.e. on datanode first call `state(input)` then
71/// on frontend call `calc(merge(state))` to get the final result.
72pub fn is_all_aggr_exprs_steppable(aggr_exprs: &[Expr]) -> bool {
73    aggr_exprs.iter().all(|expr| {
74        if let Some(aggr_func) = get_aggr_func(expr) {
75            if aggr_func.params.distinct {
76                // Distinct aggregate functions are not steppable(yet).
77                // TODO(discord9): support distinct aggregate functions.
78                return false;
79            }
80
81            // whether the corresponding state function exists in the registry
82            FUNCTION_REGISTRY.is_aggr_func_exist(&aggr_state_func_name(aggr_func.func.name()))
83        } else {
84            false
85        }
86    })
87}
88
89pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
90    let mut expr_ref = expr;
91    while let Expr::Alias(alias) = expr_ref {
92        expr_ref = &alias.expr;
93    }
94    if let Expr::AggregateFunction(aggr_func) = expr_ref {
95        Some(aggr_func)
96    } else {
97        None
98    }
99}
100
101/// A wrapper to make an aggregate function out of the state and merge functions of the original aggregate function.
102/// It contains the original aggregate function, the state functions, and the merge function.
103///
104/// Notice state functions may have multiple output columns, so it's return type is always a struct array, and the merge function is used to merge the states of the state functions.
105#[derive(Debug, Clone)]
106pub struct StateMergeHelper;
107
108/// A struct to hold the two aggregate plans, one for the state function(lower) and one for the merge function(upper).
109#[allow(unused)]
110#[derive(Debug, Clone)]
111pub struct StepAggrPlan {
112    /// Upper merge plan, which is the aggregate plan that merges the states of the state function.
113    pub upper_merge: LogicalPlan,
114    /// Lower state plan, which is the aggregate plan that computes the state of the aggregate function.
115    pub lower_state: LogicalPlan,
116}
117
118impl StateMergeHelper {
119    /// Register all the `state` function of supported aggregate functions.
120    /// Note that can't register `merge` function here, as it needs to be created from the original aggregate function with given input types.
121    pub fn register(registry: &FunctionRegistry) {
122        let all_default = all_default_aggregate_functions();
123        let greptime_custom_aggr_functions = registry.aggregate_functions();
124
125        // if our custom aggregate function have the same name as the default aggregate function, we will override it.
126        let supported = all_default
127            .into_iter()
128            .chain(greptime_custom_aggr_functions.into_iter().map(Arc::new))
129            .collect::<Vec<_>>();
130        debug!(
131            "Registering state functions for supported: {:?}",
132            supported.iter().map(|f| f.name()).collect::<Vec<_>>()
133        );
134
135        let state_func = supported.into_iter().filter_map(|f| {
136            StateWrapper::new((*f).clone())
137                .inspect_err(
138                    |e| common_telemetry::error!(e; "Failed to register state function for {:?}", f),
139                )
140                .ok()
141                .map(AggregateUDF::new_from_impl)
142        });
143
144        for func in state_func {
145            registry.register_aggr(func);
146        }
147    }
148
149    /// Split an aggregate plan into two aggregate plans, one for the state function and one for the merge function.
150    ///
151    pub fn split_aggr_node(aggr_plan: Aggregate) -> datafusion_common::Result<StepAggrPlan> {
152        let aggr = {
153            // certain aggr func need type coercion to work correctly, so we need to analyze the plan first.
154            let aggr_plan = TypeCoercion::new().analyze(
155                LogicalPlan::Aggregate(aggr_plan).clone(),
156                &Default::default(),
157            )?;
158            if let LogicalPlan::Aggregate(aggr) = aggr_plan {
159                aggr
160            } else {
161                return Err(datafusion_common::DataFusionError::Internal(format!(
162                    "Failed to coerce expressions in aggregate plan, expected Aggregate, got: {:?}",
163                    aggr_plan
164                )));
165            }
166        };
167        let mut lower_aggr_exprs = vec![];
168        let mut upper_aggr_exprs = vec![];
169
170        // group exprs for upper plan should refer to the output group expr as column from lower plan
171        // to avoid re-compute group exprs again.
172        let upper_group_exprs = aggr
173            .group_expr
174            .iter()
175            .map(|c| c.qualified_name())
176            .map(|(r, c)| Expr::Column(Column::new(r, c)))
177            .collect();
178
179        for aggr_expr in aggr.aggr_expr.iter() {
180            let Some(aggr_func) = get_aggr_func(aggr_expr) else {
181                return Err(datafusion_common::DataFusionError::NotImplemented(format!(
182                    "Unsupported aggregate expression for step aggr optimize: {:?}",
183                    aggr_expr
184                )));
185            };
186
187            let original_input_types = aggr_func
188                .params
189                .args
190                .iter()
191                .map(|e| e.get_type(&aggr.input.schema()))
192                .collect::<Result<Vec<_>, _>>()?;
193
194            // first create the state function from the original aggregate function.
195            let state_func = StateWrapper::new((*aggr_func.func).clone())?;
196
197            let expr = AggregateFunction {
198                func: Arc::new(state_func.into()),
199                params: aggr_func.params.clone(),
200            };
201            let expr = Expr::AggregateFunction(expr);
202            let lower_state_output_col_name = expr.schema_name().to_string();
203
204            lower_aggr_exprs.push(expr);
205
206            // then create the merge function using the physical expression of the original aggregate function
207            let (original_phy_expr, _filter, _ordering) = create_aggregate_expr_and_maybe_filter(
208                aggr_expr,
209                aggr.input.schema(),
210                aggr.input.schema().as_arrow(),
211                &Default::default(),
212            )?;
213
214            let merge_func = MergeWrapper::new(
215                (*aggr_func.func).clone(),
216                original_phy_expr,
217                original_input_types,
218            )?;
219            let arg = Expr::Column(Column::new_unqualified(lower_state_output_col_name));
220            let expr = AggregateFunction {
221                func: Arc::new(merge_func.into()),
222                // notice filter/order_by is not supported in the merge function, as it's not meaningful to have them in the merge phase.
223                // do notice this order by is only removed in the outer logical plan, the physical plan still have order by and hence
224                // can create correct accumulator with order by.
225                params: AggregateFunctionParams {
226                    args: vec![arg],
227                    distinct: aggr_func.params.distinct,
228                    filter: None,
229                    order_by: vec![],
230                    null_treatment: aggr_func.params.null_treatment,
231                },
232            };
233
234            // alias to the original aggregate expr's schema name, so parent plan can refer to it
235            // correctly.
236            let expr = Expr::AggregateFunction(expr).alias(aggr_expr.schema_name().to_string());
237            upper_aggr_exprs.push(expr);
238        }
239
240        let mut lower = aggr.clone();
241        lower.aggr_expr = lower_aggr_exprs;
242        let lower_plan = LogicalPlan::Aggregate(lower);
243
244        // update aggregate's output schema
245        let lower_plan = lower_plan.recompute_schema()?;
246
247        // should only affect two udaf `first_value/last_value`
248        // which only them have meaningful order by field
249        let fixed_lower_plan =
250            FixStateUdafOrderingAnalyzer.analyze(lower_plan, &Default::default())?;
251
252        let upper = Aggregate::try_new(
253            Arc::new(fixed_lower_plan.clone()),
254            upper_group_exprs,
255            upper_aggr_exprs.clone(),
256        )?;
257        let aggr_plan = LogicalPlan::Aggregate(aggr);
258
259        // upper schema's output schema should be the same as the original aggregate plan's output schema
260        let upper_check = upper;
261        let upper_plan = LogicalPlan::Aggregate(upper_check).recompute_schema()?;
262        if *upper_plan.schema() != *aggr_plan.schema() {
263            return Err(datafusion_common::DataFusionError::Internal(format!(
264                "Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[original]:{}",
265                upper_plan.schema(),
266                aggr_plan.schema()
267            )));
268        }
269
270        Ok(StepAggrPlan {
271            lower_state: fixed_lower_plan,
272            upper_merge: upper_plan,
273        })
274    }
275}
276
277/// Wrapper to make an aggregate function out of a state function.
278#[derive(Debug, Clone, PartialEq, Eq, Hash)]
279pub struct StateWrapper {
280    inner: AggregateUDF,
281    name: String,
282    /// Default to empty, might get fixed by analyzer later
283    ordering: Vec<FieldRef>,
284    /// Default to false, might get fixed by analyzer later
285    distinct: bool,
286}
287
288impl StateWrapper {
289    /// `state_index`: The index of the state in the output of the state function.
290    pub fn new(inner: AggregateUDF) -> datafusion_common::Result<Self> {
291        let name = aggr_state_func_name(inner.name());
292        Ok(Self {
293            inner,
294            name,
295            ordering: vec![],
296            distinct: false,
297        })
298    }
299
300    pub fn inner(&self) -> &AggregateUDF {
301        &self.inner
302    }
303
304    /// Deduce the return type of the original aggregate function
305    /// based on the accumulator arguments.
306    ///
307    pub fn deduce_aggr_return_type(
308        &self,
309        acc_args: &datafusion_expr::function::AccumulatorArgs,
310    ) -> datafusion_common::Result<FieldRef> {
311        let input_fields = acc_args
312            .exprs
313            .iter()
314            .map(|e| e.return_field(acc_args.schema))
315            .collect::<Result<Vec<_>, _>>()?;
316        self.inner.return_field(&input_fields).inspect_err(|e| {
317            common_telemetry::error!(
318                "StateWrapper: {:#?}\nacc_args:{:?}\nerror:{:?}",
319                &self,
320                &acc_args,
321                e
322            );
323        })
324    }
325}
326
327impl AggregateUDFImpl for StateWrapper {
328    fn accumulator<'a, 'b>(
329        &'a self,
330        acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
331    ) -> datafusion_common::Result<Box<dyn Accumulator>> {
332        // fix and recover proper acc args for the original aggregate function.
333        let state_type = acc_args.return_type().clone();
334        let inner = {
335            let acc_args = datafusion_expr::function::AccumulatorArgs {
336                return_field: self.deduce_aggr_return_type(&acc_args)?,
337                schema: acc_args.schema,
338                ignore_nulls: acc_args.ignore_nulls,
339                order_bys: acc_args.order_bys,
340                is_reversed: acc_args.is_reversed,
341                name: acc_args.name,
342                is_distinct: acc_args.is_distinct,
343                exprs: acc_args.exprs,
344            };
345            self.inner.accumulator(acc_args)?
346        };
347
348        Ok(Box::new(StateAccum::new(inner, state_type)?))
349    }
350
351    fn as_any(&self) -> &dyn std::any::Any {
352        self
353    }
354    fn name(&self) -> &str {
355        self.name.as_str()
356    }
357
358    fn is_nullable(&self) -> bool {
359        self.inner.is_nullable()
360    }
361
362    /// Return state_fields as the output struct type.
363    ///
364    fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
365        let input_fields = &arg_types
366            .iter()
367            .map(|x| Arc::new(Field::new("x", x.clone(), false)))
368            .collect::<Vec<_>>();
369
370        let state_fields_args = StateFieldsArgs {
371            name: self.inner().name(),
372            input_fields,
373            return_field: self.inner.return_field(input_fields)?,
374            // those args are also needed as they are vital to construct the state fields correctly.
375            ordering_fields: &self.ordering,
376            is_distinct: self.distinct,
377        };
378        let state_fields = self.inner.state_fields(state_fields_args)?;
379
380        let state_fields = state_fields
381            .into_iter()
382            .map(|f| {
383                let mut f = f.as_ref().clone();
384                // since state can be null when no input rows, so make all fields nullable
385                f.set_nullable(true);
386                Arc::new(f)
387            })
388            .collect::<Vec<_>>();
389
390        let struct_field = DataType::Struct(state_fields.into());
391        Ok(struct_field)
392    }
393
394    /// The state function's output fields are the same as the original aggregate function's state fields.
395    fn state_fields(
396        &self,
397        args: datafusion_expr::function::StateFieldsArgs,
398    ) -> datafusion_common::Result<Vec<FieldRef>> {
399        let state_fields_args = StateFieldsArgs {
400            name: args.name,
401            input_fields: args.input_fields,
402            return_field: self.inner.return_field(args.input_fields)?,
403            ordering_fields: args.ordering_fields,
404            is_distinct: args.is_distinct,
405        };
406        self.inner.state_fields(state_fields_args)
407    }
408
409    /// The state function's signature is the same as the original aggregate function's signature,
410    fn signature(&self) -> &Signature {
411        self.inner.signature()
412    }
413
414    /// Coerce types also do nothing, as optimizer should be able to already make struct types
415    fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
416        self.inner.coerce_types(arg_types)
417    }
418
419    fn value_from_stats(
420        &self,
421        statistics_args: &datafusion_expr::StatisticsArgs,
422    ) -> Option<ScalarValue> {
423        let inner = self.inner().inner().as_any();
424        // only count/min/max need special handling here, for getting result from statistics
425        // the result of count/min/max is also the result of count_state so can return directly
426        let can_use_stat = inner.is::<Count>() || inner.is::<Max>() || inner.is::<Min>();
427        if !can_use_stat {
428            return None;
429        }
430
431        // fix return type by extract the first field's data type from the struct type
432        let state_type = if let DataType::Struct(fields) = &statistics_args.return_type {
433            if fields.is_empty() {
434                return None;
435            }
436            fields[0].data_type().clone()
437        } else {
438            return None;
439        };
440
441        let fixed_args = datafusion_expr::StatisticsArgs {
442            statistics: statistics_args.statistics,
443            return_type: &state_type,
444            is_distinct: statistics_args.is_distinct,
445            exprs: statistics_args.exprs,
446        };
447
448        let ret = self.inner().value_from_stats(&fixed_args)?;
449
450        // wrap the result into struct scalar value
451        let fields = if let DataType::Struct(fields) = &statistics_args.return_type {
452            fields
453        } else {
454            return None;
455        };
456
457        let array = ret.to_array().ok()?;
458
459        let struct_array = StructArray::new(fields.clone(), vec![array], None);
460        let ret = ScalarValue::Struct(Arc::new(struct_array));
461        Some(ret)
462    }
463}
464
465/// The wrapper's input is the same as the original aggregate function's input,
466/// and the output is the state function's output.
467#[derive(Debug)]
468pub struct StateAccum {
469    inner: Box<dyn Accumulator>,
470    state_fields: Fields,
471}
472
473impl StateAccum {
474    pub fn new(
475        inner: Box<dyn Accumulator>,
476        state_type: DataType,
477    ) -> datafusion_common::Result<Self> {
478        let DataType::Struct(fields) = state_type else {
479            return Err(datafusion_common::DataFusionError::Internal(format!(
480                "Expected a struct type for state, got: {:?}",
481                state_type
482            )));
483        };
484        Ok(Self {
485            inner,
486            state_fields: fields,
487        })
488    }
489}
490
491impl Accumulator for StateAccum {
492    fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
493        let state = self.inner.state()?;
494
495        let array = state
496            .iter()
497            .map(|s| s.to_array())
498            .collect::<Result<Vec<_>, _>>()?;
499        let array_type = array
500            .iter()
501            .map(|a| a.data_type().clone())
502            .collect::<Vec<_>>();
503        let expected_type: Vec<_> = self
504            .state_fields
505            .iter()
506            .map(|f| f.data_type().clone())
507            .collect();
508        if array_type != expected_type {
509            debug!(
510                "State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}",
511                self.state_fields.len(),
512                array.len(),
513                self.state_fields,
514                array_type,
515            );
516            let guess_schema = array
517                .iter()
518                .enumerate()
519                .map(|(index, array)| {
520                    Field::new(
521                        format!("col_{index}[mismatch_state]").as_str(),
522                        array.data_type().clone(),
523                        true,
524                    )
525                })
526                .collect::<Fields>();
527            let arr = StructArray::try_new(guess_schema, array, None)?;
528
529            return Ok(ScalarValue::Struct(Arc::new(arr)));
530        }
531
532        let struct_array = StructArray::try_new(self.state_fields.clone(), array, None)?;
533        Ok(ScalarValue::Struct(Arc::new(struct_array)))
534    }
535
536    fn merge_batch(
537        &mut self,
538        states: &[datatypes::arrow::array::ArrayRef],
539    ) -> datafusion_common::Result<()> {
540        self.inner.merge_batch(states)
541    }
542
543    fn update_batch(
544        &mut self,
545        values: &[datatypes::arrow::array::ArrayRef],
546    ) -> datafusion_common::Result<()> {
547        self.inner.update_batch(values)
548    }
549
550    fn size(&self) -> usize {
551        self.inner.size()
552    }
553
554    fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
555        self.inner.state()
556    }
557}
558
559/// TODO(discord9): mark this function as non-ser/de able
560///
561/// This wrapper shouldn't be register as a udaf, as it contain extra data that is not serializable.
562/// and changes for different logical plans.
563#[derive(Debug, Clone)]
564pub struct MergeWrapper {
565    inner: AggregateUDF,
566    name: String,
567    merge_signature: Signature,
568    /// The original physical expression of the aggregate function, can't store the original aggregate function directly, as PhysicalExpr didn't implement Any
569    original_phy_expr: Arc<AggregateFunctionExpr>,
570    return_type: DataType,
571}
572impl MergeWrapper {
573    pub fn new(
574        inner: AggregateUDF,
575        original_phy_expr: Arc<AggregateFunctionExpr>,
576        original_input_types: Vec<DataType>,
577    ) -> datafusion_common::Result<Self> {
578        let name = aggr_merge_func_name(inner.name());
579        // the input type is actually struct type, which is the state fields of the original aggregate function.
580        let merge_signature = Signature::user_defined(datafusion_expr::Volatility::Immutable);
581        let return_type = inner.return_type(&original_input_types)?;
582
583        Ok(Self {
584            inner,
585            name,
586            merge_signature,
587            original_phy_expr,
588            return_type,
589        })
590    }
591
592    pub fn inner(&self) -> &AggregateUDF {
593        &self.inner
594    }
595}
596
597impl AggregateUDFImpl for MergeWrapper {
598    fn accumulator<'a, 'b>(
599        &'a self,
600        acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
601    ) -> datafusion_common::Result<Box<dyn Accumulator>> {
602        if acc_args.exprs.len() != 1
603            || !matches!(
604                acc_args.exprs[0].data_type(acc_args.schema)?,
605                DataType::Struct(_)
606            )
607        {
608            return Err(datafusion_common::DataFusionError::Internal(format!(
609                "Expected one struct type as input, got: {:?}",
610                acc_args.schema
611            )));
612        }
613        let input_type = acc_args.exprs[0].data_type(acc_args.schema)?;
614        let DataType::Struct(fields) = input_type else {
615            return Err(datafusion_common::DataFusionError::Internal(format!(
616                "Expected a struct type for input, got: {:?}",
617                input_type
618            )));
619        };
620
621        let inner_accum = self.original_phy_expr.create_accumulator()?;
622        Ok(Box::new(MergeAccum::new(inner_accum, &fields)))
623    }
624
625    fn as_any(&self) -> &dyn std::any::Any {
626        self
627    }
628    fn name(&self) -> &str {
629        self.name.as_str()
630    }
631
632    fn is_nullable(&self) -> bool {
633        self.inner.is_nullable()
634    }
635
636    /// Notice here the `arg_types` is actually the `state_fields`'s data types,
637    /// so return fixed return type instead of using `arg_types` to determine the return type.
638    fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
639        // The return type is the same as the original aggregate function's return type.
640        Ok(self.return_type.clone())
641    }
642    fn signature(&self) -> &Signature {
643        &self.merge_signature
644    }
645
646    /// Coerce types also do nothing, as optimizer should be able to already make struct types
647    fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
648        // just check if the arg_types are only one and is struct array
649        if arg_types.len() != 1 || !matches!(arg_types.first(), Some(DataType::Struct(_))) {
650            return Err(datafusion_common::DataFusionError::Internal(format!(
651                "Expected one struct type as input, got: {:?}",
652                arg_types
653            )));
654        }
655        Ok(arg_types.to_vec())
656    }
657
658    /// Just return the original aggregate function's state fields.
659    fn state_fields(
660        &self,
661        _args: datafusion_expr::function::StateFieldsArgs,
662    ) -> datafusion_common::Result<Vec<FieldRef>> {
663        self.original_phy_expr.state_fields()
664    }
665}
666
667impl PartialEq for MergeWrapper {
668    fn eq(&self, other: &Self) -> bool {
669        self.inner == other.inner
670    }
671}
672
673impl Eq for MergeWrapper {}
674
675impl Hash for MergeWrapper {
676    fn hash<H: Hasher>(&self, state: &mut H) {
677        self.inner.hash(state);
678    }
679}
680
681/// The merge accumulator, which modify `update_batch`'s behavior to accept one struct array which
682/// include the state fields of original aggregate function, and merge said states into original accumulator
683/// the output is the same as original aggregate function
684#[derive(Debug)]
685pub struct MergeAccum {
686    inner: Box<dyn Accumulator>,
687    state_fields: Fields,
688}
689
690impl MergeAccum {
691    pub fn new(inner: Box<dyn Accumulator>, state_fields: &Fields) -> Self {
692        Self {
693            inner,
694            state_fields: state_fields.clone(),
695        }
696    }
697}
698
699impl Accumulator for MergeAccum {
700    fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
701        self.inner.evaluate()
702    }
703
704    fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
705        self.inner.merge_batch(states)
706    }
707
708    fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
709        let value = values.first().ok_or_else(|| {
710            datafusion_common::DataFusionError::Internal("No values provided for merge".to_string())
711        })?;
712        // The input values are states from other accumulators, so we merge them.
713        let struct_arr = value
714            .as_any()
715            .downcast_ref::<StructArray>()
716            .ok_or_else(|| {
717                datafusion_common::DataFusionError::Internal(format!(
718                    "Expected StructArray, got: {:?}",
719                    value.data_type()
720                ))
721            })?;
722        let fields = struct_arr.fields();
723        if fields != &self.state_fields {
724            debug!(
725                "State fields mismatch, expected: {:?}, got: {:?}",
726                self.state_fields, fields
727            );
728            // state fields mismatch might be acceptable by datafusion, continue
729        }
730
731        // now fields should be the same, so we can merge the batch
732        // by pass the columns as order should be the same
733        let state_columns = struct_arr.columns();
734        self.inner.merge_batch(state_columns)
735    }
736
737    fn size(&self) -> usize {
738        self.inner.size()
739    }
740
741    fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
742        self.inner.state()
743    }
744}