1use std::sync::Arc;
16
17use arc_swap::ArcSwap;
18use common_telemetry::{debug, warn};
19use common_time::Timestamp;
20use common_time::range::TimestampRange;
21use common_time::timestamp::TimeUnit;
22use datafusion::common::ScalarValue;
23use datafusion::physical_optimizer::pruning::PruningPredicate;
24use datafusion_common::ToDFSchema;
25use datafusion_common::pruning::PruningStatistics;
26use datafusion_expr::expr::{Expr, InList};
27use datafusion_expr::{Between, BinaryExpr, Operator};
28use datafusion_physical_expr::execution_props::ExecutionProps;
29use datafusion_physical_expr::expressions::DynamicFilterPhysicalExpr;
30use datafusion_physical_expr::{PhysicalExpr, create_physical_expr};
31use datatypes::arrow;
32use datatypes::value::scalar_value_to_timestamp;
33use snafu::ResultExt;
34
35use crate::error;
36
37#[cfg(test)]
38mod stats;
39
40macro_rules! return_none_if_utf8 {
43 ($lit: ident) => {
44 if matches!($lit, ScalarValue::Utf8(_)) {
45 warn!(
46 "Unexpected ScalarValue::Utf8 in time range predicate: {:?}. Maybe it's an implicit bug, please report it to https://github.com/GreptimeTeam/greptimedb/issues",
47 $lit
48 );
49
50 return None;
52 }
53 };
54}
55
56#[derive(Debug, Clone, Default)]
58pub struct Predicate {
59 exprs: Arc<Vec<Expr>>,
61 dyn_filters: Arc<ArcSwap<Vec<Arc<DynamicFilterPhysicalExpr>>>>,
65}
66
67impl Predicate {
68 pub fn new(exprs: Vec<Expr>) -> Self {
72 Self {
73 exprs: Arc::new(exprs),
74 dyn_filters: Arc::new(ArcSwap::new(Arc::new(vec![]))),
75 }
76 }
77
78 pub fn with_dyn_filters(
79 exprs: Vec<Expr>,
80 dyn_filters: Vec<Arc<DynamicFilterPhysicalExpr>>,
81 ) -> Self {
82 Self {
83 exprs: Arc::new(exprs),
84 dyn_filters: Arc::new(ArcSwap::new(Arc::new(dyn_filters))),
85 }
86 }
87
88 pub fn is_empty(&self) -> bool {
89 self.exprs.is_empty() && self.dyn_filters.load().is_empty()
90 }
91
92 pub fn add_dyn_filters(&self, dyn_filters: Vec<Arc<DynamicFilterPhysicalExpr>>) {
94 self.dyn_filters.rcu(|existing| {
95 let mut new_filters = existing.as_ref().clone();
96 new_filters.extend(dyn_filters.clone());
97 Arc::new(new_filters)
98 });
99 }
100
101 pub fn exprs(&self) -> &[Expr] {
103 &self.exprs
104 }
105
106 pub fn dyn_filters(&self) -> Arc<Vec<Arc<DynamicFilterPhysicalExpr>>> {
109 self.dyn_filters.load_full()
110 }
111
112 pub fn dyn_filter_phy_exprs(&self) -> error::Result<Vec<Arc<dyn PhysicalExpr>>> {
115 self.dyn_filters
116 .load()
117 .iter()
118 .map(|e| e.current())
119 .collect::<Result<Vec<_>, _>>()
120 .context(error::DatafusionSnafu)
121 }
122
123 pub fn to_physical_expr(
125 expr: &Expr,
126 schema: &arrow::datatypes::SchemaRef,
127 ) -> error::Result<Arc<dyn PhysicalExpr>> {
128 let df_schema = schema
129 .clone()
130 .to_dfschema_ref()
131 .context(error::DatafusionSnafu)?;
132
133 let execution_props = &ExecutionProps::new();
137
138 create_physical_expr(expr, df_schema.as_ref(), execution_props)
139 .context(error::DatafusionSnafu)
140 }
141
142 pub fn to_physical_exprs(
144 &self,
145 schema: &arrow::datatypes::SchemaRef,
146 ) -> error::Result<Vec<Arc<dyn PhysicalExpr>>> {
147 let dyn_filters = self.dyn_filter_phy_exprs()?;
148
149 Ok(self
150 .exprs
151 .iter()
152 .filter_map(|expr| Self::to_physical_expr(expr, schema).ok())
153 .chain(dyn_filters)
154 .collect::<Vec<_>>())
155 }
156
157 pub fn prune_with_stats<S: PruningStatistics>(
160 &self,
161 stats: &S,
162 schema: &arrow::datatypes::SchemaRef,
163 ) -> Vec<bool> {
164 let mut res = vec![true; stats.num_containers()];
165 let physical_exprs = match self.to_physical_exprs(schema) {
166 Ok(expr) => expr,
167 Err(e) => {
168 warn!(e; "Failed to build physical expr from predicates: {:?}", &self.exprs);
169 return res;
170 }
171 };
172
173 for expr in &physical_exprs {
174 match PruningPredicate::try_new(expr.clone(), schema.clone()) {
175 Ok(p) => match p.prune(stats) {
176 Ok(r) => {
177 for (curr_val, res) in r.into_iter().zip(res.iter_mut()) {
178 *res &= curr_val
179 }
180 }
181 Err(e) => {
182 warn!(e; "Failed to prune row groups");
183 }
184 },
185 Err(e) => {
186 debug!("Failed to create pruning predicate for expr: {e:?}");
188 }
189 }
190 }
191 res
192 }
193}
194
195pub fn build_time_range_predicate(
200 ts_col_name: &str,
201 ts_col_unit: TimeUnit,
202 filters: &[Expr],
203) -> TimestampRange {
204 let mut res = TimestampRange::min_to_max();
205 for expr in filters {
206 if let Some(range) = extract_time_range_from_expr(ts_col_name, ts_col_unit, expr) {
207 res = res.and(&range);
208 }
209 }
210 res
211}
212
213pub fn extract_time_range_from_expr(
216 ts_col_name: &str,
217 ts_col_unit: TimeUnit,
218 expr: &Expr,
219) -> Option<TimestampRange> {
220 match expr {
221 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
222 extract_from_binary_expr(ts_col_name, ts_col_unit, left, op, right)
223 }
224 Expr::Between(Between {
225 expr,
226 negated,
227 low,
228 high,
229 }) => extract_from_between_expr(ts_col_name, ts_col_unit, expr, negated, low, high),
230 Expr::InList(InList {
231 expr,
232 list,
233 negated,
234 }) => extract_from_in_list_expr(ts_col_name, expr, *negated, list),
235 _ => None,
236 }
237}
238
239fn extract_from_binary_expr(
240 ts_col_name: &str,
241 ts_col_unit: TimeUnit,
242 left: &Expr,
243 op: &Operator,
244 right: &Expr,
245) -> Option<TimestampRange> {
246 match op {
247 Operator::Eq => get_timestamp_filter(ts_col_name, left, right)
248 .and_then(|(ts, _)| ts.convert_to(ts_col_unit))
249 .map(TimestampRange::single),
250 Operator::Lt => {
251 let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
252 if reverse {
253 let ts_val = ts.convert_to(ts_col_unit)?.value();
255 Some(TimestampRange::from_start(Timestamp::new(
256 ts_val + 1,
257 ts_col_unit,
258 )))
259 } else {
260 ts.convert_to_ceil(ts_col_unit)
262 .map(|ts| TimestampRange::until_end(ts, false))
263 }
264 }
265 Operator::LtEq => {
266 let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
267 if reverse {
268 ts.convert_to_ceil(ts_col_unit)
270 .map(TimestampRange::from_start)
271 } else {
272 ts.convert_to(ts_col_unit)
274 .map(|ts| TimestampRange::until_end(ts, true))
275 }
276 }
277 Operator::Gt => {
278 let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
279 if reverse {
280 ts.convert_to_ceil(ts_col_unit)
282 .map(|t| TimestampRange::until_end(t, false))
283 } else {
284 let ts_val = ts.convert_to(ts_col_unit)?.value();
286 Some(TimestampRange::from_start(Timestamp::new(
287 ts_val + 1,
288 ts_col_unit,
289 )))
290 }
291 }
292 Operator::GtEq => {
293 let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
294 if reverse {
295 ts.convert_to(ts_col_unit)
297 .map(|t| TimestampRange::until_end(t, true))
298 } else {
299 ts.convert_to_ceil(ts_col_unit)
301 .map(TimestampRange::from_start)
302 }
303 }
304 Operator::And => {
305 let left = extract_time_range_from_expr(ts_col_name, ts_col_unit, left)
308 .unwrap_or_else(TimestampRange::min_to_max);
309 let right = extract_time_range_from_expr(ts_col_name, ts_col_unit, right)
310 .unwrap_or_else(TimestampRange::min_to_max);
311 Some(left.and(&right))
312 }
313 Operator::Or => {
314 let left = extract_time_range_from_expr(ts_col_name, ts_col_unit, left)?;
315 let right = extract_time_range_from_expr(ts_col_name, ts_col_unit, right)?;
316 Some(left.or(&right))
317 }
318 _ => None,
319 }
320}
321
322fn get_timestamp_filter(ts_col_name: &str, left: &Expr, right: &Expr) -> Option<(Timestamp, bool)> {
323 let (col, lit, reverse) = match (left, right) {
324 (Expr::Column(column), Expr::Literal(scalar, _)) => (column, scalar, false),
325 (Expr::Literal(scalar, _), Expr::Column(column)) => (column, scalar, true),
326 _ => {
327 return None;
328 }
329 };
330 if col.name != ts_col_name {
331 return None;
332 }
333
334 return_none_if_utf8!(lit);
335 scalar_value_to_timestamp(lit, None).map(|t| (t, reverse))
336}
337
338fn extract_from_between_expr(
339 ts_col_name: &str,
340 ts_col_unit: TimeUnit,
341 expr: &Expr,
342 negated: &bool,
343 low: &Expr,
344 high: &Expr,
345) -> Option<TimestampRange> {
346 let Expr::Column(col) = expr else {
347 return None;
348 };
349 if col.name != ts_col_name {
350 return None;
351 }
352
353 if *negated {
354 return None;
355 }
356
357 match (low, high) {
358 (Expr::Literal(low, _), Expr::Literal(high, _)) => {
359 return_none_if_utf8!(low);
360 return_none_if_utf8!(high);
361
362 let low_opt =
363 scalar_value_to_timestamp(low, None).and_then(|ts| ts.convert_to(ts_col_unit));
364 let high_opt = scalar_value_to_timestamp(high, None)
365 .and_then(|ts| ts.convert_to_ceil(ts_col_unit));
366 Some(TimestampRange::new_inclusive(low_opt, high_opt))
367 }
368 _ => None,
369 }
370}
371
372fn extract_from_in_list_expr(
374 ts_col_name: &str,
375 expr: &Expr,
376 negated: bool,
377 list: &[Expr],
378) -> Option<TimestampRange> {
379 if negated {
380 return None;
381 }
382 let Expr::Column(col) = expr else {
383 return None;
384 };
385 if col.name != ts_col_name {
386 return None;
387 }
388
389 if list.is_empty() {
390 return Some(TimestampRange::empty());
391 }
392 let mut init_range = TimestampRange::empty();
393 for expr in list {
394 if let Expr::Literal(scalar, _) = expr {
395 return_none_if_utf8!(scalar);
396 if let Some(timestamp) = scalar_value_to_timestamp(scalar, None) {
397 init_range = init_range.or(&TimestampRange::single(timestamp))
398 } else {
399 return None;
402 }
403 }
404 }
405 Some(init_range)
406}
407
408#[cfg(test)]
409mod tests {
410 use std::sync::Arc;
411
412 use common_test_util::temp_dir::{TempDir, create_temp_dir};
413 use datafusion::parquet::arrow::ArrowWriter;
414 use datafusion_common::{Column, ScalarValue};
415 use datafusion_expr::{BinaryExpr, Literal, Operator, col, lit};
416 use datatypes::arrow::array::Int32Array;
417 use datatypes::arrow::datatypes::{DataType, Field, Schema};
418 use datatypes::arrow::record_batch::RecordBatch;
419 use datatypes::arrow_array::StringArray;
420 use parquet::arrow::ParquetRecordBatchStreamBuilder;
421 use parquet::file::properties::WriterProperties;
422
423 use super::*;
424 use crate::predicate::stats::RowGroupPruningStatistics;
425
426 fn check_build_predicate(expr: Expr, expect: TimestampRange) {
427 assert_eq!(
428 expect,
429 build_time_range_predicate("ts", TimeUnit::Millisecond, &[expr])
430 );
431 }
432
433 #[test]
434 fn test_gt() {
435 check_build_predicate(
437 col("ts").gt(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
438 TimestampRange::from_start(Timestamp::new_millisecond(2)),
439 );
440
441 check_build_predicate(
443 lit(ScalarValue::TimestampMillisecond(Some(1), None)).gt(col("ts")),
444 TimestampRange::until_end(Timestamp::new_millisecond(1), false),
445 );
446
447 check_build_predicate(
449 lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).gt(col("ts")),
450 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
451 );
452
453 check_build_predicate(
455 col("ts").gt(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
456 TimestampRange::from_start(Timestamp::new_millisecond(2)),
457 );
458
459 check_build_predicate(
461 lit(ScalarValue::TimestampSecond(Some(1), None)).gt(col("ts")),
462 TimestampRange::until_end(Timestamp::new_millisecond(1000), false),
463 );
464
465 check_build_predicate(
467 col("ts").gt(lit(ScalarValue::TimestampSecond(Some(1), None))),
468 TimestampRange::from_start(Timestamp::new_millisecond(1001)),
469 );
470 }
471
472 #[test]
473 fn test_gt_eq() {
474 check_build_predicate(
476 col("ts").gt_eq(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
477 TimestampRange::from_start(Timestamp::new_millisecond(1)),
478 );
479
480 check_build_predicate(
482 lit(ScalarValue::TimestampMillisecond(Some(1), None)).gt_eq(col("ts")),
483 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
484 );
485
486 check_build_predicate(
488 lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).gt_eq(col("ts")),
489 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
490 );
491
492 check_build_predicate(
494 col("ts").gt_eq(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
495 TimestampRange::from_start(Timestamp::new_millisecond(2)),
496 );
497
498 check_build_predicate(
500 lit(ScalarValue::TimestampSecond(Some(1), None)).gt_eq(col("ts")),
501 TimestampRange::until_end(Timestamp::new_millisecond(1000), true),
502 );
503
504 check_build_predicate(
506 col("ts").gt_eq(lit(ScalarValue::TimestampSecond(Some(1), None))),
507 TimestampRange::from_start(Timestamp::new_millisecond(1000)),
508 );
509 }
510
511 #[test]
512 fn test_lt() {
513 check_build_predicate(
515 col("ts").lt(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
516 TimestampRange::until_end(Timestamp::new_millisecond(1), false),
517 );
518
519 check_build_predicate(
521 lit(ScalarValue::TimestampMillisecond(Some(1), None)).lt(col("ts")),
522 TimestampRange::from_start(Timestamp::new_millisecond(2)),
523 );
524
525 check_build_predicate(
527 lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).lt(col("ts")),
528 TimestampRange::from_start(Timestamp::new_millisecond(2)),
529 );
530
531 check_build_predicate(
533 col("ts").lt(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
534 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
535 );
536
537 check_build_predicate(
539 lit(ScalarValue::TimestampSecond(Some(1), None)).lt(col("ts")),
540 TimestampRange::from_start(Timestamp::new_millisecond(1001)),
541 );
542
543 check_build_predicate(
545 col("ts").lt(lit(ScalarValue::TimestampSecond(Some(1), None))),
546 TimestampRange::until_end(Timestamp::new_millisecond(1000), false),
547 );
548 }
549
550 #[test]
551 fn test_lt_eq() {
552 check_build_predicate(
554 col("ts").lt_eq(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
555 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
556 );
557
558 check_build_predicate(
560 lit(ScalarValue::TimestampMillisecond(Some(1), None)).lt_eq(col("ts")),
561 TimestampRange::from_start(Timestamp::new_millisecond(1)),
562 );
563
564 check_build_predicate(
566 lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).lt_eq(col("ts")),
567 TimestampRange::from_start(Timestamp::new_millisecond(2)),
568 );
569
570 check_build_predicate(
572 col("ts").lt_eq(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
573 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
574 );
575
576 check_build_predicate(
578 lit(ScalarValue::TimestampSecond(Some(1), None)).lt_eq(col("ts")),
579 TimestampRange::from_start(Timestamp::new_millisecond(1000)),
580 );
581
582 check_build_predicate(
584 col("ts").lt_eq(lit(ScalarValue::TimestampSecond(Some(1), None))),
585 TimestampRange::until_end(Timestamp::new_millisecond(1000), true),
586 );
587 }
588
589 async fn gen_test_parquet_file(dir: &TempDir, cnt: usize) -> (String, Arc<Schema>) {
590 let path = dir
591 .path()
592 .join("test-prune.parquet")
593 .to_string_lossy()
594 .to_string();
595
596 let name_field = Field::new("name", DataType::Utf8, true);
597 let count_field = Field::new("cnt", DataType::Int32, true);
598 let schema = Arc::new(Schema::new(vec![name_field, count_field]));
599
600 let file = std::fs::OpenOptions::new()
601 .write(true)
602 .create(true)
603 .truncate(true)
604 .open(path.clone())
605 .unwrap();
606
607 let write_props = WriterProperties::builder()
608 .set_max_row_group_size(10)
609 .build();
610 let mut writer = ArrowWriter::try_new(file, schema.clone(), Some(write_props)).unwrap();
611
612 for i in (0..cnt).step_by(10) {
613 let name_array = Arc::new(StringArray::from(
614 (i..(i + 10).min(cnt))
615 .map(|i| i.to_string())
616 .collect::<Vec<_>>(),
617 )) as Arc<_>;
618 let count_array = Arc::new(Int32Array::from(
619 (i..(i + 10).min(cnt)).map(|i| i as i32).collect::<Vec<_>>(),
620 )) as Arc<_>;
621 let rb = RecordBatch::try_new(schema.clone(), vec![name_array, count_array]).unwrap();
622 writer.write(&rb).unwrap();
623 }
624 let _ = writer.close().unwrap();
625 (path, schema)
626 }
627
628 async fn assert_prune(array_cnt: usize, filters: Vec<Expr>, expect: Vec<bool>) {
629 let dir = create_temp_dir("prune_parquet");
630 let (path, arrow_schema) = gen_test_parquet_file(&dir, array_cnt).await;
631 let schema = Arc::new(datatypes::schema::Schema::try_from(arrow_schema.clone()).unwrap());
632 let arrow_predicate = Predicate::new(filters);
633 let builder = ParquetRecordBatchStreamBuilder::new(
634 tokio::fs::OpenOptions::new()
635 .read(true)
636 .open(path)
637 .await
638 .unwrap(),
639 )
640 .await
641 .unwrap();
642 let metadata = builder.metadata().clone();
643 let row_groups = metadata.row_groups();
644
645 let stats = RowGroupPruningStatistics::new(row_groups, &schema);
646 let res = arrow_predicate.prune_with_stats(&stats, &arrow_schema);
647 assert_eq!(expect, res);
648 }
649
650 fn gen_predicate(max_val: i32, op: Operator) -> Vec<Expr> {
651 vec![datafusion_expr::Expr::BinaryExpr(BinaryExpr {
652 left: Box::new(datafusion_expr::Expr::Column(Column::from_name("cnt"))),
653 op,
654 right: Box::new(max_val.lit()),
655 })]
656 }
657
658 #[tokio::test]
659 async fn test_prune_empty() {
660 assert_prune(3, vec![], vec![true]).await;
661 }
662
663 #[tokio::test]
664 async fn test_prune_all_match() {
665 let p = gen_predicate(3, Operator::Gt);
666 assert_prune(2, p, vec![false]).await;
667 }
668
669 #[tokio::test]
670 async fn test_prune_gt() {
671 let p = gen_predicate(29, Operator::Gt);
672 assert_prune(
673 100,
674 p,
675 vec![
676 false, false, false, true, true, true, true, true, true, true,
677 ],
678 )
679 .await;
680 }
681
682 #[tokio::test]
683 async fn test_prune_eq_expr() {
684 let p = gen_predicate(30, Operator::Eq);
685 assert_prune(40, p, vec![false, false, false, true]).await;
686 }
687
688 #[tokio::test]
689 async fn test_prune_neq_expr() {
690 let p = gen_predicate(30, Operator::NotEq);
691 assert_prune(40, p, vec![true, true, true, true]).await;
692 }
693
694 #[tokio::test]
695 async fn test_prune_gteq_expr() {
696 let p = gen_predicate(29, Operator::GtEq);
697 assert_prune(40, p, vec![false, false, true, true]).await;
698 }
699
700 #[tokio::test]
701 async fn test_prune_lt_expr() {
702 let p = gen_predicate(30, Operator::Lt);
703 assert_prune(40, p, vec![true, true, true, false]).await;
704 }
705
706 #[tokio::test]
707 async fn test_prune_lteq_expr() {
708 let p = gen_predicate(30, Operator::LtEq);
709 assert_prune(40, p, vec![true, true, true, true]).await;
710 }
711
712 #[tokio::test]
713 async fn test_prune_between_expr() {
714 let p = gen_predicate(30, Operator::LtEq);
715 assert_prune(40, p, vec![true, true, true, true]).await;
716 }
717
718 #[tokio::test]
719 async fn test_or() {
720 let e = datafusion_expr::Expr::Column(Column::from_name("cnt"))
722 .gt(30.lit())
723 .or(datafusion_expr::Expr::Column(Column::from_name("cnt")).lt(20.lit()));
724 assert_prune(40, vec![e], vec![true, true, false, true]).await;
725 }
726
727 #[tokio::test]
728 async fn test_to_physical_expr() {
729 let predicate = Predicate::new(vec![
730 col("host").eq(lit("host_a")),
731 col("ts").gt(lit(ScalarValue::TimestampMicrosecond(Some(123), None))),
732 ]);
733
734 let schema = Arc::new(arrow::datatypes::Schema::new(vec![Field::new(
735 "host",
736 arrow::datatypes::DataType::Utf8,
737 false,
738 )]));
739
740 let predicates = predicate.to_physical_exprs(&schema).unwrap();
741 assert!(!predicates.is_empty());
742
743 let physical_expr = Predicate::to_physical_expr(&col("host").eq(lit("host_a")), &schema);
744 assert!(physical_expr.is_ok());
745 }
746}