operator/statement/
set.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::str::FromStr;
16use std::time::Duration;
17
18use common_time::Timezone;
19use lazy_static::lazy_static;
20use regex::Regex;
21use session::context::Channel::Postgres;
22use session::context::QueryContextRef;
23use session::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle};
24use session::ReadPreference;
25use snafu::{ensure, OptionExt, ResultExt};
26use sql::ast::{Expr, Ident, Value};
27use sql::statements::set_variables::SetVariables;
28use sqlparser::ast::ValueWithSpan;
29
30use crate::error::{InvalidConfigValueSnafu, InvalidSqlSnafu, NotSupportedSnafu, Result};
31
32lazy_static! {
33    // Regex rules:
34    // The string must start with a number (one or more digits).
35    // The number must be followed by one of the valid time units (ms, s, min, h, d).
36    // The string must end immediately after the unit, meaning there can be no extra
37    // characters or spaces after the valid time specification.
38    static ref PG_TIME_INPUT_REGEX: Regex = Regex::new(r"^(\d+)(ms|s|min|h|d)$").unwrap();
39}
40
41pub fn set_read_preference(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
42    let read_preference_expr = exprs.first().context(NotSupportedSnafu {
43        feat: "No read preference find in set variable statement",
44    })?;
45
46    match read_preference_expr {
47        Expr::Value(ValueWithSpan {
48            value: Value::SingleQuotedString(expr),
49            ..
50        })
51        | Expr::Value(ValueWithSpan {
52            value: Value::DoubleQuotedString(expr),
53            ..
54        }) => {
55            match ReadPreference::from_str(expr.as_str().to_lowercase().as_str()) {
56                Ok(read_preference) => ctx.set_read_preference(read_preference),
57                Err(_) => {
58                    return NotSupportedSnafu {
59                        feat: format!(
60                            "Invalid read preference expr {} in set variable statement",
61                            expr,
62                        ),
63                    }
64                    .fail()
65                }
66            }
67            Ok(())
68        }
69        expr => NotSupportedSnafu {
70            feat: format!(
71                "Unsupported read preference expr {} in set variable statement",
72                expr
73            ),
74        }
75        .fail(),
76    }
77}
78
79pub fn set_timezone(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
80    let tz_expr = exprs.first().context(NotSupportedSnafu {
81        feat: "No timezone find in set variable statement",
82    })?;
83    match tz_expr {
84        Expr::Value(ValueWithSpan {
85            value: Value::SingleQuotedString(tz),
86            ..
87        })
88        | Expr::Value(ValueWithSpan {
89            value: Value::DoubleQuotedString(tz),
90            ..
91        }) => {
92            match Timezone::from_tz_string(tz.as_str()) {
93                Ok(timezone) => ctx.set_timezone(timezone),
94                Err(_) => {
95                    return NotSupportedSnafu {
96                        feat: format!("Invalid timezone expr {} in set variable statement", tz),
97                    }
98                    .fail()
99                }
100            }
101            Ok(())
102        }
103        expr => NotSupportedSnafu {
104            feat: format!(
105                "Unsupported timezone expr {} in set variable statement",
106                expr
107            ),
108        }
109        .fail(),
110    }
111}
112
113pub fn set_bytea_output(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
114    let Some((var_value, [])) = exprs.split_first() else {
115        return (NotSupportedSnafu {
116            feat: "Set variable value must have one and only one value for bytea_output",
117        })
118        .fail();
119    };
120    let Expr::Value(value) = var_value else {
121        return (NotSupportedSnafu {
122            feat: "Set variable value must be a value",
123        })
124        .fail();
125    };
126    ctx.configuration_parameter().set_postgres_bytea_output(
127        PGByteaOutputValue::try_from(value.value.clone()).context(InvalidConfigValueSnafu)?,
128    );
129    Ok(())
130}
131
132pub fn set_search_path(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
133    let search_expr = exprs.first().context(NotSupportedSnafu {
134        feat: "No search path find in set variable statement",
135    })?;
136    match search_expr {
137        Expr::Value(ValueWithSpan {
138            value: Value::SingleQuotedString(search_path),
139            ..
140        })
141        | Expr::Value(ValueWithSpan {
142            value: Value::DoubleQuotedString(search_path),
143            ..
144        }) => {
145            ctx.set_current_schema(search_path);
146            Ok(())
147        }
148        Expr::Identifier(Ident { value, .. }) => {
149            ctx.set_current_schema(value);
150            Ok(())
151        }
152        expr => NotSupportedSnafu {
153            feat: format!(
154                "Unsupported search path expr {} in set variable statement",
155                expr
156            ),
157        }
158        .fail(),
159    }
160}
161
162pub fn validate_client_encoding(set: SetVariables) -> Result<()> {
163    let Some((encoding, [])) = set.value.split_first() else {
164        return InvalidSqlSnafu {
165            err_msg: "must provide one and only one client encoding value",
166        }
167        .fail();
168    };
169    let encoding = match encoding {
170        Expr::Value(ValueWithSpan {
171            value: Value::SingleQuotedString(x),
172            ..
173        })
174        | Expr::Identifier(Ident {
175            value: x,
176            quote_style: _,
177            span: _,
178        }) => x.to_uppercase(),
179        _ => {
180            return InvalidSqlSnafu {
181                err_msg: format!("client encoding must be a string, actual: {:?}", encoding),
182            }
183            .fail();
184        }
185    };
186    // For the sake of simplicity, we only support "UTF8" ("UNICODE" is the alias for it,
187    // see https://www.postgresql.org/docs/current/multibyte.html#MULTIBYTE-CHARSET-SUPPORTED).
188    // "UTF8" is universal and sufficient for almost all cases.
189    // GreptimeDB itself is always using "UTF8" as the internal encoding.
190    ensure!(
191        encoding == "UTF8" || encoding == "UNICODE",
192        NotSupportedSnafu {
193            feat: format!("client encoding of '{}'", encoding)
194        }
195    );
196    Ok(())
197}
198
199// if one of original value and new value is none, return the other one
200// returns new values only when it equals to original one else return error.
201// This is only used for handling datestyle
202fn merge_datestyle_value<T>(value: Option<T>, new_value: Option<T>) -> Result<Option<T>>
203where
204    T: PartialEq,
205{
206    match (&value, &new_value) {
207        (None, _) => Ok(new_value),
208        (_, None) => Ok(value),
209        (Some(v1), Some(v2)) if v1 == v2 => Ok(new_value),
210        _ => InvalidSqlSnafu {
211            err_msg: "Conflicting \"datestyle\" specifications.",
212        }
213        .fail(),
214    }
215}
216
217fn try_parse_datestyle(expr: &Expr) -> Result<(Option<PGDateTimeStyle>, Option<PGDateOrder>)> {
218    enum ParsedDateStyle {
219        Order(PGDateOrder),
220        Style(PGDateTimeStyle),
221    }
222    fn try_parse_str(s: &str) -> Result<ParsedDateStyle> {
223        PGDateTimeStyle::try_from(s)
224            .map_or_else(
225                |_| PGDateOrder::try_from(s).map(ParsedDateStyle::Order),
226                |style| Ok(ParsedDateStyle::Style(style)),
227            )
228            .context(InvalidConfigValueSnafu)
229    }
230    match expr {
231        Expr::Identifier(Ident {
232            value: s,
233            quote_style: _,
234            span: _,
235        })
236        | Expr::Value(ValueWithSpan {
237            value: Value::SingleQuotedString(s),
238            ..
239        })
240        | Expr::Value(ValueWithSpan {
241            value: Value::DoubleQuotedString(s),
242            ..
243        }) => s
244            .split(',')
245            .map(|s| s.trim())
246            .try_fold((None, None), |(style, order), s| match try_parse_str(s)? {
247                ParsedDateStyle::Order(o) => Ok((style, merge_datestyle_value(order, Some(o))?)),
248                ParsedDateStyle::Style(s) => Ok((merge_datestyle_value(style, Some(s))?, order)),
249            }),
250        _ => NotSupportedSnafu {
251            feat: "Not supported expression for datestyle",
252        }
253        .fail(),
254    }
255}
256
257/// Set the allow query fallback configuration parameter to true or false based on the provided expressions.
258///
259pub fn set_allow_query_fallback(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
260    let allow_fallback_expr = exprs.first().context(NotSupportedSnafu {
261        feat: "No allow query fallback value find in set variable statement",
262    })?;
263    match allow_fallback_expr {
264        Expr::Value(ValueWithSpan {
265            value: Value::Boolean(allow),
266            span: _,
267        }) => {
268            ctx.configuration_parameter()
269                .set_allow_query_fallback(*allow);
270            Ok(())
271        }
272        expr => NotSupportedSnafu {
273            feat: format!(
274                "Unsupported allow query fallback expr {} in set variable statement",
275                expr
276            ),
277        }
278        .fail(),
279    }
280}
281
282pub fn set_datestyle(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
283    // ORDER,
284    // STYLE,
285    // ORDER,ORDER
286    // ORDER,STYLE
287    // STYLE,ORDER
288    let (style, order) = exprs
289        .iter()
290        .try_fold((None, None), |(style, order), expr| {
291            let (new_style, new_order) = try_parse_datestyle(expr)?;
292            Ok((
293                merge_datestyle_value(style, new_style)?,
294                merge_datestyle_value(order, new_order)?,
295            ))
296        })?;
297
298    let (old_style, older_order) = *ctx.configuration_parameter().pg_datetime_style();
299    ctx.configuration_parameter()
300        .set_pg_datetime_style(style.unwrap_or(old_style), order.unwrap_or(older_order));
301    Ok(())
302}
303
304pub fn set_query_timeout(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
305    let timeout_expr = exprs.first().context(NotSupportedSnafu {
306        feat: "No timeout value find in set query timeout statement",
307    })?;
308    match timeout_expr {
309        Expr::Value(ValueWithSpan {
310            value: Value::Number(timeout, _),
311            ..
312        }) => {
313            match timeout.parse::<u64>() {
314                Ok(timeout) => ctx.set_query_timeout(Duration::from_millis(timeout)),
315                Err(_) => {
316                    return NotSupportedSnafu {
317                        feat: format!("Invalid timeout expr {} in set variable statement", timeout),
318                    }
319                    .fail()
320                }
321            }
322            Ok(())
323        }
324        // postgres support time units i.e. SET STATEMENT_TIMEOUT = '50ms';
325        Expr::Value(ValueWithSpan {
326            value: Value::SingleQuotedString(timeout),
327            ..
328        })
329        | Expr::Value(ValueWithSpan {
330            value: Value::DoubleQuotedString(timeout),
331            ..
332        }) => {
333            if ctx.channel() != Postgres {
334                return NotSupportedSnafu {
335                    feat: format!("Invalid timeout expr {} in set variable statement", timeout),
336                }
337                .fail();
338            }
339            let timeout = parse_pg_query_timeout_input(timeout)?;
340            ctx.set_query_timeout(Duration::from_millis(timeout));
341            Ok(())
342        }
343        expr => NotSupportedSnafu {
344            feat: format!(
345                "Unsupported timeout expr {} in set variable statement",
346                expr
347            ),
348        }
349        .fail(),
350    }
351}
352
353// support time units in ms, s, min, h, d for postgres protocol.
354// https://www.postgresql.org/docs/8.4/config-setting.html#:~:text=Valid%20memory%20units%20are%20kB,%2C%20and%20d%20(days).
355fn parse_pg_query_timeout_input(input: &str) -> Result<u64> {
356    match input.parse::<u64>() {
357        Ok(timeout) => Ok(timeout),
358        Err(_) => {
359            if let Some(captures) = PG_TIME_INPUT_REGEX.captures(input) {
360                let value = captures[1].parse::<u64>().expect("regex failed");
361                let unit = &captures[2];
362
363                match unit {
364                    "ms" => Ok(value),
365                    "s" => Ok(value * 1000),
366                    "min" => Ok(value * 60 * 1000),
367                    "h" => Ok(value * 60 * 60 * 1000),
368                    "d" => Ok(value * 24 * 60 * 60 * 1000),
369                    _ => unreachable!("regex failed"),
370                }
371            } else {
372                NotSupportedSnafu {
373                    feat: format!(
374                        "Unsupported timeout expr {} in set variable statement",
375                        input
376                    ),
377                }
378                .fail()
379            }
380        }
381    }
382}
383
384#[cfg(test)]
385mod test {
386    use crate::statement::set::parse_pg_query_timeout_input;
387
388    #[test]
389    fn test_parse_pg_query_timeout_input() {
390        assert!(parse_pg_query_timeout_input("").is_err());
391        assert!(parse_pg_query_timeout_input(" 50 ms").is_err());
392        assert!(parse_pg_query_timeout_input("5s 1ms").is_err());
393        assert!(parse_pg_query_timeout_input("3a").is_err());
394        assert!(parse_pg_query_timeout_input("1.5min").is_err());
395        assert!(parse_pg_query_timeout_input("ms").is_err());
396        assert!(parse_pg_query_timeout_input("a").is_err());
397        assert!(parse_pg_query_timeout_input("-1").is_err());
398
399        assert_eq!(50, parse_pg_query_timeout_input("50").unwrap());
400        assert_eq!(12, parse_pg_query_timeout_input("12ms").unwrap());
401        assert_eq!(2000, parse_pg_query_timeout_input("2s").unwrap());
402        assert_eq!(60000, parse_pg_query_timeout_input("1min").unwrap());
403    }
404}