flow/expr/
utils.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! This module contains utility functions for expressions.
16
17use std::cmp::Ordering;
18use std::collections::BTreeMap;
19
20use datatypes::value::Value;
21use snafu::{OptionExt, ensure};
22
23use crate::Result;
24use crate::error::UnexpectedSnafu;
25use crate::expr::ScalarExpr;
26use crate::plan::TypedPlan;
27
28/// Find lower bound for time `current` in given `plan` for the time window expr.
29///
30/// i.e. for time window expr being `date_bin(INTERVAL '5 minutes', ts) as time_window` and `current="2021-07-01 00:01:01.000"`,
31/// return `Some("2021-07-01 00:00:00.000")`
32///
33/// if `plan` doesn't contain a `TIME INDEX` column, return `None`
34pub fn find_plan_time_window_expr_lower_bound(
35    plan: &TypedPlan,
36    current: common_time::Timestamp,
37) -> Result<Option<common_time::Timestamp>> {
38    let typ = plan.schema.typ();
39    let Some(mut time_index) = typ.time_index else {
40        return Ok(None);
41    };
42
43    let mut cur_plan = plan;
44    let mut expr_time_index;
45
46    loop {
47        // follow upward and find deepest time index expr that is not a column ref
48        expr_time_index = Some(cur_plan.plan.get_nth_expr(time_index).context(
49            UnexpectedSnafu {
50                reason: "Failed to find time index expr",
51            },
52        )?);
53
54        if let Some(ScalarExpr::Column(i)) = expr_time_index {
55            time_index = i;
56        } else {
57            break;
58        }
59        if let Some(input) = cur_plan.plan.get_first_input_plan() {
60            cur_plan = input;
61        } else {
62            break;
63        }
64    }
65
66    let expr_time_index = expr_time_index.context(UnexpectedSnafu {
67        reason: "Failed to find time index expr",
68    })?;
69
70    let ts_col = expr_time_index
71        .get_all_ref_columns()
72        .first()
73        .cloned()
74        .context(UnexpectedSnafu {
75            reason: "Failed to find time index column",
76        })?;
77
78    find_time_window_lower_bound(&expr_time_index, ts_col, current)
79}
80
81/// Find the lower bound of time window in given `expr` and `current` timestamp.
82///
83/// i.e. for `current="2021-07-01 00:01:01.000"` and `expr=date_bin(INTERVAL '5 minutes', ts) as time_window` and `ts_col=ts`,
84/// return `Some("2021-07-01 00:00:00.000")` since it's the lower bound
85/// of current time window given the current timestamp
86///
87/// if return None, meaning this time window have no lower bound
88pub fn find_time_window_lower_bound(
89    expr: &ScalarExpr,
90    ts_col_idx: usize,
91    current: common_time::Timestamp,
92) -> Result<Option<common_time::Timestamp>> {
93    let all_ref_columns = expr.get_all_ref_columns();
94
95    ensure!(
96        all_ref_columns.contains(&ts_col_idx),
97        UnexpectedSnafu {
98            reason: format!(
99                "Expected column {} to be referenced in expression {expr:?}",
100                ts_col_idx
101            ),
102        }
103    );
104
105    ensure!(
106        all_ref_columns.len() == 1,
107        UnexpectedSnafu {
108            reason: format!(
109                "Expect only one column to be referenced in expression {expr:?}, found {all_ref_columns:?}"
110            ),
111        }
112    );
113
114    let permute_map = BTreeMap::from([(ts_col_idx, 0usize)]);
115
116    let mut rewrote_expr = expr.clone();
117
118    rewrote_expr.permute_map(&permute_map)?;
119
120    fn eval_to_timestamp(expr: &ScalarExpr, values: &[Value]) -> Result<common_time::Timestamp> {
121        let val = expr.eval(values)?;
122        if let Value::Timestamp(ts) = val {
123            Ok(ts)
124        } else {
125            UnexpectedSnafu {
126                reason: format!("Expected timestamp in expression {expr:?} but got {val:?}"),
127            }
128            .fail()?
129        }
130    }
131
132    let cur_time_window = eval_to_timestamp(&rewrote_expr, &[current.into()])?;
133
134    // search to find the lower bound
135    let mut offset: i64 = 1;
136    let lower_bound;
137    let mut upper_bound = Some(current);
138    // first expontial probe to found a range for binary search
139    loop {
140        let Some(next_val) = current.value().checked_sub(offset) else {
141            // no lower bound
142            return Ok(None);
143        };
144
145        let prev_time_probe = common_time::Timestamp::new(next_val, current.unit());
146
147        let prev_time_window = eval_to_timestamp(&rewrote_expr, &[prev_time_probe.into()])?;
148
149        match prev_time_window.cmp(&cur_time_window) {
150            Ordering::Less => {
151                lower_bound = Some(prev_time_probe);
152                break;
153            }
154            Ordering::Equal => {
155                upper_bound = Some(prev_time_probe);
156            }
157            Ordering::Greater => {
158                UnexpectedSnafu {
159                    reason: format!(
160                        "Unsupported time window expression {rewrote_expr:?}, expect monotonic increasing for time window expression {expr:?}"
161                    ),
162                }
163                .fail()?
164            }
165        }
166
167        let Some(new_offset) = offset.checked_mul(2) else {
168            // no lower bound
169            return Ok(None);
170        };
171        offset = new_offset;
172    }
173
174    // binary search for the lower bound
175
176    ensure!(
177        lower_bound.map(|v| v.unit()) == upper_bound.map(|v| v.unit()),
178        UnexpectedSnafu {
179            reason: format!(
180                " unit mismatch for time window expression {expr:?}, found {lower_bound:?} and {upper_bound:?}"
181            ),
182        }
183    );
184
185    let output_unit = lower_bound.expect("should have lower bound").unit();
186
187    let mut low = lower_bound.expect("should have lower bound").value();
188    let mut high = upper_bound.expect("should have upper bound").value();
189    while low < high {
190        let mid = (low + high) / 2;
191        let mid_probe = common_time::Timestamp::new(mid, output_unit);
192        let mid_time_window = eval_to_timestamp(&rewrote_expr, &[mid_probe.into()])?;
193
194        match mid_time_window.cmp(&cur_time_window) {
195            Ordering::Less => low = mid + 1,
196            Ordering::Equal => high = mid,
197            Ordering::Greater => UnexpectedSnafu {
198                reason: format!("Binary search failed for time window expression {expr:?}"),
199            }
200            .fail()?,
201        }
202    }
203
204    let final_lower_bound_for_time_window = common_time::Timestamp::new(low, output_unit);
205
206    Ok(Some(final_lower_bound_for_time_window))
207}
208
209#[cfg(test)]
210mod test {
211    use pretty_assertions::assert_eq;
212
213    use super::*;
214    use crate::plan::{Plan, TypedPlan};
215    use crate::test_utils::{create_test_ctx, create_test_query_engine, sql_to_substrait};
216
217    #[tokio::test]
218    async fn test_plan_time_window_lower_bound() {
219        let testcases = [
220            // no time index
221            (
222                "SELECT date_bin('5 minutes', ts) FROM numbers_with_ts;",
223                "2021-07-01 00:01:01.000",
224                None,
225            ),
226            // time index
227            (
228                "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window;",
229                "2021-07-01 00:01:01.000",
230                Some("2021-07-01 00:00:00.000"),
231            ),
232            // time index with other fields
233            (
234                "SELECT sum(number) as sum_up, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window;",
235                "2021-07-01 00:01:01.000",
236                Some("2021-07-01 00:00:00.000"),
237            ),
238            // time index with other pks
239            (
240                "SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number;",
241                "2021-07-01 00:01:01.000",
242                Some("2021-07-01 00:00:00.000"),
243            ),
244        ];
245        let engine = create_test_query_engine();
246
247        for (sql, current, expected) in &testcases {
248            let plan = sql_to_substrait(engine.clone(), sql).await;
249            let mut ctx = create_test_ctx();
250            let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
251                .await
252                .unwrap();
253
254            let current = common_time::Timestamp::from_str(current, None).unwrap();
255
256            let expected =
257                expected.map(|expected| common_time::Timestamp::from_str(expected, None).unwrap());
258
259            assert_eq!(
260                find_plan_time_window_expr_lower_bound(&flow_plan, current).unwrap(),
261                expected
262            );
263        }
264    }
265
266    #[tokio::test]
267    async fn test_timewindow_lower_bound() {
268        let testcases = [
269            (
270                ("'5 minutes'", "ts", Some("2021-07-01 00:00:00.000")),
271                "2021-07-01 00:01:01.000",
272                "2021-07-01 00:00:00.000",
273            ),
274            (
275                ("'5 minutes'", "ts", None),
276                "2021-07-01 00:01:01.000",
277                "2021-07-01 00:00:00.000",
278            ),
279            (
280                ("'5 minutes'", "ts", None),
281                "2021-07-01 00:00:00.000",
282                "2021-07-01 00:00:00.000",
283            ),
284            // test edge cases
285            (
286                ("'5 minutes'", "ts", None),
287                "2021-07-01 00:05:00.000",
288                "2021-07-01 00:05:00.000",
289            ),
290            (
291                ("'5 minutes'", "ts", None),
292                "2021-07-01 00:04:59.999",
293                "2021-07-01 00:00:00.000",
294            ),
295            (
296                ("'5 minutes'", "ts", None),
297                "2021-07-01 00:04:59.999999999",
298                "2021-07-01 00:00:00.000",
299            ),
300            (
301                ("'5 minutes'", "ts", None),
302                "2021-07-01 00:04:59.999999999999",
303                "2021-07-01 00:00:00.000",
304            ),
305            (
306                ("'5 minutes'", "ts", None),
307                "2021-07-01 00:04:59.999999999999999",
308                "2021-07-01 00:00:00.000",
309            ),
310        ];
311        let engine = create_test_query_engine();
312
313        for (args, current, expected) in testcases {
314            let sql = if let Some(origin) = args.2 {
315                format!(
316                    "SELECT date_bin({}, {}, '{origin}') FROM numbers_with_ts;",
317                    args.0, args.1
318                )
319            } else {
320                format!(
321                    "SELECT date_bin({}, {}) FROM numbers_with_ts;",
322                    args.0, args.1
323                )
324            };
325            let plan = sql_to_substrait(engine.clone(), &sql).await;
326            let mut ctx = create_test_ctx();
327            let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
328                .await
329                .unwrap();
330
331            let expr = {
332                let mfp = flow_plan.plan;
333                let Plan::Mfp { mfp, .. } = mfp else {
334                    unreachable!()
335                };
336                mfp.expressions[0].clone()
337            };
338
339            let current = common_time::Timestamp::from_str(current, None).unwrap();
340
341            let res = find_time_window_lower_bound(&expr, 1, current).unwrap();
342
343            let expected = Some(common_time::Timestamp::from_str(expected, None).unwrap());
344
345            assert_eq!(res, expected);
346        }
347    }
348}