partition/
checker.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16use std::sync::Arc;
17
18use datatypes::arrow::array::{BooleanArray, Float64Array, Float64Builder, RecordBatch};
19use datatypes::arrow::datatypes::{DataType, Field, Schema};
20use datatypes::value::OrderedF64;
21
22use crate::collider::{Collider, CHECK_STEP, NORMALIZE_STEP};
23use crate::error::{
24    CheckpointNotCoveredSnafu, CheckpointOverlappedSnafu, DuplicateExprSnafu, Result,
25};
26use crate::expr::{PartitionExpr, RestrictedOp};
27use crate::multi_dim::MultiDimPartitionRule;
28
29pub struct PartitionChecker<'a> {
30    rule: &'a MultiDimPartitionRule,
31    collider: Collider<'a>,
32}
33
34impl<'a> PartitionChecker<'a> {
35    pub fn try_new(rule: &'a MultiDimPartitionRule) -> Result<Self> {
36        let collider = Collider::new(rule.exprs())?;
37        Ok(Self { rule, collider })
38    }
39
40    pub fn check(&self) -> Result<()> {
41        self.run()?;
42        Ok(())
43    }
44}
45
46// Logic of checking rules
47impl<'a> PartitionChecker<'a> {
48    fn run(&self) -> Result<()> {
49        // Sort atomic exprs and check uniqueness
50        let mut atomic_exprs = BTreeMap::new();
51        for expr in self.collider.atomic_exprs.iter() {
52            let key = &expr.nucleons;
53            atomic_exprs.insert(key, expr);
54        }
55        if atomic_exprs.len() != self.collider.atomic_exprs.len() {
56            // Find the duplication for error message
57            for expr in self.collider.atomic_exprs.iter() {
58                if atomic_exprs.get(&expr.nucleons).unwrap().source_expr_index
59                    != expr.source_expr_index
60                {
61                    let expr = self.rule.exprs()[expr.source_expr_index].clone();
62                    return DuplicateExprSnafu { expr }.fail();
63                }
64            }
65            // Or return a placeholder. This should never happen.
66            return DuplicateExprSnafu {
67                expr: PartitionExpr::new(
68                    crate::expr::Operand::Column("unknown".to_string()),
69                    RestrictedOp::Eq,
70                    crate::expr::Operand::Column("expr".to_string()),
71                ),
72            }
73            .fail();
74        }
75
76        // TODO(ruihang): merge atomic exprs to improve checker's performance
77
78        // matrix test
79        let mut matrix_foundation = HashMap::new();
80        for (col, values) in self.collider.normalized_values.iter() {
81            if values.is_empty() {
82                continue;
83            }
84
85            let mut cornerstones = Vec::with_capacity(values.len() * 2 + 1);
86            cornerstones.push(values[0].1 - CHECK_STEP);
87            for value in values {
88                cornerstones.push(value.1);
89                cornerstones.push(value.1 + CHECK_STEP);
90            }
91            matrix_foundation.insert(col.as_str(), cornerstones);
92        }
93
94        // If there are no values, the rule is empty and valid.
95        if matrix_foundation.is_empty() {
96            return Ok(());
97        }
98
99        let matrix_generator = MatrixGenerator::new(matrix_foundation);
100
101        // Process data in batches using iterator
102        let mut results = Vec::with_capacity(self.collider.atomic_exprs.len());
103        let physical_exprs = self
104            .collider
105            .atomic_exprs
106            .iter()
107            .map(|expr| expr.to_physical_expr(matrix_generator.schema()))
108            .collect::<Vec<_>>();
109        for batch in matrix_generator {
110            results.clear();
111            for physical_expr in &physical_exprs {
112                let columnar_result = physical_expr.evaluate(&batch).unwrap();
113                let array_result = columnar_result.into_array(batch.num_rows()).unwrap();
114                results.push(array_result);
115            }
116            let boolean_results = results
117                .iter()
118                .map(|result| result.as_any().downcast_ref::<BooleanArray>().unwrap())
119                .collect::<Vec<_>>();
120
121            // sum and check results for this batch
122            for i in 0..batch.num_rows() {
123                let mut true_count = 0;
124                for result in boolean_results.iter() {
125                    if result.value(i) {
126                        true_count += 1;
127                    }
128                }
129
130                if true_count == 0 {
131                    return CheckpointNotCoveredSnafu {
132                        checkpoint: self.remap_checkpoint(i, &batch),
133                    }
134                    .fail();
135                } else if true_count > 1 {
136                    return CheckpointOverlappedSnafu {
137                        checkpoint: self.remap_checkpoint(i, &batch),
138                    }
139                    .fail();
140                }
141            }
142        }
143
144        Ok(())
145    }
146
147    /// Remap the normalized checkpoint data to the original values.
148    fn remap_checkpoint(&self, i: usize, batch: &RecordBatch) -> String {
149        let normalized_row = batch
150            .columns()
151            .iter()
152            .map(|col| {
153                let array = col.as_any().downcast_ref::<Float64Array>().unwrap();
154                array.value(i)
155            })
156            .collect::<Vec<_>>();
157
158        let mut check_point = String::new();
159        let schema = batch.schema();
160        for (col_index, normalized_value) in normalized_row.iter().enumerate() {
161            let col_name = schema.field(col_index).name();
162
163            if col_index > 0 {
164                check_point.push_str(", ");
165            }
166
167            // Check if point is on NORMALIZE_STEP or between steps
168            if let Some(values) = self.collider.normalized_values.get(col_name) {
169                let normalize_step = NORMALIZE_STEP.0;
170
171                // Check if the normalized value is on a NORMALIZE_STEP boundary
172                let remainder = normalized_value % normalize_step;
173                let is_on_step = remainder.abs() < f64::EPSILON
174                    || (normalize_step - remainder).abs() < f64::EPSILON * 2.0;
175
176                if is_on_step {
177                    let index = (normalized_value / normalize_step).round() as usize;
178                    if index < values.len() {
179                        let original_value = &values[index].0;
180                        check_point.push_str(&format!("{}={}", col_name, original_value));
181                    } else {
182                        check_point.push_str(&format!("{}=unknown", col_name));
183                    }
184                } else {
185                    let lower_index = (normalized_value / normalize_step).floor() as usize;
186                    let upper_index = (normalized_value / normalize_step).ceil() as usize;
187
188                    // Handle edge cases: value is outside the valid range
189                    if lower_index == upper_index && lower_index == 0 {
190                        // Value is less than the first value
191                        let first_original = &values[0].0;
192                        check_point.push_str(&format!("{}<{}", col_name, first_original));
193                    } else if upper_index == values.len() {
194                        // Value is greater than the last value
195                        let last_original = &values[values.len() - 1].0;
196                        check_point.push_str(&format!("{}>{}", col_name, last_original));
197                    } else {
198                        // Normal case: value is between two valid values
199                        let lower_original = if lower_index < values.len() {
200                            values[lower_index].0.to_string()
201                        } else {
202                            "unknown".to_string()
203                        };
204
205                        let upper_original = if upper_index < values.len() {
206                            values[upper_index].0.to_string()
207                        } else {
208                            "unknown".to_string()
209                        };
210
211                        check_point.push_str(&format!(
212                            "{}<{}<{}",
213                            lower_original, col_name, upper_original
214                        ));
215                    }
216                }
217            } else {
218                // Fallback if column not found in normalized values
219                check_point.push_str(&format!("{}:unknown", col_name));
220            }
221        }
222
223        check_point
224    }
225}
226
227/// Generates a point matrix that contains permutations of `matrix_foundation`'s values
228struct MatrixGenerator {
229    matrix_foundation: HashMap<String, Vec<OrderedF64>>,
230    // Iterator state
231    current_index: usize,
232    schema: Schema,
233    column_names: Vec<String>,
234    // Preprocessed attributes
235    /// Total number of combinations of `matrix_foundation`'s values
236    total_combinations: usize,
237    /// Biased suffix product of `matrix_foundation`'s values
238    ///
239    /// The i-th element is the product of the sizes of all columns after the i-th column.
240    /// For example, if `matrix_foundation` is `{"a": [1, 2, 3], "b": [4, 5, 6]}`,
241    /// then `biased_suffix_product` is `[3, 1]`.
242    biased_suffix_product: Vec<usize>,
243}
244
245const MAX_BATCH_SIZE: usize = 8192;
246
247impl MatrixGenerator {
248    pub fn new(matrix_foundation: HashMap<&str, Vec<OrderedF64>>) -> Self {
249        // Convert to owned HashMap to avoid lifetime issues
250        let owned_matrix_foundation: HashMap<String, Vec<OrderedF64>> = matrix_foundation
251            .into_iter()
252            .map(|(k, v)| (k.to_string(), v))
253            .collect();
254
255        let mut fields = owned_matrix_foundation
256            .keys()
257            .map(|k| Field::new(k.clone(), DataType::Float64, false))
258            .collect::<Vec<_>>();
259        fields.sort_unstable();
260        let schema = Schema::new(fields.clone());
261
262        // Store column names in the same order as fields
263        let column_names: Vec<String> = fields.iter().map(|field| field.name().clone()).collect();
264
265        // Calculate total number of combinations and suffix product
266        let mut biased_suffix_product = Vec::with_capacity(column_names.len() + 1);
267        let mut product = 1;
268        biased_suffix_product.push(product);
269        for col_name in column_names.iter().rev() {
270            product *= owned_matrix_foundation[col_name].len();
271            biased_suffix_product.push(product);
272        }
273        biased_suffix_product.pop();
274        biased_suffix_product.reverse();
275
276        Self {
277            matrix_foundation: owned_matrix_foundation,
278            current_index: 0,
279            schema,
280            column_names,
281            total_combinations: product,
282            biased_suffix_product,
283        }
284    }
285
286    pub fn schema(&self) -> &Schema {
287        &self.schema
288    }
289
290    fn generate_batch(&self, start_index: usize, batch_size: usize) -> RecordBatch {
291        let actual_batch_size = batch_size.min(self.total_combinations - start_index);
292
293        // Create array builders
294        let mut array_builders: Vec<Float64Builder> = Vec::with_capacity(self.column_names.len());
295        for _ in 0..self.column_names.len() {
296            array_builders.push(Float64Builder::with_capacity(actual_batch_size));
297        }
298
299        // Generate combinations for this batch
300        for combination_offset in 0..actual_batch_size {
301            let combination_index = start_index + combination_offset;
302
303            // For each column, determine which value to use for this combination
304            for (col_idx, col_name) in self.column_names.iter().enumerate() {
305                let values = &self.matrix_foundation[col_name];
306                let stride = self.biased_suffix_product[col_idx];
307                let value_index = (combination_index / stride) % values.len();
308                let value = *values[value_index].as_ref();
309
310                array_builders[col_idx].append_value(value);
311            }
312        }
313
314        // Finish arrays and create RecordBatch
315        let arrays: Vec<_> = array_builders
316            .into_iter()
317            .map(|mut builder| Arc::new(builder.finish()) as _)
318            .collect();
319
320        RecordBatch::try_new(Arc::new(self.schema.clone()), arrays)
321            .expect("Failed to create RecordBatch from generated arrays")
322    }
323}
324
325impl Iterator for MatrixGenerator {
326    type Item = RecordBatch;
327
328    fn next(&mut self) -> Option<Self::Item> {
329        if self.current_index >= self.total_combinations {
330            return None;
331        }
332
333        let remaining = self.total_combinations - self.current_index;
334        let batch_size = remaining.min(MAX_BATCH_SIZE);
335
336        let batch = self.generate_batch(self.current_index, batch_size);
337        self.current_index += batch_size;
338
339        Some(batch)
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use std::collections::HashMap;
346
347    use datatypes::value::Value;
348
349    use super::*;
350    use crate::expr::col;
351    use crate::multi_dim::MultiDimPartitionRule;
352
353    #[test]
354    fn test_matrix_generator_single_column() {
355        let mut matrix_foundation = HashMap::new();
356        matrix_foundation.insert(
357            "col1",
358            vec![
359                OrderedF64::from(1.0),
360                OrderedF64::from(2.0),
361                OrderedF64::from(3.0),
362            ],
363        );
364
365        let mut generator = MatrixGenerator::new(matrix_foundation);
366        let batch = generator.next().unwrap();
367
368        assert_eq!(batch.num_rows(), 3);
369        assert_eq!(batch.num_columns(), 1);
370        assert_eq!(batch.schema().field(0).name(), "col1");
371
372        let col1_array = batch
373            .column(0)
374            .as_any()
375            .downcast_ref::<datatypes::arrow::array::Float64Array>()
376            .unwrap();
377        assert_eq!(col1_array.value(0), 1.0);
378        assert_eq!(col1_array.value(1), 2.0);
379        assert_eq!(col1_array.value(2), 3.0);
380
381        // Should be no more batches for such a small dataset
382        assert!(generator.next().is_none());
383    }
384
385    #[test]
386    fn test_matrix_generator_three_columns_cartesian_product() {
387        let mut matrix_foundation = HashMap::new();
388        matrix_foundation.insert("a", vec![OrderedF64::from(1.0), OrderedF64::from(2.0)]);
389        matrix_foundation.insert("b", vec![OrderedF64::from(10.0), OrderedF64::from(20.0)]);
390        matrix_foundation.insert(
391            "c",
392            vec![
393                OrderedF64::from(100.0),
394                OrderedF64::from(200.0),
395                OrderedF64::from(300.0),
396            ],
397        );
398
399        let mut generator = MatrixGenerator::new(matrix_foundation);
400        let batch = generator.next().unwrap();
401
402        // Should have 2 * 2 * 3 = 12 combinations
403        assert_eq!(batch.num_rows(), 12);
404        assert_eq!(batch.num_columns(), 3);
405
406        let a_array = batch
407            .column(0)
408            .as_any()
409            .downcast_ref::<datatypes::arrow::array::Float64Array>()
410            .unwrap();
411        let b_array = batch
412            .column(1)
413            .as_any()
414            .downcast_ref::<datatypes::arrow::array::Float64Array>()
415            .unwrap();
416        let c_array = batch
417            .column(2)
418            .as_any()
419            .downcast_ref::<datatypes::arrow::array::Float64Array>()
420            .unwrap();
421
422        // Verify first few combinations (a changes slowest, c changes fastest)
423        let expected = vec![
424            (1.0, 10.0, 100.0),
425            (1.0, 10.0, 200.0),
426            (1.0, 10.0, 300.0),
427            (1.0, 20.0, 100.0),
428            (1.0, 20.0, 200.0),
429            (1.0, 20.0, 300.0),
430            (2.0, 10.0, 100.0),
431            (2.0, 10.0, 200.0),
432            (2.0, 10.0, 300.0),
433            (2.0, 20.0, 100.0),
434            (2.0, 20.0, 200.0),
435            (2.0, 20.0, 300.0),
436        ];
437        #[allow(clippy::needless_range_loop)]
438        for i in 0..batch.num_rows() {
439            assert_eq!(
440                (a_array.value(i), b_array.value(i), c_array.value(i)),
441                expected[i]
442            );
443        }
444
445        // Should be no more batches for such a small dataset
446        assert!(generator.next().is_none());
447    }
448
449    #[test]
450    fn test_matrix_generator_iterator_small_batches() {
451        let mut matrix_foundation = HashMap::new();
452        matrix_foundation.insert("col1", vec![OrderedF64::from(1.0), OrderedF64::from(2.0)]);
453        matrix_foundation.insert(
454            "col2",
455            vec![
456                OrderedF64::from(10.0),
457                OrderedF64::from(20.0),
458                OrderedF64::from(30.0),
459            ],
460        );
461
462        let generator = MatrixGenerator::new(matrix_foundation);
463
464        // Total combinations should be 2 * 3 = 6
465        assert_eq!(generator.total_combinations, 6);
466
467        let mut total_rows = 0;
468
469        for batch in generator {
470            total_rows += batch.num_rows();
471            assert_eq!(batch.num_columns(), 2);
472
473            // Verify each batch is valid
474            assert!(batch.num_rows() > 0);
475            assert!(batch.num_rows() <= MAX_BATCH_SIZE);
476        }
477
478        assert_eq!(total_rows, 6);
479    }
480
481    #[test]
482    fn test_matrix_generator_empty_column_values() {
483        let mut matrix_foundation = HashMap::new();
484        matrix_foundation.insert("col1", vec![]);
485
486        let mut generator = MatrixGenerator::new(matrix_foundation);
487
488        // Should have 0 total combinations when any column is empty
489        assert_eq!(generator.total_combinations, 0);
490
491        // Should have no batches when total combinations is 0
492        assert!(generator.next().is_none());
493    }
494
495    #[test]
496    fn test_matrix_generator_large_dataset_batching() {
497        // Create a dataset that will exceed MAX_BATCH_SIZE (8192)
498        // 20 * 20 * 21 = 8400 > 8192
499        let mut matrix_foundation = HashMap::new();
500
501        let values1: Vec<OrderedF64> = (0..20).map(|i| OrderedF64::from(i as f64)).collect();
502        let values2: Vec<OrderedF64> = (0..20)
503            .map(|i| OrderedF64::from(i as f64 + 100.0))
504            .collect();
505        let values3: Vec<OrderedF64> = (0..21)
506            .map(|i| OrderedF64::from(i as f64 + 1000.0))
507            .collect();
508
509        matrix_foundation.insert("col1", values1);
510        matrix_foundation.insert("col2", values2);
511        matrix_foundation.insert("col3", values3);
512
513        let generator = MatrixGenerator::new(matrix_foundation);
514
515        assert_eq!(generator.total_combinations, 8400);
516
517        let mut total_rows = 0;
518        let mut batch_count = 0;
519        let mut first_batch_size = None;
520
521        for batch in generator {
522            batch_count += 1;
523            let batch_size = batch.num_rows();
524            total_rows += batch_size;
525
526            if first_batch_size.is_none() {
527                first_batch_size = Some(batch_size);
528            }
529
530            // Each batch should not exceed MAX_BATCH_SIZE
531            assert!(batch_size <= MAX_BATCH_SIZE);
532            assert_eq!(batch.num_columns(), 3);
533        }
534
535        assert_eq!(total_rows, 8400);
536        assert!(batch_count > 1);
537        assert_eq!(first_batch_size.unwrap(), MAX_BATCH_SIZE);
538    }
539
540    #[test]
541    fn test_remap_checkpoint_values() {
542        // Create rule with single column
543        let rule = MultiDimPartitionRule::try_new(
544            vec!["host".to_string(), "value".to_string()],
545            vec![1, 2, 3],
546            vec![
547                col("host")
548                    .lt(Value::Int64(0))
549                    .and(col("value").lt(Value::Int64(0))),
550                col("host")
551                    .lt(Value::Int64(0))
552                    .and(col("value").gt_eq(Value::Int64(0))),
553                col("host")
554                    .gt_eq(Value::Int64(0))
555                    .and(col("host").lt(Value::Int64(1)))
556                    .and(col("value").lt(Value::Int64(1))),
557                col("host")
558                    .gt_eq(Value::Int64(0))
559                    .and(col("host").lt(Value::Int64(1)))
560                    .and(col("value").gt_eq(Value::Int64(1))),
561                col("host")
562                    .gt_eq(Value::Int64(1))
563                    .and(col("host").lt(Value::Int64(2)))
564                    .and(col("value").lt(Value::Int64(2))),
565                col("host")
566                    .gt_eq(Value::Int64(1))
567                    .and(col("host").lt(Value::Int64(2)))
568                    .and(col("value").gt_eq(Value::Int64(2))),
569                col("host")
570                    .gt_eq(Value::Int64(2))
571                    .and(col("host").lt(Value::Int64(3)))
572                    .and(col("value").lt(Value::Int64(3))),
573                col("host")
574                    .gt_eq(Value::Int64(2))
575                    .and(col("host").lt(Value::Int64(3)))
576                    .and(col("value").gt_eq(Value::Int64(3))),
577                col("host").gt_eq(Value::Int64(3)),
578            ],
579            true,
580        )
581        .unwrap();
582        let checker = PartitionChecker::try_new(&rule).unwrap();
583
584        let schema = Arc::new(Schema::new(vec![
585            Field::new("host", DataType::Float64, false),
586            Field::new("value", DataType::Float64, false),
587        ]));
588        let host_array = Float64Array::from(vec![-0.5, 0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5]);
589        let value_array = Float64Array::from(vec![-0.5, 0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5]);
590        let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
591            .unwrap();
592
593        let checkpoint = checker.remap_checkpoint(0, &batch);
594        assert_eq!(checkpoint, "host<0, value<0");
595        let checkpoint = checker.remap_checkpoint(1, &batch);
596        assert_eq!(checkpoint, "host=0, value=0");
597        let checkpoint = checker.remap_checkpoint(6, &batch);
598        assert_eq!(checkpoint, "2<host<3, 2<value<3");
599        let checkpoint = checker.remap_checkpoint(7, &batch);
600        assert_eq!(checkpoint, "host=3, value=3");
601        let checkpoint = checker.remap_checkpoint(8, &batch);
602        assert_eq!(checkpoint, "host>3, value>3");
603    }
604}