1use std::collections::{BTreeMap, BTreeSet};
22use std::ops::Bound;
23
24use datatypes::value::{OrderedF64, Value};
25
26use crate::collider::{AtomicExpr, Collider, GluonOp, NucleonExpr};
27use crate::expr::{Operand, PartitionExpr, RestrictedOp, col};
28
29pub fn simplify_merged_partition_expr(expr: PartitionExpr) -> PartitionExpr {
36 try_simplify_merged_partition_expr(&expr).unwrap_or(expr)
37}
38
39type DenormValues = BTreeMap<String, BTreeMap<OrderedF64, Value>>;
40
41fn try_simplify_merged_partition_expr(expr: &PartitionExpr) -> Option<PartitionExpr> {
42 let collider = Collider::new(std::slice::from_ref(expr)).ok()?;
43 if collider.atomic_exprs.len() <= 1 {
44 return None;
45 }
46
47 let denorm_values = build_denorm_values(&collider)?;
48
49 let mut terms = Vec::with_capacity(collider.atomic_exprs.len());
50 for atomic in &collider.atomic_exprs {
51 terms.push(term_from_atomic(atomic, &denorm_values)?);
52 }
53
54 let terms = simplify_terms(terms)?;
55 build_expr_from_terms(&terms, &denorm_values)
56}
57
58fn build_denorm_values(collider: &Collider<'_>) -> Option<DenormValues> {
59 let mut values = DenormValues::new();
60 for (column, pairs) in &collider.normalized_values {
61 let mut map = BTreeMap::new();
62 for (value, normalized) in pairs {
63 if matches!(value, Value::Null) {
65 return None;
66 }
67 map.insert(*normalized, value.clone());
68 }
69 values.insert(column.clone(), map);
70 }
71 Some(values)
72}
73
74fn term_from_atomic(atomic: &AtomicExpr, denorm_values: &DenormValues) -> Option<Term> {
75 let mut constraints = BTreeMap::new();
76
77 let mut i = 0;
78 while i < atomic.nucleons.len() {
79 let column = atomic.nucleons[i].column();
80 if !denorm_values.contains_key(column) {
81 return None;
82 }
83 let start = i;
84 while i < atomic.nucleons.len() && atomic.nucleons[i].column() == column {
85 i += 1;
86 }
87
88 let interval = interval_from_nucleons(&atomic.nucleons[start..i])?;
89 if !interval.is_unbounded() {
90 constraints.insert(column.to_string(), interval);
91 }
92 }
93
94 Some(Term { constraints })
95}
96
97fn interval_from_nucleons(nucleons: &[NucleonExpr]) -> Option<Interval> {
98 let mut interval = Interval::unbounded();
99 for nucleon in nucleons {
100 interval.apply_nucleon(nucleon.op(), nucleon.value())?;
101 }
102 Some(interval)
103}
104
105#[derive(Debug, Clone, PartialEq, Eq)]
106struct Term {
107 constraints: BTreeMap<String, Interval>,
109}
110
111impl Term {
112 fn is_subset_of(&self, other: &Term) -> bool {
113 for (col, other_interval) in &other.constraints {
115 let Some(self_interval) = self.constraints.get(col) else {
116 return false;
117 };
118 if !self_interval.is_subset_of(other_interval) {
119 return false;
120 }
121 }
122 true
123 }
124}
125
126#[derive(Debug, Clone, PartialEq, Eq)]
127struct Interval {
128 lower: Bound<OrderedF64>,
129 upper: Bound<OrderedF64>,
130}
131
132impl Interval {
133 fn unbounded() -> Self {
134 Self {
135 lower: Bound::Unbounded,
136 upper: Bound::Unbounded,
137 }
138 }
139
140 fn is_unbounded(&self) -> bool {
141 matches!(self.lower, Bound::Unbounded) && matches!(self.upper, Bound::Unbounded)
142 }
143
144 fn apply_nucleon(&mut self, op: &GluonOp, value: OrderedF64) -> Option<()> {
145 match op {
146 GluonOp::Eq => {
147 if !self.contains_value(&value) {
149 return None;
150 }
151 self.lower = Bound::Included(value);
152 self.upper = Bound::Included(value);
153 }
154 GluonOp::Lt => self.update_upper(Bound::Excluded(value)),
155 GluonOp::LtEq => self.update_upper(Bound::Included(value)),
156 GluonOp::Gt => self.update_lower(Bound::Excluded(value)),
157 GluonOp::GtEq => self.update_lower(Bound::Included(value)),
158 GluonOp::NotEq => return None,
159 }
160
161 if self.is_empty() {
162 return None;
163 }
164 Some(())
165 }
166
167 fn contains_value(&self, value: &OrderedF64) -> bool {
168 match &self.lower {
170 Bound::Unbounded => {}
171 Bound::Included(v) if value < v => return false,
172 Bound::Excluded(v) if value <= v => return false,
173 _ => {}
174 }
175 match &self.upper {
176 Bound::Unbounded => {}
177 Bound::Included(v) if value > v => return false,
178 Bound::Excluded(v) if value >= v => return false,
179 _ => {}
180 }
181 true
182 }
183
184 fn update_lower(&mut self, new_lower: Bound<OrderedF64>) {
185 if cmp_lower(&new_lower, &self.lower).is_gt() {
186 self.lower = new_lower;
187 }
188 }
189
190 fn update_upper(&mut self, new_upper: Bound<OrderedF64>) {
191 if cmp_upper(&new_upper, &self.upper).is_lt() {
192 self.upper = new_upper;
193 }
194 }
195
196 fn is_empty(&self) -> bool {
197 let (Bound::Included(lv) | Bound::Excluded(lv), Bound::Included(uv) | Bound::Excluded(uv)) =
198 (&self.lower, &self.upper)
199 else {
200 return false;
201 };
202
203 match lv.cmp(uv) {
204 std::cmp::Ordering::Less => false,
205 std::cmp::Ordering::Greater => true,
206 std::cmp::Ordering::Equal => {
207 matches!(self.lower, Bound::Excluded(_)) || matches!(self.upper, Bound::Excluded(_))
208 }
209 }
210 }
211
212 fn is_subset_of(&self, other: &Interval) -> bool {
213 cmp_lower(&self.lower, &other.lower).is_ge() && cmp_upper(&self.upper, &other.upper).is_le()
215 }
216
217 fn union_if_mergeable(&self, other: &Interval) -> Option<UnionInterval> {
218 let (left, right) = match cmp_lower(&self.lower, &other.lower) {
220 std::cmp::Ordering::Greater => (other, self),
221 _ => (self, other),
222 };
223
224 if has_gap(&left.upper, &right.lower) {
225 return None;
226 }
227
228 let lower = left.lower;
229 let upper = union_upper(&left.upper, &right.upper);
230 let interval = Interval { lower, upper };
231 if interval.is_unbounded() {
232 Some(UnionInterval::Unbounded)
233 } else {
234 Some(UnionInterval::Interval(interval))
235 }
236 }
237}
238
239enum UnionInterval {
240 Unbounded,
241 Interval(Interval),
242}
243
244fn cmp_lower(a: &Bound<OrderedF64>, b: &Bound<OrderedF64>) -> std::cmp::Ordering {
245 fn lower_key(bound: &Bound<OrderedF64>) -> (u8, Option<OrderedF64>, u8) {
249 use Bound::*;
250
251 match bound {
252 Unbounded => (0, None, 0),
253 Included(v) => (1, Some(*v), 0),
254 Excluded(v) => (1, Some(*v), 1),
255 }
256 }
257
258 lower_key(a).cmp(&lower_key(b))
259}
260
261fn cmp_upper(a: &Bound<OrderedF64>, b: &Bound<OrderedF64>) -> std::cmp::Ordering {
262 fn upper_key(bound: &Bound<OrderedF64>) -> (u8, Option<OrderedF64>, u8) {
266 use Bound::*;
267
268 match bound {
269 Unbounded => (1, None, 0),
270 Included(v) => (0, Some(*v), 1),
271 Excluded(v) => (0, Some(*v), 0),
272 }
273 }
274
275 upper_key(a).cmp(&upper_key(b))
276}
277
278fn has_gap(upper: &Bound<OrderedF64>, lower: &Bound<OrderedF64>) -> bool {
279 use Bound::*;
280
281 match (upper, lower) {
282 (Unbounded, _) | (_, Unbounded) => false,
283 (Included(u), Included(l)) => u < l,
284 (Included(u) | Excluded(u), Included(l) | Excluded(l)) => {
285 u < l || (u == l && matches!((upper, lower), (Excluded(_), Excluded(_))))
286 }
287 }
288}
289
290fn union_upper(a: &Bound<OrderedF64>, b: &Bound<OrderedF64>) -> Bound<OrderedF64> {
291 match cmp_upper(a, b) {
292 std::cmp::Ordering::Less => *b,
293 std::cmp::Ordering::Equal | std::cmp::Ordering::Greater => *a,
294 }
295}
296
297fn simplify_terms(mut terms: Vec<Term>) -> Option<Vec<Term>> {
298 let mut unique = Vec::new();
300 for t in terms.drain(..) {
301 if !unique.contains(&t) {
302 unique.push(t);
303 }
304 }
305 terms = unique;
306
307 loop {
308 let mut to_remove = vec![false; terms.len()];
310 for i in 0..terms.len() {
311 for j in 0..terms.len() {
312 if i == j || to_remove[i] {
313 continue;
314 }
315 if terms[i].is_subset_of(&terms[j]) {
316 to_remove[i] = true;
317 }
318 }
319 }
320 let before = terms.len();
321 terms = terms
322 .into_iter()
323 .enumerate()
324 .filter_map(|(idx, t)| (!to_remove[idx]).then_some(t))
325 .collect();
326
327 let mut merged = None;
329 'outer: for i in 0..terms.len() {
330 for j in (i + 1)..terms.len() {
331 if let Some(t) = try_merge_terms(&terms[i], &terms[j]) {
332 merged = Some((i, j, t));
333 break 'outer;
334 }
335 }
336 }
337
338 if let Some((i, j, new_term)) = merged {
339 let mut next = Vec::with_capacity(terms.len() - 1);
340 for (idx, t) in terms.into_iter().enumerate() {
341 if idx != i && idx != j {
342 next.push(t);
343 }
344 }
345 next.push(new_term);
346 terms = next;
347 continue;
348 }
349
350 if terms.len() == before {
352 break;
353 }
354 }
355
356 Some(terms)
357}
358
359fn try_merge_terms(a: &Term, b: &Term) -> Option<Term> {
360 let mut diff_col: Option<&str> = None;
362
363 let mut cols = BTreeSet::new();
364 cols.extend(a.constraints.keys().map(|s| s.as_str()));
365 cols.extend(b.constraints.keys().map(|s| s.as_str()));
366
367 for col in cols {
368 let a_interval = a.constraints.get(col);
369 let b_interval = b.constraints.get(col);
370 if a_interval == b_interval {
371 continue;
372 }
373 if diff_col.is_some() {
374 return None;
375 }
376 diff_col = Some(col);
377 }
378
379 let diff_col = diff_col?;
380 let a_interval = a.constraints.get(diff_col)?;
381 let b_interval = b.constraints.get(diff_col)?;
382
383 let union = a_interval.union_if_mergeable(b_interval)?;
384 let mut constraints = a.constraints.clone();
385 match union {
386 UnionInterval::Unbounded => {
387 constraints.remove(diff_col);
388 }
389 UnionInterval::Interval(interval) => {
390 constraints.insert(diff_col.to_string(), interval);
391 }
392 }
393
394 Some(Term { constraints })
395}
396
397fn build_expr_from_terms(terms: &[Term], denorm_values: &DenormValues) -> Option<PartitionExpr> {
398 let mut term_exprs = Vec::with_capacity(terms.len());
399 for term in terms {
400 let expr = term_to_expr(term, denorm_values)?;
401 term_exprs.push(expr);
402 }
403
404 if term_exprs.is_empty() {
406 return None;
407 }
408
409 if term_exprs.len() == 1 {
410 return Some(term_exprs.pop().unwrap());
411 }
412
413 term_exprs.sort_by_key(|a| a.to_string());
414
415 let mut iter = term_exprs.into_iter();
416 let mut acc = iter.next()?;
417 for next in iter {
418 acc = PartitionExpr::new(Operand::Expr(acc), RestrictedOp::Or, Operand::Expr(next));
419 }
420 Some(acc)
421}
422
423fn term_to_expr(term: &Term, denorm_values: &DenormValues) -> Option<PartitionExpr> {
424 if term.constraints.is_empty() {
426 return None;
427 }
428
429 let mut exprs = Vec::new();
430 for (column, interval) in &term.constraints {
431 exprs.extend(interval_to_exprs(column, interval, denorm_values)?);
432 }
433
434 let mut iter = exprs.into_iter();
435 let mut acc = iter.next()?;
436 for next in iter {
437 acc = acc.and(next);
438 }
439 Some(acc)
440}
441
442fn interval_to_exprs(
443 column: &str,
444 interval: &Interval,
445 denorm_values: &DenormValues,
446) -> Option<Vec<PartitionExpr>> {
447 use Bound::*;
448
449 if interval.is_unbounded() {
450 return Some(vec![]);
451 }
452
453 let col_values = denorm_values.get(column)?;
454
455 let lower = &interval.lower;
456 let upper = &interval.upper;
457
458 match (lower, upper) {
459 (Included(lv), Included(uv)) if lv == uv => {
460 return Some(vec![col(column).eq(col_values.get(lv)?.clone())]);
461 }
462 (Excluded(lv), Excluded(uv)) if lv == uv => return None,
463 (Included(lv), Excluded(uv)) if lv == uv => return None,
464 (Excluded(lv), Included(uv)) if lv == uv => return None,
465 _ => {}
466 }
467
468 let mut exprs = Vec::new();
469 match lower {
470 Unbounded => {}
471 Included(v) => exprs.push(col(column).gt_eq(col_values.get(v)?.clone())),
472 Excluded(v) => exprs.push(col(column).gt(col_values.get(v)?.clone())),
473 }
474 match upper {
475 Unbounded => {}
476 Included(v) => exprs.push(col(column).lt_eq(col_values.get(v)?.clone())),
477 Excluded(v) => exprs.push(col(column).lt(col_values.get(v)?.clone())),
478 }
479
480 Some(exprs)
481}
482
483#[cfg(test)]
484mod tests {
485 use std::ops::Bound;
486
487 use datatypes::value::{OrderedFloat, Value};
488
489 use super::*;
490 use crate::expr::Operand;
491
492 fn or(lhs: PartitionExpr, rhs: PartitionExpr) -> PartitionExpr {
493 PartitionExpr::new(Operand::Expr(lhs), RestrictedOp::Or, Operand::Expr(rhs))
494 }
495
496 #[test]
497 fn simplify_common_factor_complement() {
498 let left = col("device_id")
500 .lt(Value::Int32(100))
501 .and(col("area").lt(Value::String("South".into())));
502 let right = col("device_id")
504 .lt(Value::Int32(100))
505 .and(col("area").gt_eq(Value::String("South".into())));
506 let merged = or(left, right);
507 let simplified = simplify_merged_partition_expr(merged);
508 assert_eq!(simplified.to_string(), "device_id < 100");
509 }
510
511 #[test]
512 fn simplify_adjacent_ranges() {
513 let left = col("host").lt(Value::String("h0".into()));
515 let right = col("host")
516 .gt_eq(Value::String("h0".into()))
517 .and(col("host").lt(Value::String("h1".into())));
518 let merged = or(left, right);
519 let simplified = simplify_merged_partition_expr(merged);
520 assert_eq!(simplified.to_string(), "host < h1");
521 }
522
523 #[test]
524 fn simplify_drop_upper_bound() {
525 let left = col("a").gt(Value::Int32(10));
527 let right = col("a")
528 .lt_eq(Value::Int32(10))
529 .and(col("a").gt(Value::Int32(0)));
530 let merged = or(left, right);
531 let simplified = simplify_merged_partition_expr(merged);
532 assert_eq!(simplified.to_string(), "a > 0");
533 }
534
535 #[test]
536 fn do_not_merge_hole_without_not_eq() {
537 let left = col("a").lt(Value::Int32(10));
539 let right = col("a").gt(Value::Int32(10));
540 let merged = or(left, right);
541 let simplified = simplify_merged_partition_expr(merged.clone());
542 assert_eq!(simplified, merged);
543 }
544
545 #[test]
546 fn interval_bound_helpers() {
547 use std::cmp::Ordering::*;
548
549 use Bound::*;
550
551 let v0 = OrderedFloat(0.0f64);
552 let v1 = OrderedFloat(1.0f64);
553
554 let lower_order = [
556 Unbounded,
557 Included(v0),
558 Excluded(v0),
559 Included(v1),
560 Excluded(v1),
561 ];
562 for pair in lower_order.windows(2) {
563 assert_eq!(cmp_lower(&pair[0], &pair[1]), Less);
564 assert_eq!(cmp_lower(&pair[1], &pair[0]), Greater);
565 }
566 for bound in &lower_order {
567 assert_eq!(cmp_lower(bound, bound), Equal);
568 }
569
570 let upper_order = [
572 Excluded(v0),
573 Included(v0),
574 Excluded(v1),
575 Included(v1),
576 Unbounded,
577 ];
578 for pair in upper_order.windows(2) {
579 assert_eq!(cmp_upper(&pair[0], &pair[1]), Less);
580 assert_eq!(cmp_upper(&pair[1], &pair[0]), Greater);
581 }
582 for bound in &upper_order {
583 assert_eq!(cmp_upper(bound, bound), Equal);
584 }
585
586 assert!(!has_gap(&Unbounded, &Included(v0)));
588 assert!(!has_gap(&Excluded(v0), &Unbounded));
589 assert!(has_gap(&Included(v0), &Included(v1)));
591 assert!(has_gap(&Excluded(v0), &Included(v1)));
592 assert!(!has_gap(&Included(v1), &Included(v0)));
593 assert!(!has_gap(&Excluded(v1), &Included(v0)));
594 assert!(!has_gap(&Included(v0), &Included(v0)));
596 assert!(!has_gap(&Included(v0), &Excluded(v0)));
597 assert!(!has_gap(&Excluded(v0), &Included(v0)));
598 assert!(has_gap(&Excluded(v0), &Excluded(v0)));
599
600 assert_eq!(union_upper(&Unbounded, &Included(v0)), Unbounded);
602 assert_eq!(union_upper(&Included(v0), &Unbounded), Unbounded);
603 assert_eq!(union_upper(&Included(v0), &Included(v1)), Included(v1));
604 assert_eq!(union_upper(&Excluded(v1), &Included(v0)), Excluded(v1));
605 assert_eq!(union_upper(&Excluded(v0), &Included(v0)), Included(v0));
606 assert_eq!(union_upper(&Included(v0), &Excluded(v0)), Included(v0));
607 assert_eq!(union_upper(&Excluded(v0), &Excluded(v0)), Excluded(v0));
608 assert_eq!(union_upper(&Included(v0), &Included(v0)), Included(v0));
609 }
610}