common_query/logical_plan/
udaf.rs1use std::any::Any;
20use std::fmt::{self, Debug, Formatter};
21use std::sync::Arc;
22
23use datafusion::arrow::datatypes::Field;
24use datafusion_common::Result;
25use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
26use datafusion_expr::{
27 Accumulator, AccumulatorFactoryFunction, AggregateUDF as DfAggregateUdf, AggregateUDFImpl,
28};
29use datatypes::arrow::datatypes::DataType as ArrowDataType;
30use datatypes::data_type::DataType;
31
32use crate::function::{
33 to_df_return_type, AccumulatorFunctionImpl, ReturnTypeFunction, StateTypeFunction,
34};
35use crate::logical_plan::accumulator::DfAccumulatorAdaptor;
36use crate::logical_plan::AggregateFunctionCreatorRef;
37use crate::signature::Signature;
38
39#[derive(Clone)]
42pub struct AggregateFunction {
43 pub name: String,
45 pub signature: Signature,
47 pub return_type: ReturnTypeFunction,
49 pub accumulator: AccumulatorFunctionImpl,
51 pub state_type: StateTypeFunction,
53 creator: AggregateFunctionCreatorRef,
55}
56
57impl Debug for AggregateFunction {
58 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
59 f.debug_struct("AggregateUDF")
60 .field("name", &self.name)
61 .field("signature", &self.signature)
62 .field("fun", &"<FUNC>")
63 .finish()
64 }
65}
66
67impl PartialEq for AggregateFunction {
68 fn eq(&self, other: &Self) -> bool {
69 self.name == other.name && self.signature == other.signature
70 }
71}
72
73impl AggregateFunction {
74 pub fn new(
76 name: String,
77 signature: Signature,
78 return_type: ReturnTypeFunction,
79 accumulator: AccumulatorFunctionImpl,
80 state_type: StateTypeFunction,
81 creator: AggregateFunctionCreatorRef,
82 ) -> Self {
83 Self {
84 name,
85 signature,
86 return_type,
87 accumulator,
88 state_type,
89 creator,
90 }
91 }
92}
93
94struct DfUdafAdapter {
95 name: String,
96 signature: datafusion_expr::Signature,
97 return_type_func: datafusion_expr::ReturnTypeFunction,
98 accumulator: AccumulatorFactoryFunction,
99 creator: AggregateFunctionCreatorRef,
100}
101
102impl Debug for DfUdafAdapter {
103 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
104 f.debug_struct("DfUdafAdapter")
105 .field("name", &self.name)
106 .field("signature", &self.signature)
107 .finish()
108 }
109}
110
111impl AggregateUDFImpl for DfUdafAdapter {
112 fn as_any(&self) -> &dyn Any {
113 self
114 }
115
116 fn name(&self) -> &str {
117 &self.name
118 }
119
120 fn signature(&self) -> &datafusion_expr::Signature {
121 &self.signature
122 }
123
124 fn return_type(&self, arg_types: &[ArrowDataType]) -> Result<ArrowDataType> {
125 (self.return_type_func)(arg_types).map(|x| x.as_ref().clone())
126 }
127
128 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
129 (self.accumulator)(acc_args)
130 }
131
132 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
133 let state_types = self.creator.state_types()?;
134 let fields = state_types
135 .into_iter()
136 .enumerate()
137 .map(|(i, t)| {
138 let name = format!("{}_{i}", args.name);
139 Field::new(name, t.as_arrow_type(), true)
140 })
141 .collect::<Vec<_>>();
142 Ok(fields)
143 }
144}
145
146impl From<AggregateFunction> for DfAggregateUdf {
147 fn from(udaf: AggregateFunction) -> Self {
148 DfAggregateUdf::new_from_impl(DfUdafAdapter {
149 name: udaf.name,
150 signature: udaf.signature.into(),
151 return_type_func: to_df_return_type(udaf.return_type),
152 accumulator: to_df_accumulator_func(udaf.accumulator, udaf.creator.clone()),
153 creator: udaf.creator,
154 })
155 }
156}
157
158fn to_df_accumulator_func(
159 accumulator: AccumulatorFunctionImpl,
160 creator: AggregateFunctionCreatorRef,
161) -> AccumulatorFactoryFunction {
162 Arc::new(move |_| {
163 let accumulator = accumulator()?;
164 let creator = creator.clone();
165 Ok(Box::new(DfAccumulatorAdaptor::new(accumulator, creator)) as _)
166 })
167}