Skip to main content

query/optimizer/
count_nest_aggr.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::sync::Arc;
17
18use datafusion::config::ConfigOptions;
19use datafusion::functions_aggregate::count::count_udaf;
20use datafusion::logical_expr::{Extension, LogicalPlan, LogicalPlanBuilder, Sort};
21use datafusion_common::Result;
22use datafusion_common::tree_node::{Transformed, TreeNode};
23use datafusion_expr::{Expr, UserDefinedLogicalNodeCore, lit};
24use promql::extension_plan::{InstantManipulate, SeriesDivide, SeriesNormalize};
25use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
26
27use crate::QueryEngineContext;
28use crate::optimizer::ExtensionAnalyzerRule;
29
30/// Rewrites `count(<presence-preserving-agg>(<vector_selector>) by (...))` into a presence-based
31/// group count.
32///
33/// This stays intentionally narrow:
34/// - the outer aggregate must be plain `count`
35/// - the inner aggregate must be a plain aggregate whose result existence is equivalent to input
36///   group existence
37/// - the inner input must be the direct instant-vector-selector plan
38/// - the outer count must only group by the evaluation timestamp
39#[derive(Debug)]
40pub struct CountNestAggrRule;
41
42impl ExtensionAnalyzerRule for CountNestAggrRule {
43    fn analyze(
44        &self,
45        plan: LogicalPlan,
46        _ctx: &QueryEngineContext,
47        _config: &ConfigOptions,
48    ) -> Result<LogicalPlan> {
49        plan.transform_down(&Self::rewrite_plan).map(|x| x.data)
50    }
51}
52
53impl CountNestAggrRule {
54    fn rewrite_plan(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
55        let LogicalPlan::Sort(sort) = plan else {
56            return Ok(Transformed::no(plan));
57        };
58
59        if let Some(rewritten) = Self::try_rewrite_sort(&sort)? {
60            Ok(Transformed::yes(rewritten))
61        } else {
62            Ok(Transformed::no(LogicalPlan::Sort(sort)))
63        }
64    }
65
66    fn try_rewrite_sort(sort: &Sort) -> Result<Option<LogicalPlan>> {
67        if sort.fetch.is_some() {
68            return Ok(None);
69        }
70
71        let LogicalPlan::Aggregate(outer_agg) = sort.input.as_ref() else {
72            return Ok(None);
73        };
74        if outer_agg.group_expr.len() != 1 || outer_agg.aggr_expr.len() != 1 {
75            return Ok(None);
76        }
77        let outer_time_expr = outer_agg.group_expr[0].clone();
78        let outer_count_arg =
79            match Self::aggregate_if(&outer_agg.aggr_expr[0], |name| name == "count") {
80                Some((_, arg)) => arg,
81                None => return Ok(None),
82            };
83
84        let LogicalPlan::Sort(inner_sort) = outer_agg.input.as_ref() else {
85            return Ok(None);
86        };
87        if inner_sort.fetch.is_some() {
88            return Ok(None);
89        }
90
91        let LogicalPlan::Aggregate(inner_agg) = inner_sort.input.as_ref() else {
92            return Ok(None);
93        };
94        if inner_agg.aggr_expr.len() != 1 || inner_agg.group_expr.is_empty() {
95            return Ok(None);
96        }
97        let (inner_is_count, inner_value_expr) =
98            match Self::aggregate_if(&inner_agg.aggr_expr[0], |name| {
99                Self::is_supported_inner_aggregate(name)
100            }) {
101                Some((name, arg)) => (name == "count", arg),
102                None => return Ok(None),
103            };
104        let Expr::Column(_) = inner_value_expr else {
105            return Ok(None);
106        };
107
108        let Expr::Column(outer_count_column) = outer_count_arg else {
109            return Ok(None);
110        };
111        let inner_output_field = inner_agg.schema.field(inner_agg.group_expr.len());
112        if outer_count_column.name != *inner_output_field.name() {
113            return Ok(None);
114        }
115
116        if !Self::is_projection_chain_to_instant(inner_agg.input.as_ref()) {
117            return Ok(None);
118        }
119
120        if !inner_agg
121            .group_expr
122            .iter()
123            .all(|expr| matches!(expr, Expr::Column(_)))
124        {
125            return Ok(None);
126        }
127
128        let Some(time_expr_pos) = inner_agg
129            .group_expr
130            .iter()
131            .position(|expr| expr == &outer_time_expr)
132        else {
133            return Ok(None);
134        };
135
136        let mut presence_group_exprs = Vec::with_capacity(inner_agg.group_expr.len());
137        presence_group_exprs.push(outer_time_expr.clone());
138        presence_group_exprs.extend(
139            inner_agg
140                .group_expr
141                .iter()
142                .enumerate()
143                .filter(|(idx, _)| *idx != time_expr_pos)
144                .map(|(_, expr)| expr.clone()),
145        );
146
147        let mut required_input_columns =
148            Self::collect_required_input_columns(&presence_group_exprs, inner_value_expr);
149        required_input_columns.extend(Self::collect_required_instant_columns(
150            inner_agg.input.as_ref(),
151        ));
152        let presence_source = Self::rebuild_projection_chain_to_instant(
153            inner_agg.input.as_ref(),
154            &required_input_columns,
155        )?;
156
157        let outer_value_name = outer_agg
158            .schema
159            .field(outer_agg.group_expr.len())
160            .name()
161            .clone();
162        let mut presence_input = LogicalPlanBuilder::from(presence_source);
163        if !inner_is_count {
164            presence_input = presence_input.filter(inner_value_expr.clone().is_not_null())?;
165        }
166        let presence_input = presence_input
167            .project(presence_group_exprs.clone())?
168            .distinct()?
169            .build()?;
170
171        let rewritten = LogicalPlanBuilder::from(presence_input)
172            .aggregate(
173                outer_agg.group_expr.clone(),
174                vec![count_udaf().call(vec![lit(1_i64)]).alias(outer_value_name)],
175            )?
176            .sort(sort.expr.clone())?
177            .build()?;
178
179        Ok(Some(rewritten))
180    }
181
182    fn collect_required_input_columns(group_exprs: &[Expr], value_expr: &Expr) -> HashSet<String> {
183        let mut required = HashSet::new();
184
185        for expr in group_exprs {
186            if let Expr::Column(column) = expr {
187                required.insert(column.name.clone());
188            }
189        }
190        if let Expr::Column(column) = value_expr {
191            // Keep the value column in the pruned instant input so `InstantManipulate`
192            // can still perform stale-NaN filtering before we project down to keys.
193            required.insert(column.name.clone());
194        }
195
196        required
197    }
198
199    fn collect_required_instant_columns(plan: &LogicalPlan) -> HashSet<String> {
200        let mut required = HashSet::new();
201        Self::collect_required_instant_columns_into(plan, &mut required);
202        required
203    }
204
205    fn collect_required_instant_columns_into(plan: &LogicalPlan, required: &mut HashSet<String>) {
206        match plan {
207            LogicalPlan::Projection(projection) => {
208                Self::collect_required_instant_columns_into(projection.input.as_ref(), required);
209            }
210            LogicalPlan::Extension(extension) => {
211                for expr in extension.node.expressions() {
212                    if let Expr::Column(column) = expr {
213                        required.insert(column.name);
214                    }
215                }
216
217                if extension.node.as_any().is::<SeriesDivide>()
218                    && extension.node.inputs()[0]
219                        .schema()
220                        .fields()
221                        .iter()
222                        .any(|field| field.name() == DATA_SCHEMA_TSID_COLUMN_NAME)
223                {
224                    required.insert(DATA_SCHEMA_TSID_COLUMN_NAME.to_string());
225                }
226
227                if let Some(input) = extension.node.inputs().into_iter().next() {
228                    Self::collect_required_instant_columns_into(input, required);
229                }
230            }
231            _ => {}
232        }
233    }
234
235    fn aggregate_if<F>(expr: &Expr, accept_name: F) -> Option<(&str, &Expr)>
236    where
237        F: FnOnce(&str) -> bool,
238    {
239        let Expr::AggregateFunction(func) = expr else {
240            return None;
241        };
242        let name = func.func.name();
243        if !accept_name(name)
244            || func.params.filter.is_some()
245            || func.params.distinct
246            || !func.params.order_by.is_empty()
247            || func.params.args.len() != 1
248        {
249            return None;
250        }
251
252        Some((name, &func.params.args[0]))
253    }
254
255    fn is_supported_inner_aggregate(name: &str) -> bool {
256        matches!(
257            name,
258            "count" | "sum" | "avg" | "min" | "max" | "stddev_pop" | "var_pop"
259        )
260    }
261
262    fn is_projection_chain_to_instant(plan: &LogicalPlan) -> bool {
263        let mut current = plan;
264        loop {
265            match current {
266                LogicalPlan::Projection(projection) => current = projection.input.as_ref(),
267                LogicalPlan::Extension(ext) => {
268                    return ext.node.as_any().is::<InstantManipulate>();
269                }
270                _ => return false,
271            }
272        }
273    }
274
275    fn rebuild_projection_chain_to_instant(
276        plan: &LogicalPlan,
277        required_columns: &HashSet<String>,
278    ) -> Result<LogicalPlan> {
279        match plan {
280            LogicalPlan::Projection(projection) => {
281                let input = Self::rebuild_projection_chain_to_instant(
282                    projection.input.as_ref(),
283                    required_columns,
284                )?;
285                LogicalPlanBuilder::from(input)
286                    .project(projection.expr.clone())?
287                    .build()
288            }
289            LogicalPlan::Extension(extension) => {
290                if let Some(instant) = extension.node.as_any().downcast_ref::<InstantManipulate>() {
291                    let input =
292                        Self::prune_instant_input(extension.node.inputs()[0], required_columns)?;
293                    return Ok(LogicalPlan::Extension(Extension {
294                        node: Arc::new(instant.with_exprs_and_inputs(vec![], vec![input])?),
295                    }));
296                }
297
298                Ok(plan.clone())
299            }
300            _ => Ok(plan.clone()),
301        }
302    }
303
304    fn prune_instant_input(
305        plan: &LogicalPlan,
306        required_columns: &HashSet<String>,
307    ) -> Result<LogicalPlan> {
308        match plan {
309            LogicalPlan::Extension(extension) => {
310                if let Some(normalize) = extension.node.as_any().downcast_ref::<SeriesNormalize>() {
311                    let input =
312                        Self::prune_instant_input(extension.node.inputs()[0], required_columns)?;
313                    return Ok(LogicalPlan::Extension(Extension {
314                        node: Arc::new(normalize.with_exprs_and_inputs(vec![], vec![input])?),
315                    }));
316                }
317
318                if let Some(divide) = extension.node.as_any().downcast_ref::<SeriesDivide>() {
319                    let divide_input = extension.node.inputs()[0].clone();
320
321                    let projection_exprs = divide_input
322                        .schema()
323                        .fields()
324                        .iter()
325                        .filter(|field| required_columns.contains(field.name()))
326                        .map(|field| {
327                            Expr::Column(datafusion_common::Column::from_name(field.name().clone()))
328                        })
329                        .collect::<Vec<_>>();
330                    let projected_input = LogicalPlanBuilder::from(divide_input)
331                        .project(projection_exprs)?
332                        .build()?;
333
334                    return Ok(LogicalPlan::Extension(Extension {
335                        node: Arc::new(
336                            divide.with_exprs_and_inputs(vec![], vec![projected_input])?,
337                        ),
338                    }));
339                }
340
341                Ok(plan.clone())
342            }
343            _ => Ok(plan.clone()),
344        }
345    }
346}