operator/statement/
set.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::str::FromStr;
16use std::time::Duration;
17
18use common_time::Timezone;
19use lazy_static::lazy_static;
20use regex::Regex;
21use session::context::Channel::Postgres;
22use session::context::QueryContextRef;
23use session::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle};
24use session::ReadPreference;
25use snafu::{ensure, OptionExt, ResultExt};
26use sql::ast::{Expr, Ident, Value};
27use sql::statements::set_variables::SetVariables;
28
29use crate::error::{InvalidConfigValueSnafu, InvalidSqlSnafu, NotSupportedSnafu, Result};
30
31lazy_static! {
32    // Regex rules:
33    // The string must start with a number (one or more digits).
34    // The number must be followed by one of the valid time units (ms, s, min, h, d).
35    // The string must end immediately after the unit, meaning there can be no extra
36    // characters or spaces after the valid time specification.
37    static ref PG_TIME_INPUT_REGEX: Regex = Regex::new(r"^(\d+)(ms|s|min|h|d)$").unwrap();
38}
39
40pub fn set_read_preference(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
41    let read_preference_expr = exprs.first().context(NotSupportedSnafu {
42        feat: "No read preference find in set variable statement",
43    })?;
44
45    match read_preference_expr {
46        Expr::Value(Value::SingleQuotedString(expr))
47        | Expr::Value(Value::DoubleQuotedString(expr)) => {
48            match ReadPreference::from_str(expr.as_str().to_lowercase().as_str()) {
49                Ok(read_preference) => ctx.set_read_preference(read_preference),
50                Err(_) => {
51                    return NotSupportedSnafu {
52                        feat: format!(
53                            "Invalid read preference expr {} in set variable statement",
54                            expr,
55                        ),
56                    }
57                    .fail()
58                }
59            }
60            Ok(())
61        }
62        expr => NotSupportedSnafu {
63            feat: format!(
64                "Unsupported read preference expr {} in set variable statement",
65                expr
66            ),
67        }
68        .fail(),
69    }
70}
71
72pub fn set_timezone(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
73    let tz_expr = exprs.first().context(NotSupportedSnafu {
74        feat: "No timezone find in set variable statement",
75    })?;
76    match tz_expr {
77        Expr::Value(Value::SingleQuotedString(tz)) | Expr::Value(Value::DoubleQuotedString(tz)) => {
78            match Timezone::from_tz_string(tz.as_str()) {
79                Ok(timezone) => ctx.set_timezone(timezone),
80                Err(_) => {
81                    return NotSupportedSnafu {
82                        feat: format!("Invalid timezone expr {} in set variable statement", tz),
83                    }
84                    .fail()
85                }
86            }
87            Ok(())
88        }
89        expr => NotSupportedSnafu {
90            feat: format!(
91                "Unsupported timezone expr {} in set variable statement",
92                expr
93            ),
94        }
95        .fail(),
96    }
97}
98
99pub fn set_bytea_output(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
100    let Some((var_value, [])) = exprs.split_first() else {
101        return (NotSupportedSnafu {
102            feat: "Set variable value must have one and only one value for bytea_output",
103        })
104        .fail();
105    };
106    let Expr::Value(value) = var_value else {
107        return (NotSupportedSnafu {
108            feat: "Set variable value must be a value",
109        })
110        .fail();
111    };
112    ctx.configuration_parameter().set_postgres_bytea_output(
113        PGByteaOutputValue::try_from(value.clone()).context(InvalidConfigValueSnafu)?,
114    );
115    Ok(())
116}
117
118pub fn set_search_path(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
119    let search_expr = exprs.first().context(NotSupportedSnafu {
120        feat: "No search path find in set variable statement",
121    })?;
122    match search_expr {
123        Expr::Value(Value::SingleQuotedString(search_path))
124        | Expr::Value(Value::DoubleQuotedString(search_path)) => {
125            ctx.set_current_schema(&search_path.clone());
126            Ok(())
127        }
128        expr => NotSupportedSnafu {
129            feat: format!(
130                "Unsupported search path expr {} in set variable statement",
131                expr
132            ),
133        }
134        .fail(),
135    }
136}
137
138pub fn validate_client_encoding(set: SetVariables) -> Result<()> {
139    let Some((encoding, [])) = set.value.split_first() else {
140        return InvalidSqlSnafu {
141            err_msg: "must provide one and only one client encoding value",
142        }
143        .fail();
144    };
145    let encoding = match encoding {
146        Expr::Value(Value::SingleQuotedString(x))
147        | Expr::Identifier(Ident {
148            value: x,
149            quote_style: _,
150            span: _,
151        }) => x.to_uppercase(),
152        _ => {
153            return InvalidSqlSnafu {
154                err_msg: format!("client encoding must be a string, actual: {:?}", encoding),
155            }
156            .fail();
157        }
158    };
159    // For the sake of simplicity, we only support "UTF8" ("UNICODE" is the alias for it,
160    // see https://www.postgresql.org/docs/current/multibyte.html#MULTIBYTE-CHARSET-SUPPORTED).
161    // "UTF8" is universal and sufficient for almost all cases.
162    // GreptimeDB itself is always using "UTF8" as the internal encoding.
163    ensure!(
164        encoding == "UTF8" || encoding == "UNICODE",
165        NotSupportedSnafu {
166            feat: format!("client encoding of '{}'", encoding)
167        }
168    );
169    Ok(())
170}
171
172// if one of original value and new value is none, return the other one
173// returns new values only when it equals to original one else return error.
174// This is only used for handling datestyle
175fn merge_datestyle_value<T>(value: Option<T>, new_value: Option<T>) -> Result<Option<T>>
176where
177    T: PartialEq,
178{
179    match (&value, &new_value) {
180        (None, _) => Ok(new_value),
181        (_, None) => Ok(value),
182        (Some(v1), Some(v2)) if v1 == v2 => Ok(new_value),
183        _ => InvalidSqlSnafu {
184            err_msg: "Conflicting \"datestyle\" specifications.",
185        }
186        .fail(),
187    }
188}
189
190fn try_parse_datestyle(expr: &Expr) -> Result<(Option<PGDateTimeStyle>, Option<PGDateOrder>)> {
191    enum ParsedDateStyle {
192        Order(PGDateOrder),
193        Style(PGDateTimeStyle),
194    }
195    fn try_parse_str(s: &str) -> Result<ParsedDateStyle> {
196        PGDateTimeStyle::try_from(s)
197            .map_or_else(
198                |_| PGDateOrder::try_from(s).map(ParsedDateStyle::Order),
199                |style| Ok(ParsedDateStyle::Style(style)),
200            )
201            .context(InvalidConfigValueSnafu)
202    }
203    match expr {
204        Expr::Identifier(Ident {
205            value: s,
206            quote_style: _,
207            span: _,
208        })
209        | Expr::Value(Value::SingleQuotedString(s))
210        | Expr::Value(Value::DoubleQuotedString(s)) => {
211            s.split(',')
212                .map(|s| s.trim())
213                .try_fold((None, None), |(style, order), s| match try_parse_str(s)? {
214                    ParsedDateStyle::Order(o) => {
215                        Ok((style, merge_datestyle_value(order, Some(o))?))
216                    }
217                    ParsedDateStyle::Style(s) => {
218                        Ok((merge_datestyle_value(style, Some(s))?, order))
219                    }
220                })
221        }
222        _ => NotSupportedSnafu {
223            feat: "Not supported expression for datestyle",
224        }
225        .fail(),
226    }
227}
228
229pub fn set_datestyle(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
230    // ORDER,
231    // STYLE,
232    // ORDER,ORDER
233    // ORDER,STYLE
234    // STYLE,ORDER
235    let (style, order) = exprs
236        .iter()
237        .try_fold((None, None), |(style, order), expr| {
238            let (new_style, new_order) = try_parse_datestyle(expr)?;
239            Ok((
240                merge_datestyle_value(style, new_style)?,
241                merge_datestyle_value(order, new_order)?,
242            ))
243        })?;
244
245    let (old_style, older_order) = *ctx.configuration_parameter().pg_datetime_style();
246    ctx.configuration_parameter()
247        .set_pg_datetime_style(style.unwrap_or(old_style), order.unwrap_or(older_order));
248    Ok(())
249}
250
251pub fn set_query_timeout(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
252    let timeout_expr = exprs.first().context(NotSupportedSnafu {
253        feat: "No timeout value find in set query timeout statement",
254    })?;
255    match timeout_expr {
256        Expr::Value(Value::Number(timeout, _)) => {
257            match timeout.parse::<u64>() {
258                Ok(timeout) => ctx.set_query_timeout(Duration::from_millis(timeout)),
259                Err(_) => {
260                    return NotSupportedSnafu {
261                        feat: format!("Invalid timeout expr {} in set variable statement", timeout),
262                    }
263                    .fail()
264                }
265            }
266            Ok(())
267        }
268        // postgres support time units i.e. SET STATEMENT_TIMEOUT = '50ms';
269        Expr::Value(Value::SingleQuotedString(timeout))
270        | Expr::Value(Value::DoubleQuotedString(timeout)) => {
271            if ctx.channel() != Postgres {
272                return NotSupportedSnafu {
273                    feat: format!("Invalid timeout expr {} in set variable statement", timeout),
274                }
275                .fail();
276            }
277            let timeout = parse_pg_query_timeout_input(timeout)?;
278            ctx.set_query_timeout(Duration::from_millis(timeout));
279            Ok(())
280        }
281        expr => NotSupportedSnafu {
282            feat: format!(
283                "Unsupported timeout expr {} in set variable statement",
284                expr
285            ),
286        }
287        .fail(),
288    }
289}
290
291// support time units in ms, s, min, h, d for postgres protocol.
292// https://www.postgresql.org/docs/8.4/config-setting.html#:~:text=Valid%20memory%20units%20are%20kB,%2C%20and%20d%20(days).
293fn parse_pg_query_timeout_input(input: &str) -> Result<u64> {
294    match input.parse::<u64>() {
295        Ok(timeout) => Ok(timeout),
296        Err(_) => {
297            if let Some(captures) = PG_TIME_INPUT_REGEX.captures(input) {
298                let value = captures[1].parse::<u64>().expect("regex failed");
299                let unit = &captures[2];
300
301                match unit {
302                    "ms" => Ok(value),
303                    "s" => Ok(value * 1000),
304                    "min" => Ok(value * 60 * 1000),
305                    "h" => Ok(value * 60 * 60 * 1000),
306                    "d" => Ok(value * 24 * 60 * 60 * 1000),
307                    _ => unreachable!("regex failed"),
308                }
309            } else {
310                NotSupportedSnafu {
311                    feat: format!(
312                        "Unsupported timeout expr {} in set variable statement",
313                        input
314                    ),
315                }
316                .fail()
317            }
318        }
319    }
320}
321
322#[cfg(test)]
323mod test {
324    use crate::statement::set::parse_pg_query_timeout_input;
325
326    #[test]
327    fn test_parse_pg_query_timeout_input() {
328        assert!(parse_pg_query_timeout_input("").is_err());
329        assert!(parse_pg_query_timeout_input(" 50 ms").is_err());
330        assert!(parse_pg_query_timeout_input("5s 1ms").is_err());
331        assert!(parse_pg_query_timeout_input("3a").is_err());
332        assert!(parse_pg_query_timeout_input("1.5min").is_err());
333        assert!(parse_pg_query_timeout_input("ms").is_err());
334        assert!(parse_pg_query_timeout_input("a").is_err());
335        assert!(parse_pg_query_timeout_input("-1").is_err());
336
337        assert_eq!(50, parse_pg_query_timeout_input("50").unwrap());
338        assert_eq!(12, parse_pg_query_timeout_input("12ms").unwrap());
339        assert_eq!(2000, parse_pg_query_timeout_input("2s").unwrap());
340        assert_eq!(60000, parse_pg_query_timeout_input("1min").unwrap());
341    }
342}