1use std::sync::Arc;
16
17use common_telemetry::{error, warn};
18use common_time::range::TimestampRange;
19use common_time::timestamp::TimeUnit;
20use common_time::Timestamp;
21use datafusion::common::ScalarValue;
22use datafusion::physical_optimizer::pruning::PruningPredicate;
23use datafusion_common::pruning::PruningStatistics;
24use datafusion_common::ToDFSchema;
25use datafusion_expr::expr::{Expr, InList};
26use datafusion_expr::{Between, BinaryExpr, Operator};
27use datafusion_physical_expr::execution_props::ExecutionProps;
28use datafusion_physical_expr::{create_physical_expr, PhysicalExpr};
29use datatypes::arrow;
30use datatypes::value::scalar_value_to_timestamp;
31use snafu::ResultExt;
32
33use crate::error;
34
35#[cfg(test)]
36mod stats;
37
38macro_rules! return_none_if_utf8 {
41 ($lit: ident) => {
42 if matches!($lit, ScalarValue::Utf8(_)) {
43 warn!(
44 "Unexpected ScalarValue::Utf8 in time range predicate: {:?}. Maybe it's an implicit bug, please report it to https://github.com/GreptimeTeam/greptimedb/issues",
45 $lit
46 );
47
48 return None;
50 }
51 };
52}
53
54#[derive(Debug, Clone)]
56pub struct Predicate {
57 exprs: Arc<Vec<Expr>>,
59}
60
61impl Predicate {
62 pub fn new(exprs: Vec<Expr>) -> Self {
66 Self {
67 exprs: Arc::new(exprs),
68 }
69 }
70
71 pub fn exprs(&self) -> &[Expr] {
73 &self.exprs
74 }
75
76 pub fn to_physical_exprs(
78 &self,
79 schema: &arrow::datatypes::SchemaRef,
80 ) -> error::Result<Vec<Arc<dyn PhysicalExpr>>> {
81 let df_schema = schema
82 .clone()
83 .to_dfschema_ref()
84 .context(error::DatafusionSnafu)?;
85
86 let execution_props = &ExecutionProps::new();
90
91 Ok(self
92 .exprs
93 .iter()
94 .filter_map(|expr| create_physical_expr(expr, df_schema.as_ref(), execution_props).ok())
95 .collect::<Vec<_>>())
96 }
97
98 pub fn prune_with_stats<S: PruningStatistics>(
101 &self,
102 stats: &S,
103 schema: &arrow::datatypes::SchemaRef,
104 ) -> Vec<bool> {
105 let mut res = vec![true; stats.num_containers()];
106 let physical_exprs = match self.to_physical_exprs(schema) {
107 Ok(expr) => expr,
108 Err(e) => {
109 warn!(e; "Failed to build physical expr from predicates: {:?}", &self.exprs);
110 return res;
111 }
112 };
113
114 for expr in &physical_exprs {
115 match PruningPredicate::try_new(expr.clone(), schema.clone()) {
116 Ok(p) => match p.prune(stats) {
117 Ok(r) => {
118 for (curr_val, res) in r.into_iter().zip(res.iter_mut()) {
119 *res &= curr_val
120 }
121 }
122 Err(e) => {
123 warn!(e; "Failed to prune row groups");
124 }
125 },
126 Err(e) => {
127 error!(e; "Failed to create predicate for expr");
128 }
129 }
130 }
131 res
132 }
133}
134
135pub fn build_time_range_predicate(
140 ts_col_name: &str,
141 ts_col_unit: TimeUnit,
142 filters: &[Expr],
143) -> TimestampRange {
144 let mut res = TimestampRange::min_to_max();
145 for expr in filters {
146 if let Some(range) = extract_time_range_from_expr(ts_col_name, ts_col_unit, expr) {
147 res = res.and(&range);
148 }
149 }
150 res
151}
152
153fn extract_time_range_from_expr(
156 ts_col_name: &str,
157 ts_col_unit: TimeUnit,
158 expr: &Expr,
159) -> Option<TimestampRange> {
160 match expr {
161 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
162 extract_from_binary_expr(ts_col_name, ts_col_unit, left, op, right)
163 }
164 Expr::Between(Between {
165 expr,
166 negated,
167 low,
168 high,
169 }) => extract_from_between_expr(ts_col_name, ts_col_unit, expr, negated, low, high),
170 Expr::InList(InList {
171 expr,
172 list,
173 negated,
174 }) => extract_from_in_list_expr(ts_col_name, expr, *negated, list),
175 _ => None,
176 }
177}
178
179fn extract_from_binary_expr(
180 ts_col_name: &str,
181 ts_col_unit: TimeUnit,
182 left: &Expr,
183 op: &Operator,
184 right: &Expr,
185) -> Option<TimestampRange> {
186 match op {
187 Operator::Eq => get_timestamp_filter(ts_col_name, left, right)
188 .and_then(|(ts, _)| ts.convert_to(ts_col_unit))
189 .map(TimestampRange::single),
190 Operator::Lt => {
191 let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
192 if reverse {
193 let ts_val = ts.convert_to(ts_col_unit)?.value();
195 Some(TimestampRange::from_start(Timestamp::new(
196 ts_val + 1,
197 ts_col_unit,
198 )))
199 } else {
200 ts.convert_to_ceil(ts_col_unit)
202 .map(|ts| TimestampRange::until_end(ts, false))
203 }
204 }
205 Operator::LtEq => {
206 let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
207 if reverse {
208 ts.convert_to_ceil(ts_col_unit)
210 .map(TimestampRange::from_start)
211 } else {
212 ts.convert_to(ts_col_unit)
214 .map(|ts| TimestampRange::until_end(ts, true))
215 }
216 }
217 Operator::Gt => {
218 let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
219 if reverse {
220 ts.convert_to_ceil(ts_col_unit)
222 .map(|t| TimestampRange::until_end(t, false))
223 } else {
224 let ts_val = ts.convert_to(ts_col_unit)?.value();
226 Some(TimestampRange::from_start(Timestamp::new(
227 ts_val + 1,
228 ts_col_unit,
229 )))
230 }
231 }
232 Operator::GtEq => {
233 let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
234 if reverse {
235 ts.convert_to(ts_col_unit)
237 .map(|t| TimestampRange::until_end(t, true))
238 } else {
239 ts.convert_to_ceil(ts_col_unit)
241 .map(TimestampRange::from_start)
242 }
243 }
244 Operator::And => {
245 let left = extract_time_range_from_expr(ts_col_name, ts_col_unit, left)
248 .unwrap_or_else(TimestampRange::min_to_max);
249 let right = extract_time_range_from_expr(ts_col_name, ts_col_unit, right)
250 .unwrap_or_else(TimestampRange::min_to_max);
251 Some(left.and(&right))
252 }
253 Operator::Or => {
254 let left = extract_time_range_from_expr(ts_col_name, ts_col_unit, left)?;
255 let right = extract_time_range_from_expr(ts_col_name, ts_col_unit, right)?;
256 Some(left.or(&right))
257 }
258 _ => None,
259 }
260}
261
262fn get_timestamp_filter(ts_col_name: &str, left: &Expr, right: &Expr) -> Option<(Timestamp, bool)> {
263 let (col, lit, reverse) = match (left, right) {
264 (Expr::Column(column), Expr::Literal(scalar, _)) => (column, scalar, false),
265 (Expr::Literal(scalar, _), Expr::Column(column)) => (column, scalar, true),
266 _ => {
267 return None;
268 }
269 };
270 if col.name != ts_col_name {
271 return None;
272 }
273
274 return_none_if_utf8!(lit);
275 scalar_value_to_timestamp(lit, None).map(|t| (t, reverse))
276}
277
278fn extract_from_between_expr(
279 ts_col_name: &str,
280 ts_col_unit: TimeUnit,
281 expr: &Expr,
282 negated: &bool,
283 low: &Expr,
284 high: &Expr,
285) -> Option<TimestampRange> {
286 let Expr::Column(col) = expr else {
287 return None;
288 };
289 if col.name != ts_col_name {
290 return None;
291 }
292
293 if *negated {
294 return None;
295 }
296
297 match (low, high) {
298 (Expr::Literal(low, _), Expr::Literal(high, _)) => {
299 return_none_if_utf8!(low);
300 return_none_if_utf8!(high);
301
302 let low_opt =
303 scalar_value_to_timestamp(low, None).and_then(|ts| ts.convert_to(ts_col_unit));
304 let high_opt = scalar_value_to_timestamp(high, None)
305 .and_then(|ts| ts.convert_to_ceil(ts_col_unit));
306 Some(TimestampRange::new_inclusive(low_opt, high_opt))
307 }
308 _ => None,
309 }
310}
311
312fn extract_from_in_list_expr(
314 ts_col_name: &str,
315 expr: &Expr,
316 negated: bool,
317 list: &[Expr],
318) -> Option<TimestampRange> {
319 if negated {
320 return None;
321 }
322 let Expr::Column(col) = expr else {
323 return None;
324 };
325 if col.name != ts_col_name {
326 return None;
327 }
328
329 if list.is_empty() {
330 return Some(TimestampRange::empty());
331 }
332 let mut init_range = TimestampRange::empty();
333 for expr in list {
334 if let Expr::Literal(scalar, _) = expr {
335 return_none_if_utf8!(scalar);
336 if let Some(timestamp) = scalar_value_to_timestamp(scalar, None) {
337 init_range = init_range.or(&TimestampRange::single(timestamp))
338 } else {
339 return None;
342 }
343 }
344 }
345 Some(init_range)
346}
347
348#[cfg(test)]
349mod tests {
350 use std::sync::Arc;
351
352 use common_test_util::temp_dir::{create_temp_dir, TempDir};
353 use datafusion::parquet::arrow::ArrowWriter;
354 use datafusion_common::{Column, ScalarValue};
355 use datafusion_expr::{col, lit, BinaryExpr, Literal, Operator};
356 use datatypes::arrow::array::Int32Array;
357 use datatypes::arrow::datatypes::{DataType, Field, Schema};
358 use datatypes::arrow::record_batch::RecordBatch;
359 use datatypes::arrow_array::StringArray;
360 use parquet::arrow::ParquetRecordBatchStreamBuilder;
361 use parquet::file::properties::WriterProperties;
362
363 use super::*;
364 use crate::predicate::stats::RowGroupPruningStatistics;
365
366 fn check_build_predicate(expr: Expr, expect: TimestampRange) {
367 assert_eq!(
368 expect,
369 build_time_range_predicate("ts", TimeUnit::Millisecond, &[expr])
370 );
371 }
372
373 #[test]
374 fn test_gt() {
375 check_build_predicate(
377 col("ts").gt(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
378 TimestampRange::from_start(Timestamp::new_millisecond(2)),
379 );
380
381 check_build_predicate(
383 lit(ScalarValue::TimestampMillisecond(Some(1), None)).gt(col("ts")),
384 TimestampRange::until_end(Timestamp::new_millisecond(1), false),
385 );
386
387 check_build_predicate(
389 lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).gt(col("ts")),
390 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
391 );
392
393 check_build_predicate(
395 col("ts").gt(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
396 TimestampRange::from_start(Timestamp::new_millisecond(2)),
397 );
398
399 check_build_predicate(
401 lit(ScalarValue::TimestampSecond(Some(1), None)).gt(col("ts")),
402 TimestampRange::until_end(Timestamp::new_millisecond(1000), false),
403 );
404
405 check_build_predicate(
407 col("ts").gt(lit(ScalarValue::TimestampSecond(Some(1), None))),
408 TimestampRange::from_start(Timestamp::new_millisecond(1001)),
409 );
410 }
411
412 #[test]
413 fn test_gt_eq() {
414 check_build_predicate(
416 col("ts").gt_eq(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
417 TimestampRange::from_start(Timestamp::new_millisecond(1)),
418 );
419
420 check_build_predicate(
422 lit(ScalarValue::TimestampMillisecond(Some(1), None)).gt_eq(col("ts")),
423 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
424 );
425
426 check_build_predicate(
428 lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).gt_eq(col("ts")),
429 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
430 );
431
432 check_build_predicate(
434 col("ts").gt_eq(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
435 TimestampRange::from_start(Timestamp::new_millisecond(2)),
436 );
437
438 check_build_predicate(
440 lit(ScalarValue::TimestampSecond(Some(1), None)).gt_eq(col("ts")),
441 TimestampRange::until_end(Timestamp::new_millisecond(1000), true),
442 );
443
444 check_build_predicate(
446 col("ts").gt_eq(lit(ScalarValue::TimestampSecond(Some(1), None))),
447 TimestampRange::from_start(Timestamp::new_millisecond(1000)),
448 );
449 }
450
451 #[test]
452 fn test_lt() {
453 check_build_predicate(
455 col("ts").lt(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
456 TimestampRange::until_end(Timestamp::new_millisecond(1), false),
457 );
458
459 check_build_predicate(
461 lit(ScalarValue::TimestampMillisecond(Some(1), None)).lt(col("ts")),
462 TimestampRange::from_start(Timestamp::new_millisecond(2)),
463 );
464
465 check_build_predicate(
467 lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).lt(col("ts")),
468 TimestampRange::from_start(Timestamp::new_millisecond(2)),
469 );
470
471 check_build_predicate(
473 col("ts").lt(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
474 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
475 );
476
477 check_build_predicate(
479 lit(ScalarValue::TimestampSecond(Some(1), None)).lt(col("ts")),
480 TimestampRange::from_start(Timestamp::new_millisecond(1001)),
481 );
482
483 check_build_predicate(
485 col("ts").lt(lit(ScalarValue::TimestampSecond(Some(1), None))),
486 TimestampRange::until_end(Timestamp::new_millisecond(1000), false),
487 );
488 }
489
490 #[test]
491 fn test_lt_eq() {
492 check_build_predicate(
494 col("ts").lt_eq(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
495 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
496 );
497
498 check_build_predicate(
500 lit(ScalarValue::TimestampMillisecond(Some(1), None)).lt_eq(col("ts")),
501 TimestampRange::from_start(Timestamp::new_millisecond(1)),
502 );
503
504 check_build_predicate(
506 lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).lt_eq(col("ts")),
507 TimestampRange::from_start(Timestamp::new_millisecond(2)),
508 );
509
510 check_build_predicate(
512 col("ts").lt_eq(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
513 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
514 );
515
516 check_build_predicate(
518 lit(ScalarValue::TimestampSecond(Some(1), None)).lt_eq(col("ts")),
519 TimestampRange::from_start(Timestamp::new_millisecond(1000)),
520 );
521
522 check_build_predicate(
524 col("ts").lt_eq(lit(ScalarValue::TimestampSecond(Some(1), None))),
525 TimestampRange::until_end(Timestamp::new_millisecond(1000), true),
526 );
527 }
528
529 async fn gen_test_parquet_file(dir: &TempDir, cnt: usize) -> (String, Arc<Schema>) {
530 let path = dir
531 .path()
532 .join("test-prune.parquet")
533 .to_string_lossy()
534 .to_string();
535
536 let name_field = Field::new("name", DataType::Utf8, true);
537 let count_field = Field::new("cnt", DataType::Int32, true);
538 let schema = Arc::new(Schema::new(vec![name_field, count_field]));
539
540 let file = std::fs::OpenOptions::new()
541 .write(true)
542 .create(true)
543 .truncate(true)
544 .open(path.clone())
545 .unwrap();
546
547 let write_props = WriterProperties::builder()
548 .set_max_row_group_size(10)
549 .build();
550 let mut writer = ArrowWriter::try_new(file, schema.clone(), Some(write_props)).unwrap();
551
552 for i in (0..cnt).step_by(10) {
553 let name_array = Arc::new(StringArray::from(
554 (i..(i + 10).min(cnt))
555 .map(|i| i.to_string())
556 .collect::<Vec<_>>(),
557 )) as Arc<_>;
558 let count_array = Arc::new(Int32Array::from(
559 (i..(i + 10).min(cnt)).map(|i| i as i32).collect::<Vec<_>>(),
560 )) as Arc<_>;
561 let rb = RecordBatch::try_new(schema.clone(), vec![name_array, count_array]).unwrap();
562 writer.write(&rb).unwrap();
563 }
564 let _ = writer.close().unwrap();
565 (path, schema)
566 }
567
568 async fn assert_prune(array_cnt: usize, filters: Vec<Expr>, expect: Vec<bool>) {
569 let dir = create_temp_dir("prune_parquet");
570 let (path, arrow_schema) = gen_test_parquet_file(&dir, array_cnt).await;
571 let schema = Arc::new(datatypes::schema::Schema::try_from(arrow_schema.clone()).unwrap());
572 let arrow_predicate = Predicate::new(filters);
573 let builder = ParquetRecordBatchStreamBuilder::new(
574 tokio::fs::OpenOptions::new()
575 .read(true)
576 .open(path)
577 .await
578 .unwrap(),
579 )
580 .await
581 .unwrap();
582 let metadata = builder.metadata().clone();
583 let row_groups = metadata.row_groups();
584
585 let stats = RowGroupPruningStatistics::new(row_groups, &schema);
586 let res = arrow_predicate.prune_with_stats(&stats, &arrow_schema);
587 assert_eq!(expect, res);
588 }
589
590 fn gen_predicate(max_val: i32, op: Operator) -> Vec<Expr> {
591 vec![datafusion_expr::Expr::BinaryExpr(BinaryExpr {
592 left: Box::new(datafusion_expr::Expr::Column(Column::from_name("cnt"))),
593 op,
594 right: Box::new(max_val.lit()),
595 })]
596 }
597
598 #[tokio::test]
599 async fn test_prune_empty() {
600 assert_prune(3, vec![], vec![true]).await;
601 }
602
603 #[tokio::test]
604 async fn test_prune_all_match() {
605 let p = gen_predicate(3, Operator::Gt);
606 assert_prune(2, p, vec![false]).await;
607 }
608
609 #[tokio::test]
610 async fn test_prune_gt() {
611 let p = gen_predicate(29, Operator::Gt);
612 assert_prune(
613 100,
614 p,
615 vec![
616 false, false, false, true, true, true, true, true, true, true,
617 ],
618 )
619 .await;
620 }
621
622 #[tokio::test]
623 async fn test_prune_eq_expr() {
624 let p = gen_predicate(30, Operator::Eq);
625 assert_prune(40, p, vec![false, false, false, true]).await;
626 }
627
628 #[tokio::test]
629 async fn test_prune_neq_expr() {
630 let p = gen_predicate(30, Operator::NotEq);
631 assert_prune(40, p, vec![true, true, true, true]).await;
632 }
633
634 #[tokio::test]
635 async fn test_prune_gteq_expr() {
636 let p = gen_predicate(29, Operator::GtEq);
637 assert_prune(40, p, vec![false, false, true, true]).await;
638 }
639
640 #[tokio::test]
641 async fn test_prune_lt_expr() {
642 let p = gen_predicate(30, Operator::Lt);
643 assert_prune(40, p, vec![true, true, true, false]).await;
644 }
645
646 #[tokio::test]
647 async fn test_prune_lteq_expr() {
648 let p = gen_predicate(30, Operator::LtEq);
649 assert_prune(40, p, vec![true, true, true, true]).await;
650 }
651
652 #[tokio::test]
653 async fn test_prune_between_expr() {
654 let p = gen_predicate(30, Operator::LtEq);
655 assert_prune(40, p, vec![true, true, true, true]).await;
656 }
657
658 #[tokio::test]
659 async fn test_or() {
660 let e = datafusion_expr::Expr::Column(Column::from_name("cnt"))
662 .gt(30.lit())
663 .or(datafusion_expr::Expr::Column(Column::from_name("cnt")).lt(20.lit()));
664 assert_prune(40, vec![e], vec![true, true, false, true]).await;
665 }
666
667 #[tokio::test]
668 async fn test_to_physical_expr() {
669 let predicate = Predicate::new(vec![
670 col("host").eq(lit("host_a")),
671 col("ts").gt(lit(ScalarValue::TimestampMicrosecond(Some(123), None))),
672 ]);
673
674 let schema = Arc::new(arrow::datatypes::Schema::new(vec![Field::new(
675 "host",
676 arrow::datatypes::DataType::Utf8,
677 false,
678 )]));
679
680 let predicates = predicate.to_physical_exprs(&schema).unwrap();
681 assert!(!predicates.is_empty());
682 }
683}