1use std::any::Any;
16use std::cmp::Ordering;
17use std::collections::HashMap;
18use std::collections::hash_map::Entry;
19use std::sync::{Arc, RwLock};
20
21use datafusion_expr::ColumnarValue;
22use datafusion_physical_expr::PhysicalExpr;
23use datatypes::arrow;
24use datatypes::arrow::array::{BooleanArray, BooleanBufferBuilder, RecordBatch};
25use datatypes::arrow::buffer::BooleanBuffer;
26use datatypes::arrow::datatypes::Schema;
27use datatypes::prelude::Value;
28use datatypes::vectors::{Helper, VectorRef};
29use serde::{Deserialize, Serialize};
30use snafu::{OptionExt, ResultExt, ensure};
31use store_api::storage::RegionNumber;
32
33use crate::PartitionRule;
34use crate::checker::PartitionChecker;
35use crate::error::{self, Result, UndefinedColumnSnafu};
36use crate::expr::{Operand, PartitionExpr, RestrictedOp};
37use crate::partition::RegionMask;
38
39const DEFAULT_REGION: RegionNumber = 0;
41
42type PhysicalExprCache = Option<(Vec<Arc<dyn PhysicalExpr>>, Arc<Schema>)>;
43
44#[derive(Debug, Serialize, Deserialize)]
50pub struct MultiDimPartitionRule {
51 partition_columns: Vec<String>,
53 name_to_index: HashMap<String, usize>,
55 regions: Vec<RegionNumber>,
58 exprs: Vec<PartitionExpr>,
60 #[serde(skip)]
62 physical_expr_cache: RwLock<PhysicalExprCache>,
63}
64
65impl MultiDimPartitionRule {
66 pub fn try_new(
71 partition_columns: Vec<String>,
72 regions: Vec<RegionNumber>,
73 exprs: Vec<PartitionExpr>,
74 check_exprs: bool,
75 ) -> Result<Self> {
76 let name_to_index = partition_columns
77 .iter()
78 .enumerate()
79 .map(|(i, name)| (name.clone(), i))
80 .collect::<HashMap<_, _>>();
81
82 let rule = Self {
83 partition_columns,
84 name_to_index,
85 regions,
86 exprs,
87 physical_expr_cache: RwLock::new(None),
88 };
89
90 if check_exprs {
91 let checker = PartitionChecker::try_new(&rule)?;
92 checker.check()?;
93 }
94
95 Ok(rule)
96 }
97
98 pub fn exprs(&self) -> &[PartitionExpr] {
99 &self.exprs
100 }
101
102 fn find_region(&self, values: &[Value]) -> Result<RegionNumber> {
103 ensure!(
104 values.len() == self.partition_columns.len(),
105 error::RegionKeysSizeSnafu {
106 expect: self.partition_columns.len(),
107 actual: values.len(),
108 }
109 );
110
111 for (region_index, expr) in self.exprs.iter().enumerate() {
112 if self.evaluate_expr(expr, values)? {
113 return Ok(self.regions[region_index]);
114 }
115 }
116
117 Ok(DEFAULT_REGION)
119 }
120
121 fn evaluate_expr(&self, expr: &PartitionExpr, values: &[Value]) -> Result<bool> {
122 match (expr.lhs.as_ref(), expr.rhs.as_ref()) {
123 (Operand::Column(name), Operand::Value(r)) => {
124 let index = self.name_to_index.get(name).unwrap();
125 let l = &values[*index];
126 Self::perform_op(l, &expr.op, r)
127 }
128 (Operand::Value(l), Operand::Column(name)) => {
129 let index = self.name_to_index.get(name).unwrap();
130 let r = &values[*index];
131 Self::perform_op(l, &expr.op, r)
132 }
133 (Operand::Expr(lhs), Operand::Expr(rhs)) => {
134 let lhs = self.evaluate_expr(lhs, values)?;
135 let rhs = self.evaluate_expr(rhs, values)?;
136 match expr.op {
137 RestrictedOp::And => Ok(lhs && rhs),
138 RestrictedOp::Or => Ok(lhs || rhs),
139 _ => unreachable!(),
140 }
141 }
142 _ => unreachable!(),
143 }
144 }
145
146 fn perform_op(lhs: &Value, op: &RestrictedOp, rhs: &Value) -> Result<bool> {
147 let result = match op {
148 RestrictedOp::Eq => lhs.eq(rhs),
149 RestrictedOp::NotEq => lhs.ne(rhs),
150 RestrictedOp::Lt => lhs.partial_cmp(rhs) == Some(Ordering::Less),
151 RestrictedOp::LtEq => {
152 let result = lhs.partial_cmp(rhs);
153 result == Some(Ordering::Less) || result == Some(Ordering::Equal)
154 }
155 RestrictedOp::Gt => lhs.partial_cmp(rhs) == Some(Ordering::Greater),
156 RestrictedOp::GtEq => {
157 let result = lhs.partial_cmp(rhs);
158 result == Some(Ordering::Greater) || result == Some(Ordering::Equal)
159 }
160 RestrictedOp::And | RestrictedOp::Or => unreachable!(),
161 };
162
163 Ok(result)
164 }
165
166 pub fn row_at(&self, cols: &[VectorRef], index: usize, row: &mut [Value]) -> Result<()> {
167 for (col_idx, col) in cols.iter().enumerate() {
168 row[col_idx] = col.get(index);
169 }
170 Ok(())
171 }
172
173 pub fn record_batch_to_cols(&self, record_batch: &RecordBatch) -> Result<Vec<VectorRef>> {
174 self.partition_columns
175 .iter()
176 .map(|col_name| {
177 record_batch
178 .column_by_name(col_name)
179 .context(UndefinedColumnSnafu { column: col_name })
180 .and_then(|array| {
181 Helper::try_into_vector(array).context(error::ConvertToVectorSnafu)
182 })
183 })
184 .collect::<Result<Vec<_>>>()
185 }
186
187 pub fn split_record_batch_naive(
188 &self,
189 record_batch: &RecordBatch,
190 ) -> Result<HashMap<RegionNumber, BooleanArray>> {
191 let num_rows = record_batch.num_rows();
192
193 let mut result = self
194 .regions
195 .iter()
196 .map(|region| {
197 let mut builder = BooleanBufferBuilder::new(num_rows);
198 builder.append_n(num_rows, false);
199 (*region, builder)
200 })
201 .collect::<HashMap<_, _>>();
202
203 let cols = self.record_batch_to_cols(record_batch)?;
204 let mut current_row = vec![Value::Null; self.partition_columns.len()];
205 for row_idx in 0..num_rows {
206 self.row_at(&cols, row_idx, &mut current_row)?;
207 let current_region = self.find_region(¤t_row)?;
208 let region_mask = result
209 .get_mut(¤t_region)
210 .unwrap_or_else(|| panic!("Region {} must be initialized", current_region));
211 region_mask.set_bit(row_idx, true);
212 }
213
214 Ok(result
215 .into_iter()
216 .map(|(region, mut mask)| (region, BooleanArray::new(mask.finish(), None)))
217 .collect())
218 }
219
220 pub fn split_record_batch(
221 &self,
222 record_batch: &RecordBatch,
223 ) -> Result<HashMap<RegionNumber, RegionMask>> {
224 let num_rows = record_batch.num_rows();
225 if self.regions.len() == 1 {
226 return Ok([(
227 self.regions[0],
228 RegionMask::from(BooleanArray::from(vec![true; num_rows])),
229 )]
230 .into_iter()
231 .collect());
232 }
233 let physical_exprs = {
234 let cache_read_guard = self.physical_expr_cache.read().unwrap();
235 if let Some((cached_exprs, schema)) = cache_read_guard.as_ref()
236 && schema == record_batch.schema_ref()
237 {
238 cached_exprs.clone()
239 } else {
240 drop(cache_read_guard); let schema = record_batch.schema();
243 let new_cache = self
244 .exprs
245 .iter()
246 .map(|e| e.try_as_physical_expr(&schema))
247 .collect::<Result<Vec<_>>>()?;
248
249 let mut cache_write_guard = self.physical_expr_cache.write().unwrap();
250 cache_write_guard.replace((new_cache.clone(), schema));
251 new_cache
252 }
253 };
254
255 let mut result: HashMap<u32, RegionMask> = physical_exprs
256 .iter()
257 .zip(self.regions.iter())
258 .filter_map(|(expr, region_num)| {
259 let col_val = match expr
260 .evaluate(record_batch)
261 .context(error::EvaluateRecordBatchSnafu)
262 {
263 Ok(array) => array,
264 Err(e) => {
265 return Some(Err(e));
266 }
267 };
268 let array = match columnar_value_to_boolean_array(col_val, num_rows) {
269 Ok(array) => array,
270 Err(e) => {
271 return Some(Err(e));
272 }
273 };
274 let selected_rows = array.true_count();
275 if selected_rows == 0 {
276 return None;
278 }
279 Some(Ok((*region_num, RegionMask::new(array, selected_rows))))
280 })
281 .collect::<error::Result<_>>()?;
282
283 let selected = if result.len() == 1 {
284 result.values().next().unwrap().array().clone()
285 } else {
286 let mut selected = BooleanArray::new(BooleanBuffer::new_unset(num_rows), None);
287 for region_mask in result.values() {
288 selected = arrow::compute::kernels::boolean::or(&selected, region_mask.array())
289 .context(error::ComputeArrowKernelSnafu)?;
290 }
291 selected
292 };
293
294 if selected.true_count() == num_rows {
296 return Ok(result);
297 }
298
299 let unselected = arrow::compute::kernels::boolean::not(&selected)
301 .context(error::ComputeArrowKernelSnafu)?;
302 match result.entry(DEFAULT_REGION) {
303 Entry::Occupied(mut o) => {
304 let default_region_mask = RegionMask::from(
306 arrow::compute::kernels::boolean::or(o.get().array(), &unselected)
307 .context(error::ComputeArrowKernelSnafu)?,
308 );
309 o.insert(default_region_mask);
310 }
311 Entry::Vacant(v) => {
312 v.insert(RegionMask::from(unselected));
314 }
315 }
316 Ok(result)
317 }
318}
319
320fn columnar_value_to_boolean_array(
321 col_val: ColumnarValue,
322 num_rows: usize,
323) -> Result<BooleanArray> {
324 let column = col_val
325 .into_array(num_rows)
326 .context(error::EvaluateRecordBatchSnafu)?;
327 let array = column
328 .as_any()
329 .downcast_ref::<BooleanArray>()
330 .with_context(|| error::UnexpectedColumnTypeSnafu {
331 data_type: column.data_type().clone(),
332 })?;
333 Ok(array.clone())
334}
335
336impl PartitionRule for MultiDimPartitionRule {
337 fn as_any(&self) -> &dyn Any {
338 self
339 }
340
341 fn partition_columns(&self) -> Vec<String> {
342 self.partition_columns.clone()
343 }
344
345 fn find_region(&self, values: &[Value]) -> Result<RegionNumber> {
346 self.find_region(values)
347 }
348
349 fn split_record_batch(
350 &self,
351 record_batch: &RecordBatch,
352 ) -> Result<HashMap<RegionNumber, RegionMask>> {
353 self.split_record_batch(record_batch)
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use std::assert_matches::assert_matches;
360
361 use super::*;
362 use crate::error::{self, Error};
363 use crate::expr::col;
364
365 #[test]
366 fn test_find_region() {
367 let rule = MultiDimPartitionRule::try_new(
373 vec!["b".to_string()],
374 vec![1, 2, 3],
375 vec![
376 PartitionExpr::new(
377 Operand::Column("b".to_string()),
378 RestrictedOp::Lt,
379 Operand::Value(datatypes::value::Value::String("hz".into())),
380 ),
381 PartitionExpr::new(
382 Operand::Expr(PartitionExpr::new(
383 Operand::Column("b".to_string()),
384 RestrictedOp::GtEq,
385 Operand::Value(datatypes::value::Value::String("hz".into())),
386 )),
387 RestrictedOp::And,
388 Operand::Expr(PartitionExpr::new(
389 Operand::Column("b".to_string()),
390 RestrictedOp::Lt,
391 Operand::Value(datatypes::value::Value::String("sh".into())),
392 )),
393 ),
394 PartitionExpr::new(
395 Operand::Column("b".to_string()),
396 RestrictedOp::GtEq,
397 Operand::Value(datatypes::value::Value::String("sh".into())),
398 ),
399 ],
400 true,
401 )
402 .unwrap();
403 assert_matches!(
404 rule.find_region(&["foo".into(), 1000_i32.into()]),
405 Err(error::Error::RegionKeysSize {
406 expect: 1,
407 actual: 2,
408 ..
409 })
410 );
411 assert_matches!(rule.find_region(&["foo".into()]), Ok(1));
412 assert_matches!(rule.find_region(&["bar".into()]), Ok(1));
413 assert_matches!(rule.find_region(&["hz".into()]), Ok(2));
414 assert_matches!(rule.find_region(&["hzz".into()]), Ok(2));
415 assert_matches!(rule.find_region(&["sh".into()]), Ok(3));
416 assert_matches!(rule.find_region(&["zzzz".into()]), Ok(3));
417 }
418
419 #[test]
420 fn invalid_expr_case_1() {
421 let rule = MultiDimPartitionRule::try_new(
425 vec!["a".to_string(), "b".to_string()],
426 vec![1],
427 vec![PartitionExpr::new(
428 Operand::Column("b".to_string()),
429 RestrictedOp::LtEq,
430 Operand::Expr(PartitionExpr::new(
431 Operand::Expr(PartitionExpr::new(
432 Operand::Column("b".to_string()),
433 RestrictedOp::GtEq,
434 Operand::Value(datatypes::value::Value::String("hz".into())),
435 )),
436 RestrictedOp::And,
437 Operand::Expr(PartitionExpr::new(
438 Operand::Column("b".to_string()),
439 RestrictedOp::Lt,
440 Operand::Value(datatypes::value::Value::String("sh".into())),
441 )),
442 )),
443 )],
444 true,
445 );
446
447 assert_matches!(rule.unwrap_err(), Error::InvalidExpr { .. });
449 }
450
451 #[test]
452 fn invalid_expr_case_2() {
453 let rule = MultiDimPartitionRule::try_new(
457 vec!["a".to_string(), "b".to_string()],
458 vec![1],
459 vec![PartitionExpr::new(
460 Operand::Expr(PartitionExpr::new(
461 Operand::Column("b".to_string()),
462 RestrictedOp::GtEq,
463 Operand::Value(datatypes::value::Value::String("hz".into())),
464 )),
465 RestrictedOp::And,
466 Operand::Value(datatypes::value::Value::String("sh".into())),
467 )],
468 true,
469 );
470
471 assert_matches!(rule.unwrap_err(), Error::InvalidExpr { .. });
473 }
474
475 #[test]
484 fn empty_expr_case_1() {
485 let rule = MultiDimPartitionRule::try_new(
490 vec!["a".to_string(), "b".to_string()],
491 vec![1, 2],
492 vec![
493 PartitionExpr::new(
494 Operand::Column("b".to_string()),
495 RestrictedOp::LtEq,
496 Operand::Value(datatypes::value::Value::String("h".into())),
497 ),
498 PartitionExpr::new(
499 Operand::Column("b".to_string()),
500 RestrictedOp::GtEq,
501 Operand::Value(datatypes::value::Value::String("s".into())),
502 ),
503 ],
504 true,
505 );
506
507 assert_matches!(rule.unwrap_err(), Error::CheckpointNotCovered { .. });
509 }
510
511 #[test]
526 fn empty_expr_case_2() {
527 let rule = MultiDimPartitionRule::try_new(
532 vec!["a".to_string(), "b".to_string()],
533 vec![1, 2],
534 vec![
535 PartitionExpr::new(
536 Operand::Expr(PartitionExpr::new(
537 Operand::Expr(PartitionExpr::new(
538 Operand::Expr(PartitionExpr::new(
540 Operand::Expr(PartitionExpr::new(
541 Operand::Column("a".to_string()),
542 RestrictedOp::GtEq,
543 Operand::Value(datatypes::value::Value::Int64(100)),
544 )),
545 RestrictedOp::And,
546 Operand::Expr(PartitionExpr::new(
547 Operand::Column("b".to_string()),
548 RestrictedOp::LtEq,
549 Operand::Value(datatypes::value::Value::Int64(10)),
550 )),
551 )),
552 RestrictedOp::Or,
553 Operand::Expr(PartitionExpr::new(
555 Operand::Expr(PartitionExpr::new(
556 Operand::Expr(PartitionExpr::new(
557 Operand::Column("a".to_string()),
558 RestrictedOp::Gt,
559 Operand::Value(datatypes::value::Value::Int64(100)),
560 )),
561 RestrictedOp::And,
562 Operand::Expr(PartitionExpr::new(
563 Operand::Column("a".to_string()),
564 RestrictedOp::LtEq,
565 Operand::Value(datatypes::value::Value::Int64(200)),
566 )),
567 )),
568 RestrictedOp::And,
569 Operand::Expr(PartitionExpr::new(
570 Operand::Column("b".to_string()),
571 RestrictedOp::LtEq,
572 Operand::Value(datatypes::value::Value::Int64(10)),
573 )),
574 )),
575 )),
576 RestrictedOp::Or,
577 Operand::Expr(PartitionExpr::new(
579 Operand::Expr(PartitionExpr::new(
580 Operand::Expr(PartitionExpr::new(
581 Operand::Column("a".to_string()),
582 RestrictedOp::GtEq,
583 Operand::Value(datatypes::value::Value::Int64(200)),
584 )),
585 RestrictedOp::And,
586 Operand::Expr(PartitionExpr::new(
587 Operand::Column("b".to_string()),
588 RestrictedOp::Gt,
589 Operand::Value(datatypes::value::Value::Int64(10)),
590 )),
591 )),
592 RestrictedOp::And,
593 Operand::Expr(PartitionExpr::new(
594 Operand::Column("b".to_string()),
595 RestrictedOp::LtEq,
596 Operand::Value(datatypes::value::Value::Int64(20)),
597 )),
598 )),
599 )),
600 RestrictedOp::Or,
601 Operand::Expr(PartitionExpr::new(
603 Operand::Expr(PartitionExpr::new(
604 Operand::Column("a".to_string()),
605 RestrictedOp::Gt,
606 Operand::Value(datatypes::value::Value::Int64(200)),
607 )),
608 RestrictedOp::And,
609 Operand::Expr(PartitionExpr::new(
610 Operand::Column("b".to_string()),
611 RestrictedOp::LtEq,
612 Operand::Value(datatypes::value::Value::Int64(20)),
613 )),
614 )),
615 ),
616 PartitionExpr::new(
617 Operand::Expr(PartitionExpr::new(
619 Operand::Expr(PartitionExpr::new(
620 Operand::Column("a".to_string()),
621 RestrictedOp::Lt,
622 Operand::Value(datatypes::value::Value::Int64(100)),
623 )),
624 RestrictedOp::And,
625 Operand::Expr(PartitionExpr::new(
626 Operand::Column("b".to_string()),
627 RestrictedOp::LtEq,
628 Operand::Value(datatypes::value::Value::Int64(20)),
629 )),
630 )),
631 RestrictedOp::Or,
632 Operand::Expr(PartitionExpr::new(
634 Operand::Expr(PartitionExpr::new(
635 Operand::Column("a".to_string()),
636 RestrictedOp::GtEq,
637 Operand::Value(datatypes::value::Value::Int64(100)),
638 )),
639 RestrictedOp::And,
640 Operand::Expr(PartitionExpr::new(
641 Operand::Column("b".to_string()),
642 RestrictedOp::GtEq,
643 Operand::Value(datatypes::value::Value::Int64(20)),
644 )),
645 )),
646 ),
647 ],
648 true,
649 );
650
651 assert_matches!(rule.unwrap_err(), Error::CheckpointNotCovered { .. });
653 }
654
655 #[test]
656 fn duplicate_expr_case_1() {
657 let rule = MultiDimPartitionRule::try_new(
662 vec!["a".to_string(), "b".to_string()],
663 vec![1, 2],
664 vec![
665 PartitionExpr::new(
666 Operand::Column("a".to_string()),
667 RestrictedOp::LtEq,
668 Operand::Value(datatypes::value::Value::Int64(20)),
669 ),
670 PartitionExpr::new(
671 Operand::Column("a".to_string()),
672 RestrictedOp::GtEq,
673 Operand::Value(datatypes::value::Value::Int64(10)),
674 ),
675 ],
676 true,
677 );
678
679 assert_matches!(rule.unwrap_err(), Error::CheckpointOverlapped { .. });
681 }
682
683 #[test]
684 fn duplicate_expr_case_2() {
685 let rule = MultiDimPartitionRule::try_new(
691 vec!["a".to_string(), "b".to_string()],
692 vec![1, 2],
693 vec![
694 PartitionExpr::new(
695 Operand::Column("a".to_string()),
696 RestrictedOp::NotEq,
697 Operand::Value(datatypes::value::Value::Int64(20)),
698 ),
699 PartitionExpr::new(
700 Operand::Column("a".to_string()),
701 RestrictedOp::LtEq,
702 Operand::Value(datatypes::value::Value::Int64(20)),
703 ),
704 PartitionExpr::new(
705 Operand::Column("a".to_string()),
706 RestrictedOp::Gt,
707 Operand::Value(datatypes::value::Value::Int64(20)),
708 ),
709 ],
710 true,
711 );
712
713 assert_matches!(rule.unwrap_err(), Error::CheckpointOverlapped { .. });
715 }
716
717 #[test]
728 fn test_partial_divided() {
729 let _rule = MultiDimPartitionRule::try_new(
730 vec!["host".to_string(), "value".to_string()],
731 vec![0, 1, 2, 3],
732 vec![
733 col("host")
734 .lt(Value::String("server10".into()))
735 .and(col("value").lt(Value::Int64(10))),
736 col("host")
737 .lt(Value::String("server10".into()))
738 .and(col("value").gt_eq(Value::Int64(10))),
739 col("host").gt_eq(Value::String("server10".into())),
740 ],
741 true,
742 )
743 .unwrap();
744 }
745}
746
747#[cfg(test)]
748mod test_split_record_batch {
749 use std::sync::Arc;
750
751 use datafusion_common::ScalarValue;
752 use datatypes::arrow::array::{Int64Array, StringArray};
753 use datatypes::arrow::datatypes::{DataType, Field, Schema};
754 use datatypes::arrow::record_batch::RecordBatch;
755 use rand::Rng;
756
757 use super::*;
758 use crate::expr::{Operand, col};
759
760 fn test_schema() -> Arc<Schema> {
761 Arc::new(Schema::new(vec![
762 Field::new("host", DataType::Utf8, false),
763 Field::new("value", DataType::Int64, false),
764 ]))
765 }
766
767 fn generate_random_record_batch(num_rows: usize) -> RecordBatch {
768 let schema = test_schema();
769 let mut rng = rand::thread_rng();
770 let mut host_array = Vec::with_capacity(num_rows);
771 let mut value_array = Vec::with_capacity(num_rows);
772 for _ in 0..num_rows {
773 host_array.push(format!("server{}", rng.gen_range(0..20)));
774 value_array.push(rng.gen_range(0..20));
775 }
776 let host_array = StringArray::from(host_array);
777 let value_array = Int64Array::from(value_array);
778 RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)]).unwrap()
779 }
780
781 #[test]
782 fn test_split_record_batch_by_one_column() {
783 let rule = MultiDimPartitionRule::try_new(
785 vec!["host".to_string(), "value".to_string()],
786 vec![0, 1],
787 vec![
788 col("host").lt(Value::String("server1".into())),
789 col("host").gt_eq(Value::String("server1".into())),
790 ],
791 true,
792 )
793 .unwrap();
794
795 let batch = generate_random_record_batch(1000);
796 let result = rule.split_record_batch(&batch).unwrap();
798 let expected = rule.split_record_batch_naive(&batch).unwrap();
799 assert_eq!(result.len(), expected.len());
800 for (region, value) in &result {
801 assert_eq!(
802 value.array(),
803 expected.get(region).unwrap(),
804 "failed on region: {}",
805 region
806 );
807 }
808 }
809
810 #[test]
811 fn test_split_record_batch_empty() {
812 let rule = MultiDimPartitionRule::try_new(
814 vec!["host".to_string()],
815 vec![1],
816 vec![
817 col("host").lt(Value::String("server1".into())),
818 col("host").gt_eq(Value::String("server1".into())),
819 ],
820 true,
821 )
822 .unwrap();
823
824 let schema = test_schema();
825 let host_array = StringArray::from(Vec::<&str>::new());
826 let value_array = Int64Array::from(Vec::<i64>::new());
827 let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
828 .unwrap();
829
830 let result = rule.split_record_batch(&batch).unwrap();
831 assert_eq!(result.len(), 1);
832 }
833
834 #[test]
835 fn test_split_record_batch_by_two_columns() {
836 let rule = MultiDimPartitionRule::try_new(
837 vec!["host".to_string(), "value".to_string()],
838 vec![0, 1, 2, 3],
839 vec![
840 col("host")
841 .lt(Value::String("server10".into()))
842 .and(col("value").lt(Value::Int64(10))),
843 col("host")
844 .lt(Value::String("server10".into()))
845 .and(col("value").gt_eq(Value::Int64(10))),
846 col("host")
847 .gt_eq(Value::String("server10".into()))
848 .and(col("value").lt(Value::Int64(10))),
849 col("host")
850 .gt_eq(Value::String("server10".into()))
851 .and(col("value").gt_eq(Value::Int64(10))),
852 ],
853 true,
854 )
855 .unwrap();
856
857 let batch = generate_random_record_batch(1000);
858 let result = rule.split_record_batch(&batch).unwrap();
859 let expected = rule.split_record_batch_naive(&batch).unwrap();
860 assert_eq!(result.len(), expected.len());
861 for (region, value) in &result {
862 assert_eq!(value.array(), expected.get(region).unwrap());
863 }
864 }
865
866 #[test]
867 fn test_all_rows_selected() {
868 let rule = MultiDimPartitionRule::try_new(
870 vec!["value".to_string()],
871 vec![1, 2],
872 vec![
873 col("value").lt(Value::Int64(30)),
874 col("value").gt_eq(Value::Int64(30)),
875 ],
876 true,
877 )
878 .unwrap();
879
880 let schema = test_schema();
881 let host_array = StringArray::from(vec!["server1", "server2", "server3", "server4"]);
882 let value_array = Int64Array::from(vec![10, 20, 30, 40]);
883 let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
884 .unwrap();
885
886 let result = rule.split_record_batch(&batch).unwrap();
887
888 assert_eq!(result.len(), 2);
890 assert!(result.contains_key(&1));
891 assert!(result.contains_key(&2));
892
893 assert_eq!(result.get(&1).unwrap().selected_rows(), 2); assert_eq!(result.get(&2).unwrap().selected_rows(), 2); }
897
898 #[test]
899 fn test_split_record_batch_with_scalar_predicate() {
900 let rule = MultiDimPartitionRule::try_new(
902 vec!["host".to_string()],
903 vec![0, 1],
904 vec![
905 PartitionExpr::new(
906 Operand::Column("host".to_string()),
907 RestrictedOp::Lt,
908 Operand::Value(Value::String("never_happen_1".into())),
909 ),
910 PartitionExpr::new(
911 Operand::Expr(PartitionExpr::new(
912 Operand::Column("host".to_string()),
913 RestrictedOp::GtEq,
914 Operand::Value(Value::String("never_happen_1".into())),
915 )),
916 RestrictedOp::And,
917 Operand::Value(Value::Boolean(false)),
918 ),
919 ],
920 false,
921 )
922 .unwrap();
923
924 let batch = generate_random_record_batch(8);
925 let result = rule.split_record_batch(&batch).unwrap();
926
927 assert_eq!(result.len(), 1);
928 assert!(result.contains_key(&0));
929
930 let total_rows = result.get(&0).unwrap().selected_rows();
931 assert_eq!(total_rows, batch.num_rows());
932 }
933
934 #[test]
935 fn test_columnar_value_to_boolean_array_scalar_false() {
936 let result = columnar_value_to_boolean_array(
937 ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))),
938 4,
939 )
940 .unwrap();
941 assert_eq!(result.len(), 4);
942 assert_eq!(result.true_count(), 0);
943 }
944
945 #[test]
946 fn test_columnar_value_to_boolean_array_scalar_true() {
947 let result = columnar_value_to_boolean_array(
948 ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))),
949 4,
950 )
951 .unwrap();
952 assert_eq!(result.len(), 4);
953 assert_eq!(result.true_count(), 4);
954 }
955}