operator/statement/
set.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::str::FromStr;
16use std::time::Duration;
17
18use common_time::Timezone;
19use lazy_static::lazy_static;
20use regex::Regex;
21use session::context::Channel::Postgres;
22use session::context::QueryContextRef;
23use session::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle};
24use session::ReadPreference;
25use snafu::{ensure, OptionExt, ResultExt};
26use sql::ast::{Expr, Ident, Value};
27use sql::statements::set_variables::SetVariables;
28
29use crate::error::{InvalidConfigValueSnafu, InvalidSqlSnafu, NotSupportedSnafu, Result};
30
31lazy_static! {
32    // Regex rules:
33    // The string must start with a number (one or more digits).
34    // The number must be followed by one of the valid time units (ms, s, min, h, d).
35    // The string must end immediately after the unit, meaning there can be no extra
36    // characters or spaces after the valid time specification.
37    static ref PG_TIME_INPUT_REGEX: Regex = Regex::new(r"^(\d+)(ms|s|min|h|d)$").unwrap();
38}
39
40pub fn set_read_preference(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
41    let read_preference_expr = exprs.first().context(NotSupportedSnafu {
42        feat: "No read preference find in set variable statement",
43    })?;
44
45    match read_preference_expr {
46        Expr::Value(Value::SingleQuotedString(expr))
47        | Expr::Value(Value::DoubleQuotedString(expr)) => {
48            match ReadPreference::from_str(expr.as_str().to_lowercase().as_str()) {
49                Ok(read_preference) => ctx.set_read_preference(read_preference),
50                Err(_) => {
51                    return NotSupportedSnafu {
52                        feat: format!(
53                            "Invalid read preference expr {} in set variable statement",
54                            expr,
55                        ),
56                    }
57                    .fail()
58                }
59            }
60            Ok(())
61        }
62        expr => NotSupportedSnafu {
63            feat: format!(
64                "Unsupported read preference expr {} in set variable statement",
65                expr
66            ),
67        }
68        .fail(),
69    }
70}
71
72pub fn set_timezone(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
73    let tz_expr = exprs.first().context(NotSupportedSnafu {
74        feat: "No timezone find in set variable statement",
75    })?;
76    match tz_expr {
77        Expr::Value(Value::SingleQuotedString(tz)) | Expr::Value(Value::DoubleQuotedString(tz)) => {
78            match Timezone::from_tz_string(tz.as_str()) {
79                Ok(timezone) => ctx.set_timezone(timezone),
80                Err(_) => {
81                    return NotSupportedSnafu {
82                        feat: format!("Invalid timezone expr {} in set variable statement", tz),
83                    }
84                    .fail()
85                }
86            }
87            Ok(())
88        }
89        expr => NotSupportedSnafu {
90            feat: format!(
91                "Unsupported timezone expr {} in set variable statement",
92                expr
93            ),
94        }
95        .fail(),
96    }
97}
98
99pub fn set_bytea_output(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
100    let Some((var_value, [])) = exprs.split_first() else {
101        return (NotSupportedSnafu {
102            feat: "Set variable value must have one and only one value for bytea_output",
103        })
104        .fail();
105    };
106    let Expr::Value(value) = var_value else {
107        return (NotSupportedSnafu {
108            feat: "Set variable value must be a value",
109        })
110        .fail();
111    };
112    ctx.configuration_parameter().set_postgres_bytea_output(
113        PGByteaOutputValue::try_from(value.clone()).context(InvalidConfigValueSnafu)?,
114    );
115    Ok(())
116}
117
118pub fn set_search_path(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
119    let search_expr = exprs.first().context(NotSupportedSnafu {
120        feat: "No search path find in set variable statement",
121    })?;
122    match search_expr {
123        Expr::Value(Value::SingleQuotedString(search_path))
124        | Expr::Value(Value::DoubleQuotedString(search_path)) => {
125            ctx.set_current_schema(search_path);
126            Ok(())
127        }
128        Expr::Identifier(Ident { value, .. }) => {
129            ctx.set_current_schema(value);
130            Ok(())
131        }
132        expr => NotSupportedSnafu {
133            feat: format!(
134                "Unsupported search path expr {} in set variable statement",
135                expr
136            ),
137        }
138        .fail(),
139    }
140}
141
142pub fn validate_client_encoding(set: SetVariables) -> Result<()> {
143    let Some((encoding, [])) = set.value.split_first() else {
144        return InvalidSqlSnafu {
145            err_msg: "must provide one and only one client encoding value",
146        }
147        .fail();
148    };
149    let encoding = match encoding {
150        Expr::Value(Value::SingleQuotedString(x))
151        | Expr::Identifier(Ident {
152            value: x,
153            quote_style: _,
154            span: _,
155        }) => x.to_uppercase(),
156        _ => {
157            return InvalidSqlSnafu {
158                err_msg: format!("client encoding must be a string, actual: {:?}", encoding),
159            }
160            .fail();
161        }
162    };
163    // For the sake of simplicity, we only support "UTF8" ("UNICODE" is the alias for it,
164    // see https://www.postgresql.org/docs/current/multibyte.html#MULTIBYTE-CHARSET-SUPPORTED).
165    // "UTF8" is universal and sufficient for almost all cases.
166    // GreptimeDB itself is always using "UTF8" as the internal encoding.
167    ensure!(
168        encoding == "UTF8" || encoding == "UNICODE",
169        NotSupportedSnafu {
170            feat: format!("client encoding of '{}'", encoding)
171        }
172    );
173    Ok(())
174}
175
176// if one of original value and new value is none, return the other one
177// returns new values only when it equals to original one else return error.
178// This is only used for handling datestyle
179fn merge_datestyle_value<T>(value: Option<T>, new_value: Option<T>) -> Result<Option<T>>
180where
181    T: PartialEq,
182{
183    match (&value, &new_value) {
184        (None, _) => Ok(new_value),
185        (_, None) => Ok(value),
186        (Some(v1), Some(v2)) if v1 == v2 => Ok(new_value),
187        _ => InvalidSqlSnafu {
188            err_msg: "Conflicting \"datestyle\" specifications.",
189        }
190        .fail(),
191    }
192}
193
194fn try_parse_datestyle(expr: &Expr) -> Result<(Option<PGDateTimeStyle>, Option<PGDateOrder>)> {
195    enum ParsedDateStyle {
196        Order(PGDateOrder),
197        Style(PGDateTimeStyle),
198    }
199    fn try_parse_str(s: &str) -> Result<ParsedDateStyle> {
200        PGDateTimeStyle::try_from(s)
201            .map_or_else(
202                |_| PGDateOrder::try_from(s).map(ParsedDateStyle::Order),
203                |style| Ok(ParsedDateStyle::Style(style)),
204            )
205            .context(InvalidConfigValueSnafu)
206    }
207    match expr {
208        Expr::Identifier(Ident {
209            value: s,
210            quote_style: _,
211            span: _,
212        })
213        | Expr::Value(Value::SingleQuotedString(s))
214        | Expr::Value(Value::DoubleQuotedString(s)) => {
215            s.split(',')
216                .map(|s| s.trim())
217                .try_fold((None, None), |(style, order), s| match try_parse_str(s)? {
218                    ParsedDateStyle::Order(o) => {
219                        Ok((style, merge_datestyle_value(order, Some(o))?))
220                    }
221                    ParsedDateStyle::Style(s) => {
222                        Ok((merge_datestyle_value(style, Some(s))?, order))
223                    }
224                })
225        }
226        _ => NotSupportedSnafu {
227            feat: "Not supported expression for datestyle",
228        }
229        .fail(),
230    }
231}
232
233/// Set the allow query fallback configuration parameter to true or false based on the provided expressions.
234///
235pub fn set_allow_query_fallback(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
236    let allow_fallback_expr = exprs.first().context(NotSupportedSnafu {
237        feat: "No allow query fallback value find in set variable statement",
238    })?;
239    match allow_fallback_expr {
240        Expr::Value(Value::Boolean(allow)) => {
241            ctx.configuration_parameter()
242                .set_allow_query_fallback(*allow);
243            Ok(())
244        }
245        expr => NotSupportedSnafu {
246            feat: format!(
247                "Unsupported allow query fallback expr {} in set variable statement",
248                expr
249            ),
250        }
251        .fail(),
252    }
253}
254
255pub fn set_datestyle(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
256    // ORDER,
257    // STYLE,
258    // ORDER,ORDER
259    // ORDER,STYLE
260    // STYLE,ORDER
261    let (style, order) = exprs
262        .iter()
263        .try_fold((None, None), |(style, order), expr| {
264            let (new_style, new_order) = try_parse_datestyle(expr)?;
265            Ok((
266                merge_datestyle_value(style, new_style)?,
267                merge_datestyle_value(order, new_order)?,
268            ))
269        })?;
270
271    let (old_style, older_order) = *ctx.configuration_parameter().pg_datetime_style();
272    ctx.configuration_parameter()
273        .set_pg_datetime_style(style.unwrap_or(old_style), order.unwrap_or(older_order));
274    Ok(())
275}
276
277pub fn set_query_timeout(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
278    let timeout_expr = exprs.first().context(NotSupportedSnafu {
279        feat: "No timeout value find in set query timeout statement",
280    })?;
281    match timeout_expr {
282        Expr::Value(Value::Number(timeout, _)) => {
283            match timeout.parse::<u64>() {
284                Ok(timeout) => ctx.set_query_timeout(Duration::from_millis(timeout)),
285                Err(_) => {
286                    return NotSupportedSnafu {
287                        feat: format!("Invalid timeout expr {} in set variable statement", timeout),
288                    }
289                    .fail()
290                }
291            }
292            Ok(())
293        }
294        // postgres support time units i.e. SET STATEMENT_TIMEOUT = '50ms';
295        Expr::Value(Value::SingleQuotedString(timeout))
296        | Expr::Value(Value::DoubleQuotedString(timeout)) => {
297            if ctx.channel() != Postgres {
298                return NotSupportedSnafu {
299                    feat: format!("Invalid timeout expr {} in set variable statement", timeout),
300                }
301                .fail();
302            }
303            let timeout = parse_pg_query_timeout_input(timeout)?;
304            ctx.set_query_timeout(Duration::from_millis(timeout));
305            Ok(())
306        }
307        expr => NotSupportedSnafu {
308            feat: format!(
309                "Unsupported timeout expr {} in set variable statement",
310                expr
311            ),
312        }
313        .fail(),
314    }
315}
316
317// support time units in ms, s, min, h, d for postgres protocol.
318// https://www.postgresql.org/docs/8.4/config-setting.html#:~:text=Valid%20memory%20units%20are%20kB,%2C%20and%20d%20(days).
319fn parse_pg_query_timeout_input(input: &str) -> Result<u64> {
320    match input.parse::<u64>() {
321        Ok(timeout) => Ok(timeout),
322        Err(_) => {
323            if let Some(captures) = PG_TIME_INPUT_REGEX.captures(input) {
324                let value = captures[1].parse::<u64>().expect("regex failed");
325                let unit = &captures[2];
326
327                match unit {
328                    "ms" => Ok(value),
329                    "s" => Ok(value * 1000),
330                    "min" => Ok(value * 60 * 1000),
331                    "h" => Ok(value * 60 * 60 * 1000),
332                    "d" => Ok(value * 24 * 60 * 60 * 1000),
333                    _ => unreachable!("regex failed"),
334                }
335            } else {
336                NotSupportedSnafu {
337                    feat: format!(
338                        "Unsupported timeout expr {} in set variable statement",
339                        input
340                    ),
341                }
342                .fail()
343            }
344        }
345    }
346}
347
348#[cfg(test)]
349mod test {
350    use crate::statement::set::parse_pg_query_timeout_input;
351
352    #[test]
353    fn test_parse_pg_query_timeout_input() {
354        assert!(parse_pg_query_timeout_input("").is_err());
355        assert!(parse_pg_query_timeout_input(" 50 ms").is_err());
356        assert!(parse_pg_query_timeout_input("5s 1ms").is_err());
357        assert!(parse_pg_query_timeout_input("3a").is_err());
358        assert!(parse_pg_query_timeout_input("1.5min").is_err());
359        assert!(parse_pg_query_timeout_input("ms").is_err());
360        assert!(parse_pg_query_timeout_input("a").is_err());
361        assert!(parse_pg_query_timeout_input("-1").is_err());
362
363        assert_eq!(50, parse_pg_query_timeout_input("50").unwrap());
364        assert_eq!(12, parse_pg_query_timeout_input("12ms").unwrap());
365        assert_eq!(2000, parse_pg_query_timeout_input("2s").unwrap());
366        assert_eq!(60000, parse_pg_query_timeout_input("1min").unwrap());
367    }
368}