table/
predicate.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use common_telemetry::{error, warn};
18use common_time::range::TimestampRange;
19use common_time::timestamp::TimeUnit;
20use common_time::Timestamp;
21use datafusion::common::ScalarValue;
22use datafusion::physical_optimizer::pruning::PruningPredicate;
23use datafusion_common::pruning::PruningStatistics;
24use datafusion_common::ToDFSchema;
25use datafusion_expr::expr::{Expr, InList};
26use datafusion_expr::{Between, BinaryExpr, Operator};
27use datafusion_physical_expr::execution_props::ExecutionProps;
28use datafusion_physical_expr::{create_physical_expr, PhysicalExpr};
29use datatypes::arrow;
30use datatypes::value::scalar_value_to_timestamp;
31use snafu::ResultExt;
32
33use crate::error;
34
35#[cfg(test)]
36mod stats;
37
38/// Assert the scalar value is not utf8. Returns `None` if it's utf8.
39/// In theory, it should be converted to a timestamp scalar value by `TypeConversionRule`.
40macro_rules! return_none_if_utf8 {
41    ($lit: ident) => {
42        if matches!($lit, ScalarValue::Utf8(_)) {
43            warn!(
44                "Unexpected ScalarValue::Utf8 in time range predicate: {:?}. Maybe it's an implicit bug, please report it to https://github.com/GreptimeTeam/greptimedb/issues",
45                $lit
46            );
47
48            // Make the predicate ineffective.
49            return None;
50        }
51    };
52}
53
54/// Reference-counted pointer to a list of logical exprs.
55#[derive(Debug, Clone)]
56pub struct Predicate {
57    /// logical exprs
58    exprs: Arc<Vec<Expr>>,
59}
60
61impl Predicate {
62    /// Creates a new `Predicate` by converting logical exprs to physical exprs that can be
63    /// evaluated against record batches.
64    /// Returns error when failed to convert exprs.
65    pub fn new(exprs: Vec<Expr>) -> Self {
66        Self {
67            exprs: Arc::new(exprs),
68        }
69    }
70
71    /// Returns the logical exprs.
72    pub fn exprs(&self) -> &[Expr] {
73        &self.exprs
74    }
75
76    /// Builds physical exprs according to provided schema.
77    pub fn to_physical_exprs(
78        &self,
79        schema: &arrow::datatypes::SchemaRef,
80    ) -> error::Result<Vec<Arc<dyn PhysicalExpr>>> {
81        let df_schema = schema
82            .clone()
83            .to_dfschema_ref()
84            .context(error::DatafusionSnafu)?;
85
86        // TODO(hl): `execution_props` provides variables required by evaluation.
87        // we may reuse the `execution_props` from `SessionState` once we support
88        // registering variables.
89        let execution_props = &ExecutionProps::new();
90
91        Ok(self
92            .exprs
93            .iter()
94            .filter_map(|expr| create_physical_expr(expr, df_schema.as_ref(), execution_props).ok())
95            .collect::<Vec<_>>())
96    }
97
98    /// Evaluates the predicate against the `stats`.
99    /// Returns a vector of boolean values, among which `false` means the row group can be skipped.
100    pub fn prune_with_stats<S: PruningStatistics>(
101        &self,
102        stats: &S,
103        schema: &arrow::datatypes::SchemaRef,
104    ) -> Vec<bool> {
105        let mut res = vec![true; stats.num_containers()];
106        let physical_exprs = match self.to_physical_exprs(schema) {
107            Ok(expr) => expr,
108            Err(e) => {
109                warn!(e; "Failed to build physical expr from predicates: {:?}", &self.exprs);
110                return res;
111            }
112        };
113
114        for expr in &physical_exprs {
115            match PruningPredicate::try_new(expr.clone(), schema.clone()) {
116                Ok(p) => match p.prune(stats) {
117                    Ok(r) => {
118                        for (curr_val, res) in r.into_iter().zip(res.iter_mut()) {
119                            *res &= curr_val
120                        }
121                    }
122                    Err(e) => {
123                        warn!(e; "Failed to prune row groups");
124                    }
125                },
126                Err(e) => {
127                    error!(e; "Failed to create predicate for expr");
128                }
129            }
130        }
131        res
132    }
133}
134
135// tests for `build_time_range_predicate` locates in src/query/tests/time_range_filter_test.rs
136// since it requires query engine to convert sql to filters.
137/// `build_time_range_predicate` extracts time range from logical exprs to facilitate fast
138/// time range pruning.
139pub fn build_time_range_predicate(
140    ts_col_name: &str,
141    ts_col_unit: TimeUnit,
142    filters: &[Expr],
143) -> TimestampRange {
144    let mut res = TimestampRange::min_to_max();
145    for expr in filters {
146        if let Some(range) = extract_time_range_from_expr(ts_col_name, ts_col_unit, expr) {
147            res = res.and(&range);
148        }
149    }
150    res
151}
152
153/// Extract time range filter from `WHERE`/`IN (...)`/`BETWEEN` clauses.
154/// Return None if no time range can be found in expr.
155fn extract_time_range_from_expr(
156    ts_col_name: &str,
157    ts_col_unit: TimeUnit,
158    expr: &Expr,
159) -> Option<TimestampRange> {
160    match expr {
161        Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
162            extract_from_binary_expr(ts_col_name, ts_col_unit, left, op, right)
163        }
164        Expr::Between(Between {
165            expr,
166            negated,
167            low,
168            high,
169        }) => extract_from_between_expr(ts_col_name, ts_col_unit, expr, negated, low, high),
170        Expr::InList(InList {
171            expr,
172            list,
173            negated,
174        }) => extract_from_in_list_expr(ts_col_name, expr, *negated, list),
175        _ => None,
176    }
177}
178
179fn extract_from_binary_expr(
180    ts_col_name: &str,
181    ts_col_unit: TimeUnit,
182    left: &Expr,
183    op: &Operator,
184    right: &Expr,
185) -> Option<TimestampRange> {
186    match op {
187        Operator::Eq => get_timestamp_filter(ts_col_name, left, right)
188            .and_then(|(ts, _)| ts.convert_to(ts_col_unit))
189            .map(TimestampRange::single),
190        Operator::Lt => {
191            let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
192            if reverse {
193                // [lit] < ts_col
194                let ts_val = ts.convert_to(ts_col_unit)?.value();
195                Some(TimestampRange::from_start(Timestamp::new(
196                    ts_val + 1,
197                    ts_col_unit,
198                )))
199            } else {
200                // ts_col < [lit]
201                ts.convert_to_ceil(ts_col_unit)
202                    .map(|ts| TimestampRange::until_end(ts, false))
203            }
204        }
205        Operator::LtEq => {
206            let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
207            if reverse {
208                // [lit] <= ts_col
209                ts.convert_to_ceil(ts_col_unit)
210                    .map(TimestampRange::from_start)
211            } else {
212                // ts_col <= [lit]
213                ts.convert_to(ts_col_unit)
214                    .map(|ts| TimestampRange::until_end(ts, true))
215            }
216        }
217        Operator::Gt => {
218            let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
219            if reverse {
220                // [lit] > ts_col
221                ts.convert_to_ceil(ts_col_unit)
222                    .map(|t| TimestampRange::until_end(t, false))
223            } else {
224                // ts_col > [lit]
225                let ts_val = ts.convert_to(ts_col_unit)?.value();
226                Some(TimestampRange::from_start(Timestamp::new(
227                    ts_val + 1,
228                    ts_col_unit,
229                )))
230            }
231        }
232        Operator::GtEq => {
233            let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
234            if reverse {
235                // [lit] >= ts_col
236                ts.convert_to(ts_col_unit)
237                    .map(|t| TimestampRange::until_end(t, true))
238            } else {
239                // ts_col >= [lit]
240                ts.convert_to_ceil(ts_col_unit)
241                    .map(TimestampRange::from_start)
242            }
243        }
244        Operator::And => {
245            // instead of return none when failed to extract time range from left/right, we unwrap the none into
246            // `TimestampRange::min_to_max`.
247            let left = extract_time_range_from_expr(ts_col_name, ts_col_unit, left)
248                .unwrap_or_else(TimestampRange::min_to_max);
249            let right = extract_time_range_from_expr(ts_col_name, ts_col_unit, right)
250                .unwrap_or_else(TimestampRange::min_to_max);
251            Some(left.and(&right))
252        }
253        Operator::Or => {
254            let left = extract_time_range_from_expr(ts_col_name, ts_col_unit, left)?;
255            let right = extract_time_range_from_expr(ts_col_name, ts_col_unit, right)?;
256            Some(left.or(&right))
257        }
258        _ => None,
259    }
260}
261
262fn get_timestamp_filter(ts_col_name: &str, left: &Expr, right: &Expr) -> Option<(Timestamp, bool)> {
263    let (col, lit, reverse) = match (left, right) {
264        (Expr::Column(column), Expr::Literal(scalar, _)) => (column, scalar, false),
265        (Expr::Literal(scalar, _), Expr::Column(column)) => (column, scalar, true),
266        _ => {
267            return None;
268        }
269    };
270    if col.name != ts_col_name {
271        return None;
272    }
273
274    return_none_if_utf8!(lit);
275    scalar_value_to_timestamp(lit, None).map(|t| (t, reverse))
276}
277
278fn extract_from_between_expr(
279    ts_col_name: &str,
280    ts_col_unit: TimeUnit,
281    expr: &Expr,
282    negated: &bool,
283    low: &Expr,
284    high: &Expr,
285) -> Option<TimestampRange> {
286    let Expr::Column(col) = expr else {
287        return None;
288    };
289    if col.name != ts_col_name {
290        return None;
291    }
292
293    if *negated {
294        return None;
295    }
296
297    match (low, high) {
298        (Expr::Literal(low, _), Expr::Literal(high, _)) => {
299            return_none_if_utf8!(low);
300            return_none_if_utf8!(high);
301
302            let low_opt =
303                scalar_value_to_timestamp(low, None).and_then(|ts| ts.convert_to(ts_col_unit));
304            let high_opt = scalar_value_to_timestamp(high, None)
305                .and_then(|ts| ts.convert_to_ceil(ts_col_unit));
306            Some(TimestampRange::new_inclusive(low_opt, high_opt))
307        }
308        _ => None,
309    }
310}
311
312/// Extract time range filter from `IN (...)` expr.
313fn extract_from_in_list_expr(
314    ts_col_name: &str,
315    expr: &Expr,
316    negated: bool,
317    list: &[Expr],
318) -> Option<TimestampRange> {
319    if negated {
320        return None;
321    }
322    let Expr::Column(col) = expr else {
323        return None;
324    };
325    if col.name != ts_col_name {
326        return None;
327    }
328
329    if list.is_empty() {
330        return Some(TimestampRange::empty());
331    }
332    let mut init_range = TimestampRange::empty();
333    for expr in list {
334        if let Expr::Literal(scalar, _) = expr {
335            return_none_if_utf8!(scalar);
336            if let Some(timestamp) = scalar_value_to_timestamp(scalar, None) {
337                init_range = init_range.or(&TimestampRange::single(timestamp))
338            } else {
339                // TODO(hl): maybe we should raise an error here since cannot parse
340                // timestamp value from in list expr
341                return None;
342            }
343        }
344    }
345    Some(init_range)
346}
347
348#[cfg(test)]
349mod tests {
350    use std::sync::Arc;
351
352    use common_test_util::temp_dir::{create_temp_dir, TempDir};
353    use datafusion::parquet::arrow::ArrowWriter;
354    use datafusion_common::{Column, ScalarValue};
355    use datafusion_expr::{col, lit, BinaryExpr, Literal, Operator};
356    use datatypes::arrow::array::Int32Array;
357    use datatypes::arrow::datatypes::{DataType, Field, Schema};
358    use datatypes::arrow::record_batch::RecordBatch;
359    use datatypes::arrow_array::StringArray;
360    use parquet::arrow::ParquetRecordBatchStreamBuilder;
361    use parquet::file::properties::WriterProperties;
362
363    use super::*;
364    use crate::predicate::stats::RowGroupPruningStatistics;
365
366    fn check_build_predicate(expr: Expr, expect: TimestampRange) {
367        assert_eq!(
368            expect,
369            build_time_range_predicate("ts", TimeUnit::Millisecond, &[expr])
370        );
371    }
372
373    #[test]
374    fn test_gt() {
375        // ts > 1ms
376        check_build_predicate(
377            col("ts").gt(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
378            TimestampRange::from_start(Timestamp::new_millisecond(2)),
379        );
380
381        // 1ms > ts
382        check_build_predicate(
383            lit(ScalarValue::TimestampMillisecond(Some(1), None)).gt(col("ts")),
384            TimestampRange::until_end(Timestamp::new_millisecond(1), false),
385        );
386
387        // 1001us > ts
388        check_build_predicate(
389            lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).gt(col("ts")),
390            TimestampRange::until_end(Timestamp::new_millisecond(1), true),
391        );
392
393        // ts > 1001us
394        check_build_predicate(
395            col("ts").gt(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
396            TimestampRange::from_start(Timestamp::new_millisecond(2)),
397        );
398
399        // 1s > ts
400        check_build_predicate(
401            lit(ScalarValue::TimestampSecond(Some(1), None)).gt(col("ts")),
402            TimestampRange::until_end(Timestamp::new_millisecond(1000), false),
403        );
404
405        // ts > 1s
406        check_build_predicate(
407            col("ts").gt(lit(ScalarValue::TimestampSecond(Some(1), None))),
408            TimestampRange::from_start(Timestamp::new_millisecond(1001)),
409        );
410    }
411
412    #[test]
413    fn test_gt_eq() {
414        // ts >= 1ms
415        check_build_predicate(
416            col("ts").gt_eq(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
417            TimestampRange::from_start(Timestamp::new_millisecond(1)),
418        );
419
420        // 1ms >= ts
421        check_build_predicate(
422            lit(ScalarValue::TimestampMillisecond(Some(1), None)).gt_eq(col("ts")),
423            TimestampRange::until_end(Timestamp::new_millisecond(1), true),
424        );
425
426        // 1001us >= ts
427        check_build_predicate(
428            lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).gt_eq(col("ts")),
429            TimestampRange::until_end(Timestamp::new_millisecond(1), true),
430        );
431
432        // ts >= 1001us
433        check_build_predicate(
434            col("ts").gt_eq(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
435            TimestampRange::from_start(Timestamp::new_millisecond(2)),
436        );
437
438        // 1s >= ts
439        check_build_predicate(
440            lit(ScalarValue::TimestampSecond(Some(1), None)).gt_eq(col("ts")),
441            TimestampRange::until_end(Timestamp::new_millisecond(1000), true),
442        );
443
444        // ts >= 1s
445        check_build_predicate(
446            col("ts").gt_eq(lit(ScalarValue::TimestampSecond(Some(1), None))),
447            TimestampRange::from_start(Timestamp::new_millisecond(1000)),
448        );
449    }
450
451    #[test]
452    fn test_lt() {
453        // ts < 1ms
454        check_build_predicate(
455            col("ts").lt(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
456            TimestampRange::until_end(Timestamp::new_millisecond(1), false),
457        );
458
459        // 1ms < ts
460        check_build_predicate(
461            lit(ScalarValue::TimestampMillisecond(Some(1), None)).lt(col("ts")),
462            TimestampRange::from_start(Timestamp::new_millisecond(2)),
463        );
464
465        // 1001us < ts
466        check_build_predicate(
467            lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).lt(col("ts")),
468            TimestampRange::from_start(Timestamp::new_millisecond(2)),
469        );
470
471        // ts < 1001us
472        check_build_predicate(
473            col("ts").lt(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
474            TimestampRange::until_end(Timestamp::new_millisecond(1), true),
475        );
476
477        // 1s < ts
478        check_build_predicate(
479            lit(ScalarValue::TimestampSecond(Some(1), None)).lt(col("ts")),
480            TimestampRange::from_start(Timestamp::new_millisecond(1001)),
481        );
482
483        // ts < 1s
484        check_build_predicate(
485            col("ts").lt(lit(ScalarValue::TimestampSecond(Some(1), None))),
486            TimestampRange::until_end(Timestamp::new_millisecond(1000), false),
487        );
488    }
489
490    #[test]
491    fn test_lt_eq() {
492        // ts <= 1ms
493        check_build_predicate(
494            col("ts").lt_eq(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
495            TimestampRange::until_end(Timestamp::new_millisecond(1), true),
496        );
497
498        // 1ms <= ts
499        check_build_predicate(
500            lit(ScalarValue::TimestampMillisecond(Some(1), None)).lt_eq(col("ts")),
501            TimestampRange::from_start(Timestamp::new_millisecond(1)),
502        );
503
504        // 1001us <= ts
505        check_build_predicate(
506            lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).lt_eq(col("ts")),
507            TimestampRange::from_start(Timestamp::new_millisecond(2)),
508        );
509
510        // ts <= 1001us
511        check_build_predicate(
512            col("ts").lt_eq(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
513            TimestampRange::until_end(Timestamp::new_millisecond(1), true),
514        );
515
516        // 1s <= ts
517        check_build_predicate(
518            lit(ScalarValue::TimestampSecond(Some(1), None)).lt_eq(col("ts")),
519            TimestampRange::from_start(Timestamp::new_millisecond(1000)),
520        );
521
522        // ts <= 1s
523        check_build_predicate(
524            col("ts").lt_eq(lit(ScalarValue::TimestampSecond(Some(1), None))),
525            TimestampRange::until_end(Timestamp::new_millisecond(1000), true),
526        );
527    }
528
529    async fn gen_test_parquet_file(dir: &TempDir, cnt: usize) -> (String, Arc<Schema>) {
530        let path = dir
531            .path()
532            .join("test-prune.parquet")
533            .to_string_lossy()
534            .to_string();
535
536        let name_field = Field::new("name", DataType::Utf8, true);
537        let count_field = Field::new("cnt", DataType::Int32, true);
538        let schema = Arc::new(Schema::new(vec![name_field, count_field]));
539
540        let file = std::fs::OpenOptions::new()
541            .write(true)
542            .create(true)
543            .truncate(true)
544            .open(path.clone())
545            .unwrap();
546
547        let write_props = WriterProperties::builder()
548            .set_max_row_group_size(10)
549            .build();
550        let mut writer = ArrowWriter::try_new(file, schema.clone(), Some(write_props)).unwrap();
551
552        for i in (0..cnt).step_by(10) {
553            let name_array = Arc::new(StringArray::from(
554                (i..(i + 10).min(cnt))
555                    .map(|i| i.to_string())
556                    .collect::<Vec<_>>(),
557            )) as Arc<_>;
558            let count_array = Arc::new(Int32Array::from(
559                (i..(i + 10).min(cnt)).map(|i| i as i32).collect::<Vec<_>>(),
560            )) as Arc<_>;
561            let rb = RecordBatch::try_new(schema.clone(), vec![name_array, count_array]).unwrap();
562            writer.write(&rb).unwrap();
563        }
564        let _ = writer.close().unwrap();
565        (path, schema)
566    }
567
568    async fn assert_prune(array_cnt: usize, filters: Vec<Expr>, expect: Vec<bool>) {
569        let dir = create_temp_dir("prune_parquet");
570        let (path, arrow_schema) = gen_test_parquet_file(&dir, array_cnt).await;
571        let schema = Arc::new(datatypes::schema::Schema::try_from(arrow_schema.clone()).unwrap());
572        let arrow_predicate = Predicate::new(filters);
573        let builder = ParquetRecordBatchStreamBuilder::new(
574            tokio::fs::OpenOptions::new()
575                .read(true)
576                .open(path)
577                .await
578                .unwrap(),
579        )
580        .await
581        .unwrap();
582        let metadata = builder.metadata().clone();
583        let row_groups = metadata.row_groups();
584
585        let stats = RowGroupPruningStatistics::new(row_groups, &schema);
586        let res = arrow_predicate.prune_with_stats(&stats, &arrow_schema);
587        assert_eq!(expect, res);
588    }
589
590    fn gen_predicate(max_val: i32, op: Operator) -> Vec<Expr> {
591        vec![datafusion_expr::Expr::BinaryExpr(BinaryExpr {
592            left: Box::new(datafusion_expr::Expr::Column(Column::from_name("cnt"))),
593            op,
594            right: Box::new(max_val.lit()),
595        })]
596    }
597
598    #[tokio::test]
599    async fn test_prune_empty() {
600        assert_prune(3, vec![], vec![true]).await;
601    }
602
603    #[tokio::test]
604    async fn test_prune_all_match() {
605        let p = gen_predicate(3, Operator::Gt);
606        assert_prune(2, p, vec![false]).await;
607    }
608
609    #[tokio::test]
610    async fn test_prune_gt() {
611        let p = gen_predicate(29, Operator::Gt);
612        assert_prune(
613            100,
614            p,
615            vec![
616                false, false, false, true, true, true, true, true, true, true,
617            ],
618        )
619        .await;
620    }
621
622    #[tokio::test]
623    async fn test_prune_eq_expr() {
624        let p = gen_predicate(30, Operator::Eq);
625        assert_prune(40, p, vec![false, false, false, true]).await;
626    }
627
628    #[tokio::test]
629    async fn test_prune_neq_expr() {
630        let p = gen_predicate(30, Operator::NotEq);
631        assert_prune(40, p, vec![true, true, true, true]).await;
632    }
633
634    #[tokio::test]
635    async fn test_prune_gteq_expr() {
636        let p = gen_predicate(29, Operator::GtEq);
637        assert_prune(40, p, vec![false, false, true, true]).await;
638    }
639
640    #[tokio::test]
641    async fn test_prune_lt_expr() {
642        let p = gen_predicate(30, Operator::Lt);
643        assert_prune(40, p, vec![true, true, true, false]).await;
644    }
645
646    #[tokio::test]
647    async fn test_prune_lteq_expr() {
648        let p = gen_predicate(30, Operator::LtEq);
649        assert_prune(40, p, vec![true, true, true, true]).await;
650    }
651
652    #[tokio::test]
653    async fn test_prune_between_expr() {
654        let p = gen_predicate(30, Operator::LtEq);
655        assert_prune(40, p, vec![true, true, true, true]).await;
656    }
657
658    #[tokio::test]
659    async fn test_or() {
660        // cnt > 30 or cnt < 20
661        let e = datafusion_expr::Expr::Column(Column::from_name("cnt"))
662            .gt(30.lit())
663            .or(datafusion_expr::Expr::Column(Column::from_name("cnt")).lt(20.lit()));
664        assert_prune(40, vec![e], vec![true, true, false, true]).await;
665    }
666
667    #[tokio::test]
668    async fn test_to_physical_expr() {
669        let predicate = Predicate::new(vec![
670            col("host").eq(lit("host_a")),
671            col("ts").gt(lit(ScalarValue::TimestampMicrosecond(Some(123), None))),
672        ]);
673
674        let schema = Arc::new(arrow::datatypes::Schema::new(vec![Field::new(
675            "host",
676            arrow::datatypes::DataType::Utf8,
677            false,
678        )]));
679
680        let predicates = predicate.to_physical_exprs(&schema).unwrap();
681        assert!(!predicates.is_empty());
682    }
683}