1use std::sync::Arc;
26
27use arrow::array::StructArray;
28use arrow_schema::{FieldRef, Fields};
29use common_telemetry::debug;
30use datafusion::functions_aggregate::all_default_aggregate_functions;
31use datafusion::optimizer::AnalyzerRule;
32use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
33use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter;
34use datafusion_common::{Column, ScalarValue};
35use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
36use datafusion_expr::function::StateFieldsArgs;
37use datafusion_expr::{
38 Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, Expr, ExprSchemable, LogicalPlan,
39 Signature,
40};
41use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
42use datatypes::arrow::datatypes::{DataType, Field};
43
44use crate::aggrs::aggr_wrapper::fix_order::FixStateUdafOrderingAnalyzer;
45use crate::function_registry::{FUNCTION_REGISTRY, FunctionRegistry};
46
47pub mod fix_order;
48#[cfg(test)]
49mod tests;
50
51pub fn aggr_state_func_name(aggr_name: &str) -> String {
55 format!("__{}_state", aggr_name)
56}
57
58pub fn aggr_merge_func_name(aggr_name: &str) -> String {
62 format!("__{}_merge", aggr_name)
63}
64
65pub fn is_all_aggr_exprs_steppable(aggr_exprs: &[Expr]) -> bool {
70 aggr_exprs.iter().all(|expr| {
71 if let Some(aggr_func) = get_aggr_func(expr) {
72 if aggr_func.params.distinct {
73 return false;
76 }
77
78 FUNCTION_REGISTRY.is_aggr_func_exist(&aggr_state_func_name(aggr_func.func.name()))
80 } else {
81 false
82 }
83 })
84}
85
86pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
87 let mut expr_ref = expr;
88 while let Expr::Alias(alias) = expr_ref {
89 expr_ref = &alias.expr;
90 }
91 if let Expr::AggregateFunction(aggr_func) = expr_ref {
92 Some(aggr_func)
93 } else {
94 None
95 }
96}
97
98#[derive(Debug, Clone)]
103pub struct StateMergeHelper;
104
105#[allow(unused)]
107#[derive(Debug, Clone)]
108pub struct StepAggrPlan {
109 pub upper_merge: LogicalPlan,
111 pub lower_state: LogicalPlan,
113}
114
115impl StateMergeHelper {
116 pub fn register(registry: &FunctionRegistry) {
119 let all_default = all_default_aggregate_functions();
120 let greptime_custom_aggr_functions = registry.aggregate_functions();
121
122 let supported = all_default
124 .into_iter()
125 .chain(greptime_custom_aggr_functions.into_iter().map(Arc::new))
126 .collect::<Vec<_>>();
127 debug!(
128 "Registering state functions for supported: {:?}",
129 supported.iter().map(|f| f.name()).collect::<Vec<_>>()
130 );
131
132 let state_func = supported.into_iter().filter_map(|f| {
133 StateWrapper::new((*f).clone())
134 .inspect_err(
135 |e| common_telemetry::error!(e; "Failed to register state function for {:?}", f),
136 )
137 .ok()
138 .map(AggregateUDF::new_from_impl)
139 });
140
141 for func in state_func {
142 registry.register_aggr(func);
143 }
144 }
145
146 pub fn split_aggr_node(aggr_plan: Aggregate) -> datafusion_common::Result<StepAggrPlan> {
149 let aggr = {
150 let aggr_plan = TypeCoercion::new().analyze(
152 LogicalPlan::Aggregate(aggr_plan).clone(),
153 &Default::default(),
154 )?;
155 if let LogicalPlan::Aggregate(aggr) = aggr_plan {
156 aggr
157 } else {
158 return Err(datafusion_common::DataFusionError::Internal(format!(
159 "Failed to coerce expressions in aggregate plan, expected Aggregate, got: {:?}",
160 aggr_plan
161 )));
162 }
163 };
164 let mut lower_aggr_exprs = vec![];
165 let mut upper_aggr_exprs = vec![];
166
167 let upper_group_exprs = aggr
170 .group_expr
171 .iter()
172 .map(|c| c.qualified_name())
173 .map(|(r, c)| Expr::Column(Column::new(r, c)))
174 .collect();
175
176 for aggr_expr in aggr.aggr_expr.iter() {
177 let Some(aggr_func) = get_aggr_func(aggr_expr) else {
178 return Err(datafusion_common::DataFusionError::NotImplemented(format!(
179 "Unsupported aggregate expression for step aggr optimize: {:?}",
180 aggr_expr
181 )));
182 };
183
184 let original_input_types = aggr_func
185 .params
186 .args
187 .iter()
188 .map(|e| e.get_type(&aggr.input.schema()))
189 .collect::<Result<Vec<_>, _>>()?;
190
191 let state_func = StateWrapper::new((*aggr_func.func).clone())?;
193
194 let expr = AggregateFunction {
195 func: Arc::new(state_func.into()),
196 params: aggr_func.params.clone(),
197 };
198 let expr = Expr::AggregateFunction(expr);
199 let lower_state_output_col_name = expr.schema_name().to_string();
200
201 lower_aggr_exprs.push(expr);
202
203 let (original_phy_expr, _filter, _ordering) = create_aggregate_expr_and_maybe_filter(
205 aggr_expr,
206 aggr.input.schema(),
207 aggr.input.schema().as_arrow(),
208 &Default::default(),
209 )?;
210
211 let merge_func = MergeWrapper::new(
212 (*aggr_func.func).clone(),
213 original_phy_expr,
214 original_input_types,
215 )?;
216 let arg = Expr::Column(Column::new_unqualified(lower_state_output_col_name));
217 let expr = AggregateFunction {
218 func: Arc::new(merge_func.into()),
219 params: AggregateFunctionParams {
223 args: vec![arg],
224 distinct: aggr_func.params.distinct,
225 filter: None,
226 order_by: vec![],
227 null_treatment: aggr_func.params.null_treatment,
228 },
229 };
230
231 let expr = Expr::AggregateFunction(expr).alias(aggr_expr.schema_name().to_string());
234 upper_aggr_exprs.push(expr);
235 }
236
237 let mut lower = aggr.clone();
238 lower.aggr_expr = lower_aggr_exprs;
239 let lower_plan = LogicalPlan::Aggregate(lower);
240
241 let lower_plan = lower_plan.recompute_schema()?;
243
244 let fixed_lower_plan =
247 FixStateUdafOrderingAnalyzer.analyze(lower_plan, &Default::default())?;
248
249 let upper = Aggregate::try_new(
250 Arc::new(fixed_lower_plan.clone()),
251 upper_group_exprs,
252 upper_aggr_exprs.clone(),
253 )?;
254 let aggr_plan = LogicalPlan::Aggregate(aggr);
255
256 let upper_check = upper;
258 let upper_plan = LogicalPlan::Aggregate(upper_check).recompute_schema()?;
259 if *upper_plan.schema() != *aggr_plan.schema() {
260 return Err(datafusion_common::DataFusionError::Internal(format!(
261 "Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[original]:{}",
262 upper_plan.schema(),
263 aggr_plan.schema()
264 )));
265 }
266
267 Ok(StepAggrPlan {
268 lower_state: fixed_lower_plan,
269 upper_merge: upper_plan,
270 })
271 }
272}
273
274#[derive(Debug, Clone, PartialEq, Eq)]
276pub struct StateWrapper {
277 inner: AggregateUDF,
278 name: String,
279 ordering: Vec<FieldRef>,
281 distinct: bool,
283}
284
285impl StateWrapper {
286 pub fn new(inner: AggregateUDF) -> datafusion_common::Result<Self> {
288 let name = aggr_state_func_name(inner.name());
289 Ok(Self {
290 inner,
291 name,
292 ordering: vec![],
293 distinct: false,
294 })
295 }
296
297 pub fn inner(&self) -> &AggregateUDF {
298 &self.inner
299 }
300
301 pub fn deduce_aggr_return_type(
305 &self,
306 acc_args: &datafusion_expr::function::AccumulatorArgs,
307 ) -> datafusion_common::Result<FieldRef> {
308 let input_fields = acc_args
309 .exprs
310 .iter()
311 .map(|e| e.return_field(acc_args.schema))
312 .collect::<Result<Vec<_>, _>>()?;
313 self.inner.return_field(&input_fields).inspect_err(|e| {
314 common_telemetry::error!(
315 "StateWrapper: {:#?}\nacc_args:{:?}\nerror:{:?}",
316 &self,
317 &acc_args,
318 e
319 );
320 })
321 }
322}
323
324impl AggregateUDFImpl for StateWrapper {
325 fn accumulator<'a, 'b>(
326 &'a self,
327 acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
328 ) -> datafusion_common::Result<Box<dyn Accumulator>> {
329 let state_type = acc_args.return_type().clone();
331 let inner = {
332 let acc_args = datafusion_expr::function::AccumulatorArgs {
333 return_field: self.deduce_aggr_return_type(&acc_args)?,
334 schema: acc_args.schema,
335 ignore_nulls: acc_args.ignore_nulls,
336 order_bys: acc_args.order_bys,
337 is_reversed: acc_args.is_reversed,
338 name: acc_args.name,
339 is_distinct: acc_args.is_distinct,
340 exprs: acc_args.exprs,
341 };
342 self.inner.accumulator(acc_args)?
343 };
344
345 Ok(Box::new(StateAccum::new(inner, state_type)?))
346 }
347
348 fn as_any(&self) -> &dyn std::any::Any {
349 self
350 }
351 fn name(&self) -> &str {
352 self.name.as_str()
353 }
354
355 fn is_nullable(&self) -> bool {
356 self.inner.is_nullable()
357 }
358
359 fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
362 let input_fields = &arg_types
363 .iter()
364 .map(|x| Arc::new(Field::new("x", x.clone(), false)))
365 .collect::<Vec<_>>();
366
367 let state_fields_args = StateFieldsArgs {
368 name: self.inner().name(),
369 input_fields,
370 return_field: self.inner.return_field(input_fields)?,
371 ordering_fields: &self.ordering,
373 is_distinct: self.distinct,
374 };
375 let state_fields = self.inner.state_fields(state_fields_args)?;
376
377 let state_fields = state_fields
378 .into_iter()
379 .map(|f| {
380 let mut f = f.as_ref().clone();
381 f.set_nullable(true);
383 Arc::new(f)
384 })
385 .collect::<Vec<_>>();
386
387 let struct_field = DataType::Struct(state_fields.into());
388 Ok(struct_field)
389 }
390
391 fn state_fields(
393 &self,
394 args: datafusion_expr::function::StateFieldsArgs,
395 ) -> datafusion_common::Result<Vec<FieldRef>> {
396 let state_fields_args = StateFieldsArgs {
397 name: args.name,
398 input_fields: args.input_fields,
399 return_field: self.inner.return_field(args.input_fields)?,
400 ordering_fields: args.ordering_fields,
401 is_distinct: args.is_distinct,
402 };
403 self.inner.state_fields(state_fields_args)
404 }
405
406 fn signature(&self) -> &Signature {
408 self.inner.signature()
409 }
410
411 fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
413 self.inner.coerce_types(arg_types)
414 }
415}
416
417#[derive(Debug)]
420pub struct StateAccum {
421 inner: Box<dyn Accumulator>,
422 state_fields: Fields,
423}
424
425impl StateAccum {
426 pub fn new(
427 inner: Box<dyn Accumulator>,
428 state_type: DataType,
429 ) -> datafusion_common::Result<Self> {
430 let DataType::Struct(fields) = state_type else {
431 return Err(datafusion_common::DataFusionError::Internal(format!(
432 "Expected a struct type for state, got: {:?}",
433 state_type
434 )));
435 };
436 Ok(Self {
437 inner,
438 state_fields: fields,
439 })
440 }
441}
442
443impl Accumulator for StateAccum {
444 fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
445 let state = self.inner.state()?;
446
447 let array = state
448 .iter()
449 .map(|s| s.to_array())
450 .collect::<Result<Vec<_>, _>>()?;
451 let array_type = array
452 .iter()
453 .map(|a| a.data_type().clone())
454 .collect::<Vec<_>>();
455 let expected_type: Vec<_> = self
456 .state_fields
457 .iter()
458 .map(|f| f.data_type().clone())
459 .collect();
460 if array_type != expected_type {
461 debug!(
462 "State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}",
463 self.state_fields.len(),
464 array.len(),
465 self.state_fields,
466 array_type,
467 );
468 let guess_schema = array
469 .iter()
470 .enumerate()
471 .map(|(index, array)| {
472 Field::new(
473 format!("col_{index}[mismatch_state]").as_str(),
474 array.data_type().clone(),
475 true,
476 )
477 })
478 .collect::<Fields>();
479 let arr = StructArray::try_new(guess_schema, array, None)?;
480
481 return Ok(ScalarValue::Struct(Arc::new(arr)));
482 }
483
484 let struct_array = StructArray::try_new(self.state_fields.clone(), array, None)?;
485 Ok(ScalarValue::Struct(Arc::new(struct_array)))
486 }
487
488 fn merge_batch(
489 &mut self,
490 states: &[datatypes::arrow::array::ArrayRef],
491 ) -> datafusion_common::Result<()> {
492 self.inner.merge_batch(states)
493 }
494
495 fn update_batch(
496 &mut self,
497 values: &[datatypes::arrow::array::ArrayRef],
498 ) -> datafusion_common::Result<()> {
499 self.inner.update_batch(values)
500 }
501
502 fn size(&self) -> usize {
503 self.inner.size()
504 }
505
506 fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
507 self.inner.state()
508 }
509}
510
511#[derive(Debug, Clone)]
516pub struct MergeWrapper {
517 inner: AggregateUDF,
518 name: String,
519 merge_signature: Signature,
520 original_phy_expr: Arc<AggregateFunctionExpr>,
522 return_type: DataType,
523}
524impl MergeWrapper {
525 pub fn new(
526 inner: AggregateUDF,
527 original_phy_expr: Arc<AggregateFunctionExpr>,
528 original_input_types: Vec<DataType>,
529 ) -> datafusion_common::Result<Self> {
530 let name = aggr_merge_func_name(inner.name());
531 let merge_signature = Signature::user_defined(datafusion_expr::Volatility::Immutable);
533 let return_type = inner.return_type(&original_input_types)?;
534
535 Ok(Self {
536 inner,
537 name,
538 merge_signature,
539 original_phy_expr,
540 return_type,
541 })
542 }
543
544 pub fn inner(&self) -> &AggregateUDF {
545 &self.inner
546 }
547}
548
549impl AggregateUDFImpl for MergeWrapper {
550 fn accumulator<'a, 'b>(
551 &'a self,
552 acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
553 ) -> datafusion_common::Result<Box<dyn Accumulator>> {
554 if acc_args.exprs.len() != 1
555 || !matches!(
556 acc_args.exprs[0].data_type(acc_args.schema)?,
557 DataType::Struct(_)
558 )
559 {
560 return Err(datafusion_common::DataFusionError::Internal(format!(
561 "Expected one struct type as input, got: {:?}",
562 acc_args.schema
563 )));
564 }
565 let input_type = acc_args.exprs[0].data_type(acc_args.schema)?;
566 let DataType::Struct(fields) = input_type else {
567 return Err(datafusion_common::DataFusionError::Internal(format!(
568 "Expected a struct type for input, got: {:?}",
569 input_type
570 )));
571 };
572
573 let inner_accum = self.original_phy_expr.create_accumulator()?;
574 Ok(Box::new(MergeAccum::new(inner_accum, &fields)))
575 }
576
577 fn as_any(&self) -> &dyn std::any::Any {
578 self
579 }
580 fn name(&self) -> &str {
581 self.name.as_str()
582 }
583
584 fn is_nullable(&self) -> bool {
585 self.inner.is_nullable()
586 }
587
588 fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
591 Ok(self.return_type.clone())
593 }
594 fn signature(&self) -> &Signature {
595 &self.merge_signature
596 }
597
598 fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
600 if arg_types.len() != 1 || !matches!(arg_types.first(), Some(DataType::Struct(_))) {
602 return Err(datafusion_common::DataFusionError::Internal(format!(
603 "Expected one struct type as input, got: {:?}",
604 arg_types
605 )));
606 }
607 Ok(arg_types.to_vec())
608 }
609
610 fn state_fields(
612 &self,
613 _args: datafusion_expr::function::StateFieldsArgs,
614 ) -> datafusion_common::Result<Vec<FieldRef>> {
615 self.original_phy_expr.state_fields()
616 }
617}
618
619#[derive(Debug)]
623pub struct MergeAccum {
624 inner: Box<dyn Accumulator>,
625 state_fields: Fields,
626}
627
628impl MergeAccum {
629 pub fn new(inner: Box<dyn Accumulator>, state_fields: &Fields) -> Self {
630 Self {
631 inner,
632 state_fields: state_fields.clone(),
633 }
634 }
635}
636
637impl Accumulator for MergeAccum {
638 fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
639 self.inner.evaluate()
640 }
641
642 fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
643 self.inner.merge_batch(states)
644 }
645
646 fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
647 let value = values.first().ok_or_else(|| {
648 datafusion_common::DataFusionError::Internal("No values provided for merge".to_string())
649 })?;
650 let struct_arr = value
652 .as_any()
653 .downcast_ref::<StructArray>()
654 .ok_or_else(|| {
655 datafusion_common::DataFusionError::Internal(format!(
656 "Expected StructArray, got: {:?}",
657 value.data_type()
658 ))
659 })?;
660 let fields = struct_arr.fields();
661 if fields != &self.state_fields {
662 debug!(
663 "State fields mismatch, expected: {:?}, got: {:?}",
664 self.state_fields, fields
665 );
666 }
668
669 let state_columns = struct_arr.columns();
672 self.inner.merge_batch(state_columns)
673 }
674
675 fn size(&self) -> usize {
676 self.inner.size()
677 }
678
679 fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
680 self.inner.state()
681 }
682}