flow/batching_mode/
time_window.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Time window expr and helper functions
16//!
17
18use std::collections::BTreeSet;
19use std::sync::Arc;
20
21use api::helper::pb_value_to_value_ref;
22use arrow::array::{
23    TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
24    TimestampSecondArray,
25};
26use catalog::CatalogManagerRef;
27use common_error::ext::BoxedError;
28use common_recordbatch::DfRecordBatch;
29use common_telemetry::warn;
30use common_time::Timestamp;
31use common_time::timestamp::TimeUnit;
32use datafusion::error::Result as DfResult;
33use datafusion::execution::SessionState;
34use datafusion::logical_expr::Expr;
35use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
36use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter};
37use datafusion_common::{DFSchema, TableReference};
38use datafusion_expr::{ColumnarValue, LogicalPlan};
39use datafusion_physical_expr::PhysicalExprRef;
40use datatypes::prelude::{ConcreteDataType, DataType};
41use datatypes::schema::TIME_INDEX_KEY;
42use datatypes::value::Value;
43use datatypes::vectors::{
44    TimestampMicrosecondVector, TimestampMillisecondVector, TimestampNanosecondVector,
45    TimestampSecondVector, Vector,
46};
47use itertools::Itertools;
48use session::context::QueryContextRef;
49use snafu::{OptionExt, ResultExt, ensure};
50
51use crate::Error;
52use crate::adapter::util::from_proto_to_data_type;
53use crate::error::{
54    ArrowSnafu, DatafusionSnafu, DatatypesSnafu, ExternalSnafu, PlanSnafu, TimeSnafu,
55    UnexpectedSnafu,
56};
57use crate::expr::error::DataTypeSnafu;
58
59/// Represents a test timestamp in seconds since the Unix epoch.
60const DEFAULT_TEST_TIMESTAMP: Timestamp = Timestamp::new_second(17_0000_0000);
61
62/// Time window expr like `date_bin(INTERVAL '1' MINUTE, ts)`, this type help with
63/// evaluating the expr using given timestamp
64///
65/// The time window expr must satisfies following conditions:
66/// 1. The expr must be monotonic non-decreasing
67/// 2. The expr must only have one and only one input column with timestamp type, and the output column must be timestamp type
68/// 3. The expr must be deterministic
69///
70/// An example of time window expr is `date_bin(INTERVAL '1' MINUTE, ts)`
71#[derive(Debug, Clone)]
72pub struct TimeWindowExpr {
73    phy_expr: PhysicalExprRef,
74    pub column_name: String,
75    logical_expr: Expr,
76    df_schema: DFSchema,
77    eval_time_window_size: Option<std::time::Duration>,
78    eval_time_original: Option<Timestamp>,
79}
80
81impl std::fmt::Display for TimeWindowExpr {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        f.debug_struct("TimeWindowExpr")
84            .field("phy_expr", &self.phy_expr.to_string())
85            .field("column_name", &self.column_name)
86            .field("logical_expr", &self.logical_expr.to_string())
87            .field("df_schema", &self.df_schema)
88            .finish()
89    }
90}
91
92impl TimeWindowExpr {
93    /// The time window size of the expr, get from calling `eval` with a test timestamp
94    pub fn time_window_size(&self) -> &Option<std::time::Duration> {
95        &self.eval_time_window_size
96    }
97
98    pub fn from_expr(
99        expr: &Expr,
100        column_name: &str,
101        df_schema: &DFSchema,
102        session: &SessionState,
103    ) -> Result<Self, Error> {
104        let phy_expr: PhysicalExprRef = to_phy_expr(expr, df_schema, session)?;
105        let mut zelf = Self {
106            phy_expr,
107            column_name: column_name.to_string(),
108            logical_expr: expr.clone(),
109            df_schema: df_schema.clone(),
110            eval_time_window_size: None,
111            eval_time_original: None,
112        };
113        let test_ts = DEFAULT_TEST_TIMESTAMP;
114        let (lower, upper) = zelf.eval(test_ts)?;
115        let time_window_size = match (lower, upper) {
116            (Some(l), Some(u)) => u.sub(&l).map(|r| r.to_std()).transpose().map_err(|_| {
117                UnexpectedSnafu {
118                    reason: format!(
119                        "Expect upper bound older than lower bound, found upper={u:?} and lower={l:?}"
120                    ),
121                }
122                .build()
123            })?,
124            _ => None,
125        };
126        zelf.eval_time_window_size = time_window_size;
127        zelf.eval_time_original = lower;
128
129        Ok(zelf)
130    }
131
132    /// TODO(discord9): add `eval_batch` too
133    pub fn eval(
134        &self,
135        current: Timestamp,
136    ) -> Result<(Option<Timestamp>, Option<Timestamp>), Error> {
137        fn compute_distance(time_diff_ns: i64, stride_ns: i64) -> i64 {
138            if stride_ns == 0 {
139                return time_diff_ns;
140            }
141            // a - (a % n) impl ceil to nearest n * stride
142            let time_delta = time_diff_ns - (time_diff_ns % stride_ns);
143
144            if time_diff_ns < 0 && time_delta != time_diff_ns {
145                // The origin is later than the source timestamp, round down to the previous bin
146
147                time_delta - stride_ns
148            } else {
149                time_delta
150            }
151        }
152
153        // FAST PATH: if we have eval_time_original and eval_time_window_size,
154        // we can compute the bounds directly
155        if let (Some(original), Some(window_size)) =
156            (self.eval_time_original, self.eval_time_window_size)
157        {
158            // date_bin align current to lower bound
159            let time_diff_ns = current.sub(&original).and_then(|s|s.num_nanoseconds()).with_context(||UnexpectedSnafu {
160                reason: format!(
161                    "Failed to compute time difference between current {current:?} and original {original:?}"
162                ),
163            })?;
164
165            let window_size_ns = window_size.as_nanos() as i64;
166
167            let distance_ns = compute_distance(time_diff_ns, window_size_ns);
168
169            let lower_bound = if distance_ns >= 0 {
170                original.add_duration(std::time::Duration::from_nanos(distance_ns as u64))
171            } else {
172                original.sub_duration(std::time::Duration::from_nanos((-distance_ns) as u64))
173            }
174            .context(TimeSnafu)?;
175            let upper_bound = lower_bound.add_duration(window_size).context(TimeSnafu)?;
176
177            return Ok((Some(lower_bound), Some(upper_bound)));
178        }
179
180        let lower_bound =
181            calc_expr_time_window_lower_bound(&self.phy_expr, &self.df_schema, current)?;
182        let upper_bound =
183            probe_expr_time_window_upper_bound(&self.phy_expr, &self.df_schema, current)?;
184        Ok((lower_bound, upper_bound))
185    }
186
187    /// Find timestamps from rows using time window expr
188    ///
189    /// use column of name `self.column_name` from input rows list as input to time window expr
190    pub async fn handle_rows(
191        &self,
192        rows_list: Vec<api::v1::Rows>,
193    ) -> Result<BTreeSet<Timestamp>, Error> {
194        let mut time_windows = BTreeSet::new();
195
196        for rows in rows_list {
197            // pick the time index column and use it to eval on `self.expr`
198            // TODO(discord9): handle case where time index column is not present(i.e. DEFAULT constant value)
199            let ts_col_index = rows
200                .schema
201                .iter()
202                .map(|col| col.column_name.clone())
203                .position(|name| name == self.column_name);
204            let Some(ts_col_index) = ts_col_index else {
205                warn!("can't found time index column in schema: {:?}", rows.schema);
206                continue;
207            };
208            let col_schema = &rows.schema[ts_col_index];
209            let cdt = from_proto_to_data_type(col_schema)?;
210
211            let mut vector = cdt.create_mutable_vector(rows.rows.len());
212            for row in rows.rows {
213                let value = pb_value_to_value_ref(&row.values[ts_col_index], &None);
214                vector.try_push_value_ref(value).context(DataTypeSnafu {
215                    msg: "Failed to convert rows to columns",
216                })?;
217            }
218            let vector = vector.to_vector();
219
220            let df_schema = create_df_schema_for_ts_column(&self.column_name, cdt)?;
221
222            let rb =
223                DfRecordBatch::try_new(df_schema.inner().clone(), vec![vector.to_arrow_array()])
224                    .with_context(|_e| ArrowSnafu {
225                        context: format!(
226                            "Failed to create record batch from {df_schema:?} and {vector:?}"
227                        ),
228                    })?;
229
230            let eval_res = self
231                .phy_expr
232                .evaluate(&rb)
233                .with_context(|_| DatafusionSnafu {
234                    context: format!(
235                        "Failed to evaluate physical expression {:?} on {rb:?}",
236                        self.phy_expr
237                    ),
238                })?;
239
240            let res = columnar_to_ts_vector(&eval_res)?;
241
242            for ts in res.into_iter().flatten() {
243                time_windows.insert(ts);
244            }
245        }
246
247        Ok(time_windows)
248    }
249}
250
251fn create_df_schema_for_ts_column(name: &str, cdt: ConcreteDataType) -> Result<DFSchema, Error> {
252    let arrow_schema = Arc::new(arrow_schema::Schema::new(vec![arrow_schema::Field::new(
253        name,
254        cdt.as_arrow_type(),
255        false,
256    )]));
257
258    let df_schema = DFSchema::from_field_specific_qualified_schema(
259        vec![Some(TableReference::bare("TimeIndexOnlyTable"))],
260        &arrow_schema,
261    )
262    .with_context(|_e| DatafusionSnafu {
263        context: format!("Failed to create DFSchema from arrow schema {arrow_schema:?}"),
264    })?;
265
266    Ok(df_schema)
267}
268
269/// Convert `ColumnarValue` to `Vec<Option<Timestamp>>`
270fn columnar_to_ts_vector(columnar: &ColumnarValue) -> Result<Vec<Option<Timestamp>>, Error> {
271    let val = match columnar {
272        datafusion_expr::ColumnarValue::Array(array) => {
273            let ty = array.data_type();
274            let ty = ConcreteDataType::from_arrow_type(ty);
275            let time_unit = if let ConcreteDataType::Timestamp(ty) = ty {
276                ty.unit()
277            } else {
278                return UnexpectedSnafu {
279                    reason: format!("Non-timestamp type: {ty:?}"),
280                }
281                .fail();
282            };
283
284            match time_unit {
285                TimeUnit::Second => array
286                    .as_ref()
287                    .as_any()
288                    .downcast_ref::<TimestampSecondArray>()
289                    .with_context(|| PlanSnafu {
290                        reason: format!("Failed to create vector from arrow array {array:?}"),
291                    })?
292                    .values()
293                    .iter()
294                    .map(|d| Some(Timestamp::new(*d, time_unit)))
295                    .collect_vec(),
296                TimeUnit::Millisecond => array
297                    .as_ref()
298                    .as_any()
299                    .downcast_ref::<TimestampMillisecondArray>()
300                    .with_context(|| PlanSnafu {
301                        reason: format!("Failed to create vector from arrow array {array:?}"),
302                    })?
303                    .values()
304                    .iter()
305                    .map(|d| Some(Timestamp::new(*d, time_unit)))
306                    .collect_vec(),
307                TimeUnit::Microsecond => array
308                    .as_ref()
309                    .as_any()
310                    .downcast_ref::<TimestampMicrosecondArray>()
311                    .with_context(|| PlanSnafu {
312                        reason: format!("Failed to create vector from arrow array {array:?}"),
313                    })?
314                    .values()
315                    .iter()
316                    .map(|d| Some(Timestamp::new(*d, time_unit)))
317                    .collect_vec(),
318                TimeUnit::Nanosecond => array
319                    .as_ref()
320                    .as_any()
321                    .downcast_ref::<TimestampNanosecondArray>()
322                    .with_context(|| PlanSnafu {
323                        reason: format!("Failed to create vector from arrow array {array:?}"),
324                    })?
325                    .values()
326                    .iter()
327                    .map(|d| Some(Timestamp::new(*d, time_unit)))
328                    .collect_vec(),
329            }
330        }
331        datafusion_expr::ColumnarValue::Scalar(scalar) => {
332            let value = Value::try_from(scalar.clone()).with_context(|_| DatatypesSnafu {
333                extra: format!("Failed to convert scalar {scalar:?} to value"),
334            })?;
335            let ts = value.as_timestamp().context(UnexpectedSnafu {
336                reason: format!("Expect Timestamp, found {:?}", value),
337            })?;
338            vec![Some(ts)]
339        }
340    };
341    Ok(val)
342}
343
344/// Return (`the column name of time index column`, `the time window expr`, `the expected time unit of time index column`, `the expr's schema for evaluating the time window`)
345///
346/// The time window expr is expected to have one input column with Timestamp type, and also return Timestamp type, the time window expr is expected
347/// to be monotonic increasing and appears in the innermost GROUP BY clause
348///
349/// note this plan should only contain one TableScan
350pub async fn find_time_window_expr(
351    plan: &LogicalPlan,
352    catalog_man: CatalogManagerRef,
353    query_ctx: QueryContextRef,
354) -> Result<(String, Option<datafusion_expr::Expr>, TimeUnit, DFSchema), Error> {
355    // TODO(discord9): find the expr that do time window
356
357    let mut table_name = None;
358
359    // first find the table source in the logical plan
360    plan.apply(|plan| {
361        let LogicalPlan::TableScan(table_scan) = plan else {
362            return Ok(TreeNodeRecursion::Continue);
363        };
364        table_name = Some(table_scan.table_name.clone());
365        Ok(TreeNodeRecursion::Stop)
366    })
367    .with_context(|_| DatafusionSnafu {
368        context: format!("Can't find table source in plan {plan:?}"),
369    })?;
370    let Some(table_name) = table_name else {
371        UnexpectedSnafu {
372            reason: format!("Can't find table source in plan {plan:?}"),
373        }
374        .fail()?
375    };
376
377    let current_schema = query_ctx.current_schema();
378
379    let catalog_name = table_name.catalog().unwrap_or(query_ctx.current_catalog());
380    let schema_name = table_name.schema().unwrap_or(&current_schema);
381    let table_name = table_name.table();
382
383    let Some(table_ref) = catalog_man
384        .table(catalog_name, schema_name, table_name, Some(&query_ctx))
385        .await
386        .map_err(BoxedError::new)
387        .context(ExternalSnafu)?
388    else {
389        UnexpectedSnafu {
390            reason: format!(
391                "Can't find table {table_name:?} in catalog {catalog_name:?}/{schema_name:?}"
392            ),
393        }
394        .fail()?
395    };
396
397    let schema = &table_ref.table_info().meta.schema;
398
399    let ts_index = schema.timestamp_column().with_context(|| UnexpectedSnafu {
400        reason: format!("Can't find timestamp column in table {table_name:?}"),
401    })?;
402
403    let ts_col_name = ts_index.name.clone();
404
405    let expected_time_unit = ts_index.data_type.as_timestamp().with_context(|| UnexpectedSnafu {
406        reason: format!(
407            "Expected timestamp column {ts_col_name:?} in table {table_name:?} to be timestamp, but got {ts_index:?}"
408        ),
409    })?.unit();
410
411    let arrow_schema = Arc::new(arrow_schema::Schema::new(vec![arrow_schema::Field::new(
412        ts_col_name.clone(),
413        ts_index.data_type.as_arrow_type(),
414        false,
415    )]));
416
417    let df_schema = DFSchema::from_field_specific_qualified_schema(
418        vec![Some(TableReference::bare(table_name))],
419        &arrow_schema,
420    )
421    .with_context(|_e| DatafusionSnafu {
422        context: format!("Failed to create DFSchema from arrow schema {arrow_schema:?}"),
423    })?;
424
425    // find the time window expr which refers to the time index column
426    let mut aggr_expr = None;
427    let mut time_window_expr: Option<Expr> = None;
428
429    let find_inner_aggr_expr = |plan: &LogicalPlan| {
430        if let LogicalPlan::Aggregate(aggregate) = plan {
431            aggr_expr = Some(aggregate.clone());
432        };
433
434        Ok(TreeNodeRecursion::Continue)
435    };
436    plan.apply(find_inner_aggr_expr)
437        .with_context(|_| DatafusionSnafu {
438            context: format!("Can't find aggr expr in plan {plan:?}"),
439        })?;
440
441    if let Some(aggregate) = aggr_expr {
442        for group_expr in &aggregate.group_expr {
443            let refs = group_expr.column_refs();
444            if refs.len() != 1 {
445                continue;
446            }
447            let ref_col = refs.iter().next().unwrap();
448
449            let index = aggregate.input.schema().maybe_index_of_column(ref_col);
450            let Some(index) = index else {
451                continue;
452            };
453            let field = aggregate.input.schema().field(index);
454
455            // TODO(discord9): need to ensure the field has the meta key for the time index
456            let is_time_index =
457                field.metadata().get(TIME_INDEX_KEY).map(|s| s.as_str()) == Some("true");
458
459            if is_time_index {
460                let rewrite_column = group_expr.clone();
461                let rewritten = rewrite_column
462                    .rewrite(&mut RewriteColumn {
463                        table_name: table_name.to_string(),
464                    })
465                    .with_context(|_| DatafusionSnafu {
466                        context: format!("Rewrite expr failed, expr={:?}", group_expr),
467                    })?
468                    .data;
469                struct RewriteColumn {
470                    table_name: String,
471                }
472
473                impl TreeNodeRewriter for RewriteColumn {
474                    type Node = Expr;
475                    fn f_down(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
476                        let Expr::Column(mut column) = node else {
477                            return Ok(Transformed::no(node));
478                        };
479
480                        column.relation = Some(TableReference::bare(self.table_name.clone()));
481
482                        Ok(Transformed::yes(Expr::Column(column)))
483                    }
484                }
485
486                time_window_expr = Some(rewritten);
487                break;
488            }
489        }
490        Ok((ts_col_name, time_window_expr, expected_time_unit, df_schema))
491    } else {
492        // can't found time window expr, return None
493        Ok((ts_col_name, None, expected_time_unit, df_schema))
494    }
495}
496
497/// Find nearest lower bound for time `current` in given `plan` for the time window expr.
498/// i.e. for time window expr being `date_bin(INTERVAL '5 minutes', ts) as time_window` and `current="2021-07-01 00:01:01.000"`,
499/// return `Some("2021-07-01 00:00:00.000")`
500/// if `plan` doesn't contain a `TIME INDEX` column, return `None`
501///
502/// Time window expr is a expr that:
503/// 1. ref only to a time index column
504/// 2. is monotonic increasing
505/// 3. show up in GROUP BY clause
506///
507/// note this plan should only contain one TableScan
508#[cfg(test)]
509pub async fn find_plan_time_window_bound(
510    plan: &LogicalPlan,
511    current: Timestamp,
512    query_ctx: QueryContextRef,
513    engine: query::QueryEngineRef,
514) -> Result<(String, Option<Timestamp>, Option<Timestamp>), Error> {
515    // TODO(discord9): find the expr that do time window
516    let catalog_man = engine.engine_state().catalog_manager();
517
518    let (ts_col_name, time_window_expr, expected_time_unit, df_schema) =
519        find_time_window_expr(plan, catalog_man.clone(), query_ctx).await?;
520    // cast current to ts_index's type
521    let new_current = current
522        .convert_to(expected_time_unit)
523        .with_context(|| UnexpectedSnafu {
524            reason: format!("Failed to cast current timestamp {current:?} to {expected_time_unit}"),
525        })?;
526
527    // if no time_window_expr is found, return None
528    if let Some(time_window_expr) = time_window_expr {
529        let phy_expr = to_phy_expr(
530            &time_window_expr,
531            &df_schema,
532            &engine.engine_state().session_state(),
533        )?;
534        let lower_bound = calc_expr_time_window_lower_bound(&phy_expr, &df_schema, new_current)?;
535        let upper_bound = probe_expr_time_window_upper_bound(&phy_expr, &df_schema, new_current)?;
536        Ok((ts_col_name, lower_bound, upper_bound))
537    } else {
538        Ok((ts_col_name, None, None))
539    }
540}
541
542/// Find the lower bound of time window in given `expr` and `current` timestamp.
543///
544/// i.e. for `current="2021-07-01 00:01:01.000"` and `expr=date_bin(INTERVAL '5 minutes', ts) as time_window` and `ts_col=ts`,
545/// return `Some("2021-07-01 00:00:00.000")` since it's the lower bound
546/// return `Some("2021-07-01 00:00:00.000")` since it's the lower bound
547/// of current time window given the current timestamp
548///
549/// if return None, meaning this time window have no lower bound
550fn calc_expr_time_window_lower_bound(
551    phy_expr: &PhysicalExprRef,
552    df_schema: &DFSchema,
553    current: Timestamp,
554) -> Result<Option<Timestamp>, Error> {
555    let cur_time_window = eval_phy_time_window_expr(phy_expr, df_schema, current)?;
556    let input_time_unit = cur_time_window.unit();
557    Ok(cur_time_window.convert_to(input_time_unit))
558}
559
560/// Probe for the upper bound for time window expression
561fn probe_expr_time_window_upper_bound(
562    phy_expr: &PhysicalExprRef,
563    df_schema: &DFSchema,
564    current: Timestamp,
565) -> Result<Option<Timestamp>, Error> {
566    // TODO(discord9): special handling `date_bin` for faster path
567    use std::cmp::Ordering;
568
569    let cur_time_window = eval_phy_time_window_expr(phy_expr, df_schema, current)?;
570
571    // search to find the lower bound
572    let mut offset: i64 = 1;
573    let mut lower_bound = Some(current);
574    let upper_bound;
575    // first expontial probe to found a range for binary search
576    loop {
577        let Some(next_val) = current.value().checked_add(offset) else {
578            // no upper bound if overflow, which is ok
579            return Ok(None);
580        };
581
582        let next_time_probe = common_time::Timestamp::new(next_val, current.unit());
583
584        let next_time_window = eval_phy_time_window_expr(phy_expr, df_schema, next_time_probe)?;
585
586        match next_time_window.cmp(&cur_time_window) {
587            Ordering::Less => UnexpectedSnafu {
588                    reason: format!(
589                        "Unsupported time window expression, expect monotonic increasing for time window expression {phy_expr:?}"
590                    ),
591                }
592                .fail()?,
593            Ordering::Equal => {
594                lower_bound = Some(next_time_probe);
595            }
596            Ordering::Greater => {
597                upper_bound = Some(next_time_probe);
598                break
599            }
600        }
601
602        let Some(new_offset) = offset.checked_mul(2) else {
603            // no upper bound if overflow
604            return Ok(None);
605        };
606        offset = new_offset;
607    }
608
609    // binary search for the exact upper bound
610
611    binary_search_expr(
612        lower_bound,
613        upper_bound,
614        cur_time_window,
615        phy_expr,
616        df_schema,
617    )
618    .map(Some)
619}
620
621fn binary_search_expr(
622    lower_bound: Option<Timestamp>,
623    upper_bound: Option<Timestamp>,
624    cur_time_window: Timestamp,
625    phy_expr: &PhysicalExprRef,
626    df_schema: &DFSchema,
627) -> Result<Timestamp, Error> {
628    ensure!(
629        lower_bound.map(|v| v.unit()) == upper_bound.map(|v| v.unit()),
630        UnexpectedSnafu {
631            reason: format!(
632                " unit mismatch for time window expression {phy_expr:?}, found {lower_bound:?} and {upper_bound:?}"
633            ),
634        }
635    );
636
637    let output_unit = upper_bound
638        .context(UnexpectedSnafu {
639            reason: "should have lower bound",
640        })?
641        .unit();
642
643    let mut low = lower_bound
644        .context(UnexpectedSnafu {
645            reason: "should have lower bound",
646        })?
647        .value();
648    let mut high = upper_bound
649        .context(UnexpectedSnafu {
650            reason: "should have upper bound",
651        })?
652        .value();
653    while low < high {
654        let mid = (low + high) / 2;
655        let mid_probe = common_time::Timestamp::new(mid, output_unit);
656        let mid_time_window = eval_phy_time_window_expr(phy_expr, df_schema, mid_probe)?;
657
658        match mid_time_window.cmp(&cur_time_window) {
659            std::cmp::Ordering::Less => UnexpectedSnafu {
660                reason: format!("Binary search failed for time window expression {phy_expr:?}"),
661            }
662            .fail()?,
663            std::cmp::Ordering::Equal => low = mid + 1,
664            std::cmp::Ordering::Greater => high = mid,
665        }
666    }
667
668    let final_upper_bound_for_time_window = common_time::Timestamp::new(high, output_unit);
669    Ok(final_upper_bound_for_time_window)
670}
671
672/// Expect the `phy` expression only have one input column with Timestamp type, and also return Timestamp type
673fn eval_phy_time_window_expr(
674    phy: &PhysicalExprRef,
675    df_schema: &DFSchema,
676    input_value: Timestamp,
677) -> Result<Timestamp, Error> {
678    let schema_ty = df_schema.field(0).data_type();
679    let schema_cdt = ConcreteDataType::from_arrow_type(schema_ty);
680    let schema_unit = if let ConcreteDataType::Timestamp(ts) = schema_cdt {
681        ts.unit()
682    } else {
683        return UnexpectedSnafu {
684            reason: format!("Expect Timestamp, found {:?}", schema_cdt),
685        }
686        .fail();
687    };
688    let input_value = input_value
689        .convert_to(schema_unit)
690        .with_context(|| UnexpectedSnafu {
691            reason: format!("Failed to convert timestamp {input_value:?} to {schema_unit}"),
692        })?;
693    let ts_vector = match schema_unit {
694        TimeUnit::Second => {
695            TimestampSecondVector::from_vec(vec![input_value.value()]).to_arrow_array()
696        }
697        TimeUnit::Millisecond => {
698            TimestampMillisecondVector::from_vec(vec![input_value.value()]).to_arrow_array()
699        }
700        TimeUnit::Microsecond => {
701            TimestampMicrosecondVector::from_vec(vec![input_value.value()]).to_arrow_array()
702        }
703        TimeUnit::Nanosecond => {
704            TimestampNanosecondVector::from_vec(vec![input_value.value()]).to_arrow_array()
705        }
706    };
707
708    let rb = DfRecordBatch::try_new(df_schema.inner().clone(), vec![ts_vector.clone()])
709        .with_context(|_| ArrowSnafu {
710            context: format!("Failed to create record batch from {df_schema:?} and {ts_vector:?}"),
711        })?;
712
713    let eval_res = phy.evaluate(&rb).with_context(|_| DatafusionSnafu {
714        context: format!("Failed to evaluate physical expression {phy:?} on {rb:?}"),
715    })?;
716
717    if let Some(Some(ts)) = columnar_to_ts_vector(&eval_res)?.first() {
718        Ok(*ts)
719    } else {
720        UnexpectedSnafu {
721            reason: format!(
722                "Expected timestamp in expression {phy:?} but got {:?}",
723                eval_res
724            ),
725        }
726        .fail()?
727    }
728}
729
730fn to_phy_expr(
731    expr: &Expr,
732    df_schema: &DFSchema,
733    session: &SessionState,
734) -> Result<PhysicalExprRef, Error> {
735    let phy_planner = DefaultPhysicalPlanner::default();
736
737    let phy_expr: PhysicalExprRef = phy_planner
738        .create_physical_expr(expr, df_schema, session)
739        .with_context(|_e| DatafusionSnafu {
740            context: format!(
741                "Failed to create physical expression from {expr:?} using {df_schema:?}"
742            ),
743        })?;
744    Ok(phy_expr)
745}
746
747#[cfg(test)]
748mod test {
749    use datafusion_common::tree_node::TreeNode;
750    use pretty_assertions::assert_eq;
751    use session::context::QueryContext;
752
753    use super::*;
754    use crate::batching_mode::utils::{AddFilterRewriter, df_plan_to_sql, sql_to_df_plan};
755    use crate::test_utils::create_test_query_engine;
756
757    #[tokio::test]
758    async fn test_plan_time_window_lower_bound() {
759        use datafusion_expr::{col, lit};
760        let query_engine = create_test_query_engine();
761        let ctx = QueryContext::arc();
762
763        let testcases = [
764            // same alias is not same column
765            (
766                "SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS ts FROM numbers_with_ts GROUP BY ts;",
767                Timestamp::new(1740394109, TimeUnit::Second),
768                (
769                    "ts".to_string(),
770                    Some(Timestamp::new(1740394109000, TimeUnit::Millisecond)),
771                    Some(Timestamp::new(1740394109001, TimeUnit::Millisecond)),
772                ),
773                r#"SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS ts FROM numbers_with_ts WHERE ((ts >= CAST('2025-02-24 10:48:29' AS TIMESTAMP)) AND (ts <= CAST('2025-02-24 10:48:29.001' AS TIMESTAMP))) GROUP BY numbers_with_ts.ts"#,
774            ),
775            // complex time window index
776            (
777                "SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS time_window FROM numbers_with_ts GROUP BY time_window;",
778                Timestamp::new(1740394109, TimeUnit::Second),
779                (
780                    "ts".to_string(),
781                    Some(Timestamp::new(1740394080, TimeUnit::Second)),
782                    Some(Timestamp::new(1740394140, TimeUnit::Second)),
783                ),
784                "SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS time_window FROM numbers_with_ts WHERE ((ts >= CAST('2025-02-24 10:48:00' AS TIMESTAMP)) AND (ts <= CAST('2025-02-24 10:49:00' AS TIMESTAMP))) GROUP BY arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)')",
785            ),
786            // complex time window index with where
787            (
788                "SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS time_window FROM numbers_with_ts WHERE number in (2, 3, 4) GROUP BY time_window;",
789                Timestamp::new(1740394109, TimeUnit::Second),
790                (
791                    "ts".to_string(),
792                    Some(Timestamp::new(1740394080, TimeUnit::Second)),
793                    Some(Timestamp::new(1740394140, TimeUnit::Second)),
794                ),
795                "SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS time_window FROM numbers_with_ts WHERE numbers_with_ts.number IN (2, 3, 4) AND ((ts >= CAST('2025-02-24 10:48:00' AS TIMESTAMP)) AND (ts <= CAST('2025-02-24 10:49:00' AS TIMESTAMP))) GROUP BY arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)')",
796            ),
797            // complex time window index with between and
798            (
799                "SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS time_window FROM numbers_with_ts WHERE number BETWEEN 2 AND 4 GROUP BY time_window;",
800                Timestamp::new(1740394109, TimeUnit::Second),
801                (
802                    "ts".to_string(),
803                    Some(Timestamp::new(1740394080, TimeUnit::Second)),
804                    Some(Timestamp::new(1740394140, TimeUnit::Second)),
805                ),
806                "SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS time_window FROM numbers_with_ts WHERE (numbers_with_ts.number BETWEEN 2 AND 4) AND ((ts >= CAST('2025-02-24 10:48:00' AS TIMESTAMP)) AND (ts <= CAST('2025-02-24 10:49:00' AS TIMESTAMP))) GROUP BY arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)')",
807            ),
808            // no time index
809            (
810                "SELECT date_bin('5 minutes', ts) FROM numbers_with_ts;",
811                Timestamp::new(23, TimeUnit::Millisecond),
812                ("ts".to_string(), None, None),
813                "SELECT date_bin('5 minutes', ts) FROM numbers_with_ts;",
814            ),
815            // time index
816            (
817                "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window;",
818                Timestamp::new(23, TimeUnit::Nanosecond),
819                (
820                    "ts".to_string(),
821                    Some(Timestamp::new(0, TimeUnit::Millisecond)),
822                    Some(Timestamp::new(300000, TimeUnit::Millisecond)),
823                ),
824                "SELECT date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP))) GROUP BY date_bin('5 minutes', numbers_with_ts.ts)",
825            ),
826            // on spot
827            (
828                "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window;",
829                Timestamp::new(0, TimeUnit::Nanosecond),
830                (
831                    "ts".to_string(),
832                    Some(Timestamp::new(0, TimeUnit::Millisecond)),
833                    Some(Timestamp::new(300000, TimeUnit::Millisecond)),
834                ),
835                "SELECT date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP))) GROUP BY date_bin('5 minutes', numbers_with_ts.ts)",
836            ),
837            // different time unit
838            (
839                "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window;",
840                Timestamp::new(23_000_000, TimeUnit::Nanosecond),
841                (
842                    "ts".to_string(),
843                    Some(Timestamp::new(0, TimeUnit::Millisecond)),
844                    Some(Timestamp::new(300000, TimeUnit::Millisecond)),
845                ),
846                "SELECT date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP))) GROUP BY date_bin('5 minutes', numbers_with_ts.ts)",
847            ),
848            // time index with other fields
849            (
850                "SELECT sum(number) as sum_up, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window;",
851                Timestamp::new(23, TimeUnit::Millisecond),
852                (
853                    "ts".to_string(),
854                    Some(Timestamp::new(0, TimeUnit::Millisecond)),
855                    Some(Timestamp::new(300000, TimeUnit::Millisecond)),
856                ),
857                "SELECT sum(numbers_with_ts.number) AS sum_up, date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP))) GROUP BY date_bin('5 minutes', numbers_with_ts.ts)",
858            ),
859            // time index with other pks
860            (
861                "SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number;",
862                Timestamp::new(23, TimeUnit::Millisecond),
863                (
864                    "ts".to_string(),
865                    Some(Timestamp::new(0, TimeUnit::Millisecond)),
866                    Some(Timestamp::new(300000, TimeUnit::Millisecond)),
867                ),
868                "SELECT numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP))) GROUP BY date_bin('5 minutes', numbers_with_ts.ts), numbers_with_ts.number",
869            ),
870            // subquery
871            (
872                "SELECT number, time_window FROM (SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number);",
873                Timestamp::new(23, TimeUnit::Millisecond),
874                (
875                    "ts".to_string(),
876                    Some(Timestamp::new(0, TimeUnit::Millisecond)),
877                    Some(Timestamp::new(300000, TimeUnit::Millisecond)),
878                ),
879                "SELECT numbers_with_ts.number, time_window FROM (SELECT numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP))) GROUP BY date_bin('5 minutes', numbers_with_ts.ts), numbers_with_ts.number)",
880            ),
881            // cte
882            (
883                "with cte as (select number, date_bin('5 minutes', ts) as time_window from numbers_with_ts GROUP BY time_window, number) select number, time_window from cte;",
884                Timestamp::new(23, TimeUnit::Millisecond),
885                (
886                    "ts".to_string(),
887                    Some(Timestamp::new(0, TimeUnit::Millisecond)),
888                    Some(Timestamp::new(300000, TimeUnit::Millisecond)),
889                ),
890                "SELECT cte.number, cte.time_window FROM (SELECT numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP))) GROUP BY date_bin('5 minutes', numbers_with_ts.ts), numbers_with_ts.number) AS cte",
891            ),
892            // complex subquery without alias
893            (
894                "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) GROUP BY number, time_window, bucket_name;",
895                Timestamp::new(23, TimeUnit::Millisecond),
896                (
897                    "ts".to_string(),
898                    Some(Timestamp::new(0, TimeUnit::Millisecond)),
899                    Some(Timestamp::new(300000, TimeUnit::Millisecond)),
900                ),
901                "SELECT sum(numbers_with_ts.number), numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window, bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP)))) GROUP BY numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts), bucket_name",
902            ),
903            // complex subquery alias
904            (
905                "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) as cte GROUP BY number, time_window, bucket_name;",
906                Timestamp::new(23, TimeUnit::Millisecond),
907                (
908                    "ts".to_string(),
909                    Some(Timestamp::new(0, TimeUnit::Millisecond)),
910                    Some(Timestamp::new(300000, TimeUnit::Millisecond)),
911                ),
912                "SELECT sum(cte.number), cte.number, date_bin('5 minutes', cte.ts) AS time_window, cte.bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE ((ts >= CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AND (ts <= CAST('1970-01-01 00:05:00' AS TIMESTAMP)))) AS cte GROUP BY cte.number, date_bin('5 minutes', cte.ts), cte.bucket_name",
913            ),
914        ];
915
916        for (sql, current, expected, expected_unparsed) in testcases {
917            let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, true)
918                .await
919                .unwrap();
920
921            let real =
922                find_plan_time_window_bound(&plan, current, ctx.clone(), query_engine.clone())
923                    .await
924                    .unwrap();
925            assert_eq!(expected, real);
926
927            let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false)
928                .await
929                .unwrap();
930            let (col_name, lower, upper) = real;
931            let new_sql = if lower.is_some() {
932                let to_df_literal = |value| {
933                    let value = Value::from(value);
934
935                    value.try_to_scalar_value(&value.data_type()).unwrap()
936                };
937                let lower = to_df_literal(lower.unwrap());
938                let upper = to_df_literal(upper.unwrap());
939                let expr = col(&col_name)
940                    .gt_eq(lit(lower))
941                    .and(col(&col_name).lt_eq(lit(upper)));
942                let mut add_filter = AddFilterRewriter::new(expr);
943                let plan = plan.rewrite(&mut add_filter).unwrap().data;
944                df_plan_to_sql(&plan).unwrap()
945            } else {
946                sql.to_string()
947            };
948            assert_eq!(expected_unparsed, new_sql);
949        }
950    }
951}