common_function/aggrs/aggr_wrapper/
fix_order.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use common_telemetry::debug;
18use datafusion::config::ConfigOptions;
19use datafusion::optimizer::AnalyzerRule;
20use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
21use datafusion_expr::{AggregateUDF, Expr, ExprSchemable, LogicalPlan};
22
23use crate::aggrs::aggr_wrapper::StateWrapper;
24
25/// Traverse the plan, found all `__<aggr_name>_state` and fix their ordering fields
26/// if their input aggr is with order by, this is currently only useful for `first_value` and `last_value` udaf
27///
28/// should be applied to datanode's query engine
29/// TODO(discord9): proper way to extend substrait's serde ability to allow carry more info for custom udaf with more info
30#[derive(Debug, Default)]
31pub struct FixStateUdafOrderingAnalyzer;
32
33impl AnalyzerRule for FixStateUdafOrderingAnalyzer {
34    fn name(&self) -> &str {
35        "FixStateUdafOrderingAnalyzer"
36    }
37
38    fn analyze(
39        &self,
40        plan: LogicalPlan,
41        _config: &ConfigOptions,
42    ) -> datafusion_common::Result<LogicalPlan> {
43        plan.rewrite_with_subqueries(&mut FixOrderingRewriter::new(true))
44            .map(|t| t.data)
45    }
46}
47
48/// Traverse the plan, found all `__<aggr_name>_state` and remove their ordering fields
49/// this is currently only useful for `first_value` and `last_value` udaf when need to encode to substrait
50///
51#[derive(Debug, Default)]
52pub struct UnFixStateUdafOrderingAnalyzer;
53
54impl AnalyzerRule for UnFixStateUdafOrderingAnalyzer {
55    fn name(&self) -> &str {
56        "UnFixStateUdafOrderingAnalyzer"
57    }
58
59    fn analyze(
60        &self,
61        plan: LogicalPlan,
62        _config: &ConfigOptions,
63    ) -> datafusion_common::Result<LogicalPlan> {
64        plan.rewrite_with_subqueries(&mut FixOrderingRewriter::new(false))
65            .map(|t| t.data)
66    }
67}
68
69struct FixOrderingRewriter {
70    /// once fixed, mark dirty, and always recompute schema from bottom up
71    is_dirty: bool,
72    /// if true, will add the ordering field from outer aggr expr
73    /// if false, will remove the ordering field
74    is_fix: bool,
75}
76
77impl FixOrderingRewriter {
78    pub fn new(is_fix: bool) -> Self {
79        Self {
80            is_dirty: false,
81            is_fix,
82        }
83    }
84}
85
86impl TreeNodeRewriter for FixOrderingRewriter {
87    type Node = LogicalPlan;
88
89    /// found all `__<aggr_name>_state` and fix their ordering fields
90    /// if their input aggr is with order by
91    fn f_up(
92        &mut self,
93        node: Self::Node,
94    ) -> datafusion_common::Result<datafusion_common::tree_node::Transformed<Self::Node>> {
95        let LogicalPlan::Aggregate(mut aggregate) = node else {
96            return if self.is_dirty {
97                let node = node.recompute_schema()?;
98                Ok(Transformed::yes(node))
99            } else {
100                Ok(Transformed::no(node))
101            };
102        };
103
104        // regex to match state udaf name
105        for aggr_expr in &mut aggregate.aggr_expr {
106            let new_aggr_expr = aggr_expr
107                .clone()
108                .transform_up(|expr| rewrite_expr(expr, &aggregate.input, self.is_fix))?;
109
110            if new_aggr_expr.transformed {
111                *aggr_expr = new_aggr_expr.data;
112                self.is_dirty = true;
113            }
114        }
115
116        if self.is_dirty {
117            let node = LogicalPlan::Aggregate(aggregate).recompute_schema()?;
118            debug!(
119                "FixStateUdafOrderingAnalyzer: plan schema's field changed to {:?}",
120                node.schema().fields()
121            );
122
123            Ok(Transformed::yes(node))
124        } else {
125            Ok(Transformed::no(LogicalPlan::Aggregate(aggregate)))
126        }
127    }
128}
129
130/// first see the aggr node in expr
131/// as it could be nested aggr like alias(aggr(sort))
132/// if contained aggr expr have a order by, and the aggr name match the regex
133/// then we need to fix the ordering field of the state udaf
134/// to be the same as the aggr expr
135fn rewrite_expr(
136    expr: Expr,
137    aggregate_input: &Arc<LogicalPlan>,
138    is_fix: bool,
139) -> Result<Transformed<Expr>, datafusion_common::DataFusionError> {
140    let Expr::AggregateFunction(aggregate_function) = expr else {
141        return Ok(Transformed::no(expr));
142    };
143
144    let Some(old_state_wrapper) = aggregate_function
145        .func
146        .inner()
147        .as_any()
148        .downcast_ref::<StateWrapper>()
149    else {
150        return Ok(Transformed::no(Expr::AggregateFunction(aggregate_function)));
151    };
152
153    let mut state_wrapper = old_state_wrapper.clone();
154    if is_fix {
155        // then always fix the ordering field&distinct flag and more
156        let order_by = aggregate_function.params.order_by.clone();
157        let ordering_fields: Vec<_> = order_by
158            .iter()
159            .map(|sort_expr| {
160                sort_expr
161                    .expr
162                    .to_field(&aggregate_input.schema())
163                    .map(|(_, f)| f)
164            })
165            .collect::<datafusion_common::Result<Vec<_>>>()?;
166        let distinct = aggregate_function.params.distinct;
167
168        // fixing up
169        state_wrapper.ordering = ordering_fields;
170        state_wrapper.distinct = distinct;
171    } else {
172        // remove the ordering field & distinct flag
173        state_wrapper.ordering = vec![];
174        state_wrapper.distinct = false;
175    }
176
177    debug!(
178        "FixStateUdafOrderingAnalyzer: fix state udaf from {old_state_wrapper:?} to {:?}",
179        state_wrapper
180    );
181
182    let mut aggregate_function = aggregate_function;
183
184    aggregate_function.func = Arc::new(AggregateUDF::new_from_impl(state_wrapper));
185
186    Ok(Transformed::yes(Expr::AggregateFunction(
187        aggregate_function,
188    )))
189}