1use std::sync::Arc;
16
17use arc_swap::ArcSwap;
18use common_telemetry::{debug, warn};
19use common_time::Timestamp;
20use common_time::range::TimestampRange;
21use common_time::timestamp::TimeUnit;
22use datafusion::common::ScalarValue;
23use datafusion::physical_optimizer::pruning::PruningPredicate;
24use datafusion_common::ToDFSchema;
25use datafusion_common::pruning::PruningStatistics;
26use datafusion_expr::expr::{Expr, InList};
27use datafusion_expr::{Between, BinaryExpr, Operator};
28use datafusion_physical_expr::execution_props::ExecutionProps;
29use datafusion_physical_expr::expressions::DynamicFilterPhysicalExpr;
30use datafusion_physical_expr::{PhysicalExpr, create_physical_expr};
31use datatypes::arrow;
32use datatypes::value::scalar_value_to_timestamp;
33use snafu::ResultExt;
34
35use crate::error;
36
37#[cfg(test)]
38mod stats;
39
40macro_rules! return_none_if_utf8 {
43 ($lit: ident) => {
44 if matches!($lit, ScalarValue::Utf8(_)) {
45 warn!(
46 "Unexpected ScalarValue::Utf8 in time range predicate: {:?}. Maybe it's an implicit bug, please report it to https://github.com/GreptimeTeam/greptimedb/issues",
47 $lit
48 );
49
50 return None;
52 }
53 };
54}
55
56#[derive(Debug, Clone, Default)]
58pub struct Predicate {
59 exprs: Arc<Vec<Expr>>,
61 dyn_filters: Arc<ArcSwap<Vec<Arc<DynamicFilterPhysicalExpr>>>>,
65}
66
67impl Predicate {
68 pub fn new(exprs: Vec<Expr>) -> Self {
72 Self {
73 exprs: Arc::new(exprs),
74 dyn_filters: Arc::new(ArcSwap::new(Arc::new(vec![]))),
75 }
76 }
77
78 pub fn with_dyn_filters(
79 exprs: Vec<Expr>,
80 dyn_filters: Vec<Arc<DynamicFilterPhysicalExpr>>,
81 ) -> Self {
82 Self {
83 exprs: Arc::new(exprs),
84 dyn_filters: Arc::new(ArcSwap::new(Arc::new(dyn_filters))),
85 }
86 }
87
88 pub fn is_empty(&self) -> bool {
89 self.exprs.is_empty() && self.dyn_filters.load().is_empty()
90 }
91
92 pub fn add_dyn_filters(&self, dyn_filters: Vec<Arc<DynamicFilterPhysicalExpr>>) {
94 self.dyn_filters.rcu(|existing| {
95 let mut new_filters = existing.as_ref().clone();
96 new_filters.extend(dyn_filters.clone());
97 Arc::new(new_filters)
98 });
99 }
100
101 pub fn exprs(&self) -> &[Expr] {
103 &self.exprs
104 }
105
106 pub fn dyn_filters(&self) -> Arc<Vec<Arc<DynamicFilterPhysicalExpr>>> {
109 self.dyn_filters.load_full()
110 }
111
112 pub fn dyn_filter_phy_exprs(&self) -> error::Result<Vec<Arc<dyn PhysicalExpr>>> {
115 self.dyn_filters
116 .load()
117 .iter()
118 .map(|e| e.current())
119 .collect::<Result<Vec<_>, _>>()
120 .context(error::DatafusionSnafu)
121 }
122
123 pub fn to_physical_exprs(
125 &self,
126 schema: &arrow::datatypes::SchemaRef,
127 ) -> error::Result<Vec<Arc<dyn PhysicalExpr>>> {
128 let df_schema = schema
129 .clone()
130 .to_dfschema_ref()
131 .context(error::DatafusionSnafu)?;
132
133 let execution_props = &ExecutionProps::new();
137
138 let dyn_filters = self.dyn_filter_phy_exprs()?;
139
140 Ok(self
141 .exprs
142 .iter()
143 .filter_map(|expr| create_physical_expr(expr, df_schema.as_ref(), execution_props).ok())
144 .chain(dyn_filters)
145 .collect::<Vec<_>>())
146 }
147
148 pub fn prune_with_stats<S: PruningStatistics>(
151 &self,
152 stats: &S,
153 schema: &arrow::datatypes::SchemaRef,
154 ) -> Vec<bool> {
155 let mut res = vec![true; stats.num_containers()];
156 let physical_exprs = match self.to_physical_exprs(schema) {
157 Ok(expr) => expr,
158 Err(e) => {
159 warn!(e; "Failed to build physical expr from predicates: {:?}", &self.exprs);
160 return res;
161 }
162 };
163
164 for expr in &physical_exprs {
165 match PruningPredicate::try_new(expr.clone(), schema.clone()) {
166 Ok(p) => match p.prune(stats) {
167 Ok(r) => {
168 for (curr_val, res) in r.into_iter().zip(res.iter_mut()) {
169 *res &= curr_val
170 }
171 }
172 Err(e) => {
173 warn!(e; "Failed to prune row groups");
174 }
175 },
176 Err(e) => {
177 debug!("Failed to create pruning predicate for expr: {e:?}");
179 }
180 }
181 }
182 res
183 }
184}
185
186pub fn build_time_range_predicate(
191 ts_col_name: &str,
192 ts_col_unit: TimeUnit,
193 filters: &[Expr],
194) -> TimestampRange {
195 let mut res = TimestampRange::min_to_max();
196 for expr in filters {
197 if let Some(range) = extract_time_range_from_expr(ts_col_name, ts_col_unit, expr) {
198 res = res.and(&range);
199 }
200 }
201 res
202}
203
204fn extract_time_range_from_expr(
207 ts_col_name: &str,
208 ts_col_unit: TimeUnit,
209 expr: &Expr,
210) -> Option<TimestampRange> {
211 match expr {
212 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
213 extract_from_binary_expr(ts_col_name, ts_col_unit, left, op, right)
214 }
215 Expr::Between(Between {
216 expr,
217 negated,
218 low,
219 high,
220 }) => extract_from_between_expr(ts_col_name, ts_col_unit, expr, negated, low, high),
221 Expr::InList(InList {
222 expr,
223 list,
224 negated,
225 }) => extract_from_in_list_expr(ts_col_name, expr, *negated, list),
226 _ => None,
227 }
228}
229
230fn extract_from_binary_expr(
231 ts_col_name: &str,
232 ts_col_unit: TimeUnit,
233 left: &Expr,
234 op: &Operator,
235 right: &Expr,
236) -> Option<TimestampRange> {
237 match op {
238 Operator::Eq => get_timestamp_filter(ts_col_name, left, right)
239 .and_then(|(ts, _)| ts.convert_to(ts_col_unit))
240 .map(TimestampRange::single),
241 Operator::Lt => {
242 let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
243 if reverse {
244 let ts_val = ts.convert_to(ts_col_unit)?.value();
246 Some(TimestampRange::from_start(Timestamp::new(
247 ts_val + 1,
248 ts_col_unit,
249 )))
250 } else {
251 ts.convert_to_ceil(ts_col_unit)
253 .map(|ts| TimestampRange::until_end(ts, false))
254 }
255 }
256 Operator::LtEq => {
257 let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
258 if reverse {
259 ts.convert_to_ceil(ts_col_unit)
261 .map(TimestampRange::from_start)
262 } else {
263 ts.convert_to(ts_col_unit)
265 .map(|ts| TimestampRange::until_end(ts, true))
266 }
267 }
268 Operator::Gt => {
269 let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
270 if reverse {
271 ts.convert_to_ceil(ts_col_unit)
273 .map(|t| TimestampRange::until_end(t, false))
274 } else {
275 let ts_val = ts.convert_to(ts_col_unit)?.value();
277 Some(TimestampRange::from_start(Timestamp::new(
278 ts_val + 1,
279 ts_col_unit,
280 )))
281 }
282 }
283 Operator::GtEq => {
284 let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
285 if reverse {
286 ts.convert_to(ts_col_unit)
288 .map(|t| TimestampRange::until_end(t, true))
289 } else {
290 ts.convert_to_ceil(ts_col_unit)
292 .map(TimestampRange::from_start)
293 }
294 }
295 Operator::And => {
296 let left = extract_time_range_from_expr(ts_col_name, ts_col_unit, left)
299 .unwrap_or_else(TimestampRange::min_to_max);
300 let right = extract_time_range_from_expr(ts_col_name, ts_col_unit, right)
301 .unwrap_or_else(TimestampRange::min_to_max);
302 Some(left.and(&right))
303 }
304 Operator::Or => {
305 let left = extract_time_range_from_expr(ts_col_name, ts_col_unit, left)?;
306 let right = extract_time_range_from_expr(ts_col_name, ts_col_unit, right)?;
307 Some(left.or(&right))
308 }
309 _ => None,
310 }
311}
312
313fn get_timestamp_filter(ts_col_name: &str, left: &Expr, right: &Expr) -> Option<(Timestamp, bool)> {
314 let (col, lit, reverse) = match (left, right) {
315 (Expr::Column(column), Expr::Literal(scalar, _)) => (column, scalar, false),
316 (Expr::Literal(scalar, _), Expr::Column(column)) => (column, scalar, true),
317 _ => {
318 return None;
319 }
320 };
321 if col.name != ts_col_name {
322 return None;
323 }
324
325 return_none_if_utf8!(lit);
326 scalar_value_to_timestamp(lit, None).map(|t| (t, reverse))
327}
328
329fn extract_from_between_expr(
330 ts_col_name: &str,
331 ts_col_unit: TimeUnit,
332 expr: &Expr,
333 negated: &bool,
334 low: &Expr,
335 high: &Expr,
336) -> Option<TimestampRange> {
337 let Expr::Column(col) = expr else {
338 return None;
339 };
340 if col.name != ts_col_name {
341 return None;
342 }
343
344 if *negated {
345 return None;
346 }
347
348 match (low, high) {
349 (Expr::Literal(low, _), Expr::Literal(high, _)) => {
350 return_none_if_utf8!(low);
351 return_none_if_utf8!(high);
352
353 let low_opt =
354 scalar_value_to_timestamp(low, None).and_then(|ts| ts.convert_to(ts_col_unit));
355 let high_opt = scalar_value_to_timestamp(high, None)
356 .and_then(|ts| ts.convert_to_ceil(ts_col_unit));
357 Some(TimestampRange::new_inclusive(low_opt, high_opt))
358 }
359 _ => None,
360 }
361}
362
363fn extract_from_in_list_expr(
365 ts_col_name: &str,
366 expr: &Expr,
367 negated: bool,
368 list: &[Expr],
369) -> Option<TimestampRange> {
370 if negated {
371 return None;
372 }
373 let Expr::Column(col) = expr else {
374 return None;
375 };
376 if col.name != ts_col_name {
377 return None;
378 }
379
380 if list.is_empty() {
381 return Some(TimestampRange::empty());
382 }
383 let mut init_range = TimestampRange::empty();
384 for expr in list {
385 if let Expr::Literal(scalar, _) = expr {
386 return_none_if_utf8!(scalar);
387 if let Some(timestamp) = scalar_value_to_timestamp(scalar, None) {
388 init_range = init_range.or(&TimestampRange::single(timestamp))
389 } else {
390 return None;
393 }
394 }
395 }
396 Some(init_range)
397}
398
399#[cfg(test)]
400mod tests {
401 use std::sync::Arc;
402
403 use common_test_util::temp_dir::{TempDir, create_temp_dir};
404 use datafusion::parquet::arrow::ArrowWriter;
405 use datafusion_common::{Column, ScalarValue};
406 use datafusion_expr::{BinaryExpr, Literal, Operator, col, lit};
407 use datatypes::arrow::array::Int32Array;
408 use datatypes::arrow::datatypes::{DataType, Field, Schema};
409 use datatypes::arrow::record_batch::RecordBatch;
410 use datatypes::arrow_array::StringArray;
411 use parquet::arrow::ParquetRecordBatchStreamBuilder;
412 use parquet::file::properties::WriterProperties;
413
414 use super::*;
415 use crate::predicate::stats::RowGroupPruningStatistics;
416
417 fn check_build_predicate(expr: Expr, expect: TimestampRange) {
418 assert_eq!(
419 expect,
420 build_time_range_predicate("ts", TimeUnit::Millisecond, &[expr])
421 );
422 }
423
424 #[test]
425 fn test_gt() {
426 check_build_predicate(
428 col("ts").gt(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
429 TimestampRange::from_start(Timestamp::new_millisecond(2)),
430 );
431
432 check_build_predicate(
434 lit(ScalarValue::TimestampMillisecond(Some(1), None)).gt(col("ts")),
435 TimestampRange::until_end(Timestamp::new_millisecond(1), false),
436 );
437
438 check_build_predicate(
440 lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).gt(col("ts")),
441 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
442 );
443
444 check_build_predicate(
446 col("ts").gt(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
447 TimestampRange::from_start(Timestamp::new_millisecond(2)),
448 );
449
450 check_build_predicate(
452 lit(ScalarValue::TimestampSecond(Some(1), None)).gt(col("ts")),
453 TimestampRange::until_end(Timestamp::new_millisecond(1000), false),
454 );
455
456 check_build_predicate(
458 col("ts").gt(lit(ScalarValue::TimestampSecond(Some(1), None))),
459 TimestampRange::from_start(Timestamp::new_millisecond(1001)),
460 );
461 }
462
463 #[test]
464 fn test_gt_eq() {
465 check_build_predicate(
467 col("ts").gt_eq(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
468 TimestampRange::from_start(Timestamp::new_millisecond(1)),
469 );
470
471 check_build_predicate(
473 lit(ScalarValue::TimestampMillisecond(Some(1), None)).gt_eq(col("ts")),
474 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
475 );
476
477 check_build_predicate(
479 lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).gt_eq(col("ts")),
480 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
481 );
482
483 check_build_predicate(
485 col("ts").gt_eq(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
486 TimestampRange::from_start(Timestamp::new_millisecond(2)),
487 );
488
489 check_build_predicate(
491 lit(ScalarValue::TimestampSecond(Some(1), None)).gt_eq(col("ts")),
492 TimestampRange::until_end(Timestamp::new_millisecond(1000), true),
493 );
494
495 check_build_predicate(
497 col("ts").gt_eq(lit(ScalarValue::TimestampSecond(Some(1), None))),
498 TimestampRange::from_start(Timestamp::new_millisecond(1000)),
499 );
500 }
501
502 #[test]
503 fn test_lt() {
504 check_build_predicate(
506 col("ts").lt(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
507 TimestampRange::until_end(Timestamp::new_millisecond(1), false),
508 );
509
510 check_build_predicate(
512 lit(ScalarValue::TimestampMillisecond(Some(1), None)).lt(col("ts")),
513 TimestampRange::from_start(Timestamp::new_millisecond(2)),
514 );
515
516 check_build_predicate(
518 lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).lt(col("ts")),
519 TimestampRange::from_start(Timestamp::new_millisecond(2)),
520 );
521
522 check_build_predicate(
524 col("ts").lt(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
525 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
526 );
527
528 check_build_predicate(
530 lit(ScalarValue::TimestampSecond(Some(1), None)).lt(col("ts")),
531 TimestampRange::from_start(Timestamp::new_millisecond(1001)),
532 );
533
534 check_build_predicate(
536 col("ts").lt(lit(ScalarValue::TimestampSecond(Some(1), None))),
537 TimestampRange::until_end(Timestamp::new_millisecond(1000), false),
538 );
539 }
540
541 #[test]
542 fn test_lt_eq() {
543 check_build_predicate(
545 col("ts").lt_eq(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
546 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
547 );
548
549 check_build_predicate(
551 lit(ScalarValue::TimestampMillisecond(Some(1), None)).lt_eq(col("ts")),
552 TimestampRange::from_start(Timestamp::new_millisecond(1)),
553 );
554
555 check_build_predicate(
557 lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).lt_eq(col("ts")),
558 TimestampRange::from_start(Timestamp::new_millisecond(2)),
559 );
560
561 check_build_predicate(
563 col("ts").lt_eq(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
564 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
565 );
566
567 check_build_predicate(
569 lit(ScalarValue::TimestampSecond(Some(1), None)).lt_eq(col("ts")),
570 TimestampRange::from_start(Timestamp::new_millisecond(1000)),
571 );
572
573 check_build_predicate(
575 col("ts").lt_eq(lit(ScalarValue::TimestampSecond(Some(1), None))),
576 TimestampRange::until_end(Timestamp::new_millisecond(1000), true),
577 );
578 }
579
580 async fn gen_test_parquet_file(dir: &TempDir, cnt: usize) -> (String, Arc<Schema>) {
581 let path = dir
582 .path()
583 .join("test-prune.parquet")
584 .to_string_lossy()
585 .to_string();
586
587 let name_field = Field::new("name", DataType::Utf8, true);
588 let count_field = Field::new("cnt", DataType::Int32, true);
589 let schema = Arc::new(Schema::new(vec![name_field, count_field]));
590
591 let file = std::fs::OpenOptions::new()
592 .write(true)
593 .create(true)
594 .truncate(true)
595 .open(path.clone())
596 .unwrap();
597
598 let write_props = WriterProperties::builder()
599 .set_max_row_group_size(10)
600 .build();
601 let mut writer = ArrowWriter::try_new(file, schema.clone(), Some(write_props)).unwrap();
602
603 for i in (0..cnt).step_by(10) {
604 let name_array = Arc::new(StringArray::from(
605 (i..(i + 10).min(cnt))
606 .map(|i| i.to_string())
607 .collect::<Vec<_>>(),
608 )) as Arc<_>;
609 let count_array = Arc::new(Int32Array::from(
610 (i..(i + 10).min(cnt)).map(|i| i as i32).collect::<Vec<_>>(),
611 )) as Arc<_>;
612 let rb = RecordBatch::try_new(schema.clone(), vec![name_array, count_array]).unwrap();
613 writer.write(&rb).unwrap();
614 }
615 let _ = writer.close().unwrap();
616 (path, schema)
617 }
618
619 async fn assert_prune(array_cnt: usize, filters: Vec<Expr>, expect: Vec<bool>) {
620 let dir = create_temp_dir("prune_parquet");
621 let (path, arrow_schema) = gen_test_parquet_file(&dir, array_cnt).await;
622 let schema = Arc::new(datatypes::schema::Schema::try_from(arrow_schema.clone()).unwrap());
623 let arrow_predicate = Predicate::new(filters);
624 let builder = ParquetRecordBatchStreamBuilder::new(
625 tokio::fs::OpenOptions::new()
626 .read(true)
627 .open(path)
628 .await
629 .unwrap(),
630 )
631 .await
632 .unwrap();
633 let metadata = builder.metadata().clone();
634 let row_groups = metadata.row_groups();
635
636 let stats = RowGroupPruningStatistics::new(row_groups, &schema);
637 let res = arrow_predicate.prune_with_stats(&stats, &arrow_schema);
638 assert_eq!(expect, res);
639 }
640
641 fn gen_predicate(max_val: i32, op: Operator) -> Vec<Expr> {
642 vec![datafusion_expr::Expr::BinaryExpr(BinaryExpr {
643 left: Box::new(datafusion_expr::Expr::Column(Column::from_name("cnt"))),
644 op,
645 right: Box::new(max_val.lit()),
646 })]
647 }
648
649 #[tokio::test]
650 async fn test_prune_empty() {
651 assert_prune(3, vec![], vec![true]).await;
652 }
653
654 #[tokio::test]
655 async fn test_prune_all_match() {
656 let p = gen_predicate(3, Operator::Gt);
657 assert_prune(2, p, vec![false]).await;
658 }
659
660 #[tokio::test]
661 async fn test_prune_gt() {
662 let p = gen_predicate(29, Operator::Gt);
663 assert_prune(
664 100,
665 p,
666 vec![
667 false, false, false, true, true, true, true, true, true, true,
668 ],
669 )
670 .await;
671 }
672
673 #[tokio::test]
674 async fn test_prune_eq_expr() {
675 let p = gen_predicate(30, Operator::Eq);
676 assert_prune(40, p, vec![false, false, false, true]).await;
677 }
678
679 #[tokio::test]
680 async fn test_prune_neq_expr() {
681 let p = gen_predicate(30, Operator::NotEq);
682 assert_prune(40, p, vec![true, true, true, true]).await;
683 }
684
685 #[tokio::test]
686 async fn test_prune_gteq_expr() {
687 let p = gen_predicate(29, Operator::GtEq);
688 assert_prune(40, p, vec![false, false, true, true]).await;
689 }
690
691 #[tokio::test]
692 async fn test_prune_lt_expr() {
693 let p = gen_predicate(30, Operator::Lt);
694 assert_prune(40, p, vec![true, true, true, false]).await;
695 }
696
697 #[tokio::test]
698 async fn test_prune_lteq_expr() {
699 let p = gen_predicate(30, Operator::LtEq);
700 assert_prune(40, p, vec![true, true, true, true]).await;
701 }
702
703 #[tokio::test]
704 async fn test_prune_between_expr() {
705 let p = gen_predicate(30, Operator::LtEq);
706 assert_prune(40, p, vec![true, true, true, true]).await;
707 }
708
709 #[tokio::test]
710 async fn test_or() {
711 let e = datafusion_expr::Expr::Column(Column::from_name("cnt"))
713 .gt(30.lit())
714 .or(datafusion_expr::Expr::Column(Column::from_name("cnt")).lt(20.lit()));
715 assert_prune(40, vec![e], vec![true, true, false, true]).await;
716 }
717
718 #[tokio::test]
719 async fn test_to_physical_expr() {
720 let predicate = Predicate::new(vec![
721 col("host").eq(lit("host_a")),
722 col("ts").gt(lit(ScalarValue::TimestampMicrosecond(Some(123), None))),
723 ]);
724
725 let schema = Arc::new(arrow::datatypes::Schema::new(vec![Field::new(
726 "host",
727 arrow::datatypes::DataType::Utf8,
728 false,
729 )]));
730
731 let predicates = predicate.to_physical_exprs(&schema).unwrap();
732 assert!(!predicates.is_empty());
733 }
734}