1use std::any::Any;
16use std::cmp::Ordering;
17use std::collections::HashMap;
18use std::sync::{Arc, RwLock};
19
20use datafusion_expr::ColumnarValue;
21use datafusion_physical_expr::PhysicalExpr;
22use datatypes::arrow;
23use datatypes::arrow::array::{BooleanArray, BooleanBufferBuilder, RecordBatch};
24use datatypes::arrow::buffer::BooleanBuffer;
25use datatypes::arrow::datatypes::Schema;
26use datatypes::prelude::Value;
27use datatypes::vectors::{Helper, VectorRef};
28use serde::{Deserialize, Serialize};
29use snafu::{ensure, OptionExt, ResultExt};
30use store_api::storage::RegionNumber;
31
32use crate::error::{
33 self, ConjunctExprWithNonExprSnafu, InvalidExprSnafu, Result, UnclosedValueSnafu,
34 UndefinedColumnSnafu,
35};
36use crate::expr::{Operand, PartitionExpr, RestrictedOp};
37use crate::PartitionRule;
38
39const DEFAULT_REGION: RegionNumber = 0;
41
42type PhysicalExprCache = Option<(Vec<Arc<dyn PhysicalExpr>>, Arc<Schema>)>;
43
44#[derive(Debug, Serialize, Deserialize)]
50pub struct MultiDimPartitionRule {
51 partition_columns: Vec<String>,
53 name_to_index: HashMap<String, usize>,
55 regions: Vec<RegionNumber>,
58 exprs: Vec<PartitionExpr>,
60 #[serde(skip)]
62 physical_expr_cache: RwLock<PhysicalExprCache>,
63}
64
65impl MultiDimPartitionRule {
66 pub fn try_new(
67 partition_columns: Vec<String>,
68 regions: Vec<RegionNumber>,
69 exprs: Vec<PartitionExpr>,
70 ) -> Result<Self> {
71 let name_to_index = partition_columns
72 .iter()
73 .enumerate()
74 .map(|(i, name)| (name.clone(), i))
75 .collect::<HashMap<_, _>>();
76
77 let rule = Self {
78 partition_columns,
79 name_to_index,
80 regions,
81 exprs,
82 physical_expr_cache: RwLock::new(None),
83 };
84
85 let mut checker = RuleChecker::new(&rule);
86 checker.check()?;
87
88 Ok(rule)
89 }
90
91 fn find_region(&self, values: &[Value]) -> Result<RegionNumber> {
92 ensure!(
93 values.len() == self.partition_columns.len(),
94 error::RegionKeysSizeSnafu {
95 expect: self.partition_columns.len(),
96 actual: values.len(),
97 }
98 );
99
100 for (region_index, expr) in self.exprs.iter().enumerate() {
101 if self.evaluate_expr(expr, values)? {
102 return Ok(self.regions[region_index]);
103 }
104 }
105
106 Ok(DEFAULT_REGION)
108 }
109
110 fn evaluate_expr(&self, expr: &PartitionExpr, values: &[Value]) -> Result<bool> {
111 match (expr.lhs.as_ref(), expr.rhs.as_ref()) {
112 (Operand::Column(name), Operand::Value(r)) => {
113 let index = self.name_to_index.get(name).unwrap();
114 let l = &values[*index];
115 Self::perform_op(l, &expr.op, r)
116 }
117 (Operand::Value(l), Operand::Column(name)) => {
118 let index = self.name_to_index.get(name).unwrap();
119 let r = &values[*index];
120 Self::perform_op(l, &expr.op, r)
121 }
122 (Operand::Expr(lhs), Operand::Expr(rhs)) => {
123 let lhs = self.evaluate_expr(lhs, values)?;
124 let rhs = self.evaluate_expr(rhs, values)?;
125 match expr.op {
126 RestrictedOp::And => Ok(lhs && rhs),
127 RestrictedOp::Or => Ok(lhs || rhs),
128 _ => unreachable!(),
129 }
130 }
131 _ => unreachable!(),
132 }
133 }
134
135 fn perform_op(lhs: &Value, op: &RestrictedOp, rhs: &Value) -> Result<bool> {
136 let result = match op {
137 RestrictedOp::Eq => lhs.eq(rhs),
138 RestrictedOp::NotEq => lhs.ne(rhs),
139 RestrictedOp::Lt => lhs.partial_cmp(rhs) == Some(Ordering::Less),
140 RestrictedOp::LtEq => {
141 let result = lhs.partial_cmp(rhs);
142 result == Some(Ordering::Less) || result == Some(Ordering::Equal)
143 }
144 RestrictedOp::Gt => lhs.partial_cmp(rhs) == Some(Ordering::Greater),
145 RestrictedOp::GtEq => {
146 let result = lhs.partial_cmp(rhs);
147 result == Some(Ordering::Greater) || result == Some(Ordering::Equal)
148 }
149 RestrictedOp::And | RestrictedOp::Or => unreachable!(),
150 };
151
152 Ok(result)
153 }
154
155 pub fn row_at(&self, cols: &[VectorRef], index: usize, row: &mut [Value]) -> Result<()> {
156 for (col_idx, col) in cols.iter().enumerate() {
157 row[col_idx] = col.get(index);
158 }
159 Ok(())
160 }
161
162 pub fn record_batch_to_cols(&self, record_batch: &RecordBatch) -> Result<Vec<VectorRef>> {
163 self.partition_columns
164 .iter()
165 .map(|col_name| {
166 record_batch
167 .column_by_name(col_name)
168 .context(error::UndefinedColumnSnafu { column: col_name })
169 .and_then(|array| {
170 Helper::try_into_vector(array).context(error::ConvertToVectorSnafu)
171 })
172 })
173 .collect::<Result<Vec<_>>>()
174 }
175
176 pub fn split_record_batch_naive(
177 &self,
178 record_batch: &RecordBatch,
179 ) -> Result<HashMap<RegionNumber, BooleanArray>> {
180 let num_rows = record_batch.num_rows();
181
182 let mut result = self
183 .regions
184 .iter()
185 .map(|region| {
186 let mut builder = BooleanBufferBuilder::new(num_rows);
187 builder.append_n(num_rows, false);
188 (*region, builder)
189 })
190 .collect::<HashMap<_, _>>();
191
192 let cols = self.record_batch_to_cols(record_batch)?;
193 let mut current_row = vec![Value::Null; self.partition_columns.len()];
194 for row_idx in 0..num_rows {
195 self.row_at(&cols, row_idx, &mut current_row)?;
196 let current_region = self.find_region(¤t_row)?;
197 let region_mask = result
198 .get_mut(¤t_region)
199 .unwrap_or_else(|| panic!("Region {} must be initialized", current_region));
200 region_mask.set_bit(row_idx, true);
201 }
202
203 Ok(result
204 .into_iter()
205 .map(|(region, mut mask)| (region, BooleanArray::new(mask.finish(), None)))
206 .collect())
207 }
208
209 pub fn split_record_batch(
210 &self,
211 record_batch: &RecordBatch,
212 ) -> Result<HashMap<RegionNumber, BooleanArray>> {
213 let num_rows = record_batch.num_rows();
214 let physical_exprs = {
215 let cache_read_guard = self.physical_expr_cache.read().unwrap();
216 if let Some((cached_exprs, schema)) = cache_read_guard.as_ref()
217 && schema == record_batch.schema_ref()
218 {
219 cached_exprs.clone()
220 } else {
221 drop(cache_read_guard); let schema = record_batch.schema();
224 let new_cache = self
225 .exprs
226 .iter()
227 .map(|e| e.try_as_physical_expr(&schema))
228 .collect::<Result<Vec<_>>>()?;
229
230 let mut cache_write_guard = self.physical_expr_cache.write().unwrap();
231 cache_write_guard.replace((new_cache.clone(), schema));
232 new_cache
233 }
234 };
235
236 let mut result: HashMap<u32, BooleanArray> = physical_exprs
237 .iter()
238 .zip(self.regions.iter())
239 .map(|(expr, region_num)| {
240 let ColumnarValue::Array(column) = expr
241 .evaluate(record_batch)
242 .context(error::EvaluateRecordBatchSnafu)?
243 else {
244 unreachable!("Expected an array")
245 };
246 Ok((
247 *region_num,
248 column
249 .as_any()
250 .downcast_ref::<BooleanArray>()
251 .with_context(|| error::UnexpectedColumnTypeSnafu {
252 data_type: column.data_type().clone(),
253 })?
254 .clone(),
255 ))
256 })
257 .collect::<error::Result<_>>()?;
258
259 let mut selected = BooleanArray::new(BooleanBuffer::new_unset(num_rows), None);
260 for region_selection in result.values() {
261 selected = arrow::compute::kernels::boolean::or(&selected, region_selection)
262 .context(error::ComputeArrowKernelSnafu)?;
263 }
264
265 if selected.true_count() == num_rows {
267 return Ok(result);
268 }
269
270 let unselected = arrow::compute::kernels::boolean::not(&selected)
272 .context(error::ComputeArrowKernelSnafu)?;
273 let default_region_selection = result
274 .entry(DEFAULT_REGION)
275 .or_insert_with(|| unselected.clone());
276 *default_region_selection =
277 arrow::compute::kernels::boolean::or(default_region_selection, &unselected)
278 .context(error::ComputeArrowKernelSnafu)?;
279 Ok(result)
280 }
281}
282
283impl PartitionRule for MultiDimPartitionRule {
284 fn as_any(&self) -> &dyn Any {
285 self
286 }
287
288 fn partition_columns(&self) -> Vec<String> {
289 self.partition_columns.clone()
290 }
291
292 fn find_region(&self, values: &[Value]) -> Result<RegionNumber> {
293 self.find_region(values)
294 }
295
296 fn split_record_batch(
297 &self,
298 record_batch: &RecordBatch,
299 ) -> Result<HashMap<RegionNumber, BooleanArray>> {
300 self.split_record_batch(record_batch)
301 }
302}
303
304type Axis = HashMap<Value, SplitPoint>;
306
307struct SplitPoint {
309 is_equal: bool,
310 less_than_counter: isize,
311}
312
313struct RuleChecker<'a> {
321 axis: Vec<Axis>,
322 rule: &'a MultiDimPartitionRule,
323}
324
325impl<'a> RuleChecker<'a> {
326 pub fn new(rule: &'a MultiDimPartitionRule) -> Self {
327 let mut projections = Vec::with_capacity(rule.partition_columns.len());
328 projections.resize_with(rule.partition_columns.len(), Default::default);
329
330 Self {
331 axis: projections,
332 rule,
333 }
334 }
335
336 pub fn check(&mut self) -> Result<()> {
337 for expr in &self.rule.exprs {
338 self.walk_expr(expr)?
339 }
340
341 self.check_axis()
342 }
343
344 #[allow(clippy::mutable_key_type)]
345 fn walk_expr(&mut self, expr: &PartitionExpr) -> Result<()> {
346 match expr.op {
348 RestrictedOp::And | RestrictedOp::Or => {
349 match (expr.lhs.as_ref(), expr.rhs.as_ref()) {
350 (Operand::Expr(lhs), Operand::Expr(rhs)) => {
351 self.walk_expr(lhs)?;
352 self.walk_expr(rhs)?
353 }
354 _ => ConjunctExprWithNonExprSnafu { expr: expr.clone() }.fail()?,
355 }
356
357 return Ok(());
358 }
359 _ => {}
361 }
362
363 let (col, val) = match (expr.lhs.as_ref(), expr.rhs.as_ref()) {
364 (Operand::Expr(_), _)
365 | (_, Operand::Expr(_))
366 | (Operand::Column(_), Operand::Column(_))
367 | (Operand::Value(_), Operand::Value(_)) => {
368 InvalidExprSnafu { expr: expr.clone() }.fail()?
369 }
370
371 (Operand::Column(col), Operand::Value(val))
372 | (Operand::Value(val), Operand::Column(col)) => (col, val),
373 };
374
375 let col_index =
376 *self
377 .rule
378 .name_to_index
379 .get(col)
380 .with_context(|| UndefinedColumnSnafu {
381 column: col.clone(),
382 })?;
383 let axis = &mut self.axis[col_index];
384 let split_point = axis.entry(val.clone()).or_insert(SplitPoint {
385 is_equal: false,
386 less_than_counter: 0,
387 });
388 match expr.op {
389 RestrictedOp::Eq => {
390 split_point.is_equal = true;
391 }
392 RestrictedOp::NotEq => {
393 }
395 RestrictedOp::Lt => {
396 split_point.less_than_counter += 1;
397 }
398 RestrictedOp::LtEq => {
399 split_point.less_than_counter += 1;
400 split_point.is_equal = true;
401 }
402 RestrictedOp::Gt => {
403 split_point.less_than_counter -= 1;
404 }
405 RestrictedOp::GtEq => {
406 split_point.less_than_counter -= 1;
407 split_point.is_equal = true;
408 }
409 RestrictedOp::And | RestrictedOp::Or => {
410 unreachable!("conjunct expr should be handled above")
411 }
412 }
413
414 Ok(())
415 }
416
417 fn check_axis(&self) -> Result<()> {
419 for (col_index, axis) in self.axis.iter().enumerate() {
420 for (val, split_point) in axis {
421 if split_point.less_than_counter != 0 || !split_point.is_equal {
422 UnclosedValueSnafu {
423 value: format!("{val:?}"),
424 column: self.rule.partition_columns[col_index].clone(),
425 }
426 .fail()?;
427 }
428 }
429 }
430 Ok(())
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use std::assert_matches::assert_matches;
437
438 use super::*;
439 use crate::error::{self, Error};
440
441 #[test]
442 fn test_find_region() {
443 let rule = MultiDimPartitionRule::try_new(
449 vec!["b".to_string()],
450 vec![1, 2, 3],
451 vec![
452 PartitionExpr::new(
453 Operand::Column("b".to_string()),
454 RestrictedOp::Lt,
455 Operand::Value(datatypes::value::Value::String("hz".into())),
456 ),
457 PartitionExpr::new(
458 Operand::Expr(PartitionExpr::new(
459 Operand::Column("b".to_string()),
460 RestrictedOp::GtEq,
461 Operand::Value(datatypes::value::Value::String("hz".into())),
462 )),
463 RestrictedOp::And,
464 Operand::Expr(PartitionExpr::new(
465 Operand::Column("b".to_string()),
466 RestrictedOp::Lt,
467 Operand::Value(datatypes::value::Value::String("sh".into())),
468 )),
469 ),
470 PartitionExpr::new(
471 Operand::Column("b".to_string()),
472 RestrictedOp::GtEq,
473 Operand::Value(datatypes::value::Value::String("sh".into())),
474 ),
475 ],
476 )
477 .unwrap();
478 assert_matches!(
479 rule.find_region(&["foo".into(), 1000_i32.into()]),
480 Err(error::Error::RegionKeysSize {
481 expect: 1,
482 actual: 2,
483 ..
484 })
485 );
486 assert_matches!(rule.find_region(&["foo".into()]), Ok(1));
487 assert_matches!(rule.find_region(&["bar".into()]), Ok(1));
488 assert_matches!(rule.find_region(&["hz".into()]), Ok(2));
489 assert_matches!(rule.find_region(&["hzz".into()]), Ok(2));
490 assert_matches!(rule.find_region(&["sh".into()]), Ok(3));
491 assert_matches!(rule.find_region(&["zzzz".into()]), Ok(3));
492 }
493
494 #[test]
495 fn invalid_expr_case_1() {
496 let rule = MultiDimPartitionRule::try_new(
500 vec!["a".to_string(), "b".to_string()],
501 vec![1],
502 vec![PartitionExpr::new(
503 Operand::Column("b".to_string()),
504 RestrictedOp::LtEq,
505 Operand::Expr(PartitionExpr::new(
506 Operand::Expr(PartitionExpr::new(
507 Operand::Column("b".to_string()),
508 RestrictedOp::GtEq,
509 Operand::Value(datatypes::value::Value::String("hz".into())),
510 )),
511 RestrictedOp::And,
512 Operand::Expr(PartitionExpr::new(
513 Operand::Column("b".to_string()),
514 RestrictedOp::Lt,
515 Operand::Value(datatypes::value::Value::String("sh".into())),
516 )),
517 )),
518 )],
519 );
520
521 assert_matches!(rule.unwrap_err(), Error::InvalidExpr { .. });
523 }
524
525 #[test]
526 fn invalid_expr_case_2() {
527 let rule = MultiDimPartitionRule::try_new(
531 vec!["a".to_string(), "b".to_string()],
532 vec![1],
533 vec![PartitionExpr::new(
534 Operand::Expr(PartitionExpr::new(
535 Operand::Column("b".to_string()),
536 RestrictedOp::GtEq,
537 Operand::Value(datatypes::value::Value::String("hz".into())),
538 )),
539 RestrictedOp::And,
540 Operand::Value(datatypes::value::Value::String("sh".into())),
541 )],
542 );
543
544 assert_matches!(rule.unwrap_err(), Error::ConjunctExprWithNonExpr { .. });
546 }
547
548 #[test]
557 fn empty_expr_case_1() {
558 let rule = MultiDimPartitionRule::try_new(
563 vec!["a".to_string(), "b".to_string()],
564 vec![1, 2],
565 vec![
566 PartitionExpr::new(
567 Operand::Column("b".to_string()),
568 RestrictedOp::LtEq,
569 Operand::Value(datatypes::value::Value::String("h".into())),
570 ),
571 PartitionExpr::new(
572 Operand::Column("b".to_string()),
573 RestrictedOp::GtEq,
574 Operand::Value(datatypes::value::Value::String("s".into())),
575 ),
576 ],
577 );
578
579 assert_matches!(rule.unwrap_err(), Error::UnclosedValue { .. });
581 }
582
583 #[test]
598 fn empty_expr_case_2() {
599 let rule = MultiDimPartitionRule::try_new(
604 vec!["a".to_string(), "b".to_string()],
605 vec![1, 2],
606 vec![
607 PartitionExpr::new(
608 Operand::Expr(PartitionExpr::new(
609 Operand::Expr(PartitionExpr::new(
610 Operand::Expr(PartitionExpr::new(
612 Operand::Expr(PartitionExpr::new(
613 Operand::Column("a".to_string()),
614 RestrictedOp::GtEq,
615 Operand::Value(datatypes::value::Value::Int64(100)),
616 )),
617 RestrictedOp::And,
618 Operand::Expr(PartitionExpr::new(
619 Operand::Column("b".to_string()),
620 RestrictedOp::LtEq,
621 Operand::Value(datatypes::value::Value::Int64(10)),
622 )),
623 )),
624 RestrictedOp::Or,
625 Operand::Expr(PartitionExpr::new(
627 Operand::Expr(PartitionExpr::new(
628 Operand::Expr(PartitionExpr::new(
629 Operand::Column("a".to_string()),
630 RestrictedOp::Gt,
631 Operand::Value(datatypes::value::Value::Int64(100)),
632 )),
633 RestrictedOp::And,
634 Operand::Expr(PartitionExpr::new(
635 Operand::Column("a".to_string()),
636 RestrictedOp::LtEq,
637 Operand::Value(datatypes::value::Value::Int64(200)),
638 )),
639 )),
640 RestrictedOp::And,
641 Operand::Expr(PartitionExpr::new(
642 Operand::Column("b".to_string()),
643 RestrictedOp::LtEq,
644 Operand::Value(datatypes::value::Value::Int64(10)),
645 )),
646 )),
647 )),
648 RestrictedOp::Or,
649 Operand::Expr(PartitionExpr::new(
651 Operand::Expr(PartitionExpr::new(
652 Operand::Expr(PartitionExpr::new(
653 Operand::Column("a".to_string()),
654 RestrictedOp::GtEq,
655 Operand::Value(datatypes::value::Value::Int64(200)),
656 )),
657 RestrictedOp::And,
658 Operand::Expr(PartitionExpr::new(
659 Operand::Column("b".to_string()),
660 RestrictedOp::Gt,
661 Operand::Value(datatypes::value::Value::Int64(10)),
662 )),
663 )),
664 RestrictedOp::And,
665 Operand::Expr(PartitionExpr::new(
666 Operand::Column("b".to_string()),
667 RestrictedOp::LtEq,
668 Operand::Value(datatypes::value::Value::Int64(20)),
669 )),
670 )),
671 )),
672 RestrictedOp::Or,
673 Operand::Expr(PartitionExpr::new(
675 Operand::Expr(PartitionExpr::new(
676 Operand::Column("a".to_string()),
677 RestrictedOp::Gt,
678 Operand::Value(datatypes::value::Value::Int64(200)),
679 )),
680 RestrictedOp::And,
681 Operand::Expr(PartitionExpr::new(
682 Operand::Column("b".to_string()),
683 RestrictedOp::LtEq,
684 Operand::Value(datatypes::value::Value::Int64(20)),
685 )),
686 )),
687 ),
688 PartitionExpr::new(
689 Operand::Expr(PartitionExpr::new(
691 Operand::Expr(PartitionExpr::new(
692 Operand::Column("a".to_string()),
693 RestrictedOp::Lt,
694 Operand::Value(datatypes::value::Value::Int64(100)),
695 )),
696 RestrictedOp::And,
697 Operand::Expr(PartitionExpr::new(
698 Operand::Column("b".to_string()),
699 RestrictedOp::LtEq,
700 Operand::Value(datatypes::value::Value::Int64(20)),
701 )),
702 )),
703 RestrictedOp::Or,
704 Operand::Expr(PartitionExpr::new(
706 Operand::Expr(PartitionExpr::new(
707 Operand::Column("a".to_string()),
708 RestrictedOp::GtEq,
709 Operand::Value(datatypes::value::Value::Int64(100)),
710 )),
711 RestrictedOp::And,
712 Operand::Expr(PartitionExpr::new(
713 Operand::Column("b".to_string()),
714 RestrictedOp::GtEq,
715 Operand::Value(datatypes::value::Value::Int64(20)),
716 )),
717 )),
718 ),
719 ],
720 );
721
722 assert_matches!(rule.unwrap_err(), Error::UnclosedValue { .. });
724 }
725
726 #[test]
727 fn duplicate_expr_case_1() {
728 let rule = MultiDimPartitionRule::try_new(
733 vec!["a".to_string(), "b".to_string()],
734 vec![1, 2],
735 vec![
736 PartitionExpr::new(
737 Operand::Column("a".to_string()),
738 RestrictedOp::LtEq,
739 Operand::Value(datatypes::value::Value::Int64(20)),
740 ),
741 PartitionExpr::new(
742 Operand::Column("a".to_string()),
743 RestrictedOp::GtEq,
744 Operand::Value(datatypes::value::Value::Int64(10)),
745 ),
746 ],
747 );
748
749 assert_matches!(rule.unwrap_err(), Error::UnclosedValue { .. });
751 }
752
753 #[test]
754 #[ignore = "checker cannot detect this kind of duplicate for now"]
755 fn duplicate_expr_case_2() {
756 let rule = MultiDimPartitionRule::try_new(
762 vec!["a".to_string(), "b".to_string()],
763 vec![1, 2],
764 vec![
765 PartitionExpr::new(
766 Operand::Column("a".to_string()),
767 RestrictedOp::NotEq,
768 Operand::Value(datatypes::value::Value::Int64(20)),
769 ),
770 PartitionExpr::new(
771 Operand::Column("a".to_string()),
772 RestrictedOp::LtEq,
773 Operand::Value(datatypes::value::Value::Int64(20)),
774 ),
775 PartitionExpr::new(
776 Operand::Column("a".to_string()),
777 RestrictedOp::Gt,
778 Operand::Value(datatypes::value::Value::Int64(20)),
779 ),
780 ],
781 );
782
783 assert!(rule.is_err());
785 }
786}
787
788#[cfg(test)]
789mod test_split_record_batch {
790 use std::sync::Arc;
791
792 use datatypes::arrow::array::{Int64Array, StringArray};
793 use datatypes::arrow::datatypes::{DataType, Field, Schema};
794 use datatypes::arrow::record_batch::RecordBatch;
795 use rand::Rng;
796
797 use super::*;
798 use crate::expr::col;
799
800 fn test_schema() -> Arc<Schema> {
801 Arc::new(Schema::new(vec![
802 Field::new("host", DataType::Utf8, false),
803 Field::new("value", DataType::Int64, false),
804 ]))
805 }
806
807 fn generate_random_record_batch(num_rows: usize) -> RecordBatch {
808 let schema = test_schema();
809 let mut rng = rand::thread_rng();
810 let mut host_array = Vec::with_capacity(num_rows);
811 let mut value_array = Vec::with_capacity(num_rows);
812 for _ in 0..num_rows {
813 host_array.push(format!("server{}", rng.gen_range(0..20)));
814 value_array.push(rng.gen_range(0..20));
815 }
816 let host_array = StringArray::from(host_array);
817 let value_array = Int64Array::from(value_array);
818 RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)]).unwrap()
819 }
820
821 #[test]
822 fn test_split_record_batch_by_one_column() {
823 let rule = MultiDimPartitionRule::try_new(
825 vec!["host".to_string(), "value".to_string()],
826 vec![0, 1],
827 vec![
828 col("host").lt(Value::String("server1".into())),
829 col("host").gt_eq(Value::String("server1".into())),
830 ],
831 )
832 .unwrap();
833
834 let batch = generate_random_record_batch(1000);
835 let result = rule.split_record_batch(&batch).unwrap();
837 let expected = rule.split_record_batch_naive(&batch).unwrap();
838 assert_eq!(result.len(), expected.len());
839 for (region, value) in &result {
840 assert_eq!(
841 value,
842 expected.get(region).unwrap(),
843 "failed on region: {}",
844 region
845 );
846 }
847 }
848
849 #[test]
850 fn test_split_record_batch_empty() {
851 let rule = MultiDimPartitionRule::try_new(
853 vec!["host".to_string()],
854 vec![1],
855 vec![PartitionExpr::new(
856 Operand::Column("host".to_string()),
857 RestrictedOp::Eq,
858 Operand::Value(Value::String("server1".into())),
859 )],
860 )
861 .unwrap();
862
863 let schema = test_schema();
864 let host_array = StringArray::from(Vec::<&str>::new());
865 let value_array = Int64Array::from(Vec::<i64>::new());
866 let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
867 .unwrap();
868
869 let result = rule.split_record_batch(&batch).unwrap();
870 assert_eq!(result.len(), 1);
871 }
872
873 #[test]
874 fn test_split_record_batch_by_two_columns() {
875 let rule = MultiDimPartitionRule::try_new(
876 vec!["host".to_string(), "value".to_string()],
877 vec![0, 1, 2, 3],
878 vec![
879 col("host")
880 .lt(Value::String("server10".into()))
881 .and(col("value").lt(Value::Int64(10))),
882 col("host")
883 .lt(Value::String("server10".into()))
884 .and(col("value").gt_eq(Value::Int64(10))),
885 col("host")
886 .gt_eq(Value::String("server10".into()))
887 .and(col("value").lt(Value::Int64(10))),
888 col("host")
889 .gt_eq(Value::String("server10".into()))
890 .and(col("value").gt_eq(Value::Int64(10))),
891 ],
892 )
893 .unwrap();
894
895 let batch = generate_random_record_batch(1000);
896 let result = rule.split_record_batch(&batch).unwrap();
897 let expected = rule.split_record_batch_naive(&batch).unwrap();
898 assert_eq!(result.len(), expected.len());
899 for (region, value) in &result {
900 assert_eq!(value, expected.get(region).unwrap());
901 }
902 }
903
904 #[test]
905 fn test_default_region() {
906 let rule = MultiDimPartitionRule::try_new(
907 vec!["host".to_string(), "value".to_string()],
908 vec![0, 1, 2, 3],
909 vec![
910 col("host")
911 .lt(Value::String("server10".into()))
912 .and(col("value").eq(Value::Int64(10))),
913 col("host")
914 .lt(Value::String("server10".into()))
915 .and(col("value").eq(Value::Int64(20))),
916 col("host")
917 .gt_eq(Value::String("server10".into()))
918 .and(col("value").eq(Value::Int64(10))),
919 col("host")
920 .gt_eq(Value::String("server10".into()))
921 .and(col("value").eq(Value::Int64(20))),
922 ],
923 )
924 .unwrap();
925
926 let schema = test_schema();
927 let host_array = StringArray::from(vec!["server1", "server1", "server1", "server100"]);
928 let value_array = Int64Array::from(vec![10, 20, 30, 10]);
929 let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
930 .unwrap();
931 let result = rule.split_record_batch(&batch).unwrap();
932 let expected = rule.split_record_batch_naive(&batch).unwrap();
933 assert_eq!(result.len(), expected.len());
934 for (region, value) in &result {
935 assert_eq!(value, expected.get(region).unwrap());
936 }
937 }
938}