sql/statements/transform/
expand_interval.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::ops::ControlFlow;
17use std::time::Duration as StdDuration;
18
19use itertools::Itertools;
20use lazy_static::lazy_static;
21use regex::Regex;
22use sqlparser::ast::{DataType, Expr, Interval, Value, ValueWithSpan};
23
24use crate::statements::transform::TransformRule;
25
26lazy_static! {
27    /// Matches either one or more digits `(\d+)` or one or more ASCII characters `[a-zA-Z]` or plus/minus signs
28    static ref INTERVAL_ABBREVIATION_PATTERN: Regex = Regex::new(r"([+-]?\d+|[a-zA-Z]+|\+|-)").unwrap();
29
30    /// Checks if the provided string starts as ISO_8601 format string (case/sign independent)
31    static ref IS_VALID_ISO_8601_PREFIX_PATTERN: Regex = Regex::new(r"^[-]?[Pp]").unwrap();
32
33    static ref INTERVAL_ABBREVIATION_MAPPING: HashMap<&'static str, &'static str> = HashMap::from([
34        ("y","years"),
35        ("mon","months"),
36        ("w","weeks"),
37        ("d","days"),
38        ("h","hours"),
39        ("m","minutes"),
40        ("s","seconds"),
41        ("millis","milliseconds"),
42        ("ms","milliseconds"),
43        ("us","microseconds"),
44        ("ns","nanoseconds"),
45    ]);
46}
47
48/// 'INTERVAL' abbreviation transformer
49/// - `y` for `years`
50/// - `mon` for `months`
51/// - `w` for `weeks`
52/// - `d` for `days`
53/// - `h` for `hours`
54/// - `m` for `minutes`
55/// - `s` for `seconds`
56/// - `millis` for `milliseconds`
57/// - `ms` for `milliseconds`
58/// - `us` for `microseconds`
59/// - `ns` for `nanoseconds`
60///
61/// Required for scenarios that use the shortened version of `INTERVAL`,
62///   f.e `SELECT INTERVAL '1h'` or `SELECT INTERVAL '3w2d'`
63pub(crate) struct ExpandIntervalTransformRule;
64
65impl TransformRule for ExpandIntervalTransformRule {
66    /// Applies transform rule for `Interval` type by extending the shortened version (e.g. '1h', '2d') or
67    /// converting ISO 8601 format strings (e.g., "P1Y2M3D")
68    /// In case when `Interval` has `BinaryOp` value (e.g. query like `SELECT INTERVAL '2h' - INTERVAL '1h'`)
69    /// it's AST has `left` part of type `Value::SingleQuotedString` which needs to be handled specifically.
70    /// To handle the `right` part which is `Interval` no extra steps are needed.
71    fn visit_expr(&self, expr: &mut Expr) -> ControlFlow<()> {
72        match expr {
73            Expr::Interval(interval) => match &*interval.value {
74                Expr::Value(ValueWithSpan {
75                    value: Value::SingleQuotedString(value),
76                    ..
77                })
78                | Expr::Value(ValueWithSpan {
79                    value: Value::DoubleQuotedString(value),
80                    ..
81                }) => {
82                    if let Some(normalized_name) = normalize_interval_name(value) {
83                        *expr = update_existing_interval_with_value(
84                            interval,
85                            single_quoted_string_expr(normalized_name),
86                        );
87                    }
88                }
89                Expr::BinaryOp { left, op, right } => match &**left {
90                    Expr::Value(ValueWithSpan {
91                        value: Value::SingleQuotedString(value),
92                        ..
93                    })
94                    | Expr::Value(ValueWithSpan {
95                        value: Value::DoubleQuotedString(value),
96                        ..
97                    }) => {
98                        if let Some(normalized_name) = normalize_interval_name(value) {
99                            let new_expr_value = Box::new(Expr::BinaryOp {
100                                left: single_quoted_string_expr(normalized_name),
101                                op: op.clone(),
102                                right: right.clone(),
103                            });
104                            *expr = update_existing_interval_with_value(interval, new_expr_value);
105                        }
106                    }
107                    _ => {}
108                },
109                _ => {}
110            },
111            Expr::Cast {
112                expr: cast_exp,
113                data_type,
114                array,
115                kind,
116                format,
117            } => {
118                if matches!(data_type, DataType::Interval { .. }) {
119                    match &**cast_exp {
120                        Expr::Value(ValueWithSpan {
121                            value: Value::SingleQuotedString(value),
122                            ..
123                        })
124                        | Expr::Value(ValueWithSpan {
125                            value: Value::DoubleQuotedString(value),
126                            ..
127                        }) => {
128                            let interval_value =
129                                normalize_interval_name(value).unwrap_or_else(|| value.clone());
130                            *expr = Expr::Cast {
131                                kind: kind.clone(),
132                                expr: single_quoted_string_expr(interval_value),
133                                data_type: data_type.clone(),
134                                array: *array,
135                                format: std::mem::take(format),
136                            }
137                        }
138                        _ => {}
139                    }
140                }
141            }
142            _ => {}
143        }
144        ControlFlow::<()>::Continue(())
145    }
146}
147
148fn single_quoted_string_expr(string: String) -> Box<Expr> {
149    Box::new(Expr::Value(Value::SingleQuotedString(string).into()))
150}
151
152fn update_existing_interval_with_value(interval: &Interval, value: Box<Expr>) -> Expr {
153    Expr::Interval(Interval {
154        value,
155        leading_field: interval.leading_field.clone(),
156        leading_precision: interval.leading_precision,
157        last_field: interval.last_field.clone(),
158        fractional_seconds_precision: interval.fractional_seconds_precision,
159    })
160}
161
162/// Normalizes an interval expression string into the sql-compatible format.
163/// This function handles 2 types of input:
164/// 1. Abbreviated interval strings (e.g., "1y2mo3d")
165///    Returns an interval's full name (e.g., "years", "hours", "minutes") according to the `INTERVAL_ABBREVIATION_MAPPING`
166///    If the `interval_str` contains whitespaces, the interval name is considered to be in a full form.
167/// 2. ISO 8601 format strings (e.g., "P1Y2M3D"), case/sign independent
168///    Returns a number of milliseconds corresponding to ISO 8601 (e.g., "36525000 milliseconds")
169///
170/// Note: Hybrid format "1y 2 days 3h" is not supported.
171fn normalize_interval_name(interval_str: &str) -> Option<String> {
172    if interval_str.contains(char::is_whitespace) {
173        return None;
174    }
175
176    if IS_VALID_ISO_8601_PREFIX_PATTERN.is_match(interval_str) {
177        return parse_iso8601_interval(interval_str);
178    }
179
180    expand_interval_abbreviation(interval_str)
181}
182
183fn parse_iso8601_interval(signed_iso: &str) -> Option<String> {
184    let (is_negative, unsigned_iso) = if let Some(stripped) = signed_iso.strip_prefix('-') {
185        (true, stripped)
186    } else {
187        (false, signed_iso)
188    };
189
190    match iso8601::duration(&unsigned_iso.to_uppercase()) {
191        Ok(duration) => {
192            let millis = StdDuration::from(duration).as_millis();
193            let sign = if is_negative { "-" } else { "" };
194            Some(format!("{}{} milliseconds", sign, millis))
195        }
196        Err(_) => None,
197    }
198}
199
200fn expand_interval_abbreviation(interval_str: &str) -> Option<String> {
201    Some(
202        INTERVAL_ABBREVIATION_PATTERN
203            .find_iter(interval_str)
204            .map(|mat| {
205                let mat_str = mat.as_str();
206                *INTERVAL_ABBREVIATION_MAPPING
207                    .get(mat_str)
208                    .unwrap_or(&mat_str)
209            })
210            .join(" "),
211    )
212}
213
214#[cfg(test)]
215mod tests {
216    use std::ops::ControlFlow;
217
218    use sqlparser::ast::{BinaryOperator, CastKind, DataType, Expr, Interval, Value};
219
220    use crate::statements::transform::TransformRule;
221    use crate::statements::transform::expand_interval::{
222        ExpandIntervalTransformRule, normalize_interval_name, single_quoted_string_expr,
223    };
224
225    fn create_interval(value: Box<Expr>) -> Expr {
226        Expr::Interval(Interval {
227            value,
228            leading_field: None,
229            leading_precision: None,
230            last_field: None,
231            fractional_seconds_precision: None,
232        })
233    }
234
235    #[test]
236    fn test_transform_interval_basic_conversions() {
237        let test_cases = vec![
238            ("1y", "1 years"),
239            ("4mon", "4 months"),
240            ("-3w", "-3 weeks"),
241            ("55h", "55 hours"),
242            ("3d", "3 days"),
243            ("5s", "5 seconds"),
244            ("2m", "2 minutes"),
245            ("100millis", "100 milliseconds"),
246            ("200ms", "200 milliseconds"),
247            ("350us", "350 microseconds"),
248            ("400ns", "400 nanoseconds"),
249        ];
250        for (input, expected) in test_cases {
251            let result = normalize_interval_name(input).unwrap();
252            assert_eq!(result, expected);
253        }
254
255        let test_cases = vec!["1 year 2 months 3 days 4 hours", "-2 months"];
256        for input in test_cases {
257            assert_eq!(normalize_interval_name(input), None);
258        }
259    }
260
261    #[test]
262    fn test_transform_interval_compound_conversions() {
263        let test_cases = vec![
264            ("2y4mon6w", "2 years 4 months 6 weeks"),
265            ("5d3h1m", "5 days 3 hours 1 minutes"),
266            (
267                "10s312ms789ns",
268                "10 seconds 312 milliseconds 789 nanoseconds",
269            ),
270            (
271                "23millis987us754ns",
272                "23 milliseconds 987 microseconds 754 nanoseconds",
273            ),
274            ("-1d-5h", "-1 days -5 hours"),
275            ("-2y-4mon-6w", "-2 years -4 months -6 weeks"),
276            ("-5d-3h-1m", "-5 days -3 hours -1 minutes"),
277            (
278                "-10s-312ms-789ns",
279                "-10 seconds -312 milliseconds -789 nanoseconds",
280            ),
281            (
282                "-23millis-987us-754ns",
283                "-23 milliseconds -987 microseconds -754 nanoseconds",
284            ),
285        ];
286        for (input, expected) in test_cases {
287            let result = normalize_interval_name(input).unwrap();
288            assert_eq!(result, expected);
289        }
290    }
291
292    #[test]
293    fn test_iso8601_format() {
294        assert_eq!(
295            normalize_interval_name("P1Y2M3DT4H5M6S"),
296            Some("36993906000 milliseconds".to_string())
297        );
298        assert_eq!(
299            normalize_interval_name("p3y3m700dt133h17m36.789s"),
300            Some("163343856789 milliseconds".to_string())
301        );
302        assert_eq!(
303            normalize_interval_name("-P1Y2M3DT4H5M6S"),
304            Some("-36993906000 milliseconds".to_string())
305        );
306        assert_eq!(normalize_interval_name("P1_INVALID_ISO8601"), None);
307    }
308
309    #[test]
310    fn test_visit_expr_when_interval_is_single_quoted_string_abbr_expr() {
311        let interval_transformation_rule = ExpandIntervalTransformRule {};
312
313        let mut string_expr = create_interval(single_quoted_string_expr("5y".to_string()));
314
315        let control_flow = interval_transformation_rule.visit_expr(&mut string_expr);
316
317        assert_eq!(control_flow, ControlFlow::Continue(()));
318        assert_eq!(
319            string_expr,
320            Expr::Interval(Interval {
321                value: Box::new(Expr::Value(
322                    Value::SingleQuotedString("5 years".to_string()).into()
323                )),
324                leading_field: None,
325                leading_precision: None,
326                last_field: None,
327                fractional_seconds_precision: None,
328            })
329        );
330    }
331
332    #[test]
333    fn test_visit_expr_when_interval_is_single_quoted_string_iso8601_expr() {
334        let interval_transformation_rule = ExpandIntervalTransformRule {};
335
336        let mut string_expr =
337            create_interval(single_quoted_string_expr("P1Y2M3DT4H5M6S".to_string()));
338
339        let control_flow = interval_transformation_rule.visit_expr(&mut string_expr);
340
341        assert_eq!(control_flow, ControlFlow::Continue(()));
342        assert_eq!(
343            string_expr,
344            Expr::Interval(Interval {
345                value: Box::new(Expr::Value(
346                    Value::SingleQuotedString("36993906000 milliseconds".to_string()).into()
347                )),
348                leading_field: None,
349                leading_precision: None,
350                last_field: None,
351                fractional_seconds_precision: None,
352            })
353        );
354    }
355
356    #[test]
357    fn test_visit_expr_when_interval_is_binary_op() {
358        let interval_transformation_rule = ExpandIntervalTransformRule {};
359
360        let binary_op = Box::new(Expr::BinaryOp {
361            left: single_quoted_string_expr("2d".to_string()),
362            op: BinaryOperator::Minus,
363            right: Box::new(create_interval(single_quoted_string_expr("1d".to_string()))),
364        });
365        let mut binary_op_expr = create_interval(binary_op);
366        let control_flow = interval_transformation_rule.visit_expr(&mut binary_op_expr);
367
368        assert_eq!(control_flow, ControlFlow::Continue(()));
369        assert_eq!(
370            binary_op_expr,
371            Expr::Interval(Interval {
372                value: Box::new(Expr::BinaryOp {
373                    left: single_quoted_string_expr("2 days".to_string()),
374                    op: BinaryOperator::Minus,
375                    right: Box::new(Expr::Interval(Interval {
376                        value: single_quoted_string_expr("1d".to_string()),
377                        leading_field: None,
378                        leading_precision: None,
379                        last_field: None,
380                        fractional_seconds_precision: None,
381                    })),
382                }),
383                leading_field: None,
384                leading_precision: None,
385                last_field: None,
386                fractional_seconds_precision: None,
387            })
388        );
389    }
390
391    #[test]
392    fn test_visit_expr_when_cast_expr() {
393        let interval_transformation_rule = ExpandIntervalTransformRule {};
394
395        let mut cast_to_interval_expr = Expr::Cast {
396            expr: single_quoted_string_expr("3y2mon".to_string()),
397            data_type: DataType::Interval {
398                fields: None,
399                precision: None,
400            },
401            array: false,
402            format: None,
403            kind: sqlparser::ast::CastKind::Cast,
404        };
405
406        let control_flow = interval_transformation_rule.visit_expr(&mut cast_to_interval_expr);
407
408        assert_eq!(control_flow, ControlFlow::Continue(()));
409        assert_eq!(
410            cast_to_interval_expr,
411            Expr::Cast {
412                kind: CastKind::Cast,
413                expr: Box::new(Expr::Value(
414                    Value::SingleQuotedString("3 years 2 months".to_string()).into()
415                )),
416                data_type: DataType::Interval {
417                    fields: None,
418                    precision: None,
419                },
420                array: false,
421                format: None,
422            }
423        );
424
425        let mut cast_to_i64_expr = Expr::Cast {
426            expr: single_quoted_string_expr("5".to_string()),
427            data_type: DataType::Int64,
428            array: false,
429            format: None,
430            kind: sqlparser::ast::CastKind::Cast,
431        };
432        let control_flow = interval_transformation_rule.visit_expr(&mut cast_to_i64_expr);
433        assert_eq!(control_flow, ControlFlow::Continue(()));
434        assert_eq!(
435            cast_to_i64_expr,
436            Expr::Cast {
437                expr: single_quoted_string_expr("5".to_string()),
438                data_type: DataType::Int64,
439                array: false,
440                format: None,
441                kind: sqlparser::ast::CastKind::Cast,
442            }
443        );
444    }
445}