query/optimizer/
count_nest_aggr.rs1use std::collections::HashSet;
16use std::sync::Arc;
17
18use datafusion::config::ConfigOptions;
19use datafusion::functions_aggregate::count::count_udaf;
20use datafusion::logical_expr::{Extension, LogicalPlan, LogicalPlanBuilder, Sort};
21use datafusion_common::Result;
22use datafusion_common::tree_node::{Transformed, TreeNode};
23use datafusion_expr::{Expr, UserDefinedLogicalNodeCore, lit};
24use promql::extension_plan::{InstantManipulate, SeriesDivide, SeriesNormalize};
25use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
26
27use crate::QueryEngineContext;
28use crate::optimizer::ExtensionAnalyzerRule;
29
30#[derive(Debug)]
40pub struct CountNestAggrRule;
41
42impl ExtensionAnalyzerRule for CountNestAggrRule {
43 fn analyze(
44 &self,
45 plan: LogicalPlan,
46 _ctx: &QueryEngineContext,
47 _config: &ConfigOptions,
48 ) -> Result<LogicalPlan> {
49 plan.transform_down(&Self::rewrite_plan).map(|x| x.data)
50 }
51}
52
53impl CountNestAggrRule {
54 fn rewrite_plan(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
55 let LogicalPlan::Sort(sort) = plan else {
56 return Ok(Transformed::no(plan));
57 };
58
59 if let Some(rewritten) = Self::try_rewrite_sort(&sort)? {
60 Ok(Transformed::yes(rewritten))
61 } else {
62 Ok(Transformed::no(LogicalPlan::Sort(sort)))
63 }
64 }
65
66 fn try_rewrite_sort(sort: &Sort) -> Result<Option<LogicalPlan>> {
67 if sort.fetch.is_some() {
68 return Ok(None);
69 }
70
71 let LogicalPlan::Aggregate(outer_agg) = sort.input.as_ref() else {
72 return Ok(None);
73 };
74 if outer_agg.group_expr.len() != 1 || outer_agg.aggr_expr.len() != 1 {
75 return Ok(None);
76 }
77 let outer_time_expr = outer_agg.group_expr[0].clone();
78 let outer_count_arg =
79 match Self::aggregate_if(&outer_agg.aggr_expr[0], |name| name == "count") {
80 Some((_, arg)) => arg,
81 None => return Ok(None),
82 };
83
84 let LogicalPlan::Sort(inner_sort) = outer_agg.input.as_ref() else {
85 return Ok(None);
86 };
87 if inner_sort.fetch.is_some() {
88 return Ok(None);
89 }
90
91 let LogicalPlan::Aggregate(inner_agg) = inner_sort.input.as_ref() else {
92 return Ok(None);
93 };
94 if inner_agg.aggr_expr.len() != 1 || inner_agg.group_expr.is_empty() {
95 return Ok(None);
96 }
97 let (inner_is_count, inner_value_expr) =
98 match Self::aggregate_if(&inner_agg.aggr_expr[0], |name| {
99 Self::is_supported_inner_aggregate(name)
100 }) {
101 Some((name, arg)) => (name == "count", arg),
102 None => return Ok(None),
103 };
104 let Expr::Column(_) = inner_value_expr else {
105 return Ok(None);
106 };
107
108 let Expr::Column(outer_count_column) = outer_count_arg else {
109 return Ok(None);
110 };
111 let inner_output_field = inner_agg.schema.field(inner_agg.group_expr.len());
112 if outer_count_column.name != *inner_output_field.name() {
113 return Ok(None);
114 }
115
116 if !Self::is_projection_chain_to_instant(inner_agg.input.as_ref()) {
117 return Ok(None);
118 }
119
120 if !inner_agg
121 .group_expr
122 .iter()
123 .all(|expr| matches!(expr, Expr::Column(_)))
124 {
125 return Ok(None);
126 }
127
128 let Some(time_expr_pos) = inner_agg
129 .group_expr
130 .iter()
131 .position(|expr| expr == &outer_time_expr)
132 else {
133 return Ok(None);
134 };
135
136 let mut presence_group_exprs = Vec::with_capacity(inner_agg.group_expr.len());
137 presence_group_exprs.push(outer_time_expr.clone());
138 presence_group_exprs.extend(
139 inner_agg
140 .group_expr
141 .iter()
142 .enumerate()
143 .filter(|(idx, _)| *idx != time_expr_pos)
144 .map(|(_, expr)| expr.clone()),
145 );
146
147 let mut required_input_columns =
148 Self::collect_required_input_columns(&presence_group_exprs, inner_value_expr);
149 required_input_columns.extend(Self::collect_required_instant_columns(
150 inner_agg.input.as_ref(),
151 ));
152 let presence_source = Self::rebuild_projection_chain_to_instant(
153 inner_agg.input.as_ref(),
154 &required_input_columns,
155 )?;
156
157 let outer_value_name = outer_agg
158 .schema
159 .field(outer_agg.group_expr.len())
160 .name()
161 .clone();
162 let mut presence_input = LogicalPlanBuilder::from(presence_source);
163 if !inner_is_count {
164 presence_input = presence_input.filter(inner_value_expr.clone().is_not_null())?;
165 }
166 let presence_input = presence_input
167 .project(presence_group_exprs.clone())?
168 .distinct()?
169 .build()?;
170
171 let rewritten = LogicalPlanBuilder::from(presence_input)
172 .aggregate(
173 outer_agg.group_expr.clone(),
174 vec![count_udaf().call(vec![lit(1_i64)]).alias(outer_value_name)],
175 )?
176 .sort(sort.expr.clone())?
177 .build()?;
178
179 Ok(Some(rewritten))
180 }
181
182 fn collect_required_input_columns(group_exprs: &[Expr], value_expr: &Expr) -> HashSet<String> {
183 let mut required = HashSet::new();
184
185 for expr in group_exprs {
186 if let Expr::Column(column) = expr {
187 required.insert(column.name.clone());
188 }
189 }
190 if let Expr::Column(column) = value_expr {
191 required.insert(column.name.clone());
194 }
195
196 required
197 }
198
199 fn collect_required_instant_columns(plan: &LogicalPlan) -> HashSet<String> {
200 let mut required = HashSet::new();
201 Self::collect_required_instant_columns_into(plan, &mut required);
202 required
203 }
204
205 fn collect_required_instant_columns_into(plan: &LogicalPlan, required: &mut HashSet<String>) {
206 match plan {
207 LogicalPlan::Projection(projection) => {
208 Self::collect_required_instant_columns_into(projection.input.as_ref(), required);
209 }
210 LogicalPlan::Extension(extension) => {
211 for expr in extension.node.expressions() {
212 if let Expr::Column(column) = expr {
213 required.insert(column.name);
214 }
215 }
216
217 if extension.node.as_any().is::<SeriesDivide>()
218 && extension.node.inputs()[0]
219 .schema()
220 .fields()
221 .iter()
222 .any(|field| field.name() == DATA_SCHEMA_TSID_COLUMN_NAME)
223 {
224 required.insert(DATA_SCHEMA_TSID_COLUMN_NAME.to_string());
225 }
226
227 if let Some(input) = extension.node.inputs().into_iter().next() {
228 Self::collect_required_instant_columns_into(input, required);
229 }
230 }
231 _ => {}
232 }
233 }
234
235 fn aggregate_if<F>(expr: &Expr, accept_name: F) -> Option<(&str, &Expr)>
236 where
237 F: FnOnce(&str) -> bool,
238 {
239 let Expr::AggregateFunction(func) = expr else {
240 return None;
241 };
242 let name = func.func.name();
243 if !accept_name(name)
244 || func.params.filter.is_some()
245 || func.params.distinct
246 || !func.params.order_by.is_empty()
247 || func.params.args.len() != 1
248 {
249 return None;
250 }
251
252 Some((name, &func.params.args[0]))
253 }
254
255 fn is_supported_inner_aggregate(name: &str) -> bool {
256 matches!(
257 name,
258 "count" | "sum" | "avg" | "min" | "max" | "stddev_pop" | "var_pop"
259 )
260 }
261
262 fn is_projection_chain_to_instant(plan: &LogicalPlan) -> bool {
263 let mut current = plan;
264 loop {
265 match current {
266 LogicalPlan::Projection(projection) => current = projection.input.as_ref(),
267 LogicalPlan::Extension(ext) => {
268 return ext.node.as_any().is::<InstantManipulate>();
269 }
270 _ => return false,
271 }
272 }
273 }
274
275 fn rebuild_projection_chain_to_instant(
276 plan: &LogicalPlan,
277 required_columns: &HashSet<String>,
278 ) -> Result<LogicalPlan> {
279 match plan {
280 LogicalPlan::Projection(projection) => {
281 let input = Self::rebuild_projection_chain_to_instant(
282 projection.input.as_ref(),
283 required_columns,
284 )?;
285 LogicalPlanBuilder::from(input)
286 .project(projection.expr.clone())?
287 .build()
288 }
289 LogicalPlan::Extension(extension) => {
290 if let Some(instant) = extension.node.as_any().downcast_ref::<InstantManipulate>() {
291 let input =
292 Self::prune_instant_input(extension.node.inputs()[0], required_columns)?;
293 return Ok(LogicalPlan::Extension(Extension {
294 node: Arc::new(instant.with_exprs_and_inputs(vec![], vec![input])?),
295 }));
296 }
297
298 Ok(plan.clone())
299 }
300 _ => Ok(plan.clone()),
301 }
302 }
303
304 fn prune_instant_input(
305 plan: &LogicalPlan,
306 required_columns: &HashSet<String>,
307 ) -> Result<LogicalPlan> {
308 match plan {
309 LogicalPlan::Extension(extension) => {
310 if let Some(normalize) = extension.node.as_any().downcast_ref::<SeriesNormalize>() {
311 let input =
312 Self::prune_instant_input(extension.node.inputs()[0], required_columns)?;
313 return Ok(LogicalPlan::Extension(Extension {
314 node: Arc::new(normalize.with_exprs_and_inputs(vec![], vec![input])?),
315 }));
316 }
317
318 if let Some(divide) = extension.node.as_any().downcast_ref::<SeriesDivide>() {
319 let divide_input = extension.node.inputs()[0].clone();
320
321 let projection_exprs = divide_input
322 .schema()
323 .fields()
324 .iter()
325 .filter(|field| required_columns.contains(field.name()))
326 .map(|field| {
327 Expr::Column(datafusion_common::Column::from_name(field.name().clone()))
328 })
329 .collect::<Vec<_>>();
330 let projected_input = LogicalPlanBuilder::from(divide_input)
331 .project(projection_exprs)?
332 .build()?;
333
334 return Ok(LogicalPlan::Extension(Extension {
335 node: Arc::new(
336 divide.with_exprs_and_inputs(vec![], vec![projected_input])?,
337 ),
338 }));
339 }
340
341 Ok(plan.clone())
342 }
343 _ => Ok(plan.clone()),
344 }
345 }
346}