1use std::any::Any;
16use std::cmp::Ordering;
17use std::collections::hash_map::Entry;
18use std::collections::HashMap;
19use std::sync::{Arc, RwLock};
20
21use datafusion_expr::ColumnarValue;
22use datafusion_physical_expr::PhysicalExpr;
23use datatypes::arrow;
24use datatypes::arrow::array::{BooleanArray, BooleanBufferBuilder, RecordBatch};
25use datatypes::arrow::buffer::BooleanBuffer;
26use datatypes::arrow::datatypes::Schema;
27use datatypes::prelude::Value;
28use datatypes::vectors::{Helper, VectorRef};
29use serde::{Deserialize, Serialize};
30use snafu::{ensure, OptionExt, ResultExt};
31use store_api::storage::RegionNumber;
32
33use crate::checker::PartitionChecker;
34use crate::error::{self, Result, UndefinedColumnSnafu};
35use crate::expr::{Operand, PartitionExpr, RestrictedOp};
36use crate::partition::RegionMask;
37use crate::PartitionRule;
38
39const DEFAULT_REGION: RegionNumber = 0;
41
42type PhysicalExprCache = Option<(Vec<Arc<dyn PhysicalExpr>>, Arc<Schema>)>;
43
44#[derive(Debug, Serialize, Deserialize)]
50pub struct MultiDimPartitionRule {
51 partition_columns: Vec<String>,
53 name_to_index: HashMap<String, usize>,
55 regions: Vec<RegionNumber>,
58 exprs: Vec<PartitionExpr>,
60 #[serde(skip)]
62 physical_expr_cache: RwLock<PhysicalExprCache>,
63}
64
65impl MultiDimPartitionRule {
66 pub fn try_new(
71 partition_columns: Vec<String>,
72 regions: Vec<RegionNumber>,
73 exprs: Vec<PartitionExpr>,
74 check_exprs: bool,
75 ) -> Result<Self> {
76 let name_to_index = partition_columns
77 .iter()
78 .enumerate()
79 .map(|(i, name)| (name.clone(), i))
80 .collect::<HashMap<_, _>>();
81
82 let rule = Self {
83 partition_columns,
84 name_to_index,
85 regions,
86 exprs,
87 physical_expr_cache: RwLock::new(None),
88 };
89
90 if check_exprs {
91 let checker = PartitionChecker::try_new(&rule)?;
92 checker.check()?;
93 }
94
95 Ok(rule)
96 }
97
98 pub fn exprs(&self) -> &[PartitionExpr] {
99 &self.exprs
100 }
101
102 fn find_region(&self, values: &[Value]) -> Result<RegionNumber> {
103 ensure!(
104 values.len() == self.partition_columns.len(),
105 error::RegionKeysSizeSnafu {
106 expect: self.partition_columns.len(),
107 actual: values.len(),
108 }
109 );
110
111 for (region_index, expr) in self.exprs.iter().enumerate() {
112 if self.evaluate_expr(expr, values)? {
113 return Ok(self.regions[region_index]);
114 }
115 }
116
117 Ok(DEFAULT_REGION)
119 }
120
121 fn evaluate_expr(&self, expr: &PartitionExpr, values: &[Value]) -> Result<bool> {
122 match (expr.lhs.as_ref(), expr.rhs.as_ref()) {
123 (Operand::Column(name), Operand::Value(r)) => {
124 let index = self.name_to_index.get(name).unwrap();
125 let l = &values[*index];
126 Self::perform_op(l, &expr.op, r)
127 }
128 (Operand::Value(l), Operand::Column(name)) => {
129 let index = self.name_to_index.get(name).unwrap();
130 let r = &values[*index];
131 Self::perform_op(l, &expr.op, r)
132 }
133 (Operand::Expr(lhs), Operand::Expr(rhs)) => {
134 let lhs = self.evaluate_expr(lhs, values)?;
135 let rhs = self.evaluate_expr(rhs, values)?;
136 match expr.op {
137 RestrictedOp::And => Ok(lhs && rhs),
138 RestrictedOp::Or => Ok(lhs || rhs),
139 _ => unreachable!(),
140 }
141 }
142 _ => unreachable!(),
143 }
144 }
145
146 fn perform_op(lhs: &Value, op: &RestrictedOp, rhs: &Value) -> Result<bool> {
147 let result = match op {
148 RestrictedOp::Eq => lhs.eq(rhs),
149 RestrictedOp::NotEq => lhs.ne(rhs),
150 RestrictedOp::Lt => lhs.partial_cmp(rhs) == Some(Ordering::Less),
151 RestrictedOp::LtEq => {
152 let result = lhs.partial_cmp(rhs);
153 result == Some(Ordering::Less) || result == Some(Ordering::Equal)
154 }
155 RestrictedOp::Gt => lhs.partial_cmp(rhs) == Some(Ordering::Greater),
156 RestrictedOp::GtEq => {
157 let result = lhs.partial_cmp(rhs);
158 result == Some(Ordering::Greater) || result == Some(Ordering::Equal)
159 }
160 RestrictedOp::And | RestrictedOp::Or => unreachable!(),
161 };
162
163 Ok(result)
164 }
165
166 pub fn row_at(&self, cols: &[VectorRef], index: usize, row: &mut [Value]) -> Result<()> {
167 for (col_idx, col) in cols.iter().enumerate() {
168 row[col_idx] = col.get(index);
169 }
170 Ok(())
171 }
172
173 pub fn record_batch_to_cols(&self, record_batch: &RecordBatch) -> Result<Vec<VectorRef>> {
174 self.partition_columns
175 .iter()
176 .map(|col_name| {
177 record_batch
178 .column_by_name(col_name)
179 .context(UndefinedColumnSnafu { column: col_name })
180 .and_then(|array| {
181 Helper::try_into_vector(array).context(error::ConvertToVectorSnafu)
182 })
183 })
184 .collect::<Result<Vec<_>>>()
185 }
186
187 pub fn split_record_batch_naive(
188 &self,
189 record_batch: &RecordBatch,
190 ) -> Result<HashMap<RegionNumber, BooleanArray>> {
191 let num_rows = record_batch.num_rows();
192
193 let mut result = self
194 .regions
195 .iter()
196 .map(|region| {
197 let mut builder = BooleanBufferBuilder::new(num_rows);
198 builder.append_n(num_rows, false);
199 (*region, builder)
200 })
201 .collect::<HashMap<_, _>>();
202
203 let cols = self.record_batch_to_cols(record_batch)?;
204 let mut current_row = vec![Value::Null; self.partition_columns.len()];
205 for row_idx in 0..num_rows {
206 self.row_at(&cols, row_idx, &mut current_row)?;
207 let current_region = self.find_region(¤t_row)?;
208 let region_mask = result
209 .get_mut(¤t_region)
210 .unwrap_or_else(|| panic!("Region {} must be initialized", current_region));
211 region_mask.set_bit(row_idx, true);
212 }
213
214 Ok(result
215 .into_iter()
216 .map(|(region, mut mask)| (region, BooleanArray::new(mask.finish(), None)))
217 .collect())
218 }
219
220 pub fn split_record_batch(
221 &self,
222 record_batch: &RecordBatch,
223 ) -> Result<HashMap<RegionNumber, RegionMask>> {
224 let num_rows = record_batch.num_rows();
225 if self.regions.len() == 1 {
226 return Ok([(
227 self.regions[0],
228 RegionMask::from(BooleanArray::from(vec![true; num_rows])),
229 )]
230 .into_iter()
231 .collect());
232 }
233 let physical_exprs = {
234 let cache_read_guard = self.physical_expr_cache.read().unwrap();
235 if let Some((cached_exprs, schema)) = cache_read_guard.as_ref()
236 && schema == record_batch.schema_ref()
237 {
238 cached_exprs.clone()
239 } else {
240 drop(cache_read_guard); let schema = record_batch.schema();
243 let new_cache = self
244 .exprs
245 .iter()
246 .map(|e| e.try_as_physical_expr(&schema))
247 .collect::<Result<Vec<_>>>()?;
248
249 let mut cache_write_guard = self.physical_expr_cache.write().unwrap();
250 cache_write_guard.replace((new_cache.clone(), schema));
251 new_cache
252 }
253 };
254
255 let mut result: HashMap<u32, RegionMask> = physical_exprs
256 .iter()
257 .zip(self.regions.iter())
258 .filter_map(|(expr, region_num)| {
259 let col_val = match expr
260 .evaluate(record_batch)
261 .context(error::EvaluateRecordBatchSnafu)
262 {
263 Ok(array) => array,
264 Err(e) => {
265 return Some(Err(e));
266 }
267 };
268 let ColumnarValue::Array(column) = col_val else {
269 unreachable!("Expected an array")
270 };
271 let array =
272 match column
273 .as_any()
274 .downcast_ref::<BooleanArray>()
275 .with_context(|| error::UnexpectedColumnTypeSnafu {
276 data_type: column.data_type().clone(),
277 }) {
278 Ok(array) => array,
279 Err(e) => {
280 return Some(Err(e));
281 }
282 };
283 let selected_rows = array.true_count();
284 if selected_rows == 0 {
285 return None;
287 }
288 Some(Ok((
289 *region_num,
290 RegionMask::new(array.clone(), selected_rows),
291 )))
292 })
293 .collect::<error::Result<_>>()?;
294
295 let selected = if result.len() == 1 {
296 result.values().next().unwrap().array().clone()
297 } else {
298 let mut selected = BooleanArray::new(BooleanBuffer::new_unset(num_rows), None);
299 for region_mask in result.values() {
300 selected = arrow::compute::kernels::boolean::or(&selected, region_mask.array())
301 .context(error::ComputeArrowKernelSnafu)?;
302 }
303 selected
304 };
305
306 if selected.true_count() == num_rows {
308 return Ok(result);
309 }
310
311 let unselected = arrow::compute::kernels::boolean::not(&selected)
313 .context(error::ComputeArrowKernelSnafu)?;
314 match result.entry(DEFAULT_REGION) {
315 Entry::Occupied(mut o) => {
316 let default_region_mask = RegionMask::from(
318 arrow::compute::kernels::boolean::or(o.get().array(), &unselected)
319 .context(error::ComputeArrowKernelSnafu)?,
320 );
321 o.insert(default_region_mask);
322 }
323 Entry::Vacant(v) => {
324 v.insert(RegionMask::from(unselected));
326 }
327 }
328 Ok(result)
329 }
330}
331
332impl PartitionRule for MultiDimPartitionRule {
333 fn as_any(&self) -> &dyn Any {
334 self
335 }
336
337 fn partition_columns(&self) -> Vec<String> {
338 self.partition_columns.clone()
339 }
340
341 fn find_region(&self, values: &[Value]) -> Result<RegionNumber> {
342 self.find_region(values)
343 }
344
345 fn split_record_batch(
346 &self,
347 record_batch: &RecordBatch,
348 ) -> Result<HashMap<RegionNumber, RegionMask>> {
349 self.split_record_batch(record_batch)
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use std::assert_matches::assert_matches;
356
357 use super::*;
358 use crate::error::{self, Error};
359 use crate::expr::col;
360
361 #[test]
362 fn test_find_region() {
363 let rule = MultiDimPartitionRule::try_new(
369 vec!["b".to_string()],
370 vec![1, 2, 3],
371 vec![
372 PartitionExpr::new(
373 Operand::Column("b".to_string()),
374 RestrictedOp::Lt,
375 Operand::Value(datatypes::value::Value::String("hz".into())),
376 ),
377 PartitionExpr::new(
378 Operand::Expr(PartitionExpr::new(
379 Operand::Column("b".to_string()),
380 RestrictedOp::GtEq,
381 Operand::Value(datatypes::value::Value::String("hz".into())),
382 )),
383 RestrictedOp::And,
384 Operand::Expr(PartitionExpr::new(
385 Operand::Column("b".to_string()),
386 RestrictedOp::Lt,
387 Operand::Value(datatypes::value::Value::String("sh".into())),
388 )),
389 ),
390 PartitionExpr::new(
391 Operand::Column("b".to_string()),
392 RestrictedOp::GtEq,
393 Operand::Value(datatypes::value::Value::String("sh".into())),
394 ),
395 ],
396 true,
397 )
398 .unwrap();
399 assert_matches!(
400 rule.find_region(&["foo".into(), 1000_i32.into()]),
401 Err(error::Error::RegionKeysSize {
402 expect: 1,
403 actual: 2,
404 ..
405 })
406 );
407 assert_matches!(rule.find_region(&["foo".into()]), Ok(1));
408 assert_matches!(rule.find_region(&["bar".into()]), Ok(1));
409 assert_matches!(rule.find_region(&["hz".into()]), Ok(2));
410 assert_matches!(rule.find_region(&["hzz".into()]), Ok(2));
411 assert_matches!(rule.find_region(&["sh".into()]), Ok(3));
412 assert_matches!(rule.find_region(&["zzzz".into()]), Ok(3));
413 }
414
415 #[test]
416 fn invalid_expr_case_1() {
417 let rule = MultiDimPartitionRule::try_new(
421 vec!["a".to_string(), "b".to_string()],
422 vec![1],
423 vec![PartitionExpr::new(
424 Operand::Column("b".to_string()),
425 RestrictedOp::LtEq,
426 Operand::Expr(PartitionExpr::new(
427 Operand::Expr(PartitionExpr::new(
428 Operand::Column("b".to_string()),
429 RestrictedOp::GtEq,
430 Operand::Value(datatypes::value::Value::String("hz".into())),
431 )),
432 RestrictedOp::And,
433 Operand::Expr(PartitionExpr::new(
434 Operand::Column("b".to_string()),
435 RestrictedOp::Lt,
436 Operand::Value(datatypes::value::Value::String("sh".into())),
437 )),
438 )),
439 )],
440 true,
441 );
442
443 assert_matches!(rule.unwrap_err(), Error::InvalidExpr { .. });
445 }
446
447 #[test]
448 fn invalid_expr_case_2() {
449 let rule = MultiDimPartitionRule::try_new(
453 vec!["a".to_string(), "b".to_string()],
454 vec![1],
455 vec![PartitionExpr::new(
456 Operand::Expr(PartitionExpr::new(
457 Operand::Column("b".to_string()),
458 RestrictedOp::GtEq,
459 Operand::Value(datatypes::value::Value::String("hz".into())),
460 )),
461 RestrictedOp::And,
462 Operand::Value(datatypes::value::Value::String("sh".into())),
463 )],
464 true,
465 );
466
467 assert_matches!(rule.unwrap_err(), Error::InvalidExpr { .. });
469 }
470
471 #[test]
480 fn empty_expr_case_1() {
481 let rule = MultiDimPartitionRule::try_new(
486 vec!["a".to_string(), "b".to_string()],
487 vec![1, 2],
488 vec![
489 PartitionExpr::new(
490 Operand::Column("b".to_string()),
491 RestrictedOp::LtEq,
492 Operand::Value(datatypes::value::Value::String("h".into())),
493 ),
494 PartitionExpr::new(
495 Operand::Column("b".to_string()),
496 RestrictedOp::GtEq,
497 Operand::Value(datatypes::value::Value::String("s".into())),
498 ),
499 ],
500 true,
501 );
502
503 assert_matches!(rule.unwrap_err(), Error::CheckpointNotCovered { .. });
505 }
506
507 #[test]
522 fn empty_expr_case_2() {
523 let rule = MultiDimPartitionRule::try_new(
528 vec!["a".to_string(), "b".to_string()],
529 vec![1, 2],
530 vec![
531 PartitionExpr::new(
532 Operand::Expr(PartitionExpr::new(
533 Operand::Expr(PartitionExpr::new(
534 Operand::Expr(PartitionExpr::new(
536 Operand::Expr(PartitionExpr::new(
537 Operand::Column("a".to_string()),
538 RestrictedOp::GtEq,
539 Operand::Value(datatypes::value::Value::Int64(100)),
540 )),
541 RestrictedOp::And,
542 Operand::Expr(PartitionExpr::new(
543 Operand::Column("b".to_string()),
544 RestrictedOp::LtEq,
545 Operand::Value(datatypes::value::Value::Int64(10)),
546 )),
547 )),
548 RestrictedOp::Or,
549 Operand::Expr(PartitionExpr::new(
551 Operand::Expr(PartitionExpr::new(
552 Operand::Expr(PartitionExpr::new(
553 Operand::Column("a".to_string()),
554 RestrictedOp::Gt,
555 Operand::Value(datatypes::value::Value::Int64(100)),
556 )),
557 RestrictedOp::And,
558 Operand::Expr(PartitionExpr::new(
559 Operand::Column("a".to_string()),
560 RestrictedOp::LtEq,
561 Operand::Value(datatypes::value::Value::Int64(200)),
562 )),
563 )),
564 RestrictedOp::And,
565 Operand::Expr(PartitionExpr::new(
566 Operand::Column("b".to_string()),
567 RestrictedOp::LtEq,
568 Operand::Value(datatypes::value::Value::Int64(10)),
569 )),
570 )),
571 )),
572 RestrictedOp::Or,
573 Operand::Expr(PartitionExpr::new(
575 Operand::Expr(PartitionExpr::new(
576 Operand::Expr(PartitionExpr::new(
577 Operand::Column("a".to_string()),
578 RestrictedOp::GtEq,
579 Operand::Value(datatypes::value::Value::Int64(200)),
580 )),
581 RestrictedOp::And,
582 Operand::Expr(PartitionExpr::new(
583 Operand::Column("b".to_string()),
584 RestrictedOp::Gt,
585 Operand::Value(datatypes::value::Value::Int64(10)),
586 )),
587 )),
588 RestrictedOp::And,
589 Operand::Expr(PartitionExpr::new(
590 Operand::Column("b".to_string()),
591 RestrictedOp::LtEq,
592 Operand::Value(datatypes::value::Value::Int64(20)),
593 )),
594 )),
595 )),
596 RestrictedOp::Or,
597 Operand::Expr(PartitionExpr::new(
599 Operand::Expr(PartitionExpr::new(
600 Operand::Column("a".to_string()),
601 RestrictedOp::Gt,
602 Operand::Value(datatypes::value::Value::Int64(200)),
603 )),
604 RestrictedOp::And,
605 Operand::Expr(PartitionExpr::new(
606 Operand::Column("b".to_string()),
607 RestrictedOp::LtEq,
608 Operand::Value(datatypes::value::Value::Int64(20)),
609 )),
610 )),
611 ),
612 PartitionExpr::new(
613 Operand::Expr(PartitionExpr::new(
615 Operand::Expr(PartitionExpr::new(
616 Operand::Column("a".to_string()),
617 RestrictedOp::Lt,
618 Operand::Value(datatypes::value::Value::Int64(100)),
619 )),
620 RestrictedOp::And,
621 Operand::Expr(PartitionExpr::new(
622 Operand::Column("b".to_string()),
623 RestrictedOp::LtEq,
624 Operand::Value(datatypes::value::Value::Int64(20)),
625 )),
626 )),
627 RestrictedOp::Or,
628 Operand::Expr(PartitionExpr::new(
630 Operand::Expr(PartitionExpr::new(
631 Operand::Column("a".to_string()),
632 RestrictedOp::GtEq,
633 Operand::Value(datatypes::value::Value::Int64(100)),
634 )),
635 RestrictedOp::And,
636 Operand::Expr(PartitionExpr::new(
637 Operand::Column("b".to_string()),
638 RestrictedOp::GtEq,
639 Operand::Value(datatypes::value::Value::Int64(20)),
640 )),
641 )),
642 ),
643 ],
644 true,
645 );
646
647 assert_matches!(rule.unwrap_err(), Error::CheckpointNotCovered { .. });
649 }
650
651 #[test]
652 fn duplicate_expr_case_1() {
653 let rule = MultiDimPartitionRule::try_new(
658 vec!["a".to_string(), "b".to_string()],
659 vec![1, 2],
660 vec![
661 PartitionExpr::new(
662 Operand::Column("a".to_string()),
663 RestrictedOp::LtEq,
664 Operand::Value(datatypes::value::Value::Int64(20)),
665 ),
666 PartitionExpr::new(
667 Operand::Column("a".to_string()),
668 RestrictedOp::GtEq,
669 Operand::Value(datatypes::value::Value::Int64(10)),
670 ),
671 ],
672 true,
673 );
674
675 assert_matches!(rule.unwrap_err(), Error::CheckpointOverlapped { .. });
677 }
678
679 #[test]
680 fn duplicate_expr_case_2() {
681 let rule = MultiDimPartitionRule::try_new(
687 vec!["a".to_string(), "b".to_string()],
688 vec![1, 2],
689 vec![
690 PartitionExpr::new(
691 Operand::Column("a".to_string()),
692 RestrictedOp::NotEq,
693 Operand::Value(datatypes::value::Value::Int64(20)),
694 ),
695 PartitionExpr::new(
696 Operand::Column("a".to_string()),
697 RestrictedOp::LtEq,
698 Operand::Value(datatypes::value::Value::Int64(20)),
699 ),
700 PartitionExpr::new(
701 Operand::Column("a".to_string()),
702 RestrictedOp::Gt,
703 Operand::Value(datatypes::value::Value::Int64(20)),
704 ),
705 ],
706 true,
707 );
708
709 assert_matches!(rule.unwrap_err(), Error::CheckpointOverlapped { .. });
711 }
712
713 #[test]
724 fn test_partial_divided() {
725 let _rule = MultiDimPartitionRule::try_new(
726 vec!["host".to_string(), "value".to_string()],
727 vec![0, 1, 2, 3],
728 vec![
729 col("host")
730 .lt(Value::String("server10".into()))
731 .and(col("value").lt(Value::Int64(10))),
732 col("host")
733 .lt(Value::String("server10".into()))
734 .and(col("value").gt_eq(Value::Int64(10))),
735 col("host").gt_eq(Value::String("server10".into())),
736 ],
737 true,
738 )
739 .unwrap();
740 }
741}
742
743#[cfg(test)]
744mod test_split_record_batch {
745 use std::sync::Arc;
746
747 use datatypes::arrow::array::{Int64Array, StringArray};
748 use datatypes::arrow::datatypes::{DataType, Field, Schema};
749 use datatypes::arrow::record_batch::RecordBatch;
750 use rand::Rng;
751
752 use super::*;
753 use crate::expr::col;
754
755 fn test_schema() -> Arc<Schema> {
756 Arc::new(Schema::new(vec![
757 Field::new("host", DataType::Utf8, false),
758 Field::new("value", DataType::Int64, false),
759 ]))
760 }
761
762 fn generate_random_record_batch(num_rows: usize) -> RecordBatch {
763 let schema = test_schema();
764 let mut rng = rand::thread_rng();
765 let mut host_array = Vec::with_capacity(num_rows);
766 let mut value_array = Vec::with_capacity(num_rows);
767 for _ in 0..num_rows {
768 host_array.push(format!("server{}", rng.gen_range(0..20)));
769 value_array.push(rng.gen_range(0..20));
770 }
771 let host_array = StringArray::from(host_array);
772 let value_array = Int64Array::from(value_array);
773 RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)]).unwrap()
774 }
775
776 #[test]
777 fn test_split_record_batch_by_one_column() {
778 let rule = MultiDimPartitionRule::try_new(
780 vec!["host".to_string(), "value".to_string()],
781 vec![0, 1],
782 vec![
783 col("host").lt(Value::String("server1".into())),
784 col("host").gt_eq(Value::String("server1".into())),
785 ],
786 true,
787 )
788 .unwrap();
789
790 let batch = generate_random_record_batch(1000);
791 let result = rule.split_record_batch(&batch).unwrap();
793 let expected = rule.split_record_batch_naive(&batch).unwrap();
794 assert_eq!(result.len(), expected.len());
795 for (region, value) in &result {
796 assert_eq!(
797 value.array(),
798 expected.get(region).unwrap(),
799 "failed on region: {}",
800 region
801 );
802 }
803 }
804
805 #[test]
806 fn test_split_record_batch_empty() {
807 let rule = MultiDimPartitionRule::try_new(
809 vec!["host".to_string()],
810 vec![1],
811 vec![
812 col("host").lt(Value::String("server1".into())),
813 col("host").gt_eq(Value::String("server1".into())),
814 ],
815 true,
816 )
817 .unwrap();
818
819 let schema = test_schema();
820 let host_array = StringArray::from(Vec::<&str>::new());
821 let value_array = Int64Array::from(Vec::<i64>::new());
822 let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
823 .unwrap();
824
825 let result = rule.split_record_batch(&batch).unwrap();
826 assert_eq!(result.len(), 1);
827 }
828
829 #[test]
830 fn test_split_record_batch_by_two_columns() {
831 let rule = MultiDimPartitionRule::try_new(
832 vec!["host".to_string(), "value".to_string()],
833 vec![0, 1, 2, 3],
834 vec![
835 col("host")
836 .lt(Value::String("server10".into()))
837 .and(col("value").lt(Value::Int64(10))),
838 col("host")
839 .lt(Value::String("server10".into()))
840 .and(col("value").gt_eq(Value::Int64(10))),
841 col("host")
842 .gt_eq(Value::String("server10".into()))
843 .and(col("value").lt(Value::Int64(10))),
844 col("host")
845 .gt_eq(Value::String("server10".into()))
846 .and(col("value").gt_eq(Value::Int64(10))),
847 ],
848 true,
849 )
850 .unwrap();
851
852 let batch = generate_random_record_batch(1000);
853 let result = rule.split_record_batch(&batch).unwrap();
854 let expected = rule.split_record_batch_naive(&batch).unwrap();
855 assert_eq!(result.len(), expected.len());
856 for (region, value) in &result {
857 assert_eq!(value.array(), expected.get(region).unwrap());
858 }
859 }
860
861 #[test]
862 fn test_all_rows_selected() {
863 let rule = MultiDimPartitionRule::try_new(
865 vec!["value".to_string()],
866 vec![1, 2],
867 vec![
868 col("value").lt(Value::Int64(30)),
869 col("value").gt_eq(Value::Int64(30)),
870 ],
871 true,
872 )
873 .unwrap();
874
875 let schema = test_schema();
876 let host_array = StringArray::from(vec!["server1", "server2", "server3", "server4"]);
877 let value_array = Int64Array::from(vec![10, 20, 30, 40]);
878 let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
879 .unwrap();
880
881 let result = rule.split_record_batch(&batch).unwrap();
882
883 assert_eq!(result.len(), 2);
885 assert!(result.contains_key(&1));
886 assert!(result.contains_key(&2));
887
888 assert_eq!(result.get(&1).unwrap().selected_rows(), 2); assert_eq!(result.get(&2).unwrap().selected_rows(), 2); }
892}