partition/
multi_dim.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::cmp::Ordering;
17use std::collections::hash_map::Entry;
18use std::collections::HashMap;
19use std::sync::{Arc, RwLock};
20
21use datafusion_expr::ColumnarValue;
22use datafusion_physical_expr::PhysicalExpr;
23use datatypes::arrow;
24use datatypes::arrow::array::{BooleanArray, BooleanBufferBuilder, RecordBatch};
25use datatypes::arrow::buffer::BooleanBuffer;
26use datatypes::arrow::datatypes::Schema;
27use datatypes::prelude::Value;
28use datatypes::vectors::{Helper, VectorRef};
29use serde::{Deserialize, Serialize};
30use snafu::{ensure, OptionExt, ResultExt};
31use store_api::storage::RegionNumber;
32
33use crate::error::{
34    self, ConjunctExprWithNonExprSnafu, InvalidExprSnafu, Result, UnclosedValueSnafu,
35    UndefinedColumnSnafu,
36};
37use crate::expr::{Operand, PartitionExpr, RestrictedOp};
38use crate::partition::RegionMask;
39use crate::PartitionRule;
40
41/// The default region number when no partition exprs are matched.
42const DEFAULT_REGION: RegionNumber = 0;
43
44type PhysicalExprCache = Option<(Vec<Arc<dyn PhysicalExpr>>, Arc<Schema>)>;
45
46/// Multi-Dimiension partition rule. RFC [here](https://github.com/GreptimeTeam/greptimedb/blob/main/docs/rfcs/2024-02-21-multi-dimension-partition-rule/rfc.md)
47///
48/// This partition rule is defined by a set of simple expressions on the partition
49/// key columns. Compare to RANGE partition, which can be considered as
50/// single-dimension rule, this will evaluate expression on each column separately.
51#[derive(Debug, Serialize, Deserialize)]
52pub struct MultiDimPartitionRule {
53    /// Allow list of which columns can be used for partitioning.
54    partition_columns: Vec<String>,
55    /// Name to index of `partition_columns`. Used for quick lookup.
56    name_to_index: HashMap<String, usize>,
57    /// Region number for each partition. This list has the same length as `exprs`
58    /// (dispiting the default region).
59    regions: Vec<RegionNumber>,
60    /// Partition expressions.
61    exprs: Vec<PartitionExpr>,
62    /// Cache of physical expressions.
63    #[serde(skip)]
64    physical_expr_cache: RwLock<PhysicalExprCache>,
65}
66
67impl MultiDimPartitionRule {
68    pub fn try_new(
69        partition_columns: Vec<String>,
70        regions: Vec<RegionNumber>,
71        exprs: Vec<PartitionExpr>,
72    ) -> Result<Self> {
73        let name_to_index = partition_columns
74            .iter()
75            .enumerate()
76            .map(|(i, name)| (name.clone(), i))
77            .collect::<HashMap<_, _>>();
78
79        let rule = Self {
80            partition_columns,
81            name_to_index,
82            regions,
83            exprs,
84            physical_expr_cache: RwLock::new(None),
85        };
86
87        let mut checker = RuleChecker::new(&rule);
88        checker.check()?;
89
90        Ok(rule)
91    }
92
93    fn find_region(&self, values: &[Value]) -> Result<RegionNumber> {
94        ensure!(
95            values.len() == self.partition_columns.len(),
96            error::RegionKeysSizeSnafu {
97                expect: self.partition_columns.len(),
98                actual: values.len(),
99            }
100        );
101
102        for (region_index, expr) in self.exprs.iter().enumerate() {
103            if self.evaluate_expr(expr, values)? {
104                return Ok(self.regions[region_index]);
105            }
106        }
107
108        // return the default region number
109        Ok(DEFAULT_REGION)
110    }
111
112    fn evaluate_expr(&self, expr: &PartitionExpr, values: &[Value]) -> Result<bool> {
113        match (expr.lhs.as_ref(), expr.rhs.as_ref()) {
114            (Operand::Column(name), Operand::Value(r)) => {
115                let index = self.name_to_index.get(name).unwrap();
116                let l = &values[*index];
117                Self::perform_op(l, &expr.op, r)
118            }
119            (Operand::Value(l), Operand::Column(name)) => {
120                let index = self.name_to_index.get(name).unwrap();
121                let r = &values[*index];
122                Self::perform_op(l, &expr.op, r)
123            }
124            (Operand::Expr(lhs), Operand::Expr(rhs)) => {
125                let lhs = self.evaluate_expr(lhs, values)?;
126                let rhs = self.evaluate_expr(rhs, values)?;
127                match expr.op {
128                    RestrictedOp::And => Ok(lhs && rhs),
129                    RestrictedOp::Or => Ok(lhs || rhs),
130                    _ => unreachable!(),
131                }
132            }
133            _ => unreachable!(),
134        }
135    }
136
137    fn perform_op(lhs: &Value, op: &RestrictedOp, rhs: &Value) -> Result<bool> {
138        let result = match op {
139            RestrictedOp::Eq => lhs.eq(rhs),
140            RestrictedOp::NotEq => lhs.ne(rhs),
141            RestrictedOp::Lt => lhs.partial_cmp(rhs) == Some(Ordering::Less),
142            RestrictedOp::LtEq => {
143                let result = lhs.partial_cmp(rhs);
144                result == Some(Ordering::Less) || result == Some(Ordering::Equal)
145            }
146            RestrictedOp::Gt => lhs.partial_cmp(rhs) == Some(Ordering::Greater),
147            RestrictedOp::GtEq => {
148                let result = lhs.partial_cmp(rhs);
149                result == Some(Ordering::Greater) || result == Some(Ordering::Equal)
150            }
151            RestrictedOp::And | RestrictedOp::Or => unreachable!(),
152        };
153
154        Ok(result)
155    }
156
157    pub fn row_at(&self, cols: &[VectorRef], index: usize, row: &mut [Value]) -> Result<()> {
158        for (col_idx, col) in cols.iter().enumerate() {
159            row[col_idx] = col.get(index);
160        }
161        Ok(())
162    }
163
164    pub fn record_batch_to_cols(&self, record_batch: &RecordBatch) -> Result<Vec<VectorRef>> {
165        self.partition_columns
166            .iter()
167            .map(|col_name| {
168                record_batch
169                    .column_by_name(col_name)
170                    .context(error::UndefinedColumnSnafu { column: col_name })
171                    .and_then(|array| {
172                        Helper::try_into_vector(array).context(error::ConvertToVectorSnafu)
173                    })
174            })
175            .collect::<Result<Vec<_>>>()
176    }
177
178    pub fn split_record_batch_naive(
179        &self,
180        record_batch: &RecordBatch,
181    ) -> Result<HashMap<RegionNumber, BooleanArray>> {
182        let num_rows = record_batch.num_rows();
183
184        let mut result = self
185            .regions
186            .iter()
187            .map(|region| {
188                let mut builder = BooleanBufferBuilder::new(num_rows);
189                builder.append_n(num_rows, false);
190                (*region, builder)
191            })
192            .collect::<HashMap<_, _>>();
193
194        let cols = self.record_batch_to_cols(record_batch)?;
195        let mut current_row = vec![Value::Null; self.partition_columns.len()];
196        for row_idx in 0..num_rows {
197            self.row_at(&cols, row_idx, &mut current_row)?;
198            let current_region = self.find_region(&current_row)?;
199            let region_mask = result
200                .get_mut(&current_region)
201                .unwrap_or_else(|| panic!("Region {} must be initialized", current_region));
202            region_mask.set_bit(row_idx, true);
203        }
204
205        Ok(result
206            .into_iter()
207            .map(|(region, mut mask)| (region, BooleanArray::new(mask.finish(), None)))
208            .collect())
209    }
210
211    pub fn split_record_batch(
212        &self,
213        record_batch: &RecordBatch,
214    ) -> Result<HashMap<RegionNumber, RegionMask>> {
215        let num_rows = record_batch.num_rows();
216        if self.regions.len() == 1 {
217            return Ok([(
218                self.regions[0],
219                RegionMask::from(BooleanArray::from(vec![true; num_rows])),
220            )]
221            .into_iter()
222            .collect());
223        }
224        let physical_exprs = {
225            let cache_read_guard = self.physical_expr_cache.read().unwrap();
226            if let Some((cached_exprs, schema)) = cache_read_guard.as_ref()
227                && schema == record_batch.schema_ref()
228            {
229                cached_exprs.clone()
230            } else {
231                drop(cache_read_guard); // Release the read lock before acquiring write lock
232
233                let schema = record_batch.schema();
234                let new_cache = self
235                    .exprs
236                    .iter()
237                    .map(|e| e.try_as_physical_expr(&schema))
238                    .collect::<Result<Vec<_>>>()?;
239
240                let mut cache_write_guard = self.physical_expr_cache.write().unwrap();
241                cache_write_guard.replace((new_cache.clone(), schema));
242                new_cache
243            }
244        };
245
246        let mut result: HashMap<u32, RegionMask> = physical_exprs
247            .iter()
248            .zip(self.regions.iter())
249            .filter_map(|(expr, region_num)| {
250                let col_val = match expr
251                    .evaluate(record_batch)
252                    .context(error::EvaluateRecordBatchSnafu)
253                {
254                    Ok(array) => array,
255                    Err(e) => {
256                        return Some(Err(e));
257                    }
258                };
259                let ColumnarValue::Array(column) = col_val else {
260                    unreachable!("Expected an array")
261                };
262                let array =
263                    match column
264                        .as_any()
265                        .downcast_ref::<BooleanArray>()
266                        .with_context(|| error::UnexpectedColumnTypeSnafu {
267                            data_type: column.data_type().clone(),
268                        }) {
269                        Ok(array) => array,
270                        Err(e) => {
271                            return Some(Err(e));
272                        }
273                    };
274                let selected_rows = array.true_count();
275                if selected_rows == 0 {
276                    // skip empty region in results.
277                    return None;
278                }
279                Some(Ok((
280                    *region_num,
281                    RegionMask::new(array.clone(), selected_rows),
282                )))
283            })
284            .collect::<error::Result<_>>()?;
285
286        let selected = if result.len() == 1 {
287            result.values().next().unwrap().array().clone()
288        } else {
289            let mut selected = BooleanArray::new(BooleanBuffer::new_unset(num_rows), None);
290            for region_mask in result.values() {
291                selected = arrow::compute::kernels::boolean::or(&selected, region_mask.array())
292                    .context(error::ComputeArrowKernelSnafu)?;
293            }
294            selected
295        };
296
297        // fast path: all rows are selected
298        if selected.true_count() == num_rows {
299            return Ok(result);
300        }
301
302        // find unselected rows and assign to default region
303        let unselected = arrow::compute::kernels::boolean::not(&selected)
304            .context(error::ComputeArrowKernelSnafu)?;
305        match result.entry(DEFAULT_REGION) {
306            Entry::Occupied(mut o) => {
307                // merge default region with unselected rows.
308                let default_region_mask = RegionMask::from(
309                    arrow::compute::kernels::boolean::or(o.get().array(), &unselected)
310                        .context(error::ComputeArrowKernelSnafu)?,
311                );
312                o.insert(default_region_mask);
313            }
314            Entry::Vacant(v) => {
315                // default region has no rows, simply put all unselected rows to default region.
316                v.insert(RegionMask::from(unselected));
317            }
318        }
319        Ok(result)
320    }
321}
322
323impl PartitionRule for MultiDimPartitionRule {
324    fn as_any(&self) -> &dyn Any {
325        self
326    }
327
328    fn partition_columns(&self) -> Vec<String> {
329        self.partition_columns.clone()
330    }
331
332    fn find_region(&self, values: &[Value]) -> Result<RegionNumber> {
333        self.find_region(values)
334    }
335
336    fn split_record_batch(
337        &self,
338        record_batch: &RecordBatch,
339    ) -> Result<HashMap<RegionNumber, RegionMask>> {
340        self.split_record_batch(record_batch)
341    }
342}
343
344/// Helper for [RuleChecker]
345type Axis = HashMap<Value, SplitPoint>;
346
347/// Helper for [RuleChecker]
348struct SplitPoint {
349    is_equal: bool,
350    less_than_counter: isize,
351}
352
353/// Check if the rule set covers all the possible values.
354///
355/// Note this checker have false-negative on duplicated exprs. E.g.:
356/// `a != 20`, `a <= 20` and `a > 20`.
357///
358/// It works on the observation that each projected split point should be included (`is_equal`)
359/// and have a balanced `<` and `>` counter.
360struct RuleChecker<'a> {
361    axis: Vec<Axis>,
362    rule: &'a MultiDimPartitionRule,
363}
364
365impl<'a> RuleChecker<'a> {
366    pub fn new(rule: &'a MultiDimPartitionRule) -> Self {
367        let mut projections = Vec::with_capacity(rule.partition_columns.len());
368        projections.resize_with(rule.partition_columns.len(), Default::default);
369
370        Self {
371            axis: projections,
372            rule,
373        }
374    }
375
376    pub fn check(&mut self) -> Result<()> {
377        for expr in &self.rule.exprs {
378            self.walk_expr(expr)?
379        }
380
381        self.check_axis()
382    }
383
384    #[allow(clippy::mutable_key_type)]
385    fn walk_expr(&mut self, expr: &PartitionExpr) -> Result<()> {
386        // recursively check the expr
387        match expr.op {
388            RestrictedOp::And | RestrictedOp::Or => {
389                match (expr.lhs.as_ref(), expr.rhs.as_ref()) {
390                    (Operand::Expr(lhs), Operand::Expr(rhs)) => {
391                        self.walk_expr(lhs)?;
392                        self.walk_expr(rhs)?
393                    }
394                    _ => ConjunctExprWithNonExprSnafu { expr: expr.clone() }.fail()?,
395                }
396
397                return Ok(());
398            }
399            // Not conjunction
400            _ => {}
401        }
402
403        let (col, val) = match (expr.lhs.as_ref(), expr.rhs.as_ref()) {
404            (Operand::Expr(_), _)
405            | (_, Operand::Expr(_))
406            | (Operand::Column(_), Operand::Column(_))
407            | (Operand::Value(_), Operand::Value(_)) => {
408                InvalidExprSnafu { expr: expr.clone() }.fail()?
409            }
410
411            (Operand::Column(col), Operand::Value(val))
412            | (Operand::Value(val), Operand::Column(col)) => (col, val),
413        };
414
415        let col_index =
416            *self
417                .rule
418                .name_to_index
419                .get(col)
420                .with_context(|| UndefinedColumnSnafu {
421                    column: col.clone(),
422                })?;
423        let axis = &mut self.axis[col_index];
424        let split_point = axis.entry(val.clone()).or_insert(SplitPoint {
425            is_equal: false,
426            less_than_counter: 0,
427        });
428        match expr.op {
429            RestrictedOp::Eq => {
430                split_point.is_equal = true;
431            }
432            RestrictedOp::NotEq => {
433                // less_than +1 -1
434            }
435            RestrictedOp::Lt => {
436                split_point.less_than_counter += 1;
437            }
438            RestrictedOp::LtEq => {
439                split_point.less_than_counter += 1;
440                split_point.is_equal = true;
441            }
442            RestrictedOp::Gt => {
443                split_point.less_than_counter -= 1;
444            }
445            RestrictedOp::GtEq => {
446                split_point.less_than_counter -= 1;
447                split_point.is_equal = true;
448            }
449            RestrictedOp::And | RestrictedOp::Or => {
450                unreachable!("conjunct expr should be handled above")
451            }
452        }
453
454        Ok(())
455    }
456
457    /// Return if the rule is legal.
458    fn check_axis(&self) -> Result<()> {
459        for (col_index, axis) in self.axis.iter().enumerate() {
460            for (val, split_point) in axis {
461                if split_point.less_than_counter != 0 || !split_point.is_equal {
462                    UnclosedValueSnafu {
463                        value: format!("{val:?}"),
464                        column: self.rule.partition_columns[col_index].clone(),
465                    }
466                    .fail()?;
467                }
468            }
469        }
470        Ok(())
471    }
472}
473
474#[cfg(test)]
475mod tests {
476    use std::assert_matches::assert_matches;
477
478    use super::*;
479    use crate::error::{self, Error};
480
481    #[test]
482    fn test_find_region() {
483        // PARTITION ON COLUMNS (b) (
484        //     b < 'hz',
485        //     b >= 'hz' AND b < 'sh',
486        //     b >= 'sh'
487        // )
488        let rule = MultiDimPartitionRule::try_new(
489            vec!["b".to_string()],
490            vec![1, 2, 3],
491            vec![
492                PartitionExpr::new(
493                    Operand::Column("b".to_string()),
494                    RestrictedOp::Lt,
495                    Operand::Value(datatypes::value::Value::String("hz".into())),
496                ),
497                PartitionExpr::new(
498                    Operand::Expr(PartitionExpr::new(
499                        Operand::Column("b".to_string()),
500                        RestrictedOp::GtEq,
501                        Operand::Value(datatypes::value::Value::String("hz".into())),
502                    )),
503                    RestrictedOp::And,
504                    Operand::Expr(PartitionExpr::new(
505                        Operand::Column("b".to_string()),
506                        RestrictedOp::Lt,
507                        Operand::Value(datatypes::value::Value::String("sh".into())),
508                    )),
509                ),
510                PartitionExpr::new(
511                    Operand::Column("b".to_string()),
512                    RestrictedOp::GtEq,
513                    Operand::Value(datatypes::value::Value::String("sh".into())),
514                ),
515            ],
516        )
517        .unwrap();
518        assert_matches!(
519            rule.find_region(&["foo".into(), 1000_i32.into()]),
520            Err(error::Error::RegionKeysSize {
521                expect: 1,
522                actual: 2,
523                ..
524            })
525        );
526        assert_matches!(rule.find_region(&["foo".into()]), Ok(1));
527        assert_matches!(rule.find_region(&["bar".into()]), Ok(1));
528        assert_matches!(rule.find_region(&["hz".into()]), Ok(2));
529        assert_matches!(rule.find_region(&["hzz".into()]), Ok(2));
530        assert_matches!(rule.find_region(&["sh".into()]), Ok(3));
531        assert_matches!(rule.find_region(&["zzzz".into()]), Ok(3));
532    }
533
534    #[test]
535    fn invalid_expr_case_1() {
536        // PARTITION ON COLUMNS (b) (
537        //     b <= b >= 'hz' AND b < 'sh',
538        // )
539        let rule = MultiDimPartitionRule::try_new(
540            vec!["a".to_string(), "b".to_string()],
541            vec![1],
542            vec![PartitionExpr::new(
543                Operand::Column("b".to_string()),
544                RestrictedOp::LtEq,
545                Operand::Expr(PartitionExpr::new(
546                    Operand::Expr(PartitionExpr::new(
547                        Operand::Column("b".to_string()),
548                        RestrictedOp::GtEq,
549                        Operand::Value(datatypes::value::Value::String("hz".into())),
550                    )),
551                    RestrictedOp::And,
552                    Operand::Expr(PartitionExpr::new(
553                        Operand::Column("b".to_string()),
554                        RestrictedOp::Lt,
555                        Operand::Value(datatypes::value::Value::String("sh".into())),
556                    )),
557                )),
558            )],
559        );
560
561        // check rule
562        assert_matches!(rule.unwrap_err(), Error::InvalidExpr { .. });
563    }
564
565    #[test]
566    fn invalid_expr_case_2() {
567        // PARTITION ON COLUMNS (b) (
568        //     b >= 'hz' AND 'sh',
569        // )
570        let rule = MultiDimPartitionRule::try_new(
571            vec!["a".to_string(), "b".to_string()],
572            vec![1],
573            vec![PartitionExpr::new(
574                Operand::Expr(PartitionExpr::new(
575                    Operand::Column("b".to_string()),
576                    RestrictedOp::GtEq,
577                    Operand::Value(datatypes::value::Value::String("hz".into())),
578                )),
579                RestrictedOp::And,
580                Operand::Value(datatypes::value::Value::String("sh".into())),
581            )],
582        );
583
584        // check rule
585        assert_matches!(rule.unwrap_err(), Error::ConjunctExprWithNonExpr { .. });
586    }
587
588    /// ```ignore
589    ///          │          │               
590    ///          │          │               
591    /// ─────────┼──────────┼────────────► b
592    ///          │          │               
593    ///          │          │               
594    ///      b <= h     b >= s            
595    /// ```
596    #[test]
597    fn empty_expr_case_1() {
598        // PARTITION ON COLUMNS (b) (
599        //     b <= 'h',
600        //     b >= 's'
601        // )
602        let rule = MultiDimPartitionRule::try_new(
603            vec!["a".to_string(), "b".to_string()],
604            vec![1, 2],
605            vec![
606                PartitionExpr::new(
607                    Operand::Column("b".to_string()),
608                    RestrictedOp::LtEq,
609                    Operand::Value(datatypes::value::Value::String("h".into())),
610                ),
611                PartitionExpr::new(
612                    Operand::Column("b".to_string()),
613                    RestrictedOp::GtEq,
614                    Operand::Value(datatypes::value::Value::String("s".into())),
615                ),
616            ],
617        );
618
619        // check rule
620        assert_matches!(rule.unwrap_err(), Error::UnclosedValue { .. });
621    }
622
623    /// ```
624    ///     a                                                  
625    ///     ▲                                        
626    ///     │                   ‖        
627    ///     │                   ‖        
628    /// 200 │         ┌─────────┤        
629    ///     │         │         │        
630    ///     │         │         │        
631    ///     │         │         │        
632    /// 100 │   ======┴─────────┘        
633    ///     │                            
634    ///     └──────────────────────────►b
635    ///              10          20      
636    /// ```
637    #[test]
638    fn empty_expr_case_2() {
639        // PARTITION ON COLUMNS (b) (
640        //     a >= 100 AND b <= 10  OR  a > 100 AND a <= 200 AND b <= 10  OR  a >= 200 AND b > 10 AND b <= 20  OR  a > 200 AND b <= 20
641        //     a < 100 AND b <= 20  OR  a >= 100 AND b > 20
642        // )
643        let rule = MultiDimPartitionRule::try_new(
644            vec!["a".to_string(), "b".to_string()],
645            vec![1, 2],
646            vec![
647                PartitionExpr::new(
648                    Operand::Expr(PartitionExpr::new(
649                        Operand::Expr(PartitionExpr::new(
650                            //  a >= 100 AND b <= 10
651                            Operand::Expr(PartitionExpr::new(
652                                Operand::Expr(PartitionExpr::new(
653                                    Operand::Column("a".to_string()),
654                                    RestrictedOp::GtEq,
655                                    Operand::Value(datatypes::value::Value::Int64(100)),
656                                )),
657                                RestrictedOp::And,
658                                Operand::Expr(PartitionExpr::new(
659                                    Operand::Column("b".to_string()),
660                                    RestrictedOp::LtEq,
661                                    Operand::Value(datatypes::value::Value::Int64(10)),
662                                )),
663                            )),
664                            RestrictedOp::Or,
665                            // a > 100 AND a <= 200 AND b <= 10
666                            Operand::Expr(PartitionExpr::new(
667                                Operand::Expr(PartitionExpr::new(
668                                    Operand::Expr(PartitionExpr::new(
669                                        Operand::Column("a".to_string()),
670                                        RestrictedOp::Gt,
671                                        Operand::Value(datatypes::value::Value::Int64(100)),
672                                    )),
673                                    RestrictedOp::And,
674                                    Operand::Expr(PartitionExpr::new(
675                                        Operand::Column("a".to_string()),
676                                        RestrictedOp::LtEq,
677                                        Operand::Value(datatypes::value::Value::Int64(200)),
678                                    )),
679                                )),
680                                RestrictedOp::And,
681                                Operand::Expr(PartitionExpr::new(
682                                    Operand::Column("b".to_string()),
683                                    RestrictedOp::LtEq,
684                                    Operand::Value(datatypes::value::Value::Int64(10)),
685                                )),
686                            )),
687                        )),
688                        RestrictedOp::Or,
689                        // a >= 200 AND b > 10 AND b <= 20
690                        Operand::Expr(PartitionExpr::new(
691                            Operand::Expr(PartitionExpr::new(
692                                Operand::Expr(PartitionExpr::new(
693                                    Operand::Column("a".to_string()),
694                                    RestrictedOp::GtEq,
695                                    Operand::Value(datatypes::value::Value::Int64(200)),
696                                )),
697                                RestrictedOp::And,
698                                Operand::Expr(PartitionExpr::new(
699                                    Operand::Column("b".to_string()),
700                                    RestrictedOp::Gt,
701                                    Operand::Value(datatypes::value::Value::Int64(10)),
702                                )),
703                            )),
704                            RestrictedOp::And,
705                            Operand::Expr(PartitionExpr::new(
706                                Operand::Column("b".to_string()),
707                                RestrictedOp::LtEq,
708                                Operand::Value(datatypes::value::Value::Int64(20)),
709                            )),
710                        )),
711                    )),
712                    RestrictedOp::Or,
713                    // a > 200 AND b <= 20
714                    Operand::Expr(PartitionExpr::new(
715                        Operand::Expr(PartitionExpr::new(
716                            Operand::Column("a".to_string()),
717                            RestrictedOp::Gt,
718                            Operand::Value(datatypes::value::Value::Int64(200)),
719                        )),
720                        RestrictedOp::And,
721                        Operand::Expr(PartitionExpr::new(
722                            Operand::Column("b".to_string()),
723                            RestrictedOp::LtEq,
724                            Operand::Value(datatypes::value::Value::Int64(20)),
725                        )),
726                    )),
727                ),
728                PartitionExpr::new(
729                    // a < 100 AND b <= 20
730                    Operand::Expr(PartitionExpr::new(
731                        Operand::Expr(PartitionExpr::new(
732                            Operand::Column("a".to_string()),
733                            RestrictedOp::Lt,
734                            Operand::Value(datatypes::value::Value::Int64(100)),
735                        )),
736                        RestrictedOp::And,
737                        Operand::Expr(PartitionExpr::new(
738                            Operand::Column("b".to_string()),
739                            RestrictedOp::LtEq,
740                            Operand::Value(datatypes::value::Value::Int64(20)),
741                        )),
742                    )),
743                    RestrictedOp::Or,
744                    // a >= 100 AND b > 20
745                    Operand::Expr(PartitionExpr::new(
746                        Operand::Expr(PartitionExpr::new(
747                            Operand::Column("a".to_string()),
748                            RestrictedOp::GtEq,
749                            Operand::Value(datatypes::value::Value::Int64(100)),
750                        )),
751                        RestrictedOp::And,
752                        Operand::Expr(PartitionExpr::new(
753                            Operand::Column("b".to_string()),
754                            RestrictedOp::GtEq,
755                            Operand::Value(datatypes::value::Value::Int64(20)),
756                        )),
757                    )),
758                ),
759            ],
760        );
761
762        // check rule
763        assert_matches!(rule.unwrap_err(), Error::UnclosedValue { .. });
764    }
765
766    #[test]
767    fn duplicate_expr_case_1() {
768        // PARTITION ON COLUMNS (a) (
769        //     a <= 20,
770        //     a >= 10
771        // )
772        let rule = MultiDimPartitionRule::try_new(
773            vec!["a".to_string(), "b".to_string()],
774            vec![1, 2],
775            vec![
776                PartitionExpr::new(
777                    Operand::Column("a".to_string()),
778                    RestrictedOp::LtEq,
779                    Operand::Value(datatypes::value::Value::Int64(20)),
780                ),
781                PartitionExpr::new(
782                    Operand::Column("a".to_string()),
783                    RestrictedOp::GtEq,
784                    Operand::Value(datatypes::value::Value::Int64(10)),
785                ),
786            ],
787        );
788
789        // check rule
790        assert_matches!(rule.unwrap_err(), Error::UnclosedValue { .. });
791    }
792
793    #[test]
794    #[ignore = "checker cannot detect this kind of duplicate for now"]
795    fn duplicate_expr_case_2() {
796        // PARTITION ON COLUMNS (a) (
797        //     a != 20,
798        //     a <= 20,
799        //     a > 20,
800        // )
801        let rule = MultiDimPartitionRule::try_new(
802            vec!["a".to_string(), "b".to_string()],
803            vec![1, 2],
804            vec![
805                PartitionExpr::new(
806                    Operand::Column("a".to_string()),
807                    RestrictedOp::NotEq,
808                    Operand::Value(datatypes::value::Value::Int64(20)),
809                ),
810                PartitionExpr::new(
811                    Operand::Column("a".to_string()),
812                    RestrictedOp::LtEq,
813                    Operand::Value(datatypes::value::Value::Int64(20)),
814                ),
815                PartitionExpr::new(
816                    Operand::Column("a".to_string()),
817                    RestrictedOp::Gt,
818                    Operand::Value(datatypes::value::Value::Int64(20)),
819                ),
820            ],
821        );
822
823        // check rule
824        assert!(rule.is_err());
825    }
826}
827
828#[cfg(test)]
829mod test_split_record_batch {
830    use std::sync::Arc;
831
832    use datatypes::arrow::array::{Int64Array, StringArray};
833    use datatypes::arrow::datatypes::{DataType, Field, Schema};
834    use datatypes::arrow::record_batch::RecordBatch;
835    use rand::Rng;
836
837    use super::*;
838    use crate::expr::col;
839
840    fn test_schema() -> Arc<Schema> {
841        Arc::new(Schema::new(vec![
842            Field::new("host", DataType::Utf8, false),
843            Field::new("value", DataType::Int64, false),
844        ]))
845    }
846
847    fn generate_random_record_batch(num_rows: usize) -> RecordBatch {
848        let schema = test_schema();
849        let mut rng = rand::thread_rng();
850        let mut host_array = Vec::with_capacity(num_rows);
851        let mut value_array = Vec::with_capacity(num_rows);
852        for _ in 0..num_rows {
853            host_array.push(format!("server{}", rng.gen_range(0..20)));
854            value_array.push(rng.gen_range(0..20));
855        }
856        let host_array = StringArray::from(host_array);
857        let value_array = Int64Array::from(value_array);
858        RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)]).unwrap()
859    }
860
861    #[test]
862    fn test_split_record_batch_by_one_column() {
863        // Create a simple MultiDimPartitionRule
864        let rule = MultiDimPartitionRule::try_new(
865            vec!["host".to_string(), "value".to_string()],
866            vec![0, 1],
867            vec![
868                col("host").lt(Value::String("server1".into())),
869                col("host").gt_eq(Value::String("server1".into())),
870            ],
871        )
872        .unwrap();
873
874        let batch = generate_random_record_batch(1000);
875        // Split the batch
876        let result = rule.split_record_batch(&batch).unwrap();
877        let expected = rule.split_record_batch_naive(&batch).unwrap();
878        assert_eq!(result.len(), expected.len());
879        for (region, value) in &result {
880            assert_eq!(
881                value.array(),
882                expected.get(region).unwrap(),
883                "failed on region: {}",
884                region
885            );
886        }
887    }
888
889    #[test]
890    fn test_split_record_batch_empty() {
891        // Create a simple MultiDimPartitionRule
892        let rule = MultiDimPartitionRule::try_new(
893            vec!["host".to_string()],
894            vec![1],
895            vec![PartitionExpr::new(
896                Operand::Column("host".to_string()),
897                RestrictedOp::Eq,
898                Operand::Value(Value::String("server1".into())),
899            )],
900        )
901        .unwrap();
902
903        let schema = test_schema();
904        let host_array = StringArray::from(Vec::<&str>::new());
905        let value_array = Int64Array::from(Vec::<i64>::new());
906        let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
907            .unwrap();
908
909        let result = rule.split_record_batch(&batch).unwrap();
910        assert_eq!(result.len(), 1);
911    }
912
913    #[test]
914    fn test_split_record_batch_by_two_columns() {
915        let rule = MultiDimPartitionRule::try_new(
916            vec!["host".to_string(), "value".to_string()],
917            vec![0, 1, 2, 3],
918            vec![
919                col("host")
920                    .lt(Value::String("server10".into()))
921                    .and(col("value").lt(Value::Int64(10))),
922                col("host")
923                    .lt(Value::String("server10".into()))
924                    .and(col("value").gt_eq(Value::Int64(10))),
925                col("host")
926                    .gt_eq(Value::String("server10".into()))
927                    .and(col("value").lt(Value::Int64(10))),
928                col("host")
929                    .gt_eq(Value::String("server10".into()))
930                    .and(col("value").gt_eq(Value::Int64(10))),
931            ],
932        )
933        .unwrap();
934
935        let batch = generate_random_record_batch(1000);
936        let result = rule.split_record_batch(&batch).unwrap();
937        let expected = rule.split_record_batch_naive(&batch).unwrap();
938        assert_eq!(result.len(), expected.len());
939        for (region, value) in &result {
940            assert_eq!(value.array(), expected.get(region).unwrap());
941        }
942    }
943
944    #[test]
945    fn test_default_region() {
946        let rule = MultiDimPartitionRule::try_new(
947            vec!["host".to_string(), "value".to_string()],
948            vec![0, 1, 2, 3],
949            vec![
950                col("host")
951                    .lt(Value::String("server10".into()))
952                    .and(col("value").eq(Value::Int64(10))),
953                col("host")
954                    .lt(Value::String("server10".into()))
955                    .and(col("value").eq(Value::Int64(20))),
956                col("host")
957                    .gt_eq(Value::String("server10".into()))
958                    .and(col("value").eq(Value::Int64(10))),
959                col("host")
960                    .gt_eq(Value::String("server10".into()))
961                    .and(col("value").eq(Value::Int64(20))),
962            ],
963        )
964        .unwrap();
965
966        let schema = test_schema();
967        let host_array = StringArray::from(vec!["server1", "server1", "server1", "server100"]);
968        let value_array = Int64Array::from(vec![10, 20, 30, 10]);
969        let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
970            .unwrap();
971        let result = rule.split_record_batch(&batch).unwrap();
972        let expected = rule.split_record_batch_naive(&batch).unwrap();
973        for (region, value) in &result {
974            assert_eq!(value.array(), expected.get(region).unwrap());
975        }
976    }
977
978    #[test]
979    fn test_default_region_with_unselected_rows() {
980        // Create a rule where some rows won't match any partition
981        let rule = MultiDimPartitionRule::try_new(
982            vec!["host".to_string(), "value".to_string()],
983            vec![1, 2, 3],
984            vec![
985                col("value").eq(Value::Int64(10)),
986                col("value").eq(Value::Int64(20)),
987                col("value").eq(Value::Int64(30)),
988            ],
989        )
990        .unwrap();
991
992        let schema = test_schema();
993        let host_array =
994            StringArray::from(vec!["server1", "server2", "server3", "server4", "server5"]);
995        let value_array = Int64Array::from(vec![10, 20, 30, 40, 50]);
996        let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
997            .unwrap();
998
999        let result = rule.split_record_batch(&batch).unwrap();
1000
1001        // Check that we have 4 regions (3 defined + default)
1002        assert_eq!(result.len(), 4);
1003
1004        // Check that default region (0) contains the unselected rows
1005        assert!(result.contains_key(&DEFAULT_REGION));
1006        let default_mask = result.get(&DEFAULT_REGION).unwrap();
1007
1008        // The default region should have 2 rows (with values 40 and 50)
1009        assert_eq!(default_mask.selected_rows(), 2);
1010
1011        // Verify each region has the correct number of rows
1012        assert_eq!(result.get(&1).unwrap().selected_rows(), 1); // value = 10
1013        assert_eq!(result.get(&2).unwrap().selected_rows(), 1); // value = 20
1014        assert_eq!(result.get(&3).unwrap().selected_rows(), 1); // value = 30
1015    }
1016
1017    #[test]
1018    fn test_default_region_with_existing_default() {
1019        // Create a rule where some rows are explicitly assigned to default region
1020        // and some rows are implicitly assigned to default region
1021        let rule = MultiDimPartitionRule::try_new(
1022            vec!["host".to_string(), "value".to_string()],
1023            vec![0, 1, 2],
1024            vec![
1025                col("value").eq(Value::Int64(10)), // Explicitly assign value=10 to region 0 (default)
1026                col("value").eq(Value::Int64(20)),
1027                col("value").eq(Value::Int64(30)),
1028            ],
1029        )
1030        .unwrap();
1031
1032        let schema = test_schema();
1033        let host_array =
1034            StringArray::from(vec!["server1", "server2", "server3", "server4", "server5"]);
1035        let value_array = Int64Array::from(vec![10, 20, 30, 40, 50]);
1036        let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
1037            .unwrap();
1038
1039        let result = rule.split_record_batch(&batch).unwrap();
1040
1041        // Check that we have 3 regions
1042        assert_eq!(result.len(), 3);
1043
1044        // Check that default region contains both explicitly assigned and unselected rows
1045        assert!(result.contains_key(&DEFAULT_REGION));
1046        let default_mask = result.get(&DEFAULT_REGION).unwrap();
1047
1048        // The default region should have 3 rows (value=10, 40, 50)
1049        assert_eq!(default_mask.selected_rows(), 3);
1050
1051        // Verify each region has the correct number of rows
1052        assert_eq!(result.get(&1).unwrap().selected_rows(), 1); // value = 20
1053        assert_eq!(result.get(&2).unwrap().selected_rows(), 1); // value = 30
1054    }
1055
1056    #[test]
1057    fn test_all_rows_selected() {
1058        // Test the fast path where all rows are selected by some partition
1059        let rule = MultiDimPartitionRule::try_new(
1060            vec!["value".to_string()],
1061            vec![1, 2],
1062            vec![
1063                col("value").lt(Value::Int64(30)),
1064                col("value").gt_eq(Value::Int64(30)),
1065            ],
1066        )
1067        .unwrap();
1068
1069        let schema = test_schema();
1070        let host_array = StringArray::from(vec!["server1", "server2", "server3", "server4"]);
1071        let value_array = Int64Array::from(vec![10, 20, 30, 40]);
1072        let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
1073            .unwrap();
1074
1075        let result = rule.split_record_batch(&batch).unwrap();
1076
1077        // Check that we have 2 regions and no default region
1078        assert_eq!(result.len(), 2);
1079        assert!(result.contains_key(&1));
1080        assert!(result.contains_key(&2));
1081
1082        // Verify each region has the correct number of rows
1083        assert_eq!(result.get(&1).unwrap().selected_rows(), 2); // values < 30
1084        assert_eq!(result.get(&2).unwrap().selected_rows(), 2); // values >= 30
1085    }
1086}