1use std::sync::Arc;
18
19use datafusion::error::Result as DfResult;
20use datafusion::logical_expr::{Expr, Literal, Operator};
21use datafusion::physical_plan::PhysicalExpr;
22use datafusion_common::arrow::array::{ArrayRef, Datum, Scalar};
23use datafusion_common::arrow::buffer::BooleanBuffer;
24use datafusion_common::arrow::compute::kernels::cmp;
25use datafusion_common::cast::{as_boolean_array, as_null_array, as_string_array};
26use datafusion_common::{internal_err, DataFusionError, ScalarValue};
27use datatypes::arrow::array::{
28 Array, ArrayAccessor, ArrayData, BooleanArray, BooleanBufferBuilder, RecordBatch,
29 StringArrayType,
30};
31use datatypes::arrow::compute::filter_record_batch;
32use datatypes::arrow::datatypes::DataType;
33use datatypes::arrow::error::ArrowError;
34use datatypes::compute::or_kleene;
35use datatypes::vectors::VectorRef;
36use regex::Regex;
37use snafu::ResultExt;
38
39use crate::error::{ArrowComputeSnafu, Result, ToArrowScalarSnafu, UnsupportedOperationSnafu};
40
41#[derive(Debug)]
51pub struct SimpleFilterEvaluator {
52 column_name: String,
54 literal: Scalar<ArrayRef>,
56 op: Operator,
58 literal_list: Vec<Scalar<ArrayRef>>,
60 regex: Option<Regex>,
64 regex_negative: bool,
66}
67
68impl SimpleFilterEvaluator {
69 pub fn new<T: Literal>(column_name: String, lit: T, op: Operator) -> Option<Self> {
70 match op {
71 Operator::Eq
72 | Operator::NotEq
73 | Operator::Lt
74 | Operator::LtEq
75 | Operator::Gt
76 | Operator::GtEq => {}
77 _ => return None,
78 }
79
80 let Expr::Literal(val) = lit.lit() else {
81 return None;
82 };
83
84 Some(Self {
85 column_name,
86 literal: val.to_scalar().ok()?,
87 op,
88 literal_list: vec![],
89 regex: None,
90 regex_negative: false,
91 })
92 }
93
94 pub fn try_new(predicate: &Expr) -> Option<Self> {
95 match predicate {
96 Expr::BinaryExpr(binary) => {
97 match binary.op {
99 Operator::Eq
100 | Operator::NotEq
101 | Operator::Lt
102 | Operator::LtEq
103 | Operator::Gt
104 | Operator::GtEq
105 | Operator::RegexMatch
106 | Operator::RegexIMatch
107 | Operator::RegexNotMatch
108 | Operator::RegexNotIMatch => {}
109 Operator::Or => {
110 let lhs = Self::try_new(&binary.left)?;
111 let rhs = Self::try_new(&binary.right)?;
112 if lhs.column_name != rhs.column_name
113 || !matches!(lhs.op, Operator::Eq | Operator::Or)
114 || !matches!(rhs.op, Operator::Eq | Operator::Or)
115 {
116 return None;
117 }
118 let mut list = vec![];
119 let placeholder_literal = lhs.literal.clone();
120 if matches!(lhs.op, Operator::Or) {
122 list.extend(lhs.literal_list);
123 } else {
124 list.push(lhs.literal);
125 }
126 if matches!(rhs.op, Operator::Or) {
127 list.extend(rhs.literal_list);
128 } else {
129 list.push(rhs.literal);
130 }
131 return Some(Self {
132 column_name: lhs.column_name,
133 literal: placeholder_literal,
134 op: Operator::Or,
135 literal_list: list,
136 regex: None,
137 regex_negative: false,
138 });
139 }
140 _ => return None,
141 }
142
143 let mut op = binary.op;
145 let (lhs, rhs) = match (&*binary.left, &*binary.right) {
146 (Expr::Column(ref col), Expr::Literal(ref lit)) => (col, lit),
147 (Expr::Literal(ref lit), Expr::Column(ref col)) => {
148 op = op.swap().unwrap();
150 (col, lit)
151 }
152 _ => return None,
153 };
154
155 let (regex, regex_negative) = Self::maybe_build_regex(op, rhs).ok()?;
156 let literal = rhs.to_scalar().ok()?;
157 Some(Self {
158 column_name: lhs.name.clone(),
159 literal,
160 op,
161 literal_list: vec![],
162 regex,
163 regex_negative,
164 })
165 }
166 _ => None,
167 }
168 }
169
170 pub fn column_name(&self) -> &str {
172 &self.column_name
173 }
174
175 pub fn evaluate_scalar(&self, input: &ScalarValue) -> Result<bool> {
176 let input = input
177 .to_scalar()
178 .with_context(|_| ToArrowScalarSnafu { v: input.clone() })?;
179 let result = self.evaluate_datum(&input, 1)?;
180 Ok(result.value(0))
181 }
182
183 pub fn evaluate_array(&self, input: &ArrayRef) -> Result<BooleanBuffer> {
184 self.evaluate_datum(input, input.len())
185 }
186
187 pub fn evaluate_vector(&self, input: &VectorRef) -> Result<BooleanBuffer> {
188 self.evaluate_datum(&input.to_arrow_array(), input.len())
189 }
190
191 fn evaluate_datum(&self, input: &impl Datum, input_len: usize) -> Result<BooleanBuffer> {
192 let result = match self.op {
193 Operator::Eq => cmp::eq(input, &self.literal),
194 Operator::NotEq => cmp::neq(input, &self.literal),
195 Operator::Lt => cmp::lt(input, &self.literal),
196 Operator::LtEq => cmp::lt_eq(input, &self.literal),
197 Operator::Gt => cmp::gt(input, &self.literal),
198 Operator::GtEq => cmp::gt_eq(input, &self.literal),
199 Operator::RegexMatch => self.regex_match(input),
200 Operator::RegexIMatch => self.regex_match(input),
201 Operator::RegexNotMatch => self.regex_match(input),
202 Operator::RegexNotIMatch => self.regex_match(input),
203 Operator::Or => {
204 let mut result: BooleanArray = vec![false; input_len].into();
206 for literal in &self.literal_list {
207 let rhs = cmp::eq(input, literal).context(ArrowComputeSnafu)?;
208 result = or_kleene(&result, &rhs).context(ArrowComputeSnafu)?;
209 }
210 Ok(result)
211 }
212 _ => {
213 return UnsupportedOperationSnafu {
214 reason: format!("{:?}", self.op),
215 }
216 .fail()
217 }
218 };
219 result
220 .context(ArrowComputeSnafu)
221 .map(|array| array.values().clone())
222 }
223
224 fn maybe_build_regex(
235 operator: Operator,
236 value: &ScalarValue,
237 ) -> Result<(Option<Regex>, bool), ArrowError> {
238 let (ignore_case, negative) = match operator {
239 Operator::RegexMatch => (false, false),
240 Operator::RegexIMatch => (true, false),
241 Operator::RegexNotMatch => (false, true),
242 Operator::RegexNotIMatch => (true, true),
243 _ => return Ok((None, false)),
244 };
245 let flag = if ignore_case { Some("i") } else { None };
246 let regex = value
247 .try_as_str()
248 .ok_or_else(|| ArrowError::CastError(format!("Cannot cast {:?} to str", value)))?
249 .ok_or_else(|| ArrowError::CastError("Regex should not be null".to_string()))?;
250 let pattern = match flag {
251 Some(flag) => format!("(?{flag}){regex}"),
252 None => regex.to_string(),
253 };
254 if pattern.is_empty() {
255 Ok((None, negative))
256 } else {
257 Regex::new(pattern.as_str())
258 .map_err(|e| {
259 ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}"))
260 })
261 .map(|regex| (Some(regex), negative))
262 }
263 }
264
265 fn regex_match(&self, input: &impl Datum) -> std::result::Result<BooleanArray, ArrowError> {
266 let array = input.get().0;
267 let string_array = as_string_array(array).map_err(|_| {
268 ArrowError::CastError(format!("Cannot cast {:?} to StringArray", array))
269 })?;
270 let mut result = regexp_is_match_scalar(string_array, self.regex.as_ref())?;
271 if self.regex_negative {
272 result = datatypes::compute::not(&result)?;
273 }
274 Ok(result)
275 }
276}
277
278pub fn batch_filter(
281 batch: &RecordBatch,
282 predicate: &Arc<dyn PhysicalExpr>,
283) -> DfResult<RecordBatch> {
284 predicate
285 .evaluate(batch)
286 .and_then(|v| v.into_array(batch.num_rows()))
287 .and_then(|array| {
288 let filter_array = match as_boolean_array(&array) {
289 Ok(boolean_array) => Ok(boolean_array.clone()),
290 Err(_) => {
291 let Ok(null_array) = as_null_array(&array) else {
292 return internal_err!(
293 "Cannot create filter_array from non-boolean predicates"
294 );
295 };
296
297 Ok::<BooleanArray, DataFusionError>(BooleanArray::new_null(null_array.len()))
299 }
300 }?;
301 Ok(filter_record_batch(batch, &filter_array)?)
302 })
303}
304
305pub fn regexp_is_match_scalar<'a, S>(
309 array: &'a S,
310 regex: Option<&Regex>,
311) -> Result<BooleanArray, ArrowError>
312where
313 &'a S: StringArrayType<'a>,
314{
315 let null_bit_buffer = array.nulls().map(|x| x.inner().sliced());
316 let mut result = BooleanBufferBuilder::new(array.len());
317
318 if let Some(re) = regex {
319 for i in 0..array.len() {
320 let value = array.value(i);
321 result.append(re.is_match(value));
322 }
323 } else {
324 result.append_n(array.len(), true);
325 }
326
327 let buffer = result.into();
328 let data = unsafe {
329 ArrayData::new_unchecked(
330 DataType::Boolean,
331 array.len(),
332 None,
333 null_bit_buffer,
334 0,
335 vec![buffer],
336 vec![],
337 )
338 };
339
340 Ok(BooleanArray::from(data))
341}
342
343#[cfg(test)]
344mod test {
345
346 use std::sync::Arc;
347
348 use datafusion::execution::context::ExecutionProps;
349 use datafusion::logical_expr::{col, lit, BinaryExpr};
350 use datafusion::physical_expr::create_physical_expr;
351 use datafusion_common::{Column, DFSchema};
352 use datatypes::arrow::datatypes::{DataType, Field, Schema};
353
354 use super::*;
355
356 #[test]
357 fn unsupported_filter_op() {
358 let expr = Expr::BinaryExpr(BinaryExpr {
360 left: Box::new(Expr::Column(Column::from_name("foo"))),
361 op: Operator::Plus,
362 right: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))),
363 });
364 assert!(SimpleFilterEvaluator::try_new(&expr).is_none());
365
366 let expr = Expr::BinaryExpr(BinaryExpr {
368 left: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))),
369 op: Operator::Eq,
370 right: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))),
371 });
372 assert!(SimpleFilterEvaluator::try_new(&expr).is_none());
373
374 let expr = Expr::BinaryExpr(BinaryExpr {
376 left: Box::new(Expr::Column(Column::from_name("foo"))),
377 op: Operator::Eq,
378 right: Box::new(Expr::Column(Column::from_name("bar"))),
379 });
380 assert!(SimpleFilterEvaluator::try_new(&expr).is_none());
381
382 let expr = Expr::BinaryExpr(BinaryExpr {
384 left: Box::new(Expr::BinaryExpr(BinaryExpr {
385 left: Box::new(Expr::Column(Column::from_name("foo"))),
386 op: Operator::Eq,
387 right: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))),
388 })),
389 op: Operator::Eq,
390 right: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))),
391 });
392 assert!(SimpleFilterEvaluator::try_new(&expr).is_none());
393 }
394
395 #[test]
396 fn supported_filter_op() {
397 let expr = Expr::BinaryExpr(BinaryExpr {
399 left: Box::new(Expr::Column(Column::from_name("foo"))),
400 op: Operator::Eq,
401 right: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))),
402 });
403 let _ = SimpleFilterEvaluator::try_new(&expr).unwrap();
404
405 let expr = Expr::BinaryExpr(BinaryExpr {
407 left: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))),
408 op: Operator::Lt,
409 right: Box::new(Expr::Column(Column::from_name("foo"))),
410 });
411 let evaluator = SimpleFilterEvaluator::try_new(&expr).unwrap();
412 assert_eq!(evaluator.op, Operator::Gt);
413 assert_eq!(evaluator.column_name, "foo".to_string());
414 }
415
416 #[test]
417 fn run_on_array() {
418 let expr = Expr::BinaryExpr(BinaryExpr {
419 left: Box::new(Expr::Column(Column::from_name("foo"))),
420 op: Operator::Eq,
421 right: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))),
422 });
423 let evaluator = SimpleFilterEvaluator::try_new(&expr).unwrap();
424
425 let input_1 = Arc::new(datatypes::arrow::array::Int64Array::from(vec![1, 2, 3])) as _;
426 let result = evaluator.evaluate_array(&input_1).unwrap();
427 assert_eq!(result, BooleanBuffer::from(vec![true, false, false]));
428
429 let input_2 = Arc::new(datatypes::arrow::array::Int64Array::from(vec![1, 1, 1])) as _;
430 let result = evaluator.evaluate_array(&input_2).unwrap();
431 assert_eq!(result, BooleanBuffer::from(vec![true, true, true]));
432
433 let input_3 = Arc::new(datatypes::arrow::array::Int64Array::new_null(0)) as _;
434 let result = evaluator.evaluate_array(&input_3).unwrap();
435 assert_eq!(result, BooleanBuffer::from(vec![]));
436 }
437
438 #[test]
439 fn run_on_scalar() {
440 let expr = Expr::BinaryExpr(BinaryExpr {
441 left: Box::new(Expr::Column(Column::from_name("foo"))),
442 op: Operator::Lt,
443 right: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))),
444 });
445 let evaluator = SimpleFilterEvaluator::try_new(&expr).unwrap();
446
447 let input_1 = ScalarValue::Int64(Some(1));
448 let result = evaluator.evaluate_scalar(&input_1).unwrap();
449 assert!(!result);
450
451 let input_2 = ScalarValue::Int64(Some(0));
452 let result = evaluator.evaluate_scalar(&input_2).unwrap();
453 assert!(result);
454
455 let input_3 = ScalarValue::Int64(None);
456 let result = evaluator.evaluate_scalar(&input_3).unwrap();
457 assert!(!result);
458 }
459
460 #[test]
461 fn batch_filter_test() {
462 let expr = col("ts").gt(lit(123456u64));
463 let schema = Schema::new(vec![
464 Field::new("a", DataType::Int32, true),
465 Field::new("ts", DataType::UInt64, false),
466 ]);
467 let df_schema = DFSchema::try_from(schema.clone()).unwrap();
468 let props = ExecutionProps::new();
469 let physical_expr = create_physical_expr(&expr, &df_schema, &props).unwrap();
470 let batch = RecordBatch::try_new(
471 Arc::new(schema),
472 vec![
473 Arc::new(datatypes::arrow::array::Int32Array::from(vec![4, 5, 6])),
474 Arc::new(datatypes::arrow::array::UInt64Array::from(vec![
475 123456, 123457, 123458,
476 ])),
477 ],
478 )
479 .unwrap();
480 let new_batch = batch_filter(&batch, &physical_expr).unwrap();
481 assert_eq!(new_batch.num_rows(), 2);
482 let first_column_values = new_batch
483 .column(0)
484 .as_any()
485 .downcast_ref::<datatypes::arrow::array::Int32Array>()
486 .unwrap();
487 let expected = datatypes::arrow::array::Int32Array::from(vec![5, 6]);
488 assert_eq!(first_column_values, &expected);
489 }
490
491 #[test]
492 fn test_complex_filter_expression() {
493 let col_eq_b = col("col").eq(lit("B"));
495 let col_eq_c = col("col").eq(lit("C"));
496 let col_eq_d = col("col").eq(lit("D"));
497
498 let col_or_expr = col_eq_b.or(col_eq_c).or(col_eq_d);
500
501 let or_evaluator = SimpleFilterEvaluator::try_new(&col_or_expr).unwrap();
503 assert_eq!(or_evaluator.column_name, "col");
504 assert_eq!(or_evaluator.op, Operator::Or);
505 assert_eq!(or_evaluator.literal_list.len(), 3);
506 assert_eq!(format!("{:?}", or_evaluator.literal_list), "[Scalar(StringArray\n[\n \"B\",\n]), Scalar(StringArray\n[\n \"C\",\n]), Scalar(StringArray\n[\n \"D\",\n])]");
507
508 let schema = Schema::new(vec![Field::new("col", DataType::Utf8, false)]);
510 let df_schema = DFSchema::try_from(schema.clone()).unwrap();
511 let props = ExecutionProps::new();
512 let physical_expr = create_physical_expr(&col_or_expr, &df_schema, &props).unwrap();
513
514 let col_data = Arc::new(datatypes::arrow::array::StringArray::from(vec![
516 "B", "C", "E", "B", "C", "D", "F",
517 ]));
518 let batch = RecordBatch::try_new(Arc::new(schema), vec![col_data]).unwrap();
519 let expected = datatypes::arrow::array::StringArray::from(vec!["B", "C", "B", "C", "D"]);
520
521 let filtered_batch = batch_filter(&batch, &physical_expr).unwrap();
523
524 assert_eq!(filtered_batch.num_rows(), 5);
527
528 let col_filtered = filtered_batch
529 .column(0)
530 .as_any()
531 .downcast_ref::<datatypes::arrow::array::StringArray>()
532 .unwrap();
533 assert_eq!(col_filtered, &expected);
534 }
535
536 #[test]
537 fn test_maybe_build_regex() {
538 let (regex, negative) = SimpleFilterEvaluator::maybe_build_regex(
540 Operator::RegexMatch,
541 &ScalarValue::Utf8(Some("a.*b".to_string())),
542 )
543 .unwrap();
544 assert!(regex.is_some());
545 assert!(!negative);
546 assert!(regex.unwrap().is_match("axxb"));
547
548 let (regex, negative) = SimpleFilterEvaluator::maybe_build_regex(
550 Operator::RegexIMatch,
551 &ScalarValue::Utf8(Some("a.*b".to_string())),
552 )
553 .unwrap();
554 assert!(regex.is_some());
555 assert!(!negative);
556 assert!(regex.unwrap().is_match("AxxB"));
557
558 let (regex, negative) = SimpleFilterEvaluator::maybe_build_regex(
560 Operator::RegexNotMatch,
561 &ScalarValue::Utf8(Some("a.*b".to_string())),
562 )
563 .unwrap();
564 assert!(regex.is_some());
565 assert!(negative);
566
567 let (regex, negative) = SimpleFilterEvaluator::maybe_build_regex(
569 Operator::RegexNotIMatch,
570 &ScalarValue::Utf8(Some("a.*b".to_string())),
571 )
572 .unwrap();
573 assert!(regex.is_some());
574 assert!(negative);
575
576 let (regex, negative) = SimpleFilterEvaluator::maybe_build_regex(
578 Operator::RegexMatch,
579 &ScalarValue::Utf8(Some("".to_string())),
580 )
581 .unwrap();
582 assert!(regex.is_none());
583 assert!(!negative);
584
585 let (regex, negative) = SimpleFilterEvaluator::maybe_build_regex(
587 Operator::Eq,
588 &ScalarValue::Utf8(Some("a.*b".to_string())),
589 )
590 .unwrap();
591 assert!(regex.is_none());
592 assert!(!negative);
593
594 let result = SimpleFilterEvaluator::maybe_build_regex(
596 Operator::RegexMatch,
597 &ScalarValue::Utf8(Some("a(b".to_string())),
598 );
599 assert!(result.is_err());
600
601 let result = SimpleFilterEvaluator::maybe_build_regex(
603 Operator::RegexMatch,
604 &ScalarValue::Int64(Some(123)),
605 );
606 assert!(result.is_err());
607
608 let result = SimpleFilterEvaluator::maybe_build_regex(
610 Operator::RegexMatch,
611 &ScalarValue::Utf8(None),
612 );
613 assert!(result.is_err());
614 }
615}