1use std::any::Any;
16use std::cmp::Ordering;
17use std::collections::hash_map::Entry;
18use std::collections::HashMap;
19use std::sync::{Arc, RwLock};
20
21use datafusion_expr::ColumnarValue;
22use datafusion_physical_expr::PhysicalExpr;
23use datatypes::arrow;
24use datatypes::arrow::array::{BooleanArray, BooleanBufferBuilder, RecordBatch};
25use datatypes::arrow::buffer::BooleanBuffer;
26use datatypes::arrow::datatypes::Schema;
27use datatypes::prelude::Value;
28use datatypes::vectors::{Helper, VectorRef};
29use serde::{Deserialize, Serialize};
30use snafu::{ensure, OptionExt, ResultExt};
31use store_api::storage::RegionNumber;
32
33use crate::error::{
34 self, ConjunctExprWithNonExprSnafu, InvalidExprSnafu, Result, UnclosedValueSnafu,
35 UndefinedColumnSnafu,
36};
37use crate::expr::{Operand, PartitionExpr, RestrictedOp};
38use crate::partition::RegionMask;
39use crate::PartitionRule;
40
41const DEFAULT_REGION: RegionNumber = 0;
43
44type PhysicalExprCache = Option<(Vec<Arc<dyn PhysicalExpr>>, Arc<Schema>)>;
45
46#[derive(Debug, Serialize, Deserialize)]
52pub struct MultiDimPartitionRule {
53 partition_columns: Vec<String>,
55 name_to_index: HashMap<String, usize>,
57 regions: Vec<RegionNumber>,
60 exprs: Vec<PartitionExpr>,
62 #[serde(skip)]
64 physical_expr_cache: RwLock<PhysicalExprCache>,
65}
66
67impl MultiDimPartitionRule {
68 pub fn try_new(
69 partition_columns: Vec<String>,
70 regions: Vec<RegionNumber>,
71 exprs: Vec<PartitionExpr>,
72 ) -> Result<Self> {
73 let name_to_index = partition_columns
74 .iter()
75 .enumerate()
76 .map(|(i, name)| (name.clone(), i))
77 .collect::<HashMap<_, _>>();
78
79 let rule = Self {
80 partition_columns,
81 name_to_index,
82 regions,
83 exprs,
84 physical_expr_cache: RwLock::new(None),
85 };
86
87 let mut checker = RuleChecker::new(&rule);
88 checker.check()?;
89
90 Ok(rule)
91 }
92
93 fn find_region(&self, values: &[Value]) -> Result<RegionNumber> {
94 ensure!(
95 values.len() == self.partition_columns.len(),
96 error::RegionKeysSizeSnafu {
97 expect: self.partition_columns.len(),
98 actual: values.len(),
99 }
100 );
101
102 for (region_index, expr) in self.exprs.iter().enumerate() {
103 if self.evaluate_expr(expr, values)? {
104 return Ok(self.regions[region_index]);
105 }
106 }
107
108 Ok(DEFAULT_REGION)
110 }
111
112 fn evaluate_expr(&self, expr: &PartitionExpr, values: &[Value]) -> Result<bool> {
113 match (expr.lhs.as_ref(), expr.rhs.as_ref()) {
114 (Operand::Column(name), Operand::Value(r)) => {
115 let index = self.name_to_index.get(name).unwrap();
116 let l = &values[*index];
117 Self::perform_op(l, &expr.op, r)
118 }
119 (Operand::Value(l), Operand::Column(name)) => {
120 let index = self.name_to_index.get(name).unwrap();
121 let r = &values[*index];
122 Self::perform_op(l, &expr.op, r)
123 }
124 (Operand::Expr(lhs), Operand::Expr(rhs)) => {
125 let lhs = self.evaluate_expr(lhs, values)?;
126 let rhs = self.evaluate_expr(rhs, values)?;
127 match expr.op {
128 RestrictedOp::And => Ok(lhs && rhs),
129 RestrictedOp::Or => Ok(lhs || rhs),
130 _ => unreachable!(),
131 }
132 }
133 _ => unreachable!(),
134 }
135 }
136
137 fn perform_op(lhs: &Value, op: &RestrictedOp, rhs: &Value) -> Result<bool> {
138 let result = match op {
139 RestrictedOp::Eq => lhs.eq(rhs),
140 RestrictedOp::NotEq => lhs.ne(rhs),
141 RestrictedOp::Lt => lhs.partial_cmp(rhs) == Some(Ordering::Less),
142 RestrictedOp::LtEq => {
143 let result = lhs.partial_cmp(rhs);
144 result == Some(Ordering::Less) || result == Some(Ordering::Equal)
145 }
146 RestrictedOp::Gt => lhs.partial_cmp(rhs) == Some(Ordering::Greater),
147 RestrictedOp::GtEq => {
148 let result = lhs.partial_cmp(rhs);
149 result == Some(Ordering::Greater) || result == Some(Ordering::Equal)
150 }
151 RestrictedOp::And | RestrictedOp::Or => unreachable!(),
152 };
153
154 Ok(result)
155 }
156
157 pub fn row_at(&self, cols: &[VectorRef], index: usize, row: &mut [Value]) -> Result<()> {
158 for (col_idx, col) in cols.iter().enumerate() {
159 row[col_idx] = col.get(index);
160 }
161 Ok(())
162 }
163
164 pub fn record_batch_to_cols(&self, record_batch: &RecordBatch) -> Result<Vec<VectorRef>> {
165 self.partition_columns
166 .iter()
167 .map(|col_name| {
168 record_batch
169 .column_by_name(col_name)
170 .context(error::UndefinedColumnSnafu { column: col_name })
171 .and_then(|array| {
172 Helper::try_into_vector(array).context(error::ConvertToVectorSnafu)
173 })
174 })
175 .collect::<Result<Vec<_>>>()
176 }
177
178 pub fn split_record_batch_naive(
179 &self,
180 record_batch: &RecordBatch,
181 ) -> Result<HashMap<RegionNumber, BooleanArray>> {
182 let num_rows = record_batch.num_rows();
183
184 let mut result = self
185 .regions
186 .iter()
187 .map(|region| {
188 let mut builder = BooleanBufferBuilder::new(num_rows);
189 builder.append_n(num_rows, false);
190 (*region, builder)
191 })
192 .collect::<HashMap<_, _>>();
193
194 let cols = self.record_batch_to_cols(record_batch)?;
195 let mut current_row = vec![Value::Null; self.partition_columns.len()];
196 for row_idx in 0..num_rows {
197 self.row_at(&cols, row_idx, &mut current_row)?;
198 let current_region = self.find_region(¤t_row)?;
199 let region_mask = result
200 .get_mut(¤t_region)
201 .unwrap_or_else(|| panic!("Region {} must be initialized", current_region));
202 region_mask.set_bit(row_idx, true);
203 }
204
205 Ok(result
206 .into_iter()
207 .map(|(region, mut mask)| (region, BooleanArray::new(mask.finish(), None)))
208 .collect())
209 }
210
211 pub fn split_record_batch(
212 &self,
213 record_batch: &RecordBatch,
214 ) -> Result<HashMap<RegionNumber, RegionMask>> {
215 let num_rows = record_batch.num_rows();
216 if self.regions.len() == 1 {
217 return Ok([(
218 self.regions[0],
219 RegionMask::from(BooleanArray::from(vec![true; num_rows])),
220 )]
221 .into_iter()
222 .collect());
223 }
224 let physical_exprs = {
225 let cache_read_guard = self.physical_expr_cache.read().unwrap();
226 if let Some((cached_exprs, schema)) = cache_read_guard.as_ref()
227 && schema == record_batch.schema_ref()
228 {
229 cached_exprs.clone()
230 } else {
231 drop(cache_read_guard); let schema = record_batch.schema();
234 let new_cache = self
235 .exprs
236 .iter()
237 .map(|e| e.try_as_physical_expr(&schema))
238 .collect::<Result<Vec<_>>>()?;
239
240 let mut cache_write_guard = self.physical_expr_cache.write().unwrap();
241 cache_write_guard.replace((new_cache.clone(), schema));
242 new_cache
243 }
244 };
245
246 let mut result: HashMap<u32, RegionMask> = physical_exprs
247 .iter()
248 .zip(self.regions.iter())
249 .filter_map(|(expr, region_num)| {
250 let col_val = match expr
251 .evaluate(record_batch)
252 .context(error::EvaluateRecordBatchSnafu)
253 {
254 Ok(array) => array,
255 Err(e) => {
256 return Some(Err(e));
257 }
258 };
259 let ColumnarValue::Array(column) = col_val else {
260 unreachable!("Expected an array")
261 };
262 let array =
263 match column
264 .as_any()
265 .downcast_ref::<BooleanArray>()
266 .with_context(|| error::UnexpectedColumnTypeSnafu {
267 data_type: column.data_type().clone(),
268 }) {
269 Ok(array) => array,
270 Err(e) => {
271 return Some(Err(e));
272 }
273 };
274 let selected_rows = array.true_count();
275 if selected_rows == 0 {
276 return None;
278 }
279 Some(Ok((
280 *region_num,
281 RegionMask::new(array.clone(), selected_rows),
282 )))
283 })
284 .collect::<error::Result<_>>()?;
285
286 let selected = if result.len() == 1 {
287 result.values().next().unwrap().array().clone()
288 } else {
289 let mut selected = BooleanArray::new(BooleanBuffer::new_unset(num_rows), None);
290 for region_mask in result.values() {
291 selected = arrow::compute::kernels::boolean::or(&selected, region_mask.array())
292 .context(error::ComputeArrowKernelSnafu)?;
293 }
294 selected
295 };
296
297 if selected.true_count() == num_rows {
299 return Ok(result);
300 }
301
302 let unselected = arrow::compute::kernels::boolean::not(&selected)
304 .context(error::ComputeArrowKernelSnafu)?;
305 match result.entry(DEFAULT_REGION) {
306 Entry::Occupied(mut o) => {
307 let default_region_mask = RegionMask::from(
309 arrow::compute::kernels::boolean::or(o.get().array(), &unselected)
310 .context(error::ComputeArrowKernelSnafu)?,
311 );
312 o.insert(default_region_mask);
313 }
314 Entry::Vacant(v) => {
315 v.insert(RegionMask::from(unselected));
317 }
318 }
319 Ok(result)
320 }
321}
322
323impl PartitionRule for MultiDimPartitionRule {
324 fn as_any(&self) -> &dyn Any {
325 self
326 }
327
328 fn partition_columns(&self) -> Vec<String> {
329 self.partition_columns.clone()
330 }
331
332 fn find_region(&self, values: &[Value]) -> Result<RegionNumber> {
333 self.find_region(values)
334 }
335
336 fn split_record_batch(
337 &self,
338 record_batch: &RecordBatch,
339 ) -> Result<HashMap<RegionNumber, RegionMask>> {
340 self.split_record_batch(record_batch)
341 }
342}
343
344type Axis = HashMap<Value, SplitPoint>;
346
347struct SplitPoint {
349 is_equal: bool,
350 less_than_counter: isize,
351}
352
353struct RuleChecker<'a> {
361 axis: Vec<Axis>,
362 rule: &'a MultiDimPartitionRule,
363}
364
365impl<'a> RuleChecker<'a> {
366 pub fn new(rule: &'a MultiDimPartitionRule) -> Self {
367 let mut projections = Vec::with_capacity(rule.partition_columns.len());
368 projections.resize_with(rule.partition_columns.len(), Default::default);
369
370 Self {
371 axis: projections,
372 rule,
373 }
374 }
375
376 pub fn check(&mut self) -> Result<()> {
377 for expr in &self.rule.exprs {
378 self.walk_expr(expr)?
379 }
380
381 self.check_axis()
382 }
383
384 #[allow(clippy::mutable_key_type)]
385 fn walk_expr(&mut self, expr: &PartitionExpr) -> Result<()> {
386 match expr.op {
388 RestrictedOp::And | RestrictedOp::Or => {
389 match (expr.lhs.as_ref(), expr.rhs.as_ref()) {
390 (Operand::Expr(lhs), Operand::Expr(rhs)) => {
391 self.walk_expr(lhs)?;
392 self.walk_expr(rhs)?
393 }
394 _ => ConjunctExprWithNonExprSnafu { expr: expr.clone() }.fail()?,
395 }
396
397 return Ok(());
398 }
399 _ => {}
401 }
402
403 let (col, val) = match (expr.lhs.as_ref(), expr.rhs.as_ref()) {
404 (Operand::Expr(_), _)
405 | (_, Operand::Expr(_))
406 | (Operand::Column(_), Operand::Column(_))
407 | (Operand::Value(_), Operand::Value(_)) => {
408 InvalidExprSnafu { expr: expr.clone() }.fail()?
409 }
410
411 (Operand::Column(col), Operand::Value(val))
412 | (Operand::Value(val), Operand::Column(col)) => (col, val),
413 };
414
415 let col_index =
416 *self
417 .rule
418 .name_to_index
419 .get(col)
420 .with_context(|| UndefinedColumnSnafu {
421 column: col.clone(),
422 })?;
423 let axis = &mut self.axis[col_index];
424 let split_point = axis.entry(val.clone()).or_insert(SplitPoint {
425 is_equal: false,
426 less_than_counter: 0,
427 });
428 match expr.op {
429 RestrictedOp::Eq => {
430 split_point.is_equal = true;
431 }
432 RestrictedOp::NotEq => {
433 }
435 RestrictedOp::Lt => {
436 split_point.less_than_counter += 1;
437 }
438 RestrictedOp::LtEq => {
439 split_point.less_than_counter += 1;
440 split_point.is_equal = true;
441 }
442 RestrictedOp::Gt => {
443 split_point.less_than_counter -= 1;
444 }
445 RestrictedOp::GtEq => {
446 split_point.less_than_counter -= 1;
447 split_point.is_equal = true;
448 }
449 RestrictedOp::And | RestrictedOp::Or => {
450 unreachable!("conjunct expr should be handled above")
451 }
452 }
453
454 Ok(())
455 }
456
457 fn check_axis(&self) -> Result<()> {
459 for (col_index, axis) in self.axis.iter().enumerate() {
460 for (val, split_point) in axis {
461 if split_point.less_than_counter != 0 || !split_point.is_equal {
462 UnclosedValueSnafu {
463 value: format!("{val:?}"),
464 column: self.rule.partition_columns[col_index].clone(),
465 }
466 .fail()?;
467 }
468 }
469 }
470 Ok(())
471 }
472}
473
474#[cfg(test)]
475mod tests {
476 use std::assert_matches::assert_matches;
477
478 use super::*;
479 use crate::error::{self, Error};
480
481 #[test]
482 fn test_find_region() {
483 let rule = MultiDimPartitionRule::try_new(
489 vec!["b".to_string()],
490 vec![1, 2, 3],
491 vec![
492 PartitionExpr::new(
493 Operand::Column("b".to_string()),
494 RestrictedOp::Lt,
495 Operand::Value(datatypes::value::Value::String("hz".into())),
496 ),
497 PartitionExpr::new(
498 Operand::Expr(PartitionExpr::new(
499 Operand::Column("b".to_string()),
500 RestrictedOp::GtEq,
501 Operand::Value(datatypes::value::Value::String("hz".into())),
502 )),
503 RestrictedOp::And,
504 Operand::Expr(PartitionExpr::new(
505 Operand::Column("b".to_string()),
506 RestrictedOp::Lt,
507 Operand::Value(datatypes::value::Value::String("sh".into())),
508 )),
509 ),
510 PartitionExpr::new(
511 Operand::Column("b".to_string()),
512 RestrictedOp::GtEq,
513 Operand::Value(datatypes::value::Value::String("sh".into())),
514 ),
515 ],
516 )
517 .unwrap();
518 assert_matches!(
519 rule.find_region(&["foo".into(), 1000_i32.into()]),
520 Err(error::Error::RegionKeysSize {
521 expect: 1,
522 actual: 2,
523 ..
524 })
525 );
526 assert_matches!(rule.find_region(&["foo".into()]), Ok(1));
527 assert_matches!(rule.find_region(&["bar".into()]), Ok(1));
528 assert_matches!(rule.find_region(&["hz".into()]), Ok(2));
529 assert_matches!(rule.find_region(&["hzz".into()]), Ok(2));
530 assert_matches!(rule.find_region(&["sh".into()]), Ok(3));
531 assert_matches!(rule.find_region(&["zzzz".into()]), Ok(3));
532 }
533
534 #[test]
535 fn invalid_expr_case_1() {
536 let rule = MultiDimPartitionRule::try_new(
540 vec!["a".to_string(), "b".to_string()],
541 vec![1],
542 vec![PartitionExpr::new(
543 Operand::Column("b".to_string()),
544 RestrictedOp::LtEq,
545 Operand::Expr(PartitionExpr::new(
546 Operand::Expr(PartitionExpr::new(
547 Operand::Column("b".to_string()),
548 RestrictedOp::GtEq,
549 Operand::Value(datatypes::value::Value::String("hz".into())),
550 )),
551 RestrictedOp::And,
552 Operand::Expr(PartitionExpr::new(
553 Operand::Column("b".to_string()),
554 RestrictedOp::Lt,
555 Operand::Value(datatypes::value::Value::String("sh".into())),
556 )),
557 )),
558 )],
559 );
560
561 assert_matches!(rule.unwrap_err(), Error::InvalidExpr { .. });
563 }
564
565 #[test]
566 fn invalid_expr_case_2() {
567 let rule = MultiDimPartitionRule::try_new(
571 vec!["a".to_string(), "b".to_string()],
572 vec![1],
573 vec![PartitionExpr::new(
574 Operand::Expr(PartitionExpr::new(
575 Operand::Column("b".to_string()),
576 RestrictedOp::GtEq,
577 Operand::Value(datatypes::value::Value::String("hz".into())),
578 )),
579 RestrictedOp::And,
580 Operand::Value(datatypes::value::Value::String("sh".into())),
581 )],
582 );
583
584 assert_matches!(rule.unwrap_err(), Error::ConjunctExprWithNonExpr { .. });
586 }
587
588 #[test]
597 fn empty_expr_case_1() {
598 let rule = MultiDimPartitionRule::try_new(
603 vec!["a".to_string(), "b".to_string()],
604 vec![1, 2],
605 vec![
606 PartitionExpr::new(
607 Operand::Column("b".to_string()),
608 RestrictedOp::LtEq,
609 Operand::Value(datatypes::value::Value::String("h".into())),
610 ),
611 PartitionExpr::new(
612 Operand::Column("b".to_string()),
613 RestrictedOp::GtEq,
614 Operand::Value(datatypes::value::Value::String("s".into())),
615 ),
616 ],
617 );
618
619 assert_matches!(rule.unwrap_err(), Error::UnclosedValue { .. });
621 }
622
623 #[test]
638 fn empty_expr_case_2() {
639 let rule = MultiDimPartitionRule::try_new(
644 vec!["a".to_string(), "b".to_string()],
645 vec![1, 2],
646 vec![
647 PartitionExpr::new(
648 Operand::Expr(PartitionExpr::new(
649 Operand::Expr(PartitionExpr::new(
650 Operand::Expr(PartitionExpr::new(
652 Operand::Expr(PartitionExpr::new(
653 Operand::Column("a".to_string()),
654 RestrictedOp::GtEq,
655 Operand::Value(datatypes::value::Value::Int64(100)),
656 )),
657 RestrictedOp::And,
658 Operand::Expr(PartitionExpr::new(
659 Operand::Column("b".to_string()),
660 RestrictedOp::LtEq,
661 Operand::Value(datatypes::value::Value::Int64(10)),
662 )),
663 )),
664 RestrictedOp::Or,
665 Operand::Expr(PartitionExpr::new(
667 Operand::Expr(PartitionExpr::new(
668 Operand::Expr(PartitionExpr::new(
669 Operand::Column("a".to_string()),
670 RestrictedOp::Gt,
671 Operand::Value(datatypes::value::Value::Int64(100)),
672 )),
673 RestrictedOp::And,
674 Operand::Expr(PartitionExpr::new(
675 Operand::Column("a".to_string()),
676 RestrictedOp::LtEq,
677 Operand::Value(datatypes::value::Value::Int64(200)),
678 )),
679 )),
680 RestrictedOp::And,
681 Operand::Expr(PartitionExpr::new(
682 Operand::Column("b".to_string()),
683 RestrictedOp::LtEq,
684 Operand::Value(datatypes::value::Value::Int64(10)),
685 )),
686 )),
687 )),
688 RestrictedOp::Or,
689 Operand::Expr(PartitionExpr::new(
691 Operand::Expr(PartitionExpr::new(
692 Operand::Expr(PartitionExpr::new(
693 Operand::Column("a".to_string()),
694 RestrictedOp::GtEq,
695 Operand::Value(datatypes::value::Value::Int64(200)),
696 )),
697 RestrictedOp::And,
698 Operand::Expr(PartitionExpr::new(
699 Operand::Column("b".to_string()),
700 RestrictedOp::Gt,
701 Operand::Value(datatypes::value::Value::Int64(10)),
702 )),
703 )),
704 RestrictedOp::And,
705 Operand::Expr(PartitionExpr::new(
706 Operand::Column("b".to_string()),
707 RestrictedOp::LtEq,
708 Operand::Value(datatypes::value::Value::Int64(20)),
709 )),
710 )),
711 )),
712 RestrictedOp::Or,
713 Operand::Expr(PartitionExpr::new(
715 Operand::Expr(PartitionExpr::new(
716 Operand::Column("a".to_string()),
717 RestrictedOp::Gt,
718 Operand::Value(datatypes::value::Value::Int64(200)),
719 )),
720 RestrictedOp::And,
721 Operand::Expr(PartitionExpr::new(
722 Operand::Column("b".to_string()),
723 RestrictedOp::LtEq,
724 Operand::Value(datatypes::value::Value::Int64(20)),
725 )),
726 )),
727 ),
728 PartitionExpr::new(
729 Operand::Expr(PartitionExpr::new(
731 Operand::Expr(PartitionExpr::new(
732 Operand::Column("a".to_string()),
733 RestrictedOp::Lt,
734 Operand::Value(datatypes::value::Value::Int64(100)),
735 )),
736 RestrictedOp::And,
737 Operand::Expr(PartitionExpr::new(
738 Operand::Column("b".to_string()),
739 RestrictedOp::LtEq,
740 Operand::Value(datatypes::value::Value::Int64(20)),
741 )),
742 )),
743 RestrictedOp::Or,
744 Operand::Expr(PartitionExpr::new(
746 Operand::Expr(PartitionExpr::new(
747 Operand::Column("a".to_string()),
748 RestrictedOp::GtEq,
749 Operand::Value(datatypes::value::Value::Int64(100)),
750 )),
751 RestrictedOp::And,
752 Operand::Expr(PartitionExpr::new(
753 Operand::Column("b".to_string()),
754 RestrictedOp::GtEq,
755 Operand::Value(datatypes::value::Value::Int64(20)),
756 )),
757 )),
758 ),
759 ],
760 );
761
762 assert_matches!(rule.unwrap_err(), Error::UnclosedValue { .. });
764 }
765
766 #[test]
767 fn duplicate_expr_case_1() {
768 let rule = MultiDimPartitionRule::try_new(
773 vec!["a".to_string(), "b".to_string()],
774 vec![1, 2],
775 vec![
776 PartitionExpr::new(
777 Operand::Column("a".to_string()),
778 RestrictedOp::LtEq,
779 Operand::Value(datatypes::value::Value::Int64(20)),
780 ),
781 PartitionExpr::new(
782 Operand::Column("a".to_string()),
783 RestrictedOp::GtEq,
784 Operand::Value(datatypes::value::Value::Int64(10)),
785 ),
786 ],
787 );
788
789 assert_matches!(rule.unwrap_err(), Error::UnclosedValue { .. });
791 }
792
793 #[test]
794 #[ignore = "checker cannot detect this kind of duplicate for now"]
795 fn duplicate_expr_case_2() {
796 let rule = MultiDimPartitionRule::try_new(
802 vec!["a".to_string(), "b".to_string()],
803 vec![1, 2],
804 vec![
805 PartitionExpr::new(
806 Operand::Column("a".to_string()),
807 RestrictedOp::NotEq,
808 Operand::Value(datatypes::value::Value::Int64(20)),
809 ),
810 PartitionExpr::new(
811 Operand::Column("a".to_string()),
812 RestrictedOp::LtEq,
813 Operand::Value(datatypes::value::Value::Int64(20)),
814 ),
815 PartitionExpr::new(
816 Operand::Column("a".to_string()),
817 RestrictedOp::Gt,
818 Operand::Value(datatypes::value::Value::Int64(20)),
819 ),
820 ],
821 );
822
823 assert!(rule.is_err());
825 }
826}
827
828#[cfg(test)]
829mod test_split_record_batch {
830 use std::sync::Arc;
831
832 use datatypes::arrow::array::{Int64Array, StringArray};
833 use datatypes::arrow::datatypes::{DataType, Field, Schema};
834 use datatypes::arrow::record_batch::RecordBatch;
835 use rand::Rng;
836
837 use super::*;
838 use crate::expr::col;
839
840 fn test_schema() -> Arc<Schema> {
841 Arc::new(Schema::new(vec![
842 Field::new("host", DataType::Utf8, false),
843 Field::new("value", DataType::Int64, false),
844 ]))
845 }
846
847 fn generate_random_record_batch(num_rows: usize) -> RecordBatch {
848 let schema = test_schema();
849 let mut rng = rand::thread_rng();
850 let mut host_array = Vec::with_capacity(num_rows);
851 let mut value_array = Vec::with_capacity(num_rows);
852 for _ in 0..num_rows {
853 host_array.push(format!("server{}", rng.gen_range(0..20)));
854 value_array.push(rng.gen_range(0..20));
855 }
856 let host_array = StringArray::from(host_array);
857 let value_array = Int64Array::from(value_array);
858 RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)]).unwrap()
859 }
860
861 #[test]
862 fn test_split_record_batch_by_one_column() {
863 let rule = MultiDimPartitionRule::try_new(
865 vec!["host".to_string(), "value".to_string()],
866 vec![0, 1],
867 vec![
868 col("host").lt(Value::String("server1".into())),
869 col("host").gt_eq(Value::String("server1".into())),
870 ],
871 )
872 .unwrap();
873
874 let batch = generate_random_record_batch(1000);
875 let result = rule.split_record_batch(&batch).unwrap();
877 let expected = rule.split_record_batch_naive(&batch).unwrap();
878 assert_eq!(result.len(), expected.len());
879 for (region, value) in &result {
880 assert_eq!(
881 value.array(),
882 expected.get(region).unwrap(),
883 "failed on region: {}",
884 region
885 );
886 }
887 }
888
889 #[test]
890 fn test_split_record_batch_empty() {
891 let rule = MultiDimPartitionRule::try_new(
893 vec!["host".to_string()],
894 vec![1],
895 vec![PartitionExpr::new(
896 Operand::Column("host".to_string()),
897 RestrictedOp::Eq,
898 Operand::Value(Value::String("server1".into())),
899 )],
900 )
901 .unwrap();
902
903 let schema = test_schema();
904 let host_array = StringArray::from(Vec::<&str>::new());
905 let value_array = Int64Array::from(Vec::<i64>::new());
906 let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
907 .unwrap();
908
909 let result = rule.split_record_batch(&batch).unwrap();
910 assert_eq!(result.len(), 1);
911 }
912
913 #[test]
914 fn test_split_record_batch_by_two_columns() {
915 let rule = MultiDimPartitionRule::try_new(
916 vec!["host".to_string(), "value".to_string()],
917 vec![0, 1, 2, 3],
918 vec![
919 col("host")
920 .lt(Value::String("server10".into()))
921 .and(col("value").lt(Value::Int64(10))),
922 col("host")
923 .lt(Value::String("server10".into()))
924 .and(col("value").gt_eq(Value::Int64(10))),
925 col("host")
926 .gt_eq(Value::String("server10".into()))
927 .and(col("value").lt(Value::Int64(10))),
928 col("host")
929 .gt_eq(Value::String("server10".into()))
930 .and(col("value").gt_eq(Value::Int64(10))),
931 ],
932 )
933 .unwrap();
934
935 let batch = generate_random_record_batch(1000);
936 let result = rule.split_record_batch(&batch).unwrap();
937 let expected = rule.split_record_batch_naive(&batch).unwrap();
938 assert_eq!(result.len(), expected.len());
939 for (region, value) in &result {
940 assert_eq!(value.array(), expected.get(region).unwrap());
941 }
942 }
943
944 #[test]
945 fn test_default_region() {
946 let rule = MultiDimPartitionRule::try_new(
947 vec!["host".to_string(), "value".to_string()],
948 vec![0, 1, 2, 3],
949 vec![
950 col("host")
951 .lt(Value::String("server10".into()))
952 .and(col("value").eq(Value::Int64(10))),
953 col("host")
954 .lt(Value::String("server10".into()))
955 .and(col("value").eq(Value::Int64(20))),
956 col("host")
957 .gt_eq(Value::String("server10".into()))
958 .and(col("value").eq(Value::Int64(10))),
959 col("host")
960 .gt_eq(Value::String("server10".into()))
961 .and(col("value").eq(Value::Int64(20))),
962 ],
963 )
964 .unwrap();
965
966 let schema = test_schema();
967 let host_array = StringArray::from(vec!["server1", "server1", "server1", "server100"]);
968 let value_array = Int64Array::from(vec![10, 20, 30, 10]);
969 let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
970 .unwrap();
971 let result = rule.split_record_batch(&batch).unwrap();
972 let expected = rule.split_record_batch_naive(&batch).unwrap();
973 for (region, value) in &result {
974 assert_eq!(value.array(), expected.get(region).unwrap());
975 }
976 }
977
978 #[test]
979 fn test_default_region_with_unselected_rows() {
980 let rule = MultiDimPartitionRule::try_new(
982 vec!["host".to_string(), "value".to_string()],
983 vec![1, 2, 3],
984 vec![
985 col("value").eq(Value::Int64(10)),
986 col("value").eq(Value::Int64(20)),
987 col("value").eq(Value::Int64(30)),
988 ],
989 )
990 .unwrap();
991
992 let schema = test_schema();
993 let host_array =
994 StringArray::from(vec!["server1", "server2", "server3", "server4", "server5"]);
995 let value_array = Int64Array::from(vec![10, 20, 30, 40, 50]);
996 let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
997 .unwrap();
998
999 let result = rule.split_record_batch(&batch).unwrap();
1000
1001 assert_eq!(result.len(), 4);
1003
1004 assert!(result.contains_key(&DEFAULT_REGION));
1006 let default_mask = result.get(&DEFAULT_REGION).unwrap();
1007
1008 assert_eq!(default_mask.selected_rows(), 2);
1010
1011 assert_eq!(result.get(&1).unwrap().selected_rows(), 1); assert_eq!(result.get(&2).unwrap().selected_rows(), 1); assert_eq!(result.get(&3).unwrap().selected_rows(), 1); }
1016
1017 #[test]
1018 fn test_default_region_with_existing_default() {
1019 let rule = MultiDimPartitionRule::try_new(
1022 vec!["host".to_string(), "value".to_string()],
1023 vec![0, 1, 2],
1024 vec![
1025 col("value").eq(Value::Int64(10)), col("value").eq(Value::Int64(20)),
1027 col("value").eq(Value::Int64(30)),
1028 ],
1029 )
1030 .unwrap();
1031
1032 let schema = test_schema();
1033 let host_array =
1034 StringArray::from(vec!["server1", "server2", "server3", "server4", "server5"]);
1035 let value_array = Int64Array::from(vec![10, 20, 30, 40, 50]);
1036 let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
1037 .unwrap();
1038
1039 let result = rule.split_record_batch(&batch).unwrap();
1040
1041 assert_eq!(result.len(), 3);
1043
1044 assert!(result.contains_key(&DEFAULT_REGION));
1046 let default_mask = result.get(&DEFAULT_REGION).unwrap();
1047
1048 assert_eq!(default_mask.selected_rows(), 3);
1050
1051 assert_eq!(result.get(&1).unwrap().selected_rows(), 1); assert_eq!(result.get(&2).unwrap().selected_rows(), 1); }
1055
1056 #[test]
1057 fn test_all_rows_selected() {
1058 let rule = MultiDimPartitionRule::try_new(
1060 vec!["value".to_string()],
1061 vec![1, 2],
1062 vec![
1063 col("value").lt(Value::Int64(30)),
1064 col("value").gt_eq(Value::Int64(30)),
1065 ],
1066 )
1067 .unwrap();
1068
1069 let schema = test_schema();
1070 let host_array = StringArray::from(vec!["server1", "server2", "server3", "server4"]);
1071 let value_array = Int64Array::from(vec![10, 20, 30, 40]);
1072 let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)])
1073 .unwrap();
1074
1075 let result = rule.split_record_batch(&batch).unwrap();
1076
1077 assert_eq!(result.len(), 2);
1079 assert!(result.contains_key(&1));
1080 assert!(result.contains_key(&2));
1081
1082 assert_eq!(result.get(&1).unwrap().selected_rows(), 2); assert_eq!(result.get(&2).unwrap().selected_rows(), 2); }
1086}