flow/batching_mode/
time_window.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Time window expr and helper functions
16//!
17
18use std::collections::BTreeSet;
19use std::sync::Arc;
20
21use api::helper::pb_value_to_value_ref;
22use arrow::array::{
23    TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
24    TimestampSecondArray,
25};
26use catalog::CatalogManagerRef;
27use common_error::ext::BoxedError;
28use common_recordbatch::DfRecordBatch;
29use common_telemetry::warn;
30use common_time::timestamp::TimeUnit;
31use common_time::Timestamp;
32use datafusion::error::Result as DfResult;
33use datafusion::execution::SessionState;
34use datafusion::logical_expr::Expr;
35use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
36use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter};
37use datafusion_common::{DFSchema, TableReference};
38use datafusion_expr::{ColumnarValue, LogicalPlan};
39use datafusion_physical_expr::PhysicalExprRef;
40use datatypes::prelude::{ConcreteDataType, DataType};
41use datatypes::schema::TIME_INDEX_KEY;
42use datatypes::value::Value;
43use datatypes::vectors::{
44    TimestampMicrosecondVector, TimestampMillisecondVector, TimestampNanosecondVector,
45    TimestampSecondVector, Vector,
46};
47use itertools::Itertools;
48use session::context::QueryContextRef;
49use snafu::{ensure, OptionExt, ResultExt};
50
51use crate::adapter::util::from_proto_to_data_type;
52use crate::error::{
53    ArrowSnafu, DatafusionSnafu, DatatypesSnafu, ExternalSnafu, PlanSnafu, TimeSnafu,
54    UnexpectedSnafu,
55};
56use crate::expr::error::DataTypeSnafu;
57use crate::Error;
58
59/// Represents a test timestamp in seconds since the Unix epoch.
60const DEFAULT_TEST_TIMESTAMP: Timestamp = Timestamp::new_second(17_0000_0000);
61
62/// Time window expr like `date_bin(INTERVAL '1' MINUTE, ts)`, this type help with
63/// evaluating the expr using given timestamp
64///
65/// The time window expr must satisfies following conditions:
66/// 1. The expr must be monotonic non-decreasing
67/// 2. The expr must only have one and only one input column with timestamp type, and the output column must be timestamp type
68/// 3. The expr must be deterministic
69///
70/// An example of time window expr is `date_bin(INTERVAL '1' MINUTE, ts)`
71#[derive(Debug, Clone)]
72pub struct TimeWindowExpr {
73    phy_expr: PhysicalExprRef,
74    pub column_name: String,
75    logical_expr: Expr,
76    df_schema: DFSchema,
77    eval_time_window_size: Option<std::time::Duration>,
78    eval_time_original: Option<Timestamp>,
79}
80
81impl std::fmt::Display for TimeWindowExpr {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        f.debug_struct("TimeWindowExpr")
84            .field("phy_expr", &self.phy_expr.to_string())
85            .field("column_name", &self.column_name)
86            .field("logical_expr", &self.logical_expr.to_string())
87            .field("df_schema", &self.df_schema)
88            .finish()
89    }
90}
91
92impl TimeWindowExpr {
93    /// The time window size of the expr, get from calling `eval` with a test timestamp
94    pub fn time_window_size(&self) -> &Option<std::time::Duration> {
95        &self.eval_time_window_size
96    }
97
98    pub fn from_expr(
99        expr: &Expr,
100        column_name: &str,
101        df_schema: &DFSchema,
102        session: &SessionState,
103    ) -> Result<Self, Error> {
104        let phy_expr: PhysicalExprRef = to_phy_expr(expr, df_schema, session)?;
105        let mut zelf = Self {
106            phy_expr,
107            column_name: column_name.to_string(),
108            logical_expr: expr.clone(),
109            df_schema: df_schema.clone(),
110            eval_time_window_size: None,
111            eval_time_original: None,
112        };
113        let test_ts = DEFAULT_TEST_TIMESTAMP;
114        let (lower, upper) = zelf.eval(test_ts)?;
115        let time_window_size = match (lower, upper) {
116            (Some(l), Some(u)) => u.sub(&l).map(|r| r.to_std()).transpose().map_err(|_| {
117                UnexpectedSnafu {
118                    reason: format!(
119                        "Expect upper bound older than lower bound, found upper={u:?} and lower={l:?}"
120                    ),
121                }
122                .build()
123            })?,
124            _ => None,
125        };
126        zelf.eval_time_window_size = time_window_size;
127        zelf.eval_time_original = lower;
128
129        Ok(zelf)
130    }
131
132    /// TODO(discord9): add `eval_batch` too
133    pub fn eval(
134        &self,
135        current: Timestamp,
136    ) -> Result<(Option<Timestamp>, Option<Timestamp>), Error> {
137        fn compute_distance(time_diff_ns: i64, stride_ns: i64) -> i64 {
138            if stride_ns == 0 {
139                return time_diff_ns;
140            }
141            // a - (a % n) impl ceil to nearest n * stride
142            let time_delta = time_diff_ns - (time_diff_ns % stride_ns);
143
144            if time_diff_ns < 0 && time_delta != time_diff_ns {
145                // The origin is later than the source timestamp, round down to the previous bin
146
147                time_delta - stride_ns
148            } else {
149                time_delta
150            }
151        }
152
153        // FAST PATH: if we have eval_time_original and eval_time_window_size,
154        // we can compute the bounds directly
155        if let (Some(original), Some(window_size)) =
156            (self.eval_time_original, self.eval_time_window_size)
157        {
158            // date_bin align current to lower bound
159            let time_diff_ns = current.sub(&original).and_then(|s|s.num_nanoseconds()).with_context(||UnexpectedSnafu {
160                reason: format!(
161                    "Failed to compute time difference between current {current:?} and original {original:?}"
162                ),
163            })?;
164
165            let window_size_ns = window_size.as_nanos() as i64;
166
167            let distance_ns = compute_distance(time_diff_ns, window_size_ns);
168
169            let lower_bound = if distance_ns >= 0 {
170                original.add_duration(std::time::Duration::from_nanos(distance_ns as u64))
171            } else {
172                original.sub_duration(std::time::Duration::from_nanos((-distance_ns) as u64))
173            }
174            .context(TimeSnafu)?;
175            let upper_bound = lower_bound.add_duration(window_size).context(TimeSnafu)?;
176
177            return Ok((Some(lower_bound), Some(upper_bound)));
178        }
179
180        let lower_bound =
181            calc_expr_time_window_lower_bound(&self.phy_expr, &self.df_schema, current)?;
182        let upper_bound =
183            probe_expr_time_window_upper_bound(&self.phy_expr, &self.df_schema, current)?;
184        Ok((lower_bound, upper_bound))
185    }
186
187    /// Find timestamps from rows using time window expr
188    ///
189    /// use column of name `self.column_name` from input rows list as input to time window expr
190    pub async fn handle_rows(
191        &self,
192        rows_list: Vec<api::v1::Rows>,
193    ) -> Result<BTreeSet<Timestamp>, Error> {
194        let mut time_windows = BTreeSet::new();
195
196        for rows in rows_list {
197            // pick the time index column and use it to eval on `self.expr`
198            // TODO(discord9): handle case where time index column is not present(i.e. DEFAULT constant value)
199            let ts_col_index = rows
200                .schema
201                .iter()
202                .map(|col| col.column_name.clone())
203                .position(|name| name == self.column_name);
204            let Some(ts_col_index) = ts_col_index else {
205                warn!("can't found time index column in schema: {:?}", rows.schema);
206                continue;
207            };
208            let col_schema = &rows.schema[ts_col_index];
209            let cdt = from_proto_to_data_type(col_schema)?;
210
211            let mut vector = cdt.create_mutable_vector(rows.rows.len());
212            for row in rows.rows {
213                let value = pb_value_to_value_ref(&row.values[ts_col_index], &None);
214                vector.try_push_value_ref(value).context(DataTypeSnafu {
215                    msg: "Failed to convert rows to columns",
216                })?;
217            }
218            let vector = vector.to_vector();
219
220            let df_schema = create_df_schema_for_ts_column(&self.column_name, cdt)?;
221
222            let rb =
223                DfRecordBatch::try_new(df_schema.inner().clone(), vec![vector.to_arrow_array()])
224                    .with_context(|_e| ArrowSnafu {
225                        context: format!(
226                            "Failed to create record batch from {df_schema:?} and {vector:?}"
227                        ),
228                    })?;
229
230            let eval_res = self
231                .phy_expr
232                .evaluate(&rb)
233                .with_context(|_| DatafusionSnafu {
234                    context: format!(
235                        "Failed to evaluate physical expression {:?} on {rb:?}",
236                        self.phy_expr
237                    ),
238                })?;
239
240            let res = columnar_to_ts_vector(&eval_res)?;
241
242            for ts in res.into_iter().flatten() {
243                time_windows.insert(ts);
244            }
245        }
246
247        Ok(time_windows)
248    }
249}
250
251fn create_df_schema_for_ts_column(name: &str, cdt: ConcreteDataType) -> Result<DFSchema, Error> {
252    let arrow_schema = Arc::new(arrow_schema::Schema::new(vec![arrow_schema::Field::new(
253        name,
254        cdt.as_arrow_type(),
255        false,
256    )]));
257
258    let df_schema = DFSchema::from_field_specific_qualified_schema(
259        vec![Some(TableReference::bare("TimeIndexOnlyTable"))],
260        &arrow_schema,
261    )
262    .with_context(|_e| DatafusionSnafu {
263        context: format!("Failed to create DFSchema from arrow schema {arrow_schema:?}"),
264    })?;
265
266    Ok(df_schema)
267}
268
269/// Convert `ColumnarValue` to `Vec<Option<Timestamp>>`
270fn columnar_to_ts_vector(columnar: &ColumnarValue) -> Result<Vec<Option<Timestamp>>, Error> {
271    let val = match columnar {
272        datafusion_expr::ColumnarValue::Array(array) => {
273            let ty = array.data_type();
274            let ty = ConcreteDataType::from_arrow_type(ty);
275            let time_unit = if let ConcreteDataType::Timestamp(ty) = ty {
276                ty.unit()
277            } else {
278                return UnexpectedSnafu {
279                    reason: format!("Non-timestamp type: {ty:?}"),
280                }
281                .fail();
282            };
283
284            match time_unit {
285                TimeUnit::Second => array
286                    .as_ref()
287                    .as_any()
288                    .downcast_ref::<TimestampSecondArray>()
289                    .with_context(|| PlanSnafu {
290                        reason: format!("Failed to create vector from arrow array {array:?}"),
291                    })?
292                    .values()
293                    .iter()
294                    .map(|d| Some(Timestamp::new(*d, time_unit)))
295                    .collect_vec(),
296                TimeUnit::Millisecond => array
297                    .as_ref()
298                    .as_any()
299                    .downcast_ref::<TimestampMillisecondArray>()
300                    .with_context(|| PlanSnafu {
301                        reason: format!("Failed to create vector from arrow array {array:?}"),
302                    })?
303                    .values()
304                    .iter()
305                    .map(|d| Some(Timestamp::new(*d, time_unit)))
306                    .collect_vec(),
307                TimeUnit::Microsecond => array
308                    .as_ref()
309                    .as_any()
310                    .downcast_ref::<TimestampMicrosecondArray>()
311                    .with_context(|| PlanSnafu {
312                        reason: format!("Failed to create vector from arrow array {array:?}"),
313                    })?
314                    .values()
315                    .iter()
316                    .map(|d| Some(Timestamp::new(*d, time_unit)))
317                    .collect_vec(),
318                TimeUnit::Nanosecond => array
319                    .as_ref()
320                    .as_any()
321                    .downcast_ref::<TimestampNanosecondArray>()
322                    .with_context(|| PlanSnafu {
323                        reason: format!("Failed to create vector from arrow array {array:?}"),
324                    })?
325                    .values()
326                    .iter()
327                    .map(|d| Some(Timestamp::new(*d, time_unit)))
328                    .collect_vec(),
329            }
330        }
331        datafusion_expr::ColumnarValue::Scalar(scalar) => {
332            let value = Value::try_from(scalar.clone()).with_context(|_| DatatypesSnafu {
333                extra: format!("Failed to convert scalar {scalar:?} to value"),
334            })?;
335            let ts = value.as_timestamp().context(UnexpectedSnafu {
336                reason: format!("Expect Timestamp, found {:?}", value),
337            })?;
338            vec![Some(ts)]
339        }
340    };
341    Ok(val)
342}
343
344/// Return (`the column name of time index column`, `the time window expr`, `the expected time unit of time index column`, `the expr's schema for evaluating the time window`)
345///
346/// The time window expr is expected to have one input column with Timestamp type, and also return Timestamp type, the time window expr is expected
347/// to be monotonic increasing and appears in the innermost GROUP BY clause
348///
349/// note this plan should only contain one TableScan
350pub async fn find_time_window_expr(
351    plan: &LogicalPlan,
352    catalog_man: CatalogManagerRef,
353    query_ctx: QueryContextRef,
354) -> Result<(String, Option<datafusion_expr::Expr>, TimeUnit, DFSchema), Error> {
355    // TODO(discord9): find the expr that do time window
356
357    let mut table_name = None;
358
359    // first find the table source in the logical plan
360    plan.apply(|plan| {
361        let LogicalPlan::TableScan(table_scan) = plan else {
362            return Ok(TreeNodeRecursion::Continue);
363        };
364        table_name = Some(table_scan.table_name.clone());
365        Ok(TreeNodeRecursion::Stop)
366    })
367    .with_context(|_| DatafusionSnafu {
368        context: format!("Can't find table source in plan {plan:?}"),
369    })?;
370    let Some(table_name) = table_name else {
371        UnexpectedSnafu {
372            reason: format!("Can't find table source in plan {plan:?}"),
373        }
374        .fail()?
375    };
376
377    let current_schema = query_ctx.current_schema();
378
379    let catalog_name = table_name.catalog().unwrap_or(query_ctx.current_catalog());
380    let schema_name = table_name.schema().unwrap_or(&current_schema);
381    let table_name = table_name.table();
382
383    let Some(table_ref) = catalog_man
384        .table(catalog_name, schema_name, table_name, Some(&query_ctx))
385        .await
386        .map_err(BoxedError::new)
387        .context(ExternalSnafu)?
388    else {
389        UnexpectedSnafu {
390            reason: format!(
391                "Can't find table {table_name:?} in catalog {catalog_name:?}/{schema_name:?}"
392            ),
393        }
394        .fail()?
395    };
396
397    let schema = &table_ref.table_info().meta.schema;
398
399    let ts_index = schema.timestamp_column().with_context(|| UnexpectedSnafu {
400        reason: format!("Can't find timestamp column in table {table_name:?}"),
401    })?;
402
403    let ts_col_name = ts_index.name.clone();
404
405    let expected_time_unit = ts_index.data_type.as_timestamp().with_context(|| UnexpectedSnafu {
406        reason: format!(
407            "Expected timestamp column {ts_col_name:?} in table {table_name:?} to be timestamp, but got {ts_index:?}"
408        ),
409    })?.unit();
410
411    let arrow_schema = Arc::new(arrow_schema::Schema::new(vec![arrow_schema::Field::new(
412        ts_col_name.clone(),
413        ts_index.data_type.as_arrow_type(),
414        false,
415    )]));
416
417    let df_schema = DFSchema::from_field_specific_qualified_schema(
418        vec![Some(TableReference::bare(table_name))],
419        &arrow_schema,
420    )
421    .with_context(|_e| DatafusionSnafu {
422        context: format!("Failed to create DFSchema from arrow schema {arrow_schema:?}"),
423    })?;
424
425    // find the time window expr which refers to the time index column
426    let mut aggr_expr = None;
427    let mut time_window_expr: Option<Expr> = None;
428
429    let find_inner_aggr_expr = |plan: &LogicalPlan| {
430        if let LogicalPlan::Aggregate(aggregate) = plan {
431            aggr_expr = Some(aggregate.clone());
432        };
433
434        Ok(TreeNodeRecursion::Continue)
435    };
436    plan.apply(find_inner_aggr_expr)
437        .with_context(|_| DatafusionSnafu {
438            context: format!("Can't find aggr expr in plan {plan:?}"),
439        })?;
440
441    if let Some(aggregate) = aggr_expr {
442        for group_expr in &aggregate.group_expr {
443            let refs = group_expr.column_refs();
444            if refs.len() != 1 {
445                continue;
446            }
447            let ref_col = refs.iter().next().unwrap();
448
449            let index = aggregate.input.schema().maybe_index_of_column(ref_col);
450            let Some(index) = index else {
451                continue;
452            };
453            let field = aggregate.input.schema().field(index);
454
455            // TODO(discord9): need to ensure the field has the meta key for the time index
456            let is_time_index =
457                field.metadata().get(TIME_INDEX_KEY).map(|s| s.as_str()) == Some("true");
458
459            if is_time_index {
460                let rewrite_column = group_expr.clone();
461                let rewritten = rewrite_column
462                    .rewrite(&mut RewriteColumn {
463                        table_name: table_name.to_string(),
464                    })
465                    .with_context(|_| DatafusionSnafu {
466                        context: format!("Rewrite expr failed, expr={:?}", group_expr),
467                    })?
468                    .data;
469                struct RewriteColumn {
470                    table_name: String,
471                }
472
473                impl TreeNodeRewriter for RewriteColumn {
474                    type Node = Expr;
475                    fn f_down(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
476                        let Expr::Column(mut column) = node else {
477                            return Ok(Transformed::no(node));
478                        };
479
480                        column.relation = Some(TableReference::bare(self.table_name.clone()));
481
482                        Ok(Transformed::yes(Expr::Column(column)))
483                    }
484                }
485
486                time_window_expr = Some(rewritten);
487                break;
488            }
489        }
490        Ok((ts_col_name, time_window_expr, expected_time_unit, df_schema))
491    } else {
492        // can't found time window expr, return None
493        Ok((ts_col_name, None, expected_time_unit, df_schema))
494    }
495}
496
497/// Find nearest lower bound for time `current` in given `plan` for the time window expr.
498/// i.e. for time window expr being `date_bin(INTERVAL '5 minutes', ts) as time_window` and `current="2021-07-01 00:01:01.000"`,
499/// return `Some("2021-07-01 00:00:00.000")`
500/// if `plan` doesn't contain a `TIME INDEX` column, return `None`
501///
502/// Time window expr is a expr that:
503/// 1. ref only to a time index column
504/// 2. is monotonic increasing
505/// 3. show up in GROUP BY clause
506///
507/// note this plan should only contain one TableScan
508#[cfg(test)]
509pub async fn find_plan_time_window_bound(
510    plan: &LogicalPlan,
511    current: Timestamp,
512    query_ctx: QueryContextRef,
513    engine: query::QueryEngineRef,
514) -> Result<(String, Option<Timestamp>, Option<Timestamp>), Error> {
515    // TODO(discord9): find the expr that do time window
516    let catalog_man = engine.engine_state().catalog_manager();
517
518    let (ts_col_name, time_window_expr, expected_time_unit, df_schema) =
519        find_time_window_expr(plan, catalog_man.clone(), query_ctx).await?;
520    // cast current to ts_index's type
521    let new_current = current
522        .convert_to(expected_time_unit)
523        .with_context(|| UnexpectedSnafu {
524            reason: format!("Failed to cast current timestamp {current:?} to {expected_time_unit}"),
525        })?;
526
527    // if no time_window_expr is found, return None
528    if let Some(time_window_expr) = time_window_expr {
529        let phy_expr = to_phy_expr(
530            &time_window_expr,
531            &df_schema,
532            &engine.engine_state().session_state(),
533        )?;
534        let lower_bound = calc_expr_time_window_lower_bound(&phy_expr, &df_schema, new_current)?;
535        let upper_bound = probe_expr_time_window_upper_bound(&phy_expr, &df_schema, new_current)?;
536        Ok((ts_col_name, lower_bound, upper_bound))
537    } else {
538        Ok((ts_col_name, None, None))
539    }
540}
541
542/// Find the lower bound of time window in given `expr` and `current` timestamp.
543///
544/// i.e. for `current="2021-07-01 00:01:01.000"` and `expr=date_bin(INTERVAL '5 minutes', ts) as time_window` and `ts_col=ts`,
545/// return `Some("2021-07-01 00:00:00.000")` since it's the lower bound
546/// return `Some("2021-07-01 00:00:00.000")` since it's the lower bound
547/// of current time window given the current timestamp
548///
549/// if return None, meaning this time window have no lower bound
550fn calc_expr_time_window_lower_bound(
551    phy_expr: &PhysicalExprRef,
552    df_schema: &DFSchema,
553    current: Timestamp,
554) -> Result<Option<Timestamp>, Error> {
555    let cur_time_window = eval_phy_time_window_expr(phy_expr, df_schema, current)?;
556    let input_time_unit = cur_time_window.unit();
557    Ok(cur_time_window.convert_to(input_time_unit))
558}
559
560/// Probe for the upper bound for time window expression
561fn probe_expr_time_window_upper_bound(
562    phy_expr: &PhysicalExprRef,
563    df_schema: &DFSchema,
564    current: Timestamp,
565) -> Result<Option<Timestamp>, Error> {
566    // TODO(discord9): special handling `date_bin` for faster path
567    use std::cmp::Ordering;
568
569    let cur_time_window = eval_phy_time_window_expr(phy_expr, df_schema, current)?;
570
571    // search to find the lower bound
572    let mut offset: i64 = 1;
573    let mut lower_bound = Some(current);
574    let upper_bound;
575    // first expontial probe to found a range for binary search
576    loop {
577        let Some(next_val) = current.value().checked_add(offset) else {
578            // no upper bound if overflow, which is ok
579            return Ok(None);
580        };
581
582        let next_time_probe = common_time::Timestamp::new(next_val, current.unit());
583
584        let next_time_window = eval_phy_time_window_expr(phy_expr, df_schema, next_time_probe)?;
585
586        match next_time_window.cmp(&cur_time_window) {
587            Ordering::Less => UnexpectedSnafu {
588                    reason: format!(
589                        "Unsupported time window expression, expect monotonic increasing for time window expression {phy_expr:?}"
590                    ),
591                }
592                .fail()?,
593            Ordering::Equal => {
594                lower_bound = Some(next_time_probe);
595            }
596            Ordering::Greater => {
597                upper_bound = Some(next_time_probe);
598                break
599            }
600        }
601
602        let Some(new_offset) = offset.checked_mul(2) else {
603            // no upper bound if overflow
604            return Ok(None);
605        };
606        offset = new_offset;
607    }
608
609    // binary search for the exact upper bound
610
611    binary_search_expr(
612        lower_bound,
613        upper_bound,
614        cur_time_window,
615        phy_expr,
616        df_schema,
617    )
618    .map(Some)
619}
620
621fn binary_search_expr(
622    lower_bound: Option<Timestamp>,
623    upper_bound: Option<Timestamp>,
624    cur_time_window: Timestamp,
625    phy_expr: &PhysicalExprRef,
626    df_schema: &DFSchema,
627) -> Result<Timestamp, Error> {
628    ensure!(lower_bound.map(|v|v.unit()) == upper_bound.map(|v| v.unit()), UnexpectedSnafu {
629        reason: format!(" unit mismatch for time window expression {phy_expr:?}, found {lower_bound:?} and {upper_bound:?}"),
630    });
631
632    let output_unit = upper_bound
633        .context(UnexpectedSnafu {
634            reason: "should have lower bound",
635        })?
636        .unit();
637
638    let mut low = lower_bound
639        .context(UnexpectedSnafu {
640            reason: "should have lower bound",
641        })?
642        .value();
643    let mut high = upper_bound
644        .context(UnexpectedSnafu {
645            reason: "should have upper bound",
646        })?
647        .value();
648    while low < high {
649        let mid = (low + high) / 2;
650        let mid_probe = common_time::Timestamp::new(mid, output_unit);
651        let mid_time_window = eval_phy_time_window_expr(phy_expr, df_schema, mid_probe)?;
652
653        match mid_time_window.cmp(&cur_time_window) {
654            std::cmp::Ordering::Less => UnexpectedSnafu {
655                reason: format!("Binary search failed for time window expression {phy_expr:?}"),
656            }
657            .fail()?,
658            std::cmp::Ordering::Equal => low = mid + 1,
659            std::cmp::Ordering::Greater => high = mid,
660        }
661    }
662
663    let final_upper_bound_for_time_window = common_time::Timestamp::new(high, output_unit);
664    Ok(final_upper_bound_for_time_window)
665}
666
667/// Expect the `phy` expression only have one input column with Timestamp type, and also return Timestamp type
668fn eval_phy_time_window_expr(
669    phy: &PhysicalExprRef,
670    df_schema: &DFSchema,
671    input_value: Timestamp,
672) -> Result<Timestamp, Error> {
673    let schema_ty = df_schema.field(0).data_type();
674    let schema_cdt = ConcreteDataType::from_arrow_type(schema_ty);
675    let schema_unit = if let ConcreteDataType::Timestamp(ts) = schema_cdt {
676        ts.unit()
677    } else {
678        return UnexpectedSnafu {
679            reason: format!("Expect Timestamp, found {:?}", schema_cdt),
680        }
681        .fail();
682    };
683    let input_value = input_value
684        .convert_to(schema_unit)
685        .with_context(|| UnexpectedSnafu {
686            reason: format!("Failed to convert timestamp {input_value:?} to {schema_unit}"),
687        })?;
688    let ts_vector = match schema_unit {
689        TimeUnit::Second => {
690            TimestampSecondVector::from_vec(vec![input_value.value()]).to_arrow_array()
691        }
692        TimeUnit::Millisecond => {
693            TimestampMillisecondVector::from_vec(vec![input_value.value()]).to_arrow_array()
694        }
695        TimeUnit::Microsecond => {
696            TimestampMicrosecondVector::from_vec(vec![input_value.value()]).to_arrow_array()
697        }
698        TimeUnit::Nanosecond => {
699            TimestampNanosecondVector::from_vec(vec![input_value.value()]).to_arrow_array()
700        }
701    };
702
703    let rb = DfRecordBatch::try_new(df_schema.inner().clone(), vec![ts_vector.clone()])
704        .with_context(|_| ArrowSnafu {
705            context: format!("Failed to create record batch from {df_schema:?} and {ts_vector:?}"),
706        })?;
707
708    let eval_res = phy.evaluate(&rb).with_context(|_| DatafusionSnafu {
709        context: format!("Failed to evaluate physical expression {phy:?} on {rb:?}"),
710    })?;
711
712    if let Some(Some(ts)) = columnar_to_ts_vector(&eval_res)?.first() {
713        Ok(*ts)
714    } else {
715        UnexpectedSnafu {
716            reason: format!(
717                "Expected timestamp in expression {phy:?} but got {:?}",
718                eval_res
719            ),
720        }
721        .fail()?
722    }
723}
724
725fn to_phy_expr(
726    expr: &Expr,
727    df_schema: &DFSchema,
728    session: &SessionState,
729) -> Result<PhysicalExprRef, Error> {
730    let phy_planner = DefaultPhysicalPlanner::default();
731
732    let phy_expr: PhysicalExprRef = phy_planner
733        .create_physical_expr(expr, df_schema, session)
734        .with_context(|_e| DatafusionSnafu {
735            context: format!(
736                "Failed to create physical expression from {expr:?} using {df_schema:?}"
737            ),
738        })?;
739    Ok(phy_expr)
740}
741
742#[cfg(test)]
743mod test {
744    use datafusion_common::tree_node::TreeNode;
745    use pretty_assertions::assert_eq;
746    use session::context::QueryContext;
747
748    use super::*;
749    use crate::batching_mode::utils::{df_plan_to_sql, sql_to_df_plan, AddFilterRewriter};
750    use crate::test_utils::create_test_query_engine;
751
752    #[tokio::test]
753    async fn test_plan_time_window_lower_bound() {
754        use datafusion_expr::{col, lit};
755        let query_engine = create_test_query_engine();
756        let ctx = QueryContext::arc();
757
758        let testcases = [
759            // same alias is not same column
760            (
761                "SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS ts FROM numbers_with_ts GROUP BY ts;",
762                Timestamp::new(1740394109, TimeUnit::Second),
763                (
764                    "ts".to_string(),
765                    Some(Timestamp::new(1740394109000, TimeUnit::Millisecond)),
766                    Some(Timestamp::new(1740394109001, TimeUnit::Millisecond)),
767                ),
768                r#"SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS ts FROM numbers_with_ts WHERE ((ts >= CAST('2025-02-24 10:48:29' AS TIMESTAMP)) AND (ts <= CAST('2025-02-24 10:48:29.001' AS TIMESTAMP))) GROUP BY numbers_with_ts.ts"#
769            ),
770            // complex time window index
771            (
772                "SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS time_window FROM numbers_with_ts GROUP BY time_window;",
773                Timestamp::new(1740394109, TimeUnit::Second),
774                (
775                    "ts".to_string(),
776                    Some(Timestamp::new(1740394080, TimeUnit::Second)),
777                    Some(Timestamp::new(1740394140, TimeUnit::Second)),
778                ),
779                "SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS time_window FROM numbers_with_ts WHERE ((ts >= CAST('2025-02-24 10:48:00' AS TIMESTAMP)) AND (ts <= CAST('2025-02-24 10:49:00' AS TIMESTAMP))) GROUP BY arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)')"
780            ),
781            // complex time window index with where
782            (
783                "SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS time_window FROM numbers_with_ts WHERE number in (2, 3, 4) GROUP BY time_window;",
784                Timestamp::new(1740394109, TimeUnit::Second),
785                (
786                    "ts".to_string(),
787                    Some(Timestamp::new(1740394080, TimeUnit::Second)),
788                    Some(Timestamp::new(1740394140, TimeUnit::Second)),
789                ),
790                "SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS time_window FROM numbers_with_ts WHERE numbers_with_ts.number IN (2, 3, 4) AND ((ts >= CAST('2025-02-24 10:48:00' AS TIMESTAMP)) AND (ts <= CAST('2025-02-24 10:49:00' AS TIMESTAMP))) GROUP BY arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)')"
791            ),
792            // complex time window index with between and
793            (
794                "SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS time_window FROM numbers_with_ts WHERE number BETWEEN 2 AND 4 GROUP BY time_window;",
795                Timestamp::new(1740394109, TimeUnit::Second),
796                (
797                    "ts".to_string(),
798                    Some(Timestamp::new(1740394080, TimeUnit::Second)),
799                    Some(Timestamp::new(1740394140, TimeUnit::Second)),
800                ),
801                "SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS time_window FROM numbers_with_ts WHERE (numbers_with_ts.number BETWEEN 2 AND 4) AND ((ts >= CAST('2025-02-24 10:48:00' AS TIMESTAMP)) AND (ts <= CAST('2025-02-24 10:49:00' AS TIMESTAMP))) GROUP BY arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)')"
802            ),
803            // no time index
804            (
805                "SELECT date_bin('5 minutes', ts) FROM numbers_with_ts;",
806                Timestamp::new(23, TimeUnit::Millisecond),
807                ("ts".to_string(), None, None),
808                "SELECT date_bin('5 minutes', ts) FROM numbers_with_ts;"
809            ),
810            // time index
811            (
812                "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window;",
813                Timestamp::new(23, TimeUnit::Nanosecond),
814                (
815                    "ts".to_string(),
816                    Some(Timestamp::new(0, TimeUnit::Millisecond)),
817                    Some(Timestamp::new(300000, TimeUnit::Millisecond)),
818                ),
819                "SELECT date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP))) GROUP BY date_bin('5 minutes', numbers_with_ts.ts)"
820            ),
821            // on spot
822            (
823                "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window;",
824                Timestamp::new(0, TimeUnit::Nanosecond),
825                (
826                    "ts".to_string(),
827                    Some(Timestamp::new(0, TimeUnit::Millisecond)),
828                    Some(Timestamp::new(300000, TimeUnit::Millisecond)),
829                ),
830                "SELECT date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP))) GROUP BY date_bin('5 minutes', numbers_with_ts.ts)"
831            ),
832            // different time unit
833            (
834                "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window;",
835                Timestamp::new(23_000_000, TimeUnit::Nanosecond),
836                (
837                    "ts".to_string(),
838                    Some(Timestamp::new(0, TimeUnit::Millisecond)),
839                    Some(Timestamp::new(300000, TimeUnit::Millisecond)),
840                ),
841                "SELECT date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP))) GROUP BY date_bin('5 minutes', numbers_with_ts.ts)"
842            ),
843            // time index with other fields
844            (
845                "SELECT sum(number) as sum_up, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window;",
846                Timestamp::new(23, TimeUnit::Millisecond),
847                (
848                    "ts".to_string(),
849                    Some(Timestamp::new(0, TimeUnit::Millisecond)),
850                    Some(Timestamp::new(300000, TimeUnit::Millisecond)),
851                ),
852                "SELECT sum(numbers_with_ts.number) AS sum_up, date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP))) GROUP BY date_bin('5 minutes', numbers_with_ts.ts)"
853            ),
854            // time index with other pks
855            (
856                "SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number;",
857                Timestamp::new(23, TimeUnit::Millisecond),
858                (
859                    "ts".to_string(),
860                    Some(Timestamp::new(0, TimeUnit::Millisecond)),
861                    Some(Timestamp::new(300000, TimeUnit::Millisecond)),
862                ),
863                "SELECT numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP))) GROUP BY date_bin('5 minutes', numbers_with_ts.ts), numbers_with_ts.number"
864            ),
865            // subquery
866            (
867                "SELECT number, time_window FROM (SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number);",
868                Timestamp::new(23, TimeUnit::Millisecond),
869                (
870                    "ts".to_string(),
871                    Some(Timestamp::new(0, TimeUnit::Millisecond)),
872                    Some(Timestamp::new(300000, TimeUnit::Millisecond)),
873                ),
874                "SELECT numbers_with_ts.number, time_window FROM (SELECT numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP))) GROUP BY date_bin('5 minutes', numbers_with_ts.ts), numbers_with_ts.number)"
875            ),
876            // cte
877            (
878                "with cte as (select number, date_bin('5 minutes', ts) as time_window from numbers_with_ts GROUP BY time_window, number) select number, time_window from cte;",
879                Timestamp::new(23, TimeUnit::Millisecond),
880                (
881                    "ts".to_string(),
882                    Some(Timestamp::new(0, TimeUnit::Millisecond)),
883                    Some(Timestamp::new(300000, TimeUnit::Millisecond)),
884                ),
885                "SELECT cte.number, cte.time_window FROM (SELECT numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP))) GROUP BY date_bin('5 minutes', numbers_with_ts.ts), numbers_with_ts.number) AS cte"
886            ),
887            // complex subquery without alias
888            (
889                "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) GROUP BY number, time_window, bucket_name;",
890                Timestamp::new(23, TimeUnit::Millisecond),
891                (
892                    "ts".to_string(),
893                    Some(Timestamp::new(0, TimeUnit::Millisecond)),
894                    Some(Timestamp::new(300000, TimeUnit::Millisecond)),
895                ),
896                "SELECT sum(numbers_with_ts.number), numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window, bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP)))) GROUP BY numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts), bucket_name"
897            ),
898            // complex subquery alias
899            (
900                "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) as cte GROUP BY number, time_window, bucket_name;",
901                Timestamp::new(23, TimeUnit::Millisecond),
902                (
903                    "ts".to_string(),
904                    Some(Timestamp::new(0, TimeUnit::Millisecond)),
905                    Some(Timestamp::new(300000, TimeUnit::Millisecond)),
906                ),
907                "SELECT sum(cte.number), cte.number, date_bin('5 minutes', cte.ts) AS time_window, cte.bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP)))) AS cte GROUP BY cte.number, date_bin('5 minutes', cte.ts), cte.bucket_name"
908            ),
909        ];
910
911        for (sql, current, expected, expected_unparsed) in testcases {
912            let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, true)
913                .await
914                .unwrap();
915
916            let real =
917                find_plan_time_window_bound(&plan, current, ctx.clone(), query_engine.clone())
918                    .await
919                    .unwrap();
920            assert_eq!(expected, real);
921
922            let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
923                .await
924                .unwrap();
925            let (col_name, lower, upper) = real;
926            let new_sql = if lower.is_some() {
927                let to_df_literal = |value| {
928                    let value = Value::from(value);
929
930                    value.try_to_scalar_value(&value.data_type()).unwrap()
931                };
932                let lower = to_df_literal(lower.unwrap());
933                let upper = to_df_literal(upper.unwrap());
934                let expr = col(&col_name)
935                    .gt_eq(lit(lower))
936                    .and(col(&col_name).lt_eq(lit(upper)));
937                let mut add_filter = AddFilterRewriter::new(expr);
938                let plan = plan.rewrite(&mut add_filter).unwrap().data;
939                df_plan_to_sql(&plan).unwrap()
940            } else {
941                sql.to_string()
942            };
943            assert_eq!(expected_unparsed, new_sql);
944        }
945    }
946}