query/dist_plan/
commutativity.rs1use std::collections::HashSet;
16use std::sync::Arc;
17
18use datafusion_expr::{Expr, LogicalPlan, UserDefinedLogicalNode};
19use promql::extension_plan::{
20 EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize,
21};
22
23use crate::dist_plan::merge_sort::{merge_sort_transformer, MergeSortLogicalPlan};
24use crate::dist_plan::MergeScanLogicalPlan;
25
26#[allow(dead_code)]
27pub enum Commutativity {
28 Commutative,
29 PartialCommutative,
30 ConditionalCommutative(Option<Transformer>),
31 TransformedCommutative(Option<Transformer>),
32 NonCommutative,
33 Unimplemented,
34 Unsupported,
36}
37
38pub struct Categorizer {}
39
40impl Categorizer {
41 pub fn check_plan(plan: &LogicalPlan, partition_cols: Option<Vec<String>>) -> Commutativity {
42 let partition_cols = partition_cols.unwrap_or_default();
43
44 match plan {
45 LogicalPlan::Projection(proj) => {
46 for expr in &proj.expr {
47 let commutativity = Self::check_expr(expr);
48 if !matches!(commutativity, Commutativity::Commutative) {
49 return commutativity;
50 }
51 }
52 Commutativity::Commutative
53 }
54 LogicalPlan::Filter(filter) => Self::check_expr(&filter.predicate),
56 LogicalPlan::Window(_) => Commutativity::Unimplemented,
57 LogicalPlan::Aggregate(aggr) => {
58 if !Self::check_partition(&aggr.group_expr, &partition_cols) {
59 return Commutativity::NonCommutative;
60 }
61 for expr in &aggr.aggr_expr {
62 let commutativity = Self::check_expr(expr);
63 if !matches!(commutativity, Commutativity::Commutative) {
64 return commutativity;
65 }
66 }
67 Commutativity::Commutative
68 }
69 LogicalPlan::Sort(_) => {
70 if partition_cols.is_empty() {
71 return Commutativity::Commutative;
72 }
73
74 Commutativity::ConditionalCommutative(Some(Arc::new(merge_sort_transformer)))
78 }
79 LogicalPlan::Join(_) => Commutativity::NonCommutative,
80 LogicalPlan::Repartition(_) => {
81 Commutativity::Unimplemented
83 }
84 LogicalPlan::Union(_) => Commutativity::Unimplemented,
85 LogicalPlan::TableScan(_) => Commutativity::Commutative,
86 LogicalPlan::EmptyRelation(_) => Commutativity::NonCommutative,
87 LogicalPlan::Subquery(_) => Commutativity::Unimplemented,
88 LogicalPlan::SubqueryAlias(_) => Commutativity::Unimplemented,
89 LogicalPlan::Limit(limit) => {
90 if partition_cols.is_empty() && limit.fetch.is_some() {
93 Commutativity::Commutative
94 } else if limit.skip.is_none() && limit.fetch.is_some() {
95 Commutativity::PartialCommutative
96 } else {
97 Commutativity::Unimplemented
98 }
99 }
100 LogicalPlan::Extension(extension) => {
101 Self::check_extension_plan(extension.node.as_ref() as _, &partition_cols)
102 }
103 LogicalPlan::Distinct(_) => {
104 if partition_cols.is_empty() {
105 Commutativity::Commutative
106 } else {
107 Commutativity::Unimplemented
108 }
109 }
110 LogicalPlan::Unnest(_) => Commutativity::Commutative,
111 LogicalPlan::Statement(_) => Commutativity::Unsupported,
112 LogicalPlan::Values(_) => Commutativity::Unsupported,
113 LogicalPlan::Explain(_) => Commutativity::Unsupported,
114 LogicalPlan::Analyze(_) => Commutativity::Unsupported,
115 LogicalPlan::DescribeTable(_) => Commutativity::Unsupported,
116 LogicalPlan::Dml(_) => Commutativity::Unsupported,
117 LogicalPlan::Ddl(_) => Commutativity::Unsupported,
118 LogicalPlan::Copy(_) => Commutativity::Unsupported,
119 LogicalPlan::RecursiveQuery(_) => Commutativity::Unsupported,
120 }
121 }
122
123 pub fn check_extension_plan(
124 plan: &dyn UserDefinedLogicalNode,
125 partition_cols: &[String],
126 ) -> Commutativity {
127 match plan.name() {
128 name if name == SeriesDivide::name() => {
129 let series_divide = plan.as_any().downcast_ref::<SeriesDivide>().unwrap();
130 let tags = series_divide.tags().iter().collect::<HashSet<_>>();
131 for partition_col in partition_cols {
132 if !tags.contains(partition_col) {
133 return Commutativity::NonCommutative;
134 }
135 }
136 Commutativity::Commutative
137 }
138 name if name == SeriesNormalize::name()
139 || name == InstantManipulate::name()
140 || name == RangeManipulate::name() =>
141 {
142 Commutativity::Commutative
145 }
146 name if name == EmptyMetric::name()
147 || name == MergeScanLogicalPlan::name()
148 || name == MergeSortLogicalPlan::name() =>
149 {
150 Commutativity::Unimplemented
151 }
152 _ => Commutativity::Unsupported,
153 }
154 }
155
156 pub fn check_expr(expr: &Expr) -> Commutativity {
157 match expr {
158 Expr::Column(_)
159 | Expr::ScalarVariable(_, _)
160 | Expr::Literal(_)
161 | Expr::BinaryExpr(_)
162 | Expr::Not(_)
163 | Expr::IsNotNull(_)
164 | Expr::IsNull(_)
165 | Expr::IsTrue(_)
166 | Expr::IsFalse(_)
167 | Expr::IsNotTrue(_)
168 | Expr::IsNotFalse(_)
169 | Expr::Negative(_)
170 | Expr::Between(_)
171 | Expr::Exists(_)
172 | Expr::InList(_) => Commutativity::Commutative,
173 Expr::ScalarFunction(_udf) => Commutativity::Commutative,
174 Expr::AggregateFunction(_udaf) => Commutativity::Commutative,
175
176 Expr::Like(_)
177 | Expr::SimilarTo(_)
178 | Expr::IsUnknown(_)
179 | Expr::IsNotUnknown(_)
180 | Expr::Case(_)
181 | Expr::Cast(_)
182 | Expr::TryCast(_)
183 | Expr::WindowFunction(_)
184 | Expr::InSubquery(_)
185 | Expr::ScalarSubquery(_)
186 | Expr::Wildcard { .. } => Commutativity::Unimplemented,
187
188 Expr::Alias(alias) => Self::check_expr(&alias.expr),
189
190 Expr::Unnest(_)
191 | Expr::GroupingSet(_)
192 | Expr::Placeholder(_)
193 | Expr::OuterReferenceColumn(_, _) => Commutativity::Unimplemented,
194 }
195 }
196
197 fn check_partition(exprs: &[Expr], partition_cols: &[String]) -> bool {
200 let mut ref_cols = HashSet::new();
201 for expr in exprs {
202 expr.add_column_refs(&mut ref_cols);
203 }
204 let ref_cols = ref_cols
205 .into_iter()
206 .map(|c| c.name.clone())
207 .collect::<HashSet<_>>();
208 for col in partition_cols {
209 if !ref_cols.contains(col) {
210 return false;
211 }
212 }
213
214 true
215 }
216}
217
218pub type Transformer = Arc<dyn Fn(&LogicalPlan) -> Option<LogicalPlan>>;
219
220pub fn partial_commutative_transformer(plan: &LogicalPlan) -> Option<LogicalPlan> {
221 Some(plan.clone())
222}
223
224#[cfg(test)]
225mod test {
226 use datafusion_expr::{LogicalPlanBuilder, Sort};
227
228 use super::*;
229
230 #[test]
231 fn sort_on_empty_partition() {
232 let plan = LogicalPlan::Sort(Sort {
233 expr: vec![],
234 input: Arc::new(LogicalPlanBuilder::empty(false).build().unwrap()),
235 fetch: None,
236 });
237 assert!(matches!(
238 Categorizer::check_plan(&plan, Some(vec![])),
239 Commutativity::Commutative
240 ));
241 }
242}