table/
predicate.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use common_telemetry::{error, warn};
18use common_time::range::TimestampRange;
19use common_time::timestamp::TimeUnit;
20use common_time::Timestamp;
21use datafusion::common::ScalarValue;
22use datafusion::physical_optimizer::pruning::{PruningPredicate, PruningStatistics};
23use datafusion_common::ToDFSchema;
24use datafusion_expr::expr::{Expr, InList};
25use datafusion_expr::{Between, BinaryExpr, Operator};
26use datafusion_physical_expr::execution_props::ExecutionProps;
27use datafusion_physical_expr::{create_physical_expr, PhysicalExpr};
28use datatypes::arrow;
29use datatypes::value::scalar_value_to_timestamp;
30use snafu::ResultExt;
31
32use crate::error;
33
34#[cfg(test)]
35mod stats;
36
37/// Assert the scalar value is not utf8. Returns `None` if it's utf8.
38/// In theory, it should be converted to a timestamp scalar value by `TypeConversionRule`.
39macro_rules! return_none_if_utf8 {
40    ($lit: ident) => {
41        if matches!($lit, ScalarValue::Utf8(_)) {
42            warn!(
43                "Unexpected ScalarValue::Utf8 in time range predicate: {:?}. Maybe it's an implicit bug, please report it to https://github.com/GreptimeTeam/greptimedb/issues",
44                $lit
45            );
46
47            // Make the predicate ineffective.
48            return None;
49        }
50    };
51}
52
53/// Reference-counted pointer to a list of logical exprs.
54#[derive(Debug, Clone)]
55pub struct Predicate {
56    /// logical exprs
57    exprs: Arc<Vec<Expr>>,
58}
59
60impl Predicate {
61    /// Creates a new `Predicate` by converting logical exprs to physical exprs that can be
62    /// evaluated against record batches.
63    /// Returns error when failed to convert exprs.
64    pub fn new(exprs: Vec<Expr>) -> Self {
65        Self {
66            exprs: Arc::new(exprs),
67        }
68    }
69
70    /// Returns the logical exprs.
71    pub fn exprs(&self) -> &[Expr] {
72        &self.exprs
73    }
74
75    /// Builds physical exprs according to provided schema.
76    pub fn to_physical_exprs(
77        &self,
78        schema: &arrow::datatypes::SchemaRef,
79    ) -> error::Result<Vec<Arc<dyn PhysicalExpr>>> {
80        let df_schema = schema
81            .clone()
82            .to_dfschema_ref()
83            .context(error::DatafusionSnafu)?;
84
85        // TODO(hl): `execution_props` provides variables required by evaluation.
86        // we may reuse the `execution_props` from `SessionState` once we support
87        // registering variables.
88        let execution_props = &ExecutionProps::new();
89
90        Ok(self
91            .exprs
92            .iter()
93            .filter_map(|expr| create_physical_expr(expr, df_schema.as_ref(), execution_props).ok())
94            .collect::<Vec<_>>())
95    }
96
97    /// Evaluates the predicate against the `stats`.
98    /// Returns a vector of boolean values, among which `false` means the row group can be skipped.
99    pub fn prune_with_stats<S: PruningStatistics>(
100        &self,
101        stats: &S,
102        schema: &arrow::datatypes::SchemaRef,
103    ) -> Vec<bool> {
104        let mut res = vec![true; stats.num_containers()];
105        let physical_exprs = match self.to_physical_exprs(schema) {
106            Ok(expr) => expr,
107            Err(e) => {
108                warn!(e; "Failed to build physical expr from predicates: {:?}", &self.exprs);
109                return res;
110            }
111        };
112
113        for expr in &physical_exprs {
114            match PruningPredicate::try_new(expr.clone(), schema.clone()) {
115                Ok(p) => match p.prune(stats) {
116                    Ok(r) => {
117                        for (curr_val, res) in r.into_iter().zip(res.iter_mut()) {
118                            *res &= curr_val
119                        }
120                    }
121                    Err(e) => {
122                        warn!(e; "Failed to prune row groups");
123                    }
124                },
125                Err(e) => {
126                    error!(e; "Failed to create predicate for expr");
127                }
128            }
129        }
130        res
131    }
132}
133
134// tests for `build_time_range_predicate` locates in src/query/tests/time_range_filter_test.rs
135// since it requires query engine to convert sql to filters.
136/// `build_time_range_predicate` extracts time range from logical exprs to facilitate fast
137/// time range pruning.
138pub fn build_time_range_predicate(
139    ts_col_name: &str,
140    ts_col_unit: TimeUnit,
141    filters: &[Expr],
142) -> TimestampRange {
143    let mut res = TimestampRange::min_to_max();
144    for expr in filters {
145        if let Some(range) = extract_time_range_from_expr(ts_col_name, ts_col_unit, expr) {
146            res = res.and(&range);
147        }
148    }
149    res
150}
151
152/// Extract time range filter from `WHERE`/`IN (...)`/`BETWEEN` clauses.
153/// Return None if no time range can be found in expr.
154fn extract_time_range_from_expr(
155    ts_col_name: &str,
156    ts_col_unit: TimeUnit,
157    expr: &Expr,
158) -> Option<TimestampRange> {
159    match expr {
160        Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
161            extract_from_binary_expr(ts_col_name, ts_col_unit, left, op, right)
162        }
163        Expr::Between(Between {
164            expr,
165            negated,
166            low,
167            high,
168        }) => extract_from_between_expr(ts_col_name, ts_col_unit, expr, negated, low, high),
169        Expr::InList(InList {
170            expr,
171            list,
172            negated,
173        }) => extract_from_in_list_expr(ts_col_name, expr, *negated, list),
174        _ => None,
175    }
176}
177
178fn extract_from_binary_expr(
179    ts_col_name: &str,
180    ts_col_unit: TimeUnit,
181    left: &Expr,
182    op: &Operator,
183    right: &Expr,
184) -> Option<TimestampRange> {
185    match op {
186        Operator::Eq => get_timestamp_filter(ts_col_name, left, right)
187            .and_then(|(ts, _)| ts.convert_to(ts_col_unit))
188            .map(TimestampRange::single),
189        Operator::Lt => {
190            let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
191            if reverse {
192                // [lit] < ts_col
193                let ts_val = ts.convert_to(ts_col_unit)?.value();
194                Some(TimestampRange::from_start(Timestamp::new(
195                    ts_val + 1,
196                    ts_col_unit,
197                )))
198            } else {
199                // ts_col < [lit]
200                ts.convert_to_ceil(ts_col_unit)
201                    .map(|ts| TimestampRange::until_end(ts, false))
202            }
203        }
204        Operator::LtEq => {
205            let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
206            if reverse {
207                // [lit] <= ts_col
208                ts.convert_to_ceil(ts_col_unit)
209                    .map(TimestampRange::from_start)
210            } else {
211                // ts_col <= [lit]
212                ts.convert_to(ts_col_unit)
213                    .map(|ts| TimestampRange::until_end(ts, true))
214            }
215        }
216        Operator::Gt => {
217            let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
218            if reverse {
219                // [lit] > ts_col
220                ts.convert_to_ceil(ts_col_unit)
221                    .map(|t| TimestampRange::until_end(t, false))
222            } else {
223                // ts_col > [lit]
224                let ts_val = ts.convert_to(ts_col_unit)?.value();
225                Some(TimestampRange::from_start(Timestamp::new(
226                    ts_val + 1,
227                    ts_col_unit,
228                )))
229            }
230        }
231        Operator::GtEq => {
232            let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
233            if reverse {
234                // [lit] >= ts_col
235                ts.convert_to(ts_col_unit)
236                    .map(|t| TimestampRange::until_end(t, true))
237            } else {
238                // ts_col >= [lit]
239                ts.convert_to_ceil(ts_col_unit)
240                    .map(TimestampRange::from_start)
241            }
242        }
243        Operator::And => {
244            // instead of return none when failed to extract time range from left/right, we unwrap the none into
245            // `TimestampRange::min_to_max`.
246            let left = extract_time_range_from_expr(ts_col_name, ts_col_unit, left)
247                .unwrap_or_else(TimestampRange::min_to_max);
248            let right = extract_time_range_from_expr(ts_col_name, ts_col_unit, right)
249                .unwrap_or_else(TimestampRange::min_to_max);
250            Some(left.and(&right))
251        }
252        Operator::Or => {
253            let left = extract_time_range_from_expr(ts_col_name, ts_col_unit, left)?;
254            let right = extract_time_range_from_expr(ts_col_name, ts_col_unit, right)?;
255            Some(left.or(&right))
256        }
257        _ => None,
258    }
259}
260
261fn get_timestamp_filter(ts_col_name: &str, left: &Expr, right: &Expr) -> Option<(Timestamp, bool)> {
262    let (col, lit, reverse) = match (left, right) {
263        (Expr::Column(column), Expr::Literal(scalar)) => (column, scalar, false),
264        (Expr::Literal(scalar), Expr::Column(column)) => (column, scalar, true),
265        _ => {
266            return None;
267        }
268    };
269    if col.name != ts_col_name {
270        return None;
271    }
272
273    return_none_if_utf8!(lit);
274    scalar_value_to_timestamp(lit, None).map(|t| (t, reverse))
275}
276
277fn extract_from_between_expr(
278    ts_col_name: &str,
279    ts_col_unit: TimeUnit,
280    expr: &Expr,
281    negated: &bool,
282    low: &Expr,
283    high: &Expr,
284) -> Option<TimestampRange> {
285    let Expr::Column(col) = expr else {
286        return None;
287    };
288    if col.name != ts_col_name {
289        return None;
290    }
291
292    if *negated {
293        return None;
294    }
295
296    match (low, high) {
297        (Expr::Literal(low), Expr::Literal(high)) => {
298            return_none_if_utf8!(low);
299            return_none_if_utf8!(high);
300
301            let low_opt =
302                scalar_value_to_timestamp(low, None).and_then(|ts| ts.convert_to(ts_col_unit));
303            let high_opt = scalar_value_to_timestamp(high, None)
304                .and_then(|ts| ts.convert_to_ceil(ts_col_unit));
305            Some(TimestampRange::new_inclusive(low_opt, high_opt))
306        }
307        _ => None,
308    }
309}
310
311/// Extract time range filter from `IN (...)` expr.
312fn extract_from_in_list_expr(
313    ts_col_name: &str,
314    expr: &Expr,
315    negated: bool,
316    list: &[Expr],
317) -> Option<TimestampRange> {
318    if negated {
319        return None;
320    }
321    let Expr::Column(col) = expr else {
322        return None;
323    };
324    if col.name != ts_col_name {
325        return None;
326    }
327
328    if list.is_empty() {
329        return Some(TimestampRange::empty());
330    }
331    let mut init_range = TimestampRange::empty();
332    for expr in list {
333        if let Expr::Literal(scalar) = expr {
334            return_none_if_utf8!(scalar);
335            if let Some(timestamp) = scalar_value_to_timestamp(scalar, None) {
336                init_range = init_range.or(&TimestampRange::single(timestamp))
337            } else {
338                // TODO(hl): maybe we should raise an error here since cannot parse
339                // timestamp value from in list expr
340                return None;
341            }
342        }
343    }
344    Some(init_range)
345}
346
347#[cfg(test)]
348mod tests {
349    use std::sync::Arc;
350
351    use common_test_util::temp_dir::{create_temp_dir, TempDir};
352    use datafusion::parquet::arrow::ArrowWriter;
353    use datafusion_common::{Column, ScalarValue};
354    use datafusion_expr::{col, lit, BinaryExpr, Literal, Operator};
355    use datatypes::arrow::array::Int32Array;
356    use datatypes::arrow::datatypes::{DataType, Field, Schema};
357    use datatypes::arrow::record_batch::RecordBatch;
358    use datatypes::arrow_array::StringArray;
359    use parquet::arrow::ParquetRecordBatchStreamBuilder;
360    use parquet::file::properties::WriterProperties;
361
362    use super::*;
363    use crate::predicate::stats::RowGroupPruningStatistics;
364
365    fn check_build_predicate(expr: Expr, expect: TimestampRange) {
366        assert_eq!(
367            expect,
368            build_time_range_predicate("ts", TimeUnit::Millisecond, &[expr])
369        );
370    }
371
372    #[test]
373    fn test_gt() {
374        // ts > 1ms
375        check_build_predicate(
376            col("ts").gt(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
377            TimestampRange::from_start(Timestamp::new_millisecond(2)),
378        );
379
380        // 1ms > ts
381        check_build_predicate(
382            lit(ScalarValue::TimestampMillisecond(Some(1), None)).gt(col("ts")),
383            TimestampRange::until_end(Timestamp::new_millisecond(1), false),
384        );
385
386        // 1001us > ts
387        check_build_predicate(
388            lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).gt(col("ts")),
389            TimestampRange::until_end(Timestamp::new_millisecond(1), true),
390        );
391
392        // ts > 1001us
393        check_build_predicate(
394            col("ts").gt(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
395            TimestampRange::from_start(Timestamp::new_millisecond(2)),
396        );
397
398        // 1s > ts
399        check_build_predicate(
400            lit(ScalarValue::TimestampSecond(Some(1), None)).gt(col("ts")),
401            TimestampRange::until_end(Timestamp::new_millisecond(1000), false),
402        );
403
404        // ts > 1s
405        check_build_predicate(
406            col("ts").gt(lit(ScalarValue::TimestampSecond(Some(1), None))),
407            TimestampRange::from_start(Timestamp::new_millisecond(1001)),
408        );
409    }
410
411    #[test]
412    fn test_gt_eq() {
413        // ts >= 1ms
414        check_build_predicate(
415            col("ts").gt_eq(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
416            TimestampRange::from_start(Timestamp::new_millisecond(1)),
417        );
418
419        // 1ms >= ts
420        check_build_predicate(
421            lit(ScalarValue::TimestampMillisecond(Some(1), None)).gt_eq(col("ts")),
422            TimestampRange::until_end(Timestamp::new_millisecond(1), true),
423        );
424
425        // 1001us >= ts
426        check_build_predicate(
427            lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).gt_eq(col("ts")),
428            TimestampRange::until_end(Timestamp::new_millisecond(1), true),
429        );
430
431        // ts >= 1001us
432        check_build_predicate(
433            col("ts").gt_eq(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
434            TimestampRange::from_start(Timestamp::new_millisecond(2)),
435        );
436
437        // 1s >= ts
438        check_build_predicate(
439            lit(ScalarValue::TimestampSecond(Some(1), None)).gt_eq(col("ts")),
440            TimestampRange::until_end(Timestamp::new_millisecond(1000), true),
441        );
442
443        // ts >= 1s
444        check_build_predicate(
445            col("ts").gt_eq(lit(ScalarValue::TimestampSecond(Some(1), None))),
446            TimestampRange::from_start(Timestamp::new_millisecond(1000)),
447        );
448    }
449
450    #[test]
451    fn test_lt() {
452        // ts < 1ms
453        check_build_predicate(
454            col("ts").lt(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
455            TimestampRange::until_end(Timestamp::new_millisecond(1), false),
456        );
457
458        // 1ms < ts
459        check_build_predicate(
460            lit(ScalarValue::TimestampMillisecond(Some(1), None)).lt(col("ts")),
461            TimestampRange::from_start(Timestamp::new_millisecond(2)),
462        );
463
464        // 1001us < ts
465        check_build_predicate(
466            lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).lt(col("ts")),
467            TimestampRange::from_start(Timestamp::new_millisecond(2)),
468        );
469
470        // ts < 1001us
471        check_build_predicate(
472            col("ts").lt(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
473            TimestampRange::until_end(Timestamp::new_millisecond(1), true),
474        );
475
476        // 1s < ts
477        check_build_predicate(
478            lit(ScalarValue::TimestampSecond(Some(1), None)).lt(col("ts")),
479            TimestampRange::from_start(Timestamp::new_millisecond(1001)),
480        );
481
482        // ts < 1s
483        check_build_predicate(
484            col("ts").lt(lit(ScalarValue::TimestampSecond(Some(1), None))),
485            TimestampRange::until_end(Timestamp::new_millisecond(1000), false),
486        );
487    }
488
489    #[test]
490    fn test_lt_eq() {
491        // ts <= 1ms
492        check_build_predicate(
493            col("ts").lt_eq(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
494            TimestampRange::until_end(Timestamp::new_millisecond(1), true),
495        );
496
497        // 1ms <= ts
498        check_build_predicate(
499            lit(ScalarValue::TimestampMillisecond(Some(1), None)).lt_eq(col("ts")),
500            TimestampRange::from_start(Timestamp::new_millisecond(1)),
501        );
502
503        // 1001us <= ts
504        check_build_predicate(
505            lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).lt_eq(col("ts")),
506            TimestampRange::from_start(Timestamp::new_millisecond(2)),
507        );
508
509        // ts <= 1001us
510        check_build_predicate(
511            col("ts").lt_eq(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
512            TimestampRange::until_end(Timestamp::new_millisecond(1), true),
513        );
514
515        // 1s <= ts
516        check_build_predicate(
517            lit(ScalarValue::TimestampSecond(Some(1), None)).lt_eq(col("ts")),
518            TimestampRange::from_start(Timestamp::new_millisecond(1000)),
519        );
520
521        // ts <= 1s
522        check_build_predicate(
523            col("ts").lt_eq(lit(ScalarValue::TimestampSecond(Some(1), None))),
524            TimestampRange::until_end(Timestamp::new_millisecond(1000), true),
525        );
526    }
527
528    async fn gen_test_parquet_file(dir: &TempDir, cnt: usize) -> (String, Arc<Schema>) {
529        let path = dir
530            .path()
531            .join("test-prune.parquet")
532            .to_string_lossy()
533            .to_string();
534
535        let name_field = Field::new("name", DataType::Utf8, true);
536        let count_field = Field::new("cnt", DataType::Int32, true);
537        let schema = Arc::new(Schema::new(vec![name_field, count_field]));
538
539        let file = std::fs::OpenOptions::new()
540            .write(true)
541            .create(true)
542            .truncate(true)
543            .open(path.clone())
544            .unwrap();
545
546        let write_props = WriterProperties::builder()
547            .set_max_row_group_size(10)
548            .build();
549        let mut writer = ArrowWriter::try_new(file, schema.clone(), Some(write_props)).unwrap();
550
551        for i in (0..cnt).step_by(10) {
552            let name_array = Arc::new(StringArray::from(
553                (i..(i + 10).min(cnt))
554                    .map(|i| i.to_string())
555                    .collect::<Vec<_>>(),
556            )) as Arc<_>;
557            let count_array = Arc::new(Int32Array::from(
558                (i..(i + 10).min(cnt)).map(|i| i as i32).collect::<Vec<_>>(),
559            )) as Arc<_>;
560            let rb = RecordBatch::try_new(schema.clone(), vec![name_array, count_array]).unwrap();
561            writer.write(&rb).unwrap();
562        }
563        let _ = writer.close().unwrap();
564        (path, schema)
565    }
566
567    async fn assert_prune(array_cnt: usize, filters: Vec<Expr>, expect: Vec<bool>) {
568        let dir = create_temp_dir("prune_parquet");
569        let (path, arrow_schema) = gen_test_parquet_file(&dir, array_cnt).await;
570        let schema = Arc::new(datatypes::schema::Schema::try_from(arrow_schema.clone()).unwrap());
571        let arrow_predicate = Predicate::new(filters);
572        let builder = ParquetRecordBatchStreamBuilder::new(
573            tokio::fs::OpenOptions::new()
574                .read(true)
575                .open(path)
576                .await
577                .unwrap(),
578        )
579        .await
580        .unwrap();
581        let metadata = builder.metadata().clone();
582        let row_groups = metadata.row_groups();
583
584        let stats = RowGroupPruningStatistics::new(row_groups, &schema);
585        let res = arrow_predicate.prune_with_stats(&stats, &arrow_schema);
586        assert_eq!(expect, res);
587    }
588
589    fn gen_predicate(max_val: i32, op: Operator) -> Vec<Expr> {
590        vec![datafusion_expr::Expr::BinaryExpr(BinaryExpr {
591            left: Box::new(datafusion_expr::Expr::Column(Column::from_name("cnt"))),
592            op,
593            right: Box::new(datafusion_expr::Expr::Literal(ScalarValue::Int32(Some(
594                max_val,
595            )))),
596        })]
597    }
598
599    #[tokio::test]
600    async fn test_prune_empty() {
601        assert_prune(3, vec![], vec![true]).await;
602    }
603
604    #[tokio::test]
605    async fn test_prune_all_match() {
606        let p = gen_predicate(3, Operator::Gt);
607        assert_prune(2, p, vec![false]).await;
608    }
609
610    #[tokio::test]
611    async fn test_prune_gt() {
612        let p = gen_predicate(29, Operator::Gt);
613        assert_prune(
614            100,
615            p,
616            vec![
617                false, false, false, true, true, true, true, true, true, true,
618            ],
619        )
620        .await;
621    }
622
623    #[tokio::test]
624    async fn test_prune_eq_expr() {
625        let p = gen_predicate(30, Operator::Eq);
626        assert_prune(40, p, vec![false, false, false, true]).await;
627    }
628
629    #[tokio::test]
630    async fn test_prune_neq_expr() {
631        let p = gen_predicate(30, Operator::NotEq);
632        assert_prune(40, p, vec![true, true, true, true]).await;
633    }
634
635    #[tokio::test]
636    async fn test_prune_gteq_expr() {
637        let p = gen_predicate(29, Operator::GtEq);
638        assert_prune(40, p, vec![false, false, true, true]).await;
639    }
640
641    #[tokio::test]
642    async fn test_prune_lt_expr() {
643        let p = gen_predicate(30, Operator::Lt);
644        assert_prune(40, p, vec![true, true, true, false]).await;
645    }
646
647    #[tokio::test]
648    async fn test_prune_lteq_expr() {
649        let p = gen_predicate(30, Operator::LtEq);
650        assert_prune(40, p, vec![true, true, true, true]).await;
651    }
652
653    #[tokio::test]
654    async fn test_prune_between_expr() {
655        let p = gen_predicate(30, Operator::LtEq);
656        assert_prune(40, p, vec![true, true, true, true]).await;
657    }
658
659    #[tokio::test]
660    async fn test_or() {
661        // cnt > 30 or cnt < 20
662        let e = datafusion_expr::Expr::Column(Column::from_name("cnt"))
663            .gt(30.lit())
664            .or(datafusion_expr::Expr::Column(Column::from_name("cnt")).lt(20.lit()));
665        assert_prune(40, vec![e], vec![true, true, false, true]).await;
666    }
667
668    #[tokio::test]
669    async fn test_to_physical_expr() {
670        let predicate = Predicate::new(vec![
671            col("host").eq(lit("host_a")),
672            col("ts").gt(lit(ScalarValue::TimestampMicrosecond(Some(123), None))),
673        ]);
674
675        let schema = Arc::new(arrow::datatypes::Schema::new(vec![Field::new(
676            "host",
677            arrow::datatypes::DataType::Utf8,
678            false,
679        )]));
680
681        let predicates = predicate.to_physical_exprs(&schema).unwrap();
682        assert!(!predicates.is_empty());
683    }
684}