common_function/aggrs/
aggr_wrapper.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Wrapper for making aggregate functions out of state/merge functions of original aggregate functions.
16//!
17//! i.e. for a aggregate function `foo`, we will have a state function `foo_state` and a merge function `foo_merge`.
18//!
19//! `foo_state`'s input args is the same as `foo`'s, and its output is a state object.
20//! Note that `foo_state` might have multiple output columns, so it's a struct array
21//! that each output column is a struct field.
22//! `foo_merge`'s input arg is the same as `foo_state`'s output, and its output is the same as `foo`'s input.
23//!
24
25use std::hash::{Hash, Hasher};
26use std::sync::Arc;
27
28use arrow::array::StructArray;
29use arrow_schema::{FieldRef, Fields};
30use common_telemetry::debug;
31use datafusion::functions_aggregate::all_default_aggregate_functions;
32use datafusion::functions_aggregate::count::Count;
33use datafusion::functions_aggregate::min_max::{Max, Min};
34use datafusion::optimizer::AnalyzerRule;
35use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
36use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter;
37use datafusion_common::{Column, ScalarValue};
38use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
39use datafusion_expr::function::StateFieldsArgs;
40use datafusion_expr::{
41    Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, Expr, ExprSchemable, LogicalPlan,
42    Signature,
43};
44use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
45use datatypes::arrow::datatypes::{DataType, Field};
46
47use crate::aggrs::aggr_wrapper::fix_order::FixStateUdafOrderingAnalyzer;
48use crate::function_registry::{FUNCTION_REGISTRY, FunctionRegistry};
49
50pub mod fix_order;
51#[cfg(test)]
52mod tests;
53
54/// Returns the name of the state function for the given aggregate function name.
55/// The state function is used to compute the state of the aggregate function.
56/// The state function's name is in the format `__<aggr_name>_state
57pub fn aggr_state_func_name(aggr_name: &str) -> String {
58    format!("__{}_state", aggr_name)
59}
60
61/// Returns the name of the merge function for the given aggregate function name.
62/// The merge function is used to merge the states of the state functions.
63/// The merge function's name is in the format `__<aggr_name>_merge
64pub fn aggr_merge_func_name(aggr_name: &str) -> String {
65    format!("__{}_merge", aggr_name)
66}
67
68/// Check if the given aggregate expression is steppable.
69/// As in if it can be split into multiple steps:
70/// i.e. on datanode first call `state(input)` then
71/// on frontend call `calc(merge(state))` to get the final result.
72pub fn is_all_aggr_exprs_steppable(aggr_exprs: &[Expr]) -> bool {
73    aggr_exprs.iter().all(|expr| {
74        if let Some(aggr_func) = get_aggr_func(expr) {
75            if aggr_func.params.distinct {
76                // Distinct aggregate functions are not steppable(yet).
77                // TODO(discord9): support distinct aggregate functions.
78                return false;
79            }
80
81            // whether the corresponding state function exists in the registry
82            FUNCTION_REGISTRY.is_aggr_func_exist(&aggr_state_func_name(aggr_func.func.name()))
83        } else {
84            false
85        }
86    })
87}
88
89pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
90    let mut expr_ref = expr;
91    while let Expr::Alias(alias) = expr_ref {
92        expr_ref = &alias.expr;
93    }
94    if let Expr::AggregateFunction(aggr_func) = expr_ref {
95        Some(aggr_func)
96    } else {
97        None
98    }
99}
100
101/// A wrapper to make an aggregate function out of the state and merge functions of the original aggregate function.
102/// It contains the original aggregate function, the state functions, and the merge function.
103///
104/// Notice state functions may have multiple output columns, so it's return type is always a struct array, and the merge function is used to merge the states of the state functions.
105#[derive(Debug, Clone)]
106pub struct StateMergeHelper;
107
108/// A struct to hold the two aggregate plans, one for the state function(lower) and one for the merge function(upper).
109#[allow(unused)]
110#[derive(Debug, Clone)]
111pub struct StepAggrPlan {
112    /// Upper merge plan, which is the aggregate plan that merges the states of the state function.
113    pub upper_merge: LogicalPlan,
114    /// Lower state plan, which is the aggregate plan that computes the state of the aggregate function.
115    pub lower_state: LogicalPlan,
116}
117
118impl StateMergeHelper {
119    /// Register all the `state` function of supported aggregate functions.
120    /// Note that can't register `merge` function here, as it needs to be created from the original aggregate function with given input types.
121    pub fn register(registry: &FunctionRegistry) {
122        let all_default = all_default_aggregate_functions();
123        let greptime_custom_aggr_functions = registry.aggregate_functions();
124
125        // if our custom aggregate function have the same name as the default aggregate function, we will override it.
126        let supported = all_default
127            .into_iter()
128            .chain(greptime_custom_aggr_functions.into_iter().map(Arc::new))
129            .collect::<Vec<_>>();
130        debug!(
131            "Registering state functions for supported: {:?}",
132            supported.iter().map(|f| f.name()).collect::<Vec<_>>()
133        );
134
135        let state_func = supported.into_iter().filter_map(|f| {
136            StateWrapper::new((*f).clone())
137                .inspect_err(
138                    |e| common_telemetry::error!(e; "Failed to register state function for {:?}", f),
139                )
140                .ok()
141                .map(AggregateUDF::new_from_impl)
142        });
143
144        for func in state_func {
145            registry.register_aggr(func);
146        }
147    }
148
149    /// Split an aggregate plan into two aggregate plans, one for the state function and one for the merge function.
150    ///
151    pub fn split_aggr_node(aggr_plan: Aggregate) -> datafusion_common::Result<StepAggrPlan> {
152        let aggr = {
153            // certain aggr func need type coercion to work correctly, so we need to analyze the plan first.
154            let aggr_plan = TypeCoercion::new().analyze(
155                LogicalPlan::Aggregate(aggr_plan).clone(),
156                &Default::default(),
157            )?;
158            if let LogicalPlan::Aggregate(aggr) = aggr_plan {
159                aggr
160            } else {
161                return Err(datafusion_common::DataFusionError::Internal(format!(
162                    "Failed to coerce expressions in aggregate plan, expected Aggregate, got: {:?}",
163                    aggr_plan
164                )));
165            }
166        };
167        let mut lower_aggr_exprs = vec![];
168        let mut upper_aggr_exprs = vec![];
169
170        // group exprs for upper plan should refer to the output group expr as column from lower plan
171        // to avoid re-compute group exprs again.
172        let upper_group_exprs = aggr
173            .group_expr
174            .iter()
175            .map(|c| c.qualified_name())
176            .map(|(r, c)| Expr::Column(Column::new(r, c)))
177            .collect();
178
179        for aggr_expr in aggr.aggr_expr.iter() {
180            let Some(aggr_func) = get_aggr_func(aggr_expr) else {
181                return Err(datafusion_common::DataFusionError::NotImplemented(format!(
182                    "Unsupported aggregate expression for step aggr optimize: {:?}",
183                    aggr_expr
184                )));
185            };
186
187            let original_input_fields = aggr_func
188                .params
189                .args
190                .iter()
191                .map(|e| e.to_field(&aggr.input.schema()).map(|(_, field)| field))
192                .collect::<Result<Vec<_>, _>>()?;
193
194            // first create the state function from the original aggregate function.
195            let state_func = StateWrapper::new((*aggr_func.func).clone())?;
196
197            let expr = AggregateFunction {
198                func: Arc::new(state_func.into()),
199                params: aggr_func.params.clone(),
200            };
201            let expr = Expr::AggregateFunction(expr);
202            let lower_state_output_col_name = expr.schema_name().to_string();
203
204            lower_aggr_exprs.push(expr);
205
206            // then create the merge function using the physical expression of the original aggregate function
207            let (original_phy_expr, _filter, _ordering) = create_aggregate_expr_and_maybe_filter(
208                aggr_expr,
209                aggr.input.schema(),
210                aggr.input.schema().as_arrow(),
211                &Default::default(),
212            )?;
213
214            let merge_func = MergeWrapper::new(
215                (*aggr_func.func).clone(),
216                original_phy_expr,
217                original_input_fields,
218            )?;
219            let arg = Expr::Column(Column::new_unqualified(lower_state_output_col_name));
220            let expr = AggregateFunction {
221                func: Arc::new(merge_func.into()),
222                // notice filter/order_by is not supported in the merge function, as it's not meaningful to have them in the merge phase.
223                // do notice this order by is only removed in the outer logical plan, the physical plan still have order by and hence
224                // can create correct accumulator with order by.
225                params: AggregateFunctionParams {
226                    args: vec![arg],
227                    distinct: aggr_func.params.distinct,
228                    filter: None,
229                    order_by: vec![],
230                    null_treatment: aggr_func.params.null_treatment,
231                },
232            };
233
234            // alias to the original aggregate expr's schema name, so parent plan can refer to it
235            // correctly.
236            let expr = Expr::AggregateFunction(expr).alias(aggr_expr.schema_name().to_string());
237            upper_aggr_exprs.push(expr);
238        }
239
240        let mut lower = aggr.clone();
241        lower.aggr_expr = lower_aggr_exprs;
242        let lower_plan = LogicalPlan::Aggregate(lower);
243
244        // update aggregate's output schema
245        let lower_plan = lower_plan.recompute_schema()?;
246
247        // should only affect two udaf `first_value/last_value`
248        // which only them have meaningful order by field
249        let fixed_lower_plan =
250            FixStateUdafOrderingAnalyzer.analyze(lower_plan, &Default::default())?;
251
252        let upper = Aggregate::try_new(
253            Arc::new(fixed_lower_plan.clone()),
254            upper_group_exprs,
255            upper_aggr_exprs.clone(),
256        )?;
257        let aggr_plan = LogicalPlan::Aggregate(aggr);
258
259        // upper schema's output schema should be the same as the original aggregate plan's output schema
260        let upper_check = upper;
261        let upper_plan = LogicalPlan::Aggregate(upper_check).recompute_schema()?;
262        if *upper_plan.schema() != *aggr_plan.schema() {
263            return Err(datafusion_common::DataFusionError::Internal(format!(
264                "Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[original]:{}",
265                upper_plan.schema(),
266                aggr_plan.schema()
267            )));
268        }
269
270        Ok(StepAggrPlan {
271            lower_state: fixed_lower_plan,
272            upper_merge: upper_plan,
273        })
274    }
275}
276
277/// Wrapper to make an aggregate function out of a state function.
278#[derive(Debug, Clone, PartialEq, Eq, Hash)]
279pub struct StateWrapper {
280    inner: AggregateUDF,
281    name: String,
282    /// Default to empty, might get fixed by analyzer later
283    ordering: Vec<FieldRef>,
284    /// Default to false, might get fixed by analyzer later
285    distinct: bool,
286}
287
288impl StateWrapper {
289    /// `state_index`: The index of the state in the output of the state function.
290    pub fn new(inner: AggregateUDF) -> datafusion_common::Result<Self> {
291        let name = aggr_state_func_name(inner.name());
292        Ok(Self {
293            inner,
294            name,
295            ordering: vec![],
296            distinct: false,
297        })
298    }
299
300    pub fn inner(&self) -> &AggregateUDF {
301        &self.inner
302    }
303
304    /// Deduce the return type of the original aggregate function
305    /// based on the accumulator arguments.
306    ///
307    pub fn deduce_aggr_return_type(
308        &self,
309        acc_args: &datafusion_expr::function::AccumulatorArgs,
310    ) -> datafusion_common::Result<FieldRef> {
311        let input_fields = acc_args
312            .exprs
313            .iter()
314            .map(|e| e.return_field(acc_args.schema))
315            .collect::<Result<Vec<_>, _>>()?;
316        self.inner.return_field(&input_fields).inspect_err(|e| {
317            common_telemetry::error!(
318                "StateWrapper: {:#?}\nacc_args:{:?}\nerror:{:?}",
319                &self,
320                &acc_args,
321                e
322            );
323        })
324    }
325}
326
327impl AggregateUDFImpl for StateWrapper {
328    fn accumulator<'a, 'b>(
329        &'a self,
330        acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
331    ) -> datafusion_common::Result<Box<dyn Accumulator>> {
332        // fix and recover proper acc args for the original aggregate function.
333        let state_type = acc_args.return_type().clone();
334        let inner = {
335            let mut new_acc_args = acc_args.clone();
336            new_acc_args.return_field = self.deduce_aggr_return_type(&acc_args)?;
337            self.inner.accumulator(new_acc_args)?
338        };
339
340        Ok(Box::new(StateAccum::new(inner, state_type)?))
341    }
342
343    fn as_any(&self) -> &dyn std::any::Any {
344        self
345    }
346    fn name(&self) -> &str {
347        self.name.as_str()
348    }
349
350    fn is_nullable(&self) -> bool {
351        self.inner.is_nullable()
352    }
353
354    /// Return state_fields as the output struct type.
355    ///
356    fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
357        let input_fields = &arg_types
358            .iter()
359            .map(|x| Arc::new(Field::new("x", x.clone(), false)))
360            .collect::<Vec<_>>();
361
362        let state_fields_args = StateFieldsArgs {
363            name: self.inner().name(),
364            input_fields,
365            return_field: self.inner.return_field(input_fields)?,
366            // those args are also needed as they are vital to construct the state fields correctly.
367            ordering_fields: &self.ordering,
368            is_distinct: self.distinct,
369        };
370        let state_fields = self.inner.state_fields(state_fields_args)?;
371
372        let state_fields = state_fields
373            .into_iter()
374            .map(|f| {
375                let mut f = f.as_ref().clone();
376                // since state can be null when no input rows, so make all fields nullable
377                f.set_nullable(true);
378                Arc::new(f)
379            })
380            .collect::<Vec<_>>();
381
382        let struct_field = DataType::Struct(state_fields.into());
383        Ok(struct_field)
384    }
385
386    /// The state function's output fields are the same as the original aggregate function's state fields.
387    fn state_fields(
388        &self,
389        args: datafusion_expr::function::StateFieldsArgs,
390    ) -> datafusion_common::Result<Vec<FieldRef>> {
391        let state_fields_args = StateFieldsArgs {
392            name: args.name,
393            input_fields: args.input_fields,
394            return_field: self.inner.return_field(args.input_fields)?,
395            ordering_fields: args.ordering_fields,
396            is_distinct: args.is_distinct,
397        };
398        self.inner.state_fields(state_fields_args)
399    }
400
401    /// The state function's signature is the same as the original aggregate function's signature,
402    fn signature(&self) -> &Signature {
403        self.inner.signature()
404    }
405
406    /// Coerce types also do nothing, as optimizer should be able to already make struct types
407    fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
408        self.inner.coerce_types(arg_types)
409    }
410
411    fn value_from_stats(
412        &self,
413        statistics_args: &datafusion_expr::StatisticsArgs,
414    ) -> Option<ScalarValue> {
415        let inner = self.inner().inner().as_any();
416        // only count/min/max need special handling here, for getting result from statistics
417        // the result of count/min/max is also the result of count_state so can return directly
418        let can_use_stat = inner.is::<Count>() || inner.is::<Max>() || inner.is::<Min>();
419        if !can_use_stat {
420            return None;
421        }
422
423        // fix return type by extract the first field's data type from the struct type
424        let state_type = if let DataType::Struct(fields) = &statistics_args.return_type {
425            if fields.is_empty() {
426                return None;
427            }
428            fields[0].data_type().clone()
429        } else {
430            return None;
431        };
432
433        let fixed_args = datafusion_expr::StatisticsArgs {
434            statistics: statistics_args.statistics,
435            return_type: &state_type,
436            is_distinct: statistics_args.is_distinct,
437            exprs: statistics_args.exprs,
438        };
439
440        let ret = self.inner().value_from_stats(&fixed_args)?;
441
442        // wrap the result into struct scalar value
443        let fields = if let DataType::Struct(fields) = &statistics_args.return_type {
444            fields
445        } else {
446            return None;
447        };
448
449        let array = ret.to_array().ok()?;
450
451        let struct_array = StructArray::new(fields.clone(), vec![array], None);
452        let ret = ScalarValue::Struct(Arc::new(struct_array));
453        Some(ret)
454    }
455}
456
457/// The wrapper's input is the same as the original aggregate function's input,
458/// and the output is the state function's output.
459#[derive(Debug)]
460pub struct StateAccum {
461    inner: Box<dyn Accumulator>,
462    state_fields: Fields,
463}
464
465impl StateAccum {
466    pub fn new(
467        inner: Box<dyn Accumulator>,
468        state_type: DataType,
469    ) -> datafusion_common::Result<Self> {
470        let DataType::Struct(fields) = state_type else {
471            return Err(datafusion_common::DataFusionError::Internal(format!(
472                "Expected a struct type for state, got: {:?}",
473                state_type
474            )));
475        };
476        Ok(Self {
477            inner,
478            state_fields: fields,
479        })
480    }
481}
482
483impl Accumulator for StateAccum {
484    fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
485        let state = self.inner.state()?;
486
487        let array = state
488            .iter()
489            .map(|s| s.to_array())
490            .collect::<Result<Vec<_>, _>>()?;
491        let array_type = array
492            .iter()
493            .map(|a| a.data_type().clone())
494            .collect::<Vec<_>>();
495        let expected_type: Vec<_> = self
496            .state_fields
497            .iter()
498            .map(|f| f.data_type().clone())
499            .collect();
500        if array_type != expected_type {
501            debug!(
502                "State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}",
503                self.state_fields.len(),
504                array.len(),
505                self.state_fields,
506                array_type,
507            );
508            let guess_schema = array
509                .iter()
510                .enumerate()
511                .map(|(index, array)| {
512                    Field::new(
513                        format!("col_{index}[mismatch_state]").as_str(),
514                        array.data_type().clone(),
515                        true,
516                    )
517                })
518                .collect::<Fields>();
519            let arr = StructArray::try_new(guess_schema, array, None)?;
520
521            return Ok(ScalarValue::Struct(Arc::new(arr)));
522        }
523
524        let struct_array = StructArray::try_new(self.state_fields.clone(), array, None)?;
525        Ok(ScalarValue::Struct(Arc::new(struct_array)))
526    }
527
528    fn merge_batch(
529        &mut self,
530        states: &[datatypes::arrow::array::ArrayRef],
531    ) -> datafusion_common::Result<()> {
532        self.inner.merge_batch(states)
533    }
534
535    fn update_batch(
536        &mut self,
537        values: &[datatypes::arrow::array::ArrayRef],
538    ) -> datafusion_common::Result<()> {
539        self.inner.update_batch(values)
540    }
541
542    fn size(&self) -> usize {
543        self.inner.size()
544    }
545
546    fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
547        self.inner.state()
548    }
549}
550
551/// TODO(discord9): mark this function as non-ser/de able
552///
553/// This wrapper shouldn't be register as a udaf, as it contain extra data that is not serializable.
554/// and changes for different logical plans.
555#[derive(Debug, Clone)]
556pub struct MergeWrapper {
557    inner: AggregateUDF,
558    name: String,
559    merge_signature: Signature,
560    /// The original physical expression of the aggregate function, can't store the original aggregate function directly, as PhysicalExpr didn't implement Any
561    original_phy_expr: Arc<AggregateFunctionExpr>,
562    return_field: FieldRef,
563}
564impl MergeWrapper {
565    pub fn new(
566        inner: AggregateUDF,
567        original_phy_expr: Arc<AggregateFunctionExpr>,
568        original_input_fields: Vec<FieldRef>,
569    ) -> datafusion_common::Result<Self> {
570        let name = aggr_merge_func_name(inner.name());
571        // the input type is actually struct type, which is the state fields of the original aggregate function.
572        let merge_signature = Signature::user_defined(datafusion_expr::Volatility::Immutable);
573        let return_field = inner.return_field(&original_input_fields)?.clone();
574
575        Ok(Self {
576            inner,
577            name,
578            merge_signature,
579            original_phy_expr,
580            return_field,
581        })
582    }
583
584    pub fn inner(&self) -> &AggregateUDF {
585        &self.inner
586    }
587}
588
589impl AggregateUDFImpl for MergeWrapper {
590    fn accumulator<'a, 'b>(
591        &'a self,
592        acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
593    ) -> datafusion_common::Result<Box<dyn Accumulator>> {
594        if acc_args.exprs.len() != 1
595            || !matches!(
596                acc_args.exprs[0].data_type(acc_args.schema)?,
597                DataType::Struct(_)
598            )
599        {
600            return Err(datafusion_common::DataFusionError::Internal(format!(
601                "Expected one struct type as input, got: {:?}",
602                acc_args.schema
603            )));
604        }
605        let input_type = acc_args.exprs[0].data_type(acc_args.schema)?;
606        let DataType::Struct(fields) = input_type else {
607            return Err(datafusion_common::DataFusionError::Internal(format!(
608                "Expected a struct type for input, got: {:?}",
609                input_type
610            )));
611        };
612
613        let inner_accum = self.original_phy_expr.create_accumulator()?;
614        Ok(Box::new(MergeAccum::new(inner_accum, &fields)))
615    }
616
617    fn as_any(&self) -> &dyn std::any::Any {
618        self
619    }
620    fn name(&self) -> &str {
621        self.name.as_str()
622    }
623
624    fn is_nullable(&self) -> bool {
625        self.inner.is_nullable()
626    }
627
628    /// Notice here the `arg_types` is actually the `state_fields`'s data types,
629    /// so return fixed return type instead of using `arg_types` to determine the return type.
630    fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
631        // The return type is the same as the original aggregate function's return type.
632        Ok(self.return_field.data_type().clone())
633    }
634
635    /// Similar to return_type, we just return the fixed return field.
636    fn return_field(&self, _arg_fields: &[FieldRef]) -> datafusion_common::Result<FieldRef> {
637        Ok(self.return_field.clone())
638    }
639
640    fn signature(&self) -> &Signature {
641        &self.merge_signature
642    }
643
644    /// Coerce types also do nothing, as optimizer should be able to already make struct types
645    fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
646        // just check if the arg_types are only one and is struct array
647        if arg_types.len() != 1 || !matches!(arg_types.first(), Some(DataType::Struct(_))) {
648            return Err(datafusion_common::DataFusionError::Internal(format!(
649                "Expected one struct type as input, got: {:?}",
650                arg_types
651            )));
652        }
653        Ok(arg_types.to_vec())
654    }
655
656    /// Just return the original aggregate function's state fields.
657    fn state_fields(
658        &self,
659        _args: datafusion_expr::function::StateFieldsArgs,
660    ) -> datafusion_common::Result<Vec<FieldRef>> {
661        self.original_phy_expr.state_fields()
662    }
663}
664
665impl PartialEq for MergeWrapper {
666    fn eq(&self, other: &Self) -> bool {
667        self.inner == other.inner
668    }
669}
670
671impl Eq for MergeWrapper {}
672
673impl Hash for MergeWrapper {
674    fn hash<H: Hasher>(&self, state: &mut H) {
675        self.inner.hash(state);
676    }
677}
678
679/// The merge accumulator, which modify `update_batch`'s behavior to accept one struct array which
680/// include the state fields of original aggregate function, and merge said states into original accumulator
681/// the output is the same as original aggregate function
682#[derive(Debug)]
683pub struct MergeAccum {
684    inner: Box<dyn Accumulator>,
685    state_fields: Fields,
686}
687
688impl MergeAccum {
689    pub fn new(inner: Box<dyn Accumulator>, state_fields: &Fields) -> Self {
690        Self {
691            inner,
692            state_fields: state_fields.clone(),
693        }
694    }
695}
696
697impl Accumulator for MergeAccum {
698    fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
699        self.inner.evaluate()
700    }
701
702    fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
703        self.inner.merge_batch(states)
704    }
705
706    fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
707        let value = values.first().ok_or_else(|| {
708            datafusion_common::DataFusionError::Internal("No values provided for merge".to_string())
709        })?;
710        // The input values are states from other accumulators, so we merge them.
711        let struct_arr = value
712            .as_any()
713            .downcast_ref::<StructArray>()
714            .ok_or_else(|| {
715                datafusion_common::DataFusionError::Internal(format!(
716                    "Expected StructArray, got: {:?}",
717                    value.data_type()
718                ))
719            })?;
720        let fields = struct_arr.fields();
721        if fields != &self.state_fields {
722            debug!(
723                "State fields mismatch, expected: {:?}, got: {:?}",
724                self.state_fields, fields
725            );
726            // state fields mismatch might be acceptable by datafusion, continue
727        }
728
729        // now fields should be the same, so we can merge the batch
730        // by pass the columns as order should be the same
731        let state_columns = struct_arr.columns();
732        self.inner.merge_batch(state_columns)
733    }
734
735    fn size(&self) -> usize {
736        self.inner.size()
737    }
738
739    fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
740        self.inner.state()
741    }
742}