common_function/aggrs/aggr_wrapper/
fix_order.rs1use std::sync::Arc;
16
17use common_telemetry::debug;
18use datafusion::config::ConfigOptions;
19use datafusion::optimizer::AnalyzerRule;
20use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
21use datafusion_expr::{AggregateUDF, Expr, ExprSchemable, LogicalPlan};
22
23use crate::aggrs::aggr_wrapper::StateWrapper;
24
25#[derive(Debug, Default)]
31pub struct FixStateUdafOrderingAnalyzer;
32
33impl AnalyzerRule for FixStateUdafOrderingAnalyzer {
34    fn name(&self) -> &str {
35        "FixStateUdafOrderingAnalyzer"
36    }
37
38    fn analyze(
39        &self,
40        plan: LogicalPlan,
41        _config: &ConfigOptions,
42    ) -> datafusion_common::Result<LogicalPlan> {
43        plan.rewrite_with_subqueries(&mut FixOrderingRewriter::new(true))
44            .map(|t| t.data)
45    }
46}
47
48#[derive(Debug, Default)]
52pub struct UnFixStateUdafOrderingAnalyzer;
53
54impl AnalyzerRule for UnFixStateUdafOrderingAnalyzer {
55    fn name(&self) -> &str {
56        "UnFixStateUdafOrderingAnalyzer"
57    }
58
59    fn analyze(
60        &self,
61        plan: LogicalPlan,
62        _config: &ConfigOptions,
63    ) -> datafusion_common::Result<LogicalPlan> {
64        plan.rewrite_with_subqueries(&mut FixOrderingRewriter::new(false))
65            .map(|t| t.data)
66    }
67}
68
69struct FixOrderingRewriter {
70    is_dirty: bool,
72    is_fix: bool,
75}
76
77impl FixOrderingRewriter {
78    pub fn new(is_fix: bool) -> Self {
79        Self {
80            is_dirty: false,
81            is_fix,
82        }
83    }
84}
85
86impl TreeNodeRewriter for FixOrderingRewriter {
87    type Node = LogicalPlan;
88
89    fn f_up(
92        &mut self,
93        node: Self::Node,
94    ) -> datafusion_common::Result<datafusion_common::tree_node::Transformed<Self::Node>> {
95        let LogicalPlan::Aggregate(mut aggregate) = node else {
96            return if self.is_dirty {
97                let node = node.recompute_schema()?;
98                Ok(Transformed::yes(node))
99            } else {
100                Ok(Transformed::no(node))
101            };
102        };
103
104        for aggr_expr in &mut aggregate.aggr_expr {
106            let new_aggr_expr = aggr_expr
107                .clone()
108                .transform_up(|expr| rewrite_expr(expr, &aggregate.input, self.is_fix))?;
109
110            if new_aggr_expr.transformed {
111                *aggr_expr = new_aggr_expr.data;
112                self.is_dirty = true;
113            }
114        }
115
116        if self.is_dirty {
117            let node = LogicalPlan::Aggregate(aggregate).recompute_schema()?;
118            debug!(
119                "FixStateUdafOrderingAnalyzer: plan schema's field changed to {:?}",
120                node.schema().fields()
121            );
122
123            Ok(Transformed::yes(node))
124        } else {
125            Ok(Transformed::no(LogicalPlan::Aggregate(aggregate)))
126        }
127    }
128}
129
130fn rewrite_expr(
136    expr: Expr,
137    aggregate_input: &Arc<LogicalPlan>,
138    is_fix: bool,
139) -> Result<Transformed<Expr>, datafusion_common::DataFusionError> {
140    let Expr::AggregateFunction(aggregate_function) = expr else {
141        return Ok(Transformed::no(expr));
142    };
143
144    let Some(old_state_wrapper) = aggregate_function
145        .func
146        .inner()
147        .as_any()
148        .downcast_ref::<StateWrapper>()
149    else {
150        return Ok(Transformed::no(Expr::AggregateFunction(aggregate_function)));
151    };
152
153    let mut state_wrapper = old_state_wrapper.clone();
154    if is_fix {
155        let order_by = aggregate_function.params.order_by.clone();
157        let ordering_fields: Vec<_> = order_by
158            .iter()
159            .map(|sort_expr| {
160                sort_expr
161                    .expr
162                    .to_field(&aggregate_input.schema())
163                    .map(|(_, f)| f)
164            })
165            .collect::<datafusion_common::Result<Vec<_>>>()?;
166        let distinct = aggregate_function.params.distinct;
167
168        state_wrapper.ordering = ordering_fields;
170        state_wrapper.distinct = distinct;
171    } else {
172        state_wrapper.ordering = vec![];
174        state_wrapper.distinct = false;
175    }
176
177    debug!(
178        "FixStateUdafOrderingAnalyzer: fix state udaf from {old_state_wrapper:?} to {:?}",
179        state_wrapper
180    );
181
182    let mut aggregate_function = aggregate_function;
183
184    aggregate_function.func = Arc::new(AggregateUDF::new_from_impl(state_wrapper));
185
186    Ok(Transformed::yes(Expr::AggregateFunction(
187        aggregate_function,
188    )))
189}