common_function/aggrs/
aggr_wrapper.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Wrapper for making aggregate functions out of state/merge functions of original aggregate functions.
16//!
17//! i.e. for a aggregate function `foo`, we will have a state function `foo_state` and a merge function `foo_merge`.
18//!
19//! `foo_state`'s input args is the same as `foo`'s, and its output is a state object.
20//! Note that `foo_state` might have multiple output columns, so it's a struct array
21//! that each output column is a struct field.
22//! `foo_merge`'s input arg is the same as `foo_state`'s output, and its output is the same as `foo`'s input.
23//!
24
25use std::sync::Arc;
26
27use arrow::array::StructArray;
28use arrow_schema::{FieldRef, Fields};
29use common_telemetry::debug;
30use datafusion::functions_aggregate::all_default_aggregate_functions;
31use datafusion::optimizer::AnalyzerRule;
32use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
33use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter;
34use datafusion_common::{Column, ScalarValue};
35use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
36use datafusion_expr::function::StateFieldsArgs;
37use datafusion_expr::{
38    Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, Expr, ExprSchemable, LogicalPlan,
39    Signature,
40};
41use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
42use datatypes::arrow::datatypes::{DataType, Field};
43
44use crate::function_registry::FunctionRegistry;
45
46/// Returns the name of the state function for the given aggregate function name.
47/// The state function is used to compute the state of the aggregate function.
48/// The state function's name is in the format `__<aggr_name>_state
49pub fn aggr_state_func_name(aggr_name: &str) -> String {
50    format!("__{}_state", aggr_name)
51}
52
53/// Returns the name of the merge function for the given aggregate function name.
54/// The merge function is used to merge the states of the state functions.
55/// The merge function's name is in the format `__<aggr_name>_merge
56pub fn aggr_merge_func_name(aggr_name: &str) -> String {
57    format!("__{}_merge", aggr_name)
58}
59
60/// A wrapper to make an aggregate function out of the state and merge functions of the original aggregate function.
61/// It contains the original aggregate function, the state functions, and the merge function.
62///
63/// Notice state functions may have multiple output columns, so it's return type is always a struct array, and the merge function is used to merge the states of the state functions.
64#[derive(Debug, Clone)]
65pub struct StateMergeHelper;
66
67/// A struct to hold the two aggregate plans, one for the state function(lower) and one for the merge function(upper).
68#[allow(unused)]
69#[derive(Debug, Clone)]
70pub struct StepAggrPlan {
71    /// Upper merge plan, which is the aggregate plan that merges the states of the state function.
72    pub upper_merge: LogicalPlan,
73    /// Lower state plan, which is the aggregate plan that computes the state of the aggregate function.
74    pub lower_state: LogicalPlan,
75}
76
77pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
78    let mut expr_ref = expr;
79    while let Expr::Alias(alias) = expr_ref {
80        expr_ref = &alias.expr;
81    }
82    if let Expr::AggregateFunction(aggr_func) = expr_ref {
83        Some(aggr_func)
84    } else {
85        None
86    }
87}
88
89impl StateMergeHelper {
90    /// Register all the `state` function of supported aggregate functions.
91    /// Note that can't register `merge` function here, as it needs to be created from the original aggregate function with given input types.
92    pub fn register(registry: &FunctionRegistry) {
93        let all_default = all_default_aggregate_functions();
94        let greptime_custom_aggr_functions = registry.aggregate_functions();
95
96        // if our custom aggregate function have the same name as the default aggregate function, we will override it.
97        let supported = all_default
98            .into_iter()
99            .chain(greptime_custom_aggr_functions.into_iter().map(Arc::new))
100            .collect::<Vec<_>>();
101        debug!(
102            "Registering state functions for supported: {:?}",
103            supported.iter().map(|f| f.name()).collect::<Vec<_>>()
104        );
105
106        let state_func = supported.into_iter().filter_map(|f| {
107            StateWrapper::new((*f).clone())
108                .inspect_err(
109                    |e| common_telemetry::error!(e; "Failed to register state function for {:?}", f),
110                )
111                .ok()
112                .map(AggregateUDF::new_from_impl)
113        });
114
115        for func in state_func {
116            registry.register_aggr(func);
117        }
118    }
119
120    /// Split an aggregate plan into two aggregate plans, one for the state function and one for the merge function.
121    pub fn split_aggr_node(aggr_plan: Aggregate) -> datafusion_common::Result<StepAggrPlan> {
122        let aggr = {
123            // certain aggr func need type coercion to work correctly, so we need to analyze the plan first.
124            let aggr_plan = TypeCoercion::new().analyze(
125                LogicalPlan::Aggregate(aggr_plan).clone(),
126                &Default::default(),
127            )?;
128            if let LogicalPlan::Aggregate(aggr) = aggr_plan {
129                aggr
130            } else {
131                return Err(datafusion_common::DataFusionError::Internal(format!(
132                    "Failed to coerce expressions in aggregate plan, expected Aggregate, got: {:?}",
133                    aggr_plan
134                )));
135            }
136        };
137        let mut lower_aggr_exprs = vec![];
138        let mut upper_aggr_exprs = vec![];
139
140        for aggr_expr in aggr.aggr_expr.iter() {
141            let Some(aggr_func) = get_aggr_func(aggr_expr) else {
142                return Err(datafusion_common::DataFusionError::NotImplemented(format!(
143                    "Unsupported aggregate expression for step aggr optimize: {:?}",
144                    aggr_expr
145                )));
146            };
147
148            let original_input_types = aggr_func
149                .params
150                .args
151                .iter()
152                .map(|e| e.get_type(&aggr.input.schema()))
153                .collect::<Result<Vec<_>, _>>()?;
154
155            // first create the state function from the original aggregate function.
156            let state_func = StateWrapper::new((*aggr_func.func).clone())?;
157
158            let expr = AggregateFunction {
159                func: Arc::new(state_func.into()),
160                params: aggr_func.params.clone(),
161            };
162            let expr = Expr::AggregateFunction(expr);
163            let lower_state_output_col_name = expr.schema_name().to_string();
164
165            lower_aggr_exprs.push(expr);
166
167            let (original_phy_expr, _filter, _ordering) = create_aggregate_expr_and_maybe_filter(
168                aggr_expr,
169                aggr.input.schema(),
170                aggr.input.schema().as_arrow(),
171                &Default::default(),
172            )?;
173
174            let merge_func = MergeWrapper::new(
175                (*aggr_func.func).clone(),
176                original_phy_expr,
177                original_input_types,
178            )?;
179            let arg = Expr::Column(Column::new_unqualified(lower_state_output_col_name));
180            let expr = AggregateFunction {
181                func: Arc::new(merge_func.into()),
182                params: AggregateFunctionParams {
183                    args: vec![arg],
184                    ..aggr_func.params.clone()
185                },
186            };
187
188            // alias to the original aggregate expr's schema name, so parent plan can refer to it
189            // correctly.
190            let expr = Expr::AggregateFunction(expr).alias(aggr_expr.schema_name().to_string());
191            upper_aggr_exprs.push(expr);
192        }
193
194        let mut lower = aggr.clone();
195        lower.aggr_expr = lower_aggr_exprs;
196        let lower_plan = LogicalPlan::Aggregate(lower);
197
198        // update aggregate's output schema
199        let lower_plan = lower_plan.recompute_schema()?;
200
201        let mut upper = aggr.clone();
202        let aggr_plan = LogicalPlan::Aggregate(aggr);
203        upper.aggr_expr = upper_aggr_exprs;
204        upper.input = Arc::new(lower_plan.clone());
205        // upper schema's output schema should be the same as the original aggregate plan's output schema
206        let upper_check = upper;
207        let upper_plan = LogicalPlan::Aggregate(upper_check).recompute_schema()?;
208        if *upper_plan.schema() != *aggr_plan.schema() {
209            return Err(datafusion_common::DataFusionError::Internal(format!(
210                "Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[original]:{}",
211                upper_plan.schema(),
212                aggr_plan.schema()
213            )));
214        }
215
216        Ok(StepAggrPlan {
217            lower_state: lower_plan,
218            upper_merge: upper_plan,
219        })
220    }
221}
222
223/// Wrapper to make an aggregate function out of a state function.
224#[derive(Debug, Clone, PartialEq, Eq)]
225pub struct StateWrapper {
226    inner: AggregateUDF,
227    name: String,
228}
229
230impl StateWrapper {
231    /// `state_index`: The index of the state in the output of the state function.
232    pub fn new(inner: AggregateUDF) -> datafusion_common::Result<Self> {
233        let name = aggr_state_func_name(inner.name());
234        Ok(Self { inner, name })
235    }
236
237    pub fn inner(&self) -> &AggregateUDF {
238        &self.inner
239    }
240
241    /// Deduce the return type of the original aggregate function
242    /// based on the accumulator arguments.
243    ///
244    pub fn deduce_aggr_return_type(
245        &self,
246        acc_args: &datafusion_expr::function::AccumulatorArgs,
247    ) -> datafusion_common::Result<FieldRef> {
248        self.inner.return_field(acc_args.schema.fields())
249    }
250}
251
252impl AggregateUDFImpl for StateWrapper {
253    fn accumulator<'a, 'b>(
254        &'a self,
255        acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
256    ) -> datafusion_common::Result<Box<dyn Accumulator>> {
257        // fix and recover proper acc args for the original aggregate function.
258        let state_type = acc_args.return_type().clone();
259        let inner = {
260            let acc_args = datafusion_expr::function::AccumulatorArgs {
261                return_field: self.deduce_aggr_return_type(&acc_args)?,
262                schema: acc_args.schema,
263                ignore_nulls: acc_args.ignore_nulls,
264                order_bys: acc_args.order_bys,
265                is_reversed: acc_args.is_reversed,
266                name: acc_args.name,
267                is_distinct: acc_args.is_distinct,
268                exprs: acc_args.exprs,
269            };
270            self.inner.accumulator(acc_args)?
271        };
272        Ok(Box::new(StateAccum::new(inner, state_type)?))
273    }
274
275    fn as_any(&self) -> &dyn std::any::Any {
276        self
277    }
278    fn name(&self) -> &str {
279        self.name.as_str()
280    }
281
282    fn is_nullable(&self) -> bool {
283        self.inner.is_nullable()
284    }
285
286    /// Return state_fields as the output struct type.
287    ///
288    fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
289        let input_fields = &arg_types
290            .iter()
291            .map(|x| Arc::new(Field::new("x", x.clone(), false)))
292            .collect::<Vec<_>>();
293
294        let state_fields_args = StateFieldsArgs {
295            name: self.inner().name(),
296            input_fields,
297            return_field: self.inner.return_field(input_fields)?,
298            // TODO(discord9): how to get this?, probably ok?
299            ordering_fields: &[],
300            is_distinct: false,
301        };
302        let state_fields = self.inner.state_fields(state_fields_args)?;
303        let struct_field = DataType::Struct(state_fields.into());
304        Ok(struct_field)
305    }
306
307    /// The state function's output fields are the same as the original aggregate function's state fields.
308    fn state_fields(
309        &self,
310        args: datafusion_expr::function::StateFieldsArgs,
311    ) -> datafusion_common::Result<Vec<FieldRef>> {
312        let state_fields_args = StateFieldsArgs {
313            name: args.name,
314            input_fields: args.input_fields,
315            return_field: self.inner.return_field(args.input_fields)?,
316            ordering_fields: args.ordering_fields,
317            is_distinct: args.is_distinct,
318        };
319        self.inner.state_fields(state_fields_args)
320    }
321
322    /// The state function's signature is the same as the original aggregate function's signature,
323    fn signature(&self) -> &Signature {
324        self.inner.signature()
325    }
326
327    /// Coerce types also do nothing, as optimzer should be able to already make struct types
328    fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
329        self.inner.coerce_types(arg_types)
330    }
331}
332
333/// The wrapper's input is the same as the original aggregate function's input,
334/// and the output is the state function's output.
335#[derive(Debug)]
336pub struct StateAccum {
337    inner: Box<dyn Accumulator>,
338    state_fields: Fields,
339}
340
341impl StateAccum {
342    pub fn new(
343        inner: Box<dyn Accumulator>,
344        state_type: DataType,
345    ) -> datafusion_common::Result<Self> {
346        let DataType::Struct(fields) = state_type else {
347            return Err(datafusion_common::DataFusionError::Internal(format!(
348                "Expected a struct type for state, got: {:?}",
349                state_type
350            )));
351        };
352        Ok(Self {
353            inner,
354            state_fields: fields,
355        })
356    }
357}
358
359impl Accumulator for StateAccum {
360    fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
361        let state = self.inner.state()?;
362
363        let array = state
364            .iter()
365            .map(|s| s.to_array())
366            .collect::<Result<Vec<_>, _>>()?;
367        let struct_array = StructArray::try_new(self.state_fields.clone(), array, None)?;
368        Ok(ScalarValue::Struct(Arc::new(struct_array)))
369    }
370
371    fn merge_batch(
372        &mut self,
373        states: &[datatypes::arrow::array::ArrayRef],
374    ) -> datafusion_common::Result<()> {
375        self.inner.merge_batch(states)
376    }
377
378    fn update_batch(
379        &mut self,
380        values: &[datatypes::arrow::array::ArrayRef],
381    ) -> datafusion_common::Result<()> {
382        self.inner.update_batch(values)
383    }
384
385    fn size(&self) -> usize {
386        self.inner.size()
387    }
388
389    fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
390        self.inner.state()
391    }
392}
393
394/// TODO(discord9): mark this function as non-ser/de able
395///
396/// This wrapper shouldn't be register as a udaf, as it contain extra data that is not serializable.
397/// and changes for different logical plans.
398#[derive(Debug, Clone)]
399pub struct MergeWrapper {
400    inner: AggregateUDF,
401    name: String,
402    merge_signature: Signature,
403    /// The original physical expression of the aggregate function, can't store the original aggregate function directly, as PhysicalExpr didn't implement Any
404    original_phy_expr: Arc<AggregateFunctionExpr>,
405    original_input_types: Vec<DataType>,
406}
407impl MergeWrapper {
408    pub fn new(
409        inner: AggregateUDF,
410        original_phy_expr: Arc<AggregateFunctionExpr>,
411        original_input_types: Vec<DataType>,
412    ) -> datafusion_common::Result<Self> {
413        let name = aggr_merge_func_name(inner.name());
414        // the input type is actually struct type, which is the state fields of the original aggregate function.
415        let merge_signature = Signature::user_defined(datafusion_expr::Volatility::Immutable);
416
417        Ok(Self {
418            inner,
419            name,
420            merge_signature,
421            original_phy_expr,
422            original_input_types,
423        })
424    }
425
426    pub fn inner(&self) -> &AggregateUDF {
427        &self.inner
428    }
429}
430
431impl AggregateUDFImpl for MergeWrapper {
432    fn accumulator<'a, 'b>(
433        &'a self,
434        acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
435    ) -> datafusion_common::Result<Box<dyn Accumulator>> {
436        if acc_args.exprs.len() != 1
437            || !matches!(
438                acc_args.exprs[0].data_type(acc_args.schema)?,
439                DataType::Struct(_)
440            )
441        {
442            return Err(datafusion_common::DataFusionError::Internal(format!(
443                "Expected one struct type as input, got: {:?}",
444                acc_args.schema
445            )));
446        }
447        let input_type = acc_args.exprs[0].data_type(acc_args.schema)?;
448        let DataType::Struct(fields) = input_type else {
449            return Err(datafusion_common::DataFusionError::Internal(format!(
450                "Expected a struct type for input, got: {:?}",
451                input_type
452            )));
453        };
454
455        let inner_accum = self.original_phy_expr.create_accumulator()?;
456        Ok(Box::new(MergeAccum::new(inner_accum, &fields)))
457    }
458
459    fn as_any(&self) -> &dyn std::any::Any {
460        self
461    }
462    fn name(&self) -> &str {
463        self.name.as_str()
464    }
465
466    fn is_nullable(&self) -> bool {
467        self.inner.is_nullable()
468    }
469
470    /// Notice here the `arg_types` is actually the `state_fields`'s data types,
471    /// so return fixed return type instead of using `arg_types` to determine the return type.
472    fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
473        // The return type is the same as the original aggregate function's return type.
474        let ret_type = self.inner.return_type(&self.original_input_types)?;
475        Ok(ret_type)
476    }
477    fn signature(&self) -> &Signature {
478        &self.merge_signature
479    }
480
481    /// Coerce types also do nothing, as optimzer should be able to already make struct types
482    fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
483        // just check if the arg_types are only one and is struct array
484        if arg_types.len() != 1 || !matches!(arg_types.first(), Some(DataType::Struct(_))) {
485            return Err(datafusion_common::DataFusionError::Internal(format!(
486                "Expected one struct type as input, got: {:?}",
487                arg_types
488            )));
489        }
490        Ok(arg_types.to_vec())
491    }
492
493    /// Just return the original aggregate function's state fields.
494    fn state_fields(
495        &self,
496        _args: datafusion_expr::function::StateFieldsArgs,
497    ) -> datafusion_common::Result<Vec<FieldRef>> {
498        self.original_phy_expr.state_fields()
499    }
500}
501
502/// The merge accumulator, which modify `update_batch`'s behavior to accept one struct array which
503/// include the state fields of original aggregate function, and merge said states into original accumulator
504/// the output is the same as original aggregate function
505#[derive(Debug)]
506pub struct MergeAccum {
507    inner: Box<dyn Accumulator>,
508    state_fields: Fields,
509}
510
511impl MergeAccum {
512    pub fn new(inner: Box<dyn Accumulator>, state_fields: &Fields) -> Self {
513        Self {
514            inner,
515            state_fields: state_fields.clone(),
516        }
517    }
518}
519
520impl Accumulator for MergeAccum {
521    fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
522        self.inner.evaluate()
523    }
524
525    fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
526        self.inner.merge_batch(states)
527    }
528
529    fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
530        let value = values.first().ok_or_else(|| {
531            datafusion_common::DataFusionError::Internal("No values provided for merge".to_string())
532        })?;
533        // The input values are states from other accumulators, so we merge them.
534        let struct_arr = value
535            .as_any()
536            .downcast_ref::<StructArray>()
537            .ok_or_else(|| {
538                datafusion_common::DataFusionError::Internal(format!(
539                    "Expected StructArray, got: {:?}",
540                    value.data_type()
541                ))
542            })?;
543        let fields = struct_arr.fields();
544        if fields != &self.state_fields {
545            return Err(datafusion_common::DataFusionError::Internal(format!(
546                "Expected state fields: {:?}, got: {:?}",
547                self.state_fields, fields
548            )));
549        }
550
551        // now fields should be the same, so we can merge the batch
552        // by pass the columns as order should be the same
553        let state_columns = struct_arr.columns();
554        self.inner.merge_batch(state_columns)
555    }
556
557    fn size(&self) -> usize {
558        self.inner.size()
559    }
560
561    fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
562        self.inner.state()
563    }
564}
565
566#[cfg(test)]
567mod tests;