1use std::sync::Arc;
18
19use datafusion::error::Result as DfResult;
20use datafusion::logical_expr::{Expr, Literal, Operator};
21use datafusion::physical_plan::PhysicalExpr;
22use datafusion_common::arrow::array::{ArrayRef, Datum, Scalar};
23use datafusion_common::arrow::buffer::BooleanBuffer;
24use datafusion_common::arrow::compute::kernels::cmp;
25use datafusion_common::cast::{as_boolean_array, as_null_array, as_string_array};
26use datafusion_common::{DataFusionError, ScalarValue, internal_err};
27use datatypes::arrow::array::{
28 Array, ArrayAccessor, ArrayData, BooleanArray, BooleanBufferBuilder, DictionaryArray,
29 RecordBatch, StringArrayType,
30};
31use datatypes::arrow::compute::filter_record_batch;
32use datatypes::arrow::datatypes::{DataType, UInt32Type};
33use datatypes::arrow::error::ArrowError;
34use datatypes::compute::or_kleene;
35use datatypes::value::Value;
36use datatypes::vectors::VectorRef;
37use regex::Regex;
38use snafu::ResultExt;
39
40use crate::error::{ArrowComputeSnafu, Result, ToArrowScalarSnafu, UnsupportedOperationSnafu};
41
42#[derive(Debug)]
52pub struct SimpleFilterEvaluator {
53 column_name: String,
55 literal: Scalar<ArrayRef>,
57 op: Operator,
59 literal_list: Vec<Scalar<ArrayRef>>,
61 regex: Option<Regex>,
65 regex_negative: bool,
67}
68
69impl SimpleFilterEvaluator {
70 pub fn new<T: Literal>(column_name: String, lit: T, op: Operator) -> Option<Self> {
71 match op {
72 Operator::Eq
73 | Operator::NotEq
74 | Operator::Lt
75 | Operator::LtEq
76 | Operator::Gt
77 | Operator::GtEq => {}
78 _ => return None,
79 }
80
81 let Expr::Literal(val, _) = lit.lit() else {
82 return None;
83 };
84
85 Some(Self {
86 column_name,
87 literal: val.to_scalar().ok()?,
88 op,
89 literal_list: vec![],
90 regex: None,
91 regex_negative: false,
92 })
93 }
94
95 pub fn try_new(predicate: &Expr) -> Option<Self> {
96 match predicate {
97 Expr::BinaryExpr(binary) => {
98 match binary.op {
100 Operator::Eq
101 | Operator::NotEq
102 | Operator::Lt
103 | Operator::LtEq
104 | Operator::Gt
105 | Operator::GtEq
106 | Operator::RegexMatch
107 | Operator::RegexIMatch
108 | Operator::RegexNotMatch
109 | Operator::RegexNotIMatch => {}
110 Operator::Or => {
111 let lhs = Self::try_new(&binary.left)?;
112 let rhs = Self::try_new(&binary.right)?;
113 if lhs.column_name != rhs.column_name
114 || !matches!(lhs.op, Operator::Eq | Operator::Or)
115 || !matches!(rhs.op, Operator::Eq | Operator::Or)
116 {
117 return None;
118 }
119 let mut list = vec![];
120 let placeholder_literal = lhs.literal.clone();
121 if matches!(lhs.op, Operator::Or) {
123 list.extend(lhs.literal_list);
124 } else {
125 list.push(lhs.literal);
126 }
127 if matches!(rhs.op, Operator::Or) {
128 list.extend(rhs.literal_list);
129 } else {
130 list.push(rhs.literal);
131 }
132 return Some(Self {
133 column_name: lhs.column_name,
134 literal: placeholder_literal,
135 op: Operator::Or,
136 literal_list: list,
137 regex: None,
138 regex_negative: false,
139 });
140 }
141 _ => return None,
142 }
143
144 let mut op = binary.op;
146 let (lhs, rhs) = match (&*binary.left, &*binary.right) {
147 (Expr::Column(col), Expr::Literal(lit, _)) => (col, lit),
148 (Expr::Literal(lit, _), Expr::Column(col)) => {
149 op = op.swap().unwrap();
151 (col, lit)
152 }
153 _ => return None,
154 };
155
156 let (regex, regex_negative) = Self::maybe_build_regex(op, rhs).ok()?;
157 let literal = rhs.to_scalar().ok()?;
158 Some(Self {
159 column_name: lhs.name.clone(),
160 literal,
161 op,
162 literal_list: vec![],
163 regex,
164 regex_negative,
165 })
166 }
167 _ => None,
168 }
169 }
170
171 pub fn column_name(&self) -> &str {
173 &self.column_name
174 }
175
176 pub fn is_eq(&self) -> bool {
177 matches!(self.op, Operator::Eq)
178 }
179
180 pub fn is_not_eq(&self) -> bool {
181 matches!(self.op, Operator::NotEq)
182 }
183
184 pub fn is_lt(&self) -> bool {
185 matches!(self.op, Operator::Lt)
186 }
187
188 pub fn is_lt_eq(&self) -> bool {
189 matches!(self.op, Operator::LtEq)
190 }
191
192 pub fn is_gt(&self) -> bool {
193 matches!(self.op, Operator::Gt)
194 }
195
196 pub fn is_gt_eq(&self) -> bool {
197 matches!(self.op, Operator::GtEq)
198 }
199
200 pub fn is_or_eq_chain(&self) -> bool {
203 matches!(self.op, Operator::Or)
204 }
205
206 pub fn literal_value(&self) -> Option<Value> {
208 let array = self.literal.get().0;
209 let scalar = ScalarValue::try_from_array(array, 0).ok()?;
210 Value::try_from(scalar).ok()
211 }
212
213 pub fn literal_list_values(&self) -> Option<Vec<Value>> {
216 self.literal_list
217 .iter()
218 .map(|scalar| {
219 let array = scalar.get().0;
220 let scalar = ScalarValue::try_from_array(array, 0).ok()?;
221 Value::try_from(scalar).ok()
222 })
223 .collect()
224 }
225
226 pub fn evaluate_scalar(&self, input: &ScalarValue) -> Result<bool> {
227 let input = input
228 .to_scalar()
229 .with_context(|_| ToArrowScalarSnafu { v: input.clone() })?;
230 let result = self.evaluate_datum(&input, 1)?;
231 Ok(result.value(0))
232 }
233
234 pub fn evaluate_array(&self, input: &ArrayRef) -> Result<BooleanBuffer> {
235 self.evaluate_datum(input, input.len())
236 }
237
238 pub fn evaluate_vector(&self, input: &VectorRef) -> Result<BooleanBuffer> {
239 self.evaluate_datum(&input.to_arrow_array(), input.len())
240 }
241
242 fn evaluate_datum(&self, input: &impl Datum, input_len: usize) -> Result<BooleanBuffer> {
243 let result = match self.op {
244 Operator::Eq => cmp::eq(input, &self.literal),
245 Operator::NotEq => cmp::neq(input, &self.literal),
246 Operator::Lt => cmp::lt(input, &self.literal),
247 Operator::LtEq => cmp::lt_eq(input, &self.literal),
248 Operator::Gt => cmp::gt(input, &self.literal),
249 Operator::GtEq => cmp::gt_eq(input, &self.literal),
250 Operator::RegexMatch => self.regex_match(input),
251 Operator::RegexIMatch => self.regex_match(input),
252 Operator::RegexNotMatch => self.regex_match(input),
253 Operator::RegexNotIMatch => self.regex_match(input),
254 Operator::Or => {
255 let mut result: BooleanArray = vec![false; input_len].into();
257 for literal in &self.literal_list {
258 let rhs = cmp::eq(input, literal).context(ArrowComputeSnafu)?;
259 result = or_kleene(&result, &rhs).context(ArrowComputeSnafu)?;
260 }
261 Ok(result)
262 }
263 _ => {
264 return UnsupportedOperationSnafu {
265 reason: format!("{:?}", self.op),
266 }
267 .fail();
268 }
269 };
270 result
271 .context(ArrowComputeSnafu)
272 .map(|array| array.values().clone())
273 }
274
275 fn maybe_build_regex(
286 operator: Operator,
287 value: &ScalarValue,
288 ) -> Result<(Option<Regex>, bool), ArrowError> {
289 let (ignore_case, negative) = match operator {
290 Operator::RegexMatch => (false, false),
291 Operator::RegexIMatch => (true, false),
292 Operator::RegexNotMatch => (false, true),
293 Operator::RegexNotIMatch => (true, true),
294 _ => return Ok((None, false)),
295 };
296 let flag = if ignore_case { Some("i") } else { None };
297 let regex = value
298 .try_as_str()
299 .ok_or_else(|| ArrowError::CastError(format!("Cannot cast {:?} to str", value)))?
300 .ok_or_else(|| ArrowError::CastError("Regex should not be null".to_string()))?;
301 let pattern = match flag {
302 Some(flag) => format!("(?{flag}){regex}"),
303 None => regex.to_string(),
304 };
305 if pattern.is_empty() {
306 Ok((None, negative))
307 } else {
308 Regex::new(pattern.as_str())
309 .map_err(|e| {
310 ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}"))
311 })
312 .map(|regex| (Some(regex), negative))
313 }
314 }
315
316 fn regex_match(&self, input: &impl Datum) -> std::result::Result<BooleanArray, ArrowError> {
317 let array = input.get().0;
318
319 if let Ok(string_array) = as_string_array(array) {
321 let mut result = regexp_is_match_scalar(string_array, self.regex.as_ref())?;
322 if self.regex_negative {
323 result = datatypes::compute::not(&result)?;
324 }
325 return Ok(result);
326 }
327
328 if let Some(dict_array) = array.as_any().downcast_ref::<DictionaryArray<UInt32Type>>() {
330 let mut result = regexp_is_match_dictionary(dict_array, self.regex.as_ref())?;
331 if self.regex_negative {
332 result = datatypes::compute::not(&result)?;
333 }
334 return Ok(result);
335 }
336
337 Err(ArrowError::CastError(format!(
338 "Cannot cast {:?} to StringArray or StringDictionaryArray",
339 array.data_type()
340 )))
341 }
342}
343
344pub fn batch_filter(
347 batch: &RecordBatch,
348 predicate: &Arc<dyn PhysicalExpr>,
349) -> DfResult<RecordBatch> {
350 predicate
351 .evaluate(batch)
352 .and_then(|v| v.into_array(batch.num_rows()))
353 .and_then(|array| {
354 let filter_array = match as_boolean_array(&array) {
355 Ok(boolean_array) => Ok(boolean_array.clone()),
356 Err(_) => {
357 let Ok(null_array) = as_null_array(&array) else {
358 return internal_err!(
359 "Cannot create filter_array from non-boolean predicates"
360 );
361 };
362
363 Ok::<BooleanArray, DataFusionError>(BooleanArray::new_null(null_array.len()))
365 }
366 }?;
367 Ok(filter_record_batch(batch, &filter_array)?)
368 })
369}
370
371pub fn regexp_is_match_scalar<'a, S>(
375 array: &'a S,
376 regex: Option<&Regex>,
377) -> Result<BooleanArray, ArrowError>
378where
379 &'a S: StringArrayType<'a>,
380{
381 let null_bit_buffer = array.nulls().map(|x| x.inner().sliced());
382 let mut result = BooleanBufferBuilder::new(array.len());
383
384 if let Some(re) = regex {
385 for i in 0..array.len() {
386 let value = array.value(i);
387 result.append(re.is_match(value));
388 }
389 } else {
390 result.append_n(array.len(), true);
391 }
392
393 let buffer = result.into();
394 let data = unsafe {
395 ArrayData::new_unchecked(
396 DataType::Boolean,
397 array.len(),
398 None,
399 null_bit_buffer,
400 0,
401 vec![buffer],
402 vec![],
403 )
404 };
405
406 Ok(BooleanArray::from(data))
407}
408
409pub fn regexp_is_match_dictionary(
412 dict_array: &DictionaryArray<UInt32Type>,
413 regex: Option<&Regex>,
414) -> Result<BooleanArray, ArrowError> {
415 let string_values = dict_array
417 .values()
418 .as_any()
419 .downcast_ref::<datatypes::arrow::array::StringArray>()
420 .ok_or_else(|| {
421 ArrowError::CastError("Dictionary values must be StringArray".to_string())
422 })?;
423
424 let null_bit_buffer = dict_array.nulls().map(|x| x.inner().sliced());
425 let mut result = BooleanBufferBuilder::new(dict_array.len());
426
427 if let Some(re) = regex {
428 let keys = dict_array.keys().values();
429 for i in 0..dict_array.len() {
430 if dict_array.is_null(i) {
431 result.append(false);
432 } else {
433 let key = keys[i] as usize;
434 let string_value = string_values.value(key);
435 result.append(re.is_match(string_value));
436 }
437 }
438 } else {
439 result.append_n(dict_array.len(), true);
440 }
441
442 let buffer = result.into();
443 let data = unsafe {
444 ArrayData::new_unchecked(
445 DataType::Boolean,
446 dict_array.len(),
447 None,
448 null_bit_buffer,
449 0,
450 vec![buffer],
451 vec![],
452 )
453 };
454
455 Ok(BooleanArray::from(data))
456}
457
458#[cfg(test)]
459mod test {
460
461 use std::sync::Arc;
462
463 use datafusion::execution::context::ExecutionProps;
464 use datafusion::logical_expr::{BinaryExpr, col, lit};
465 use datafusion::physical_expr::create_physical_expr;
466 use datafusion_common::{Column, DFSchema};
467 use datatypes::arrow::datatypes::{DataType, Field, Schema};
468
469 use super::*;
470
471 #[test]
472 fn unsupported_filter_op() {
473 let expr = Expr::BinaryExpr(BinaryExpr {
475 left: Box::new(Expr::Column(Column::from_name("foo"))),
476 op: Operator::Plus,
477 right: Box::new(1.lit()),
478 });
479 assert!(SimpleFilterEvaluator::try_new(&expr).is_none());
480
481 let expr = Expr::BinaryExpr(BinaryExpr {
483 left: Box::new(1.lit()),
484 op: Operator::Eq,
485 right: Box::new(1.lit()),
486 });
487 assert!(SimpleFilterEvaluator::try_new(&expr).is_none());
488
489 let expr = Expr::BinaryExpr(BinaryExpr {
491 left: Box::new(Expr::Column(Column::from_name("foo"))),
492 op: Operator::Eq,
493 right: Box::new(Expr::Column(Column::from_name("bar"))),
494 });
495 assert!(SimpleFilterEvaluator::try_new(&expr).is_none());
496
497 let expr = Expr::BinaryExpr(BinaryExpr {
499 left: Box::new(Expr::BinaryExpr(BinaryExpr {
500 left: Box::new(Expr::Column(Column::from_name("foo"))),
501 op: Operator::Eq,
502 right: Box::new(1.lit()),
503 })),
504 op: Operator::Eq,
505 right: Box::new(1.lit()),
506 });
507 assert!(SimpleFilterEvaluator::try_new(&expr).is_none());
508 }
509
510 #[test]
511 fn supported_filter_op() {
512 let expr = Expr::BinaryExpr(BinaryExpr {
514 left: Box::new(Expr::Column(Column::from_name("foo"))),
515 op: Operator::Eq,
516 right: Box::new(1.lit()),
517 });
518 let _ = SimpleFilterEvaluator::try_new(&expr).unwrap();
519
520 let expr = Expr::BinaryExpr(BinaryExpr {
522 left: Box::new(1.lit()),
523 op: Operator::Lt,
524 right: Box::new(Expr::Column(Column::from_name("foo"))),
525 });
526 let evaluator = SimpleFilterEvaluator::try_new(&expr).unwrap();
527 assert_eq!(evaluator.op, Operator::Gt);
528 assert_eq!(evaluator.column_name, "foo".to_string());
529 }
530
531 #[test]
532 fn run_on_array() {
533 let expr = Expr::BinaryExpr(BinaryExpr {
534 left: Box::new(Expr::Column(Column::from_name("foo"))),
535 op: Operator::Eq,
536 right: Box::new(1i64.lit()),
537 });
538 let evaluator = SimpleFilterEvaluator::try_new(&expr).unwrap();
539
540 let input_1 = Arc::new(datatypes::arrow::array::Int64Array::from(vec![1, 2, 3])) as _;
541 let result = evaluator.evaluate_array(&input_1).unwrap();
542 assert_eq!(result, BooleanBuffer::from(vec![true, false, false]));
543
544 let input_2 = Arc::new(datatypes::arrow::array::Int64Array::from(vec![1, 1, 1])) as _;
545 let result = evaluator.evaluate_array(&input_2).unwrap();
546 assert_eq!(result, BooleanBuffer::from(vec![true, true, true]));
547
548 let input_3 = Arc::new(datatypes::arrow::array::Int64Array::new_null(0)) as _;
549 let result = evaluator.evaluate_array(&input_3).unwrap();
550 assert_eq!(result, BooleanBuffer::from(vec![]));
551 }
552
553 #[test]
554 fn run_on_scalar() {
555 let expr = Expr::BinaryExpr(BinaryExpr {
556 left: Box::new(Expr::Column(Column::from_name("foo"))),
557 op: Operator::Lt,
558 right: Box::new(1i64.lit()),
559 });
560 let evaluator = SimpleFilterEvaluator::try_new(&expr).unwrap();
561
562 let input_1 = ScalarValue::Int64(Some(1));
563 let result = evaluator.evaluate_scalar(&input_1).unwrap();
564 assert!(!result);
565
566 let input_2 = ScalarValue::Int64(Some(0));
567 let result = evaluator.evaluate_scalar(&input_2).unwrap();
568 assert!(result);
569
570 let input_3 = ScalarValue::Int64(None);
571 let result = evaluator.evaluate_scalar(&input_3).unwrap();
572 assert!(!result);
573 }
574
575 #[test]
576 fn batch_filter_test() {
577 let expr = col("ts").gt(lit(123456u64));
578 let schema = Schema::new(vec![
579 Field::new("a", DataType::Int32, true),
580 Field::new("ts", DataType::UInt64, false),
581 ]);
582 let df_schema = DFSchema::try_from(schema.clone()).unwrap();
583 let props = ExecutionProps::new();
584 let physical_expr = create_physical_expr(&expr, &df_schema, &props).unwrap();
585 let batch = RecordBatch::try_new(
586 Arc::new(schema),
587 vec![
588 Arc::new(datatypes::arrow::array::Int32Array::from(vec![4, 5, 6])),
589 Arc::new(datatypes::arrow::array::UInt64Array::from(vec![
590 123456, 123457, 123458,
591 ])),
592 ],
593 )
594 .unwrap();
595 let new_batch = batch_filter(&batch, &physical_expr).unwrap();
596 assert_eq!(new_batch.num_rows(), 2);
597 let first_column_values = new_batch
598 .column(0)
599 .as_any()
600 .downcast_ref::<datatypes::arrow::array::Int32Array>()
601 .unwrap();
602 let expected = datatypes::arrow::array::Int32Array::from(vec![5, 6]);
603 assert_eq!(first_column_values, &expected);
604 }
605
606 #[test]
607 fn test_complex_filter_expression() {
608 let col_eq_b = col("col").eq(lit("B"));
610 let col_eq_c = col("col").eq(lit("C"));
611 let col_eq_d = col("col").eq(lit("D"));
612
613 let col_or_expr = col_eq_b.or(col_eq_c).or(col_eq_d);
615
616 let or_evaluator = SimpleFilterEvaluator::try_new(&col_or_expr).unwrap();
618 assert_eq!(or_evaluator.column_name, "col");
619 assert_eq!(or_evaluator.op, Operator::Or);
620 assert_eq!(or_evaluator.literal_list.len(), 3);
621 assert_eq!(
622 format!("{:?}", or_evaluator.literal_list),
623 "[Scalar(StringArray\n[\n \"B\",\n]), Scalar(StringArray\n[\n \"C\",\n]), Scalar(StringArray\n[\n \"D\",\n])]"
624 );
625
626 let schema = Schema::new(vec![Field::new("col", DataType::Utf8, false)]);
628 let df_schema = DFSchema::try_from(schema.clone()).unwrap();
629 let props = ExecutionProps::new();
630 let physical_expr = create_physical_expr(&col_or_expr, &df_schema, &props).unwrap();
631
632 let col_data = Arc::new(datatypes::arrow::array::StringArray::from(vec![
634 "B", "C", "E", "B", "C", "D", "F",
635 ]));
636 let batch = RecordBatch::try_new(Arc::new(schema), vec![col_data]).unwrap();
637 let expected = datatypes::arrow::array::StringArray::from(vec!["B", "C", "B", "C", "D"]);
638
639 let filtered_batch = batch_filter(&batch, &physical_expr).unwrap();
641
642 assert_eq!(filtered_batch.num_rows(), 5);
645
646 let col_filtered = filtered_batch
647 .column(0)
648 .as_any()
649 .downcast_ref::<datatypes::arrow::array::StringArray>()
650 .unwrap();
651 assert_eq!(col_filtered, &expected);
652 }
653
654 #[test]
655 fn test_maybe_build_regex() {
656 let (regex, negative) = SimpleFilterEvaluator::maybe_build_regex(
658 Operator::RegexMatch,
659 &ScalarValue::Utf8(Some("a.*b".to_string())),
660 )
661 .unwrap();
662 assert!(regex.is_some());
663 assert!(!negative);
664 assert!(regex.unwrap().is_match("axxb"));
665
666 let (regex, negative) = SimpleFilterEvaluator::maybe_build_regex(
668 Operator::RegexIMatch,
669 &ScalarValue::Utf8(Some("a.*b".to_string())),
670 )
671 .unwrap();
672 assert!(regex.is_some());
673 assert!(!negative);
674 assert!(regex.unwrap().is_match("AxxB"));
675
676 let (regex, negative) = SimpleFilterEvaluator::maybe_build_regex(
678 Operator::RegexNotMatch,
679 &ScalarValue::Utf8(Some("a.*b".to_string())),
680 )
681 .unwrap();
682 assert!(regex.is_some());
683 assert!(negative);
684
685 let (regex, negative) = SimpleFilterEvaluator::maybe_build_regex(
687 Operator::RegexNotIMatch,
688 &ScalarValue::Utf8(Some("a.*b".to_string())),
689 )
690 .unwrap();
691 assert!(regex.is_some());
692 assert!(negative);
693
694 let (regex, negative) = SimpleFilterEvaluator::maybe_build_regex(
696 Operator::RegexMatch,
697 &ScalarValue::Utf8(Some("".to_string())),
698 )
699 .unwrap();
700 assert!(regex.is_none());
701 assert!(!negative);
702
703 let (regex, negative) = SimpleFilterEvaluator::maybe_build_regex(
705 Operator::Eq,
706 &ScalarValue::Utf8(Some("a.*b".to_string())),
707 )
708 .unwrap();
709 assert!(regex.is_none());
710 assert!(!negative);
711
712 let result = SimpleFilterEvaluator::maybe_build_regex(
714 Operator::RegexMatch,
715 &ScalarValue::Utf8(Some("a(b".to_string())),
716 );
717 assert!(result.is_err());
718
719 let result = SimpleFilterEvaluator::maybe_build_regex(
721 Operator::RegexMatch,
722 &ScalarValue::Int64(Some(123)),
723 );
724 assert!(result.is_err());
725
726 let result = SimpleFilterEvaluator::maybe_build_regex(
728 Operator::RegexMatch,
729 &ScalarValue::Utf8(None),
730 );
731 assert!(result.is_err());
732 }
733
734 #[test]
735 fn test_regex_match_dictionary_array() {
736 use datatypes::arrow::array::StringDictionaryBuilder;
737
738 let mut builder = StringDictionaryBuilder::<UInt32Type>::new();
740 builder.append("apple").unwrap();
741 builder.append("banana").unwrap();
742 builder.append("apple").unwrap();
743 builder.append("cherry").unwrap();
744 let dict_array = builder.finish();
745
746 let regex = regex::Regex::new(r"app.*").unwrap();
748 let result = regexp_is_match_dictionary(&dict_array, Some(®ex)).unwrap();
749
750 assert_eq!(result.len(), 4);
752 assert!(result.value(0)); assert!(!result.value(1)); assert!(result.value(2)); assert!(!result.value(3)); let regex2 = regex::Regex::new(r"ban.*").unwrap();
759 let result2 = regexp_is_match_dictionary(&dict_array, Some(®ex2)).unwrap();
760
761 assert!(!result2.value(0)); assert!(result2.value(1)); assert!(!result2.value(2)); assert!(!result2.value(3)); let result3 = regexp_is_match_dictionary(&dict_array, None).unwrap();
768 assert!(result3.value(0));
769 assert!(result3.value(1));
770 assert!(result3.value(2));
771 assert!(result3.value(3));
772 }
773}