1use std::hash::{Hash, Hasher};
26use std::sync::Arc;
27
28use arrow::array::StructArray;
29use arrow_schema::{FieldRef, Fields};
30use common_telemetry::debug;
31use datafusion::functions_aggregate::all_default_aggregate_functions;
32use datafusion::functions_aggregate::count::Count;
33use datafusion::functions_aggregate::min_max::{Max, Min};
34use datafusion::optimizer::AnalyzerRule;
35use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
36use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter;
37use datafusion_common::{Column, ScalarValue};
38use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
39use datafusion_expr::function::StateFieldsArgs;
40use datafusion_expr::{
41 Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, Expr, ExprSchemable, LogicalPlan,
42 Signature,
43};
44use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
45use datatypes::arrow::datatypes::{DataType, Field};
46
47use crate::aggrs::aggr_wrapper::fix_order::FixStateUdafOrderingAnalyzer;
48use crate::function_registry::{FUNCTION_REGISTRY, FunctionRegistry};
49
50pub mod fix_order;
51#[cfg(test)]
52mod tests;
53
54pub fn aggr_state_func_name(aggr_name: &str) -> String {
58 format!("__{}_state", aggr_name)
59}
60
61pub fn aggr_merge_func_name(aggr_name: &str) -> String {
65 format!("__{}_merge", aggr_name)
66}
67
68pub fn is_all_aggr_exprs_steppable(aggr_exprs: &[Expr]) -> bool {
73 aggr_exprs.iter().all(|expr| {
74 if let Some(aggr_func) = get_aggr_func(expr) {
75 if aggr_func.params.distinct {
76 return false;
79 }
80
81 FUNCTION_REGISTRY.is_aggr_func_exist(&aggr_state_func_name(aggr_func.func.name()))
83 } else {
84 false
85 }
86 })
87}
88
89pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
90 let mut expr_ref = expr;
91 while let Expr::Alias(alias) = expr_ref {
92 expr_ref = &alias.expr;
93 }
94 if let Expr::AggregateFunction(aggr_func) = expr_ref {
95 Some(aggr_func)
96 } else {
97 None
98 }
99}
100
101#[derive(Debug, Clone)]
106pub struct StateMergeHelper;
107
108#[allow(unused)]
110#[derive(Debug, Clone)]
111pub struct StepAggrPlan {
112 pub upper_merge: LogicalPlan,
114 pub lower_state: LogicalPlan,
116}
117
118impl StateMergeHelper {
119 pub fn register(registry: &FunctionRegistry) {
122 let all_default = all_default_aggregate_functions();
123 let greptime_custom_aggr_functions = registry.aggregate_functions();
124
125 let supported = all_default
127 .into_iter()
128 .chain(greptime_custom_aggr_functions.into_iter().map(Arc::new))
129 .collect::<Vec<_>>();
130 debug!(
131 "Registering state functions for supported: {:?}",
132 supported.iter().map(|f| f.name()).collect::<Vec<_>>()
133 );
134
135 let state_func = supported.into_iter().filter_map(|f| {
136 StateWrapper::new((*f).clone())
137 .inspect_err(
138 |e| common_telemetry::error!(e; "Failed to register state function for {:?}", f),
139 )
140 .ok()
141 .map(AggregateUDF::new_from_impl)
142 });
143
144 for func in state_func {
145 registry.register_aggr(func);
146 }
147 }
148
149 pub fn split_aggr_node(aggr_plan: Aggregate) -> datafusion_common::Result<StepAggrPlan> {
152 let aggr = {
153 let aggr_plan = TypeCoercion::new().analyze(
155 LogicalPlan::Aggregate(aggr_plan).clone(),
156 &Default::default(),
157 )?;
158 if let LogicalPlan::Aggregate(aggr) = aggr_plan {
159 aggr
160 } else {
161 return Err(datafusion_common::DataFusionError::Internal(format!(
162 "Failed to coerce expressions in aggregate plan, expected Aggregate, got: {:?}",
163 aggr_plan
164 )));
165 }
166 };
167 let mut lower_aggr_exprs = vec![];
168 let mut upper_aggr_exprs = vec![];
169
170 let upper_group_exprs = aggr
173 .group_expr
174 .iter()
175 .map(|c| c.qualified_name())
176 .map(|(r, c)| Expr::Column(Column::new(r, c)))
177 .collect();
178
179 for aggr_expr in aggr.aggr_expr.iter() {
180 let Some(aggr_func) = get_aggr_func(aggr_expr) else {
181 return Err(datafusion_common::DataFusionError::NotImplemented(format!(
182 "Unsupported aggregate expression for step aggr optimize: {:?}",
183 aggr_expr
184 )));
185 };
186
187 let original_input_types = aggr_func
188 .params
189 .args
190 .iter()
191 .map(|e| e.get_type(&aggr.input.schema()))
192 .collect::<Result<Vec<_>, _>>()?;
193
194 let state_func = StateWrapper::new((*aggr_func.func).clone())?;
196
197 let expr = AggregateFunction {
198 func: Arc::new(state_func.into()),
199 params: aggr_func.params.clone(),
200 };
201 let expr = Expr::AggregateFunction(expr);
202 let lower_state_output_col_name = expr.schema_name().to_string();
203
204 lower_aggr_exprs.push(expr);
205
206 let (original_phy_expr, _filter, _ordering) = create_aggregate_expr_and_maybe_filter(
208 aggr_expr,
209 aggr.input.schema(),
210 aggr.input.schema().as_arrow(),
211 &Default::default(),
212 )?;
213
214 let merge_func = MergeWrapper::new(
215 (*aggr_func.func).clone(),
216 original_phy_expr,
217 original_input_types,
218 )?;
219 let arg = Expr::Column(Column::new_unqualified(lower_state_output_col_name));
220 let expr = AggregateFunction {
221 func: Arc::new(merge_func.into()),
222 params: AggregateFunctionParams {
226 args: vec![arg],
227 distinct: aggr_func.params.distinct,
228 filter: None,
229 order_by: vec![],
230 null_treatment: aggr_func.params.null_treatment,
231 },
232 };
233
234 let expr = Expr::AggregateFunction(expr).alias(aggr_expr.schema_name().to_string());
237 upper_aggr_exprs.push(expr);
238 }
239
240 let mut lower = aggr.clone();
241 lower.aggr_expr = lower_aggr_exprs;
242 let lower_plan = LogicalPlan::Aggregate(lower);
243
244 let lower_plan = lower_plan.recompute_schema()?;
246
247 let fixed_lower_plan =
250 FixStateUdafOrderingAnalyzer.analyze(lower_plan, &Default::default())?;
251
252 let upper = Aggregate::try_new(
253 Arc::new(fixed_lower_plan.clone()),
254 upper_group_exprs,
255 upper_aggr_exprs.clone(),
256 )?;
257 let aggr_plan = LogicalPlan::Aggregate(aggr);
258
259 let upper_check = upper;
261 let upper_plan = LogicalPlan::Aggregate(upper_check).recompute_schema()?;
262 if *upper_plan.schema() != *aggr_plan.schema() {
263 return Err(datafusion_common::DataFusionError::Internal(format!(
264 "Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[original]:{}",
265 upper_plan.schema(),
266 aggr_plan.schema()
267 )));
268 }
269
270 Ok(StepAggrPlan {
271 lower_state: fixed_lower_plan,
272 upper_merge: upper_plan,
273 })
274 }
275}
276
277#[derive(Debug, Clone, PartialEq, Eq, Hash)]
279pub struct StateWrapper {
280 inner: AggregateUDF,
281 name: String,
282 ordering: Vec<FieldRef>,
284 distinct: bool,
286}
287
288impl StateWrapper {
289 pub fn new(inner: AggregateUDF) -> datafusion_common::Result<Self> {
291 let name = aggr_state_func_name(inner.name());
292 Ok(Self {
293 inner,
294 name,
295 ordering: vec![],
296 distinct: false,
297 })
298 }
299
300 pub fn inner(&self) -> &AggregateUDF {
301 &self.inner
302 }
303
304 pub fn deduce_aggr_return_type(
308 &self,
309 acc_args: &datafusion_expr::function::AccumulatorArgs,
310 ) -> datafusion_common::Result<FieldRef> {
311 let input_fields = acc_args
312 .exprs
313 .iter()
314 .map(|e| e.return_field(acc_args.schema))
315 .collect::<Result<Vec<_>, _>>()?;
316 self.inner.return_field(&input_fields).inspect_err(|e| {
317 common_telemetry::error!(
318 "StateWrapper: {:#?}\nacc_args:{:?}\nerror:{:?}",
319 &self,
320 &acc_args,
321 e
322 );
323 })
324 }
325}
326
327impl AggregateUDFImpl for StateWrapper {
328 fn accumulator<'a, 'b>(
329 &'a self,
330 acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
331 ) -> datafusion_common::Result<Box<dyn Accumulator>> {
332 let state_type = acc_args.return_type().clone();
334 let inner = {
335 let acc_args = datafusion_expr::function::AccumulatorArgs {
336 return_field: self.deduce_aggr_return_type(&acc_args)?,
337 schema: acc_args.schema,
338 ignore_nulls: acc_args.ignore_nulls,
339 order_bys: acc_args.order_bys,
340 is_reversed: acc_args.is_reversed,
341 name: acc_args.name,
342 is_distinct: acc_args.is_distinct,
343 exprs: acc_args.exprs,
344 };
345 self.inner.accumulator(acc_args)?
346 };
347
348 Ok(Box::new(StateAccum::new(inner, state_type)?))
349 }
350
351 fn as_any(&self) -> &dyn std::any::Any {
352 self
353 }
354 fn name(&self) -> &str {
355 self.name.as_str()
356 }
357
358 fn is_nullable(&self) -> bool {
359 self.inner.is_nullable()
360 }
361
362 fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
365 let input_fields = &arg_types
366 .iter()
367 .map(|x| Arc::new(Field::new("x", x.clone(), false)))
368 .collect::<Vec<_>>();
369
370 let state_fields_args = StateFieldsArgs {
371 name: self.inner().name(),
372 input_fields,
373 return_field: self.inner.return_field(input_fields)?,
374 ordering_fields: &self.ordering,
376 is_distinct: self.distinct,
377 };
378 let state_fields = self.inner.state_fields(state_fields_args)?;
379
380 let state_fields = state_fields
381 .into_iter()
382 .map(|f| {
383 let mut f = f.as_ref().clone();
384 f.set_nullable(true);
386 Arc::new(f)
387 })
388 .collect::<Vec<_>>();
389
390 let struct_field = DataType::Struct(state_fields.into());
391 Ok(struct_field)
392 }
393
394 fn state_fields(
396 &self,
397 args: datafusion_expr::function::StateFieldsArgs,
398 ) -> datafusion_common::Result<Vec<FieldRef>> {
399 let state_fields_args = StateFieldsArgs {
400 name: args.name,
401 input_fields: args.input_fields,
402 return_field: self.inner.return_field(args.input_fields)?,
403 ordering_fields: args.ordering_fields,
404 is_distinct: args.is_distinct,
405 };
406 self.inner.state_fields(state_fields_args)
407 }
408
409 fn signature(&self) -> &Signature {
411 self.inner.signature()
412 }
413
414 fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
416 self.inner.coerce_types(arg_types)
417 }
418
419 fn value_from_stats(
420 &self,
421 statistics_args: &datafusion_expr::StatisticsArgs,
422 ) -> Option<ScalarValue> {
423 let inner = self.inner().inner().as_any();
424 let can_use_stat = inner.is::<Count>() || inner.is::<Max>() || inner.is::<Min>();
427 if !can_use_stat {
428 return None;
429 }
430
431 let state_type = if let DataType::Struct(fields) = &statistics_args.return_type {
433 if fields.is_empty() {
434 return None;
435 }
436 fields[0].data_type().clone()
437 } else {
438 return None;
439 };
440
441 let fixed_args = datafusion_expr::StatisticsArgs {
442 statistics: statistics_args.statistics,
443 return_type: &state_type,
444 is_distinct: statistics_args.is_distinct,
445 exprs: statistics_args.exprs,
446 };
447
448 let ret = self.inner().value_from_stats(&fixed_args)?;
449
450 let fields = if let DataType::Struct(fields) = &statistics_args.return_type {
452 fields
453 } else {
454 return None;
455 };
456
457 let array = ret.to_array().ok()?;
458
459 let struct_array = StructArray::new(fields.clone(), vec![array], None);
460 let ret = ScalarValue::Struct(Arc::new(struct_array));
461 Some(ret)
462 }
463}
464
465#[derive(Debug)]
468pub struct StateAccum {
469 inner: Box<dyn Accumulator>,
470 state_fields: Fields,
471}
472
473impl StateAccum {
474 pub fn new(
475 inner: Box<dyn Accumulator>,
476 state_type: DataType,
477 ) -> datafusion_common::Result<Self> {
478 let DataType::Struct(fields) = state_type else {
479 return Err(datafusion_common::DataFusionError::Internal(format!(
480 "Expected a struct type for state, got: {:?}",
481 state_type
482 )));
483 };
484 Ok(Self {
485 inner,
486 state_fields: fields,
487 })
488 }
489}
490
491impl Accumulator for StateAccum {
492 fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
493 let state = self.inner.state()?;
494
495 let array = state
496 .iter()
497 .map(|s| s.to_array())
498 .collect::<Result<Vec<_>, _>>()?;
499 let array_type = array
500 .iter()
501 .map(|a| a.data_type().clone())
502 .collect::<Vec<_>>();
503 let expected_type: Vec<_> = self
504 .state_fields
505 .iter()
506 .map(|f| f.data_type().clone())
507 .collect();
508 if array_type != expected_type {
509 debug!(
510 "State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}",
511 self.state_fields.len(),
512 array.len(),
513 self.state_fields,
514 array_type,
515 );
516 let guess_schema = array
517 .iter()
518 .enumerate()
519 .map(|(index, array)| {
520 Field::new(
521 format!("col_{index}[mismatch_state]").as_str(),
522 array.data_type().clone(),
523 true,
524 )
525 })
526 .collect::<Fields>();
527 let arr = StructArray::try_new(guess_schema, array, None)?;
528
529 return Ok(ScalarValue::Struct(Arc::new(arr)));
530 }
531
532 let struct_array = StructArray::try_new(self.state_fields.clone(), array, None)?;
533 Ok(ScalarValue::Struct(Arc::new(struct_array)))
534 }
535
536 fn merge_batch(
537 &mut self,
538 states: &[datatypes::arrow::array::ArrayRef],
539 ) -> datafusion_common::Result<()> {
540 self.inner.merge_batch(states)
541 }
542
543 fn update_batch(
544 &mut self,
545 values: &[datatypes::arrow::array::ArrayRef],
546 ) -> datafusion_common::Result<()> {
547 self.inner.update_batch(values)
548 }
549
550 fn size(&self) -> usize {
551 self.inner.size()
552 }
553
554 fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
555 self.inner.state()
556 }
557}
558
559#[derive(Debug, Clone)]
564pub struct MergeWrapper {
565 inner: AggregateUDF,
566 name: String,
567 merge_signature: Signature,
568 original_phy_expr: Arc<AggregateFunctionExpr>,
570 return_type: DataType,
571}
572impl MergeWrapper {
573 pub fn new(
574 inner: AggregateUDF,
575 original_phy_expr: Arc<AggregateFunctionExpr>,
576 original_input_types: Vec<DataType>,
577 ) -> datafusion_common::Result<Self> {
578 let name = aggr_merge_func_name(inner.name());
579 let merge_signature = Signature::user_defined(datafusion_expr::Volatility::Immutable);
581 let return_type = inner.return_type(&original_input_types)?;
582
583 Ok(Self {
584 inner,
585 name,
586 merge_signature,
587 original_phy_expr,
588 return_type,
589 })
590 }
591
592 pub fn inner(&self) -> &AggregateUDF {
593 &self.inner
594 }
595}
596
597impl AggregateUDFImpl for MergeWrapper {
598 fn accumulator<'a, 'b>(
599 &'a self,
600 acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
601 ) -> datafusion_common::Result<Box<dyn Accumulator>> {
602 if acc_args.exprs.len() != 1
603 || !matches!(
604 acc_args.exprs[0].data_type(acc_args.schema)?,
605 DataType::Struct(_)
606 )
607 {
608 return Err(datafusion_common::DataFusionError::Internal(format!(
609 "Expected one struct type as input, got: {:?}",
610 acc_args.schema
611 )));
612 }
613 let input_type = acc_args.exprs[0].data_type(acc_args.schema)?;
614 let DataType::Struct(fields) = input_type else {
615 return Err(datafusion_common::DataFusionError::Internal(format!(
616 "Expected a struct type for input, got: {:?}",
617 input_type
618 )));
619 };
620
621 let inner_accum = self.original_phy_expr.create_accumulator()?;
622 Ok(Box::new(MergeAccum::new(inner_accum, &fields)))
623 }
624
625 fn as_any(&self) -> &dyn std::any::Any {
626 self
627 }
628 fn name(&self) -> &str {
629 self.name.as_str()
630 }
631
632 fn is_nullable(&self) -> bool {
633 self.inner.is_nullable()
634 }
635
636 fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
639 Ok(self.return_type.clone())
641 }
642 fn signature(&self) -> &Signature {
643 &self.merge_signature
644 }
645
646 fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
648 if arg_types.len() != 1 || !matches!(arg_types.first(), Some(DataType::Struct(_))) {
650 return Err(datafusion_common::DataFusionError::Internal(format!(
651 "Expected one struct type as input, got: {:?}",
652 arg_types
653 )));
654 }
655 Ok(arg_types.to_vec())
656 }
657
658 fn state_fields(
660 &self,
661 _args: datafusion_expr::function::StateFieldsArgs,
662 ) -> datafusion_common::Result<Vec<FieldRef>> {
663 self.original_phy_expr.state_fields()
664 }
665}
666
667impl PartialEq for MergeWrapper {
668 fn eq(&self, other: &Self) -> bool {
669 self.inner == other.inner
670 }
671}
672
673impl Eq for MergeWrapper {}
674
675impl Hash for MergeWrapper {
676 fn hash<H: Hasher>(&self, state: &mut H) {
677 self.inner.hash(state);
678 }
679}
680
681#[derive(Debug)]
685pub struct MergeAccum {
686 inner: Box<dyn Accumulator>,
687 state_fields: Fields,
688}
689
690impl MergeAccum {
691 pub fn new(inner: Box<dyn Accumulator>, state_fields: &Fields) -> Self {
692 Self {
693 inner,
694 state_fields: state_fields.clone(),
695 }
696 }
697}
698
699impl Accumulator for MergeAccum {
700 fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
701 self.inner.evaluate()
702 }
703
704 fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
705 self.inner.merge_batch(states)
706 }
707
708 fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
709 let value = values.first().ok_or_else(|| {
710 datafusion_common::DataFusionError::Internal("No values provided for merge".to_string())
711 })?;
712 let struct_arr = value
714 .as_any()
715 .downcast_ref::<StructArray>()
716 .ok_or_else(|| {
717 datafusion_common::DataFusionError::Internal(format!(
718 "Expected StructArray, got: {:?}",
719 value.data_type()
720 ))
721 })?;
722 let fields = struct_arr.fields();
723 if fields != &self.state_fields {
724 debug!(
725 "State fields mismatch, expected: {:?}, got: {:?}",
726 self.state_fields, fields
727 );
728 }
730
731 let state_columns = struct_arr.columns();
734 self.inner.merge_batch(state_columns)
735 }
736
737 fn size(&self) -> usize {
738 self.inner.size()
739 }
740
741 fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
742 self.inner.state()
743 }
744}