operator/statement/
set.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::str::FromStr;
16use std::time::Duration;
17
18use common_time::Timezone;
19use lazy_static::lazy_static;
20use regex::Regex;
21use session::context::Channel::Postgres;
22use session::context::QueryContextRef;
23use session::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle};
24use session::ReadPreference;
25use snafu::{ensure, OptionExt, ResultExt};
26use sql::ast::{Expr, Ident, Value};
27use sql::statements::set_variables::SetVariables;
28
29use crate::error::{InvalidConfigValueSnafu, InvalidSqlSnafu, NotSupportedSnafu, Result};
30
31lazy_static! {
32    // Regex rules:
33    // The string must start with a number (one or more digits).
34    // The number must be followed by one of the valid time units (ms, s, min, h, d).
35    // The string must end immediately after the unit, meaning there can be no extra
36    // characters or spaces after the valid time specification.
37    static ref PG_TIME_INPUT_REGEX: Regex = Regex::new(r"^(\d+)(ms|s|min|h|d)$").unwrap();
38}
39
40pub fn set_read_preference(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
41    let read_preference_expr = exprs.first().context(NotSupportedSnafu {
42        feat: "No read preference find in set variable statement",
43    })?;
44
45    match read_preference_expr {
46        Expr::Value(Value::SingleQuotedString(expr))
47        | Expr::Value(Value::DoubleQuotedString(expr)) => {
48            match ReadPreference::from_str(expr.as_str().to_lowercase().as_str()) {
49                Ok(read_preference) => ctx.set_read_preference(read_preference),
50                Err(_) => {
51                    return NotSupportedSnafu {
52                        feat: format!(
53                            "Invalid read preference expr {} in set variable statement",
54                            expr,
55                        ),
56                    }
57                    .fail()
58                }
59            }
60            Ok(())
61        }
62        expr => NotSupportedSnafu {
63            feat: format!(
64                "Unsupported read preference expr {} in set variable statement",
65                expr
66            ),
67        }
68        .fail(),
69    }
70}
71
72pub fn set_timezone(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
73    let tz_expr = exprs.first().context(NotSupportedSnafu {
74        feat: "No timezone find in set variable statement",
75    })?;
76    match tz_expr {
77        Expr::Value(Value::SingleQuotedString(tz)) | Expr::Value(Value::DoubleQuotedString(tz)) => {
78            match Timezone::from_tz_string(tz.as_str()) {
79                Ok(timezone) => ctx.set_timezone(timezone),
80                Err(_) => {
81                    return NotSupportedSnafu {
82                        feat: format!("Invalid timezone expr {} in set variable statement", tz),
83                    }
84                    .fail()
85                }
86            }
87            Ok(())
88        }
89        expr => NotSupportedSnafu {
90            feat: format!(
91                "Unsupported timezone expr {} in set variable statement",
92                expr
93            ),
94        }
95        .fail(),
96    }
97}
98
99pub fn set_bytea_output(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
100    let Some((var_value, [])) = exprs.split_first() else {
101        return (NotSupportedSnafu {
102            feat: "Set variable value must have one and only one value for bytea_output",
103        })
104        .fail();
105    };
106    let Expr::Value(value) = var_value else {
107        return (NotSupportedSnafu {
108            feat: "Set variable value must be a value",
109        })
110        .fail();
111    };
112    ctx.configuration_parameter().set_postgres_bytea_output(
113        PGByteaOutputValue::try_from(value.clone()).context(InvalidConfigValueSnafu)?,
114    );
115    Ok(())
116}
117
118pub fn set_search_path(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
119    let search_expr = exprs.first().context(NotSupportedSnafu {
120        feat: "No search path find in set variable statement",
121    })?;
122    match search_expr {
123        Expr::Value(Value::SingleQuotedString(search_path))
124        | Expr::Value(Value::DoubleQuotedString(search_path)) => {
125            ctx.set_current_schema(search_path);
126            Ok(())
127        }
128        Expr::Identifier(Ident { value, .. }) => {
129            ctx.set_current_schema(value);
130            Ok(())
131        }
132        expr => NotSupportedSnafu {
133            feat: format!(
134                "Unsupported search path expr {} in set variable statement",
135                expr
136            ),
137        }
138        .fail(),
139    }
140}
141
142pub fn validate_client_encoding(set: SetVariables) -> Result<()> {
143    let Some((encoding, [])) = set.value.split_first() else {
144        return InvalidSqlSnafu {
145            err_msg: "must provide one and only one client encoding value",
146        }
147        .fail();
148    };
149    let encoding = match encoding {
150        Expr::Value(Value::SingleQuotedString(x))
151        | Expr::Identifier(Ident {
152            value: x,
153            quote_style: _,
154            span: _,
155        }) => x.to_uppercase(),
156        _ => {
157            return InvalidSqlSnafu {
158                err_msg: format!("client encoding must be a string, actual: {:?}", encoding),
159            }
160            .fail();
161        }
162    };
163    // For the sake of simplicity, we only support "UTF8" ("UNICODE" is the alias for it,
164    // see https://www.postgresql.org/docs/current/multibyte.html#MULTIBYTE-CHARSET-SUPPORTED).
165    // "UTF8" is universal and sufficient for almost all cases.
166    // GreptimeDB itself is always using "UTF8" as the internal encoding.
167    ensure!(
168        encoding == "UTF8" || encoding == "UNICODE",
169        NotSupportedSnafu {
170            feat: format!("client encoding of '{}'", encoding)
171        }
172    );
173    Ok(())
174}
175
176// if one of original value and new value is none, return the other one
177// returns new values only when it equals to original one else return error.
178// This is only used for handling datestyle
179fn merge_datestyle_value<T>(value: Option<T>, new_value: Option<T>) -> Result<Option<T>>
180where
181    T: PartialEq,
182{
183    match (&value, &new_value) {
184        (None, _) => Ok(new_value),
185        (_, None) => Ok(value),
186        (Some(v1), Some(v2)) if v1 == v2 => Ok(new_value),
187        _ => InvalidSqlSnafu {
188            err_msg: "Conflicting \"datestyle\" specifications.",
189        }
190        .fail(),
191    }
192}
193
194fn try_parse_datestyle(expr: &Expr) -> Result<(Option<PGDateTimeStyle>, Option<PGDateOrder>)> {
195    enum ParsedDateStyle {
196        Order(PGDateOrder),
197        Style(PGDateTimeStyle),
198    }
199    fn try_parse_str(s: &str) -> Result<ParsedDateStyle> {
200        PGDateTimeStyle::try_from(s)
201            .map_or_else(
202                |_| PGDateOrder::try_from(s).map(ParsedDateStyle::Order),
203                |style| Ok(ParsedDateStyle::Style(style)),
204            )
205            .context(InvalidConfigValueSnafu)
206    }
207    match expr {
208        Expr::Identifier(Ident {
209            value: s,
210            quote_style: _,
211            span: _,
212        })
213        | Expr::Value(Value::SingleQuotedString(s))
214        | Expr::Value(Value::DoubleQuotedString(s)) => {
215            s.split(',')
216                .map(|s| s.trim())
217                .try_fold((None, None), |(style, order), s| match try_parse_str(s)? {
218                    ParsedDateStyle::Order(o) => {
219                        Ok((style, merge_datestyle_value(order, Some(o))?))
220                    }
221                    ParsedDateStyle::Style(s) => {
222                        Ok((merge_datestyle_value(style, Some(s))?, order))
223                    }
224                })
225        }
226        _ => NotSupportedSnafu {
227            feat: "Not supported expression for datestyle",
228        }
229        .fail(),
230    }
231}
232
233pub fn set_datestyle(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
234    // ORDER,
235    // STYLE,
236    // ORDER,ORDER
237    // ORDER,STYLE
238    // STYLE,ORDER
239    let (style, order) = exprs
240        .iter()
241        .try_fold((None, None), |(style, order), expr| {
242            let (new_style, new_order) = try_parse_datestyle(expr)?;
243            Ok((
244                merge_datestyle_value(style, new_style)?,
245                merge_datestyle_value(order, new_order)?,
246            ))
247        })?;
248
249    let (old_style, older_order) = *ctx.configuration_parameter().pg_datetime_style();
250    ctx.configuration_parameter()
251        .set_pg_datetime_style(style.unwrap_or(old_style), order.unwrap_or(older_order));
252    Ok(())
253}
254
255pub fn set_query_timeout(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
256    let timeout_expr = exprs.first().context(NotSupportedSnafu {
257        feat: "No timeout value find in set query timeout statement",
258    })?;
259    match timeout_expr {
260        Expr::Value(Value::Number(timeout, _)) => {
261            match timeout.parse::<u64>() {
262                Ok(timeout) => ctx.set_query_timeout(Duration::from_millis(timeout)),
263                Err(_) => {
264                    return NotSupportedSnafu {
265                        feat: format!("Invalid timeout expr {} in set variable statement", timeout),
266                    }
267                    .fail()
268                }
269            }
270            Ok(())
271        }
272        // postgres support time units i.e. SET STATEMENT_TIMEOUT = '50ms';
273        Expr::Value(Value::SingleQuotedString(timeout))
274        | Expr::Value(Value::DoubleQuotedString(timeout)) => {
275            if ctx.channel() != Postgres {
276                return NotSupportedSnafu {
277                    feat: format!("Invalid timeout expr {} in set variable statement", timeout),
278                }
279                .fail();
280            }
281            let timeout = parse_pg_query_timeout_input(timeout)?;
282            ctx.set_query_timeout(Duration::from_millis(timeout));
283            Ok(())
284        }
285        expr => NotSupportedSnafu {
286            feat: format!(
287                "Unsupported timeout expr {} in set variable statement",
288                expr
289            ),
290        }
291        .fail(),
292    }
293}
294
295// support time units in ms, s, min, h, d for postgres protocol.
296// https://www.postgresql.org/docs/8.4/config-setting.html#:~:text=Valid%20memory%20units%20are%20kB,%2C%20and%20d%20(days).
297fn parse_pg_query_timeout_input(input: &str) -> Result<u64> {
298    match input.parse::<u64>() {
299        Ok(timeout) => Ok(timeout),
300        Err(_) => {
301            if let Some(captures) = PG_TIME_INPUT_REGEX.captures(input) {
302                let value = captures[1].parse::<u64>().expect("regex failed");
303                let unit = &captures[2];
304
305                match unit {
306                    "ms" => Ok(value),
307                    "s" => Ok(value * 1000),
308                    "min" => Ok(value * 60 * 1000),
309                    "h" => Ok(value * 60 * 60 * 1000),
310                    "d" => Ok(value * 24 * 60 * 60 * 1000),
311                    _ => unreachable!("regex failed"),
312                }
313            } else {
314                NotSupportedSnafu {
315                    feat: format!(
316                        "Unsupported timeout expr {} in set variable statement",
317                        input
318                    ),
319                }
320                .fail()
321            }
322        }
323    }
324}
325
326#[cfg(test)]
327mod test {
328    use crate::statement::set::parse_pg_query_timeout_input;
329
330    #[test]
331    fn test_parse_pg_query_timeout_input() {
332        assert!(parse_pg_query_timeout_input("").is_err());
333        assert!(parse_pg_query_timeout_input(" 50 ms").is_err());
334        assert!(parse_pg_query_timeout_input("5s 1ms").is_err());
335        assert!(parse_pg_query_timeout_input("3a").is_err());
336        assert!(parse_pg_query_timeout_input("1.5min").is_err());
337        assert!(parse_pg_query_timeout_input("ms").is_err());
338        assert!(parse_pg_query_timeout_input("a").is_err());
339        assert!(parse_pg_query_timeout_input("-1").is_err());
340
341        assert_eq!(50, parse_pg_query_timeout_input("50").unwrap());
342        assert_eq!(12, parse_pg_query_timeout_input("12ms").unwrap());
343        assert_eq!(2000, parse_pg_query_timeout_input("2s").unwrap());
344        assert_eq!(60000, parse_pg_query_timeout_input("1min").unwrap());
345    }
346}