common_query/logical_plan/
udaf.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Udaf module contains functions and structs supporting user-defined aggregate functions.
16//!
17//! Modified from DataFusion.
18
19use std::any::Any;
20use std::fmt::{self, Debug, Formatter};
21use std::sync::Arc;
22
23use datafusion::arrow::datatypes::Field;
24use datafusion_common::Result;
25use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
26use datafusion_expr::{
27    Accumulator, AccumulatorFactoryFunction, AggregateUDF as DfAggregateUdf, AggregateUDFImpl,
28};
29use datatypes::arrow::datatypes::DataType as ArrowDataType;
30use datatypes::data_type::DataType;
31
32use crate::function::{
33    to_df_return_type, AccumulatorFunctionImpl, ReturnTypeFunction, StateTypeFunction,
34};
35use crate::logical_plan::accumulator::DfAccumulatorAdaptor;
36use crate::logical_plan::AggregateFunctionCreatorRef;
37use crate::signature::Signature;
38
39/// Logical representation of a user-defined aggregate function (UDAF)
40/// A UDAF is different from a UDF in that it is stateful across batches.
41#[derive(Clone)]
42pub struct AggregateFunction {
43    /// name
44    pub name: String,
45    /// signature
46    pub signature: Signature,
47    /// Return type
48    pub return_type: ReturnTypeFunction,
49    /// actual implementation
50    pub accumulator: AccumulatorFunctionImpl,
51    /// the accumulator's state's description as a function of the return type
52    pub state_type: StateTypeFunction,
53    /// the creator that creates aggregate functions
54    creator: AggregateFunctionCreatorRef,
55}
56
57impl Debug for AggregateFunction {
58    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
59        f.debug_struct("AggregateUDF")
60            .field("name", &self.name)
61            .field("signature", &self.signature)
62            .field("fun", &"<FUNC>")
63            .finish()
64    }
65}
66
67impl PartialEq for AggregateFunction {
68    fn eq(&self, other: &Self) -> bool {
69        self.name == other.name && self.signature == other.signature
70    }
71}
72
73impl AggregateFunction {
74    /// Create a new AggregateUDF
75    pub fn new(
76        name: String,
77        signature: Signature,
78        return_type: ReturnTypeFunction,
79        accumulator: AccumulatorFunctionImpl,
80        state_type: StateTypeFunction,
81        creator: AggregateFunctionCreatorRef,
82    ) -> Self {
83        Self {
84            name,
85            signature,
86            return_type,
87            accumulator,
88            state_type,
89            creator,
90        }
91    }
92}
93
94struct DfUdafAdapter {
95    name: String,
96    signature: datafusion_expr::Signature,
97    return_type_func: datafusion_expr::ReturnTypeFunction,
98    accumulator: AccumulatorFactoryFunction,
99    creator: AggregateFunctionCreatorRef,
100}
101
102impl Debug for DfUdafAdapter {
103    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
104        f.debug_struct("DfUdafAdapter")
105            .field("name", &self.name)
106            .field("signature", &self.signature)
107            .finish()
108    }
109}
110
111impl AggregateUDFImpl for DfUdafAdapter {
112    fn as_any(&self) -> &dyn Any {
113        self
114    }
115
116    fn name(&self) -> &str {
117        &self.name
118    }
119
120    fn signature(&self) -> &datafusion_expr::Signature {
121        &self.signature
122    }
123
124    fn return_type(&self, arg_types: &[ArrowDataType]) -> Result<ArrowDataType> {
125        (self.return_type_func)(arg_types).map(|x| x.as_ref().clone())
126    }
127
128    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
129        (self.accumulator)(acc_args)
130    }
131
132    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
133        let state_types = self.creator.state_types()?;
134        let fields = state_types
135            .into_iter()
136            .enumerate()
137            .map(|(i, t)| {
138                let name = format!("{}_{i}", args.name);
139                Field::new(name, t.as_arrow_type(), true)
140            })
141            .collect::<Vec<_>>();
142        Ok(fields)
143    }
144}
145
146impl From<AggregateFunction> for DfAggregateUdf {
147    fn from(udaf: AggregateFunction) -> Self {
148        DfAggregateUdf::new_from_impl(DfUdafAdapter {
149            name: udaf.name,
150            signature: udaf.signature.into(),
151            return_type_func: to_df_return_type(udaf.return_type),
152            accumulator: to_df_accumulator_func(udaf.accumulator, udaf.creator.clone()),
153            creator: udaf.creator,
154        })
155    }
156}
157
158fn to_df_accumulator_func(
159    accumulator: AccumulatorFunctionImpl,
160    creator: AggregateFunctionCreatorRef,
161) -> AccumulatorFactoryFunction {
162    Arc::new(move |_| {
163        let accumulator = accumulator()?;
164        let creator = creator.clone();
165        Ok(Box::new(DfAccumulatorAdaptor::new(accumulator, creator)) as _)
166    })
167}