1use std::str::FromStr;
16use std::time::Duration;
17
18use common_time::Timezone;
19use lazy_static::lazy_static;
20use regex::Regex;
21use session::context::Channel::Postgres;
22use session::context::QueryContextRef;
23use session::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle};
24use session::ReadPreference;
25use snafu::{ensure, OptionExt, ResultExt};
26use sql::ast::{Expr, Ident, Value};
27use sql::statements::set_variables::SetVariables;
28
29use crate::error::{InvalidConfigValueSnafu, InvalidSqlSnafu, NotSupportedSnafu, Result};
30
31lazy_static! {
32 static ref PG_TIME_INPUT_REGEX: Regex = Regex::new(r"^(\d+)(ms|s|min|h|d)$").unwrap();
38}
39
40pub fn set_read_preference(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
41 let read_preference_expr = exprs.first().context(NotSupportedSnafu {
42 feat: "No read preference find in set variable statement",
43 })?;
44
45 match read_preference_expr {
46 Expr::Value(Value::SingleQuotedString(expr))
47 | Expr::Value(Value::DoubleQuotedString(expr)) => {
48 match ReadPreference::from_str(expr.as_str().to_lowercase().as_str()) {
49 Ok(read_preference) => ctx.set_read_preference(read_preference),
50 Err(_) => {
51 return NotSupportedSnafu {
52 feat: format!(
53 "Invalid read preference expr {} in set variable statement",
54 expr,
55 ),
56 }
57 .fail()
58 }
59 }
60 Ok(())
61 }
62 expr => NotSupportedSnafu {
63 feat: format!(
64 "Unsupported read preference expr {} in set variable statement",
65 expr
66 ),
67 }
68 .fail(),
69 }
70}
71
72pub fn set_timezone(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
73 let tz_expr = exprs.first().context(NotSupportedSnafu {
74 feat: "No timezone find in set variable statement",
75 })?;
76 match tz_expr {
77 Expr::Value(Value::SingleQuotedString(tz)) | Expr::Value(Value::DoubleQuotedString(tz)) => {
78 match Timezone::from_tz_string(tz.as_str()) {
79 Ok(timezone) => ctx.set_timezone(timezone),
80 Err(_) => {
81 return NotSupportedSnafu {
82 feat: format!("Invalid timezone expr {} in set variable statement", tz),
83 }
84 .fail()
85 }
86 }
87 Ok(())
88 }
89 expr => NotSupportedSnafu {
90 feat: format!(
91 "Unsupported timezone expr {} in set variable statement",
92 expr
93 ),
94 }
95 .fail(),
96 }
97}
98
99pub fn set_bytea_output(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
100 let Some((var_value, [])) = exprs.split_first() else {
101 return (NotSupportedSnafu {
102 feat: "Set variable value must have one and only one value for bytea_output",
103 })
104 .fail();
105 };
106 let Expr::Value(value) = var_value else {
107 return (NotSupportedSnafu {
108 feat: "Set variable value must be a value",
109 })
110 .fail();
111 };
112 ctx.configuration_parameter().set_postgres_bytea_output(
113 PGByteaOutputValue::try_from(value.clone()).context(InvalidConfigValueSnafu)?,
114 );
115 Ok(())
116}
117
118pub fn set_search_path(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
119 let search_expr = exprs.first().context(NotSupportedSnafu {
120 feat: "No search path find in set variable statement",
121 })?;
122 match search_expr {
123 Expr::Value(Value::SingleQuotedString(search_path))
124 | Expr::Value(Value::DoubleQuotedString(search_path)) => {
125 ctx.set_current_schema(search_path);
126 Ok(())
127 }
128 Expr::Identifier(Ident { value, .. }) => {
129 ctx.set_current_schema(value);
130 Ok(())
131 }
132 expr => NotSupportedSnafu {
133 feat: format!(
134 "Unsupported search path expr {} in set variable statement",
135 expr
136 ),
137 }
138 .fail(),
139 }
140}
141
142pub fn validate_client_encoding(set: SetVariables) -> Result<()> {
143 let Some((encoding, [])) = set.value.split_first() else {
144 return InvalidSqlSnafu {
145 err_msg: "must provide one and only one client encoding value",
146 }
147 .fail();
148 };
149 let encoding = match encoding {
150 Expr::Value(Value::SingleQuotedString(x))
151 | Expr::Identifier(Ident {
152 value: x,
153 quote_style: _,
154 span: _,
155 }) => x.to_uppercase(),
156 _ => {
157 return InvalidSqlSnafu {
158 err_msg: format!("client encoding must be a string, actual: {:?}", encoding),
159 }
160 .fail();
161 }
162 };
163 ensure!(
168 encoding == "UTF8" || encoding == "UNICODE",
169 NotSupportedSnafu {
170 feat: format!("client encoding of '{}'", encoding)
171 }
172 );
173 Ok(())
174}
175
176fn merge_datestyle_value<T>(value: Option<T>, new_value: Option<T>) -> Result<Option<T>>
180where
181 T: PartialEq,
182{
183 match (&value, &new_value) {
184 (None, _) => Ok(new_value),
185 (_, None) => Ok(value),
186 (Some(v1), Some(v2)) if v1 == v2 => Ok(new_value),
187 _ => InvalidSqlSnafu {
188 err_msg: "Conflicting \"datestyle\" specifications.",
189 }
190 .fail(),
191 }
192}
193
194fn try_parse_datestyle(expr: &Expr) -> Result<(Option<PGDateTimeStyle>, Option<PGDateOrder>)> {
195 enum ParsedDateStyle {
196 Order(PGDateOrder),
197 Style(PGDateTimeStyle),
198 }
199 fn try_parse_str(s: &str) -> Result<ParsedDateStyle> {
200 PGDateTimeStyle::try_from(s)
201 .map_or_else(
202 |_| PGDateOrder::try_from(s).map(ParsedDateStyle::Order),
203 |style| Ok(ParsedDateStyle::Style(style)),
204 )
205 .context(InvalidConfigValueSnafu)
206 }
207 match expr {
208 Expr::Identifier(Ident {
209 value: s,
210 quote_style: _,
211 span: _,
212 })
213 | Expr::Value(Value::SingleQuotedString(s))
214 | Expr::Value(Value::DoubleQuotedString(s)) => {
215 s.split(',')
216 .map(|s| s.trim())
217 .try_fold((None, None), |(style, order), s| match try_parse_str(s)? {
218 ParsedDateStyle::Order(o) => {
219 Ok((style, merge_datestyle_value(order, Some(o))?))
220 }
221 ParsedDateStyle::Style(s) => {
222 Ok((merge_datestyle_value(style, Some(s))?, order))
223 }
224 })
225 }
226 _ => NotSupportedSnafu {
227 feat: "Not supported expression for datestyle",
228 }
229 .fail(),
230 }
231}
232
233pub fn set_datestyle(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
234 let (style, order) = exprs
240 .iter()
241 .try_fold((None, None), |(style, order), expr| {
242 let (new_style, new_order) = try_parse_datestyle(expr)?;
243 Ok((
244 merge_datestyle_value(style, new_style)?,
245 merge_datestyle_value(order, new_order)?,
246 ))
247 })?;
248
249 let (old_style, older_order) = *ctx.configuration_parameter().pg_datetime_style();
250 ctx.configuration_parameter()
251 .set_pg_datetime_style(style.unwrap_or(old_style), order.unwrap_or(older_order));
252 Ok(())
253}
254
255pub fn set_query_timeout(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
256 let timeout_expr = exprs.first().context(NotSupportedSnafu {
257 feat: "No timeout value find in set query timeout statement",
258 })?;
259 match timeout_expr {
260 Expr::Value(Value::Number(timeout, _)) => {
261 match timeout.parse::<u64>() {
262 Ok(timeout) => ctx.set_query_timeout(Duration::from_millis(timeout)),
263 Err(_) => {
264 return NotSupportedSnafu {
265 feat: format!("Invalid timeout expr {} in set variable statement", timeout),
266 }
267 .fail()
268 }
269 }
270 Ok(())
271 }
272 Expr::Value(Value::SingleQuotedString(timeout))
274 | Expr::Value(Value::DoubleQuotedString(timeout)) => {
275 if ctx.channel() != Postgres {
276 return NotSupportedSnafu {
277 feat: format!("Invalid timeout expr {} in set variable statement", timeout),
278 }
279 .fail();
280 }
281 let timeout = parse_pg_query_timeout_input(timeout)?;
282 ctx.set_query_timeout(Duration::from_millis(timeout));
283 Ok(())
284 }
285 expr => NotSupportedSnafu {
286 feat: format!(
287 "Unsupported timeout expr {} in set variable statement",
288 expr
289 ),
290 }
291 .fail(),
292 }
293}
294
295fn parse_pg_query_timeout_input(input: &str) -> Result<u64> {
298 match input.parse::<u64>() {
299 Ok(timeout) => Ok(timeout),
300 Err(_) => {
301 if let Some(captures) = PG_TIME_INPUT_REGEX.captures(input) {
302 let value = captures[1].parse::<u64>().expect("regex failed");
303 let unit = &captures[2];
304
305 match unit {
306 "ms" => Ok(value),
307 "s" => Ok(value * 1000),
308 "min" => Ok(value * 60 * 1000),
309 "h" => Ok(value * 60 * 60 * 1000),
310 "d" => Ok(value * 24 * 60 * 60 * 1000),
311 _ => unreachable!("regex failed"),
312 }
313 } else {
314 NotSupportedSnafu {
315 feat: format!(
316 "Unsupported timeout expr {} in set variable statement",
317 input
318 ),
319 }
320 .fail()
321 }
322 }
323 }
324}
325
326#[cfg(test)]
327mod test {
328 use crate::statement::set::parse_pg_query_timeout_input;
329
330 #[test]
331 fn test_parse_pg_query_timeout_input() {
332 assert!(parse_pg_query_timeout_input("").is_err());
333 assert!(parse_pg_query_timeout_input(" 50 ms").is_err());
334 assert!(parse_pg_query_timeout_input("5s 1ms").is_err());
335 assert!(parse_pg_query_timeout_input("3a").is_err());
336 assert!(parse_pg_query_timeout_input("1.5min").is_err());
337 assert!(parse_pg_query_timeout_input("ms").is_err());
338 assert!(parse_pg_query_timeout_input("a").is_err());
339 assert!(parse_pg_query_timeout_input("-1").is_err());
340
341 assert_eq!(50, parse_pg_query_timeout_input("50").unwrap());
342 assert_eq!(12, parse_pg_query_timeout_input("12ms").unwrap());
343 assert_eq!(2000, parse_pg_query_timeout_input("2s").unwrap());
344 assert_eq!(60000, parse_pg_query_timeout_input("1min").unwrap());
345 }
346}