common_recordbatch/
filter.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Util record batch stream wrapper that can perform precise filter.
16
17use std::sync::Arc;
18
19use datafusion::error::Result as DfResult;
20use datafusion::logical_expr::{Expr, Literal, Operator};
21use datafusion::physical_plan::PhysicalExpr;
22use datafusion_common::arrow::array::{ArrayRef, Datum, Scalar};
23use datafusion_common::arrow::buffer::BooleanBuffer;
24use datafusion_common::arrow::compute::kernels::cmp;
25use datafusion_common::cast::{as_boolean_array, as_null_array, as_string_array};
26use datafusion_common::{internal_err, DataFusionError, ScalarValue};
27use datatypes::arrow::array::{
28    Array, ArrayAccessor, ArrayData, BooleanArray, BooleanBufferBuilder, RecordBatch,
29    StringArrayType,
30};
31use datatypes::arrow::compute::filter_record_batch;
32use datatypes::arrow::datatypes::DataType;
33use datatypes::arrow::error::ArrowError;
34use datatypes::compute::or_kleene;
35use datatypes::vectors::VectorRef;
36use regex::Regex;
37use snafu::ResultExt;
38
39use crate::error::{ArrowComputeSnafu, Result, ToArrowScalarSnafu, UnsupportedOperationSnafu};
40
41/// An inplace expr evaluator for simple filter. Only support
42/// - `col` `op` `literal`
43/// - `literal` `op` `col`
44///
45/// And the `op` is one of `=`, `!=`, `>`, `>=`, `<`, `<=`,
46/// or regex operators: `~`, `~*`, `!~`, `!~*`.
47///
48/// This struct contains normalized predicate expr. In the form of
49/// `col` `op` `literal` where the `col` is provided from input.
50#[derive(Debug)]
51pub struct SimpleFilterEvaluator {
52    /// Name of the referenced column.
53    column_name: String,
54    /// The literal value.
55    literal: Scalar<ArrayRef>,
56    /// The operator.
57    op: Operator,
58    /// Only used when the operator is `Or`-chain.
59    literal_list: Vec<Scalar<ArrayRef>>,
60    /// Pre-compiled regex.
61    /// Only used when the operator is regex operators.
62    /// If the regex is empty, it is also `None`.
63    regex: Option<Regex>,
64    /// Whether the regex is negative.
65    regex_negative: bool,
66}
67
68impl SimpleFilterEvaluator {
69    pub fn new<T: Literal>(column_name: String, lit: T, op: Operator) -> Option<Self> {
70        match op {
71            Operator::Eq
72            | Operator::NotEq
73            | Operator::Lt
74            | Operator::LtEq
75            | Operator::Gt
76            | Operator::GtEq => {}
77            _ => return None,
78        }
79
80        let Expr::Literal(val) = lit.lit() else {
81            return None;
82        };
83
84        Some(Self {
85            column_name,
86            literal: val.to_scalar().ok()?,
87            op,
88            literal_list: vec![],
89            regex: None,
90            regex_negative: false,
91        })
92    }
93
94    pub fn try_new(predicate: &Expr) -> Option<Self> {
95        match predicate {
96            Expr::BinaryExpr(binary) => {
97                // check if the expr is in the supported form
98                match binary.op {
99                    Operator::Eq
100                    | Operator::NotEq
101                    | Operator::Lt
102                    | Operator::LtEq
103                    | Operator::Gt
104                    | Operator::GtEq
105                    | Operator::RegexMatch
106                    | Operator::RegexIMatch
107                    | Operator::RegexNotMatch
108                    | Operator::RegexNotIMatch => {}
109                    Operator::Or => {
110                        let lhs = Self::try_new(&binary.left)?;
111                        let rhs = Self::try_new(&binary.right)?;
112                        if lhs.column_name != rhs.column_name
113                            || !matches!(lhs.op, Operator::Eq | Operator::Or)
114                            || !matches!(rhs.op, Operator::Eq | Operator::Or)
115                        {
116                            return None;
117                        }
118                        let mut list = vec![];
119                        let placeholder_literal = lhs.literal.clone();
120                        // above check guarantees the op is either `Eq` or `Or`
121                        if matches!(lhs.op, Operator::Or) {
122                            list.extend(lhs.literal_list);
123                        } else {
124                            list.push(lhs.literal);
125                        }
126                        if matches!(rhs.op, Operator::Or) {
127                            list.extend(rhs.literal_list);
128                        } else {
129                            list.push(rhs.literal);
130                        }
131                        return Some(Self {
132                            column_name: lhs.column_name,
133                            literal: placeholder_literal,
134                            op: Operator::Or,
135                            literal_list: list,
136                            regex: None,
137                            regex_negative: false,
138                        });
139                    }
140                    _ => return None,
141                }
142
143                // swap the expr if it is in the form of `literal` `op` `col`
144                let mut op = binary.op;
145                let (lhs, rhs) = match (&*binary.left, &*binary.right) {
146                    (Expr::Column(ref col), Expr::Literal(ref lit)) => (col, lit),
147                    (Expr::Literal(ref lit), Expr::Column(ref col)) => {
148                        // safety: The previous check ensures the operator is able to swap.
149                        op = op.swap().unwrap();
150                        (col, lit)
151                    }
152                    _ => return None,
153                };
154
155                let (regex, regex_negative) = Self::maybe_build_regex(op, rhs).ok()?;
156                let literal = rhs.to_scalar().ok()?;
157                Some(Self {
158                    column_name: lhs.name.clone(),
159                    literal,
160                    op,
161                    literal_list: vec![],
162                    regex,
163                    regex_negative,
164                })
165            }
166            _ => None,
167        }
168    }
169
170    /// Get the name of the referenced column.
171    pub fn column_name(&self) -> &str {
172        &self.column_name
173    }
174
175    pub fn evaluate_scalar(&self, input: &ScalarValue) -> Result<bool> {
176        let input = input
177            .to_scalar()
178            .with_context(|_| ToArrowScalarSnafu { v: input.clone() })?;
179        let result = self.evaluate_datum(&input, 1)?;
180        Ok(result.value(0))
181    }
182
183    pub fn evaluate_array(&self, input: &ArrayRef) -> Result<BooleanBuffer> {
184        self.evaluate_datum(input, input.len())
185    }
186
187    pub fn evaluate_vector(&self, input: &VectorRef) -> Result<BooleanBuffer> {
188        self.evaluate_datum(&input.to_arrow_array(), input.len())
189    }
190
191    fn evaluate_datum(&self, input: &impl Datum, input_len: usize) -> Result<BooleanBuffer> {
192        let result = match self.op {
193            Operator::Eq => cmp::eq(input, &self.literal),
194            Operator::NotEq => cmp::neq(input, &self.literal),
195            Operator::Lt => cmp::lt(input, &self.literal),
196            Operator::LtEq => cmp::lt_eq(input, &self.literal),
197            Operator::Gt => cmp::gt(input, &self.literal),
198            Operator::GtEq => cmp::gt_eq(input, &self.literal),
199            Operator::RegexMatch => self.regex_match(input),
200            Operator::RegexIMatch => self.regex_match(input),
201            Operator::RegexNotMatch => self.regex_match(input),
202            Operator::RegexNotIMatch => self.regex_match(input),
203            Operator::Or => {
204                // OR operator stands for OR-chained EQs (or INLIST in other words)
205                let mut result: BooleanArray = vec![false; input_len].into();
206                for literal in &self.literal_list {
207                    let rhs = cmp::eq(input, literal).context(ArrowComputeSnafu)?;
208                    result = or_kleene(&result, &rhs).context(ArrowComputeSnafu)?;
209                }
210                Ok(result)
211            }
212            _ => {
213                return UnsupportedOperationSnafu {
214                    reason: format!("{:?}", self.op),
215                }
216                .fail()
217            }
218        };
219        result
220            .context(ArrowComputeSnafu)
221            .map(|array| array.values().clone())
222    }
223
224    /// Builds a regex pattern from a scalar value and operator.
225    /// Returns the `(regex, negative)` and if successful.
226    ///
227    /// Returns `Err` if
228    /// - the value is not a string
229    /// - the regex pattern is invalid
230    ///
231    /// The regex is `None` if
232    /// - the operator is not a regex operator
233    /// - the pattern is empty
234    fn maybe_build_regex(
235        operator: Operator,
236        value: &ScalarValue,
237    ) -> Result<(Option<Regex>, bool), ArrowError> {
238        let (ignore_case, negative) = match operator {
239            Operator::RegexMatch => (false, false),
240            Operator::RegexIMatch => (true, false),
241            Operator::RegexNotMatch => (false, true),
242            Operator::RegexNotIMatch => (true, true),
243            _ => return Ok((None, false)),
244        };
245        let flag = if ignore_case { Some("i") } else { None };
246        let regex = value
247            .try_as_str()
248            .ok_or_else(|| ArrowError::CastError(format!("Cannot cast {:?} to str", value)))?
249            .ok_or_else(|| ArrowError::CastError("Regex should not be null".to_string()))?;
250        let pattern = match flag {
251            Some(flag) => format!("(?{flag}){regex}"),
252            None => regex.to_string(),
253        };
254        if pattern.is_empty() {
255            Ok((None, negative))
256        } else {
257            Regex::new(pattern.as_str())
258                .map_err(|e| {
259                    ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}"))
260                })
261                .map(|regex| (Some(regex), negative))
262        }
263    }
264
265    fn regex_match(&self, input: &impl Datum) -> std::result::Result<BooleanArray, ArrowError> {
266        let array = input.get().0;
267        let string_array = as_string_array(array).map_err(|_| {
268            ArrowError::CastError(format!("Cannot cast {:?} to StringArray", array))
269        })?;
270        let mut result = regexp_is_match_scalar(string_array, self.regex.as_ref())?;
271        if self.regex_negative {
272            result = datatypes::compute::not(&result)?;
273        }
274        Ok(result)
275    }
276}
277
278/// Evaluate the predicate on the input [RecordBatch], and return a new [RecordBatch].
279/// Copy from datafusion::physical_plan::src::filter.rs
280pub fn batch_filter(
281    batch: &RecordBatch,
282    predicate: &Arc<dyn PhysicalExpr>,
283) -> DfResult<RecordBatch> {
284    predicate
285        .evaluate(batch)
286        .and_then(|v| v.into_array(batch.num_rows()))
287        .and_then(|array| {
288            let filter_array = match as_boolean_array(&array) {
289                Ok(boolean_array) => Ok(boolean_array.clone()),
290                Err(_) => {
291                    let Ok(null_array) = as_null_array(&array) else {
292                        return internal_err!(
293                            "Cannot create filter_array from non-boolean predicates"
294                        );
295                    };
296
297                    // if the predicate is null, then the result is also null
298                    Ok::<BooleanArray, DataFusionError>(BooleanArray::new_null(null_array.len()))
299                }
300            }?;
301            Ok(filter_record_batch(batch, &filter_array)?)
302        })
303}
304
305/// The same as arrow [regexp_is_match_scalar()](datatypes::compute::kernels::regexp::regexp_is_match_scalar())
306/// with pre-compiled regex.
307/// See <https://github.com/apache/arrow-rs/blob/54.2.0/arrow-string/src/regexp.rs#L204-L246> for the implementation details.
308pub fn regexp_is_match_scalar<'a, S>(
309    array: &'a S,
310    regex: Option<&Regex>,
311) -> Result<BooleanArray, ArrowError>
312where
313    &'a S: StringArrayType<'a>,
314{
315    let null_bit_buffer = array.nulls().map(|x| x.inner().sliced());
316    let mut result = BooleanBufferBuilder::new(array.len());
317
318    if let Some(re) = regex {
319        for i in 0..array.len() {
320            let value = array.value(i);
321            result.append(re.is_match(value));
322        }
323    } else {
324        result.append_n(array.len(), true);
325    }
326
327    let buffer = result.into();
328    let data = unsafe {
329        ArrayData::new_unchecked(
330            DataType::Boolean,
331            array.len(),
332            None,
333            null_bit_buffer,
334            0,
335            vec![buffer],
336            vec![],
337        )
338    };
339
340    Ok(BooleanArray::from(data))
341}
342
343#[cfg(test)]
344mod test {
345
346    use std::sync::Arc;
347
348    use datafusion::execution::context::ExecutionProps;
349    use datafusion::logical_expr::{col, lit, BinaryExpr};
350    use datafusion::physical_expr::create_physical_expr;
351    use datafusion_common::{Column, DFSchema};
352    use datatypes::arrow::datatypes::{DataType, Field, Schema};
353
354    use super::*;
355
356    #[test]
357    fn unsupported_filter_op() {
358        // `+` is not supported
359        let expr = Expr::BinaryExpr(BinaryExpr {
360            left: Box::new(Expr::Column(Column::from_name("foo"))),
361            op: Operator::Plus,
362            right: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))),
363        });
364        assert!(SimpleFilterEvaluator::try_new(&expr).is_none());
365
366        // two literal is not supported
367        let expr = Expr::BinaryExpr(BinaryExpr {
368            left: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))),
369            op: Operator::Eq,
370            right: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))),
371        });
372        assert!(SimpleFilterEvaluator::try_new(&expr).is_none());
373
374        // two column is not supported
375        let expr = Expr::BinaryExpr(BinaryExpr {
376            left: Box::new(Expr::Column(Column::from_name("foo"))),
377            op: Operator::Eq,
378            right: Box::new(Expr::Column(Column::from_name("bar"))),
379        });
380        assert!(SimpleFilterEvaluator::try_new(&expr).is_none());
381
382        // compound expr is not supported
383        let expr = Expr::BinaryExpr(BinaryExpr {
384            left: Box::new(Expr::BinaryExpr(BinaryExpr {
385                left: Box::new(Expr::Column(Column::from_name("foo"))),
386                op: Operator::Eq,
387                right: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))),
388            })),
389            op: Operator::Eq,
390            right: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))),
391        });
392        assert!(SimpleFilterEvaluator::try_new(&expr).is_none());
393    }
394
395    #[test]
396    fn supported_filter_op() {
397        // equal
398        let expr = Expr::BinaryExpr(BinaryExpr {
399            left: Box::new(Expr::Column(Column::from_name("foo"))),
400            op: Operator::Eq,
401            right: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))),
402        });
403        let _ = SimpleFilterEvaluator::try_new(&expr).unwrap();
404
405        // swap operands
406        let expr = Expr::BinaryExpr(BinaryExpr {
407            left: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))),
408            op: Operator::Lt,
409            right: Box::new(Expr::Column(Column::from_name("foo"))),
410        });
411        let evaluator = SimpleFilterEvaluator::try_new(&expr).unwrap();
412        assert_eq!(evaluator.op, Operator::Gt);
413        assert_eq!(evaluator.column_name, "foo".to_string());
414    }
415
416    #[test]
417    fn run_on_array() {
418        let expr = Expr::BinaryExpr(BinaryExpr {
419            left: Box::new(Expr::Column(Column::from_name("foo"))),
420            op: Operator::Eq,
421            right: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))),
422        });
423        let evaluator = SimpleFilterEvaluator::try_new(&expr).unwrap();
424
425        let input_1 = Arc::new(datatypes::arrow::array::Int64Array::from(vec![1, 2, 3])) as _;
426        let result = evaluator.evaluate_array(&input_1).unwrap();
427        assert_eq!(result, BooleanBuffer::from(vec![true, false, false]));
428
429        let input_2 = Arc::new(datatypes::arrow::array::Int64Array::from(vec![1, 1, 1])) as _;
430        let result = evaluator.evaluate_array(&input_2).unwrap();
431        assert_eq!(result, BooleanBuffer::from(vec![true, true, true]));
432
433        let input_3 = Arc::new(datatypes::arrow::array::Int64Array::new_null(0)) as _;
434        let result = evaluator.evaluate_array(&input_3).unwrap();
435        assert_eq!(result, BooleanBuffer::from(vec![]));
436    }
437
438    #[test]
439    fn run_on_scalar() {
440        let expr = Expr::BinaryExpr(BinaryExpr {
441            left: Box::new(Expr::Column(Column::from_name("foo"))),
442            op: Operator::Lt,
443            right: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))),
444        });
445        let evaluator = SimpleFilterEvaluator::try_new(&expr).unwrap();
446
447        let input_1 = ScalarValue::Int64(Some(1));
448        let result = evaluator.evaluate_scalar(&input_1).unwrap();
449        assert!(!result);
450
451        let input_2 = ScalarValue::Int64(Some(0));
452        let result = evaluator.evaluate_scalar(&input_2).unwrap();
453        assert!(result);
454
455        let input_3 = ScalarValue::Int64(None);
456        let result = evaluator.evaluate_scalar(&input_3).unwrap();
457        assert!(!result);
458    }
459
460    #[test]
461    fn batch_filter_test() {
462        let expr = col("ts").gt(lit(123456u64));
463        let schema = Schema::new(vec![
464            Field::new("a", DataType::Int32, true),
465            Field::new("ts", DataType::UInt64, false),
466        ]);
467        let df_schema = DFSchema::try_from(schema.clone()).unwrap();
468        let props = ExecutionProps::new();
469        let physical_expr = create_physical_expr(&expr, &df_schema, &props).unwrap();
470        let batch = RecordBatch::try_new(
471            Arc::new(schema),
472            vec![
473                Arc::new(datatypes::arrow::array::Int32Array::from(vec![4, 5, 6])),
474                Arc::new(datatypes::arrow::array::UInt64Array::from(vec![
475                    123456, 123457, 123458,
476                ])),
477            ],
478        )
479        .unwrap();
480        let new_batch = batch_filter(&batch, &physical_expr).unwrap();
481        assert_eq!(new_batch.num_rows(), 2);
482        let first_column_values = new_batch
483            .column(0)
484            .as_any()
485            .downcast_ref::<datatypes::arrow::array::Int32Array>()
486            .unwrap();
487        let expected = datatypes::arrow::array::Int32Array::from(vec![5, 6]);
488        assert_eq!(first_column_values, &expected);
489    }
490
491    #[test]
492    fn test_complex_filter_expression() {
493        // Create an expression tree for: col = 'B' OR col = 'C' OR col = 'D'
494        let col_eq_b = col("col").eq(lit("B"));
495        let col_eq_c = col("col").eq(lit("C"));
496        let col_eq_d = col("col").eq(lit("D"));
497
498        // Build the OR chain
499        let col_or_expr = col_eq_b.or(col_eq_c).or(col_eq_d);
500
501        // Check that SimpleFilterEvaluator can handle OR chain
502        let or_evaluator = SimpleFilterEvaluator::try_new(&col_or_expr).unwrap();
503        assert_eq!(or_evaluator.column_name, "col");
504        assert_eq!(or_evaluator.op, Operator::Or);
505        assert_eq!(or_evaluator.literal_list.len(), 3);
506        assert_eq!(format!("{:?}", or_evaluator.literal_list), "[Scalar(StringArray\n[\n  \"B\",\n]), Scalar(StringArray\n[\n  \"C\",\n]), Scalar(StringArray\n[\n  \"D\",\n])]");
507
508        // Create a schema and batch for testing
509        let schema = Schema::new(vec![Field::new("col", DataType::Utf8, false)]);
510        let df_schema = DFSchema::try_from(schema.clone()).unwrap();
511        let props = ExecutionProps::new();
512        let physical_expr = create_physical_expr(&col_or_expr, &df_schema, &props).unwrap();
513
514        // Create test data
515        let col_data = Arc::new(datatypes::arrow::array::StringArray::from(vec![
516            "B", "C", "E", "B", "C", "D", "F",
517        ]));
518        let batch = RecordBatch::try_new(Arc::new(schema), vec![col_data]).unwrap();
519        let expected = datatypes::arrow::array::StringArray::from(vec!["B", "C", "B", "C", "D"]);
520
521        // Filter the batch
522        let filtered_batch = batch_filter(&batch, &physical_expr).unwrap();
523
524        // Expected: rows with col in ("B", "C", "D")
525        // That would be rows 0, 1, 3, 4, 5
526        assert_eq!(filtered_batch.num_rows(), 5);
527
528        let col_filtered = filtered_batch
529            .column(0)
530            .as_any()
531            .downcast_ref::<datatypes::arrow::array::StringArray>()
532            .unwrap();
533        assert_eq!(col_filtered, &expected);
534    }
535
536    #[test]
537    fn test_maybe_build_regex() {
538        // Test case for RegexMatch (case sensitive, non-negative)
539        let (regex, negative) = SimpleFilterEvaluator::maybe_build_regex(
540            Operator::RegexMatch,
541            &ScalarValue::Utf8(Some("a.*b".to_string())),
542        )
543        .unwrap();
544        assert!(regex.is_some());
545        assert!(!negative);
546        assert!(regex.unwrap().is_match("axxb"));
547
548        // Test case for RegexIMatch (case insensitive, non-negative)
549        let (regex, negative) = SimpleFilterEvaluator::maybe_build_regex(
550            Operator::RegexIMatch,
551            &ScalarValue::Utf8(Some("a.*b".to_string())),
552        )
553        .unwrap();
554        assert!(regex.is_some());
555        assert!(!negative);
556        assert!(regex.unwrap().is_match("AxxB"));
557
558        // Test case for RegexNotMatch (case sensitive, negative)
559        let (regex, negative) = SimpleFilterEvaluator::maybe_build_regex(
560            Operator::RegexNotMatch,
561            &ScalarValue::Utf8(Some("a.*b".to_string())),
562        )
563        .unwrap();
564        assert!(regex.is_some());
565        assert!(negative);
566
567        // Test case for RegexNotIMatch (case insensitive, negative)
568        let (regex, negative) = SimpleFilterEvaluator::maybe_build_regex(
569            Operator::RegexNotIMatch,
570            &ScalarValue::Utf8(Some("a.*b".to_string())),
571        )
572        .unwrap();
573        assert!(regex.is_some());
574        assert!(negative);
575
576        // Test with empty regex pattern
577        let (regex, negative) = SimpleFilterEvaluator::maybe_build_regex(
578            Operator::RegexMatch,
579            &ScalarValue::Utf8(Some("".to_string())),
580        )
581        .unwrap();
582        assert!(regex.is_none());
583        assert!(!negative);
584
585        // Test with non-regex operator
586        let (regex, negative) = SimpleFilterEvaluator::maybe_build_regex(
587            Operator::Eq,
588            &ScalarValue::Utf8(Some("a.*b".to_string())),
589        )
590        .unwrap();
591        assert!(regex.is_none());
592        assert!(!negative);
593
594        // Test with invalid regex pattern
595        let result = SimpleFilterEvaluator::maybe_build_regex(
596            Operator::RegexMatch,
597            &ScalarValue::Utf8(Some("a(b".to_string())),
598        );
599        assert!(result.is_err());
600
601        // Test with non-string value
602        let result = SimpleFilterEvaluator::maybe_build_regex(
603            Operator::RegexMatch,
604            &ScalarValue::Int64(Some(123)),
605        );
606        assert!(result.is_err());
607
608        // Test with null value
609        let result = SimpleFilterEvaluator::maybe_build_regex(
610            Operator::RegexMatch,
611            &ScalarValue::Utf8(None),
612        );
613        assert!(result.is_err());
614    }
615}