partition/
simplify.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Simplification utilities for partition expressions.
16//!
17//! The main use case is simplifying `MERGE PARTITION` generated expressions:
18//! `expr1 OR expr2 OR ...`, where each expr is a conjunction of simple
19//! comparisons on partition columns.
20
21use std::collections::{BTreeMap, BTreeSet};
22use std::ops::Bound;
23
24use datatypes::value::{OrderedF64, Value};
25
26use crate::collider::{AtomicExpr, Collider, GluonOp, NucleonExpr};
27use crate::expr::{Operand, PartitionExpr, RestrictedOp, col};
28
29/// Attempts to simplify a merged partition expression (typically an `OR` of multiple partitions)
30/// into an equivalent but shorter expression.
31///
32/// Falls back to the original expression if the simplifier can't prove equivalence.
33///
34/// Note: NULL semantics is not part of this simplification logic.
35pub fn simplify_merged_partition_expr(expr: PartitionExpr) -> PartitionExpr {
36    try_simplify_merged_partition_expr(&expr).unwrap_or(expr)
37}
38
39type DenormValues = BTreeMap<String, BTreeMap<OrderedF64, Value>>;
40
41fn try_simplify_merged_partition_expr(expr: &PartitionExpr) -> Option<PartitionExpr> {
42    let collider = Collider::new(std::slice::from_ref(expr)).ok()?;
43    if collider.atomic_exprs.len() <= 1 {
44        return None;
45    }
46
47    let denorm_values = build_denorm_values(&collider)?;
48
49    let mut terms = Vec::with_capacity(collider.atomic_exprs.len());
50    for atomic in &collider.atomic_exprs {
51        terms.push(term_from_atomic(atomic, &denorm_values)?);
52    }
53
54    let terms = simplify_terms(terms)?;
55    build_expr_from_terms(&terms, &denorm_values)
56}
57
58fn build_denorm_values(collider: &Collider<'_>) -> Option<DenormValues> {
59    let mut values = DenormValues::new();
60    for (column, pairs) in &collider.normalized_values {
61        let mut map = BTreeMap::new();
62        for (value, normalized) in pairs {
63            // Keep simplification conservative for NULL semantics.
64            if matches!(value, Value::Null) {
65                return None;
66            }
67            map.insert(*normalized, value.clone());
68        }
69        values.insert(column.clone(), map);
70    }
71    Some(values)
72}
73
74fn term_from_atomic(atomic: &AtomicExpr, denorm_values: &DenormValues) -> Option<Term> {
75    let mut constraints = BTreeMap::new();
76
77    let mut i = 0;
78    while i < atomic.nucleons.len() {
79        let column = atomic.nucleons[i].column();
80        if !denorm_values.contains_key(column) {
81            return None;
82        }
83        let start = i;
84        while i < atomic.nucleons.len() && atomic.nucleons[i].column() == column {
85            i += 1;
86        }
87
88        let interval = interval_from_nucleons(&atomic.nucleons[start..i])?;
89        if !interval.is_unbounded() {
90            constraints.insert(column.to_string(), interval);
91        }
92    }
93
94    Some(Term { constraints })
95}
96
97fn interval_from_nucleons(nucleons: &[NucleonExpr]) -> Option<Interval> {
98    let mut interval = Interval::unbounded();
99    for nucleon in nucleons {
100        interval.apply_nucleon(nucleon.op(), nucleon.value())?;
101    }
102    Some(interval)
103}
104
105#[derive(Debug, Clone, PartialEq, Eq)]
106struct Term {
107    // Only stores constrained columns. Missing column means unbounded.
108    constraints: BTreeMap<String, Interval>,
109}
110
111impl Term {
112    fn is_subset_of(&self, other: &Term) -> bool {
113        // If `self` doesn't constrain a column that `other` does, `self` can't be a subset.
114        for (col, other_interval) in &other.constraints {
115            let Some(self_interval) = self.constraints.get(col) else {
116                return false;
117            };
118            if !self_interval.is_subset_of(other_interval) {
119                return false;
120            }
121        }
122        true
123    }
124}
125
126#[derive(Debug, Clone, PartialEq, Eq)]
127struct Interval {
128    lower: Bound<OrderedF64>,
129    upper: Bound<OrderedF64>,
130}
131
132impl Interval {
133    fn unbounded() -> Self {
134        Self {
135            lower: Bound::Unbounded,
136            upper: Bound::Unbounded,
137        }
138    }
139
140    fn is_unbounded(&self) -> bool {
141        matches!(self.lower, Bound::Unbounded) && matches!(self.upper, Bound::Unbounded)
142    }
143
144    fn apply_nucleon(&mut self, op: &GluonOp, value: OrderedF64) -> Option<()> {
145        match op {
146            GluonOp::Eq => {
147                // Ensure existing bounds contain `value`.
148                if !self.contains_value(&value) {
149                    return None;
150                }
151                self.lower = Bound::Included(value);
152                self.upper = Bound::Included(value);
153            }
154            GluonOp::Lt => self.update_upper(Bound::Excluded(value)),
155            GluonOp::LtEq => self.update_upper(Bound::Included(value)),
156            GluonOp::Gt => self.update_lower(Bound::Excluded(value)),
157            GluonOp::GtEq => self.update_lower(Bound::Included(value)),
158            GluonOp::NotEq => return None,
159        }
160
161        if self.is_empty() {
162            return None;
163        }
164        Some(())
165    }
166
167    fn contains_value(&self, value: &OrderedF64) -> bool {
168        // `value` is within [lower, upper] taking inclusiveness into account.
169        match &self.lower {
170            Bound::Unbounded => {}
171            Bound::Included(v) if value < v => return false,
172            Bound::Excluded(v) if value <= v => return false,
173            _ => {}
174        }
175        match &self.upper {
176            Bound::Unbounded => {}
177            Bound::Included(v) if value > v => return false,
178            Bound::Excluded(v) if value >= v => return false,
179            _ => {}
180        }
181        true
182    }
183
184    fn update_lower(&mut self, new_lower: Bound<OrderedF64>) {
185        if cmp_lower(&new_lower, &self.lower).is_gt() {
186            self.lower = new_lower;
187        }
188    }
189
190    fn update_upper(&mut self, new_upper: Bound<OrderedF64>) {
191        if cmp_upper(&new_upper, &self.upper).is_lt() {
192            self.upper = new_upper;
193        }
194    }
195
196    fn is_empty(&self) -> bool {
197        let (Bound::Included(lv) | Bound::Excluded(lv), Bound::Included(uv) | Bound::Excluded(uv)) =
198            (&self.lower, &self.upper)
199        else {
200            return false;
201        };
202
203        match lv.cmp(uv) {
204            std::cmp::Ordering::Less => false,
205            std::cmp::Ordering::Greater => true,
206            std::cmp::Ordering::Equal => {
207                matches!(self.lower, Bound::Excluded(_)) || matches!(self.upper, Bound::Excluded(_))
208            }
209        }
210    }
211
212    fn is_subset_of(&self, other: &Interval) -> bool {
213        // self.lower >= other.lower and self.upper <= other.upper
214        cmp_lower(&self.lower, &other.lower).is_ge() && cmp_upper(&self.upper, &other.upper).is_le()
215    }
216
217    fn union_if_mergeable(&self, other: &Interval) -> Option<UnionInterval> {
218        // Ensure `left` starts no later than `right`.
219        let (left, right) = match cmp_lower(&self.lower, &other.lower) {
220            std::cmp::Ordering::Greater => (other, self),
221            _ => (self, other),
222        };
223
224        if has_gap(&left.upper, &right.lower) {
225            return None;
226        }
227
228        let lower = left.lower;
229        let upper = union_upper(&left.upper, &right.upper);
230        let interval = Interval { lower, upper };
231        if interval.is_unbounded() {
232            Some(UnionInterval::Unbounded)
233        } else {
234            Some(UnionInterval::Interval(interval))
235        }
236    }
237}
238
239enum UnionInterval {
240    Unbounded,
241    Interval(Interval),
242}
243
244fn cmp_lower(a: &Bound<OrderedF64>, b: &Bound<OrderedF64>) -> std::cmp::Ordering {
245    // Lower bound ordering:
246    // - Unbounded is -∞ (smallest)
247    // - For the same value, Included is smaller (less restrictive) than Excluded.
248    fn lower_key(bound: &Bound<OrderedF64>) -> (u8, Option<OrderedF64>, u8) {
249        use Bound::*;
250
251        match bound {
252            Unbounded => (0, None, 0),
253            Included(v) => (1, Some(*v), 0),
254            Excluded(v) => (1, Some(*v), 1),
255        }
256    }
257
258    lower_key(a).cmp(&lower_key(b))
259}
260
261fn cmp_upper(a: &Bound<OrderedF64>, b: &Bound<OrderedF64>) -> std::cmp::Ordering {
262    // Upper bound ordering:
263    // - Unbounded is +∞ (largest)
264    // - For the same value, Excluded is smaller (more restrictive) than Included.
265    fn upper_key(bound: &Bound<OrderedF64>) -> (u8, Option<OrderedF64>, u8) {
266        use Bound::*;
267
268        match bound {
269            Unbounded => (1, None, 0),
270            Included(v) => (0, Some(*v), 1),
271            Excluded(v) => (0, Some(*v), 0),
272        }
273    }
274
275    upper_key(a).cmp(&upper_key(b))
276}
277
278fn has_gap(upper: &Bound<OrderedF64>, lower: &Bound<OrderedF64>) -> bool {
279    use Bound::*;
280
281    match (upper, lower) {
282        (Unbounded, _) | (_, Unbounded) => false,
283        (Included(u), Included(l)) => u < l,
284        (Included(u) | Excluded(u), Included(l) | Excluded(l)) => {
285            u < l || (u == l && matches!((upper, lower), (Excluded(_), Excluded(_))))
286        }
287    }
288}
289
290fn union_upper(a: &Bound<OrderedF64>, b: &Bound<OrderedF64>) -> Bound<OrderedF64> {
291    match cmp_upper(a, b) {
292        std::cmp::Ordering::Less => *b,
293        std::cmp::Ordering::Equal | std::cmp::Ordering::Greater => *a,
294    }
295}
296
297fn simplify_terms(mut terms: Vec<Term>) -> Option<Vec<Term>> {
298    // Dedup exact duplicates.
299    let mut unique = Vec::new();
300    for t in terms.drain(..) {
301        if !unique.contains(&t) {
302            unique.push(t);
303        }
304    }
305    terms = unique;
306
307    loop {
308        // Remove subsumed terms (more restrictive) in `OR`.
309        let mut to_remove = vec![false; terms.len()];
310        for i in 0..terms.len() {
311            for j in 0..terms.len() {
312                if i == j || to_remove[i] {
313                    continue;
314                }
315                if terms[i].is_subset_of(&terms[j]) {
316                    to_remove[i] = true;
317                }
318            }
319        }
320        let before = terms.len();
321        terms = terms
322            .into_iter()
323            .enumerate()
324            .filter_map(|(idx, t)| (!to_remove[idx]).then_some(t))
325            .collect();
326
327        // Try to merge a pair; restart on success.
328        let mut merged = None;
329        'outer: for i in 0..terms.len() {
330            for j in (i + 1)..terms.len() {
331                if let Some(t) = try_merge_terms(&terms[i], &terms[j]) {
332                    merged = Some((i, j, t));
333                    break 'outer;
334                }
335            }
336        }
337
338        if let Some((i, j, new_term)) = merged {
339            let mut next = Vec::with_capacity(terms.len() - 1);
340            for (idx, t) in terms.into_iter().enumerate() {
341                if idx != i && idx != j {
342                    next.push(t);
343                }
344            }
345            next.push(new_term);
346            terms = next;
347            continue;
348        }
349
350        // Stable point: no more merges and no more subsumption.
351        if terms.len() == before {
352            break;
353        }
354    }
355
356    Some(terms)
357}
358
359fn try_merge_terms(a: &Term, b: &Term) -> Option<Term> {
360    // Find the only differing column (treat missing as unbounded).
361    let mut diff_col: Option<&str> = None;
362
363    let mut cols = BTreeSet::new();
364    cols.extend(a.constraints.keys().map(|s| s.as_str()));
365    cols.extend(b.constraints.keys().map(|s| s.as_str()));
366
367    for col in cols {
368        let a_interval = a.constraints.get(col);
369        let b_interval = b.constraints.get(col);
370        if a_interval == b_interval {
371            continue;
372        }
373        if diff_col.is_some() {
374            return None;
375        }
376        diff_col = Some(col);
377    }
378
379    let diff_col = diff_col?;
380    let a_interval = a.constraints.get(diff_col)?;
381    let b_interval = b.constraints.get(diff_col)?;
382
383    let union = a_interval.union_if_mergeable(b_interval)?;
384    let mut constraints = a.constraints.clone();
385    match union {
386        UnionInterval::Unbounded => {
387            constraints.remove(diff_col);
388        }
389        UnionInterval::Interval(interval) => {
390            constraints.insert(diff_col.to_string(), interval);
391        }
392    }
393
394    Some(Term { constraints })
395}
396
397fn build_expr_from_terms(terms: &[Term], denorm_values: &DenormValues) -> Option<PartitionExpr> {
398    let mut term_exprs = Vec::with_capacity(terms.len());
399    for term in terms {
400        let expr = term_to_expr(term, denorm_values)?;
401        term_exprs.push(expr);
402    }
403
404    // Can't represent a tautology in `PartitionExpr`.
405    if term_exprs.is_empty() {
406        return None;
407    }
408
409    if term_exprs.len() == 1 {
410        return Some(term_exprs.pop().unwrap());
411    }
412
413    term_exprs.sort_by_key(|a| a.to_string());
414
415    let mut iter = term_exprs.into_iter();
416    let mut acc = iter.next()?;
417    for next in iter {
418        acc = PartitionExpr::new(Operand::Expr(acc), RestrictedOp::Or, Operand::Expr(next));
419    }
420    Some(acc)
421}
422
423fn term_to_expr(term: &Term, denorm_values: &DenormValues) -> Option<PartitionExpr> {
424    // Empty term would represent a tautology which can't be expressed here.
425    if term.constraints.is_empty() {
426        return None;
427    }
428
429    let mut exprs = Vec::new();
430    for (column, interval) in &term.constraints {
431        exprs.extend(interval_to_exprs(column, interval, denorm_values)?);
432    }
433
434    let mut iter = exprs.into_iter();
435    let mut acc = iter.next()?;
436    for next in iter {
437        acc = acc.and(next);
438    }
439    Some(acc)
440}
441
442fn interval_to_exprs(
443    column: &str,
444    interval: &Interval,
445    denorm_values: &DenormValues,
446) -> Option<Vec<PartitionExpr>> {
447    use Bound::*;
448
449    if interval.is_unbounded() {
450        return Some(vec![]);
451    }
452
453    let col_values = denorm_values.get(column)?;
454
455    let lower = &interval.lower;
456    let upper = &interval.upper;
457
458    match (lower, upper) {
459        (Included(lv), Included(uv)) if lv == uv => {
460            return Some(vec![col(column).eq(col_values.get(lv)?.clone())]);
461        }
462        (Excluded(lv), Excluded(uv)) if lv == uv => return None,
463        (Included(lv), Excluded(uv)) if lv == uv => return None,
464        (Excluded(lv), Included(uv)) if lv == uv => return None,
465        _ => {}
466    }
467
468    let mut exprs = Vec::new();
469    match lower {
470        Unbounded => {}
471        Included(v) => exprs.push(col(column).gt_eq(col_values.get(v)?.clone())),
472        Excluded(v) => exprs.push(col(column).gt(col_values.get(v)?.clone())),
473    }
474    match upper {
475        Unbounded => {}
476        Included(v) => exprs.push(col(column).lt_eq(col_values.get(v)?.clone())),
477        Excluded(v) => exprs.push(col(column).lt(col_values.get(v)?.clone())),
478    }
479
480    Some(exprs)
481}
482
483#[cfg(test)]
484mod tests {
485    use std::ops::Bound;
486
487    use datatypes::value::{OrderedFloat, Value};
488
489    use super::*;
490    use crate::expr::Operand;
491
492    fn or(lhs: PartitionExpr, rhs: PartitionExpr) -> PartitionExpr {
493        PartitionExpr::new(Operand::Expr(lhs), RestrictedOp::Or, Operand::Expr(rhs))
494    }
495
496    #[test]
497    fn simplify_common_factor_complement() {
498        // device_id < 100 AND area < 'South'
499        let left = col("device_id")
500            .lt(Value::Int32(100))
501            .and(col("area").lt(Value::String("South".into())));
502        // device_id < 100 AND area >= 'South'
503        let right = col("device_id")
504            .lt(Value::Int32(100))
505            .and(col("area").gt_eq(Value::String("South".into())));
506        let merged = or(left, right);
507        let simplified = simplify_merged_partition_expr(merged);
508        assert_eq!(simplified.to_string(), "device_id < 100");
509    }
510
511    #[test]
512    fn simplify_adjacent_ranges() {
513        // host < 'h0' OR (host >= 'h0' AND host < 'h1') -> host < 'h1'
514        let left = col("host").lt(Value::String("h0".into()));
515        let right = col("host")
516            .gt_eq(Value::String("h0".into()))
517            .and(col("host").lt(Value::String("h1".into())));
518        let merged = or(left, right);
519        let simplified = simplify_merged_partition_expr(merged);
520        assert_eq!(simplified.to_string(), "host < h1");
521    }
522
523    #[test]
524    fn simplify_drop_upper_bound() {
525        // a > 10 OR (a <= 10 AND a > 0) -> a > 0
526        let left = col("a").gt(Value::Int32(10));
527        let right = col("a")
528            .lt_eq(Value::Int32(10))
529            .and(col("a").gt(Value::Int32(0)));
530        let merged = or(left, right);
531        let simplified = simplify_merged_partition_expr(merged);
532        assert_eq!(simplified.to_string(), "a > 0");
533    }
534
535    #[test]
536    fn do_not_merge_hole_without_not_eq() {
537        // a < 10 OR a > 10 can't be simplified without `a <> 10`.
538        let left = col("a").lt(Value::Int32(10));
539        let right = col("a").gt(Value::Int32(10));
540        let merged = or(left, right);
541        let simplified = simplify_merged_partition_expr(merged.clone());
542        assert_eq!(simplified, merged);
543    }
544
545    #[test]
546    fn interval_bound_helpers() {
547        use std::cmp::Ordering::*;
548
549        use Bound::*;
550
551        let v0 = OrderedFloat(0.0f64);
552        let v1 = OrderedFloat(1.0f64);
553
554        // cmp_lower: Unbounded < Included(v) < Excluded(v) and increasing by value.
555        let lower_order = [
556            Unbounded,
557            Included(v0),
558            Excluded(v0),
559            Included(v1),
560            Excluded(v1),
561        ];
562        for pair in lower_order.windows(2) {
563            assert_eq!(cmp_lower(&pair[0], &pair[1]), Less);
564            assert_eq!(cmp_lower(&pair[1], &pair[0]), Greater);
565        }
566        for bound in &lower_order {
567            assert_eq!(cmp_lower(bound, bound), Equal);
568        }
569
570        // cmp_upper: Excluded(v) < Included(v) and increasing by value; Unbounded is +∞ (largest).
571        let upper_order = [
572            Excluded(v0),
573            Included(v0),
574            Excluded(v1),
575            Included(v1),
576            Unbounded,
577        ];
578        for pair in upper_order.windows(2) {
579            assert_eq!(cmp_upper(&pair[0], &pair[1]), Less);
580            assert_eq!(cmp_upper(&pair[1], &pair[0]), Greater);
581        }
582        for bound in &upper_order {
583            assert_eq!(cmp_upper(bound, bound), Equal);
584        }
585
586        // has_gap: Unbounded never contributes a gap.
587        assert!(!has_gap(&Unbounded, &Included(v0)));
588        assert!(!has_gap(&Excluded(v0), &Unbounded));
589        // Separated bounds always have a gap.
590        assert!(has_gap(&Included(v0), &Included(v1)));
591        assert!(has_gap(&Excluded(v0), &Included(v1)));
592        assert!(!has_gap(&Included(v1), &Included(v0)));
593        assert!(!has_gap(&Excluded(v1), &Included(v0)));
594        // Touching at boundary has a gap only if both ends exclude.
595        assert!(!has_gap(&Included(v0), &Included(v0)));
596        assert!(!has_gap(&Included(v0), &Excluded(v0)));
597        assert!(!has_gap(&Excluded(v0), &Included(v0)));
598        assert!(has_gap(&Excluded(v0), &Excluded(v0)));
599
600        // union_upper: choose the less restrictive upper bound (max under cmp_upper).
601        assert_eq!(union_upper(&Unbounded, &Included(v0)), Unbounded);
602        assert_eq!(union_upper(&Included(v0), &Unbounded), Unbounded);
603        assert_eq!(union_upper(&Included(v0), &Included(v1)), Included(v1));
604        assert_eq!(union_upper(&Excluded(v1), &Included(v0)), Excluded(v1));
605        assert_eq!(union_upper(&Excluded(v0), &Included(v0)), Included(v0));
606        assert_eq!(union_upper(&Included(v0), &Excluded(v0)), Included(v0));
607        assert_eq!(union_upper(&Excluded(v0), &Excluded(v0)), Excluded(v0));
608        assert_eq!(union_upper(&Included(v0), &Included(v0)), Included(v0));
609    }
610}