1use std::sync::Arc;
18
19use datafusion::error::Result as DfResult;
20use datafusion::logical_expr::{Expr, Literal, Operator};
21use datafusion::physical_plan::PhysicalExpr;
22use datafusion_common::arrow::array::{ArrayRef, Datum, Scalar};
23use datafusion_common::arrow::buffer::BooleanBuffer;
24use datafusion_common::arrow::compute::kernels::cmp;
25use datafusion_common::cast::{as_boolean_array, as_null_array, as_string_array};
26use datafusion_common::{DataFusionError, ScalarValue, internal_err};
27use datatypes::arrow::array::{
28 Array, ArrayAccessor, ArrayData, BooleanArray, BooleanBufferBuilder, DictionaryArray,
29 RecordBatch, StringArrayType,
30};
31use datatypes::arrow::compute::filter_record_batch;
32use datatypes::arrow::datatypes::{DataType, UInt32Type};
33use datatypes::arrow::error::ArrowError;
34use datatypes::compute::or_kleene;
35use datatypes::vectors::VectorRef;
36use regex::Regex;
37use snafu::ResultExt;
38
39use crate::error::{ArrowComputeSnafu, Result, ToArrowScalarSnafu, UnsupportedOperationSnafu};
40
41#[derive(Debug)]
51pub struct SimpleFilterEvaluator {
52 column_name: String,
54 literal: Scalar<ArrayRef>,
56 op: Operator,
58 literal_list: Vec<Scalar<ArrayRef>>,
60 regex: Option<Regex>,
64 regex_negative: bool,
66}
67
68impl SimpleFilterEvaluator {
69 pub fn new<T: Literal>(column_name: String, lit: T, op: Operator) -> Option<Self> {
70 match op {
71 Operator::Eq
72 | Operator::NotEq
73 | Operator::Lt
74 | Operator::LtEq
75 | Operator::Gt
76 | Operator::GtEq => {}
77 _ => return None,
78 }
79
80 let Expr::Literal(val, _) = lit.lit() else {
81 return None;
82 };
83
84 Some(Self {
85 column_name,
86 literal: val.to_scalar().ok()?,
87 op,
88 literal_list: vec![],
89 regex: None,
90 regex_negative: false,
91 })
92 }
93
94 pub fn try_new(predicate: &Expr) -> Option<Self> {
95 match predicate {
96 Expr::BinaryExpr(binary) => {
97 match binary.op {
99 Operator::Eq
100 | Operator::NotEq
101 | Operator::Lt
102 | Operator::LtEq
103 | Operator::Gt
104 | Operator::GtEq
105 | Operator::RegexMatch
106 | Operator::RegexIMatch
107 | Operator::RegexNotMatch
108 | Operator::RegexNotIMatch => {}
109 Operator::Or => {
110 let lhs = Self::try_new(&binary.left)?;
111 let rhs = Self::try_new(&binary.right)?;
112 if lhs.column_name != rhs.column_name
113 || !matches!(lhs.op, Operator::Eq | Operator::Or)
114 || !matches!(rhs.op, Operator::Eq | Operator::Or)
115 {
116 return None;
117 }
118 let mut list = vec![];
119 let placeholder_literal = lhs.literal.clone();
120 if matches!(lhs.op, Operator::Or) {
122 list.extend(lhs.literal_list);
123 } else {
124 list.push(lhs.literal);
125 }
126 if matches!(rhs.op, Operator::Or) {
127 list.extend(rhs.literal_list);
128 } else {
129 list.push(rhs.literal);
130 }
131 return Some(Self {
132 column_name: lhs.column_name,
133 literal: placeholder_literal,
134 op: Operator::Or,
135 literal_list: list,
136 regex: None,
137 regex_negative: false,
138 });
139 }
140 _ => return None,
141 }
142
143 let mut op = binary.op;
145 let (lhs, rhs) = match (&*binary.left, &*binary.right) {
146 (Expr::Column(col), Expr::Literal(lit, _)) => (col, lit),
147 (Expr::Literal(lit, _), Expr::Column(col)) => {
148 op = op.swap().unwrap();
150 (col, lit)
151 }
152 _ => return None,
153 };
154
155 let (regex, regex_negative) = Self::maybe_build_regex(op, rhs).ok()?;
156 let literal = rhs.to_scalar().ok()?;
157 Some(Self {
158 column_name: lhs.name.clone(),
159 literal,
160 op,
161 literal_list: vec![],
162 regex,
163 regex_negative,
164 })
165 }
166 _ => None,
167 }
168 }
169
170 pub fn column_name(&self) -> &str {
172 &self.column_name
173 }
174
175 pub fn evaluate_scalar(&self, input: &ScalarValue) -> Result<bool> {
176 let input = input
177 .to_scalar()
178 .with_context(|_| ToArrowScalarSnafu { v: input.clone() })?;
179 let result = self.evaluate_datum(&input, 1)?;
180 Ok(result.value(0))
181 }
182
183 pub fn evaluate_array(&self, input: &ArrayRef) -> Result<BooleanBuffer> {
184 self.evaluate_datum(input, input.len())
185 }
186
187 pub fn evaluate_vector(&self, input: &VectorRef) -> Result<BooleanBuffer> {
188 self.evaluate_datum(&input.to_arrow_array(), input.len())
189 }
190
191 fn evaluate_datum(&self, input: &impl Datum, input_len: usize) -> Result<BooleanBuffer> {
192 let result = match self.op {
193 Operator::Eq => cmp::eq(input, &self.literal),
194 Operator::NotEq => cmp::neq(input, &self.literal),
195 Operator::Lt => cmp::lt(input, &self.literal),
196 Operator::LtEq => cmp::lt_eq(input, &self.literal),
197 Operator::Gt => cmp::gt(input, &self.literal),
198 Operator::GtEq => cmp::gt_eq(input, &self.literal),
199 Operator::RegexMatch => self.regex_match(input),
200 Operator::RegexIMatch => self.regex_match(input),
201 Operator::RegexNotMatch => self.regex_match(input),
202 Operator::RegexNotIMatch => self.regex_match(input),
203 Operator::Or => {
204 let mut result: BooleanArray = vec![false; input_len].into();
206 for literal in &self.literal_list {
207 let rhs = cmp::eq(input, literal).context(ArrowComputeSnafu)?;
208 result = or_kleene(&result, &rhs).context(ArrowComputeSnafu)?;
209 }
210 Ok(result)
211 }
212 _ => {
213 return UnsupportedOperationSnafu {
214 reason: format!("{:?}", self.op),
215 }
216 .fail();
217 }
218 };
219 result
220 .context(ArrowComputeSnafu)
221 .map(|array| array.values().clone())
222 }
223
224 fn maybe_build_regex(
235 operator: Operator,
236 value: &ScalarValue,
237 ) -> Result<(Option<Regex>, bool), ArrowError> {
238 let (ignore_case, negative) = match operator {
239 Operator::RegexMatch => (false, false),
240 Operator::RegexIMatch => (true, false),
241 Operator::RegexNotMatch => (false, true),
242 Operator::RegexNotIMatch => (true, true),
243 _ => return Ok((None, false)),
244 };
245 let flag = if ignore_case { Some("i") } else { None };
246 let regex = value
247 .try_as_str()
248 .ok_or_else(|| ArrowError::CastError(format!("Cannot cast {:?} to str", value)))?
249 .ok_or_else(|| ArrowError::CastError("Regex should not be null".to_string()))?;
250 let pattern = match flag {
251 Some(flag) => format!("(?{flag}){regex}"),
252 None => regex.to_string(),
253 };
254 if pattern.is_empty() {
255 Ok((None, negative))
256 } else {
257 Regex::new(pattern.as_str())
258 .map_err(|e| {
259 ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}"))
260 })
261 .map(|regex| (Some(regex), negative))
262 }
263 }
264
265 fn regex_match(&self, input: &impl Datum) -> std::result::Result<BooleanArray, ArrowError> {
266 let array = input.get().0;
267
268 if let Ok(string_array) = as_string_array(array) {
270 let mut result = regexp_is_match_scalar(string_array, self.regex.as_ref())?;
271 if self.regex_negative {
272 result = datatypes::compute::not(&result)?;
273 }
274 return Ok(result);
275 }
276
277 if let Some(dict_array) = array.as_any().downcast_ref::<DictionaryArray<UInt32Type>>() {
279 let mut result = regexp_is_match_dictionary(dict_array, self.regex.as_ref())?;
280 if self.regex_negative {
281 result = datatypes::compute::not(&result)?;
282 }
283 return Ok(result);
284 }
285
286 Err(ArrowError::CastError(format!(
287 "Cannot cast {:?} to StringArray or StringDictionaryArray",
288 array.data_type()
289 )))
290 }
291}
292
293pub fn batch_filter(
296 batch: &RecordBatch,
297 predicate: &Arc<dyn PhysicalExpr>,
298) -> DfResult<RecordBatch> {
299 predicate
300 .evaluate(batch)
301 .and_then(|v| v.into_array(batch.num_rows()))
302 .and_then(|array| {
303 let filter_array = match as_boolean_array(&array) {
304 Ok(boolean_array) => Ok(boolean_array.clone()),
305 Err(_) => {
306 let Ok(null_array) = as_null_array(&array) else {
307 return internal_err!(
308 "Cannot create filter_array from non-boolean predicates"
309 );
310 };
311
312 Ok::<BooleanArray, DataFusionError>(BooleanArray::new_null(null_array.len()))
314 }
315 }?;
316 Ok(filter_record_batch(batch, &filter_array)?)
317 })
318}
319
320pub fn regexp_is_match_scalar<'a, S>(
324 array: &'a S,
325 regex: Option<&Regex>,
326) -> Result<BooleanArray, ArrowError>
327where
328 &'a S: StringArrayType<'a>,
329{
330 let null_bit_buffer = array.nulls().map(|x| x.inner().sliced());
331 let mut result = BooleanBufferBuilder::new(array.len());
332
333 if let Some(re) = regex {
334 for i in 0..array.len() {
335 let value = array.value(i);
336 result.append(re.is_match(value));
337 }
338 } else {
339 result.append_n(array.len(), true);
340 }
341
342 let buffer = result.into();
343 let data = unsafe {
344 ArrayData::new_unchecked(
345 DataType::Boolean,
346 array.len(),
347 None,
348 null_bit_buffer,
349 0,
350 vec![buffer],
351 vec![],
352 )
353 };
354
355 Ok(BooleanArray::from(data))
356}
357
358pub fn regexp_is_match_dictionary(
361 dict_array: &DictionaryArray<UInt32Type>,
362 regex: Option<&Regex>,
363) -> Result<BooleanArray, ArrowError> {
364 let string_values = dict_array
366 .values()
367 .as_any()
368 .downcast_ref::<datatypes::arrow::array::StringArray>()
369 .ok_or_else(|| {
370 ArrowError::CastError("Dictionary values must be StringArray".to_string())
371 })?;
372
373 let null_bit_buffer = dict_array.nulls().map(|x| x.inner().sliced());
374 let mut result = BooleanBufferBuilder::new(dict_array.len());
375
376 if let Some(re) = regex {
377 let keys = dict_array.keys().values();
378 for i in 0..dict_array.len() {
379 if dict_array.is_null(i) {
380 result.append(false);
381 } else {
382 let key = keys[i] as usize;
383 let string_value = string_values.value(key);
384 result.append(re.is_match(string_value));
385 }
386 }
387 } else {
388 result.append_n(dict_array.len(), true);
389 }
390
391 let buffer = result.into();
392 let data = unsafe {
393 ArrayData::new_unchecked(
394 DataType::Boolean,
395 dict_array.len(),
396 None,
397 null_bit_buffer,
398 0,
399 vec![buffer],
400 vec![],
401 )
402 };
403
404 Ok(BooleanArray::from(data))
405}
406
407#[cfg(test)]
408mod test {
409
410 use std::sync::Arc;
411
412 use datafusion::execution::context::ExecutionProps;
413 use datafusion::logical_expr::{BinaryExpr, col, lit};
414 use datafusion::physical_expr::create_physical_expr;
415 use datafusion_common::{Column, DFSchema};
416 use datatypes::arrow::datatypes::{DataType, Field, Schema};
417
418 use super::*;
419
420 #[test]
421 fn unsupported_filter_op() {
422 let expr = Expr::BinaryExpr(BinaryExpr {
424 left: Box::new(Expr::Column(Column::from_name("foo"))),
425 op: Operator::Plus,
426 right: Box::new(1.lit()),
427 });
428 assert!(SimpleFilterEvaluator::try_new(&expr).is_none());
429
430 let expr = Expr::BinaryExpr(BinaryExpr {
432 left: Box::new(1.lit()),
433 op: Operator::Eq,
434 right: Box::new(1.lit()),
435 });
436 assert!(SimpleFilterEvaluator::try_new(&expr).is_none());
437
438 let expr = Expr::BinaryExpr(BinaryExpr {
440 left: Box::new(Expr::Column(Column::from_name("foo"))),
441 op: Operator::Eq,
442 right: Box::new(Expr::Column(Column::from_name("bar"))),
443 });
444 assert!(SimpleFilterEvaluator::try_new(&expr).is_none());
445
446 let expr = Expr::BinaryExpr(BinaryExpr {
448 left: Box::new(Expr::BinaryExpr(BinaryExpr {
449 left: Box::new(Expr::Column(Column::from_name("foo"))),
450 op: Operator::Eq,
451 right: Box::new(1.lit()),
452 })),
453 op: Operator::Eq,
454 right: Box::new(1.lit()),
455 });
456 assert!(SimpleFilterEvaluator::try_new(&expr).is_none());
457 }
458
459 #[test]
460 fn supported_filter_op() {
461 let expr = Expr::BinaryExpr(BinaryExpr {
463 left: Box::new(Expr::Column(Column::from_name("foo"))),
464 op: Operator::Eq,
465 right: Box::new(1.lit()),
466 });
467 let _ = SimpleFilterEvaluator::try_new(&expr).unwrap();
468
469 let expr = Expr::BinaryExpr(BinaryExpr {
471 left: Box::new(1.lit()),
472 op: Operator::Lt,
473 right: Box::new(Expr::Column(Column::from_name("foo"))),
474 });
475 let evaluator = SimpleFilterEvaluator::try_new(&expr).unwrap();
476 assert_eq!(evaluator.op, Operator::Gt);
477 assert_eq!(evaluator.column_name, "foo".to_string());
478 }
479
480 #[test]
481 fn run_on_array() {
482 let expr = Expr::BinaryExpr(BinaryExpr {
483 left: Box::new(Expr::Column(Column::from_name("foo"))),
484 op: Operator::Eq,
485 right: Box::new(1i64.lit()),
486 });
487 let evaluator = SimpleFilterEvaluator::try_new(&expr).unwrap();
488
489 let input_1 = Arc::new(datatypes::arrow::array::Int64Array::from(vec![1, 2, 3])) as _;
490 let result = evaluator.evaluate_array(&input_1).unwrap();
491 assert_eq!(result, BooleanBuffer::from(vec![true, false, false]));
492
493 let input_2 = Arc::new(datatypes::arrow::array::Int64Array::from(vec![1, 1, 1])) as _;
494 let result = evaluator.evaluate_array(&input_2).unwrap();
495 assert_eq!(result, BooleanBuffer::from(vec![true, true, true]));
496
497 let input_3 = Arc::new(datatypes::arrow::array::Int64Array::new_null(0)) as _;
498 let result = evaluator.evaluate_array(&input_3).unwrap();
499 assert_eq!(result, BooleanBuffer::from(vec![]));
500 }
501
502 #[test]
503 fn run_on_scalar() {
504 let expr = Expr::BinaryExpr(BinaryExpr {
505 left: Box::new(Expr::Column(Column::from_name("foo"))),
506 op: Operator::Lt,
507 right: Box::new(1i64.lit()),
508 });
509 let evaluator = SimpleFilterEvaluator::try_new(&expr).unwrap();
510
511 let input_1 = ScalarValue::Int64(Some(1));
512 let result = evaluator.evaluate_scalar(&input_1).unwrap();
513 assert!(!result);
514
515 let input_2 = ScalarValue::Int64(Some(0));
516 let result = evaluator.evaluate_scalar(&input_2).unwrap();
517 assert!(result);
518
519 let input_3 = ScalarValue::Int64(None);
520 let result = evaluator.evaluate_scalar(&input_3).unwrap();
521 assert!(!result);
522 }
523
524 #[test]
525 fn batch_filter_test() {
526 let expr = col("ts").gt(lit(123456u64));
527 let schema = Schema::new(vec![
528 Field::new("a", DataType::Int32, true),
529 Field::new("ts", DataType::UInt64, false),
530 ]);
531 let df_schema = DFSchema::try_from(schema.clone()).unwrap();
532 let props = ExecutionProps::new();
533 let physical_expr = create_physical_expr(&expr, &df_schema, &props).unwrap();
534 let batch = RecordBatch::try_new(
535 Arc::new(schema),
536 vec![
537 Arc::new(datatypes::arrow::array::Int32Array::from(vec![4, 5, 6])),
538 Arc::new(datatypes::arrow::array::UInt64Array::from(vec![
539 123456, 123457, 123458,
540 ])),
541 ],
542 )
543 .unwrap();
544 let new_batch = batch_filter(&batch, &physical_expr).unwrap();
545 assert_eq!(new_batch.num_rows(), 2);
546 let first_column_values = new_batch
547 .column(0)
548 .as_any()
549 .downcast_ref::<datatypes::arrow::array::Int32Array>()
550 .unwrap();
551 let expected = datatypes::arrow::array::Int32Array::from(vec![5, 6]);
552 assert_eq!(first_column_values, &expected);
553 }
554
555 #[test]
556 fn test_complex_filter_expression() {
557 let col_eq_b = col("col").eq(lit("B"));
559 let col_eq_c = col("col").eq(lit("C"));
560 let col_eq_d = col("col").eq(lit("D"));
561
562 let col_or_expr = col_eq_b.or(col_eq_c).or(col_eq_d);
564
565 let or_evaluator = SimpleFilterEvaluator::try_new(&col_or_expr).unwrap();
567 assert_eq!(or_evaluator.column_name, "col");
568 assert_eq!(or_evaluator.op, Operator::Or);
569 assert_eq!(or_evaluator.literal_list.len(), 3);
570 assert_eq!(
571 format!("{:?}", or_evaluator.literal_list),
572 "[Scalar(StringArray\n[\n \"B\",\n]), Scalar(StringArray\n[\n \"C\",\n]), Scalar(StringArray\n[\n \"D\",\n])]"
573 );
574
575 let schema = Schema::new(vec![Field::new("col", DataType::Utf8, false)]);
577 let df_schema = DFSchema::try_from(schema.clone()).unwrap();
578 let props = ExecutionProps::new();
579 let physical_expr = create_physical_expr(&col_or_expr, &df_schema, &props).unwrap();
580
581 let col_data = Arc::new(datatypes::arrow::array::StringArray::from(vec![
583 "B", "C", "E", "B", "C", "D", "F",
584 ]));
585 let batch = RecordBatch::try_new(Arc::new(schema), vec![col_data]).unwrap();
586 let expected = datatypes::arrow::array::StringArray::from(vec!["B", "C", "B", "C", "D"]);
587
588 let filtered_batch = batch_filter(&batch, &physical_expr).unwrap();
590
591 assert_eq!(filtered_batch.num_rows(), 5);
594
595 let col_filtered = filtered_batch
596 .column(0)
597 .as_any()
598 .downcast_ref::<datatypes::arrow::array::StringArray>()
599 .unwrap();
600 assert_eq!(col_filtered, &expected);
601 }
602
603 #[test]
604 fn test_maybe_build_regex() {
605 let (regex, negative) = SimpleFilterEvaluator::maybe_build_regex(
607 Operator::RegexMatch,
608 &ScalarValue::Utf8(Some("a.*b".to_string())),
609 )
610 .unwrap();
611 assert!(regex.is_some());
612 assert!(!negative);
613 assert!(regex.unwrap().is_match("axxb"));
614
615 let (regex, negative) = SimpleFilterEvaluator::maybe_build_regex(
617 Operator::RegexIMatch,
618 &ScalarValue::Utf8(Some("a.*b".to_string())),
619 )
620 .unwrap();
621 assert!(regex.is_some());
622 assert!(!negative);
623 assert!(regex.unwrap().is_match("AxxB"));
624
625 let (regex, negative) = SimpleFilterEvaluator::maybe_build_regex(
627 Operator::RegexNotMatch,
628 &ScalarValue::Utf8(Some("a.*b".to_string())),
629 )
630 .unwrap();
631 assert!(regex.is_some());
632 assert!(negative);
633
634 let (regex, negative) = SimpleFilterEvaluator::maybe_build_regex(
636 Operator::RegexNotIMatch,
637 &ScalarValue::Utf8(Some("a.*b".to_string())),
638 )
639 .unwrap();
640 assert!(regex.is_some());
641 assert!(negative);
642
643 let (regex, negative) = SimpleFilterEvaluator::maybe_build_regex(
645 Operator::RegexMatch,
646 &ScalarValue::Utf8(Some("".to_string())),
647 )
648 .unwrap();
649 assert!(regex.is_none());
650 assert!(!negative);
651
652 let (regex, negative) = SimpleFilterEvaluator::maybe_build_regex(
654 Operator::Eq,
655 &ScalarValue::Utf8(Some("a.*b".to_string())),
656 )
657 .unwrap();
658 assert!(regex.is_none());
659 assert!(!negative);
660
661 let result = SimpleFilterEvaluator::maybe_build_regex(
663 Operator::RegexMatch,
664 &ScalarValue::Utf8(Some("a(b".to_string())),
665 );
666 assert!(result.is_err());
667
668 let result = SimpleFilterEvaluator::maybe_build_regex(
670 Operator::RegexMatch,
671 &ScalarValue::Int64(Some(123)),
672 );
673 assert!(result.is_err());
674
675 let result = SimpleFilterEvaluator::maybe_build_regex(
677 Operator::RegexMatch,
678 &ScalarValue::Utf8(None),
679 );
680 assert!(result.is_err());
681 }
682
683 #[test]
684 fn test_regex_match_dictionary_array() {
685 use datatypes::arrow::array::StringDictionaryBuilder;
686
687 let mut builder = StringDictionaryBuilder::<UInt32Type>::new();
689 builder.append("apple").unwrap();
690 builder.append("banana").unwrap();
691 builder.append("apple").unwrap();
692 builder.append("cherry").unwrap();
693 let dict_array = builder.finish();
694
695 let regex = regex::Regex::new(r"app.*").unwrap();
697 let result = regexp_is_match_dictionary(&dict_array, Some(®ex)).unwrap();
698
699 assert_eq!(result.len(), 4);
701 assert!(result.value(0)); assert!(!result.value(1)); assert!(result.value(2)); assert!(!result.value(3)); let regex2 = regex::Regex::new(r"ban.*").unwrap();
708 let result2 = regexp_is_match_dictionary(&dict_array, Some(®ex2)).unwrap();
709
710 assert!(!result2.value(0)); assert!(result2.value(1)); assert!(!result2.value(2)); assert!(!result2.value(3)); let result3 = regexp_is_match_dictionary(&dict_array, None).unwrap();
717 assert!(result3.value(0));
718 assert!(result3.value(1));
719 assert!(result3.value(2));
720 assert!(result3.value(3));
721 }
722}