1use std::sync::Arc;
26
27use arrow::array::StructArray;
28use arrow_schema::Fields;
29use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
30use datafusion::optimizer::AnalyzerRule;
31use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter;
32use datafusion_common::{Column, ScalarValue};
33use datafusion_expr::expr::AggregateFunction;
34use datafusion_expr::function::StateFieldsArgs;
35use datafusion_expr::{
36 Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, Expr, ExprSchemable, LogicalPlan,
37 Signature,
38};
39use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
40use datatypes::arrow::datatypes::{DataType, Field};
41
42pub fn aggr_state_func_name(aggr_name: &str) -> String {
46 format!("__{}_state", aggr_name)
47}
48
49pub fn aggr_merge_func_name(aggr_name: &str) -> String {
53 format!("__{}_merge", aggr_name)
54}
55
56#[derive(Debug, Clone)]
61pub struct StateMergeHelper;
62
63#[allow(unused)]
65#[derive(Debug, Clone)]
66pub struct StepAggrPlan {
67 pub upper_merge: Arc<LogicalPlan>,
69 pub lower_state: Arc<LogicalPlan>,
71}
72
73pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
74 let mut expr_ref = expr;
75 while let Expr::Alias(alias) = expr_ref {
76 expr_ref = &alias.expr;
77 }
78 if let Expr::AggregateFunction(aggr_func) = expr_ref {
79 Some(aggr_func)
80 } else {
81 None
82 }
83}
84
85impl StateMergeHelper {
86 pub fn split_aggr_node(aggr_plan: Aggregate) -> datafusion_common::Result<StepAggrPlan> {
88 let aggr = {
89 let aggr_plan = TypeCoercion::new().analyze(
91 LogicalPlan::Aggregate(aggr_plan).clone(),
92 &Default::default(),
93 )?;
94 if let LogicalPlan::Aggregate(aggr) = aggr_plan {
95 aggr
96 } else {
97 return Err(datafusion_common::DataFusionError::Internal(format!(
98 "Failed to coerce expressions in aggregate plan, expected Aggregate, got: {:?}",
99 aggr_plan
100 )));
101 }
102 };
103 let mut lower_aggr_exprs = vec![];
104 let mut upper_aggr_exprs = vec![];
105
106 for aggr_expr in aggr.aggr_expr.iter() {
107 let Some(aggr_func) = get_aggr_func(aggr_expr) else {
108 return Err(datafusion_common::DataFusionError::NotImplemented(format!(
109 "Unsupported aggregate expression for step aggr optimize: {:?}",
110 aggr_expr
111 )));
112 };
113
114 let original_input_types = aggr_func
115 .args
116 .iter()
117 .map(|e| e.get_type(&aggr.input.schema()))
118 .collect::<Result<Vec<_>, _>>()?;
119
120 let state_func = StateWrapper::new((*aggr_func.func).clone())?;
122
123 let expr = AggregateFunction {
124 func: Arc::new(state_func.into()),
125 args: aggr_func.args.clone(),
126 distinct: aggr_func.distinct,
127 filter: aggr_func.filter.clone(),
128 order_by: aggr_func.order_by.clone(),
129 null_treatment: aggr_func.null_treatment,
130 };
131 let expr = Expr::AggregateFunction(expr);
132 let lower_state_output_col_name = expr.schema_name().to_string();
133
134 lower_aggr_exprs.push(expr);
135
136 let (original_phy_expr, _filter, _ordering) = create_aggregate_expr_and_maybe_filter(
137 aggr_expr,
138 aggr.input.schema(),
139 aggr.input.schema().as_arrow(),
140 &Default::default(),
141 )?;
142
143 let merge_func = MergeWrapper::new(
144 (*aggr_func.func).clone(),
145 original_phy_expr,
146 original_input_types,
147 )?;
148 let arg = Expr::Column(Column::new_unqualified(lower_state_output_col_name));
149 let expr = AggregateFunction {
150 func: Arc::new(merge_func.into()),
151 args: vec![arg],
152 distinct: aggr_func.distinct,
153 filter: aggr_func.filter.clone(),
154 order_by: aggr_func.order_by.clone(),
155 null_treatment: aggr_func.null_treatment,
156 };
157
158 let expr = Expr::AggregateFunction(expr).alias(aggr_expr.schema_name().to_string());
161 upper_aggr_exprs.push(expr);
162 }
163
164 let mut lower = aggr.clone();
165 lower.aggr_expr = lower_aggr_exprs;
166 let lower_plan = LogicalPlan::Aggregate(lower);
167
168 let lower_plan = Arc::new(lower_plan.recompute_schema()?);
170
171 let mut upper = aggr.clone();
172 let aggr_plan = LogicalPlan::Aggregate(aggr);
173 upper.aggr_expr = upper_aggr_exprs;
174 upper.input = lower_plan.clone();
175 let upper_check = upper.clone();
177 let upper_plan = Arc::new(LogicalPlan::Aggregate(upper_check).recompute_schema()?);
178 if *upper_plan.schema() != *aggr_plan.schema() {
179 return Err(datafusion_common::DataFusionError::Internal(format!(
180 "Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[ original]{}",
181 upper_plan.schema(), aggr_plan.schema()
182 )));
183 }
184
185 Ok(StepAggrPlan {
186 lower_state: lower_plan,
187 upper_merge: upper_plan,
188 })
189 }
190}
191
192#[derive(Debug, Clone, PartialEq, Eq)]
194pub struct StateWrapper {
195 inner: AggregateUDF,
196 name: String,
197}
198
199impl StateWrapper {
200 pub fn new(inner: AggregateUDF) -> datafusion_common::Result<Self> {
202 let name = aggr_state_func_name(inner.name());
203 Ok(Self { inner, name })
204 }
205
206 pub fn inner(&self) -> &AggregateUDF {
207 &self.inner
208 }
209
210 pub fn deduce_aggr_return_type(
214 &self,
215 acc_args: &datafusion_expr::function::AccumulatorArgs,
216 ) -> datafusion_common::Result<DataType> {
217 let input_exprs = acc_args.exprs;
218 let input_schema = acc_args.schema;
219 let input_types = input_exprs
220 .iter()
221 .map(|e| e.data_type(input_schema))
222 .collect::<Result<Vec<_>, _>>()?;
223 let return_type = self.inner.return_type(&input_types)?;
224 Ok(return_type)
225 }
226}
227
228impl AggregateUDFImpl for StateWrapper {
229 fn accumulator<'a, 'b>(
230 &'a self,
231 acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
232 ) -> datafusion_common::Result<Box<dyn Accumulator>> {
233 let state_type = acc_args.return_type.clone();
235 let inner = {
236 let old_return_type = self.deduce_aggr_return_type(&acc_args)?;
237 let acc_args = datafusion_expr::function::AccumulatorArgs {
238 return_type: &old_return_type,
239 schema: acc_args.schema,
240 ignore_nulls: acc_args.ignore_nulls,
241 ordering_req: acc_args.ordering_req,
242 is_reversed: acc_args.is_reversed,
243 name: acc_args.name,
244 is_distinct: acc_args.is_distinct,
245 exprs: acc_args.exprs,
246 };
247 self.inner.accumulator(acc_args)?
248 };
249 Ok(Box::new(StateAccum::new(inner, state_type)?))
250 }
251
252 fn as_any(&self) -> &dyn std::any::Any {
253 self
254 }
255 fn name(&self) -> &str {
256 self.name.as_str()
257 }
258
259 fn is_nullable(&self) -> bool {
260 self.inner.is_nullable()
261 }
262
263 fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
266 let old_return_type = self.inner.return_type(arg_types)?;
267 let state_fields_args = StateFieldsArgs {
268 name: self.inner().name(),
269 input_types: arg_types,
270 return_type: &old_return_type,
271 ordering_fields: &[],
273 is_distinct: false,
274 };
275 let state_fields = self.inner.state_fields(state_fields_args)?;
276 let struct_field = DataType::Struct(state_fields.into());
277 Ok(struct_field)
278 }
279
280 fn state_fields(
282 &self,
283 args: datafusion_expr::function::StateFieldsArgs,
284 ) -> datafusion_common::Result<Vec<Field>> {
285 let old_return_type = self.inner.return_type(args.input_types)?;
286 let state_fields_args = StateFieldsArgs {
287 name: args.name,
288 input_types: args.input_types,
289 return_type: &old_return_type,
290 ordering_fields: args.ordering_fields,
291 is_distinct: args.is_distinct,
292 };
293 self.inner.state_fields(state_fields_args)
294 }
295
296 fn signature(&self) -> &Signature {
298 self.inner.signature()
299 }
300
301 fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
303 self.inner.coerce_types(arg_types)
304 }
305}
306
307#[derive(Debug)]
310pub struct StateAccum {
311 inner: Box<dyn Accumulator>,
312 state_fields: Fields,
313}
314
315impl StateAccum {
316 pub fn new(
317 inner: Box<dyn Accumulator>,
318 state_type: DataType,
319 ) -> datafusion_common::Result<Self> {
320 let DataType::Struct(fields) = state_type else {
321 return Err(datafusion_common::DataFusionError::Internal(format!(
322 "Expected a struct type for state, got: {:?}",
323 state_type
324 )));
325 };
326 Ok(Self {
327 inner,
328 state_fields: fields,
329 })
330 }
331}
332
333impl Accumulator for StateAccum {
334 fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
335 let state = self.inner.state()?;
336
337 let array = state
338 .iter()
339 .map(|s| s.to_array())
340 .collect::<Result<Vec<_>, _>>()?;
341 let struct_array = StructArray::try_new(self.state_fields.clone(), array, None)?;
342 Ok(ScalarValue::Struct(Arc::new(struct_array)))
343 }
344
345 fn merge_batch(
346 &mut self,
347 states: &[datatypes::arrow::array::ArrayRef],
348 ) -> datafusion_common::Result<()> {
349 self.inner.merge_batch(states)
350 }
351
352 fn update_batch(
353 &mut self,
354 values: &[datatypes::arrow::array::ArrayRef],
355 ) -> datafusion_common::Result<()> {
356 self.inner.update_batch(values)
357 }
358
359 fn size(&self) -> usize {
360 self.inner.size()
361 }
362
363 fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
364 self.inner.state()
365 }
366}
367
368#[derive(Debug, Clone)]
373pub struct MergeWrapper {
374 inner: AggregateUDF,
375 name: String,
376 merge_signature: Signature,
377 original_phy_expr: Arc<AggregateFunctionExpr>,
379 original_input_types: Vec<DataType>,
380}
381impl MergeWrapper {
382 pub fn new(
383 inner: AggregateUDF,
384 original_phy_expr: Arc<AggregateFunctionExpr>,
385 original_input_types: Vec<DataType>,
386 ) -> datafusion_common::Result<Self> {
387 let name = aggr_merge_func_name(inner.name());
388 let merge_signature = Signature::user_defined(datafusion_expr::Volatility::Immutable);
390
391 Ok(Self {
392 inner,
393 name,
394 merge_signature,
395 original_phy_expr,
396 original_input_types,
397 })
398 }
399
400 pub fn inner(&self) -> &AggregateUDF {
401 &self.inner
402 }
403}
404
405impl AggregateUDFImpl for MergeWrapper {
406 fn accumulator<'a, 'b>(
407 &'a self,
408 acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
409 ) -> datafusion_common::Result<Box<dyn Accumulator>> {
410 if acc_args.schema.fields().len() != 1
411 || !matches!(acc_args.schema.field(0).data_type(), DataType::Struct(_))
412 {
413 return Err(datafusion_common::DataFusionError::Internal(format!(
414 "Expected one struct type as input, got: {:?}",
415 acc_args.schema
416 )));
417 }
418 let input_type = acc_args.schema.field(0).data_type();
419 let DataType::Struct(fields) = input_type else {
420 return Err(datafusion_common::DataFusionError::Internal(format!(
421 "Expected a struct type for input, got: {:?}",
422 input_type
423 )));
424 };
425
426 let inner_accum = self.original_phy_expr.create_accumulator()?;
427 Ok(Box::new(MergeAccum::new(inner_accum, fields)))
428 }
429
430 fn as_any(&self) -> &dyn std::any::Any {
431 self
432 }
433 fn name(&self) -> &str {
434 self.name.as_str()
435 }
436
437 fn is_nullable(&self) -> bool {
438 self.inner.is_nullable()
439 }
440
441 fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
444 let ret_type = self.inner.return_type(&self.original_input_types)?;
446 Ok(ret_type)
447 }
448 fn signature(&self) -> &Signature {
449 &self.merge_signature
450 }
451
452 fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
454 if arg_types.len() != 1 || !matches!(arg_types.first(), Some(DataType::Struct(_))) {
456 return Err(datafusion_common::DataFusionError::Internal(format!(
457 "Expected one struct type as input, got: {:?}",
458 arg_types
459 )));
460 }
461 Ok(arg_types.to_vec())
462 }
463
464 fn state_fields(
466 &self,
467 _args: datafusion_expr::function::StateFieldsArgs,
468 ) -> datafusion_common::Result<Vec<Field>> {
469 self.original_phy_expr.state_fields()
470 }
471}
472
473#[derive(Debug)]
477pub struct MergeAccum {
478 inner: Box<dyn Accumulator>,
479 state_fields: Fields,
480}
481
482impl MergeAccum {
483 pub fn new(inner: Box<dyn Accumulator>, state_fields: &Fields) -> Self {
484 Self {
485 inner,
486 state_fields: state_fields.clone(),
487 }
488 }
489}
490
491impl Accumulator for MergeAccum {
492 fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
493 self.inner.evaluate()
494 }
495
496 fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
497 self.inner.merge_batch(states)
498 }
499
500 fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
501 let value = values.first().ok_or_else(|| {
502 datafusion_common::DataFusionError::Internal("No values provided for merge".to_string())
503 })?;
504 let struct_arr = value
506 .as_any()
507 .downcast_ref::<StructArray>()
508 .ok_or_else(|| {
509 datafusion_common::DataFusionError::Internal(format!(
510 "Expected StructArray, got: {:?}",
511 value.data_type()
512 ))
513 })?;
514 let fields = struct_arr.fields();
515 if fields != &self.state_fields {
516 return Err(datafusion_common::DataFusionError::Internal(format!(
517 "Expected state fields: {:?}, got: {:?}",
518 self.state_fields, fields
519 )));
520 }
521
522 let state_columns = struct_arr.columns();
525 self.inner.merge_batch(state_columns)
526 }
527
528 fn size(&self) -> usize {
529 self.inner.size()
530 }
531
532 fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
533 self.inner.state()
534 }
535}
536
537#[cfg(test)]
538mod tests;