1use std::sync::Arc;
16
17use common_telemetry::{error, warn};
18use common_time::range::TimestampRange;
19use common_time::timestamp::TimeUnit;
20use common_time::Timestamp;
21use datafusion::common::ScalarValue;
22use datafusion::physical_optimizer::pruning::{PruningPredicate, PruningStatistics};
23use datafusion_common::ToDFSchema;
24use datafusion_expr::expr::{Expr, InList};
25use datafusion_expr::{Between, BinaryExpr, Operator};
26use datafusion_physical_expr::execution_props::ExecutionProps;
27use datafusion_physical_expr::{create_physical_expr, PhysicalExpr};
28use datatypes::arrow;
29use datatypes::value::scalar_value_to_timestamp;
30use snafu::ResultExt;
31
32use crate::error;
33
34#[cfg(test)]
35mod stats;
36
37macro_rules! return_none_if_utf8 {
40 ($lit: ident) => {
41 if matches!($lit, ScalarValue::Utf8(_)) {
42 warn!(
43 "Unexpected ScalarValue::Utf8 in time range predicate: {:?}. Maybe it's an implicit bug, please report it to https://github.com/GreptimeTeam/greptimedb/issues",
44 $lit
45 );
46
47 return None;
49 }
50 };
51}
52
53#[derive(Debug, Clone)]
55pub struct Predicate {
56 exprs: Arc<Vec<Expr>>,
58}
59
60impl Predicate {
61 pub fn new(exprs: Vec<Expr>) -> Self {
65 Self {
66 exprs: Arc::new(exprs),
67 }
68 }
69
70 pub fn exprs(&self) -> &[Expr] {
72 &self.exprs
73 }
74
75 pub fn to_physical_exprs(
77 &self,
78 schema: &arrow::datatypes::SchemaRef,
79 ) -> error::Result<Vec<Arc<dyn PhysicalExpr>>> {
80 let df_schema = schema
81 .clone()
82 .to_dfschema_ref()
83 .context(error::DatafusionSnafu)?;
84
85 let execution_props = &ExecutionProps::new();
89
90 Ok(self
91 .exprs
92 .iter()
93 .filter_map(|expr| create_physical_expr(expr, df_schema.as_ref(), execution_props).ok())
94 .collect::<Vec<_>>())
95 }
96
97 pub fn prune_with_stats<S: PruningStatistics>(
100 &self,
101 stats: &S,
102 schema: &arrow::datatypes::SchemaRef,
103 ) -> Vec<bool> {
104 let mut res = vec![true; stats.num_containers()];
105 let physical_exprs = match self.to_physical_exprs(schema) {
106 Ok(expr) => expr,
107 Err(e) => {
108 warn!(e; "Failed to build physical expr from predicates: {:?}", &self.exprs);
109 return res;
110 }
111 };
112
113 for expr in &physical_exprs {
114 match PruningPredicate::try_new(expr.clone(), schema.clone()) {
115 Ok(p) => match p.prune(stats) {
116 Ok(r) => {
117 for (curr_val, res) in r.into_iter().zip(res.iter_mut()) {
118 *res &= curr_val
119 }
120 }
121 Err(e) => {
122 warn!(e; "Failed to prune row groups");
123 }
124 },
125 Err(e) => {
126 error!(e; "Failed to create predicate for expr");
127 }
128 }
129 }
130 res
131 }
132}
133
134pub fn build_time_range_predicate(
139 ts_col_name: &str,
140 ts_col_unit: TimeUnit,
141 filters: &[Expr],
142) -> TimestampRange {
143 let mut res = TimestampRange::min_to_max();
144 for expr in filters {
145 if let Some(range) = extract_time_range_from_expr(ts_col_name, ts_col_unit, expr) {
146 res = res.and(&range);
147 }
148 }
149 res
150}
151
152fn extract_time_range_from_expr(
155 ts_col_name: &str,
156 ts_col_unit: TimeUnit,
157 expr: &Expr,
158) -> Option<TimestampRange> {
159 match expr {
160 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
161 extract_from_binary_expr(ts_col_name, ts_col_unit, left, op, right)
162 }
163 Expr::Between(Between {
164 expr,
165 negated,
166 low,
167 high,
168 }) => extract_from_between_expr(ts_col_name, ts_col_unit, expr, negated, low, high),
169 Expr::InList(InList {
170 expr,
171 list,
172 negated,
173 }) => extract_from_in_list_expr(ts_col_name, expr, *negated, list),
174 _ => None,
175 }
176}
177
178fn extract_from_binary_expr(
179 ts_col_name: &str,
180 ts_col_unit: TimeUnit,
181 left: &Expr,
182 op: &Operator,
183 right: &Expr,
184) -> Option<TimestampRange> {
185 match op {
186 Operator::Eq => get_timestamp_filter(ts_col_name, left, right)
187 .and_then(|(ts, _)| ts.convert_to(ts_col_unit))
188 .map(TimestampRange::single),
189 Operator::Lt => {
190 let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
191 if reverse {
192 let ts_val = ts.convert_to(ts_col_unit)?.value();
194 Some(TimestampRange::from_start(Timestamp::new(
195 ts_val + 1,
196 ts_col_unit,
197 )))
198 } else {
199 ts.convert_to_ceil(ts_col_unit)
201 .map(|ts| TimestampRange::until_end(ts, false))
202 }
203 }
204 Operator::LtEq => {
205 let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
206 if reverse {
207 ts.convert_to_ceil(ts_col_unit)
209 .map(TimestampRange::from_start)
210 } else {
211 ts.convert_to(ts_col_unit)
213 .map(|ts| TimestampRange::until_end(ts, true))
214 }
215 }
216 Operator::Gt => {
217 let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
218 if reverse {
219 ts.convert_to_ceil(ts_col_unit)
221 .map(|t| TimestampRange::until_end(t, false))
222 } else {
223 let ts_val = ts.convert_to(ts_col_unit)?.value();
225 Some(TimestampRange::from_start(Timestamp::new(
226 ts_val + 1,
227 ts_col_unit,
228 )))
229 }
230 }
231 Operator::GtEq => {
232 let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
233 if reverse {
234 ts.convert_to(ts_col_unit)
236 .map(|t| TimestampRange::until_end(t, true))
237 } else {
238 ts.convert_to_ceil(ts_col_unit)
240 .map(TimestampRange::from_start)
241 }
242 }
243 Operator::And => {
244 let left = extract_time_range_from_expr(ts_col_name, ts_col_unit, left)
247 .unwrap_or_else(TimestampRange::min_to_max);
248 let right = extract_time_range_from_expr(ts_col_name, ts_col_unit, right)
249 .unwrap_or_else(TimestampRange::min_to_max);
250 Some(left.and(&right))
251 }
252 Operator::Or => {
253 let left = extract_time_range_from_expr(ts_col_name, ts_col_unit, left)?;
254 let right = extract_time_range_from_expr(ts_col_name, ts_col_unit, right)?;
255 Some(left.or(&right))
256 }
257 _ => None,
258 }
259}
260
261fn get_timestamp_filter(ts_col_name: &str, left: &Expr, right: &Expr) -> Option<(Timestamp, bool)> {
262 let (col, lit, reverse) = match (left, right) {
263 (Expr::Column(column), Expr::Literal(scalar)) => (column, scalar, false),
264 (Expr::Literal(scalar), Expr::Column(column)) => (column, scalar, true),
265 _ => {
266 return None;
267 }
268 };
269 if col.name != ts_col_name {
270 return None;
271 }
272
273 return_none_if_utf8!(lit);
274 scalar_value_to_timestamp(lit, None).map(|t| (t, reverse))
275}
276
277fn extract_from_between_expr(
278 ts_col_name: &str,
279 ts_col_unit: TimeUnit,
280 expr: &Expr,
281 negated: &bool,
282 low: &Expr,
283 high: &Expr,
284) -> Option<TimestampRange> {
285 let Expr::Column(col) = expr else {
286 return None;
287 };
288 if col.name != ts_col_name {
289 return None;
290 }
291
292 if *negated {
293 return None;
294 }
295
296 match (low, high) {
297 (Expr::Literal(low), Expr::Literal(high)) => {
298 return_none_if_utf8!(low);
299 return_none_if_utf8!(high);
300
301 let low_opt =
302 scalar_value_to_timestamp(low, None).and_then(|ts| ts.convert_to(ts_col_unit));
303 let high_opt = scalar_value_to_timestamp(high, None)
304 .and_then(|ts| ts.convert_to_ceil(ts_col_unit));
305 Some(TimestampRange::new_inclusive(low_opt, high_opt))
306 }
307 _ => None,
308 }
309}
310
311fn extract_from_in_list_expr(
313 ts_col_name: &str,
314 expr: &Expr,
315 negated: bool,
316 list: &[Expr],
317) -> Option<TimestampRange> {
318 if negated {
319 return None;
320 }
321 let Expr::Column(col) = expr else {
322 return None;
323 };
324 if col.name != ts_col_name {
325 return None;
326 }
327
328 if list.is_empty() {
329 return Some(TimestampRange::empty());
330 }
331 let mut init_range = TimestampRange::empty();
332 for expr in list {
333 if let Expr::Literal(scalar) = expr {
334 return_none_if_utf8!(scalar);
335 if let Some(timestamp) = scalar_value_to_timestamp(scalar, None) {
336 init_range = init_range.or(&TimestampRange::single(timestamp))
337 } else {
338 return None;
341 }
342 }
343 }
344 Some(init_range)
345}
346
347#[cfg(test)]
348mod tests {
349 use std::sync::Arc;
350
351 use common_test_util::temp_dir::{create_temp_dir, TempDir};
352 use datafusion::parquet::arrow::ArrowWriter;
353 use datafusion_common::{Column, ScalarValue};
354 use datafusion_expr::{col, lit, BinaryExpr, Literal, Operator};
355 use datatypes::arrow::array::Int32Array;
356 use datatypes::arrow::datatypes::{DataType, Field, Schema};
357 use datatypes::arrow::record_batch::RecordBatch;
358 use datatypes::arrow_array::StringArray;
359 use parquet::arrow::ParquetRecordBatchStreamBuilder;
360 use parquet::file::properties::WriterProperties;
361
362 use super::*;
363 use crate::predicate::stats::RowGroupPruningStatistics;
364
365 fn check_build_predicate(expr: Expr, expect: TimestampRange) {
366 assert_eq!(
367 expect,
368 build_time_range_predicate("ts", TimeUnit::Millisecond, &[expr])
369 );
370 }
371
372 #[test]
373 fn test_gt() {
374 check_build_predicate(
376 col("ts").gt(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
377 TimestampRange::from_start(Timestamp::new_millisecond(2)),
378 );
379
380 check_build_predicate(
382 lit(ScalarValue::TimestampMillisecond(Some(1), None)).gt(col("ts")),
383 TimestampRange::until_end(Timestamp::new_millisecond(1), false),
384 );
385
386 check_build_predicate(
388 lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).gt(col("ts")),
389 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
390 );
391
392 check_build_predicate(
394 col("ts").gt(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
395 TimestampRange::from_start(Timestamp::new_millisecond(2)),
396 );
397
398 check_build_predicate(
400 lit(ScalarValue::TimestampSecond(Some(1), None)).gt(col("ts")),
401 TimestampRange::until_end(Timestamp::new_millisecond(1000), false),
402 );
403
404 check_build_predicate(
406 col("ts").gt(lit(ScalarValue::TimestampSecond(Some(1), None))),
407 TimestampRange::from_start(Timestamp::new_millisecond(1001)),
408 );
409 }
410
411 #[test]
412 fn test_gt_eq() {
413 check_build_predicate(
415 col("ts").gt_eq(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
416 TimestampRange::from_start(Timestamp::new_millisecond(1)),
417 );
418
419 check_build_predicate(
421 lit(ScalarValue::TimestampMillisecond(Some(1), None)).gt_eq(col("ts")),
422 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
423 );
424
425 check_build_predicate(
427 lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).gt_eq(col("ts")),
428 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
429 );
430
431 check_build_predicate(
433 col("ts").gt_eq(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
434 TimestampRange::from_start(Timestamp::new_millisecond(2)),
435 );
436
437 check_build_predicate(
439 lit(ScalarValue::TimestampSecond(Some(1), None)).gt_eq(col("ts")),
440 TimestampRange::until_end(Timestamp::new_millisecond(1000), true),
441 );
442
443 check_build_predicate(
445 col("ts").gt_eq(lit(ScalarValue::TimestampSecond(Some(1), None))),
446 TimestampRange::from_start(Timestamp::new_millisecond(1000)),
447 );
448 }
449
450 #[test]
451 fn test_lt() {
452 check_build_predicate(
454 col("ts").lt(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
455 TimestampRange::until_end(Timestamp::new_millisecond(1), false),
456 );
457
458 check_build_predicate(
460 lit(ScalarValue::TimestampMillisecond(Some(1), None)).lt(col("ts")),
461 TimestampRange::from_start(Timestamp::new_millisecond(2)),
462 );
463
464 check_build_predicate(
466 lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).lt(col("ts")),
467 TimestampRange::from_start(Timestamp::new_millisecond(2)),
468 );
469
470 check_build_predicate(
472 col("ts").lt(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
473 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
474 );
475
476 check_build_predicate(
478 lit(ScalarValue::TimestampSecond(Some(1), None)).lt(col("ts")),
479 TimestampRange::from_start(Timestamp::new_millisecond(1001)),
480 );
481
482 check_build_predicate(
484 col("ts").lt(lit(ScalarValue::TimestampSecond(Some(1), None))),
485 TimestampRange::until_end(Timestamp::new_millisecond(1000), false),
486 );
487 }
488
489 #[test]
490 fn test_lt_eq() {
491 check_build_predicate(
493 col("ts").lt_eq(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
494 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
495 );
496
497 check_build_predicate(
499 lit(ScalarValue::TimestampMillisecond(Some(1), None)).lt_eq(col("ts")),
500 TimestampRange::from_start(Timestamp::new_millisecond(1)),
501 );
502
503 check_build_predicate(
505 lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).lt_eq(col("ts")),
506 TimestampRange::from_start(Timestamp::new_millisecond(2)),
507 );
508
509 check_build_predicate(
511 col("ts").lt_eq(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
512 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
513 );
514
515 check_build_predicate(
517 lit(ScalarValue::TimestampSecond(Some(1), None)).lt_eq(col("ts")),
518 TimestampRange::from_start(Timestamp::new_millisecond(1000)),
519 );
520
521 check_build_predicate(
523 col("ts").lt_eq(lit(ScalarValue::TimestampSecond(Some(1), None))),
524 TimestampRange::until_end(Timestamp::new_millisecond(1000), true),
525 );
526 }
527
528 async fn gen_test_parquet_file(dir: &TempDir, cnt: usize) -> (String, Arc<Schema>) {
529 let path = dir
530 .path()
531 .join("test-prune.parquet")
532 .to_string_lossy()
533 .to_string();
534
535 let name_field = Field::new("name", DataType::Utf8, true);
536 let count_field = Field::new("cnt", DataType::Int32, true);
537 let schema = Arc::new(Schema::new(vec![name_field, count_field]));
538
539 let file = std::fs::OpenOptions::new()
540 .write(true)
541 .create(true)
542 .truncate(true)
543 .open(path.clone())
544 .unwrap();
545
546 let write_props = WriterProperties::builder()
547 .set_max_row_group_size(10)
548 .build();
549 let mut writer = ArrowWriter::try_new(file, schema.clone(), Some(write_props)).unwrap();
550
551 for i in (0..cnt).step_by(10) {
552 let name_array = Arc::new(StringArray::from(
553 (i..(i + 10).min(cnt))
554 .map(|i| i.to_string())
555 .collect::<Vec<_>>(),
556 )) as Arc<_>;
557 let count_array = Arc::new(Int32Array::from(
558 (i..(i + 10).min(cnt)).map(|i| i as i32).collect::<Vec<_>>(),
559 )) as Arc<_>;
560 let rb = RecordBatch::try_new(schema.clone(), vec![name_array, count_array]).unwrap();
561 writer.write(&rb).unwrap();
562 }
563 let _ = writer.close().unwrap();
564 (path, schema)
565 }
566
567 async fn assert_prune(array_cnt: usize, filters: Vec<Expr>, expect: Vec<bool>) {
568 let dir = create_temp_dir("prune_parquet");
569 let (path, arrow_schema) = gen_test_parquet_file(&dir, array_cnt).await;
570 let schema = Arc::new(datatypes::schema::Schema::try_from(arrow_schema.clone()).unwrap());
571 let arrow_predicate = Predicate::new(filters);
572 let builder = ParquetRecordBatchStreamBuilder::new(
573 tokio::fs::OpenOptions::new()
574 .read(true)
575 .open(path)
576 .await
577 .unwrap(),
578 )
579 .await
580 .unwrap();
581 let metadata = builder.metadata().clone();
582 let row_groups = metadata.row_groups();
583
584 let stats = RowGroupPruningStatistics::new(row_groups, &schema);
585 let res = arrow_predicate.prune_with_stats(&stats, &arrow_schema);
586 assert_eq!(expect, res);
587 }
588
589 fn gen_predicate(max_val: i32, op: Operator) -> Vec<Expr> {
590 vec![datafusion_expr::Expr::BinaryExpr(BinaryExpr {
591 left: Box::new(datafusion_expr::Expr::Column(Column::from_name("cnt"))),
592 op,
593 right: Box::new(datafusion_expr::Expr::Literal(ScalarValue::Int32(Some(
594 max_val,
595 )))),
596 })]
597 }
598
599 #[tokio::test]
600 async fn test_prune_empty() {
601 assert_prune(3, vec![], vec![true]).await;
602 }
603
604 #[tokio::test]
605 async fn test_prune_all_match() {
606 let p = gen_predicate(3, Operator::Gt);
607 assert_prune(2, p, vec![false]).await;
608 }
609
610 #[tokio::test]
611 async fn test_prune_gt() {
612 let p = gen_predicate(29, Operator::Gt);
613 assert_prune(
614 100,
615 p,
616 vec![
617 false, false, false, true, true, true, true, true, true, true,
618 ],
619 )
620 .await;
621 }
622
623 #[tokio::test]
624 async fn test_prune_eq_expr() {
625 let p = gen_predicate(30, Operator::Eq);
626 assert_prune(40, p, vec![false, false, false, true]).await;
627 }
628
629 #[tokio::test]
630 async fn test_prune_neq_expr() {
631 let p = gen_predicate(30, Operator::NotEq);
632 assert_prune(40, p, vec![true, true, true, true]).await;
633 }
634
635 #[tokio::test]
636 async fn test_prune_gteq_expr() {
637 let p = gen_predicate(29, Operator::GtEq);
638 assert_prune(40, p, vec![false, false, true, true]).await;
639 }
640
641 #[tokio::test]
642 async fn test_prune_lt_expr() {
643 let p = gen_predicate(30, Operator::Lt);
644 assert_prune(40, p, vec![true, true, true, false]).await;
645 }
646
647 #[tokio::test]
648 async fn test_prune_lteq_expr() {
649 let p = gen_predicate(30, Operator::LtEq);
650 assert_prune(40, p, vec![true, true, true, true]).await;
651 }
652
653 #[tokio::test]
654 async fn test_prune_between_expr() {
655 let p = gen_predicate(30, Operator::LtEq);
656 assert_prune(40, p, vec![true, true, true, true]).await;
657 }
658
659 #[tokio::test]
660 async fn test_or() {
661 let e = datafusion_expr::Expr::Column(Column::from_name("cnt"))
663 .gt(30.lit())
664 .or(datafusion_expr::Expr::Column(Column::from_name("cnt")).lt(20.lit()));
665 assert_prune(40, vec![e], vec![true, true, false, true]).await;
666 }
667
668 #[tokio::test]
669 async fn test_to_physical_expr() {
670 let predicate = Predicate::new(vec![
671 col("host").eq(lit("host_a")),
672 col("ts").gt(lit(ScalarValue::TimestampMicrosecond(Some(123), None))),
673 ]);
674
675 let schema = Arc::new(arrow::datatypes::Schema::new(vec![Field::new(
676 "host",
677 arrow::datatypes::DataType::Utf8,
678 false,
679 )]));
680
681 let predicates = predicate.to_physical_exprs(&schema).unwrap();
682 assert!(!predicates.is_empty());
683 }
684}