common_function/aggrs/
aggr_wrapper.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Wrapper for making aggregate functions out of state/merge functions of original aggregate functions.
16//!
17//! i.e. for a aggregate function `foo`, we will have a state function `foo_state` and a merge function `foo_merge`.
18//!
19//! `foo_state`'s input args is the same as `foo`'s, and its output is a state object.
20//! Note that `foo_state` might have multiple output columns, so it's a struct array
21//! that each output column is a struct field.
22//! `foo_merge`'s input arg is the same as `foo_state`'s output, and its output is the same as `foo`'s input.
23//!
24
25use std::sync::Arc;
26
27use arrow::array::StructArray;
28use arrow_schema::{FieldRef, Fields};
29use common_telemetry::debug;
30use datafusion::functions_aggregate::all_default_aggregate_functions;
31use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
32use datafusion::optimizer::AnalyzerRule;
33use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter;
34use datafusion_common::{Column, ScalarValue};
35use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
36use datafusion_expr::function::StateFieldsArgs;
37use datafusion_expr::{
38    Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, Expr, ExprSchemable, LogicalPlan,
39    Signature,
40};
41use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
42use datatypes::arrow::datatypes::{DataType, Field};
43
44use crate::function_registry::FunctionRegistry;
45
46/// Returns the name of the state function for the given aggregate function name.
47/// The state function is used to compute the state of the aggregate function.
48/// The state function's name is in the format `__<aggr_name>_state
49pub fn aggr_state_func_name(aggr_name: &str) -> String {
50    format!("__{}_state", aggr_name)
51}
52
53/// Returns the name of the merge function for the given aggregate function name.
54/// The merge function is used to merge the states of the state functions.
55/// The merge function's name is in the format `__<aggr_name>_merge
56pub fn aggr_merge_func_name(aggr_name: &str) -> String {
57    format!("__{}_merge", aggr_name)
58}
59
60/// A wrapper to make an aggregate function out of the state and merge functions of the original aggregate function.
61/// It contains the original aggregate function, the state functions, and the merge function.
62///
63/// Notice state functions may have multiple output columns, so it's return type is always a struct array, and the merge function is used to merge the states of the state functions.
64#[derive(Debug, Clone)]
65pub struct StateMergeHelper;
66
67/// A struct to hold the two aggregate plans, one for the state function(lower) and one for the merge function(upper).
68#[allow(unused)]
69#[derive(Debug, Clone)]
70pub struct StepAggrPlan {
71    /// Upper merge plan, which is the aggregate plan that merges the states of the state function.
72    pub upper_merge: LogicalPlan,
73    /// Lower state plan, which is the aggregate plan that computes the state of the aggregate function.
74    pub lower_state: LogicalPlan,
75}
76
77pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
78    let mut expr_ref = expr;
79    while let Expr::Alias(alias) = expr_ref {
80        expr_ref = &alias.expr;
81    }
82    if let Expr::AggregateFunction(aggr_func) = expr_ref {
83        Some(aggr_func)
84    } else {
85        None
86    }
87}
88
89impl StateMergeHelper {
90    /// Register all the `state` function of supported aggregate functions.
91    /// Note that can't register `merge` function here, as it needs to be created from the original aggregate function with given input types.
92    pub fn register(registry: &FunctionRegistry) {
93        let all_default = all_default_aggregate_functions();
94        let greptime_custom_aggr_functions = registry.aggregate_functions();
95
96        // if our custom aggregate function have the same name as the default aggregate function, we will override it.
97        let supported = all_default
98            .into_iter()
99            .chain(greptime_custom_aggr_functions.into_iter().map(Arc::new))
100            .collect::<Vec<_>>();
101        debug!(
102            "Registering state functions for supported: {:?}",
103            supported.iter().map(|f| f.name()).collect::<Vec<_>>()
104        );
105
106        let state_func = supported.into_iter().filter_map(|f| {
107            StateWrapper::new((*f).clone())
108                .inspect_err(
109                    |e| common_telemetry::error!(e; "Failed to register state function for {:?}", f),
110                )
111                .ok()
112                .map(AggregateUDF::new_from_impl)
113        });
114
115        for func in state_func {
116            registry.register_aggr(func);
117        }
118    }
119
120    /// Split an aggregate plan into two aggregate plans, one for the state function and one for the merge function.
121    pub fn split_aggr_node(aggr_plan: Aggregate) -> datafusion_common::Result<StepAggrPlan> {
122        let aggr = {
123            // certain aggr func need type coercion to work correctly, so we need to analyze the plan first.
124            let aggr_plan = TypeCoercion::new().analyze(
125                LogicalPlan::Aggregate(aggr_plan).clone(),
126                &Default::default(),
127            )?;
128            if let LogicalPlan::Aggregate(aggr) = aggr_plan {
129                aggr
130            } else {
131                return Err(datafusion_common::DataFusionError::Internal(format!(
132                    "Failed to coerce expressions in aggregate plan, expected Aggregate, got: {:?}",
133                    aggr_plan
134                )));
135            }
136        };
137        let mut lower_aggr_exprs = vec![];
138        let mut upper_aggr_exprs = vec![];
139
140        for aggr_expr in aggr.aggr_expr.iter() {
141            let Some(aggr_func) = get_aggr_func(aggr_expr) else {
142                return Err(datafusion_common::DataFusionError::NotImplemented(format!(
143                    "Unsupported aggregate expression for step aggr optimize: {:?}",
144                    aggr_expr
145                )));
146            };
147
148            let original_input_types = aggr_func
149                .params
150                .args
151                .iter()
152                .map(|e| e.get_type(&aggr.input.schema()))
153                .collect::<Result<Vec<_>, _>>()?;
154
155            // first create the state function from the original aggregate function.
156            let state_func = StateWrapper::new((*aggr_func.func).clone())?;
157
158            let expr = AggregateFunction {
159                func: Arc::new(state_func.into()),
160                params: aggr_func.params.clone(),
161            };
162            let expr = Expr::AggregateFunction(expr);
163            let lower_state_output_col_name = expr.schema_name().to_string();
164
165            lower_aggr_exprs.push(expr);
166
167            let (original_phy_expr, _filter, _ordering) = create_aggregate_expr_and_maybe_filter(
168                aggr_expr,
169                aggr.input.schema(),
170                aggr.input.schema().as_arrow(),
171                &Default::default(),
172            )?;
173
174            let merge_func = MergeWrapper::new(
175                (*aggr_func.func).clone(),
176                original_phy_expr,
177                original_input_types,
178            )?;
179            let arg = Expr::Column(Column::new_unqualified(lower_state_output_col_name));
180            let expr = AggregateFunction {
181                func: Arc::new(merge_func.into()),
182                params: AggregateFunctionParams {
183                    args: vec![arg],
184                    ..aggr_func.params.clone()
185                },
186            };
187
188            // alias to the original aggregate expr's schema name, so parent plan can refer to it
189            // correctly.
190            let expr = Expr::AggregateFunction(expr).alias(aggr_expr.schema_name().to_string());
191            upper_aggr_exprs.push(expr);
192        }
193
194        let mut lower = aggr.clone();
195        lower.aggr_expr = lower_aggr_exprs;
196        let lower_plan = LogicalPlan::Aggregate(lower);
197
198        // update aggregate's output schema
199        let lower_plan = lower_plan.recompute_schema()?;
200
201        let mut upper = aggr.clone();
202        let aggr_plan = LogicalPlan::Aggregate(aggr);
203        upper.aggr_expr = upper_aggr_exprs;
204        upper.input = Arc::new(lower_plan.clone());
205        // upper schema's output schema should be the same as the original aggregate plan's output schema
206        let upper_check = upper;
207        let upper_plan = LogicalPlan::Aggregate(upper_check).recompute_schema()?;
208        if *upper_plan.schema() != *aggr_plan.schema() {
209            return Err(datafusion_common::DataFusionError::Internal(format!(
210                 "Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[original]:{}",
211                upper_plan.schema(), aggr_plan.schema()
212            )));
213        }
214
215        Ok(StepAggrPlan {
216            lower_state: lower_plan,
217            upper_merge: upper_plan,
218        })
219    }
220}
221
222/// Wrapper to make an aggregate function out of a state function.
223#[derive(Debug, Clone, PartialEq, Eq)]
224pub struct StateWrapper {
225    inner: AggregateUDF,
226    name: String,
227}
228
229impl StateWrapper {
230    /// `state_index`: The index of the state in the output of the state function.
231    pub fn new(inner: AggregateUDF) -> datafusion_common::Result<Self> {
232        let name = aggr_state_func_name(inner.name());
233        Ok(Self { inner, name })
234    }
235
236    pub fn inner(&self) -> &AggregateUDF {
237        &self.inner
238    }
239
240    /// Deduce the return type of the original aggregate function
241    /// based on the accumulator arguments.
242    ///
243    pub fn deduce_aggr_return_type(
244        &self,
245        acc_args: &datafusion_expr::function::AccumulatorArgs,
246    ) -> datafusion_common::Result<FieldRef> {
247        self.inner.return_field(acc_args.schema.fields())
248    }
249}
250
251impl AggregateUDFImpl for StateWrapper {
252    fn accumulator<'a, 'b>(
253        &'a self,
254        acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
255    ) -> datafusion_common::Result<Box<dyn Accumulator>> {
256        // fix and recover proper acc args for the original aggregate function.
257        let state_type = acc_args.return_type().clone();
258        let inner = {
259            let acc_args = datafusion_expr::function::AccumulatorArgs {
260                return_field: self.deduce_aggr_return_type(&acc_args)?,
261                schema: acc_args.schema,
262                ignore_nulls: acc_args.ignore_nulls,
263                order_bys: acc_args.order_bys,
264                is_reversed: acc_args.is_reversed,
265                name: acc_args.name,
266                is_distinct: acc_args.is_distinct,
267                exprs: acc_args.exprs,
268            };
269            self.inner.accumulator(acc_args)?
270        };
271        Ok(Box::new(StateAccum::new(inner, state_type)?))
272    }
273
274    fn as_any(&self) -> &dyn std::any::Any {
275        self
276    }
277    fn name(&self) -> &str {
278        self.name.as_str()
279    }
280
281    fn is_nullable(&self) -> bool {
282        self.inner.is_nullable()
283    }
284
285    /// Return state_fields as the output struct type.
286    ///
287    fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
288        let input_fields = &arg_types
289            .iter()
290            .map(|x| Arc::new(Field::new("x", x.clone(), false)))
291            .collect::<Vec<_>>();
292
293        let state_fields_args = StateFieldsArgs {
294            name: self.inner().name(),
295            input_fields,
296            return_field: self.inner.return_field(input_fields)?,
297            // TODO(discord9): how to get this?, probably ok?
298            ordering_fields: &[],
299            is_distinct: false,
300        };
301        let state_fields = self.inner.state_fields(state_fields_args)?;
302        let struct_field = DataType::Struct(state_fields.into());
303        Ok(struct_field)
304    }
305
306    /// The state function's output fields are the same as the original aggregate function's state fields.
307    fn state_fields(
308        &self,
309        args: datafusion_expr::function::StateFieldsArgs,
310    ) -> datafusion_common::Result<Vec<FieldRef>> {
311        let state_fields_args = StateFieldsArgs {
312            name: args.name,
313            input_fields: args.input_fields,
314            return_field: self.inner.return_field(args.input_fields)?,
315            ordering_fields: args.ordering_fields,
316            is_distinct: args.is_distinct,
317        };
318        self.inner.state_fields(state_fields_args)
319    }
320
321    /// The state function's signature is the same as the original aggregate function's signature,
322    fn signature(&self) -> &Signature {
323        self.inner.signature()
324    }
325
326    /// Coerce types also do nothing, as optimzer should be able to already make struct types
327    fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
328        self.inner.coerce_types(arg_types)
329    }
330}
331
332/// The wrapper's input is the same as the original aggregate function's input,
333/// and the output is the state function's output.
334#[derive(Debug)]
335pub struct StateAccum {
336    inner: Box<dyn Accumulator>,
337    state_fields: Fields,
338}
339
340impl StateAccum {
341    pub fn new(
342        inner: Box<dyn Accumulator>,
343        state_type: DataType,
344    ) -> datafusion_common::Result<Self> {
345        let DataType::Struct(fields) = state_type else {
346            return Err(datafusion_common::DataFusionError::Internal(format!(
347                "Expected a struct type for state, got: {:?}",
348                state_type
349            )));
350        };
351        Ok(Self {
352            inner,
353            state_fields: fields,
354        })
355    }
356}
357
358impl Accumulator for StateAccum {
359    fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
360        let state = self.inner.state()?;
361
362        let array = state
363            .iter()
364            .map(|s| s.to_array())
365            .collect::<Result<Vec<_>, _>>()?;
366        let struct_array = StructArray::try_new(self.state_fields.clone(), array, None)?;
367        Ok(ScalarValue::Struct(Arc::new(struct_array)))
368    }
369
370    fn merge_batch(
371        &mut self,
372        states: &[datatypes::arrow::array::ArrayRef],
373    ) -> datafusion_common::Result<()> {
374        self.inner.merge_batch(states)
375    }
376
377    fn update_batch(
378        &mut self,
379        values: &[datatypes::arrow::array::ArrayRef],
380    ) -> datafusion_common::Result<()> {
381        self.inner.update_batch(values)
382    }
383
384    fn size(&self) -> usize {
385        self.inner.size()
386    }
387
388    fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
389        self.inner.state()
390    }
391}
392
393/// TODO(discord9): mark this function as non-ser/de able
394///
395/// This wrapper shouldn't be register as a udaf, as it contain extra data that is not serializable.
396/// and changes for different logical plans.
397#[derive(Debug, Clone)]
398pub struct MergeWrapper {
399    inner: AggregateUDF,
400    name: String,
401    merge_signature: Signature,
402    /// The original physical expression of the aggregate function, can't store the original aggregate function directly, as PhysicalExpr didn't implement Any
403    original_phy_expr: Arc<AggregateFunctionExpr>,
404    original_input_types: Vec<DataType>,
405}
406impl MergeWrapper {
407    pub fn new(
408        inner: AggregateUDF,
409        original_phy_expr: Arc<AggregateFunctionExpr>,
410        original_input_types: Vec<DataType>,
411    ) -> datafusion_common::Result<Self> {
412        let name = aggr_merge_func_name(inner.name());
413        // the input type is actually struct type, which is the state fields of the original aggregate function.
414        let merge_signature = Signature::user_defined(datafusion_expr::Volatility::Immutable);
415
416        Ok(Self {
417            inner,
418            name,
419            merge_signature,
420            original_phy_expr,
421            original_input_types,
422        })
423    }
424
425    pub fn inner(&self) -> &AggregateUDF {
426        &self.inner
427    }
428}
429
430impl AggregateUDFImpl for MergeWrapper {
431    fn accumulator<'a, 'b>(
432        &'a self,
433        acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
434    ) -> datafusion_common::Result<Box<dyn Accumulator>> {
435        if acc_args.exprs.len() != 1
436            || !matches!(
437                acc_args.exprs[0].data_type(acc_args.schema)?,
438                DataType::Struct(_)
439            )
440        {
441            return Err(datafusion_common::DataFusionError::Internal(format!(
442                "Expected one struct type as input, got: {:?}",
443                acc_args.schema
444            )));
445        }
446        let input_type = acc_args.exprs[0].data_type(acc_args.schema)?;
447        let DataType::Struct(fields) = input_type else {
448            return Err(datafusion_common::DataFusionError::Internal(format!(
449                "Expected a struct type for input, got: {:?}",
450                input_type
451            )));
452        };
453
454        let inner_accum = self.original_phy_expr.create_accumulator()?;
455        Ok(Box::new(MergeAccum::new(inner_accum, &fields)))
456    }
457
458    fn as_any(&self) -> &dyn std::any::Any {
459        self
460    }
461    fn name(&self) -> &str {
462        self.name.as_str()
463    }
464
465    fn is_nullable(&self) -> bool {
466        self.inner.is_nullable()
467    }
468
469    /// Notice here the `arg_types` is actually the `state_fields`'s data types,
470    /// so return fixed return type instead of using `arg_types` to determine the return type.
471    fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
472        // The return type is the same as the original aggregate function's return type.
473        let ret_type = self.inner.return_type(&self.original_input_types)?;
474        Ok(ret_type)
475    }
476    fn signature(&self) -> &Signature {
477        &self.merge_signature
478    }
479
480    /// Coerce types also do nothing, as optimzer should be able to already make struct types
481    fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
482        // just check if the arg_types are only one and is struct array
483        if arg_types.len() != 1 || !matches!(arg_types.first(), Some(DataType::Struct(_))) {
484            return Err(datafusion_common::DataFusionError::Internal(format!(
485                "Expected one struct type as input, got: {:?}",
486                arg_types
487            )));
488        }
489        Ok(arg_types.to_vec())
490    }
491
492    /// Just return the original aggregate function's state fields.
493    fn state_fields(
494        &self,
495        _args: datafusion_expr::function::StateFieldsArgs,
496    ) -> datafusion_common::Result<Vec<FieldRef>> {
497        self.original_phy_expr.state_fields()
498    }
499}
500
501/// The merge accumulator, which modify `update_batch`'s behavior to accept one struct array which
502/// include the state fields of original aggregate function, and merge said states into original accumulator
503/// the output is the same as original aggregate function
504#[derive(Debug)]
505pub struct MergeAccum {
506    inner: Box<dyn Accumulator>,
507    state_fields: Fields,
508}
509
510impl MergeAccum {
511    pub fn new(inner: Box<dyn Accumulator>, state_fields: &Fields) -> Self {
512        Self {
513            inner,
514            state_fields: state_fields.clone(),
515        }
516    }
517}
518
519impl Accumulator for MergeAccum {
520    fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
521        self.inner.evaluate()
522    }
523
524    fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
525        self.inner.merge_batch(states)
526    }
527
528    fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
529        let value = values.first().ok_or_else(|| {
530            datafusion_common::DataFusionError::Internal("No values provided for merge".to_string())
531        })?;
532        // The input values are states from other accumulators, so we merge them.
533        let struct_arr = value
534            .as_any()
535            .downcast_ref::<StructArray>()
536            .ok_or_else(|| {
537                datafusion_common::DataFusionError::Internal(format!(
538                    "Expected StructArray, got: {:?}",
539                    value.data_type()
540                ))
541            })?;
542        let fields = struct_arr.fields();
543        if fields != &self.state_fields {
544            return Err(datafusion_common::DataFusionError::Internal(format!(
545                "Expected state fields: {:?}, got: {:?}",
546                self.state_fields, fields
547            )));
548        }
549
550        // now fields should be the same, so we can merge the batch
551        // by pass the columns as order should be the same
552        let state_columns = struct_arr.columns();
553        self.inner.merge_batch(state_columns)
554    }
555
556    fn size(&self) -> usize {
557        self.inner.size()
558    }
559
560    fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
561        self.inner.state()
562    }
563}
564
565#[cfg(test)]
566mod tests;