common_function/aggrs/aggr_wrapper/
fix_order.rs1use std::sync::Arc;
16
17use common_telemetry::debug;
18use datafusion::config::ConfigOptions;
19use datafusion::optimizer::AnalyzerRule;
20use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
21use datafusion_expr::{AggregateUDF, Expr, ExprSchemable, LogicalPlan};
22
23use crate::aggrs::aggr_wrapper::StateWrapper;
24
25#[derive(Debug, Default)]
31pub struct FixStateUdafOrderingAnalyzer;
32
33impl AnalyzerRule for FixStateUdafOrderingAnalyzer {
34 fn name(&self) -> &str {
35 "FixStateUdafOrderingAnalyzer"
36 }
37
38 fn analyze(
39 &self,
40 plan: LogicalPlan,
41 _config: &ConfigOptions,
42 ) -> datafusion_common::Result<LogicalPlan> {
43 plan.rewrite_with_subqueries(&mut FixOrderingRewriter::new(true))
44 .map(|t| t.data)
45 }
46}
47
48#[derive(Debug, Default)]
52pub struct UnFixStateUdafOrderingAnalyzer;
53
54impl AnalyzerRule for UnFixStateUdafOrderingAnalyzer {
55 fn name(&self) -> &str {
56 "UnFixStateUdafOrderingAnalyzer"
57 }
58
59 fn analyze(
60 &self,
61 plan: LogicalPlan,
62 _config: &ConfigOptions,
63 ) -> datafusion_common::Result<LogicalPlan> {
64 plan.rewrite_with_subqueries(&mut FixOrderingRewriter::new(false))
65 .map(|t| t.data)
66 }
67}
68
69struct FixOrderingRewriter {
70 is_dirty: bool,
72 is_fix: bool,
75}
76
77impl FixOrderingRewriter {
78 pub fn new(is_fix: bool) -> Self {
79 Self {
80 is_dirty: false,
81 is_fix,
82 }
83 }
84}
85
86impl TreeNodeRewriter for FixOrderingRewriter {
87 type Node = LogicalPlan;
88
89 fn f_up(
92 &mut self,
93 node: Self::Node,
94 ) -> datafusion_common::Result<datafusion_common::tree_node::Transformed<Self::Node>> {
95 let LogicalPlan::Aggregate(mut aggregate) = node else {
96 return if self.is_dirty {
97 let node = node.recompute_schema()?;
98 Ok(Transformed::yes(node))
99 } else {
100 Ok(Transformed::no(node))
101 };
102 };
103
104 for aggr_expr in &mut aggregate.aggr_expr {
106 let new_aggr_expr = aggr_expr
107 .clone()
108 .transform_up(|expr| rewrite_expr(expr, &aggregate.input, self.is_fix))?;
109
110 if new_aggr_expr.transformed {
111 *aggr_expr = new_aggr_expr.data;
112 self.is_dirty = true;
113 }
114 }
115
116 if self.is_dirty {
117 let node = LogicalPlan::Aggregate(aggregate).recompute_schema()?;
118 debug!(
119 "FixStateUdafOrderingAnalyzer: plan schema's field changed to {:?}",
120 node.schema().fields()
121 );
122
123 Ok(Transformed::yes(node))
124 } else {
125 Ok(Transformed::no(LogicalPlan::Aggregate(aggregate)))
126 }
127 }
128}
129
130fn rewrite_expr(
136 expr: Expr,
137 aggregate_input: &Arc<LogicalPlan>,
138 is_fix: bool,
139) -> Result<Transformed<Expr>, datafusion_common::DataFusionError> {
140 let Expr::AggregateFunction(aggregate_function) = expr else {
141 return Ok(Transformed::no(expr));
142 };
143
144 let Some(old_state_wrapper) = aggregate_function
145 .func
146 .inner()
147 .as_any()
148 .downcast_ref::<StateWrapper>()
149 else {
150 return Ok(Transformed::no(Expr::AggregateFunction(aggregate_function)));
151 };
152
153 let mut state_wrapper = old_state_wrapper.clone();
154 if is_fix {
155 let order_by = aggregate_function.params.order_by.clone();
157 let ordering_fields: Vec<_> = order_by
158 .iter()
159 .map(|sort_expr| {
160 sort_expr
161 .expr
162 .to_field(&aggregate_input.schema())
163 .map(|(_, f)| f)
164 })
165 .collect::<datafusion_common::Result<Vec<_>>>()?;
166 let distinct = aggregate_function.params.distinct;
167
168 state_wrapper.ordering = ordering_fields;
170 state_wrapper.distinct = distinct;
171 } else {
172 state_wrapper.ordering = vec![];
174 state_wrapper.distinct = false;
175 }
176
177 debug!(
178 "FixStateUdafOrderingAnalyzer: fix state udaf from {old_state_wrapper:?} to {:?}",
179 state_wrapper
180 );
181
182 let mut aggregate_function = aggregate_function;
183
184 aggregate_function.func = Arc::new(AggregateUDF::new_from_impl(state_wrapper));
185
186 Ok(Transformed::yes(Expr::AggregateFunction(
187 aggregate_function,
188 )))
189}