1use std::hash::{Hash, Hasher};
26use std::sync::Arc;
27
28use arrow::array::{ArrayRef, BooleanArray, StructArray};
29use arrow_schema::{FieldRef, Fields};
30use common_telemetry::debug;
31use datafusion::functions_aggregate::all_default_aggregate_functions;
32use datafusion::functions_aggregate::count::Count;
33use datafusion::functions_aggregate::min_max::{Max, Min};
34use datafusion::optimizer::AnalyzerRule;
35use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
36use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter;
37use datafusion_common::{Column, ScalarValue};
38use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
39use datafusion_expr::function::StateFieldsArgs;
40use datafusion_expr::{
41 Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, EmitTo, Expr, ExprSchemable,
42 GroupsAccumulator, LogicalPlan, Signature,
43};
44use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
45use datatypes::arrow::datatypes::{DataType, Field};
46
47use crate::aggrs::aggr_wrapper::fix_order::FixStateUdafOrderingAnalyzer;
48use crate::function_registry::{FUNCTION_REGISTRY, FunctionRegistry};
49
50pub mod fix_order;
51#[cfg(test)]
52mod tests;
53
54pub fn aggr_state_func_name(aggr_name: &str) -> String {
58 format!("__{}_state", aggr_name)
59}
60
61pub fn aggr_merge_func_name(aggr_name: &str) -> String {
65 format!("__{}_merge", aggr_name)
66}
67
68pub fn is_all_aggr_exprs_steppable(aggr_exprs: &[Expr]) -> bool {
73 aggr_exprs.iter().all(|expr| {
74 if let Some(aggr_func) = get_aggr_func(expr) {
75 if aggr_func.params.distinct {
76 return false;
79 }
80
81 FUNCTION_REGISTRY.is_aggr_func_exist(&aggr_state_func_name(aggr_func.func.name()))
83 } else {
84 false
85 }
86 })
87}
88
89pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
90 let mut expr_ref = expr;
91 while let Expr::Alias(alias) = expr_ref {
92 expr_ref = &alias.expr;
93 }
94 if let Expr::AggregateFunction(aggr_func) = expr_ref {
95 Some(aggr_func)
96 } else {
97 None
98 }
99}
100
101#[derive(Debug, Clone)]
106pub struct StateMergeHelper;
107
108#[allow(unused)]
110#[derive(Debug, Clone)]
111pub struct StepAggrPlan {
112 pub upper_merge: LogicalPlan,
114 pub lower_state: LogicalPlan,
116}
117
118impl StateMergeHelper {
119 pub fn register(registry: &FunctionRegistry) {
122 let all_default = all_default_aggregate_functions();
123 let greptime_custom_aggr_functions = registry.aggregate_functions();
124
125 let supported = all_default
127 .into_iter()
128 .chain(greptime_custom_aggr_functions.into_iter().map(Arc::new))
129 .collect::<Vec<_>>();
130 debug!(
131 "Registering state functions for supported: {:?}",
132 supported.iter().map(|f| f.name()).collect::<Vec<_>>()
133 );
134
135 let state_func = supported.into_iter().filter_map(|f| {
136 StateWrapper::new((*f).clone())
137 .inspect_err(
138 |e| common_telemetry::error!(e; "Failed to register state function for {:?}", f),
139 )
140 .ok()
141 .map(AggregateUDF::new_from_impl)
142 });
143
144 for func in state_func {
145 registry.register_aggr(func);
146 }
147 }
148
149 pub fn split_aggr_node(aggr_plan: Aggregate) -> datafusion_common::Result<StepAggrPlan> {
152 let aggr = {
153 let aggr_plan = TypeCoercion::new().analyze(
155 LogicalPlan::Aggregate(aggr_plan).clone(),
156 &Default::default(),
157 )?;
158 if let LogicalPlan::Aggregate(aggr) = aggr_plan {
159 aggr
160 } else {
161 return Err(datafusion_common::DataFusionError::Internal(format!(
162 "Failed to coerce expressions in aggregate plan, expected Aggregate, got: {:?}",
163 aggr_plan
164 )));
165 }
166 };
167 let mut lower_aggr_exprs = vec![];
168 let mut upper_aggr_exprs = vec![];
169
170 let upper_group_exprs = aggr
173 .group_expr
174 .iter()
175 .map(|c| c.qualified_name())
176 .map(|(r, c)| Expr::Column(Column::new(r, c)))
177 .collect();
178
179 for aggr_expr in aggr.aggr_expr.iter() {
180 let Some(aggr_func) = get_aggr_func(aggr_expr) else {
181 return Err(datafusion_common::DataFusionError::NotImplemented(format!(
182 "Unsupported aggregate expression for step aggr optimize: {:?}",
183 aggr_expr
184 )));
185 };
186
187 let original_input_fields = aggr_func
188 .params
189 .args
190 .iter()
191 .map(|e| e.to_field(&aggr.input.schema()).map(|(_, field)| field))
192 .collect::<Result<Vec<_>, _>>()?;
193
194 let state_func = StateWrapper::new((*aggr_func.func).clone())?;
196
197 let expr = AggregateFunction {
198 func: Arc::new(state_func.into()),
199 params: aggr_func.params.clone(),
200 };
201 let expr = Expr::AggregateFunction(expr);
202 let lower_state_output_col_name = expr.schema_name().to_string();
203
204 lower_aggr_exprs.push(expr);
205
206 let (original_phy_expr, _filter, _ordering) = create_aggregate_expr_and_maybe_filter(
208 aggr_expr,
209 aggr.input.schema(),
210 aggr.input.schema().as_arrow(),
211 &Default::default(),
212 )?;
213
214 let merge_func = MergeWrapper::new(
215 (*aggr_func.func).clone(),
216 original_phy_expr,
217 original_input_fields,
218 )?;
219 let arg = Expr::Column(Column::new_unqualified(lower_state_output_col_name));
220 let expr = AggregateFunction {
221 func: Arc::new(merge_func.into()),
222 params: AggregateFunctionParams {
226 args: vec![arg],
227 distinct: aggr_func.params.distinct,
228 filter: None,
229 order_by: vec![],
230 null_treatment: aggr_func.params.null_treatment,
231 },
232 };
233
234 let expr = Expr::AggregateFunction(expr).alias(aggr_expr.schema_name().to_string());
237 upper_aggr_exprs.push(expr);
238 }
239
240 let mut lower = aggr.clone();
241 lower.aggr_expr = lower_aggr_exprs;
242 let lower_plan = LogicalPlan::Aggregate(lower);
243
244 let lower_plan = lower_plan.recompute_schema()?;
246
247 let fixed_lower_plan =
250 FixStateUdafOrderingAnalyzer.analyze(lower_plan, &Default::default())?;
251
252 let upper = Aggregate::try_new(
253 Arc::new(fixed_lower_plan.clone()),
254 upper_group_exprs,
255 upper_aggr_exprs.clone(),
256 )?;
257 let aggr_plan = LogicalPlan::Aggregate(aggr);
258
259 let upper_check = upper;
261 let upper_plan = LogicalPlan::Aggregate(upper_check).recompute_schema()?;
262 if *upper_plan.schema() != *aggr_plan.schema() {
263 return Err(datafusion_common::DataFusionError::Internal(format!(
264 "Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[original]:{}",
265 upper_plan.schema(),
266 aggr_plan.schema()
267 )));
268 }
269
270 Ok(StepAggrPlan {
271 lower_state: fixed_lower_plan,
272 upper_merge: upper_plan,
273 })
274 }
275}
276
277#[derive(Debug, Clone, PartialEq, Eq, Hash)]
279pub struct StateWrapper {
280 inner: AggregateUDF,
281 name: String,
282 ordering: Vec<FieldRef>,
284 distinct: bool,
286}
287
288impl StateWrapper {
289 pub fn new(inner: AggregateUDF) -> datafusion_common::Result<Self> {
291 let name = aggr_state_func_name(inner.name());
292 Ok(Self {
293 inner,
294 name,
295 ordering: vec![],
296 distinct: false,
297 })
298 }
299
300 pub fn inner(&self) -> &AggregateUDF {
301 &self.inner
302 }
303
304 pub fn deduce_aggr_return_type(
308 &self,
309 acc_args: &datafusion_expr::function::AccumulatorArgs,
310 ) -> datafusion_common::Result<FieldRef> {
311 let input_fields = acc_args
312 .exprs
313 .iter()
314 .map(|e| e.return_field(acc_args.schema))
315 .collect::<Result<Vec<_>, _>>()?;
316 self.inner.return_field(&input_fields).inspect_err(|e| {
317 common_telemetry::error!(
318 "StateWrapper: {:#?}\nacc_args:{:?}\nerror:{:?}",
319 &self,
320 &acc_args,
321 e
322 );
323 })
324 }
325
326 fn fix_inner_acc_args<'b>(
327 &self,
328 mut acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
329 ) -> datafusion_common::Result<datafusion_expr::function::AccumulatorArgs<'b>> {
330 acc_args.return_field = self.deduce_aggr_return_type(&acc_args)?;
331 Ok(acc_args)
332 }
333}
334
335impl AggregateUDFImpl for StateWrapper {
336 fn accumulator<'a, 'b>(
337 &'a self,
338 acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
339 ) -> datafusion_common::Result<Box<dyn Accumulator>> {
340 let state_type = acc_args.return_type().clone();
342 let inner = self.inner.accumulator(self.fix_inner_acc_args(acc_args)?)?;
343
344 Ok(Box::new(StateAccum::new(inner, state_type)?))
345 }
346
347 fn groups_accumulator_supported(
348 &self,
349 acc_args: datafusion_expr::function::AccumulatorArgs,
350 ) -> bool {
351 self.fix_inner_acc_args(acc_args)
352 .map(|args| self.inner.inner().groups_accumulator_supported(args))
353 .unwrap_or(false)
354 }
355
356 fn create_groups_accumulator(
357 &self,
358 acc_args: datafusion_expr::function::AccumulatorArgs,
359 ) -> datafusion_common::Result<Box<dyn GroupsAccumulator>> {
360 let state_type = acc_args.return_type().clone();
361 let inner = self
362 .inner
363 .inner()
364 .create_groups_accumulator(self.fix_inner_acc_args(acc_args)?)?;
365 Ok(Box::new(StateGroupsAccum::new(inner, state_type)?))
366 }
367
368 fn as_any(&self) -> &dyn std::any::Any {
369 self
370 }
371 fn name(&self) -> &str {
372 self.name.as_str()
373 }
374
375 fn is_nullable(&self) -> bool {
376 self.inner.is_nullable()
377 }
378
379 fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
382 let input_fields = &arg_types
383 .iter()
384 .map(|x| Arc::new(Field::new("x", x.clone(), false)))
385 .collect::<Vec<_>>();
386
387 let state_fields_args = StateFieldsArgs {
388 name: self.inner().name(),
389 input_fields,
390 return_field: self.inner.return_field(input_fields)?,
391 ordering_fields: &self.ordering,
393 is_distinct: self.distinct,
394 };
395 let state_fields = self.inner.state_fields(state_fields_args)?;
396
397 let state_fields = state_fields
398 .into_iter()
399 .map(|f| {
400 let mut f = f.as_ref().clone();
401 f.set_nullable(true);
403 Arc::new(f)
404 })
405 .collect::<Vec<_>>();
406
407 let struct_field = DataType::Struct(state_fields.into());
408 Ok(struct_field)
409 }
410
411 fn state_fields(
413 &self,
414 args: datafusion_expr::function::StateFieldsArgs,
415 ) -> datafusion_common::Result<Vec<FieldRef>> {
416 let state_fields_args = StateFieldsArgs {
417 name: args.name,
418 input_fields: args.input_fields,
419 return_field: self.inner.return_field(args.input_fields)?,
420 ordering_fields: args.ordering_fields,
421 is_distinct: args.is_distinct,
422 };
423 self.inner.state_fields(state_fields_args)
424 }
425
426 fn signature(&self) -> &Signature {
428 self.inner.signature()
429 }
430
431 fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
433 self.inner.coerce_types(arg_types)
434 }
435
436 fn value_from_stats(
437 &self,
438 statistics_args: &datafusion_expr::StatisticsArgs,
439 ) -> Option<ScalarValue> {
440 let inner = self.inner().inner().as_any();
441 let can_use_stat = inner.is::<Count>() || inner.is::<Max>() || inner.is::<Min>();
444 if !can_use_stat {
445 return None;
446 }
447
448 let state_type = if let DataType::Struct(fields) = &statistics_args.return_type {
450 if fields.is_empty() {
451 return None;
452 }
453 fields[0].data_type().clone()
454 } else {
455 return None;
456 };
457
458 let fixed_args = datafusion_expr::StatisticsArgs {
459 statistics: statistics_args.statistics,
460 return_type: &state_type,
461 is_distinct: statistics_args.is_distinct,
462 exprs: statistics_args.exprs,
463 };
464
465 let ret = self.inner().value_from_stats(&fixed_args)?;
466
467 let fields = if let DataType::Struct(fields) = &statistics_args.return_type {
469 fields
470 } else {
471 return None;
472 };
473
474 let array = ret.to_array().ok()?;
475
476 let struct_array = StructArray::new(fields.clone(), vec![array], None);
477 let ret = ScalarValue::Struct(Arc::new(struct_array));
478 Some(ret)
479 }
480}
481
482#[derive(Debug)]
485pub struct StateAccum {
486 inner: Box<dyn Accumulator>,
487 state_fields: Fields,
488}
489
490pub struct StateGroupsAccum {
491 inner: Box<dyn GroupsAccumulator>,
492 state_fields: Fields,
493}
494
495impl StateGroupsAccum {
496 fn new(
497 inner: Box<dyn GroupsAccumulator>,
498 state_type: DataType,
499 ) -> datafusion_common::Result<Self> {
500 let DataType::Struct(fields) = state_type else {
501 return Err(datafusion_common::DataFusionError::Internal(format!(
502 "Expected a struct type for state, got: {:?}",
503 state_type
504 )));
505 };
506 Ok(Self {
507 inner,
508 state_fields: fields,
509 })
510 }
511
512 fn wrap_state_arrays(&self, arrays: Vec<ArrayRef>) -> datafusion_common::Result<ArrayRef> {
513 let array_type = arrays
514 .iter()
515 .map(|array| array.data_type().clone())
516 .collect::<Vec<_>>();
517 let expected_type = self
518 .state_fields
519 .iter()
520 .map(|field| field.data_type().clone())
521 .collect::<Vec<_>>();
522 if array_type != expected_type {
523 debug!(
524 "State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}",
525 self.state_fields.len(),
526 arrays.len(),
527 self.state_fields,
528 array_type,
529 );
530 let guess_schema = arrays
531 .iter()
532 .enumerate()
533 .map(|(index, array)| {
534 Field::new(
535 format!("col_{index}[mismatch_state]").as_str(),
536 array.data_type().clone(),
537 true,
538 )
539 })
540 .collect::<Fields>();
541 let array = StructArray::try_new(guess_schema, arrays, None)?;
542 return Ok(Arc::new(array));
543 }
544
545 Ok(Arc::new(StructArray::try_new(
546 self.state_fields.clone(),
547 arrays,
548 None,
549 )?))
550 }
551}
552
553impl GroupsAccumulator for StateGroupsAccum {
554 fn update_batch(
555 &mut self,
556 values: &[ArrayRef],
557 group_indices: &[usize],
558 opt_filter: Option<&BooleanArray>,
559 total_num_groups: usize,
560 ) -> datafusion_common::Result<()> {
561 self.inner
562 .update_batch(values, group_indices, opt_filter, total_num_groups)
563 }
564
565 fn merge_batch(
566 &mut self,
567 values: &[ArrayRef],
568 group_indices: &[usize],
569 opt_filter: Option<&BooleanArray>,
570 total_num_groups: usize,
571 ) -> datafusion_common::Result<()> {
572 self.inner
573 .merge_batch(values, group_indices, opt_filter, total_num_groups)
574 }
575
576 fn evaluate(&mut self, emit_to: EmitTo) -> datafusion_common::Result<ArrayRef> {
577 let state = self.inner.state(emit_to)?;
578 self.wrap_state_arrays(state)
579 }
580
581 fn state(&mut self, emit_to: EmitTo) -> datafusion_common::Result<Vec<ArrayRef>> {
582 self.inner.state(emit_to)
583 }
584
585 fn convert_to_state(
586 &self,
587 values: &[ArrayRef],
588 opt_filter: Option<&BooleanArray>,
589 ) -> datafusion_common::Result<Vec<ArrayRef>> {
590 self.inner.convert_to_state(values, opt_filter)
591 }
592
593 fn supports_convert_to_state(&self) -> bool {
594 self.inner.supports_convert_to_state()
595 }
596
597 fn size(&self) -> usize {
598 self.inner.size()
599 }
600}
601
602impl StateAccum {
603 pub fn new(
604 inner: Box<dyn Accumulator>,
605 state_type: DataType,
606 ) -> datafusion_common::Result<Self> {
607 let DataType::Struct(fields) = state_type else {
608 return Err(datafusion_common::DataFusionError::Internal(format!(
609 "Expected a struct type for state, got: {:?}",
610 state_type
611 )));
612 };
613 Ok(Self {
614 inner,
615 state_fields: fields,
616 })
617 }
618}
619
620impl Accumulator for StateAccum {
621 fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
622 let state = self.inner.state()?;
623
624 let array = state
625 .iter()
626 .map(|s| s.to_array())
627 .collect::<Result<Vec<_>, _>>()?;
628 let array_type = array
629 .iter()
630 .map(|a| a.data_type().clone())
631 .collect::<Vec<_>>();
632 let expected_type: Vec<_> = self
633 .state_fields
634 .iter()
635 .map(|f| f.data_type().clone())
636 .collect();
637 if array_type != expected_type {
638 debug!(
639 "State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}",
640 self.state_fields.len(),
641 array.len(),
642 self.state_fields,
643 array_type,
644 );
645 let guess_schema = array
646 .iter()
647 .enumerate()
648 .map(|(index, array)| {
649 Field::new(
650 format!("col_{index}[mismatch_state]").as_str(),
651 array.data_type().clone(),
652 true,
653 )
654 })
655 .collect::<Fields>();
656 let arr = StructArray::try_new(guess_schema, array, None)?;
657
658 return Ok(ScalarValue::Struct(Arc::new(arr)));
659 }
660
661 let struct_array = StructArray::try_new(self.state_fields.clone(), array, None)?;
662 Ok(ScalarValue::Struct(Arc::new(struct_array)))
663 }
664
665 fn merge_batch(
666 &mut self,
667 states: &[datatypes::arrow::array::ArrayRef],
668 ) -> datafusion_common::Result<()> {
669 self.inner.merge_batch(states)
670 }
671
672 fn update_batch(
673 &mut self,
674 values: &[datatypes::arrow::array::ArrayRef],
675 ) -> datafusion_common::Result<()> {
676 self.inner.update_batch(values)
677 }
678
679 fn size(&self) -> usize {
680 self.inner.size()
681 }
682
683 fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
684 self.inner.state()
685 }
686}
687
688#[derive(Debug, Clone)]
693pub struct MergeWrapper {
694 inner: AggregateUDF,
695 name: String,
696 merge_signature: Signature,
697 original_phy_expr: Arc<AggregateFunctionExpr>,
699 return_field: FieldRef,
700}
701impl MergeWrapper {
702 pub fn new(
703 inner: AggregateUDF,
704 original_phy_expr: Arc<AggregateFunctionExpr>,
705 original_input_fields: Vec<FieldRef>,
706 ) -> datafusion_common::Result<Self> {
707 let name = aggr_merge_func_name(inner.name());
708 let merge_signature = Signature::user_defined(datafusion_expr::Volatility::Immutable);
710 let return_field = inner.return_field(&original_input_fields)?.clone();
711
712 Ok(Self {
713 inner,
714 name,
715 merge_signature,
716 original_phy_expr,
717 return_field,
718 })
719 }
720
721 pub fn inner(&self) -> &AggregateUDF {
722 &self.inner
723 }
724}
725
726impl AggregateUDFImpl for MergeWrapper {
727 fn accumulator<'a, 'b>(
728 &'a self,
729 acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
730 ) -> datafusion_common::Result<Box<dyn Accumulator>> {
731 if acc_args.exprs.len() != 1
732 || !matches!(
733 acc_args.exprs[0].data_type(acc_args.schema)?,
734 DataType::Struct(_)
735 )
736 {
737 return Err(datafusion_common::DataFusionError::Internal(format!(
738 "Expected one struct type as input, got: {:?}",
739 acc_args.schema
740 )));
741 }
742 let input_type = acc_args.exprs[0].data_type(acc_args.schema)?;
743 let DataType::Struct(fields) = input_type else {
744 return Err(datafusion_common::DataFusionError::Internal(format!(
745 "Expected a struct type for input, got: {:?}",
746 input_type
747 )));
748 };
749
750 let inner_accum = self.original_phy_expr.create_accumulator()?;
751 Ok(Box::new(MergeAccum::new(inner_accum, &fields)))
752 }
753
754 fn as_any(&self) -> &dyn std::any::Any {
755 self
756 }
757 fn name(&self) -> &str {
758 self.name.as_str()
759 }
760
761 fn is_nullable(&self) -> bool {
762 self.inner.is_nullable()
763 }
764
765 fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
768 Ok(self.return_field.data_type().clone())
770 }
771
772 fn return_field(&self, _arg_fields: &[FieldRef]) -> datafusion_common::Result<FieldRef> {
774 Ok(self.return_field.clone())
775 }
776
777 fn signature(&self) -> &Signature {
778 &self.merge_signature
779 }
780
781 fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
783 if arg_types.len() != 1 || !matches!(arg_types.first(), Some(DataType::Struct(_))) {
785 return Err(datafusion_common::DataFusionError::Internal(format!(
786 "Expected one struct type as input, got: {:?}",
787 arg_types
788 )));
789 }
790 Ok(arg_types.to_vec())
791 }
792
793 fn state_fields(
795 &self,
796 _args: datafusion_expr::function::StateFieldsArgs,
797 ) -> datafusion_common::Result<Vec<FieldRef>> {
798 self.original_phy_expr.state_fields()
799 }
800}
801
802impl PartialEq for MergeWrapper {
803 fn eq(&self, other: &Self) -> bool {
804 self.inner == other.inner
805 }
806}
807
808impl Eq for MergeWrapper {}
809
810impl Hash for MergeWrapper {
811 fn hash<H: Hasher>(&self, state: &mut H) {
812 self.inner.hash(state);
813 }
814}
815
816#[derive(Debug)]
820pub struct MergeAccum {
821 inner: Box<dyn Accumulator>,
822 state_fields: Fields,
823}
824
825impl MergeAccum {
826 pub fn new(inner: Box<dyn Accumulator>, state_fields: &Fields) -> Self {
827 Self {
828 inner,
829 state_fields: state_fields.clone(),
830 }
831 }
832}
833
834impl Accumulator for MergeAccum {
835 fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
836 self.inner.evaluate()
837 }
838
839 fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
840 self.inner.merge_batch(states)
841 }
842
843 fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
844 let value = values.first().ok_or_else(|| {
845 datafusion_common::DataFusionError::Internal("No values provided for merge".to_string())
846 })?;
847 let struct_arr = value
849 .as_any()
850 .downcast_ref::<StructArray>()
851 .ok_or_else(|| {
852 datafusion_common::DataFusionError::Internal(format!(
853 "Expected StructArray, got: {:?}",
854 value.data_type()
855 ))
856 })?;
857 let fields = struct_arr.fields();
858 if fields != &self.state_fields {
859 debug!(
860 "State fields mismatch, expected: {:?}, got: {:?}",
861 self.state_fields, fields
862 );
863 }
865
866 let state_columns = struct_arr.columns();
869 self.inner.merge_batch(state_columns)
870 }
871
872 fn size(&self) -> usize {
873 self.inner.size()
874 }
875
876 fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
877 self.inner.state()
878 }
879}