query/dist_plan/
region_pruner.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [`ConstraintPruner`] prunes partition info based on given expressions.
16
17use ahash::{HashMap, HashSet};
18use common_telemetry::debug;
19use datatypes::prelude::ConcreteDataType;
20use datatypes::value::{OrderedFloat, Value};
21use partition::collider::{AtomicExpr, Collider};
22use partition::expr::{Operand, PartitionExpr};
23use partition::manager::PartitionInfo;
24use partition::overlap::atomic_exprs_overlap;
25use store_api::storage::RegionId;
26
27use crate::error::Result;
28
29pub struct ConstraintPruner;
30
31impl ConstraintPruner {
32    /// Prune regions using constraint satisfaction approach
33    ///
34    /// Takes query expressions and partition info, returns matching region IDs
35    pub fn prune_regions(
36        query_expressions: &[PartitionExpr],
37        partitions: &[PartitionInfo],
38        column_datatypes: HashMap<String, ConcreteDataType>,
39    ) -> Result<Vec<RegionId>> {
40        let start = std::time::Instant::now();
41        if query_expressions.is_empty() || partitions.is_empty() {
42            // No constraints, return all regions
43            return Ok(partitions.iter().map(|p| p.id).collect());
44        }
45
46        // Collect all partition expressions for unified normalization
47        let mut expression_to_partition = Vec::with_capacity(partitions.len());
48        let mut all_partition_expressions = Vec::with_capacity(partitions.len());
49        for partition in partitions {
50            if let Some(expr) = &partition.partition_expr {
51                expression_to_partition.push(partition.id);
52                all_partition_expressions.push(expr.clone());
53            }
54        }
55        if all_partition_expressions.is_empty() {
56            return Ok(partitions.iter().map(|p| p.id).collect());
57        }
58
59        // Create unified collider with both query and partition expressions for consistent normalization
60        let mut all_expressions = query_expressions.to_vec();
61        all_expressions.extend(all_partition_expressions.iter().cloned());
62        if !Self::normalize_datatype(&mut all_expressions, &column_datatypes) {
63            return Ok(partitions.iter().map(|p| p.id).collect());
64        }
65
66        let collider = match Collider::new(&all_expressions) {
67            Ok(collider) => collider,
68            Err(err) => {
69                debug!(
70                    "Failed to create unified collider: {}, returning all regions conservatively",
71                    err
72                );
73                return Ok(partitions.iter().map(|p| p.id).collect());
74            }
75        };
76
77        // Extract query atomic expressions (first N expressions in the collider)
78        let query_atomics: Vec<&AtomicExpr> = collider
79            .atomic_exprs
80            .iter()
81            .filter(|atomic| atomic.source_expr_index < query_expressions.len())
82            .collect();
83
84        let mut candidate_regions = HashSet::default();
85
86        for region_atomics in collider
87            .atomic_exprs
88            .iter()
89            .filter(|atomic| atomic.source_expr_index >= query_expressions.len())
90        {
91            if Self::atomic_sets_overlap(&query_atomics, region_atomics) {
92                let partition_expr_index =
93                    region_atomics.source_expr_index - query_expressions.len();
94                candidate_regions.insert(expression_to_partition[partition_expr_index]);
95            }
96        }
97
98        debug!(
99            "Constraint pruning (cost {}ms): {} -> {} regions",
100            start.elapsed().as_millis(),
101            partitions.len(),
102            candidate_regions.len()
103        );
104
105        Ok(candidate_regions.into_iter().collect())
106    }
107
108    fn atomic_sets_overlap(query_atomics: &[&AtomicExpr], partition_atomic: &AtomicExpr) -> bool {
109        query_atomics
110            .iter()
111            .any(|qa| atomic_exprs_overlap(qa, partition_atomic))
112    }
113
114    fn normalize_datatype(
115        all_expressions: &mut Vec<PartitionExpr>,
116        column_datatypes: &HashMap<String, ConcreteDataType>,
117    ) -> bool {
118        for expr in all_expressions {
119            if !Self::normalize_expr_datatype(&mut expr.lhs, &mut expr.rhs, column_datatypes) {
120                return false;
121            }
122        }
123        true
124    }
125
126    fn normalize_expr_datatype(
127        lhs: &mut Operand,
128        rhs: &mut Operand,
129        column_datatypes: &HashMap<String, ConcreteDataType>,
130    ) -> bool {
131        match (lhs, rhs) {
132            (Operand::Expr(lhs_expr), Operand::Expr(rhs_expr)) => {
133                Self::normalize_expr_datatype(
134                    &mut lhs_expr.lhs,
135                    &mut lhs_expr.rhs,
136                    column_datatypes,
137                ) && Self::normalize_expr_datatype(
138                    &mut rhs_expr.lhs,
139                    &mut rhs_expr.rhs,
140                    column_datatypes,
141                )
142            }
143            (Operand::Column(col_name), Operand::Value(val))
144            | (Operand::Value(val), Operand::Column(col_name)) => {
145                let Some(datatype) = column_datatypes.get(col_name) else {
146                    debug!("Column {} not found from type set, skip pruning", col_name);
147                    return false;
148                };
149
150                match datatype {
151                    ConcreteDataType::Int8(_)
152                    | ConcreteDataType::Int16(_)
153                    | ConcreteDataType::Int32(_)
154                    | ConcreteDataType::Int64(_) => {
155                        let Some(new_lit) = val.as_i64() else {
156                            debug!("Value {:?} cannot be converted to i64", val);
157                            return false;
158                        };
159                        *val = Value::Int64(new_lit);
160                    }
161
162                    ConcreteDataType::UInt8(_)
163                    | ConcreteDataType::UInt16(_)
164                    | ConcreteDataType::UInt32(_)
165                    | ConcreteDataType::UInt64(_) => {
166                        let Some(new_lit) = val.as_u64() else {
167                            debug!("Value {:?} cannot be converted to u64", val);
168                            return false;
169                        };
170                        *val = Value::UInt64(new_lit);
171                    }
172
173                    ConcreteDataType::Float32(_) | ConcreteDataType::Float64(_) => {
174                        let Some(new_lit) = val.as_f64_lossy() else {
175                            debug!("Value {:?} cannot be converted to f64", val);
176                            return false;
177                        };
178
179                        *val = Value::Float64(OrderedFloat(new_lit));
180                    }
181
182                    ConcreteDataType::String(_) | ConcreteDataType::Boolean(_) => {
183                        // no operation needed
184                    }
185
186                    ConcreteDataType::Decimal128(_)
187                    | ConcreteDataType::Binary(_)
188                    | ConcreteDataType::Date(_)
189                    | ConcreteDataType::Timestamp(_)
190                    | ConcreteDataType::Time(_)
191                    | ConcreteDataType::Duration(_)
192                    | ConcreteDataType::Interval(_)
193                    | ConcreteDataType::List(_)
194                    | ConcreteDataType::Dictionary(_)
195                    | ConcreteDataType::Struct(_)
196                    | ConcreteDataType::Json(_)
197                    | ConcreteDataType::Null(_)
198                    | ConcreteDataType::Vector(_) => {
199                        debug!("Unsupported data type {datatype}");
200                        return false;
201                    }
202                }
203
204                true
205            }
206            _ => false,
207        }
208    }
209}
210// Value range and atomic overlap logic is now refactored into `partition::diff`.
211
212#[cfg(test)]
213mod tests {
214    use datatypes::value::Value;
215    use partition::expr::{Operand, PartitionExpr, RestrictedOp, col};
216    use store_api::storage::RegionId;
217
218    use super::*;
219
220    fn create_test_partition_info(region_id: u64, expr: Option<PartitionExpr>) -> PartitionInfo {
221        PartitionInfo {
222            id: RegionId::new(1, region_id as u32),
223            partition_expr: expr,
224        }
225    }
226
227    #[test]
228    fn test_constraint_pruning_equality() {
229        let partitions = vec![
230            // Region 1: user_id >= 0 AND user_id < 100
231            create_test_partition_info(
232                1,
233                Some(
234                    col("user_id")
235                        .gt_eq(Value::Int64(0))
236                        .and(col("user_id").lt(Value::Int64(100))),
237                ),
238            ),
239            // Region 2: user_id >= 100 AND user_id < 200
240            create_test_partition_info(
241                2,
242                Some(
243                    col("user_id")
244                        .gt_eq(Value::Int64(100))
245                        .and(col("user_id").lt(Value::Int64(200))),
246                ),
247            ),
248            // Region 3: user_id >= 200 AND user_id < 300
249            create_test_partition_info(
250                3,
251                Some(
252                    col("user_id")
253                        .gt_eq(Value::Int64(200))
254                        .and(col("user_id").lt(Value::Int64(300))),
255                ),
256            ),
257        ];
258
259        // Query: user_id = 150 (should only match Region 2)
260        let query_exprs = vec![col("user_id").eq(Value::Int64(150))];
261        let mut column_datatypes = HashMap::default();
262        column_datatypes.insert("user_id".to_string(), ConcreteDataType::int64_datatype());
263        let pruned =
264            ConstraintPruner::prune_regions(&query_exprs, &partitions, column_datatypes).unwrap();
265
266        // Should include Region 2, and potentially others due to conservative approach
267        assert!(pruned.contains(&RegionId::new(1, 2)));
268    }
269
270    #[test]
271    fn test_constraint_pruning_in_list() {
272        let partitions = vec![
273            // Region 1: user_id >= 0 AND user_id < 100
274            create_test_partition_info(
275                1,
276                Some(
277                    col("user_id")
278                        .gt_eq(Value::Int64(0))
279                        .and(col("user_id").lt(Value::Int64(100))),
280                ),
281            ),
282            // Region 2: user_id >= 100 AND user_id < 200
283            create_test_partition_info(
284                2,
285                Some(
286                    col("user_id")
287                        .gt_eq(Value::Int64(100))
288                        .and(col("user_id").lt(Value::Int64(200))),
289                ),
290            ),
291            // Region 3: user_id >= 200 AND user_id < 300
292            create_test_partition_info(
293                3,
294                Some(
295                    col("user_id")
296                        .gt_eq(Value::Int64(200))
297                        .and(col("user_id").lt(Value::Int64(300))),
298                ),
299            ),
300        ];
301
302        // Query: user_id IN (50, 150, 250) - should match all regions
303        let query_exprs = vec![PartitionExpr::new(
304            Operand::Expr(PartitionExpr::new(
305                Operand::Expr(col("user_id").eq(Value::Int64(50))),
306                RestrictedOp::Or,
307                Operand::Expr(col("user_id").eq(Value::Int64(150))),
308            )),
309            RestrictedOp::Or,
310            Operand::Expr(col("user_id").eq(Value::Int64(250))),
311        )];
312
313        let mut column_datatypes = HashMap::default();
314        column_datatypes.insert("user_id".to_string(), ConcreteDataType::int64_datatype());
315        let pruned =
316            ConstraintPruner::prune_regions(&query_exprs, &partitions, column_datatypes).unwrap();
317
318        // Should include regions that can satisfy any of the values
319        assert!(!pruned.is_empty());
320    }
321
322    #[test]
323    fn test_constraint_pruning_range() {
324        let partitions = vec![
325            // Region 1: user_id >= 0 AND user_id < 100
326            create_test_partition_info(
327                1,
328                Some(
329                    col("user_id")
330                        .gt_eq(Value::Int64(0))
331                        .and(col("user_id").lt(Value::Int64(100))),
332                ),
333            ),
334            // Region 2: user_id >= 100 AND user_id < 200
335            create_test_partition_info(
336                2,
337                Some(
338                    col("user_id")
339                        .gt_eq(Value::Int64(100))
340                        .and(col("user_id").lt(Value::Int64(200))),
341                ),
342            ),
343            // Region 3: user_id >= 200 AND user_id < 300
344            create_test_partition_info(
345                3,
346                Some(
347                    col("user_id")
348                        .gt_eq(Value::Int64(200))
349                        .and(col("user_id").lt(Value::Int64(300))),
350                ),
351            ),
352        ];
353
354        // Query: user_id >= 150 (should include regions that can satisfy this constraint)
355        let query_exprs = vec![col("user_id").gt_eq(Value::Int64(150))];
356        let mut column_datatypes = HashMap::default();
357        column_datatypes.insert("user_id".to_string(), ConcreteDataType::int64_datatype());
358        let pruned =
359            ConstraintPruner::prune_regions(&query_exprs, &partitions, column_datatypes).unwrap();
360
361        // With constraint-based approach:
362        // Region 1: [0, 100) - user_id >= 150 is not satisfiable
363        // Region 2: [100, 200) - user_id >= 150 is satisfiable in range [150, 200)
364        // Region 3: [200, 300) - user_id >= 150 is satisfiable (all values >= 200 satisfy user_id >= 150)
365        // Conservative approach may include more regions, but should at least include regions 2 and 3
366        assert!(pruned.len() >= 2);
367        assert!(pruned.contains(&RegionId::new(1, 2))); // Region 2 should be included
368        assert!(pruned.contains(&RegionId::new(1, 3))); // Region 3 should be included
369    }
370
371    #[test]
372    fn test_prune_regions_no_constraints() {
373        let partitions = vec![
374            create_test_partition_info(1, None),
375            create_test_partition_info(2, None),
376        ];
377
378        let constraints = vec![];
379        let column_datatypes = HashMap::default();
380        let pruned =
381            ConstraintPruner::prune_regions(&constraints, &partitions, column_datatypes).unwrap();
382
383        // No constraints should return all regions
384        assert_eq!(pruned.len(), 2);
385    }
386
387    #[test]
388    fn test_prune_regions_with_simple_equality() {
389        let partitions = vec![
390            // Region 1: user_id >= 0 AND user_id < 100
391            create_test_partition_info(
392                1,
393                Some(
394                    col("user_id")
395                        .gt_eq(Value::Int64(0))
396                        .and(col("user_id").lt(Value::Int64(100))),
397                ),
398            ),
399            // Region 2: user_id >= 100 AND user_id < 200
400            create_test_partition_info(
401                2,
402                Some(
403                    col("user_id")
404                        .gt_eq(Value::Int64(100))
405                        .and(col("user_id").lt(Value::Int64(200))),
406                ),
407            ),
408            // Region 3: user_id >= 200 AND user_id < 300
409            create_test_partition_info(
410                3,
411                Some(
412                    col("user_id")
413                        .gt_eq(Value::Int64(200))
414                        .and(col("user_id").lt(Value::Int64(300))),
415                ),
416            ),
417        ];
418
419        // Query: user_id = 150 (should only match Region 2 which contains values [100, 200))
420        let query_exprs = vec![col("user_id").eq(Value::Int64(150))];
421        let mut column_datatypes = HashMap::default();
422        column_datatypes.insert("user_id".to_string(), ConcreteDataType::int64_datatype());
423        let pruned =
424            ConstraintPruner::prune_regions(&query_exprs, &partitions, column_datatypes).unwrap();
425
426        // user_id = 150 should match Region 2 ([100, 200)) and potentially others due to conservative approach
427        assert!(pruned.contains(&RegionId::new(1, 2)));
428    }
429
430    #[test]
431    fn test_prune_regions_with_or_constraint() {
432        let partitions = vec![
433            // Region 1: user_id >= 0 AND user_id < 100
434            create_test_partition_info(
435                1,
436                Some(
437                    col("user_id")
438                        .gt_eq(Value::Int64(0))
439                        .and(col("user_id").lt(Value::Int64(100))),
440                ),
441            ),
442            // Region 2: user_id >= 100 AND user_id < 200
443            create_test_partition_info(
444                2,
445                Some(
446                    col("user_id")
447                        .gt_eq(Value::Int64(100))
448                        .and(col("user_id").lt(Value::Int64(200))),
449                ),
450            ),
451            // Region 3: user_id >= 200 AND user_id < 300
452            create_test_partition_info(
453                3,
454                Some(
455                    col("user_id")
456                        .gt_eq(Value::Int64(200))
457                        .and(col("user_id").lt(Value::Int64(300))),
458                ),
459            ),
460        ];
461
462        // Query: user_id = 50 OR user_id = 150 OR user_id = 250 - should match all 3 regions
463        let expr1 = col("user_id").eq(Value::Int64(50));
464        let expr2 = col("user_id").eq(Value::Int64(150));
465        let expr3 = col("user_id").eq(Value::Int64(250));
466
467        let or_expr = PartitionExpr::new(
468            Operand::Expr(PartitionExpr::new(
469                Operand::Expr(expr1),
470                RestrictedOp::Or,
471                Operand::Expr(expr2),
472            )),
473            RestrictedOp::Or,
474            Operand::Expr(expr3),
475        );
476
477        let query_exprs = vec![or_expr];
478        let mut column_datatypes = HashMap::default();
479        column_datatypes.insert("user_id".to_string(), ConcreteDataType::int64_datatype());
480        let pruned =
481            ConstraintPruner::prune_regions(&query_exprs, &partitions, column_datatypes).unwrap();
482
483        // Should match all 3 regions: 50 matches Region 1, 150 matches Region 2, 250 matches Region 3
484        assert_eq!(pruned.len(), 3);
485        assert!(pruned.contains(&RegionId::new(1, 1)));
486        assert!(pruned.contains(&RegionId::new(1, 2)));
487        assert!(pruned.contains(&RegionId::new(1, 3)));
488    }
489
490    #[test]
491    fn test_constraint_pruning_no_match() {
492        let partitions = vec![
493            // Region 1: user_id >= 0 AND user_id < 100
494            create_test_partition_info(
495                1,
496                Some(
497                    col("user_id")
498                        .gt_eq(Value::Int64(0))
499                        .and(col("user_id").lt(Value::Int64(100))),
500                ),
501            ),
502            // Region 2: user_id >= 100 AND user_id < 200
503            create_test_partition_info(
504                2,
505                Some(
506                    col("user_id")
507                        .gt_eq(Value::Int64(100))
508                        .and(col("user_id").lt(Value::Int64(200))),
509                ),
510            ),
511        ];
512
513        // Query: user_id = 300 (should match no regions)
514        let query_exprs = vec![col("user_id").eq(Value::Int64(300))];
515        let mut column_datatypes = HashMap::default();
516        column_datatypes.insert("user_id".to_string(), ConcreteDataType::int64_datatype());
517        let pruned =
518            ConstraintPruner::prune_regions(&query_exprs, &partitions, column_datatypes).unwrap();
519
520        // Should match no regions since 300 is outside both partition ranges
521        assert_eq!(pruned.len(), 0);
522    }
523
524    #[test]
525    fn test_constraint_pruning_partial_match() {
526        let partitions = vec![
527            // Region 1: user_id >= 0 AND user_id < 100
528            create_test_partition_info(
529                1,
530                Some(
531                    col("user_id")
532                        .gt_eq(Value::Int64(0))
533                        .and(col("user_id").lt(Value::Int64(100))),
534                ),
535            ),
536            // Region 2: user_id >= 100 AND user_id < 200
537            create_test_partition_info(
538                2,
539                Some(
540                    col("user_id")
541                        .gt_eq(Value::Int64(100))
542                        .and(col("user_id").lt(Value::Int64(200))),
543                ),
544            ),
545        ];
546
547        // Query: user_id >= 50 (should match both regions partially)
548        let query_exprs = vec![col("user_id").gt_eq(Value::Int64(50))];
549        let mut column_datatypes = HashMap::default();
550        column_datatypes.insert("user_id".to_string(), ConcreteDataType::int64_datatype());
551        let pruned =
552            ConstraintPruner::prune_regions(&query_exprs, &partitions, column_datatypes).unwrap();
553
554        // Region 1: [0,100) intersects with [50,∞) -> includes [50,100)
555        // Region 2: [100,200) is fully contained in [50,∞)
556        assert_eq!(pruned.len(), 2);
557        assert!(pruned.contains(&RegionId::new(1, 1)));
558        assert!(pruned.contains(&RegionId::new(1, 2)));
559    }
560}