Skip to main content

table/
predicate.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use arc_swap::ArcSwap;
18use common_telemetry::{debug, warn};
19use common_time::Timestamp;
20use common_time::range::TimestampRange;
21use common_time::timestamp::TimeUnit;
22use datafusion::common::ScalarValue;
23use datafusion::physical_optimizer::pruning::PruningPredicate;
24use datafusion_common::ToDFSchema;
25use datafusion_common::pruning::PruningStatistics;
26use datafusion_expr::expr::{Expr, InList};
27use datafusion_expr::{Between, BinaryExpr, Operator};
28use datafusion_physical_expr::execution_props::ExecutionProps;
29use datafusion_physical_expr::expressions::DynamicFilterPhysicalExpr;
30use datafusion_physical_expr::{PhysicalExpr, create_physical_expr};
31use datatypes::arrow;
32use datatypes::value::scalar_value_to_timestamp;
33use snafu::ResultExt;
34
35use crate::error;
36
37#[cfg(test)]
38mod stats;
39
40/// Assert the scalar value is not utf8. Returns `None` if it's utf8.
41/// In theory, it should be converted to a timestamp scalar value by `TypeConversionRule`.
42macro_rules! return_none_if_utf8 {
43    ($lit: ident) => {
44        if matches!($lit, ScalarValue::Utf8(_)) {
45            warn!(
46                "Unexpected ScalarValue::Utf8 in time range predicate: {:?}. Maybe it's an implicit bug, please report it to https://github.com/GreptimeTeam/greptimedb/issues",
47                $lit
48            );
49
50            // Make the predicate ineffective.
51            return None;
52        }
53    };
54}
55
56/// Reference-counted pointer to a list of logical exprs and a list of dynamic filter physical exprs.
57#[derive(Debug, Clone, Default)]
58pub struct Predicate {
59    /// logical exprs
60    exprs: Arc<Vec<Expr>>,
61    /// dynamic filter physical exprs, only useful if dynamic filtering is enabled
62    ///
63    /// They are usually from `TopK` or `Join` operators, and can dynamically filter data during query execution by using current runtime information to further reduce data scanning
64    dyn_filters: Arc<ArcSwap<Vec<Arc<DynamicFilterPhysicalExpr>>>>,
65}
66
67impl Predicate {
68    /// Creates a new `Predicate` by converting logical exprs to physical exprs that can be
69    /// evaluated against record batches.
70    /// Returns error when failed to convert exprs.
71    pub fn new(exprs: Vec<Expr>) -> Self {
72        Self {
73            exprs: Arc::new(exprs),
74            dyn_filters: Arc::new(ArcSwap::new(Arc::new(vec![]))),
75        }
76    }
77
78    pub fn with_dyn_filters(
79        exprs: Vec<Expr>,
80        dyn_filters: Vec<Arc<DynamicFilterPhysicalExpr>>,
81    ) -> Self {
82        Self {
83            exprs: Arc::new(exprs),
84            dyn_filters: Arc::new(ArcSwap::new(Arc::new(dyn_filters))),
85        }
86    }
87
88    pub fn is_empty(&self) -> bool {
89        self.exprs.is_empty() && self.dyn_filters.load().is_empty()
90    }
91
92    /// Adds dynamic filter physical exprs to the existing list.
93    pub fn add_dyn_filters(&self, dyn_filters: Vec<Arc<DynamicFilterPhysicalExpr>>) {
94        self.dyn_filters.rcu(|existing| {
95            let mut new_filters = existing.as_ref().clone();
96            new_filters.extend(dyn_filters.clone());
97            Arc::new(new_filters)
98        });
99    }
100
101    /// Returns the logical exprs.
102    pub fn exprs(&self) -> &[Expr] {
103        &self.exprs
104    }
105
106    /// Returns the dynamic filter physical exprs. Notice this return a live dynamic filters which
107    /// can change during query execution.
108    pub fn dyn_filters(&self) -> Arc<Vec<Arc<DynamicFilterPhysicalExpr>>> {
109        self.dyn_filters.load_full()
110    }
111
112    /// Returns the dynamic filter as physical exprs. Notice this return a "snapshot" of
113    /// dynamic filters at the time of calling this method.
114    pub fn dyn_filter_phy_exprs(&self) -> error::Result<Vec<Arc<dyn PhysicalExpr>>> {
115        self.dyn_filters
116            .load()
117            .iter()
118            .map(|e| e.current())
119            .collect::<Result<Vec<_>, _>>()
120            .context(error::DatafusionSnafu)
121    }
122
123    /// Builds a single physical expr according to provided schema.
124    pub fn to_physical_expr(
125        expr: &Expr,
126        schema: &arrow::datatypes::SchemaRef,
127    ) -> error::Result<Arc<dyn PhysicalExpr>> {
128        let df_schema = schema
129            .clone()
130            .to_dfschema_ref()
131            .context(error::DatafusionSnafu)?;
132
133        // TODO(hl): `execution_props` provides variables required by evaluation.
134        // we may reuse the `execution_props` from `SessionState` once we support
135        // registering variables.
136        let execution_props = &ExecutionProps::new();
137
138        create_physical_expr(expr, df_schema.as_ref(), execution_props)
139            .context(error::DatafusionSnafu)
140    }
141
142    /// Builds physical exprs according to provided schema.
143    pub fn to_physical_exprs(
144        &self,
145        schema: &arrow::datatypes::SchemaRef,
146    ) -> error::Result<Vec<Arc<dyn PhysicalExpr>>> {
147        let dyn_filters = self.dyn_filter_phy_exprs()?;
148
149        Ok(self
150            .exprs
151            .iter()
152            .filter_map(|expr| Self::to_physical_expr(expr, schema).ok())
153            .chain(dyn_filters)
154            .collect::<Vec<_>>())
155    }
156
157    /// Evaluates the predicate against the `stats`.
158    /// Returns a vector of boolean values, among which `false` means the row group can be skipped.
159    pub fn prune_with_stats<S: PruningStatistics>(
160        &self,
161        stats: &S,
162        schema: &arrow::datatypes::SchemaRef,
163    ) -> Vec<bool> {
164        let mut res = vec![true; stats.num_containers()];
165        let physical_exprs = match self.to_physical_exprs(schema) {
166            Ok(expr) => expr,
167            Err(e) => {
168                warn!(e; "Failed to build physical expr from predicates: {:?}", &self.exprs);
169                return res;
170            }
171        };
172
173        for expr in &physical_exprs {
174            match PruningPredicate::try_new(expr.clone(), schema.clone()) {
175                Ok(p) => match p.prune(stats) {
176                    Ok(r) => {
177                        for (curr_val, res) in r.into_iter().zip(res.iter_mut()) {
178                            *res &= curr_val
179                        }
180                    }
181                    Err(e) => {
182                        warn!(e; "Failed to prune row groups");
183                    }
184                },
185                Err(e) => {
186                    // since dynamic filter exprs could be complex, it's possible that `PruningPredicate::try_new` fails to prove anything from it. In that case, we just log it and skip pruning with this expr.
187                    debug!("Failed to create pruning predicate for expr: {e:?}");
188                }
189            }
190        }
191        res
192    }
193}
194
195// tests for `build_time_range_predicate` locates in src/query/tests/time_range_filter_test.rs
196// since it requires query engine to convert sql to filters.
197/// `build_time_range_predicate` extracts time range from logical exprs to facilitate fast
198/// time range pruning.
199pub fn build_time_range_predicate(
200    ts_col_name: &str,
201    ts_col_unit: TimeUnit,
202    filters: &[Expr],
203) -> TimestampRange {
204    let mut res = TimestampRange::min_to_max();
205    for expr in filters {
206        if let Some(range) = extract_time_range_from_expr(ts_col_name, ts_col_unit, expr) {
207            res = res.and(&range);
208        }
209    }
210    res
211}
212
213/// Extract time range filter from `WHERE`/`IN (...)`/`BETWEEN` clauses.
214/// Return None if no time range can be found in expr.
215pub fn extract_time_range_from_expr(
216    ts_col_name: &str,
217    ts_col_unit: TimeUnit,
218    expr: &Expr,
219) -> Option<TimestampRange> {
220    match expr {
221        Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
222            extract_from_binary_expr(ts_col_name, ts_col_unit, left, op, right)
223        }
224        Expr::Between(Between {
225            expr,
226            negated,
227            low,
228            high,
229        }) => extract_from_between_expr(ts_col_name, ts_col_unit, expr, negated, low, high),
230        Expr::InList(InList {
231            expr,
232            list,
233            negated,
234        }) => extract_from_in_list_expr(ts_col_name, expr, *negated, list),
235        _ => None,
236    }
237}
238
239fn extract_from_binary_expr(
240    ts_col_name: &str,
241    ts_col_unit: TimeUnit,
242    left: &Expr,
243    op: &Operator,
244    right: &Expr,
245) -> Option<TimestampRange> {
246    match op {
247        Operator::Eq => get_timestamp_filter(ts_col_name, left, right)
248            .and_then(|(ts, _)| ts.convert_to(ts_col_unit))
249            .map(TimestampRange::single),
250        Operator::Lt => {
251            let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
252            if reverse {
253                // [lit] < ts_col
254                let ts_val = ts.convert_to(ts_col_unit)?.value();
255                Some(TimestampRange::from_start(Timestamp::new(
256                    ts_val + 1,
257                    ts_col_unit,
258                )))
259            } else {
260                // ts_col < [lit]
261                ts.convert_to_ceil(ts_col_unit)
262                    .map(|ts| TimestampRange::until_end(ts, false))
263            }
264        }
265        Operator::LtEq => {
266            let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
267            if reverse {
268                // [lit] <= ts_col
269                ts.convert_to_ceil(ts_col_unit)
270                    .map(TimestampRange::from_start)
271            } else {
272                // ts_col <= [lit]
273                ts.convert_to(ts_col_unit)
274                    .map(|ts| TimestampRange::until_end(ts, true))
275            }
276        }
277        Operator::Gt => {
278            let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
279            if reverse {
280                // [lit] > ts_col
281                ts.convert_to_ceil(ts_col_unit)
282                    .map(|t| TimestampRange::until_end(t, false))
283            } else {
284                // ts_col > [lit]
285                let ts_val = ts.convert_to(ts_col_unit)?.value();
286                Some(TimestampRange::from_start(Timestamp::new(
287                    ts_val + 1,
288                    ts_col_unit,
289                )))
290            }
291        }
292        Operator::GtEq => {
293            let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
294            if reverse {
295                // [lit] >= ts_col
296                ts.convert_to(ts_col_unit)
297                    .map(|t| TimestampRange::until_end(t, true))
298            } else {
299                // ts_col >= [lit]
300                ts.convert_to_ceil(ts_col_unit)
301                    .map(TimestampRange::from_start)
302            }
303        }
304        Operator::And => {
305            // instead of return none when failed to extract time range from left/right, we unwrap the none into
306            // `TimestampRange::min_to_max`.
307            let left = extract_time_range_from_expr(ts_col_name, ts_col_unit, left)
308                .unwrap_or_else(TimestampRange::min_to_max);
309            let right = extract_time_range_from_expr(ts_col_name, ts_col_unit, right)
310                .unwrap_or_else(TimestampRange::min_to_max);
311            Some(left.and(&right))
312        }
313        Operator::Or => {
314            let left = extract_time_range_from_expr(ts_col_name, ts_col_unit, left)?;
315            let right = extract_time_range_from_expr(ts_col_name, ts_col_unit, right)?;
316            Some(left.or(&right))
317        }
318        _ => None,
319    }
320}
321
322fn get_timestamp_filter(ts_col_name: &str, left: &Expr, right: &Expr) -> Option<(Timestamp, bool)> {
323    let (col, lit, reverse) = match (left, right) {
324        (Expr::Column(column), Expr::Literal(scalar, _)) => (column, scalar, false),
325        (Expr::Literal(scalar, _), Expr::Column(column)) => (column, scalar, true),
326        _ => {
327            return None;
328        }
329    };
330    if col.name != ts_col_name {
331        return None;
332    }
333
334    return_none_if_utf8!(lit);
335    scalar_value_to_timestamp(lit, None).map(|t| (t, reverse))
336}
337
338fn extract_from_between_expr(
339    ts_col_name: &str,
340    ts_col_unit: TimeUnit,
341    expr: &Expr,
342    negated: &bool,
343    low: &Expr,
344    high: &Expr,
345) -> Option<TimestampRange> {
346    let Expr::Column(col) = expr else {
347        return None;
348    };
349    if col.name != ts_col_name {
350        return None;
351    }
352
353    if *negated {
354        return None;
355    }
356
357    match (low, high) {
358        (Expr::Literal(low, _), Expr::Literal(high, _)) => {
359            return_none_if_utf8!(low);
360            return_none_if_utf8!(high);
361
362            let low_opt =
363                scalar_value_to_timestamp(low, None).and_then(|ts| ts.convert_to(ts_col_unit));
364            let high_opt = scalar_value_to_timestamp(high, None)
365                .and_then(|ts| ts.convert_to_ceil(ts_col_unit));
366            Some(TimestampRange::new_inclusive(low_opt, high_opt))
367        }
368        _ => None,
369    }
370}
371
372/// Extract time range filter from `IN (...)` expr.
373fn extract_from_in_list_expr(
374    ts_col_name: &str,
375    expr: &Expr,
376    negated: bool,
377    list: &[Expr],
378) -> Option<TimestampRange> {
379    if negated {
380        return None;
381    }
382    let Expr::Column(col) = expr else {
383        return None;
384    };
385    if col.name != ts_col_name {
386        return None;
387    }
388
389    if list.is_empty() {
390        return Some(TimestampRange::empty());
391    }
392    let mut init_range = TimestampRange::empty();
393    for expr in list {
394        if let Expr::Literal(scalar, _) = expr {
395            return_none_if_utf8!(scalar);
396            if let Some(timestamp) = scalar_value_to_timestamp(scalar, None) {
397                init_range = init_range.or(&TimestampRange::single(timestamp))
398            } else {
399                // TODO(hl): maybe we should raise an error here since cannot parse
400                // timestamp value from in list expr
401                return None;
402            }
403        }
404    }
405    Some(init_range)
406}
407
408#[cfg(test)]
409mod tests {
410    use std::sync::Arc;
411
412    use common_test_util::temp_dir::{TempDir, create_temp_dir};
413    use datafusion::parquet::arrow::ArrowWriter;
414    use datafusion_common::{Column, ScalarValue};
415    use datafusion_expr::{BinaryExpr, Literal, Operator, col, lit};
416    use datatypes::arrow::array::Int32Array;
417    use datatypes::arrow::datatypes::{DataType, Field, Schema};
418    use datatypes::arrow::record_batch::RecordBatch;
419    use datatypes::arrow_array::StringArray;
420    use parquet::arrow::ParquetRecordBatchStreamBuilder;
421    use parquet::file::properties::WriterProperties;
422
423    use super::*;
424    use crate::predicate::stats::RowGroupPruningStatistics;
425
426    fn check_build_predicate(expr: Expr, expect: TimestampRange) {
427        assert_eq!(
428            expect,
429            build_time_range_predicate("ts", TimeUnit::Millisecond, &[expr])
430        );
431    }
432
433    #[test]
434    fn test_gt() {
435        // ts > 1ms
436        check_build_predicate(
437            col("ts").gt(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
438            TimestampRange::from_start(Timestamp::new_millisecond(2)),
439        );
440
441        // 1ms > ts
442        check_build_predicate(
443            lit(ScalarValue::TimestampMillisecond(Some(1), None)).gt(col("ts")),
444            TimestampRange::until_end(Timestamp::new_millisecond(1), false),
445        );
446
447        // 1001us > ts
448        check_build_predicate(
449            lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).gt(col("ts")),
450            TimestampRange::until_end(Timestamp::new_millisecond(1), true),
451        );
452
453        // ts > 1001us
454        check_build_predicate(
455            col("ts").gt(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
456            TimestampRange::from_start(Timestamp::new_millisecond(2)),
457        );
458
459        // 1s > ts
460        check_build_predicate(
461            lit(ScalarValue::TimestampSecond(Some(1), None)).gt(col("ts")),
462            TimestampRange::until_end(Timestamp::new_millisecond(1000), false),
463        );
464
465        // ts > 1s
466        check_build_predicate(
467            col("ts").gt(lit(ScalarValue::TimestampSecond(Some(1), None))),
468            TimestampRange::from_start(Timestamp::new_millisecond(1001)),
469        );
470    }
471
472    #[test]
473    fn test_gt_eq() {
474        // ts >= 1ms
475        check_build_predicate(
476            col("ts").gt_eq(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
477            TimestampRange::from_start(Timestamp::new_millisecond(1)),
478        );
479
480        // 1ms >= ts
481        check_build_predicate(
482            lit(ScalarValue::TimestampMillisecond(Some(1), None)).gt_eq(col("ts")),
483            TimestampRange::until_end(Timestamp::new_millisecond(1), true),
484        );
485
486        // 1001us >= ts
487        check_build_predicate(
488            lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).gt_eq(col("ts")),
489            TimestampRange::until_end(Timestamp::new_millisecond(1), true),
490        );
491
492        // ts >= 1001us
493        check_build_predicate(
494            col("ts").gt_eq(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
495            TimestampRange::from_start(Timestamp::new_millisecond(2)),
496        );
497
498        // 1s >= ts
499        check_build_predicate(
500            lit(ScalarValue::TimestampSecond(Some(1), None)).gt_eq(col("ts")),
501            TimestampRange::until_end(Timestamp::new_millisecond(1000), true),
502        );
503
504        // ts >= 1s
505        check_build_predicate(
506            col("ts").gt_eq(lit(ScalarValue::TimestampSecond(Some(1), None))),
507            TimestampRange::from_start(Timestamp::new_millisecond(1000)),
508        );
509    }
510
511    #[test]
512    fn test_lt() {
513        // ts < 1ms
514        check_build_predicate(
515            col("ts").lt(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
516            TimestampRange::until_end(Timestamp::new_millisecond(1), false),
517        );
518
519        // 1ms < ts
520        check_build_predicate(
521            lit(ScalarValue::TimestampMillisecond(Some(1), None)).lt(col("ts")),
522            TimestampRange::from_start(Timestamp::new_millisecond(2)),
523        );
524
525        // 1001us < ts
526        check_build_predicate(
527            lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).lt(col("ts")),
528            TimestampRange::from_start(Timestamp::new_millisecond(2)),
529        );
530
531        // ts < 1001us
532        check_build_predicate(
533            col("ts").lt(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
534            TimestampRange::until_end(Timestamp::new_millisecond(1), true),
535        );
536
537        // 1s < ts
538        check_build_predicate(
539            lit(ScalarValue::TimestampSecond(Some(1), None)).lt(col("ts")),
540            TimestampRange::from_start(Timestamp::new_millisecond(1001)),
541        );
542
543        // ts < 1s
544        check_build_predicate(
545            col("ts").lt(lit(ScalarValue::TimestampSecond(Some(1), None))),
546            TimestampRange::until_end(Timestamp::new_millisecond(1000), false),
547        );
548    }
549
550    #[test]
551    fn test_lt_eq() {
552        // ts <= 1ms
553        check_build_predicate(
554            col("ts").lt_eq(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
555            TimestampRange::until_end(Timestamp::new_millisecond(1), true),
556        );
557
558        // 1ms <= ts
559        check_build_predicate(
560            lit(ScalarValue::TimestampMillisecond(Some(1), None)).lt_eq(col("ts")),
561            TimestampRange::from_start(Timestamp::new_millisecond(1)),
562        );
563
564        // 1001us <= ts
565        check_build_predicate(
566            lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).lt_eq(col("ts")),
567            TimestampRange::from_start(Timestamp::new_millisecond(2)),
568        );
569
570        // ts <= 1001us
571        check_build_predicate(
572            col("ts").lt_eq(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
573            TimestampRange::until_end(Timestamp::new_millisecond(1), true),
574        );
575
576        // 1s <= ts
577        check_build_predicate(
578            lit(ScalarValue::TimestampSecond(Some(1), None)).lt_eq(col("ts")),
579            TimestampRange::from_start(Timestamp::new_millisecond(1000)),
580        );
581
582        // ts <= 1s
583        check_build_predicate(
584            col("ts").lt_eq(lit(ScalarValue::TimestampSecond(Some(1), None))),
585            TimestampRange::until_end(Timestamp::new_millisecond(1000), true),
586        );
587    }
588
589    async fn gen_test_parquet_file(dir: &TempDir, cnt: usize) -> (String, Arc<Schema>) {
590        let path = dir
591            .path()
592            .join("test-prune.parquet")
593            .to_string_lossy()
594            .to_string();
595
596        let name_field = Field::new("name", DataType::Utf8, true);
597        let count_field = Field::new("cnt", DataType::Int32, true);
598        let schema = Arc::new(Schema::new(vec![name_field, count_field]));
599
600        let file = std::fs::OpenOptions::new()
601            .write(true)
602            .create(true)
603            .truncate(true)
604            .open(path.clone())
605            .unwrap();
606
607        let write_props = WriterProperties::builder()
608            .set_max_row_group_size(10)
609            .build();
610        let mut writer = ArrowWriter::try_new(file, schema.clone(), Some(write_props)).unwrap();
611
612        for i in (0..cnt).step_by(10) {
613            let name_array = Arc::new(StringArray::from(
614                (i..(i + 10).min(cnt))
615                    .map(|i| i.to_string())
616                    .collect::<Vec<_>>(),
617            )) as Arc<_>;
618            let count_array = Arc::new(Int32Array::from(
619                (i..(i + 10).min(cnt)).map(|i| i as i32).collect::<Vec<_>>(),
620            )) as Arc<_>;
621            let rb = RecordBatch::try_new(schema.clone(), vec![name_array, count_array]).unwrap();
622            writer.write(&rb).unwrap();
623        }
624        let _ = writer.close().unwrap();
625        (path, schema)
626    }
627
628    async fn assert_prune(array_cnt: usize, filters: Vec<Expr>, expect: Vec<bool>) {
629        let dir = create_temp_dir("prune_parquet");
630        let (path, arrow_schema) = gen_test_parquet_file(&dir, array_cnt).await;
631        let schema = Arc::new(datatypes::schema::Schema::try_from(arrow_schema.clone()).unwrap());
632        let arrow_predicate = Predicate::new(filters);
633        let builder = ParquetRecordBatchStreamBuilder::new(
634            tokio::fs::OpenOptions::new()
635                .read(true)
636                .open(path)
637                .await
638                .unwrap(),
639        )
640        .await
641        .unwrap();
642        let metadata = builder.metadata().clone();
643        let row_groups = metadata.row_groups();
644
645        let stats = RowGroupPruningStatistics::new(row_groups, &schema);
646        let res = arrow_predicate.prune_with_stats(&stats, &arrow_schema);
647        assert_eq!(expect, res);
648    }
649
650    fn gen_predicate(max_val: i32, op: Operator) -> Vec<Expr> {
651        vec![datafusion_expr::Expr::BinaryExpr(BinaryExpr {
652            left: Box::new(datafusion_expr::Expr::Column(Column::from_name("cnt"))),
653            op,
654            right: Box::new(max_val.lit()),
655        })]
656    }
657
658    #[tokio::test]
659    async fn test_prune_empty() {
660        assert_prune(3, vec![], vec![true]).await;
661    }
662
663    #[tokio::test]
664    async fn test_prune_all_match() {
665        let p = gen_predicate(3, Operator::Gt);
666        assert_prune(2, p, vec![false]).await;
667    }
668
669    #[tokio::test]
670    async fn test_prune_gt() {
671        let p = gen_predicate(29, Operator::Gt);
672        assert_prune(
673            100,
674            p,
675            vec![
676                false, false, false, true, true, true, true, true, true, true,
677            ],
678        )
679        .await;
680    }
681
682    #[tokio::test]
683    async fn test_prune_eq_expr() {
684        let p = gen_predicate(30, Operator::Eq);
685        assert_prune(40, p, vec![false, false, false, true]).await;
686    }
687
688    #[tokio::test]
689    async fn test_prune_neq_expr() {
690        let p = gen_predicate(30, Operator::NotEq);
691        assert_prune(40, p, vec![true, true, true, true]).await;
692    }
693
694    #[tokio::test]
695    async fn test_prune_gteq_expr() {
696        let p = gen_predicate(29, Operator::GtEq);
697        assert_prune(40, p, vec![false, false, true, true]).await;
698    }
699
700    #[tokio::test]
701    async fn test_prune_lt_expr() {
702        let p = gen_predicate(30, Operator::Lt);
703        assert_prune(40, p, vec![true, true, true, false]).await;
704    }
705
706    #[tokio::test]
707    async fn test_prune_lteq_expr() {
708        let p = gen_predicate(30, Operator::LtEq);
709        assert_prune(40, p, vec![true, true, true, true]).await;
710    }
711
712    #[tokio::test]
713    async fn test_prune_between_expr() {
714        let p = gen_predicate(30, Operator::LtEq);
715        assert_prune(40, p, vec![true, true, true, true]).await;
716    }
717
718    #[tokio::test]
719    async fn test_or() {
720        // cnt > 30 or cnt < 20
721        let e = datafusion_expr::Expr::Column(Column::from_name("cnt"))
722            .gt(30.lit())
723            .or(datafusion_expr::Expr::Column(Column::from_name("cnt")).lt(20.lit()));
724        assert_prune(40, vec![e], vec![true, true, false, true]).await;
725    }
726
727    #[tokio::test]
728    async fn test_to_physical_expr() {
729        let predicate = Predicate::new(vec![
730            col("host").eq(lit("host_a")),
731            col("ts").gt(lit(ScalarValue::TimestampMicrosecond(Some(123), None))),
732        ]);
733
734        let schema = Arc::new(arrow::datatypes::Schema::new(vec![Field::new(
735            "host",
736            arrow::datatypes::DataType::Utf8,
737            false,
738        )]));
739
740        let predicates = predicate.to_physical_exprs(&schema).unwrap();
741        assert!(!predicates.is_empty());
742
743        let physical_expr = Predicate::to_physical_expr(&col("host").eq(lit("host_a")), &schema);
744        assert!(physical_expr.is_ok());
745    }
746}