partition/
multi_dim.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::cmp::Ordering;
17use std::collections::hash_map::Entry;
18use std::collections::HashMap;
19use std::sync::{Arc, RwLock};
20
21use datafusion_expr::ColumnarValue;
22use datafusion_physical_expr::PhysicalExpr;
23use datatypes::arrow;
24use datatypes::arrow::array::{BooleanArray, BooleanBufferBuilder, RecordBatch};
25use datatypes::arrow::buffer::BooleanBuffer;
26use datatypes::arrow::datatypes::Schema;
27use datatypes::prelude::Value;
28use datatypes::vectors::{Helper, VectorRef};
29use serde::{Deserialize, Serialize};
30use snafu::{ensure, OptionExt, ResultExt};
31use store_api::storage::RegionNumber;
32
33use crate::checker::PartitionChecker;
34use crate::error::{self, Result, UndefinedColumnSnafu};
35use crate::expr::{Operand, PartitionExpr, RestrictedOp};
36use crate::partition::RegionMask;
37use crate::PartitionRule;
38
39/// The default region number when no partition exprs are matched.
40const DEFAULT_REGION: RegionNumber = 0;
41
42type PhysicalExprCache = Option<(Vec<Arc<dyn PhysicalExpr>>, Arc<Schema>)>;
43
44/// Multi-Dimiension partition rule. RFC [here](https://github.com/GreptimeTeam/greptimedb/blob/main/docs/rfcs/2024-02-21-multi-dimension-partition-rule/rfc.md)
45///
46/// This partition rule is defined by a set of simple expressions on the partition
47/// key columns. Compare to RANGE partition, which can be considered as
48/// single-dimension rule, this will evaluate expression on each column separately.
49#[derive(Debug, Serialize, Deserialize)]
50pub struct MultiDimPartitionRule {
51    /// Allow list of which columns can be used for partitioning.
52    partition_columns: Vec<String>,
53    /// Name to index of `partition_columns`. Used for quick lookup.
54    name_to_index: HashMap<String, usize>,
55    /// Region number for each partition. This list has the same length as `exprs`
56    /// (dispiting the default region).
57    regions: Vec<RegionNumber>,
58    /// Partition expressions.
59    exprs: Vec<PartitionExpr>,
60    /// Cache of physical expressions.
61    #[serde(skip)]
62    physical_expr_cache: RwLock<PhysicalExprCache>,
63}
64
65impl MultiDimPartitionRule {
66    /// Create a new [`MultiDimPartitionRule`].
67    ///
68    /// If `check_exprs` is true, the function will check if the expressions are valid. This is
69    /// required when constructing a new partition rule like `CREATE TABLE` or `ALTER TABLE`.
70    pub fn try_new(
71        partition_columns: Vec<String>,
72        regions: Vec<RegionNumber>,
73        exprs: Vec<PartitionExpr>,
74        check_exprs: bool,
75    ) -> Result<Self> {
76        let name_to_index = partition_columns
77            .iter()
78            .enumerate()
79            .map(|(i, name)| (name.clone(), i))
80            .collect::<HashMap<_, _>>();
81
82        let rule = Self {
83            partition_columns,
84            name_to_index,
85            regions,
86            exprs,
87            physical_expr_cache: RwLock::new(None),
88        };
89
90        if check_exprs {
91            let checker = PartitionChecker::try_new(&rule)?;
92            checker.check()?;
93        }
94
95        Ok(rule)
96    }
97
98    pub fn exprs(&self) -> &[PartitionExpr] {
99        &self.exprs
100    }
101
102    fn find_region(&self, values: &[Value]) -> Result<RegionNumber> {
103        ensure!(
104            values.len() == self.partition_columns.len(),
105            error::RegionKeysSizeSnafu {
106                expect: self.partition_columns.len(),
107                actual: values.len(),
108            }
109        );
110
111        for (region_index, expr) in self.exprs.iter().enumerate() {
112            if self.evaluate_expr(expr, values)? {
113                return Ok(self.regions[region_index]);
114            }
115        }
116
117        // return the default region number
118        Ok(DEFAULT_REGION)
119    }
120
121    fn evaluate_expr(&self, expr: &PartitionExpr, values: &[Value]) -> Result<bool> {
122        match (expr.lhs.as_ref(), expr.rhs.as_ref()) {
123            (Operand::Column(name), Operand::Value(r)) => {
124                let index = self.name_to_index.get(name).unwrap();
125                let l = &values[*index];
126                Self::perform_op(l, &expr.op, r)
127            }
128            (Operand::Value(l), Operand::Column(name)) => {
129                let index = self.name_to_index.get(name).unwrap();
130                let r = &values[*index];
131                Self::perform_op(l, &expr.op, r)
132            }
133            (Operand::Expr(lhs), Operand::Expr(rhs)) => {
134                let lhs = self.evaluate_expr(lhs, values)?;
135                let rhs = self.evaluate_expr(rhs, values)?;
136                match expr.op {
137                    RestrictedOp::And => Ok(lhs && rhs),
138                    RestrictedOp::Or => Ok(lhs || rhs),
139                    _ => unreachable!(),
140                }
141            }
142            _ => unreachable!(),
143        }
144    }
145
146    fn perform_op(lhs: &Value, op: &RestrictedOp, rhs: &Value) -> Result<bool> {
147        let result = match op {
148            RestrictedOp::Eq => lhs.eq(rhs),
149            RestrictedOp::NotEq => lhs.ne(rhs),
150            RestrictedOp::Lt => lhs.partial_cmp(rhs) == Some(Ordering::Less),
151            RestrictedOp::LtEq => {
152                let result = lhs.partial_cmp(rhs);
153                result == Some(Ordering::Less) || result == Some(Ordering::Equal)
154            }
155            RestrictedOp::Gt => lhs.partial_cmp(rhs) == Some(Ordering::Greater),
156            RestrictedOp::GtEq => {
157                let result = lhs.partial_cmp(rhs);
158                result == Some(Ordering::Greater) || result == Some(Ordering::Equal)
159            }
160            RestrictedOp::And | RestrictedOp::Or => unreachable!(),
161        };
162
163        Ok(result)
164    }
165
166    pub fn row_at(&self, cols: &[VectorRef], index: usize, row: &mut [Value]) -> Result<()> {
167        for (col_idx, col) in cols.iter().enumerate() {
168            row[col_idx] = col.get(index);
169        }
170        Ok(())
171    }
172
173    pub fn record_batch_to_cols(&self, record_batch: &RecordBatch) -> Result<Vec<VectorRef>> {
174        self.partition_columns
175            .iter()
176            .map(|col_name| {
177                record_batch
178                    .column_by_name(col_name)
179                    .context(UndefinedColumnSnafu { column: col_name })
180                    .and_then(|array| {
181                        Helper::try_into_vector(array).context(error::ConvertToVectorSnafu)
182                    })
183            })
184            .collect::<Result<Vec<_>>>()
185    }
186
187    pub fn split_record_batch_naive(
188        &self,
189        record_batch: &RecordBatch,
190    ) -> Result<HashMap<RegionNumber, BooleanArray>> {
191        let num_rows = record_batch.num_rows();
192
193        let mut result = self
194            .regions
195            .iter()
196            .map(|region| {
197                let mut builder = BooleanBufferBuilder::new(num_rows);
198                builder.append_n(num_rows, false);
199                (*region, builder)
200            })
201            .collect::<HashMap<_, _>>();
202
203        let cols = self.record_batch_to_cols(record_batch)?;
204        let mut current_row = vec![Value::Null; self.partition_columns.len()];
205        for row_idx in 0..num_rows {
206            self.row_at(&cols, row_idx, &mut current_row)?;
207            let current_region = self.find_region(&current_row)?;
208            let region_mask = result
209                .get_mut(&current_region)
210                .unwrap_or_else(|| panic!("Region {} must be initialized", current_region));
211            region_mask.set_bit(row_idx, true);
212        }
213
214        Ok(result
215            .into_iter()
216            .map(|(region, mut mask)| (region, BooleanArray::new(mask.finish(), None)))
217            .collect())
218    }
219
220    pub fn split_record_batch(
221        &self,
222        record_batch: &RecordBatch,
223    ) -> Result<HashMap<RegionNumber, RegionMask>> {
224        let num_rows = record_batch.num_rows();
225        if self.regions.len() == 1 {
226            return Ok([(
227                self.regions[0],
228                RegionMask::from(BooleanArray::from(vec![true; num_rows])),
229            )]
230            .into_iter()
231            .collect());
232        }
233        let physical_exprs = {
234            let cache_read_guard = self.physical_expr_cache.read().unwrap();
235            if let Some((cached_exprs, schema)) = cache_read_guard.as_ref()
236                && schema == record_batch.schema_ref()
237            {
238                cached_exprs.clone()
239            } else {
240                drop(cache_read_guard); // Release the read lock before acquiring write lock
241
242                let schema = record_batch.schema();
243                let new_cache = self
244                    .exprs
245                    .iter()
246                    .map(|e| e.try_as_physical_expr(&schema))
247                    .collect::<Result<Vec<_>>>()?;
248
249                let mut cache_write_guard = self.physical_expr_cache.write().unwrap();
250                cache_write_guard.replace((new_cache.clone(), schema));
251                new_cache
252            }
253        };
254
255        let mut result: HashMap<u32, RegionMask> = physical_exprs
256            .iter()
257            .zip(self.regions.iter())
258            .filter_map(|(expr, region_num)| {
259                let col_val = match expr
260                    .evaluate(record_batch)
261                    .context(error::EvaluateRecordBatchSnafu)
262                {
263                    Ok(array) => array,
264                    Err(e) => {
265                        return Some(Err(e));
266                    }
267                };
268                let ColumnarValue::Array(column) = col_val else {
269                    unreachable!("Expected an array")
270                };
271                let array =
272                    match column
273                        .as_any()
274                        .downcast_ref::<BooleanArray>()
275                        .with_context(|| error::UnexpectedColumnTypeSnafu {
276                            data_type: column.data_type().clone(),
277                        }) {
278                        Ok(array) => array,
279                        Err(e) => {
280                            return Some(Err(e));
281                        }
282                    };
283                let selected_rows = array.true_count();
284                if selected_rows == 0 {
285                    // skip empty region in results.
286                    return None;
287                }
288                Some(Ok((
289                    *region_num,
290                    RegionMask::new(array.clone(), selected_rows),
291                )))
292            })
293            .collect::<error::Result<_>>()?;
294
295        let selected = if result.len() == 1 {
296            result.values().next().unwrap().array().clone()
297        } else {
298            let mut selected = BooleanArray::new(BooleanBuffer::new_unset(num_rows), None);
299            for region_mask in result.values() {
300                selected = arrow::compute::kernels::boolean::or(&selected, region_mask.array())
301                    .context(error::ComputeArrowKernelSnafu)?;
302            }
303            selected
304        };
305
306        // fast path: all rows are selected
307        if selected.true_count() == num_rows {
308            return Ok(result);
309        }
310
311        // find unselected rows and assign to default region
312        let unselected = arrow::compute::kernels::boolean::not(&selected)
313            .context(error::ComputeArrowKernelSnafu)?;
314        match result.entry(DEFAULT_REGION) {
315            Entry::Occupied(mut o) => {
316                // merge default region with unselected rows.
317                let default_region_mask = RegionMask::from(
318                    arrow::compute::kernels::boolean::or(o.get().array(), &unselected)
319                        .context(error::ComputeArrowKernelSnafu)?,
320                );
321                o.insert(default_region_mask);
322            }
323            Entry::Vacant(v) => {
324                // default region has no rows, simply put all unselected rows to default region.
325                v.insert(RegionMask::from(unselected));
326            }
327        }
328        Ok(result)
329    }
330}
331
332impl PartitionRule for MultiDimPartitionRule {
333    fn as_any(&self) -> &dyn Any {
334        self
335    }
336
337    fn partition_columns(&self) -> Vec<String> {
338        self.partition_columns.clone()
339    }
340
341    fn find_region(&self, values: &[Value]) -> Result<RegionNumber> {
342        self.find_region(values)
343    }
344
345    fn split_record_batch(
346        &self,
347        record_batch: &RecordBatch,
348    ) -> Result<HashMap<RegionNumber, RegionMask>> {
349        self.split_record_batch(record_batch)
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use std::assert_matches::assert_matches;
356
357    use super::*;
358    use crate::error::{self, Error};
359    use crate::expr::col;
360
361    #[test]
362    fn test_find_region() {
363        // PARTITION ON COLUMNS (b) (
364        //     b < 'hz',
365        //     b >= 'hz' AND b < 'sh',
366        //     b >= 'sh'
367        // )
368        let rule = MultiDimPartitionRule::try_new(
369            vec!["b".to_string()],
370            vec![1, 2, 3],
371            vec![
372                PartitionExpr::new(
373                    Operand::Column("b".to_string()),
374                    RestrictedOp::Lt,
375                    Operand::Value(datatypes::value::Value::String("hz".into())),
376                ),
377                PartitionExpr::new(
378                    Operand::Expr(PartitionExpr::new(
379                        Operand::Column("b".to_string()),
380                        RestrictedOp::GtEq,
381                        Operand::Value(datatypes::value::Value::String("hz".into())),
382                    )),
383                    RestrictedOp::And,
384                    Operand::Expr(PartitionExpr::new(
385                        Operand::Column("b".to_string()),
386                        RestrictedOp::Lt,
387                        Operand::Value(datatypes::value::Value::String("sh".into())),
388                    )),
389                ),
390                PartitionExpr::new(
391                    Operand::Column("b".to_string()),
392                    RestrictedOp::GtEq,
393                    Operand::Value(datatypes::value::Value::String("sh".into())),
394                ),
395            ],
396            true,
397        )
398        .unwrap();
399        assert_matches!(
400            rule.find_region(&["foo".into(), 1000_i32.into()]),
401            Err(error::Error::RegionKeysSize {
402                expect: 1,
403                actual: 2,
404                ..
405            })
406        );
407        assert_matches!(rule.find_region(&["foo".into()]), Ok(1));
408        assert_matches!(rule.find_region(&["bar".into()]), Ok(1));
409        assert_matches!(rule.find_region(&["hz".into()]), Ok(2));
410        assert_matches!(rule.find_region(&["hzz".into()]), Ok(2));
411        assert_matches!(rule.find_region(&["sh".into()]), Ok(3));
412        assert_matches!(rule.find_region(&["zzzz".into()]), Ok(3));
413    }
414
415    #[test]
416    fn invalid_expr_case_1() {
417        // PARTITION ON COLUMNS (b) (
418        //     b <= b >= 'hz' AND b < 'sh',
419        // )
420        let rule = MultiDimPartitionRule::try_new(
421            vec!["a".to_string(), "b".to_string()],
422            vec![1],
423            vec![PartitionExpr::new(
424                Operand::Column("b".to_string()),
425                RestrictedOp::LtEq,
426                Operand::Expr(PartitionExpr::new(
427                    Operand::Expr(PartitionExpr::new(
428                        Operand::Column("b".to_string()),
429                        RestrictedOp::GtEq,
430                        Operand::Value(datatypes::value::Value::String("hz".into())),
431                    )),
432                    RestrictedOp::And,
433                    Operand::Expr(PartitionExpr::new(
434                        Operand::Column("b".to_string()),
435                        RestrictedOp::Lt,
436                        Operand::Value(datatypes::value::Value::String("sh".into())),
437                    )),
438                )),
439            )],
440            true,
441        );
442
443        // check rule
444        assert_matches!(rule.unwrap_err(), Error::InvalidExpr { .. });
445    }
446
447    #[test]
448    fn invalid_expr_case_2() {
449        // PARTITION ON COLUMNS (b) (
450        //     b >= 'hz' AND 'sh',
451        // )
452        let rule = MultiDimPartitionRule::try_new(
453            vec!["a".to_string(), "b".to_string()],
454            vec![1],
455            vec![PartitionExpr::new(
456                Operand::Expr(PartitionExpr::new(
457                    Operand::Column("b".to_string()),
458                    RestrictedOp::GtEq,
459                    Operand::Value(datatypes::value::Value::String("hz".into())),
460                )),
461                RestrictedOp::And,
462                Operand::Value(datatypes::value::Value::String("sh".into())),
463            )],
464            true,
465        );
466
467        // check rule
468        assert_matches!(rule.unwrap_err(), Error::InvalidExpr { .. });
469    }
470
471    /// ```ignore
472    ///          │          │               
473    ///          │          │               
474    /// ─────────┼──────────┼────────────► b
475    ///          │          │               
476    ///          │          │               
477    ///      b <= h     b >= s            
478    /// ```
479    #[test]
480    fn empty_expr_case_1() {
481        // PARTITION ON COLUMNS (b) (
482        //     b <= 'h',
483        //     b >= 's'
484        // )
485        let rule = MultiDimPartitionRule::try_new(
486            vec!["a".to_string(), "b".to_string()],
487            vec![1, 2],
488            vec![
489                PartitionExpr::new(
490                    Operand::Column("b".to_string()),
491                    RestrictedOp::LtEq,
492                    Operand::Value(datatypes::value::Value::String("h".into())),
493                ),
494                PartitionExpr::new(
495                    Operand::Column("b".to_string()),
496                    RestrictedOp::GtEq,
497                    Operand::Value(datatypes::value::Value::String("s".into())),
498                ),
499            ],
500            true,
501        );
502
503        // check rule
504        assert_matches!(rule.unwrap_err(), Error::CheckpointNotCovered { .. });
505    }
506
507    /// ```
508    ///     a                                                  
509    ///     ▲                                        
510    ///     │                   ‖        
511    ///     │                   ‖        
512    /// 200 │         ┌─────────┤        
513    ///     │         │         │        
514    ///     │         │         │        
515    ///     │         │         │        
516    /// 100 │   ======┴─────────┘        
517    ///     │                            
518    ///     └──────────────────────────►b
519    ///              10          20      
520    /// ```
521    #[test]
522    fn empty_expr_case_2() {
523        // PARTITION ON COLUMNS (b) (
524        //     a >= 100 AND b <= 10  OR  a > 100 AND a <= 200 AND b <= 10  OR  a >= 200 AND b > 10 AND b <= 20  OR  a > 200 AND b <= 20
525        //     a < 100 AND b <= 20  OR  a >= 100 AND b > 20
526        // )
527        let rule = MultiDimPartitionRule::try_new(
528            vec!["a".to_string(), "b".to_string()],
529            vec![1, 2],
530            vec![
531                PartitionExpr::new(
532                    Operand::Expr(PartitionExpr::new(
533                        Operand::Expr(PartitionExpr::new(
534                            //  a >= 100 AND b <= 10
535                            Operand::Expr(PartitionExpr::new(
536                                Operand::Expr(PartitionExpr::new(
537                                    Operand::Column("a".to_string()),
538                                    RestrictedOp::GtEq,
539                                    Operand::Value(datatypes::value::Value::Int64(100)),
540                                )),
541                                RestrictedOp::And,
542                                Operand::Expr(PartitionExpr::new(
543                                    Operand::Column("b".to_string()),
544                                    RestrictedOp::LtEq,
545                                    Operand::Value(datatypes::value::Value::Int64(10)),
546                                )),
547                            )),
548                            RestrictedOp::Or,
549                            // a > 100 AND a <= 200 AND b <= 10
550                            Operand::Expr(PartitionExpr::new(
551                                Operand::Expr(PartitionExpr::new(
552                                    Operand::Expr(PartitionExpr::new(
553                                        Operand::Column("a".to_string()),
554                                        RestrictedOp::Gt,
555                                        Operand::Value(datatypes::value::Value::Int64(100)),
556                                    )),
557                                    RestrictedOp::And,
558                                    Operand::Expr(PartitionExpr::new(
559                                        Operand::Column("a".to_string()),
560                                        RestrictedOp::LtEq,
561                                        Operand::Value(datatypes::value::Value::Int64(200)),
562                                    )),
563                                )),
564                                RestrictedOp::And,
565                                Operand::Expr(PartitionExpr::new(
566                                    Operand::Column("b".to_string()),
567                                    RestrictedOp::LtEq,
568                                    Operand::Value(datatypes::value::Value::Int64(10)),
569                                )),
570                            )),
571                        )),
572                        RestrictedOp::Or,
573                        // a >= 200 AND b > 10 AND b <= 20
574                        Operand::Expr(PartitionExpr::new(
575                            Operand::Expr(PartitionExpr::new(
576                                Operand::Expr(PartitionExpr::new(
577                                    Operand::Column("a".to_string()),
578                                    RestrictedOp::GtEq,
579                                    Operand::Value(datatypes::value::Value::Int64(200)),
580                                )),
581                                RestrictedOp::And,
582                                Operand::Expr(PartitionExpr::new(
583                                    Operand::Column("b".to_string()),
584                                    RestrictedOp::Gt,
585                                    Operand::Value(datatypes::value::Value::Int64(10)),
586                                )),
587                            )),
588                            RestrictedOp::And,
589                            Operand::Expr(PartitionExpr::new(
590                                Operand::Column("b".to_string()),
591                                RestrictedOp::LtEq,
592                                Operand::Value(datatypes::value::Value::Int64(20)),
593                            )),
594                        )),
595                    )),
596                    RestrictedOp::Or,
597                    // a > 200 AND b <= 20
598                    Operand::Expr(PartitionExpr::new(
599                        Operand::Expr(PartitionExpr::new(
600                            Operand::Column("a".to_string()),
601                            RestrictedOp::Gt,
602                            Operand::Value(datatypes::value::Value::Int64(200)),
603                        )),
604                        RestrictedOp::And,
605                        Operand::Expr(PartitionExpr::new(
606                            Operand::Column("b".to_string()),
607                            RestrictedOp::LtEq,
608                            Operand::Value(datatypes::value::Value::Int64(20)),
609                        )),
610                    )),
611                ),
612                PartitionExpr::new(
613                    // a < 100 AND b <= 20
614                    Operand::Expr(PartitionExpr::new(
615                        Operand::Expr(PartitionExpr::new(
616                            Operand::Column("a".to_string()),
617                            RestrictedOp::Lt,
618                            Operand::Value(datatypes::value::Value::Int64(100)),
619                        )),
620                        RestrictedOp::And,
621                        Operand::Expr(PartitionExpr::new(
622                            Operand::Column("b".to_string()),
623                            RestrictedOp::LtEq,
624                            Operand::Value(datatypes::value::Value::Int64(20)),
625                        )),
626                    )),
627                    RestrictedOp::Or,
628                    // a >= 100 AND b > 20
629                    Operand::Expr(PartitionExpr::new(
630                        Operand::Expr(PartitionExpr::new(
631                            Operand::Column("a".to_string()),
632                            RestrictedOp::GtEq,
633                            Operand::Value(datatypes::value::Value::Int64(100)),
634                        )),
635                        RestrictedOp::And,
636                        Operand::Expr(PartitionExpr::new(
637                            Operand::Column("b".to_string()),
638                            RestrictedOp::GtEq,
639                            Operand::Value(datatypes::value::Value::Int64(20)),
640                        )),
641                    )),
642                ),
643            ],
644            true,
645        );
646
647        // check rule
648        assert_matches!(rule.unwrap_err(), Error::CheckpointNotCovered { .. });
649    }
650
651    #[test]
652    fn duplicate_expr_case_1() {
653        // PARTITION ON COLUMNS (a) (
654        //     a <= 20,
655        //     a >= 10
656        // )
657        let rule = MultiDimPartitionRule::try_new(
658            vec!["a".to_string(), "b".to_string()],
659            vec![1, 2],
660            vec![
661                PartitionExpr::new(
662                    Operand::Column("a".to_string()),
663                    RestrictedOp::LtEq,
664                    Operand::Value(datatypes::value::Value::Int64(20)),
665                ),
666                PartitionExpr::new(
667                    Operand::Column("a".to_string()),
668                    RestrictedOp::GtEq,
669                    Operand::Value(datatypes::value::Value::Int64(10)),
670                ),
671            ],
672            true,
673        );
674
675        // check rule
676        assert_matches!(rule.unwrap_err(), Error::CheckpointOverlapped { .. });
677    }
678
679    #[test]
680    fn duplicate_expr_case_2() {
681        // PARTITION ON COLUMNS (a) (
682        //     a != 20,
683        //     a <= 20,
684        //     a > 20,
685        // )
686        let rule = MultiDimPartitionRule::try_new(
687            vec!["a".to_string(), "b".to_string()],
688            vec![1, 2],
689            vec![
690                PartitionExpr::new(
691                    Operand::Column("a".to_string()),
692                    RestrictedOp::NotEq,
693                    Operand::Value(datatypes::value::Value::Int64(20)),
694                ),
695                PartitionExpr::new(
696                    Operand::Column("a".to_string()),
697                    RestrictedOp::LtEq,
698                    Operand::Value(datatypes::value::Value::Int64(20)),
699                ),
700                PartitionExpr::new(
701                    Operand::Column("a".to_string()),
702                    RestrictedOp::Gt,
703                    Operand::Value(datatypes::value::Value::Int64(20)),
704                ),
705            ],
706            true,
707        );
708
709        // check rule
710        assert_matches!(rule.unwrap_err(), Error::CheckpointOverlapped { .. });
711    }
712
713    /// ```ignore
714    /// value
715    ///                                 │
716    ///                                 │
717    ///    value=10 --------------------│
718    ///                                 │
719    /// ────────────────────────────────┼──► host
720    ///                                 │
721    ///                             host=server10
722    /// ```
723    #[test]
724    fn test_partial_divided() {
725        let _rule = MultiDimPartitionRule::try_new(
726            vec!["host".to_string(), "value".to_string()],
727            vec![0, 1, 2, 3],
728            vec![
729                col("host")
730                    .lt(Value::String("server10".into()))
731                    .and(col("value").lt(Value::Int64(10))),
732                col("host")
733                    .lt(Value::String("server10".into()))
734                    .and(col("value").gt_eq(Value::Int64(10))),
735                col("host").gt_eq(Value::String("server10".into())),
736            ],
737            true,
738        )
739        .unwrap();
740    }
741}
742
743#[cfg(test)]
744mod test_split_record_batch {
745    use std::sync::Arc;
746
747    use datatypes::arrow::array::{Int64Array, StringArray};
748    use datatypes::arrow::datatypes::{DataType, Field, Schema};
749    use datatypes::arrow::record_batch::RecordBatch;
750    use rand::Rng;
751
752    use super::*;
753    use crate::expr::col;
754
755    fn test_schema() -> Arc<Schema> {
756        Arc::new(Schema::new(vec![
757            Field::new("host", DataType::Utf8, false),
758            Field::new("value", DataType::Int64, false),
759        ]))
760    }
761
762    fn generate_random_record_batch(num_rows: usize) -> RecordBatch {
763        let schema = test_schema();
764        let mut rng = rand::thread_rng();
765        let mut host_array = Vec::with_capacity(num_rows);
766        let mut value_array = Vec::with_capacity(num_rows);
767        for _ in 0..num_rows {
768            host_array.push(format!("server{}", rng.gen_range(0..20)));
769            value_array.push(rng.gen_range(0..20));
770        }
771        let host_array = StringArray::from(host_array);
772        let value_array = Int64Array::from(value_array);
773        RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)]).unwrap()
774    }
775
776    #[test]
777    fn test_split_record_batch_by_one_column() {
778        // Create a simple MultiDimPartitionRule
779        let rule = MultiDimPartitionRule::try_new(
780            vec!["host".to_string(), "value".to_string()],
781            vec![0, 1],
782            vec![
783                col("host").lt(Value::String("server1".into())),
784                col("host").gt_eq(Value::String("server1".into())),
785            ],
786            true,
787        )
788        .unwrap();
789
790        let batch = generate_random_record_batch(1000);
791        // Split the batch
792        let result = rule.split_record_batch(&batch).unwrap();
793        let expected = rule.split_record_batch_naive(&batch).unwrap();
794        assert_eq!(result.len(), expected.len());
795        for (region, value) in &result {
796            assert_eq!(
797                value.array(),
798                expected.get(region).unwrap(),
799                "failed on region: {}",
800                region
801            );
802        }
803    }
804
805    #[test]
806    fn test_split_record_batch_empty() {
807        // Create a simple MultiDimPartitionRule
808        let rule = MultiDimPartitionRule::try_new(
809            vec!["host".to_string()],
810            vec![1],
811            vec![
812                col("host").lt(Value::String("server1".into())),
813                col("host").gt_eq(Value::String("server1".into())),
814            ],
815            true,
816        )
817        .unwrap();
818
819        let schema = test_schema();
820        let host_array = StringArray::from(Vec::<&str>::new());
821        let value_array = Int64Array::from(Vec::<i64>::new());
822        let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
823            .unwrap();
824
825        let result = rule.split_record_batch(&batch).unwrap();
826        assert_eq!(result.len(), 1);
827    }
828
829    #[test]
830    fn test_split_record_batch_by_two_columns() {
831        let rule = MultiDimPartitionRule::try_new(
832            vec!["host".to_string(), "value".to_string()],
833            vec![0, 1, 2, 3],
834            vec![
835                col("host")
836                    .lt(Value::String("server10".into()))
837                    .and(col("value").lt(Value::Int64(10))),
838                col("host")
839                    .lt(Value::String("server10".into()))
840                    .and(col("value").gt_eq(Value::Int64(10))),
841                col("host")
842                    .gt_eq(Value::String("server10".into()))
843                    .and(col("value").lt(Value::Int64(10))),
844                col("host")
845                    .gt_eq(Value::String("server10".into()))
846                    .and(col("value").gt_eq(Value::Int64(10))),
847            ],
848            true,
849        )
850        .unwrap();
851
852        let batch = generate_random_record_batch(1000);
853        let result = rule.split_record_batch(&batch).unwrap();
854        let expected = rule.split_record_batch_naive(&batch).unwrap();
855        assert_eq!(result.len(), expected.len());
856        for (region, value) in &result {
857            assert_eq!(value.array(), expected.get(region).unwrap());
858        }
859    }
860
861    #[test]
862    fn test_all_rows_selected() {
863        // Test the fast path where all rows are selected by some partition
864        let rule = MultiDimPartitionRule::try_new(
865            vec!["value".to_string()],
866            vec![1, 2],
867            vec![
868                col("value").lt(Value::Int64(30)),
869                col("value").gt_eq(Value::Int64(30)),
870            ],
871            true,
872        )
873        .unwrap();
874
875        let schema = test_schema();
876        let host_array = StringArray::from(vec!["server1", "server2", "server3", "server4"]);
877        let value_array = Int64Array::from(vec![10, 20, 30, 40]);
878        let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
879            .unwrap();
880
881        let result = rule.split_record_batch(&batch).unwrap();
882
883        // Check that we have 2 regions and no default region
884        assert_eq!(result.len(), 2);
885        assert!(result.contains_key(&1));
886        assert!(result.contains_key(&2));
887
888        // Verify each region has the correct number of rows
889        assert_eq!(result.get(&1).unwrap().selected_rows(), 2); // values < 30
890        assert_eq!(result.get(&2).unwrap().selected_rows(), 2); // values >= 30
891    }
892}