partition/
multi_dim.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::cmp::Ordering;
17use std::collections::HashMap;
18use std::collections::hash_map::Entry;
19use std::sync::{Arc, RwLock};
20
21use datafusion_expr::ColumnarValue;
22use datafusion_physical_expr::PhysicalExpr;
23use datatypes::arrow;
24use datatypes::arrow::array::{BooleanArray, BooleanBufferBuilder, RecordBatch};
25use datatypes::arrow::buffer::BooleanBuffer;
26use datatypes::arrow::datatypes::Schema;
27use datatypes::prelude::Value;
28use datatypes::vectors::{Helper, VectorRef};
29use serde::{Deserialize, Serialize};
30use snafu::{OptionExt, ResultExt, ensure};
31use store_api::storage::RegionNumber;
32
33use crate::PartitionRule;
34use crate::checker::PartitionChecker;
35use crate::error::{self, Result, UndefinedColumnSnafu};
36use crate::expr::{Operand, PartitionExpr, RestrictedOp};
37use crate::partition::RegionMask;
38
39/// The default region number when no partition exprs are matched.
40const DEFAULT_REGION: RegionNumber = 0;
41
42type PhysicalExprCache = Option<(Vec<Arc<dyn PhysicalExpr>>, Arc<Schema>)>;
43
44/// Multi-Dimiension partition rule. RFC [here](https://github.com/GreptimeTeam/greptimedb/blob/main/docs/rfcs/2024-02-21-multi-dimension-partition-rule/rfc.md)
45///
46/// This partition rule is defined by a set of simple expressions on the partition
47/// key columns. Compare to RANGE partition, which can be considered as
48/// single-dimension rule, this will evaluate expression on each column separately.
49#[derive(Debug, Serialize, Deserialize)]
50pub struct MultiDimPartitionRule {
51    /// Allow list of which columns can be used for partitioning.
52    partition_columns: Vec<String>,
53    /// Name to index of `partition_columns`. Used for quick lookup.
54    name_to_index: HashMap<String, usize>,
55    /// Region number for each partition. This list has the same length as `exprs`
56    /// (dispiting the default region).
57    regions: Vec<RegionNumber>,
58    /// Partition expressions.
59    exprs: Vec<PartitionExpr>,
60    /// Cache of physical expressions.
61    #[serde(skip)]
62    physical_expr_cache: RwLock<PhysicalExprCache>,
63}
64
65impl MultiDimPartitionRule {
66    /// Create a new [`MultiDimPartitionRule`].
67    ///
68    /// If `check_exprs` is true, the function will check if the expressions are valid. This is
69    /// required when constructing a new partition rule like `CREATE TABLE` or `ALTER TABLE`.
70    pub fn try_new(
71        partition_columns: Vec<String>,
72        regions: Vec<RegionNumber>,
73        exprs: Vec<PartitionExpr>,
74        check_exprs: bool,
75    ) -> Result<Self> {
76        let name_to_index = partition_columns
77            .iter()
78            .enumerate()
79            .map(|(i, name)| (name.clone(), i))
80            .collect::<HashMap<_, _>>();
81
82        let rule = Self {
83            partition_columns,
84            name_to_index,
85            regions,
86            exprs,
87            physical_expr_cache: RwLock::new(None),
88        };
89
90        if check_exprs {
91            let checker = PartitionChecker::try_new(&rule)?;
92            checker.check()?;
93        }
94
95        Ok(rule)
96    }
97
98    pub fn exprs(&self) -> &[PartitionExpr] {
99        &self.exprs
100    }
101
102    fn find_region(&self, values: &[Value]) -> Result<RegionNumber> {
103        ensure!(
104            values.len() == self.partition_columns.len(),
105            error::RegionKeysSizeSnafu {
106                expect: self.partition_columns.len(),
107                actual: values.len(),
108            }
109        );
110
111        for (region_index, expr) in self.exprs.iter().enumerate() {
112            if self.evaluate_expr(expr, values)? {
113                return Ok(self.regions[region_index]);
114            }
115        }
116
117        // return the default region number
118        Ok(DEFAULT_REGION)
119    }
120
121    fn evaluate_expr(&self, expr: &PartitionExpr, values: &[Value]) -> Result<bool> {
122        match (expr.lhs.as_ref(), expr.rhs.as_ref()) {
123            (Operand::Column(name), Operand::Value(r)) => {
124                let index = self.name_to_index.get(name).unwrap();
125                let l = &values[*index];
126                Self::perform_op(l, &expr.op, r)
127            }
128            (Operand::Value(l), Operand::Column(name)) => {
129                let index = self.name_to_index.get(name).unwrap();
130                let r = &values[*index];
131                Self::perform_op(l, &expr.op, r)
132            }
133            (Operand::Expr(lhs), Operand::Expr(rhs)) => {
134                let lhs = self.evaluate_expr(lhs, values)?;
135                let rhs = self.evaluate_expr(rhs, values)?;
136                match expr.op {
137                    RestrictedOp::And => Ok(lhs && rhs),
138                    RestrictedOp::Or => Ok(lhs || rhs),
139                    _ => unreachable!(),
140                }
141            }
142            _ => unreachable!(),
143        }
144    }
145
146    fn perform_op(lhs: &Value, op: &RestrictedOp, rhs: &Value) -> Result<bool> {
147        let result = match op {
148            RestrictedOp::Eq => lhs.eq(rhs),
149            RestrictedOp::NotEq => lhs.ne(rhs),
150            RestrictedOp::Lt => lhs.partial_cmp(rhs) == Some(Ordering::Less),
151            RestrictedOp::LtEq => {
152                let result = lhs.partial_cmp(rhs);
153                result == Some(Ordering::Less) || result == Some(Ordering::Equal)
154            }
155            RestrictedOp::Gt => lhs.partial_cmp(rhs) == Some(Ordering::Greater),
156            RestrictedOp::GtEq => {
157                let result = lhs.partial_cmp(rhs);
158                result == Some(Ordering::Greater) || result == Some(Ordering::Equal)
159            }
160            RestrictedOp::And | RestrictedOp::Or => unreachable!(),
161        };
162
163        Ok(result)
164    }
165
166    pub fn row_at(&self, cols: &[VectorRef], index: usize, row: &mut [Value]) -> Result<()> {
167        for (col_idx, col) in cols.iter().enumerate() {
168            row[col_idx] = col.get(index);
169        }
170        Ok(())
171    }
172
173    pub fn record_batch_to_cols(&self, record_batch: &RecordBatch) -> Result<Vec<VectorRef>> {
174        self.partition_columns
175            .iter()
176            .map(|col_name| {
177                record_batch
178                    .column_by_name(col_name)
179                    .context(UndefinedColumnSnafu { column: col_name })
180                    .and_then(|array| {
181                        Helper::try_into_vector(array).context(error::ConvertToVectorSnafu)
182                    })
183            })
184            .collect::<Result<Vec<_>>>()
185    }
186
187    pub fn split_record_batch_naive(
188        &self,
189        record_batch: &RecordBatch,
190    ) -> Result<HashMap<RegionNumber, BooleanArray>> {
191        let num_rows = record_batch.num_rows();
192
193        let mut result = self
194            .regions
195            .iter()
196            .map(|region| {
197                let mut builder = BooleanBufferBuilder::new(num_rows);
198                builder.append_n(num_rows, false);
199                (*region, builder)
200            })
201            .collect::<HashMap<_, _>>();
202
203        let cols = self.record_batch_to_cols(record_batch)?;
204        let mut current_row = vec![Value::Null; self.partition_columns.len()];
205        for row_idx in 0..num_rows {
206            self.row_at(&cols, row_idx, &mut current_row)?;
207            let current_region = self.find_region(&current_row)?;
208            let region_mask = result
209                .get_mut(&current_region)
210                .unwrap_or_else(|| panic!("Region {} must be initialized", current_region));
211            region_mask.set_bit(row_idx, true);
212        }
213
214        Ok(result
215            .into_iter()
216            .map(|(region, mut mask)| (region, BooleanArray::new(mask.finish(), None)))
217            .collect())
218    }
219
220    pub fn split_record_batch(
221        &self,
222        record_batch: &RecordBatch,
223    ) -> Result<HashMap<RegionNumber, RegionMask>> {
224        let num_rows = record_batch.num_rows();
225        if self.regions.len() == 1 {
226            return Ok([(
227                self.regions[0],
228                RegionMask::from(BooleanArray::from(vec![true; num_rows])),
229            )]
230            .into_iter()
231            .collect());
232        }
233        let physical_exprs = {
234            let cache_read_guard = self.physical_expr_cache.read().unwrap();
235            if let Some((cached_exprs, schema)) = cache_read_guard.as_ref()
236                && schema == record_batch.schema_ref()
237            {
238                cached_exprs.clone()
239            } else {
240                drop(cache_read_guard); // Release the read lock before acquiring write lock
241
242                let schema = record_batch.schema();
243                let new_cache = self
244                    .exprs
245                    .iter()
246                    .map(|e| e.try_as_physical_expr(&schema))
247                    .collect::<Result<Vec<_>>>()?;
248
249                let mut cache_write_guard = self.physical_expr_cache.write().unwrap();
250                cache_write_guard.replace((new_cache.clone(), schema));
251                new_cache
252            }
253        };
254
255        let mut result: HashMap<u32, RegionMask> = physical_exprs
256            .iter()
257            .zip(self.regions.iter())
258            .filter_map(|(expr, region_num)| {
259                let col_val = match expr
260                    .evaluate(record_batch)
261                    .context(error::EvaluateRecordBatchSnafu)
262                {
263                    Ok(array) => array,
264                    Err(e) => {
265                        return Some(Err(e));
266                    }
267                };
268                let array = match columnar_value_to_boolean_array(col_val, num_rows) {
269                    Ok(array) => array,
270                    Err(e) => {
271                        return Some(Err(e));
272                    }
273                };
274                let selected_rows = array.true_count();
275                if selected_rows == 0 {
276                    // skip empty region in results.
277                    return None;
278                }
279                Some(Ok((*region_num, RegionMask::new(array, selected_rows))))
280            })
281            .collect::<error::Result<_>>()?;
282
283        let selected = if result.len() == 1 {
284            result.values().next().unwrap().array().clone()
285        } else {
286            let mut selected = BooleanArray::new(BooleanBuffer::new_unset(num_rows), None);
287            for region_mask in result.values() {
288                selected = arrow::compute::kernels::boolean::or(&selected, region_mask.array())
289                    .context(error::ComputeArrowKernelSnafu)?;
290            }
291            selected
292        };
293
294        // fast path: all rows are selected
295        if selected.true_count() == num_rows {
296            return Ok(result);
297        }
298
299        // find unselected rows and assign to default region
300        let unselected = arrow::compute::kernels::boolean::not(&selected)
301            .context(error::ComputeArrowKernelSnafu)?;
302        match result.entry(DEFAULT_REGION) {
303            Entry::Occupied(mut o) => {
304                // merge default region with unselected rows.
305                let default_region_mask = RegionMask::from(
306                    arrow::compute::kernels::boolean::or(o.get().array(), &unselected)
307                        .context(error::ComputeArrowKernelSnafu)?,
308                );
309                o.insert(default_region_mask);
310            }
311            Entry::Vacant(v) => {
312                // default region has no rows, simply put all unselected rows to default region.
313                v.insert(RegionMask::from(unselected));
314            }
315        }
316        Ok(result)
317    }
318}
319
320fn columnar_value_to_boolean_array(
321    col_val: ColumnarValue,
322    num_rows: usize,
323) -> Result<BooleanArray> {
324    let column = col_val
325        .into_array(num_rows)
326        .context(error::EvaluateRecordBatchSnafu)?;
327    let array = column
328        .as_any()
329        .downcast_ref::<BooleanArray>()
330        .with_context(|| error::UnexpectedColumnTypeSnafu {
331            data_type: column.data_type().clone(),
332        })?;
333    Ok(array.clone())
334}
335
336impl PartitionRule for MultiDimPartitionRule {
337    fn as_any(&self) -> &dyn Any {
338        self
339    }
340
341    fn partition_columns(&self) -> Vec<String> {
342        self.partition_columns.clone()
343    }
344
345    fn find_region(&self, values: &[Value]) -> Result<RegionNumber> {
346        self.find_region(values)
347    }
348
349    fn split_record_batch(
350        &self,
351        record_batch: &RecordBatch,
352    ) -> Result<HashMap<RegionNumber, RegionMask>> {
353        self.split_record_batch(record_batch)
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use std::assert_matches::assert_matches;
360
361    use super::*;
362    use crate::error::{self, Error};
363    use crate::expr::col;
364
365    #[test]
366    fn test_find_region() {
367        // PARTITION ON COLUMNS (b) (
368        //     b < 'hz',
369        //     b >= 'hz' AND b < 'sh',
370        //     b >= 'sh'
371        // )
372        let rule = MultiDimPartitionRule::try_new(
373            vec!["b".to_string()],
374            vec![1, 2, 3],
375            vec![
376                PartitionExpr::new(
377                    Operand::Column("b".to_string()),
378                    RestrictedOp::Lt,
379                    Operand::Value(datatypes::value::Value::String("hz".into())),
380                ),
381                PartitionExpr::new(
382                    Operand::Expr(PartitionExpr::new(
383                        Operand::Column("b".to_string()),
384                        RestrictedOp::GtEq,
385                        Operand::Value(datatypes::value::Value::String("hz".into())),
386                    )),
387                    RestrictedOp::And,
388                    Operand::Expr(PartitionExpr::new(
389                        Operand::Column("b".to_string()),
390                        RestrictedOp::Lt,
391                        Operand::Value(datatypes::value::Value::String("sh".into())),
392                    )),
393                ),
394                PartitionExpr::new(
395                    Operand::Column("b".to_string()),
396                    RestrictedOp::GtEq,
397                    Operand::Value(datatypes::value::Value::String("sh".into())),
398                ),
399            ],
400            true,
401        )
402        .unwrap();
403        assert_matches!(
404            rule.find_region(&["foo".into(), 1000_i32.into()]),
405            Err(error::Error::RegionKeysSize {
406                expect: 1,
407                actual: 2,
408                ..
409            })
410        );
411        assert_matches!(rule.find_region(&["foo".into()]), Ok(1));
412        assert_matches!(rule.find_region(&["bar".into()]), Ok(1));
413        assert_matches!(rule.find_region(&["hz".into()]), Ok(2));
414        assert_matches!(rule.find_region(&["hzz".into()]), Ok(2));
415        assert_matches!(rule.find_region(&["sh".into()]), Ok(3));
416        assert_matches!(rule.find_region(&["zzzz".into()]), Ok(3));
417    }
418
419    #[test]
420    fn invalid_expr_case_1() {
421        // PARTITION ON COLUMNS (b) (
422        //     b <= b >= 'hz' AND b < 'sh',
423        // )
424        let rule = MultiDimPartitionRule::try_new(
425            vec!["a".to_string(), "b".to_string()],
426            vec![1],
427            vec![PartitionExpr::new(
428                Operand::Column("b".to_string()),
429                RestrictedOp::LtEq,
430                Operand::Expr(PartitionExpr::new(
431                    Operand::Expr(PartitionExpr::new(
432                        Operand::Column("b".to_string()),
433                        RestrictedOp::GtEq,
434                        Operand::Value(datatypes::value::Value::String("hz".into())),
435                    )),
436                    RestrictedOp::And,
437                    Operand::Expr(PartitionExpr::new(
438                        Operand::Column("b".to_string()),
439                        RestrictedOp::Lt,
440                        Operand::Value(datatypes::value::Value::String("sh".into())),
441                    )),
442                )),
443            )],
444            true,
445        );
446
447        // check rule
448        assert_matches!(rule.unwrap_err(), Error::InvalidExpr { .. });
449    }
450
451    #[test]
452    fn invalid_expr_case_2() {
453        // PARTITION ON COLUMNS (b) (
454        //     b >= 'hz' AND 'sh',
455        // )
456        let rule = MultiDimPartitionRule::try_new(
457            vec!["a".to_string(), "b".to_string()],
458            vec![1],
459            vec![PartitionExpr::new(
460                Operand::Expr(PartitionExpr::new(
461                    Operand::Column("b".to_string()),
462                    RestrictedOp::GtEq,
463                    Operand::Value(datatypes::value::Value::String("hz".into())),
464                )),
465                RestrictedOp::And,
466                Operand::Value(datatypes::value::Value::String("sh".into())),
467            )],
468            true,
469        );
470
471        // check rule
472        assert_matches!(rule.unwrap_err(), Error::InvalidExpr { .. });
473    }
474
475    /// ```ignore
476    ///          │          │
477    ///          │          │
478    /// ─────────┼──────────┼────────────► b
479    ///          │          │
480    ///          │          │
481    ///      b <= h     b >= s
482    /// ```
483    #[test]
484    fn empty_expr_case_1() {
485        // PARTITION ON COLUMNS (b) (
486        //     b <= 'h',
487        //     b >= 's'
488        // )
489        let rule = MultiDimPartitionRule::try_new(
490            vec!["a".to_string(), "b".to_string()],
491            vec![1, 2],
492            vec![
493                PartitionExpr::new(
494                    Operand::Column("b".to_string()),
495                    RestrictedOp::LtEq,
496                    Operand::Value(datatypes::value::Value::String("h".into())),
497                ),
498                PartitionExpr::new(
499                    Operand::Column("b".to_string()),
500                    RestrictedOp::GtEq,
501                    Operand::Value(datatypes::value::Value::String("s".into())),
502                ),
503            ],
504            true,
505        );
506
507        // check rule
508        assert_matches!(rule.unwrap_err(), Error::CheckpointNotCovered { .. });
509    }
510
511    /// ```
512    ///     a
513    ///     ▲
514    ///     │                   ‖
515    ///     │                   ‖
516    /// 200 │         ┌─────────┤
517    ///     │         │         │
518    ///     │         │         │
519    ///     │         │         │
520    /// 100 │   ======┴─────────┘
521    ///     │
522    ///     └──────────────────────────►b
523    ///              10          20
524    /// ```
525    #[test]
526    fn empty_expr_case_2() {
527        // PARTITION ON COLUMNS (b) (
528        //     a >= 100 AND b <= 10  OR  a > 100 AND a <= 200 AND b <= 10  OR  a >= 200 AND b > 10 AND b <= 20  OR  a > 200 AND b <= 20
529        //     a < 100 AND b <= 20  OR  a >= 100 AND b > 20
530        // )
531        let rule = MultiDimPartitionRule::try_new(
532            vec!["a".to_string(), "b".to_string()],
533            vec![1, 2],
534            vec![
535                PartitionExpr::new(
536                    Operand::Expr(PartitionExpr::new(
537                        Operand::Expr(PartitionExpr::new(
538                            //  a >= 100 AND b <= 10
539                            Operand::Expr(PartitionExpr::new(
540                                Operand::Expr(PartitionExpr::new(
541                                    Operand::Column("a".to_string()),
542                                    RestrictedOp::GtEq,
543                                    Operand::Value(datatypes::value::Value::Int64(100)),
544                                )),
545                                RestrictedOp::And,
546                                Operand::Expr(PartitionExpr::new(
547                                    Operand::Column("b".to_string()),
548                                    RestrictedOp::LtEq,
549                                    Operand::Value(datatypes::value::Value::Int64(10)),
550                                )),
551                            )),
552                            RestrictedOp::Or,
553                            // a > 100 AND a <= 200 AND b <= 10
554                            Operand::Expr(PartitionExpr::new(
555                                Operand::Expr(PartitionExpr::new(
556                                    Operand::Expr(PartitionExpr::new(
557                                        Operand::Column("a".to_string()),
558                                        RestrictedOp::Gt,
559                                        Operand::Value(datatypes::value::Value::Int64(100)),
560                                    )),
561                                    RestrictedOp::And,
562                                    Operand::Expr(PartitionExpr::new(
563                                        Operand::Column("a".to_string()),
564                                        RestrictedOp::LtEq,
565                                        Operand::Value(datatypes::value::Value::Int64(200)),
566                                    )),
567                                )),
568                                RestrictedOp::And,
569                                Operand::Expr(PartitionExpr::new(
570                                    Operand::Column("b".to_string()),
571                                    RestrictedOp::LtEq,
572                                    Operand::Value(datatypes::value::Value::Int64(10)),
573                                )),
574                            )),
575                        )),
576                        RestrictedOp::Or,
577                        // a >= 200 AND b > 10 AND b <= 20
578                        Operand::Expr(PartitionExpr::new(
579                            Operand::Expr(PartitionExpr::new(
580                                Operand::Expr(PartitionExpr::new(
581                                    Operand::Column("a".to_string()),
582                                    RestrictedOp::GtEq,
583                                    Operand::Value(datatypes::value::Value::Int64(200)),
584                                )),
585                                RestrictedOp::And,
586                                Operand::Expr(PartitionExpr::new(
587                                    Operand::Column("b".to_string()),
588                                    RestrictedOp::Gt,
589                                    Operand::Value(datatypes::value::Value::Int64(10)),
590                                )),
591                            )),
592                            RestrictedOp::And,
593                            Operand::Expr(PartitionExpr::new(
594                                Operand::Column("b".to_string()),
595                                RestrictedOp::LtEq,
596                                Operand::Value(datatypes::value::Value::Int64(20)),
597                            )),
598                        )),
599                    )),
600                    RestrictedOp::Or,
601                    // a > 200 AND b <= 20
602                    Operand::Expr(PartitionExpr::new(
603                        Operand::Expr(PartitionExpr::new(
604                            Operand::Column("a".to_string()),
605                            RestrictedOp::Gt,
606                            Operand::Value(datatypes::value::Value::Int64(200)),
607                        )),
608                        RestrictedOp::And,
609                        Operand::Expr(PartitionExpr::new(
610                            Operand::Column("b".to_string()),
611                            RestrictedOp::LtEq,
612                            Operand::Value(datatypes::value::Value::Int64(20)),
613                        )),
614                    )),
615                ),
616                PartitionExpr::new(
617                    // a < 100 AND b <= 20
618                    Operand::Expr(PartitionExpr::new(
619                        Operand::Expr(PartitionExpr::new(
620                            Operand::Column("a".to_string()),
621                            RestrictedOp::Lt,
622                            Operand::Value(datatypes::value::Value::Int64(100)),
623                        )),
624                        RestrictedOp::And,
625                        Operand::Expr(PartitionExpr::new(
626                            Operand::Column("b".to_string()),
627                            RestrictedOp::LtEq,
628                            Operand::Value(datatypes::value::Value::Int64(20)),
629                        )),
630                    )),
631                    RestrictedOp::Or,
632                    // a >= 100 AND b > 20
633                    Operand::Expr(PartitionExpr::new(
634                        Operand::Expr(PartitionExpr::new(
635                            Operand::Column("a".to_string()),
636                            RestrictedOp::GtEq,
637                            Operand::Value(datatypes::value::Value::Int64(100)),
638                        )),
639                        RestrictedOp::And,
640                        Operand::Expr(PartitionExpr::new(
641                            Operand::Column("b".to_string()),
642                            RestrictedOp::GtEq,
643                            Operand::Value(datatypes::value::Value::Int64(20)),
644                        )),
645                    )),
646                ),
647            ],
648            true,
649        );
650
651        // check rule
652        assert_matches!(rule.unwrap_err(), Error::CheckpointNotCovered { .. });
653    }
654
655    #[test]
656    fn duplicate_expr_case_1() {
657        // PARTITION ON COLUMNS (a) (
658        //     a <= 20,
659        //     a >= 10
660        // )
661        let rule = MultiDimPartitionRule::try_new(
662            vec!["a".to_string(), "b".to_string()],
663            vec![1, 2],
664            vec![
665                PartitionExpr::new(
666                    Operand::Column("a".to_string()),
667                    RestrictedOp::LtEq,
668                    Operand::Value(datatypes::value::Value::Int64(20)),
669                ),
670                PartitionExpr::new(
671                    Operand::Column("a".to_string()),
672                    RestrictedOp::GtEq,
673                    Operand::Value(datatypes::value::Value::Int64(10)),
674                ),
675            ],
676            true,
677        );
678
679        // check rule
680        assert_matches!(rule.unwrap_err(), Error::CheckpointOverlapped { .. });
681    }
682
683    #[test]
684    fn duplicate_expr_case_2() {
685        // PARTITION ON COLUMNS (a) (
686        //     a != 20,
687        //     a <= 20,
688        //     a > 20,
689        // )
690        let rule = MultiDimPartitionRule::try_new(
691            vec!["a".to_string(), "b".to_string()],
692            vec![1, 2],
693            vec![
694                PartitionExpr::new(
695                    Operand::Column("a".to_string()),
696                    RestrictedOp::NotEq,
697                    Operand::Value(datatypes::value::Value::Int64(20)),
698                ),
699                PartitionExpr::new(
700                    Operand::Column("a".to_string()),
701                    RestrictedOp::LtEq,
702                    Operand::Value(datatypes::value::Value::Int64(20)),
703                ),
704                PartitionExpr::new(
705                    Operand::Column("a".to_string()),
706                    RestrictedOp::Gt,
707                    Operand::Value(datatypes::value::Value::Int64(20)),
708                ),
709            ],
710            true,
711        );
712
713        // check rule
714        assert_matches!(rule.unwrap_err(), Error::CheckpointOverlapped { .. });
715    }
716
717    /// ```ignore
718    /// value
719    ///                                 │
720    ///                                 │
721    ///    value=10 --------------------│
722    ///                                 │
723    /// ────────────────────────────────┼──► host
724    ///                                 │
725    ///                             host=server10
726    /// ```
727    #[test]
728    fn test_partial_divided() {
729        let _rule = MultiDimPartitionRule::try_new(
730            vec!["host".to_string(), "value".to_string()],
731            vec![0, 1, 2, 3],
732            vec![
733                col("host")
734                    .lt(Value::String("server10".into()))
735                    .and(col("value").lt(Value::Int64(10))),
736                col("host")
737                    .lt(Value::String("server10".into()))
738                    .and(col("value").gt_eq(Value::Int64(10))),
739                col("host").gt_eq(Value::String("server10".into())),
740            ],
741            true,
742        )
743        .unwrap();
744    }
745}
746
747#[cfg(test)]
748mod test_split_record_batch {
749    use std::sync::Arc;
750
751    use datafusion_common::ScalarValue;
752    use datatypes::arrow::array::{Int64Array, StringArray};
753    use datatypes::arrow::datatypes::{DataType, Field, Schema};
754    use datatypes::arrow::record_batch::RecordBatch;
755    use rand::Rng;
756
757    use super::*;
758    use crate::expr::{Operand, col};
759
760    fn test_schema() -> Arc<Schema> {
761        Arc::new(Schema::new(vec![
762            Field::new("host", DataType::Utf8, false),
763            Field::new("value", DataType::Int64, false),
764        ]))
765    }
766
767    fn generate_random_record_batch(num_rows: usize) -> RecordBatch {
768        let schema = test_schema();
769        let mut rng = rand::thread_rng();
770        let mut host_array = Vec::with_capacity(num_rows);
771        let mut value_array = Vec::with_capacity(num_rows);
772        for _ in 0..num_rows {
773            host_array.push(format!("server{}", rng.gen_range(0..20)));
774            value_array.push(rng.gen_range(0..20));
775        }
776        let host_array = StringArray::from(host_array);
777        let value_array = Int64Array::from(value_array);
778        RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)]).unwrap()
779    }
780
781    #[test]
782    fn test_split_record_batch_by_one_column() {
783        // Create a simple MultiDimPartitionRule
784        let rule = MultiDimPartitionRule::try_new(
785            vec!["host".to_string(), "value".to_string()],
786            vec![0, 1],
787            vec![
788                col("host").lt(Value::String("server1".into())),
789                col("host").gt_eq(Value::String("server1".into())),
790            ],
791            true,
792        )
793        .unwrap();
794
795        let batch = generate_random_record_batch(1000);
796        // Split the batch
797        let result = rule.split_record_batch(&batch).unwrap();
798        let expected = rule.split_record_batch_naive(&batch).unwrap();
799        assert_eq!(result.len(), expected.len());
800        for (region, value) in &result {
801            assert_eq!(
802                value.array(),
803                expected.get(region).unwrap(),
804                "failed on region: {}",
805                region
806            );
807        }
808    }
809
810    #[test]
811    fn test_split_record_batch_empty() {
812        // Create a simple MultiDimPartitionRule
813        let rule = MultiDimPartitionRule::try_new(
814            vec!["host".to_string()],
815            vec![1],
816            vec![
817                col("host").lt(Value::String("server1".into())),
818                col("host").gt_eq(Value::String("server1".into())),
819            ],
820            true,
821        )
822        .unwrap();
823
824        let schema = test_schema();
825        let host_array = StringArray::from(Vec::<&str>::new());
826        let value_array = Int64Array::from(Vec::<i64>::new());
827        let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
828            .unwrap();
829
830        let result = rule.split_record_batch(&batch).unwrap();
831        assert_eq!(result.len(), 1);
832    }
833
834    #[test]
835    fn test_split_record_batch_by_two_columns() {
836        let rule = MultiDimPartitionRule::try_new(
837            vec!["host".to_string(), "value".to_string()],
838            vec![0, 1, 2, 3],
839            vec![
840                col("host")
841                    .lt(Value::String("server10".into()))
842                    .and(col("value").lt(Value::Int64(10))),
843                col("host")
844                    .lt(Value::String("server10".into()))
845                    .and(col("value").gt_eq(Value::Int64(10))),
846                col("host")
847                    .gt_eq(Value::String("server10".into()))
848                    .and(col("value").lt(Value::Int64(10))),
849                col("host")
850                    .gt_eq(Value::String("server10".into()))
851                    .and(col("value").gt_eq(Value::Int64(10))),
852            ],
853            true,
854        )
855        .unwrap();
856
857        let batch = generate_random_record_batch(1000);
858        let result = rule.split_record_batch(&batch).unwrap();
859        let expected = rule.split_record_batch_naive(&batch).unwrap();
860        assert_eq!(result.len(), expected.len());
861        for (region, value) in &result {
862            assert_eq!(value.array(), expected.get(region).unwrap());
863        }
864    }
865
866    #[test]
867    fn test_all_rows_selected() {
868        // Test the fast path where all rows are selected by some partition
869        let rule = MultiDimPartitionRule::try_new(
870            vec!["value".to_string()],
871            vec![1, 2],
872            vec![
873                col("value").lt(Value::Int64(30)),
874                col("value").gt_eq(Value::Int64(30)),
875            ],
876            true,
877        )
878        .unwrap();
879
880        let schema = test_schema();
881        let host_array = StringArray::from(vec!["server1", "server2", "server3", "server4"]);
882        let value_array = Int64Array::from(vec![10, 20, 30, 40]);
883        let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
884            .unwrap();
885
886        let result = rule.split_record_batch(&batch).unwrap();
887
888        // Check that we have 2 regions and no default region
889        assert_eq!(result.len(), 2);
890        assert!(result.contains_key(&1));
891        assert!(result.contains_key(&2));
892
893        // Verify each region has the correct number of rows
894        assert_eq!(result.get(&1).unwrap().selected_rows(), 2); // values < 30
895        assert_eq!(result.get(&2).unwrap().selected_rows(), 2); // values >= 30
896    }
897
898    #[test]
899    fn test_split_record_batch_with_scalar_predicate() {
900        // Ensure split handles conjunctive/disjunctive predicates on the same column.
901        let rule = MultiDimPartitionRule::try_new(
902            vec!["host".to_string()],
903            vec![0, 1],
904            vec![
905                PartitionExpr::new(
906                    Operand::Column("host".to_string()),
907                    RestrictedOp::Lt,
908                    Operand::Value(Value::String("never_happen_1".into())),
909                ),
910                PartitionExpr::new(
911                    Operand::Expr(PartitionExpr::new(
912                        Operand::Column("host".to_string()),
913                        RestrictedOp::GtEq,
914                        Operand::Value(Value::String("never_happen_1".into())),
915                    )),
916                    RestrictedOp::And,
917                    Operand::Value(Value::Boolean(false)),
918                ),
919            ],
920            false,
921        )
922        .unwrap();
923
924        let batch = generate_random_record_batch(8);
925        let result = rule.split_record_batch(&batch).unwrap();
926
927        assert_eq!(result.len(), 1);
928        assert!(result.contains_key(&0));
929
930        let total_rows = result.get(&0).unwrap().selected_rows();
931        assert_eq!(total_rows, batch.num_rows());
932    }
933
934    #[test]
935    fn test_columnar_value_to_boolean_array_scalar_false() {
936        let result = columnar_value_to_boolean_array(
937            ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))),
938            4,
939        )
940        .unwrap();
941        assert_eq!(result.len(), 4);
942        assert_eq!(result.true_count(), 0);
943    }
944
945    #[test]
946    fn test_columnar_value_to_boolean_array_scalar_true() {
947        let result = columnar_value_to_boolean_array(
948            ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))),
949            4,
950        )
951        .unwrap();
952        assert_eq!(result.len(), 4);
953        assert_eq!(result.true_count(), 4);
954    }
955}