operator/statement/
set.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::str::FromStr;
16use std::time::Duration;
17
18use common_time::Timezone;
19use lazy_static::lazy_static;
20use regex::Regex;
21use session::ReadPreference;
22use session::context::Channel::Postgres;
23use session::context::QueryContextRef;
24use session::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle, PGIntervalStyle};
25use snafu::{OptionExt, ResultExt, ensure};
26use sql::ast::{Expr, Ident, Value};
27use sql::statements::set_variables::SetVariables;
28use sqlparser::ast::ValueWithSpan;
29
30use crate::error::{InvalidConfigValueSnafu, InvalidSqlSnafu, NotSupportedSnafu, Result};
31
32lazy_static! {
33    // Regex rules:
34    // The string must start with a number (one or more digits).
35    // The number must be followed by one of the valid time units (ms, s, min, h, d).
36    // The string must end immediately after the unit, meaning there can be no extra
37    // characters or spaces after the valid time specification.
38    static ref PG_TIME_INPUT_REGEX: Regex = Regex::new(r"^(\d+)(ms|s|min|h|d)$").unwrap();
39}
40
41pub fn set_read_preference(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
42    let read_preference_expr = exprs.first().context(NotSupportedSnafu {
43        feat: "No read preference find in set variable statement",
44    })?;
45
46    match read_preference_expr {
47        Expr::Value(ValueWithSpan {
48            value: Value::SingleQuotedString(expr),
49            ..
50        })
51        | Expr::Value(ValueWithSpan {
52            value: Value::DoubleQuotedString(expr),
53            ..
54        }) => {
55            match ReadPreference::from_str(expr.as_str().to_lowercase().as_str()) {
56                Ok(read_preference) => ctx.set_read_preference(read_preference),
57                Err(_) => {
58                    return NotSupportedSnafu {
59                        feat: format!(
60                            "Invalid read preference expr {} in set variable statement",
61                            expr,
62                        ),
63                    }
64                    .fail();
65                }
66            }
67            Ok(())
68        }
69        expr => NotSupportedSnafu {
70            feat: format!(
71                "Unsupported read preference expr {} in set variable statement",
72                expr
73            ),
74        }
75        .fail(),
76    }
77}
78
79pub fn set_timezone(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
80    let tz_expr = exprs.first().context(NotSupportedSnafu {
81        feat: "No timezone find in set variable statement",
82    })?;
83    match tz_expr {
84        Expr::Value(ValueWithSpan {
85            value: Value::SingleQuotedString(tz),
86            ..
87        })
88        | Expr::Value(ValueWithSpan {
89            value: Value::DoubleQuotedString(tz),
90            ..
91        }) => {
92            match Timezone::from_tz_string(tz.as_str()) {
93                Ok(timezone) => ctx.set_timezone(timezone),
94                Err(_) => {
95                    return NotSupportedSnafu {
96                        feat: format!("Invalid timezone expr {} in set variable statement", tz),
97                    }
98                    .fail();
99                }
100            }
101            Ok(())
102        }
103        expr => NotSupportedSnafu {
104            feat: format!(
105                "Unsupported timezone expr {} in set variable statement",
106                expr
107            ),
108        }
109        .fail(),
110    }
111}
112
113pub fn set_bytea_output(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
114    let Some((var_value, [])) = exprs.split_first() else {
115        return (NotSupportedSnafu {
116            feat: "Set variable value must have one and only one value for bytea_output",
117        })
118        .fail();
119    };
120    let Expr::Value(value) = var_value else {
121        return (NotSupportedSnafu {
122            feat: "Set variable value must be a value",
123        })
124        .fail();
125    };
126    ctx.configuration_parameter().set_postgres_bytea_output(
127        PGByteaOutputValue::try_from(value.value.clone()).context(InvalidConfigValueSnafu)?,
128    );
129    Ok(())
130}
131
132pub fn set_search_path(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
133    let search_expr = exprs.first().context(NotSupportedSnafu {
134        feat: "No search path find in set variable statement",
135    })?;
136    match search_expr {
137        Expr::Value(ValueWithSpan {
138            value: Value::SingleQuotedString(search_path),
139            ..
140        })
141        | Expr::Value(ValueWithSpan {
142            value: Value::DoubleQuotedString(search_path),
143            ..
144        }) => {
145            ctx.set_current_schema(search_path);
146            Ok(())
147        }
148        Expr::Identifier(Ident { value, .. }) => {
149            ctx.set_current_schema(value);
150            Ok(())
151        }
152        expr => NotSupportedSnafu {
153            feat: format!(
154                "Unsupported search path expr {} in set variable statement",
155                expr
156            ),
157        }
158        .fail(),
159    }
160}
161
162pub fn validate_client_encoding(set: SetVariables) -> Result<()> {
163    let Some((encoding, [])) = set.value.split_first() else {
164        return InvalidSqlSnafu {
165            err_msg: "must provide one and only one client encoding value",
166        }
167        .fail();
168    };
169    let encoding = match encoding {
170        Expr::Value(ValueWithSpan {
171            value: Value::SingleQuotedString(x),
172            ..
173        })
174        | Expr::Identifier(Ident {
175            value: x,
176            quote_style: _,
177            span: _,
178        }) => x.to_uppercase(),
179        _ => {
180            return InvalidSqlSnafu {
181                err_msg: format!("client encoding must be a string, actual: {:?}", encoding),
182            }
183            .fail();
184        }
185    };
186    // For the sake of simplicity, we only support "UTF8" ("UNICODE" is the alias for it,
187    // see https://www.postgresql.org/docs/current/multibyte.html#MULTIBYTE-CHARSET-SUPPORTED).
188    // "UTF8" is universal and sufficient for almost all cases.
189    // GreptimeDB itself is always using "UTF8" as the internal encoding.
190    ensure!(
191        encoding == "UTF8" || encoding == "UNICODE",
192        NotSupportedSnafu {
193            feat: format!("client encoding of '{}'", encoding)
194        }
195    );
196    Ok(())
197}
198
199// if one of original value and new value is none, return the other one
200// returns new values only when it equals to original one else return error.
201// This is only used for handling datestyle
202fn merge_datestyle_value<T>(value: Option<T>, new_value: Option<T>) -> Result<Option<T>>
203where
204    T: PartialEq,
205{
206    match (&value, &new_value) {
207        (None, _) => Ok(new_value),
208        (_, None) => Ok(value),
209        (Some(v1), Some(v2)) if v1 == v2 => Ok(new_value),
210        _ => InvalidSqlSnafu {
211            err_msg: "Conflicting \"datestyle\" specifications.",
212        }
213        .fail(),
214    }
215}
216
217fn try_parse_datestyle(expr: &Expr) -> Result<(Option<PGDateTimeStyle>, Option<PGDateOrder>)> {
218    enum ParsedDateStyle {
219        Order(PGDateOrder),
220        Style(PGDateTimeStyle),
221    }
222    fn try_parse_str(s: &str) -> Result<ParsedDateStyle> {
223        PGDateTimeStyle::try_from(s)
224            .map_or_else(
225                |_| PGDateOrder::try_from(s).map(ParsedDateStyle::Order),
226                |style| Ok(ParsedDateStyle::Style(style)),
227            )
228            .context(InvalidConfigValueSnafu)
229    }
230    match expr {
231        Expr::Identifier(Ident {
232            value: s,
233            quote_style: _,
234            span: _,
235        })
236        | Expr::Value(ValueWithSpan {
237            value: Value::SingleQuotedString(s),
238            ..
239        })
240        | Expr::Value(ValueWithSpan {
241            value: Value::DoubleQuotedString(s),
242            ..
243        }) => s
244            .split(',')
245            .map(|s| s.trim())
246            .try_fold((None, None), |(style, order), s| match try_parse_str(s)? {
247                ParsedDateStyle::Order(o) => Ok((style, merge_datestyle_value(order, Some(o))?)),
248                ParsedDateStyle::Style(s) => Ok((merge_datestyle_value(style, Some(s))?, order)),
249            }),
250        _ => NotSupportedSnafu {
251            feat: "Not supported expression for datestyle",
252        }
253        .fail(),
254    }
255}
256
257/// Set the allow query fallback configuration parameter to true or false based on the provided expressions.
258///
259pub fn set_allow_query_fallback(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
260    let allow_fallback_expr = exprs.first().context(NotSupportedSnafu {
261        feat: "No allow query fallback value find in set variable statement",
262    })?;
263    match allow_fallback_expr {
264        Expr::Value(ValueWithSpan {
265            value: Value::Boolean(allow),
266            span: _,
267        }) => {
268            ctx.configuration_parameter()
269                .set_allow_query_fallback(*allow);
270            Ok(())
271        }
272        expr => NotSupportedSnafu {
273            feat: format!(
274                "Unsupported allow query fallback expr {} in set variable statement",
275                expr
276            ),
277        }
278        .fail(),
279    }
280}
281
282pub fn set_intervalstyle(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
283    let Some((var_value, [])) = exprs.split_first() else {
284        return NotSupportedSnafu {
285            feat: "Set variable value must have one and only one value for intervalstyle",
286        }
287        .fail();
288    };
289    let Expr::Value(value) = var_value else {
290        return NotSupportedSnafu {
291            feat: "Set variable value must be a value",
292        }
293        .fail();
294    };
295    ctx.configuration_parameter().set_pg_intervalstyle_format(
296        PGIntervalStyle::try_from(&value.value).context(InvalidConfigValueSnafu)?,
297    );
298    Ok(())
299}
300
301pub fn set_datestyle(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
302    // ORDER,
303    // STYLE,
304    // ORDER,ORDER
305    // ORDER,STYLE
306    // STYLE,ORDER
307    let (style, order) = exprs
308        .iter()
309        .try_fold((None, None), |(style, order), expr| {
310            let (new_style, new_order) = try_parse_datestyle(expr)?;
311            Ok((
312                merge_datestyle_value(style, new_style)?,
313                merge_datestyle_value(order, new_order)?,
314            ))
315        })?;
316
317    let (old_style, older_order) = *ctx.configuration_parameter().pg_datetime_style();
318    ctx.configuration_parameter()
319        .set_pg_datetime_style(style.unwrap_or(old_style), order.unwrap_or(older_order));
320    Ok(())
321}
322
323pub fn set_query_timeout(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
324    let timeout_expr = exprs.first().context(NotSupportedSnafu {
325        feat: "No timeout value find in set query timeout statement",
326    })?;
327    match timeout_expr {
328        Expr::Value(ValueWithSpan {
329            value: Value::Number(timeout, _),
330            ..
331        }) => {
332            match timeout.parse::<u64>() {
333                Ok(timeout) => ctx.set_query_timeout(Duration::from_millis(timeout)),
334                Err(_) => {
335                    return NotSupportedSnafu {
336                        feat: format!("Invalid timeout expr {} in set variable statement", timeout),
337                    }
338                    .fail();
339                }
340            }
341            Ok(())
342        }
343        // postgres support time units i.e. SET STATEMENT_TIMEOUT = '50ms';
344        Expr::Value(ValueWithSpan {
345            value: Value::SingleQuotedString(timeout),
346            ..
347        })
348        | Expr::Value(ValueWithSpan {
349            value: Value::DoubleQuotedString(timeout),
350            ..
351        }) => {
352            if ctx.channel() != Postgres {
353                return NotSupportedSnafu {
354                    feat: format!("Invalid timeout expr {} in set variable statement", timeout),
355                }
356                .fail();
357            }
358            let timeout = parse_pg_query_timeout_input(timeout)?;
359            ctx.set_query_timeout(Duration::from_millis(timeout));
360            Ok(())
361        }
362        expr => NotSupportedSnafu {
363            feat: format!(
364                "Unsupported timeout expr {} in set variable statement",
365                expr
366            ),
367        }
368        .fail(),
369    }
370}
371
372// support time units in ms, s, min, h, d for postgres protocol.
373// https://www.postgresql.org/docs/8.4/config-setting.html#:~:text=Valid%20memory%20units%20are%20kB,%2C%20and%20d%20(days).
374fn parse_pg_query_timeout_input(input: &str) -> Result<u64> {
375    match input.parse::<u64>() {
376        Ok(timeout) => Ok(timeout),
377        Err(_) => {
378            if let Some(captures) = PG_TIME_INPUT_REGEX.captures(input) {
379                let value = captures[1].parse::<u64>().expect("regex failed");
380                let unit = &captures[2];
381
382                match unit {
383                    "ms" => Ok(value),
384                    "s" => Ok(value * 1000),
385                    "min" => Ok(value * 60 * 1000),
386                    "h" => Ok(value * 60 * 60 * 1000),
387                    "d" => Ok(value * 24 * 60 * 60 * 1000),
388                    _ => unreachable!("regex failed"),
389                }
390            } else {
391                NotSupportedSnafu {
392                    feat: format!(
393                        "Unsupported timeout expr {} in set variable statement",
394                        input
395                    ),
396                }
397                .fail()
398            }
399        }
400    }
401}
402
403#[cfg(test)]
404mod test {
405    use crate::statement::set::parse_pg_query_timeout_input;
406
407    #[test]
408    fn test_parse_pg_query_timeout_input() {
409        assert!(parse_pg_query_timeout_input("").is_err());
410        assert!(parse_pg_query_timeout_input(" 50 ms").is_err());
411        assert!(parse_pg_query_timeout_input("5s 1ms").is_err());
412        assert!(parse_pg_query_timeout_input("3a").is_err());
413        assert!(parse_pg_query_timeout_input("1.5min").is_err());
414        assert!(parse_pg_query_timeout_input("ms").is_err());
415        assert!(parse_pg_query_timeout_input("a").is_err());
416        assert!(parse_pg_query_timeout_input("-1").is_err());
417
418        assert_eq!(50, parse_pg_query_timeout_input("50").unwrap());
419        assert_eq!(12, parse_pg_query_timeout_input("12ms").unwrap());
420        assert_eq!(2000, parse_pg_query_timeout_input("2s").unwrap());
421        assert_eq!(60000, parse_pg_query_timeout_input("1min").unwrap());
422    }
423}