1use std::hash::{Hash, Hasher};
26use std::sync::Arc;
27
28use arrow::array::StructArray;
29use arrow_schema::{FieldRef, Fields};
30use common_telemetry::debug;
31use datafusion::functions_aggregate::all_default_aggregate_functions;
32use datafusion::functions_aggregate::count::Count;
33use datafusion::functions_aggregate::min_max::{Max, Min};
34use datafusion::optimizer::AnalyzerRule;
35use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
36use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter;
37use datafusion_common::{Column, ScalarValue};
38use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
39use datafusion_expr::function::StateFieldsArgs;
40use datafusion_expr::{
41 Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, Expr, ExprSchemable, LogicalPlan,
42 Signature,
43};
44use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
45use datatypes::arrow::datatypes::{DataType, Field};
46
47use crate::aggrs::aggr_wrapper::fix_order::FixStateUdafOrderingAnalyzer;
48use crate::function_registry::{FUNCTION_REGISTRY, FunctionRegistry};
49
50pub mod fix_order;
51#[cfg(test)]
52mod tests;
53
54pub fn aggr_state_func_name(aggr_name: &str) -> String {
58 format!("__{}_state", aggr_name)
59}
60
61pub fn aggr_merge_func_name(aggr_name: &str) -> String {
65 format!("__{}_merge", aggr_name)
66}
67
68pub fn is_all_aggr_exprs_steppable(aggr_exprs: &[Expr]) -> bool {
73 aggr_exprs.iter().all(|expr| {
74 if let Some(aggr_func) = get_aggr_func(expr) {
75 if aggr_func.params.distinct {
76 return false;
79 }
80
81 FUNCTION_REGISTRY.is_aggr_func_exist(&aggr_state_func_name(aggr_func.func.name()))
83 } else {
84 false
85 }
86 })
87}
88
89pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
90 let mut expr_ref = expr;
91 while let Expr::Alias(alias) = expr_ref {
92 expr_ref = &alias.expr;
93 }
94 if let Expr::AggregateFunction(aggr_func) = expr_ref {
95 Some(aggr_func)
96 } else {
97 None
98 }
99}
100
101#[derive(Debug, Clone)]
106pub struct StateMergeHelper;
107
108#[allow(unused)]
110#[derive(Debug, Clone)]
111pub struct StepAggrPlan {
112 pub upper_merge: LogicalPlan,
114 pub lower_state: LogicalPlan,
116}
117
118impl StateMergeHelper {
119 pub fn register(registry: &FunctionRegistry) {
122 let all_default = all_default_aggregate_functions();
123 let greptime_custom_aggr_functions = registry.aggregate_functions();
124
125 let supported = all_default
127 .into_iter()
128 .chain(greptime_custom_aggr_functions.into_iter().map(Arc::new))
129 .collect::<Vec<_>>();
130 debug!(
131 "Registering state functions for supported: {:?}",
132 supported.iter().map(|f| f.name()).collect::<Vec<_>>()
133 );
134
135 let state_func = supported.into_iter().filter_map(|f| {
136 StateWrapper::new((*f).clone())
137 .inspect_err(
138 |e| common_telemetry::error!(e; "Failed to register state function for {:?}", f),
139 )
140 .ok()
141 .map(AggregateUDF::new_from_impl)
142 });
143
144 for func in state_func {
145 registry.register_aggr(func);
146 }
147 }
148
149 pub fn split_aggr_node(aggr_plan: Aggregate) -> datafusion_common::Result<StepAggrPlan> {
152 let aggr = {
153 let aggr_plan = TypeCoercion::new().analyze(
155 LogicalPlan::Aggregate(aggr_plan).clone(),
156 &Default::default(),
157 )?;
158 if let LogicalPlan::Aggregate(aggr) = aggr_plan {
159 aggr
160 } else {
161 return Err(datafusion_common::DataFusionError::Internal(format!(
162 "Failed to coerce expressions in aggregate plan, expected Aggregate, got: {:?}",
163 aggr_plan
164 )));
165 }
166 };
167 let mut lower_aggr_exprs = vec![];
168 let mut upper_aggr_exprs = vec![];
169
170 let upper_group_exprs = aggr
173 .group_expr
174 .iter()
175 .map(|c| c.qualified_name())
176 .map(|(r, c)| Expr::Column(Column::new(r, c)))
177 .collect();
178
179 for aggr_expr in aggr.aggr_expr.iter() {
180 let Some(aggr_func) = get_aggr_func(aggr_expr) else {
181 return Err(datafusion_common::DataFusionError::NotImplemented(format!(
182 "Unsupported aggregate expression for step aggr optimize: {:?}",
183 aggr_expr
184 )));
185 };
186
187 let original_input_fields = aggr_func
188 .params
189 .args
190 .iter()
191 .map(|e| e.to_field(&aggr.input.schema()).map(|(_, field)| field))
192 .collect::<Result<Vec<_>, _>>()?;
193
194 let state_func = StateWrapper::new((*aggr_func.func).clone())?;
196
197 let expr = AggregateFunction {
198 func: Arc::new(state_func.into()),
199 params: aggr_func.params.clone(),
200 };
201 let expr = Expr::AggregateFunction(expr);
202 let lower_state_output_col_name = expr.schema_name().to_string();
203
204 lower_aggr_exprs.push(expr);
205
206 let (original_phy_expr, _filter, _ordering) = create_aggregate_expr_and_maybe_filter(
208 aggr_expr,
209 aggr.input.schema(),
210 aggr.input.schema().as_arrow(),
211 &Default::default(),
212 )?;
213
214 let merge_func = MergeWrapper::new(
215 (*aggr_func.func).clone(),
216 original_phy_expr,
217 original_input_fields,
218 )?;
219 let arg = Expr::Column(Column::new_unqualified(lower_state_output_col_name));
220 let expr = AggregateFunction {
221 func: Arc::new(merge_func.into()),
222 params: AggregateFunctionParams {
226 args: vec![arg],
227 distinct: aggr_func.params.distinct,
228 filter: None,
229 order_by: vec![],
230 null_treatment: aggr_func.params.null_treatment,
231 },
232 };
233
234 let expr = Expr::AggregateFunction(expr).alias(aggr_expr.schema_name().to_string());
237 upper_aggr_exprs.push(expr);
238 }
239
240 let mut lower = aggr.clone();
241 lower.aggr_expr = lower_aggr_exprs;
242 let lower_plan = LogicalPlan::Aggregate(lower);
243
244 let lower_plan = lower_plan.recompute_schema()?;
246
247 let fixed_lower_plan =
250 FixStateUdafOrderingAnalyzer.analyze(lower_plan, &Default::default())?;
251
252 let upper = Aggregate::try_new(
253 Arc::new(fixed_lower_plan.clone()),
254 upper_group_exprs,
255 upper_aggr_exprs.clone(),
256 )?;
257 let aggr_plan = LogicalPlan::Aggregate(aggr);
258
259 let upper_check = upper;
261 let upper_plan = LogicalPlan::Aggregate(upper_check).recompute_schema()?;
262 if *upper_plan.schema() != *aggr_plan.schema() {
263 return Err(datafusion_common::DataFusionError::Internal(format!(
264 "Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[original]:{}",
265 upper_plan.schema(),
266 aggr_plan.schema()
267 )));
268 }
269
270 Ok(StepAggrPlan {
271 lower_state: fixed_lower_plan,
272 upper_merge: upper_plan,
273 })
274 }
275}
276
277#[derive(Debug, Clone, PartialEq, Eq, Hash)]
279pub struct StateWrapper {
280 inner: AggregateUDF,
281 name: String,
282 ordering: Vec<FieldRef>,
284 distinct: bool,
286}
287
288impl StateWrapper {
289 pub fn new(inner: AggregateUDF) -> datafusion_common::Result<Self> {
291 let name = aggr_state_func_name(inner.name());
292 Ok(Self {
293 inner,
294 name,
295 ordering: vec![],
296 distinct: false,
297 })
298 }
299
300 pub fn inner(&self) -> &AggregateUDF {
301 &self.inner
302 }
303
304 pub fn deduce_aggr_return_type(
308 &self,
309 acc_args: &datafusion_expr::function::AccumulatorArgs,
310 ) -> datafusion_common::Result<FieldRef> {
311 let input_fields = acc_args
312 .exprs
313 .iter()
314 .map(|e| e.return_field(acc_args.schema))
315 .collect::<Result<Vec<_>, _>>()?;
316 self.inner.return_field(&input_fields).inspect_err(|e| {
317 common_telemetry::error!(
318 "StateWrapper: {:#?}\nacc_args:{:?}\nerror:{:?}",
319 &self,
320 &acc_args,
321 e
322 );
323 })
324 }
325}
326
327impl AggregateUDFImpl for StateWrapper {
328 fn accumulator<'a, 'b>(
329 &'a self,
330 acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
331 ) -> datafusion_common::Result<Box<dyn Accumulator>> {
332 let state_type = acc_args.return_type().clone();
334 let inner = {
335 let mut new_acc_args = acc_args.clone();
336 new_acc_args.return_field = self.deduce_aggr_return_type(&acc_args)?;
337 self.inner.accumulator(new_acc_args)?
338 };
339
340 Ok(Box::new(StateAccum::new(inner, state_type)?))
341 }
342
343 fn as_any(&self) -> &dyn std::any::Any {
344 self
345 }
346 fn name(&self) -> &str {
347 self.name.as_str()
348 }
349
350 fn is_nullable(&self) -> bool {
351 self.inner.is_nullable()
352 }
353
354 fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
357 let input_fields = &arg_types
358 .iter()
359 .map(|x| Arc::new(Field::new("x", x.clone(), false)))
360 .collect::<Vec<_>>();
361
362 let state_fields_args = StateFieldsArgs {
363 name: self.inner().name(),
364 input_fields,
365 return_field: self.inner.return_field(input_fields)?,
366 ordering_fields: &self.ordering,
368 is_distinct: self.distinct,
369 };
370 let state_fields = self.inner.state_fields(state_fields_args)?;
371
372 let state_fields = state_fields
373 .into_iter()
374 .map(|f| {
375 let mut f = f.as_ref().clone();
376 f.set_nullable(true);
378 Arc::new(f)
379 })
380 .collect::<Vec<_>>();
381
382 let struct_field = DataType::Struct(state_fields.into());
383 Ok(struct_field)
384 }
385
386 fn state_fields(
388 &self,
389 args: datafusion_expr::function::StateFieldsArgs,
390 ) -> datafusion_common::Result<Vec<FieldRef>> {
391 let state_fields_args = StateFieldsArgs {
392 name: args.name,
393 input_fields: args.input_fields,
394 return_field: self.inner.return_field(args.input_fields)?,
395 ordering_fields: args.ordering_fields,
396 is_distinct: args.is_distinct,
397 };
398 self.inner.state_fields(state_fields_args)
399 }
400
401 fn signature(&self) -> &Signature {
403 self.inner.signature()
404 }
405
406 fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
408 self.inner.coerce_types(arg_types)
409 }
410
411 fn value_from_stats(
412 &self,
413 statistics_args: &datafusion_expr::StatisticsArgs,
414 ) -> Option<ScalarValue> {
415 let inner = self.inner().inner().as_any();
416 let can_use_stat = inner.is::<Count>() || inner.is::<Max>() || inner.is::<Min>();
419 if !can_use_stat {
420 return None;
421 }
422
423 let state_type = if let DataType::Struct(fields) = &statistics_args.return_type {
425 if fields.is_empty() {
426 return None;
427 }
428 fields[0].data_type().clone()
429 } else {
430 return None;
431 };
432
433 let fixed_args = datafusion_expr::StatisticsArgs {
434 statistics: statistics_args.statistics,
435 return_type: &state_type,
436 is_distinct: statistics_args.is_distinct,
437 exprs: statistics_args.exprs,
438 };
439
440 let ret = self.inner().value_from_stats(&fixed_args)?;
441
442 let fields = if let DataType::Struct(fields) = &statistics_args.return_type {
444 fields
445 } else {
446 return None;
447 };
448
449 let array = ret.to_array().ok()?;
450
451 let struct_array = StructArray::new(fields.clone(), vec![array], None);
452 let ret = ScalarValue::Struct(Arc::new(struct_array));
453 Some(ret)
454 }
455}
456
457#[derive(Debug)]
460pub struct StateAccum {
461 inner: Box<dyn Accumulator>,
462 state_fields: Fields,
463}
464
465impl StateAccum {
466 pub fn new(
467 inner: Box<dyn Accumulator>,
468 state_type: DataType,
469 ) -> datafusion_common::Result<Self> {
470 let DataType::Struct(fields) = state_type else {
471 return Err(datafusion_common::DataFusionError::Internal(format!(
472 "Expected a struct type for state, got: {:?}",
473 state_type
474 )));
475 };
476 Ok(Self {
477 inner,
478 state_fields: fields,
479 })
480 }
481}
482
483impl Accumulator for StateAccum {
484 fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
485 let state = self.inner.state()?;
486
487 let array = state
488 .iter()
489 .map(|s| s.to_array())
490 .collect::<Result<Vec<_>, _>>()?;
491 let array_type = array
492 .iter()
493 .map(|a| a.data_type().clone())
494 .collect::<Vec<_>>();
495 let expected_type: Vec<_> = self
496 .state_fields
497 .iter()
498 .map(|f| f.data_type().clone())
499 .collect();
500 if array_type != expected_type {
501 debug!(
502 "State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}",
503 self.state_fields.len(),
504 array.len(),
505 self.state_fields,
506 array_type,
507 );
508 let guess_schema = array
509 .iter()
510 .enumerate()
511 .map(|(index, array)| {
512 Field::new(
513 format!("col_{index}[mismatch_state]").as_str(),
514 array.data_type().clone(),
515 true,
516 )
517 })
518 .collect::<Fields>();
519 let arr = StructArray::try_new(guess_schema, array, None)?;
520
521 return Ok(ScalarValue::Struct(Arc::new(arr)));
522 }
523
524 let struct_array = StructArray::try_new(self.state_fields.clone(), array, None)?;
525 Ok(ScalarValue::Struct(Arc::new(struct_array)))
526 }
527
528 fn merge_batch(
529 &mut self,
530 states: &[datatypes::arrow::array::ArrayRef],
531 ) -> datafusion_common::Result<()> {
532 self.inner.merge_batch(states)
533 }
534
535 fn update_batch(
536 &mut self,
537 values: &[datatypes::arrow::array::ArrayRef],
538 ) -> datafusion_common::Result<()> {
539 self.inner.update_batch(values)
540 }
541
542 fn size(&self) -> usize {
543 self.inner.size()
544 }
545
546 fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
547 self.inner.state()
548 }
549}
550
551#[derive(Debug, Clone)]
556pub struct MergeWrapper {
557 inner: AggregateUDF,
558 name: String,
559 merge_signature: Signature,
560 original_phy_expr: Arc<AggregateFunctionExpr>,
562 return_field: FieldRef,
563}
564impl MergeWrapper {
565 pub fn new(
566 inner: AggregateUDF,
567 original_phy_expr: Arc<AggregateFunctionExpr>,
568 original_input_fields: Vec<FieldRef>,
569 ) -> datafusion_common::Result<Self> {
570 let name = aggr_merge_func_name(inner.name());
571 let merge_signature = Signature::user_defined(datafusion_expr::Volatility::Immutable);
573 let return_field = inner.return_field(&original_input_fields)?.clone();
574
575 Ok(Self {
576 inner,
577 name,
578 merge_signature,
579 original_phy_expr,
580 return_field,
581 })
582 }
583
584 pub fn inner(&self) -> &AggregateUDF {
585 &self.inner
586 }
587}
588
589impl AggregateUDFImpl for MergeWrapper {
590 fn accumulator<'a, 'b>(
591 &'a self,
592 acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
593 ) -> datafusion_common::Result<Box<dyn Accumulator>> {
594 if acc_args.exprs.len() != 1
595 || !matches!(
596 acc_args.exprs[0].data_type(acc_args.schema)?,
597 DataType::Struct(_)
598 )
599 {
600 return Err(datafusion_common::DataFusionError::Internal(format!(
601 "Expected one struct type as input, got: {:?}",
602 acc_args.schema
603 )));
604 }
605 let input_type = acc_args.exprs[0].data_type(acc_args.schema)?;
606 let DataType::Struct(fields) = input_type else {
607 return Err(datafusion_common::DataFusionError::Internal(format!(
608 "Expected a struct type for input, got: {:?}",
609 input_type
610 )));
611 };
612
613 let inner_accum = self.original_phy_expr.create_accumulator()?;
614 Ok(Box::new(MergeAccum::new(inner_accum, &fields)))
615 }
616
617 fn as_any(&self) -> &dyn std::any::Any {
618 self
619 }
620 fn name(&self) -> &str {
621 self.name.as_str()
622 }
623
624 fn is_nullable(&self) -> bool {
625 self.inner.is_nullable()
626 }
627
628 fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
631 Ok(self.return_field.data_type().clone())
633 }
634
635 fn return_field(&self, _arg_fields: &[FieldRef]) -> datafusion_common::Result<FieldRef> {
637 Ok(self.return_field.clone())
638 }
639
640 fn signature(&self) -> &Signature {
641 &self.merge_signature
642 }
643
644 fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
646 if arg_types.len() != 1 || !matches!(arg_types.first(), Some(DataType::Struct(_))) {
648 return Err(datafusion_common::DataFusionError::Internal(format!(
649 "Expected one struct type as input, got: {:?}",
650 arg_types
651 )));
652 }
653 Ok(arg_types.to_vec())
654 }
655
656 fn state_fields(
658 &self,
659 _args: datafusion_expr::function::StateFieldsArgs,
660 ) -> datafusion_common::Result<Vec<FieldRef>> {
661 self.original_phy_expr.state_fields()
662 }
663}
664
665impl PartialEq for MergeWrapper {
666 fn eq(&self, other: &Self) -> bool {
667 self.inner == other.inner
668 }
669}
670
671impl Eq for MergeWrapper {}
672
673impl Hash for MergeWrapper {
674 fn hash<H: Hasher>(&self, state: &mut H) {
675 self.inner.hash(state);
676 }
677}
678
679#[derive(Debug)]
683pub struct MergeAccum {
684 inner: Box<dyn Accumulator>,
685 state_fields: Fields,
686}
687
688impl MergeAccum {
689 pub fn new(inner: Box<dyn Accumulator>, state_fields: &Fields) -> Self {
690 Self {
691 inner,
692 state_fields: state_fields.clone(),
693 }
694 }
695}
696
697impl Accumulator for MergeAccum {
698 fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
699 self.inner.evaluate()
700 }
701
702 fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
703 self.inner.merge_batch(states)
704 }
705
706 fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
707 let value = values.first().ok_or_else(|| {
708 datafusion_common::DataFusionError::Internal("No values provided for merge".to_string())
709 })?;
710 let struct_arr = value
712 .as_any()
713 .downcast_ref::<StructArray>()
714 .ok_or_else(|| {
715 datafusion_common::DataFusionError::Internal(format!(
716 "Expected StructArray, got: {:?}",
717 value.data_type()
718 ))
719 })?;
720 let fields = struct_arr.fields();
721 if fields != &self.state_fields {
722 debug!(
723 "State fields mismatch, expected: {:?}, got: {:?}",
724 self.state_fields, fields
725 );
726 }
728
729 let state_columns = struct_arr.columns();
732 self.inner.merge_batch(state_columns)
733 }
734
735 fn size(&self) -> usize {
736 self.inner.size()
737 }
738
739 fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
740 self.inner.state()
741 }
742}