partition/
multi_dim.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::cmp::Ordering;
17use std::collections::HashMap;
18use std::sync::{Arc, RwLock};
19
20use datafusion_expr::ColumnarValue;
21use datafusion_physical_expr::PhysicalExpr;
22use datatypes::arrow;
23use datatypes::arrow::array::{BooleanArray, BooleanBufferBuilder, RecordBatch};
24use datatypes::arrow::buffer::BooleanBuffer;
25use datatypes::arrow::datatypes::Schema;
26use datatypes::prelude::Value;
27use datatypes::vectors::{Helper, VectorRef};
28use serde::{Deserialize, Serialize};
29use snafu::{ensure, OptionExt, ResultExt};
30use store_api::storage::RegionNumber;
31
32use crate::error::{
33    self, ConjunctExprWithNonExprSnafu, InvalidExprSnafu, Result, UnclosedValueSnafu,
34    UndefinedColumnSnafu,
35};
36use crate::expr::{Operand, PartitionExpr, RestrictedOp};
37use crate::PartitionRule;
38
39/// The default region number when no partition exprs are matched.
40const DEFAULT_REGION: RegionNumber = 0;
41
42type PhysicalExprCache = Option<(Vec<Arc<dyn PhysicalExpr>>, Arc<Schema>)>;
43
44/// Multi-Dimiension partition rule. RFC [here](https://github.com/GreptimeTeam/greptimedb/blob/main/docs/rfcs/2024-02-21-multi-dimension-partition-rule/rfc.md)
45///
46/// This partition rule is defined by a set of simple expressions on the partition
47/// key columns. Compare to RANGE partition, which can be considered as
48/// single-dimension rule, this will evaluate expression on each column separately.
49#[derive(Debug, Serialize, Deserialize)]
50pub struct MultiDimPartitionRule {
51    /// Allow list of which columns can be used for partitioning.
52    partition_columns: Vec<String>,
53    /// Name to index of `partition_columns`. Used for quick lookup.
54    name_to_index: HashMap<String, usize>,
55    /// Region number for each partition. This list has the same length as `exprs`
56    /// (dispiting the default region).
57    regions: Vec<RegionNumber>,
58    /// Partition expressions.
59    exprs: Vec<PartitionExpr>,
60    /// Cache of physical expressions.
61    #[serde(skip)]
62    physical_expr_cache: RwLock<PhysicalExprCache>,
63}
64
65impl MultiDimPartitionRule {
66    pub fn try_new(
67        partition_columns: Vec<String>,
68        regions: Vec<RegionNumber>,
69        exprs: Vec<PartitionExpr>,
70    ) -> Result<Self> {
71        let name_to_index = partition_columns
72            .iter()
73            .enumerate()
74            .map(|(i, name)| (name.clone(), i))
75            .collect::<HashMap<_, _>>();
76
77        let rule = Self {
78            partition_columns,
79            name_to_index,
80            regions,
81            exprs,
82            physical_expr_cache: RwLock::new(None),
83        };
84
85        let mut checker = RuleChecker::new(&rule);
86        checker.check()?;
87
88        Ok(rule)
89    }
90
91    fn find_region(&self, values: &[Value]) -> Result<RegionNumber> {
92        ensure!(
93            values.len() == self.partition_columns.len(),
94            error::RegionKeysSizeSnafu {
95                expect: self.partition_columns.len(),
96                actual: values.len(),
97            }
98        );
99
100        for (region_index, expr) in self.exprs.iter().enumerate() {
101            if self.evaluate_expr(expr, values)? {
102                return Ok(self.regions[region_index]);
103            }
104        }
105
106        // return the default region number
107        Ok(DEFAULT_REGION)
108    }
109
110    fn evaluate_expr(&self, expr: &PartitionExpr, values: &[Value]) -> Result<bool> {
111        match (expr.lhs.as_ref(), expr.rhs.as_ref()) {
112            (Operand::Column(name), Operand::Value(r)) => {
113                let index = self.name_to_index.get(name).unwrap();
114                let l = &values[*index];
115                Self::perform_op(l, &expr.op, r)
116            }
117            (Operand::Value(l), Operand::Column(name)) => {
118                let index = self.name_to_index.get(name).unwrap();
119                let r = &values[*index];
120                Self::perform_op(l, &expr.op, r)
121            }
122            (Operand::Expr(lhs), Operand::Expr(rhs)) => {
123                let lhs = self.evaluate_expr(lhs, values)?;
124                let rhs = self.evaluate_expr(rhs, values)?;
125                match expr.op {
126                    RestrictedOp::And => Ok(lhs && rhs),
127                    RestrictedOp::Or => Ok(lhs || rhs),
128                    _ => unreachable!(),
129                }
130            }
131            _ => unreachable!(),
132        }
133    }
134
135    fn perform_op(lhs: &Value, op: &RestrictedOp, rhs: &Value) -> Result<bool> {
136        let result = match op {
137            RestrictedOp::Eq => lhs.eq(rhs),
138            RestrictedOp::NotEq => lhs.ne(rhs),
139            RestrictedOp::Lt => lhs.partial_cmp(rhs) == Some(Ordering::Less),
140            RestrictedOp::LtEq => {
141                let result = lhs.partial_cmp(rhs);
142                result == Some(Ordering::Less) || result == Some(Ordering::Equal)
143            }
144            RestrictedOp::Gt => lhs.partial_cmp(rhs) == Some(Ordering::Greater),
145            RestrictedOp::GtEq => {
146                let result = lhs.partial_cmp(rhs);
147                result == Some(Ordering::Greater) || result == Some(Ordering::Equal)
148            }
149            RestrictedOp::And | RestrictedOp::Or => unreachable!(),
150        };
151
152        Ok(result)
153    }
154
155    pub fn row_at(&self, cols: &[VectorRef], index: usize, row: &mut [Value]) -> Result<()> {
156        for (col_idx, col) in cols.iter().enumerate() {
157            row[col_idx] = col.get(index);
158        }
159        Ok(())
160    }
161
162    pub fn record_batch_to_cols(&self, record_batch: &RecordBatch) -> Result<Vec<VectorRef>> {
163        self.partition_columns
164            .iter()
165            .map(|col_name| {
166                record_batch
167                    .column_by_name(col_name)
168                    .context(error::UndefinedColumnSnafu { column: col_name })
169                    .and_then(|array| {
170                        Helper::try_into_vector(array).context(error::ConvertToVectorSnafu)
171                    })
172            })
173            .collect::<Result<Vec<_>>>()
174    }
175
176    pub fn split_record_batch_naive(
177        &self,
178        record_batch: &RecordBatch,
179    ) -> Result<HashMap<RegionNumber, BooleanArray>> {
180        let num_rows = record_batch.num_rows();
181
182        let mut result = self
183            .regions
184            .iter()
185            .map(|region| {
186                let mut builder = BooleanBufferBuilder::new(num_rows);
187                builder.append_n(num_rows, false);
188                (*region, builder)
189            })
190            .collect::<HashMap<_, _>>();
191
192        let cols = self.record_batch_to_cols(record_batch)?;
193        let mut current_row = vec![Value::Null; self.partition_columns.len()];
194        for row_idx in 0..num_rows {
195            self.row_at(&cols, row_idx, &mut current_row)?;
196            let current_region = self.find_region(&current_row)?;
197            let region_mask = result
198                .get_mut(&current_region)
199                .unwrap_or_else(|| panic!("Region {} must be initialized", current_region));
200            region_mask.set_bit(row_idx, true);
201        }
202
203        Ok(result
204            .into_iter()
205            .map(|(region, mut mask)| (region, BooleanArray::new(mask.finish(), None)))
206            .collect())
207    }
208
209    pub fn split_record_batch(
210        &self,
211        record_batch: &RecordBatch,
212    ) -> Result<HashMap<RegionNumber, BooleanArray>> {
213        let num_rows = record_batch.num_rows();
214        let physical_exprs = {
215            let cache_read_guard = self.physical_expr_cache.read().unwrap();
216            if let Some((cached_exprs, schema)) = cache_read_guard.as_ref()
217                && schema == record_batch.schema_ref()
218            {
219                cached_exprs.clone()
220            } else {
221                drop(cache_read_guard); // Release the read lock before acquiring write lock
222
223                let schema = record_batch.schema();
224                let new_cache = self
225                    .exprs
226                    .iter()
227                    .map(|e| e.try_as_physical_expr(&schema))
228                    .collect::<Result<Vec<_>>>()?;
229
230                let mut cache_write_guard = self.physical_expr_cache.write().unwrap();
231                cache_write_guard.replace((new_cache.clone(), schema));
232                new_cache
233            }
234        };
235
236        let mut result: HashMap<u32, BooleanArray> = physical_exprs
237            .iter()
238            .zip(self.regions.iter())
239            .map(|(expr, region_num)| {
240                let ColumnarValue::Array(column) = expr
241                    .evaluate(record_batch)
242                    .context(error::EvaluateRecordBatchSnafu)?
243                else {
244                    unreachable!("Expected an array")
245                };
246                Ok((
247                    *region_num,
248                    column
249                        .as_any()
250                        .downcast_ref::<BooleanArray>()
251                        .with_context(|| error::UnexpectedColumnTypeSnafu {
252                            data_type: column.data_type().clone(),
253                        })?
254                        .clone(),
255                ))
256            })
257            .collect::<error::Result<_>>()?;
258
259        let mut selected = BooleanArray::new(BooleanBuffer::new_unset(num_rows), None);
260        for region_selection in result.values() {
261            selected = arrow::compute::kernels::boolean::or(&selected, region_selection)
262                .context(error::ComputeArrowKernelSnafu)?;
263        }
264
265        // fast path: all rows are selected
266        if selected.true_count() == num_rows {
267            return Ok(result);
268        }
269
270        // find unselected rows and assign to default region
271        let unselected = arrow::compute::kernels::boolean::not(&selected)
272            .context(error::ComputeArrowKernelSnafu)?;
273        let default_region_selection = result
274            .entry(DEFAULT_REGION)
275            .or_insert_with(|| unselected.clone());
276        *default_region_selection =
277            arrow::compute::kernels::boolean::or(default_region_selection, &unselected)
278                .context(error::ComputeArrowKernelSnafu)?;
279        Ok(result)
280    }
281}
282
283impl PartitionRule for MultiDimPartitionRule {
284    fn as_any(&self) -> &dyn Any {
285        self
286    }
287
288    fn partition_columns(&self) -> Vec<String> {
289        self.partition_columns.clone()
290    }
291
292    fn find_region(&self, values: &[Value]) -> Result<RegionNumber> {
293        self.find_region(values)
294    }
295
296    fn split_record_batch(
297        &self,
298        record_batch: &RecordBatch,
299    ) -> Result<HashMap<RegionNumber, BooleanArray>> {
300        self.split_record_batch(record_batch)
301    }
302}
303
304/// Helper for [RuleChecker]
305type Axis = HashMap<Value, SplitPoint>;
306
307/// Helper for [RuleChecker]
308struct SplitPoint {
309    is_equal: bool,
310    less_than_counter: isize,
311}
312
313/// Check if the rule set covers all the possible values.
314///
315/// Note this checker have false-negative on duplicated exprs. E.g.:
316/// `a != 20`, `a <= 20` and `a > 20`.
317///
318/// It works on the observation that each projected split point should be included (`is_equal`)
319/// and have a balanced `<` and `>` counter.
320struct RuleChecker<'a> {
321    axis: Vec<Axis>,
322    rule: &'a MultiDimPartitionRule,
323}
324
325impl<'a> RuleChecker<'a> {
326    pub fn new(rule: &'a MultiDimPartitionRule) -> Self {
327        let mut projections = Vec::with_capacity(rule.partition_columns.len());
328        projections.resize_with(rule.partition_columns.len(), Default::default);
329
330        Self {
331            axis: projections,
332            rule,
333        }
334    }
335
336    pub fn check(&mut self) -> Result<()> {
337        for expr in &self.rule.exprs {
338            self.walk_expr(expr)?
339        }
340
341        self.check_axis()
342    }
343
344    #[allow(clippy::mutable_key_type)]
345    fn walk_expr(&mut self, expr: &PartitionExpr) -> Result<()> {
346        // recursively check the expr
347        match expr.op {
348            RestrictedOp::And | RestrictedOp::Or => {
349                match (expr.lhs.as_ref(), expr.rhs.as_ref()) {
350                    (Operand::Expr(lhs), Operand::Expr(rhs)) => {
351                        self.walk_expr(lhs)?;
352                        self.walk_expr(rhs)?
353                    }
354                    _ => ConjunctExprWithNonExprSnafu { expr: expr.clone() }.fail()?,
355                }
356
357                return Ok(());
358            }
359            // Not conjunction
360            _ => {}
361        }
362
363        let (col, val) = match (expr.lhs.as_ref(), expr.rhs.as_ref()) {
364            (Operand::Expr(_), _)
365            | (_, Operand::Expr(_))
366            | (Operand::Column(_), Operand::Column(_))
367            | (Operand::Value(_), Operand::Value(_)) => {
368                InvalidExprSnafu { expr: expr.clone() }.fail()?
369            }
370
371            (Operand::Column(col), Operand::Value(val))
372            | (Operand::Value(val), Operand::Column(col)) => (col, val),
373        };
374
375        let col_index =
376            *self
377                .rule
378                .name_to_index
379                .get(col)
380                .with_context(|| UndefinedColumnSnafu {
381                    column: col.clone(),
382                })?;
383        let axis = &mut self.axis[col_index];
384        let split_point = axis.entry(val.clone()).or_insert(SplitPoint {
385            is_equal: false,
386            less_than_counter: 0,
387        });
388        match expr.op {
389            RestrictedOp::Eq => {
390                split_point.is_equal = true;
391            }
392            RestrictedOp::NotEq => {
393                // less_than +1 -1
394            }
395            RestrictedOp::Lt => {
396                split_point.less_than_counter += 1;
397            }
398            RestrictedOp::LtEq => {
399                split_point.less_than_counter += 1;
400                split_point.is_equal = true;
401            }
402            RestrictedOp::Gt => {
403                split_point.less_than_counter -= 1;
404            }
405            RestrictedOp::GtEq => {
406                split_point.less_than_counter -= 1;
407                split_point.is_equal = true;
408            }
409            RestrictedOp::And | RestrictedOp::Or => {
410                unreachable!("conjunct expr should be handled above")
411            }
412        }
413
414        Ok(())
415    }
416
417    /// Return if the rule is legal.
418    fn check_axis(&self) -> Result<()> {
419        for (col_index, axis) in self.axis.iter().enumerate() {
420            for (val, split_point) in axis {
421                if split_point.less_than_counter != 0 || !split_point.is_equal {
422                    UnclosedValueSnafu {
423                        value: format!("{val:?}"),
424                        column: self.rule.partition_columns[col_index].clone(),
425                    }
426                    .fail()?;
427                }
428            }
429        }
430        Ok(())
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use std::assert_matches::assert_matches;
437
438    use super::*;
439    use crate::error::{self, Error};
440
441    #[test]
442    fn test_find_region() {
443        // PARTITION ON COLUMNS (b) (
444        //     b < 'hz',
445        //     b >= 'hz' AND b < 'sh',
446        //     b >= 'sh'
447        // )
448        let rule = MultiDimPartitionRule::try_new(
449            vec!["b".to_string()],
450            vec![1, 2, 3],
451            vec![
452                PartitionExpr::new(
453                    Operand::Column("b".to_string()),
454                    RestrictedOp::Lt,
455                    Operand::Value(datatypes::value::Value::String("hz".into())),
456                ),
457                PartitionExpr::new(
458                    Operand::Expr(PartitionExpr::new(
459                        Operand::Column("b".to_string()),
460                        RestrictedOp::GtEq,
461                        Operand::Value(datatypes::value::Value::String("hz".into())),
462                    )),
463                    RestrictedOp::And,
464                    Operand::Expr(PartitionExpr::new(
465                        Operand::Column("b".to_string()),
466                        RestrictedOp::Lt,
467                        Operand::Value(datatypes::value::Value::String("sh".into())),
468                    )),
469                ),
470                PartitionExpr::new(
471                    Operand::Column("b".to_string()),
472                    RestrictedOp::GtEq,
473                    Operand::Value(datatypes::value::Value::String("sh".into())),
474                ),
475            ],
476        )
477        .unwrap();
478        assert_matches!(
479            rule.find_region(&["foo".into(), 1000_i32.into()]),
480            Err(error::Error::RegionKeysSize {
481                expect: 1,
482                actual: 2,
483                ..
484            })
485        );
486        assert_matches!(rule.find_region(&["foo".into()]), Ok(1));
487        assert_matches!(rule.find_region(&["bar".into()]), Ok(1));
488        assert_matches!(rule.find_region(&["hz".into()]), Ok(2));
489        assert_matches!(rule.find_region(&["hzz".into()]), Ok(2));
490        assert_matches!(rule.find_region(&["sh".into()]), Ok(3));
491        assert_matches!(rule.find_region(&["zzzz".into()]), Ok(3));
492    }
493
494    #[test]
495    fn invalid_expr_case_1() {
496        // PARTITION ON COLUMNS (b) (
497        //     b <= b >= 'hz' AND b < 'sh',
498        // )
499        let rule = MultiDimPartitionRule::try_new(
500            vec!["a".to_string(), "b".to_string()],
501            vec![1],
502            vec![PartitionExpr::new(
503                Operand::Column("b".to_string()),
504                RestrictedOp::LtEq,
505                Operand::Expr(PartitionExpr::new(
506                    Operand::Expr(PartitionExpr::new(
507                        Operand::Column("b".to_string()),
508                        RestrictedOp::GtEq,
509                        Operand::Value(datatypes::value::Value::String("hz".into())),
510                    )),
511                    RestrictedOp::And,
512                    Operand::Expr(PartitionExpr::new(
513                        Operand::Column("b".to_string()),
514                        RestrictedOp::Lt,
515                        Operand::Value(datatypes::value::Value::String("sh".into())),
516                    )),
517                )),
518            )],
519        );
520
521        // check rule
522        assert_matches!(rule.unwrap_err(), Error::InvalidExpr { .. });
523    }
524
525    #[test]
526    fn invalid_expr_case_2() {
527        // PARTITION ON COLUMNS (b) (
528        //     b >= 'hz' AND 'sh',
529        // )
530        let rule = MultiDimPartitionRule::try_new(
531            vec!["a".to_string(), "b".to_string()],
532            vec![1],
533            vec![PartitionExpr::new(
534                Operand::Expr(PartitionExpr::new(
535                    Operand::Column("b".to_string()),
536                    RestrictedOp::GtEq,
537                    Operand::Value(datatypes::value::Value::String("hz".into())),
538                )),
539                RestrictedOp::And,
540                Operand::Value(datatypes::value::Value::String("sh".into())),
541            )],
542        );
543
544        // check rule
545        assert_matches!(rule.unwrap_err(), Error::ConjunctExprWithNonExpr { .. });
546    }
547
548    /// ```ignore
549    ///          │          │               
550    ///          │          │               
551    /// ─────────┼──────────┼────────────► b
552    ///          │          │               
553    ///          │          │               
554    ///      b <= h     b >= s            
555    /// ```
556    #[test]
557    fn empty_expr_case_1() {
558        // PARTITION ON COLUMNS (b) (
559        //     b <= 'h',
560        //     b >= 's'
561        // )
562        let rule = MultiDimPartitionRule::try_new(
563            vec!["a".to_string(), "b".to_string()],
564            vec![1, 2],
565            vec![
566                PartitionExpr::new(
567                    Operand::Column("b".to_string()),
568                    RestrictedOp::LtEq,
569                    Operand::Value(datatypes::value::Value::String("h".into())),
570                ),
571                PartitionExpr::new(
572                    Operand::Column("b".to_string()),
573                    RestrictedOp::GtEq,
574                    Operand::Value(datatypes::value::Value::String("s".into())),
575                ),
576            ],
577        );
578
579        // check rule
580        assert_matches!(rule.unwrap_err(), Error::UnclosedValue { .. });
581    }
582
583    /// ```
584    ///     a                                                  
585    ///     ▲                                        
586    ///     │                   ‖        
587    ///     │                   ‖        
588    /// 200 │         ┌─────────┤        
589    ///     │         │         │        
590    ///     │         │         │        
591    ///     │         │         │        
592    /// 100 │   ======┴─────────┘        
593    ///     │                            
594    ///     └──────────────────────────►b
595    ///              10          20      
596    /// ```
597    #[test]
598    fn empty_expr_case_2() {
599        // PARTITION ON COLUMNS (b) (
600        //     a >= 100 AND b <= 10  OR  a > 100 AND a <= 200 AND b <= 10  OR  a >= 200 AND b > 10 AND b <= 20  OR  a > 200 AND b <= 20
601        //     a < 100 AND b <= 20  OR  a >= 100 AND b > 20
602        // )
603        let rule = MultiDimPartitionRule::try_new(
604            vec!["a".to_string(), "b".to_string()],
605            vec![1, 2],
606            vec![
607                PartitionExpr::new(
608                    Operand::Expr(PartitionExpr::new(
609                        Operand::Expr(PartitionExpr::new(
610                            //  a >= 100 AND b <= 10
611                            Operand::Expr(PartitionExpr::new(
612                                Operand::Expr(PartitionExpr::new(
613                                    Operand::Column("a".to_string()),
614                                    RestrictedOp::GtEq,
615                                    Operand::Value(datatypes::value::Value::Int64(100)),
616                                )),
617                                RestrictedOp::And,
618                                Operand::Expr(PartitionExpr::new(
619                                    Operand::Column("b".to_string()),
620                                    RestrictedOp::LtEq,
621                                    Operand::Value(datatypes::value::Value::Int64(10)),
622                                )),
623                            )),
624                            RestrictedOp::Or,
625                            // a > 100 AND a <= 200 AND b <= 10
626                            Operand::Expr(PartitionExpr::new(
627                                Operand::Expr(PartitionExpr::new(
628                                    Operand::Expr(PartitionExpr::new(
629                                        Operand::Column("a".to_string()),
630                                        RestrictedOp::Gt,
631                                        Operand::Value(datatypes::value::Value::Int64(100)),
632                                    )),
633                                    RestrictedOp::And,
634                                    Operand::Expr(PartitionExpr::new(
635                                        Operand::Column("a".to_string()),
636                                        RestrictedOp::LtEq,
637                                        Operand::Value(datatypes::value::Value::Int64(200)),
638                                    )),
639                                )),
640                                RestrictedOp::And,
641                                Operand::Expr(PartitionExpr::new(
642                                    Operand::Column("b".to_string()),
643                                    RestrictedOp::LtEq,
644                                    Operand::Value(datatypes::value::Value::Int64(10)),
645                                )),
646                            )),
647                        )),
648                        RestrictedOp::Or,
649                        // a >= 200 AND b > 10 AND b <= 20
650                        Operand::Expr(PartitionExpr::new(
651                            Operand::Expr(PartitionExpr::new(
652                                Operand::Expr(PartitionExpr::new(
653                                    Operand::Column("a".to_string()),
654                                    RestrictedOp::GtEq,
655                                    Operand::Value(datatypes::value::Value::Int64(200)),
656                                )),
657                                RestrictedOp::And,
658                                Operand::Expr(PartitionExpr::new(
659                                    Operand::Column("b".to_string()),
660                                    RestrictedOp::Gt,
661                                    Operand::Value(datatypes::value::Value::Int64(10)),
662                                )),
663                            )),
664                            RestrictedOp::And,
665                            Operand::Expr(PartitionExpr::new(
666                                Operand::Column("b".to_string()),
667                                RestrictedOp::LtEq,
668                                Operand::Value(datatypes::value::Value::Int64(20)),
669                            )),
670                        )),
671                    )),
672                    RestrictedOp::Or,
673                    // a > 200 AND b <= 20
674                    Operand::Expr(PartitionExpr::new(
675                        Operand::Expr(PartitionExpr::new(
676                            Operand::Column("a".to_string()),
677                            RestrictedOp::Gt,
678                            Operand::Value(datatypes::value::Value::Int64(200)),
679                        )),
680                        RestrictedOp::And,
681                        Operand::Expr(PartitionExpr::new(
682                            Operand::Column("b".to_string()),
683                            RestrictedOp::LtEq,
684                            Operand::Value(datatypes::value::Value::Int64(20)),
685                        )),
686                    )),
687                ),
688                PartitionExpr::new(
689                    // a < 100 AND b <= 20
690                    Operand::Expr(PartitionExpr::new(
691                        Operand::Expr(PartitionExpr::new(
692                            Operand::Column("a".to_string()),
693                            RestrictedOp::Lt,
694                            Operand::Value(datatypes::value::Value::Int64(100)),
695                        )),
696                        RestrictedOp::And,
697                        Operand::Expr(PartitionExpr::new(
698                            Operand::Column("b".to_string()),
699                            RestrictedOp::LtEq,
700                            Operand::Value(datatypes::value::Value::Int64(20)),
701                        )),
702                    )),
703                    RestrictedOp::Or,
704                    // a >= 100 AND b > 20
705                    Operand::Expr(PartitionExpr::new(
706                        Operand::Expr(PartitionExpr::new(
707                            Operand::Column("a".to_string()),
708                            RestrictedOp::GtEq,
709                            Operand::Value(datatypes::value::Value::Int64(100)),
710                        )),
711                        RestrictedOp::And,
712                        Operand::Expr(PartitionExpr::new(
713                            Operand::Column("b".to_string()),
714                            RestrictedOp::GtEq,
715                            Operand::Value(datatypes::value::Value::Int64(20)),
716                        )),
717                    )),
718                ),
719            ],
720        );
721
722        // check rule
723        assert_matches!(rule.unwrap_err(), Error::UnclosedValue { .. });
724    }
725
726    #[test]
727    fn duplicate_expr_case_1() {
728        // PARTITION ON COLUMNS (a) (
729        //     a <= 20,
730        //     a >= 10
731        // )
732        let rule = MultiDimPartitionRule::try_new(
733            vec!["a".to_string(), "b".to_string()],
734            vec![1, 2],
735            vec![
736                PartitionExpr::new(
737                    Operand::Column("a".to_string()),
738                    RestrictedOp::LtEq,
739                    Operand::Value(datatypes::value::Value::Int64(20)),
740                ),
741                PartitionExpr::new(
742                    Operand::Column("a".to_string()),
743                    RestrictedOp::GtEq,
744                    Operand::Value(datatypes::value::Value::Int64(10)),
745                ),
746            ],
747        );
748
749        // check rule
750        assert_matches!(rule.unwrap_err(), Error::UnclosedValue { .. });
751    }
752
753    #[test]
754    #[ignore = "checker cannot detect this kind of duplicate for now"]
755    fn duplicate_expr_case_2() {
756        // PARTITION ON COLUMNS (a) (
757        //     a != 20,
758        //     a <= 20,
759        //     a > 20,
760        // )
761        let rule = MultiDimPartitionRule::try_new(
762            vec!["a".to_string(), "b".to_string()],
763            vec![1, 2],
764            vec![
765                PartitionExpr::new(
766                    Operand::Column("a".to_string()),
767                    RestrictedOp::NotEq,
768                    Operand::Value(datatypes::value::Value::Int64(20)),
769                ),
770                PartitionExpr::new(
771                    Operand::Column("a".to_string()),
772                    RestrictedOp::LtEq,
773                    Operand::Value(datatypes::value::Value::Int64(20)),
774                ),
775                PartitionExpr::new(
776                    Operand::Column("a".to_string()),
777                    RestrictedOp::Gt,
778                    Operand::Value(datatypes::value::Value::Int64(20)),
779                ),
780            ],
781        );
782
783        // check rule
784        assert!(rule.is_err());
785    }
786}
787
788#[cfg(test)]
789mod test_split_record_batch {
790    use std::sync::Arc;
791
792    use datatypes::arrow::array::{Int64Array, StringArray};
793    use datatypes::arrow::datatypes::{DataType, Field, Schema};
794    use datatypes::arrow::record_batch::RecordBatch;
795    use rand::Rng;
796
797    use super::*;
798    use crate::expr::col;
799
800    fn test_schema() -> Arc<Schema> {
801        Arc::new(Schema::new(vec![
802            Field::new("host", DataType::Utf8, false),
803            Field::new("value", DataType::Int64, false),
804        ]))
805    }
806
807    fn generate_random_record_batch(num_rows: usize) -> RecordBatch {
808        let schema = test_schema();
809        let mut rng = rand::thread_rng();
810        let mut host_array = Vec::with_capacity(num_rows);
811        let mut value_array = Vec::with_capacity(num_rows);
812        for _ in 0..num_rows {
813            host_array.push(format!("server{}", rng.gen_range(0..20)));
814            value_array.push(rng.gen_range(0..20));
815        }
816        let host_array = StringArray::from(host_array);
817        let value_array = Int64Array::from(value_array);
818        RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)]).unwrap()
819    }
820
821    #[test]
822    fn test_split_record_batch_by_one_column() {
823        // Create a simple MultiDimPartitionRule
824        let rule = MultiDimPartitionRule::try_new(
825            vec!["host".to_string(), "value".to_string()],
826            vec![0, 1],
827            vec![
828                col("host").lt(Value::String("server1".into())),
829                col("host").gt_eq(Value::String("server1".into())),
830            ],
831        )
832        .unwrap();
833
834        let batch = generate_random_record_batch(1000);
835        // Split the batch
836        let result = rule.split_record_batch(&batch).unwrap();
837        let expected = rule.split_record_batch_naive(&batch).unwrap();
838        assert_eq!(result.len(), expected.len());
839        for (region, value) in &result {
840            assert_eq!(
841                value,
842                expected.get(region).unwrap(),
843                "failed on region: {}",
844                region
845            );
846        }
847    }
848
849    #[test]
850    fn test_split_record_batch_empty() {
851        // Create a simple MultiDimPartitionRule
852        let rule = MultiDimPartitionRule::try_new(
853            vec!["host".to_string()],
854            vec![1],
855            vec![PartitionExpr::new(
856                Operand::Column("host".to_string()),
857                RestrictedOp::Eq,
858                Operand::Value(Value::String("server1".into())),
859            )],
860        )
861        .unwrap();
862
863        let schema = test_schema();
864        let host_array = StringArray::from(Vec::<&str>::new());
865        let value_array = Int64Array::from(Vec::<i64>::new());
866        let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
867            .unwrap();
868
869        let result = rule.split_record_batch(&batch).unwrap();
870        assert_eq!(result.len(), 1);
871    }
872
873    #[test]
874    fn test_split_record_batch_by_two_columns() {
875        let rule = MultiDimPartitionRule::try_new(
876            vec!["host".to_string(), "value".to_string()],
877            vec![0, 1, 2, 3],
878            vec![
879                col("host")
880                    .lt(Value::String("server10".into()))
881                    .and(col("value").lt(Value::Int64(10))),
882                col("host")
883                    .lt(Value::String("server10".into()))
884                    .and(col("value").gt_eq(Value::Int64(10))),
885                col("host")
886                    .gt_eq(Value::String("server10".into()))
887                    .and(col("value").lt(Value::Int64(10))),
888                col("host")
889                    .gt_eq(Value::String("server10".into()))
890                    .and(col("value").gt_eq(Value::Int64(10))),
891            ],
892        )
893        .unwrap();
894
895        let batch = generate_random_record_batch(1000);
896        let result = rule.split_record_batch(&batch).unwrap();
897        let expected = rule.split_record_batch_naive(&batch).unwrap();
898        assert_eq!(result.len(), expected.len());
899        for (region, value) in &result {
900            assert_eq!(value, expected.get(region).unwrap());
901        }
902    }
903
904    #[test]
905    fn test_default_region() {
906        let rule = MultiDimPartitionRule::try_new(
907            vec!["host".to_string(), "value".to_string()],
908            vec![0, 1, 2, 3],
909            vec![
910                col("host")
911                    .lt(Value::String("server10".into()))
912                    .and(col("value").eq(Value::Int64(10))),
913                col("host")
914                    .lt(Value::String("server10".into()))
915                    .and(col("value").eq(Value::Int64(20))),
916                col("host")
917                    .gt_eq(Value::String("server10".into()))
918                    .and(col("value").eq(Value::Int64(10))),
919                col("host")
920                    .gt_eq(Value::String("server10".into()))
921                    .and(col("value").eq(Value::Int64(20))),
922            ],
923        )
924        .unwrap();
925
926        let schema = test_schema();
927        let host_array = StringArray::from(vec!["server1", "server1", "server1", "server100"]);
928        let value_array = Int64Array::from(vec![10, 20, 30, 10]);
929        let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
930            .unwrap();
931        let result = rule.split_record_batch(&batch).unwrap();
932        let expected = rule.split_record_batch_naive(&batch).unwrap();
933        assert_eq!(result.len(), expected.len());
934        for (region, value) in &result {
935            assert_eq!(value, expected.get(region).unwrap());
936        }
937    }
938}