1use std::sync::Arc;
26
27use arrow::array::StructArray;
28use arrow_schema::{FieldRef, Fields};
29use common_telemetry::debug;
30use datafusion::functions_aggregate::all_default_aggregate_functions;
31use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
32use datafusion::optimizer::AnalyzerRule;
33use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter;
34use datafusion_common::{Column, ScalarValue};
35use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
36use datafusion_expr::function::StateFieldsArgs;
37use datafusion_expr::{
38 Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, Expr, ExprSchemable, LogicalPlan,
39 Signature,
40};
41use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
42use datatypes::arrow::datatypes::{DataType, Field};
43
44use crate::function_registry::FunctionRegistry;
45
46pub fn aggr_state_func_name(aggr_name: &str) -> String {
50 format!("__{}_state", aggr_name)
51}
52
53pub fn aggr_merge_func_name(aggr_name: &str) -> String {
57 format!("__{}_merge", aggr_name)
58}
59
60#[derive(Debug, Clone)]
65pub struct StateMergeHelper;
66
67#[allow(unused)]
69#[derive(Debug, Clone)]
70pub struct StepAggrPlan {
71 pub upper_merge: LogicalPlan,
73 pub lower_state: LogicalPlan,
75}
76
77pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
78 let mut expr_ref = expr;
79 while let Expr::Alias(alias) = expr_ref {
80 expr_ref = &alias.expr;
81 }
82 if let Expr::AggregateFunction(aggr_func) = expr_ref {
83 Some(aggr_func)
84 } else {
85 None
86 }
87}
88
89impl StateMergeHelper {
90 pub fn register(registry: &FunctionRegistry) {
93 let all_default = all_default_aggregate_functions();
94 let greptime_custom_aggr_functions = registry.aggregate_functions();
95
96 let supported = all_default
98 .into_iter()
99 .chain(greptime_custom_aggr_functions.into_iter().map(Arc::new))
100 .collect::<Vec<_>>();
101 debug!(
102 "Registering state functions for supported: {:?}",
103 supported.iter().map(|f| f.name()).collect::<Vec<_>>()
104 );
105
106 let state_func = supported.into_iter().filter_map(|f| {
107 StateWrapper::new((*f).clone())
108 .inspect_err(
109 |e| common_telemetry::error!(e; "Failed to register state function for {:?}", f),
110 )
111 .ok()
112 .map(AggregateUDF::new_from_impl)
113 });
114
115 for func in state_func {
116 registry.register_aggr(func);
117 }
118 }
119
120 pub fn split_aggr_node(aggr_plan: Aggregate) -> datafusion_common::Result<StepAggrPlan> {
122 let aggr = {
123 let aggr_plan = TypeCoercion::new().analyze(
125 LogicalPlan::Aggregate(aggr_plan).clone(),
126 &Default::default(),
127 )?;
128 if let LogicalPlan::Aggregate(aggr) = aggr_plan {
129 aggr
130 } else {
131 return Err(datafusion_common::DataFusionError::Internal(format!(
132 "Failed to coerce expressions in aggregate plan, expected Aggregate, got: {:?}",
133 aggr_plan
134 )));
135 }
136 };
137 let mut lower_aggr_exprs = vec![];
138 let mut upper_aggr_exprs = vec![];
139
140 for aggr_expr in aggr.aggr_expr.iter() {
141 let Some(aggr_func) = get_aggr_func(aggr_expr) else {
142 return Err(datafusion_common::DataFusionError::NotImplemented(format!(
143 "Unsupported aggregate expression for step aggr optimize: {:?}",
144 aggr_expr
145 )));
146 };
147
148 let original_input_types = aggr_func
149 .params
150 .args
151 .iter()
152 .map(|e| e.get_type(&aggr.input.schema()))
153 .collect::<Result<Vec<_>, _>>()?;
154
155 let state_func = StateWrapper::new((*aggr_func.func).clone())?;
157
158 let expr = AggregateFunction {
159 func: Arc::new(state_func.into()),
160 params: aggr_func.params.clone(),
161 };
162 let expr = Expr::AggregateFunction(expr);
163 let lower_state_output_col_name = expr.schema_name().to_string();
164
165 lower_aggr_exprs.push(expr);
166
167 let (original_phy_expr, _filter, _ordering) = create_aggregate_expr_and_maybe_filter(
168 aggr_expr,
169 aggr.input.schema(),
170 aggr.input.schema().as_arrow(),
171 &Default::default(),
172 )?;
173
174 let merge_func = MergeWrapper::new(
175 (*aggr_func.func).clone(),
176 original_phy_expr,
177 original_input_types,
178 )?;
179 let arg = Expr::Column(Column::new_unqualified(lower_state_output_col_name));
180 let expr = AggregateFunction {
181 func: Arc::new(merge_func.into()),
182 params: AggregateFunctionParams {
183 args: vec![arg],
184 ..aggr_func.params.clone()
185 },
186 };
187
188 let expr = Expr::AggregateFunction(expr).alias(aggr_expr.schema_name().to_string());
191 upper_aggr_exprs.push(expr);
192 }
193
194 let mut lower = aggr.clone();
195 lower.aggr_expr = lower_aggr_exprs;
196 let lower_plan = LogicalPlan::Aggregate(lower);
197
198 let lower_plan = lower_plan.recompute_schema()?;
200
201 let mut upper = aggr.clone();
202 let aggr_plan = LogicalPlan::Aggregate(aggr);
203 upper.aggr_expr = upper_aggr_exprs;
204 upper.input = Arc::new(lower_plan.clone());
205 let upper_check = upper;
207 let upper_plan = LogicalPlan::Aggregate(upper_check).recompute_schema()?;
208 if *upper_plan.schema() != *aggr_plan.schema() {
209 return Err(datafusion_common::DataFusionError::Internal(format!(
210 "Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[original]:{}",
211 upper_plan.schema(), aggr_plan.schema()
212 )));
213 }
214
215 Ok(StepAggrPlan {
216 lower_state: lower_plan,
217 upper_merge: upper_plan,
218 })
219 }
220}
221
222#[derive(Debug, Clone, PartialEq, Eq)]
224pub struct StateWrapper {
225 inner: AggregateUDF,
226 name: String,
227}
228
229impl StateWrapper {
230 pub fn new(inner: AggregateUDF) -> datafusion_common::Result<Self> {
232 let name = aggr_state_func_name(inner.name());
233 Ok(Self { inner, name })
234 }
235
236 pub fn inner(&self) -> &AggregateUDF {
237 &self.inner
238 }
239
240 pub fn deduce_aggr_return_type(
244 &self,
245 acc_args: &datafusion_expr::function::AccumulatorArgs,
246 ) -> datafusion_common::Result<FieldRef> {
247 self.inner.return_field(acc_args.schema.fields())
248 }
249}
250
251impl AggregateUDFImpl for StateWrapper {
252 fn accumulator<'a, 'b>(
253 &'a self,
254 acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
255 ) -> datafusion_common::Result<Box<dyn Accumulator>> {
256 let state_type = acc_args.return_type().clone();
258 let inner = {
259 let acc_args = datafusion_expr::function::AccumulatorArgs {
260 return_field: self.deduce_aggr_return_type(&acc_args)?,
261 schema: acc_args.schema,
262 ignore_nulls: acc_args.ignore_nulls,
263 order_bys: acc_args.order_bys,
264 is_reversed: acc_args.is_reversed,
265 name: acc_args.name,
266 is_distinct: acc_args.is_distinct,
267 exprs: acc_args.exprs,
268 };
269 self.inner.accumulator(acc_args)?
270 };
271 Ok(Box::new(StateAccum::new(inner, state_type)?))
272 }
273
274 fn as_any(&self) -> &dyn std::any::Any {
275 self
276 }
277 fn name(&self) -> &str {
278 self.name.as_str()
279 }
280
281 fn is_nullable(&self) -> bool {
282 self.inner.is_nullable()
283 }
284
285 fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
288 let input_fields = &arg_types
289 .iter()
290 .map(|x| Arc::new(Field::new("x", x.clone(), false)))
291 .collect::<Vec<_>>();
292
293 let state_fields_args = StateFieldsArgs {
294 name: self.inner().name(),
295 input_fields,
296 return_field: self.inner.return_field(input_fields)?,
297 ordering_fields: &[],
299 is_distinct: false,
300 };
301 let state_fields = self.inner.state_fields(state_fields_args)?;
302 let struct_field = DataType::Struct(state_fields.into());
303 Ok(struct_field)
304 }
305
306 fn state_fields(
308 &self,
309 args: datafusion_expr::function::StateFieldsArgs,
310 ) -> datafusion_common::Result<Vec<FieldRef>> {
311 let state_fields_args = StateFieldsArgs {
312 name: args.name,
313 input_fields: args.input_fields,
314 return_field: self.inner.return_field(args.input_fields)?,
315 ordering_fields: args.ordering_fields,
316 is_distinct: args.is_distinct,
317 };
318 self.inner.state_fields(state_fields_args)
319 }
320
321 fn signature(&self) -> &Signature {
323 self.inner.signature()
324 }
325
326 fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
328 self.inner.coerce_types(arg_types)
329 }
330}
331
332#[derive(Debug)]
335pub struct StateAccum {
336 inner: Box<dyn Accumulator>,
337 state_fields: Fields,
338}
339
340impl StateAccum {
341 pub fn new(
342 inner: Box<dyn Accumulator>,
343 state_type: DataType,
344 ) -> datafusion_common::Result<Self> {
345 let DataType::Struct(fields) = state_type else {
346 return Err(datafusion_common::DataFusionError::Internal(format!(
347 "Expected a struct type for state, got: {:?}",
348 state_type
349 )));
350 };
351 Ok(Self {
352 inner,
353 state_fields: fields,
354 })
355 }
356}
357
358impl Accumulator for StateAccum {
359 fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
360 let state = self.inner.state()?;
361
362 let array = state
363 .iter()
364 .map(|s| s.to_array())
365 .collect::<Result<Vec<_>, _>>()?;
366 let struct_array = StructArray::try_new(self.state_fields.clone(), array, None)?;
367 Ok(ScalarValue::Struct(Arc::new(struct_array)))
368 }
369
370 fn merge_batch(
371 &mut self,
372 states: &[datatypes::arrow::array::ArrayRef],
373 ) -> datafusion_common::Result<()> {
374 self.inner.merge_batch(states)
375 }
376
377 fn update_batch(
378 &mut self,
379 values: &[datatypes::arrow::array::ArrayRef],
380 ) -> datafusion_common::Result<()> {
381 self.inner.update_batch(values)
382 }
383
384 fn size(&self) -> usize {
385 self.inner.size()
386 }
387
388 fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
389 self.inner.state()
390 }
391}
392
393#[derive(Debug, Clone)]
398pub struct MergeWrapper {
399 inner: AggregateUDF,
400 name: String,
401 merge_signature: Signature,
402 original_phy_expr: Arc<AggregateFunctionExpr>,
404 original_input_types: Vec<DataType>,
405}
406impl MergeWrapper {
407 pub fn new(
408 inner: AggregateUDF,
409 original_phy_expr: Arc<AggregateFunctionExpr>,
410 original_input_types: Vec<DataType>,
411 ) -> datafusion_common::Result<Self> {
412 let name = aggr_merge_func_name(inner.name());
413 let merge_signature = Signature::user_defined(datafusion_expr::Volatility::Immutable);
415
416 Ok(Self {
417 inner,
418 name,
419 merge_signature,
420 original_phy_expr,
421 original_input_types,
422 })
423 }
424
425 pub fn inner(&self) -> &AggregateUDF {
426 &self.inner
427 }
428}
429
430impl AggregateUDFImpl for MergeWrapper {
431 fn accumulator<'a, 'b>(
432 &'a self,
433 acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
434 ) -> datafusion_common::Result<Box<dyn Accumulator>> {
435 if acc_args.exprs.len() != 1
436 || !matches!(
437 acc_args.exprs[0].data_type(acc_args.schema)?,
438 DataType::Struct(_)
439 )
440 {
441 return Err(datafusion_common::DataFusionError::Internal(format!(
442 "Expected one struct type as input, got: {:?}",
443 acc_args.schema
444 )));
445 }
446 let input_type = acc_args.exprs[0].data_type(acc_args.schema)?;
447 let DataType::Struct(fields) = input_type else {
448 return Err(datafusion_common::DataFusionError::Internal(format!(
449 "Expected a struct type for input, got: {:?}",
450 input_type
451 )));
452 };
453
454 let inner_accum = self.original_phy_expr.create_accumulator()?;
455 Ok(Box::new(MergeAccum::new(inner_accum, &fields)))
456 }
457
458 fn as_any(&self) -> &dyn std::any::Any {
459 self
460 }
461 fn name(&self) -> &str {
462 self.name.as_str()
463 }
464
465 fn is_nullable(&self) -> bool {
466 self.inner.is_nullable()
467 }
468
469 fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
472 let ret_type = self.inner.return_type(&self.original_input_types)?;
474 Ok(ret_type)
475 }
476 fn signature(&self) -> &Signature {
477 &self.merge_signature
478 }
479
480 fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
482 if arg_types.len() != 1 || !matches!(arg_types.first(), Some(DataType::Struct(_))) {
484 return Err(datafusion_common::DataFusionError::Internal(format!(
485 "Expected one struct type as input, got: {:?}",
486 arg_types
487 )));
488 }
489 Ok(arg_types.to_vec())
490 }
491
492 fn state_fields(
494 &self,
495 _args: datafusion_expr::function::StateFieldsArgs,
496 ) -> datafusion_common::Result<Vec<FieldRef>> {
497 self.original_phy_expr.state_fields()
498 }
499}
500
501#[derive(Debug)]
505pub struct MergeAccum {
506 inner: Box<dyn Accumulator>,
507 state_fields: Fields,
508}
509
510impl MergeAccum {
511 pub fn new(inner: Box<dyn Accumulator>, state_fields: &Fields) -> Self {
512 Self {
513 inner,
514 state_fields: state_fields.clone(),
515 }
516 }
517}
518
519impl Accumulator for MergeAccum {
520 fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
521 self.inner.evaluate()
522 }
523
524 fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
525 self.inner.merge_batch(states)
526 }
527
528 fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
529 let value = values.first().ok_or_else(|| {
530 datafusion_common::DataFusionError::Internal("No values provided for merge".to_string())
531 })?;
532 let struct_arr = value
534 .as_any()
535 .downcast_ref::<StructArray>()
536 .ok_or_else(|| {
537 datafusion_common::DataFusionError::Internal(format!(
538 "Expected StructArray, got: {:?}",
539 value.data_type()
540 ))
541 })?;
542 let fields = struct_arr.fields();
543 if fields != &self.state_fields {
544 return Err(datafusion_common::DataFusionError::Internal(format!(
545 "Expected state fields: {:?}, got: {:?}",
546 self.state_fields, fields
547 )));
548 }
549
550 let state_columns = struct_arr.columns();
553 self.inner.merge_batch(state_columns)
554 }
555
556 fn size(&self) -> usize {
557 self.inner.size()
558 }
559
560 fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
561 self.inner.state()
562 }
563}
564
565#[cfg(test)]
566mod tests;