1use std::str::FromStr;
16use std::time::Duration;
17
18use common_time::Timezone;
19use lazy_static::lazy_static;
20use regex::Regex;
21use session::ReadPreference;
22use session::context::Channel::Postgres;
23use session::context::QueryContextRef;
24use session::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle, PGIntervalStyle};
25use snafu::{OptionExt, ResultExt, ensure};
26use sql::ast::{Expr, Ident, Value};
27use sql::statements::set_variables::SetVariables;
28use sqlparser::ast::ValueWithSpan;
29
30use crate::error::{InvalidConfigValueSnafu, InvalidSqlSnafu, NotSupportedSnafu, Result};
31
32lazy_static! {
33 static ref PG_TIME_INPUT_REGEX: Regex = Regex::new(r"^(\d+)(ms|s|min|h|d)$").unwrap();
39}
40
41pub fn set_read_preference(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
42 let read_preference_expr = exprs.first().context(NotSupportedSnafu {
43 feat: "No read preference find in set variable statement",
44 })?;
45
46 match read_preference_expr {
47 Expr::Value(ValueWithSpan {
48 value: Value::SingleQuotedString(expr),
49 ..
50 })
51 | Expr::Value(ValueWithSpan {
52 value: Value::DoubleQuotedString(expr),
53 ..
54 }) => {
55 match ReadPreference::from_str(expr.as_str().to_lowercase().as_str()) {
56 Ok(read_preference) => ctx.set_read_preference(read_preference),
57 Err(_) => {
58 return NotSupportedSnafu {
59 feat: format!(
60 "Invalid read preference expr {} in set variable statement",
61 expr,
62 ),
63 }
64 .fail();
65 }
66 }
67 Ok(())
68 }
69 expr => NotSupportedSnafu {
70 feat: format!(
71 "Unsupported read preference expr {} in set variable statement",
72 expr
73 ),
74 }
75 .fail(),
76 }
77}
78
79pub fn set_timezone(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
80 let tz_expr = exprs.first().context(NotSupportedSnafu {
81 feat: "No timezone find in set variable statement",
82 })?;
83 match tz_expr {
84 Expr::Value(ValueWithSpan {
85 value: Value::SingleQuotedString(tz),
86 ..
87 })
88 | Expr::Value(ValueWithSpan {
89 value: Value::DoubleQuotedString(tz),
90 ..
91 }) => {
92 match Timezone::from_tz_string(tz.as_str()) {
93 Ok(timezone) => ctx.set_timezone(timezone),
94 Err(_) => {
95 return NotSupportedSnafu {
96 feat: format!("Invalid timezone expr {} in set variable statement", tz),
97 }
98 .fail();
99 }
100 }
101 Ok(())
102 }
103 expr => NotSupportedSnafu {
104 feat: format!(
105 "Unsupported timezone expr {} in set variable statement",
106 expr
107 ),
108 }
109 .fail(),
110 }
111}
112
113pub fn set_bytea_output(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
114 let Some((var_value, [])) = exprs.split_first() else {
115 return (NotSupportedSnafu {
116 feat: "Set variable value must have one and only one value for bytea_output",
117 })
118 .fail();
119 };
120 let Expr::Value(value) = var_value else {
121 return (NotSupportedSnafu {
122 feat: "Set variable value must be a value",
123 })
124 .fail();
125 };
126 ctx.configuration_parameter().set_postgres_bytea_output(
127 PGByteaOutputValue::try_from(value.value.clone()).context(InvalidConfigValueSnafu)?,
128 );
129 Ok(())
130}
131
132pub fn set_search_path(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
133 let search_expr = exprs.first().context(NotSupportedSnafu {
134 feat: "No search path find in set variable statement",
135 })?;
136 match search_expr {
137 Expr::Value(ValueWithSpan {
138 value: Value::SingleQuotedString(search_path),
139 ..
140 })
141 | Expr::Value(ValueWithSpan {
142 value: Value::DoubleQuotedString(search_path),
143 ..
144 }) => {
145 ctx.set_current_schema(search_path);
146 Ok(())
147 }
148 Expr::Identifier(Ident { value, .. }) => {
149 ctx.set_current_schema(value);
150 Ok(())
151 }
152 expr => NotSupportedSnafu {
153 feat: format!(
154 "Unsupported search path expr {} in set variable statement",
155 expr
156 ),
157 }
158 .fail(),
159 }
160}
161
162pub fn validate_client_encoding(set: SetVariables) -> Result<()> {
163 let Some((encoding, [])) = set.value.split_first() else {
164 return InvalidSqlSnafu {
165 err_msg: "must provide one and only one client encoding value",
166 }
167 .fail();
168 };
169 let encoding = match encoding {
170 Expr::Value(ValueWithSpan {
171 value: Value::SingleQuotedString(x),
172 ..
173 })
174 | Expr::Identifier(Ident {
175 value: x,
176 quote_style: _,
177 span: _,
178 }) => x.to_uppercase(),
179 _ => {
180 return InvalidSqlSnafu {
181 err_msg: format!("client encoding must be a string, actual: {:?}", encoding),
182 }
183 .fail();
184 }
185 };
186 ensure!(
191 encoding == "UTF8" || encoding == "UNICODE",
192 NotSupportedSnafu {
193 feat: format!("client encoding of '{}'", encoding)
194 }
195 );
196 Ok(())
197}
198
199fn merge_datestyle_value<T>(value: Option<T>, new_value: Option<T>) -> Result<Option<T>>
203where
204 T: PartialEq,
205{
206 match (&value, &new_value) {
207 (None, _) => Ok(new_value),
208 (_, None) => Ok(value),
209 (Some(v1), Some(v2)) if v1 == v2 => Ok(new_value),
210 _ => InvalidSqlSnafu {
211 err_msg: "Conflicting \"datestyle\" specifications.",
212 }
213 .fail(),
214 }
215}
216
217fn try_parse_datestyle(expr: &Expr) -> Result<(Option<PGDateTimeStyle>, Option<PGDateOrder>)> {
218 enum ParsedDateStyle {
219 Order(PGDateOrder),
220 Style(PGDateTimeStyle),
221 }
222 fn try_parse_str(s: &str) -> Result<ParsedDateStyle> {
223 PGDateTimeStyle::try_from(s)
224 .map_or_else(
225 |_| PGDateOrder::try_from(s).map(ParsedDateStyle::Order),
226 |style| Ok(ParsedDateStyle::Style(style)),
227 )
228 .context(InvalidConfigValueSnafu)
229 }
230 match expr {
231 Expr::Identifier(Ident {
232 value: s,
233 quote_style: _,
234 span: _,
235 })
236 | Expr::Value(ValueWithSpan {
237 value: Value::SingleQuotedString(s),
238 ..
239 })
240 | Expr::Value(ValueWithSpan {
241 value: Value::DoubleQuotedString(s),
242 ..
243 }) => s
244 .split(',')
245 .map(|s| s.trim())
246 .try_fold((None, None), |(style, order), s| match try_parse_str(s)? {
247 ParsedDateStyle::Order(o) => Ok((style, merge_datestyle_value(order, Some(o))?)),
248 ParsedDateStyle::Style(s) => Ok((merge_datestyle_value(style, Some(s))?, order)),
249 }),
250 _ => NotSupportedSnafu {
251 feat: "Not supported expression for datestyle",
252 }
253 .fail(),
254 }
255}
256
257pub fn set_allow_query_fallback(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
260 let allow_fallback_expr = exprs.first().context(NotSupportedSnafu {
261 feat: "No allow query fallback value find in set variable statement",
262 })?;
263 match allow_fallback_expr {
264 Expr::Value(ValueWithSpan {
265 value: Value::Boolean(allow),
266 span: _,
267 }) => {
268 ctx.configuration_parameter()
269 .set_allow_query_fallback(*allow);
270 Ok(())
271 }
272 expr => NotSupportedSnafu {
273 feat: format!(
274 "Unsupported allow query fallback expr {} in set variable statement",
275 expr
276 ),
277 }
278 .fail(),
279 }
280}
281
282pub fn set_intervalstyle(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
283 let Some((var_value, [])) = exprs.split_first() else {
284 return NotSupportedSnafu {
285 feat: "Set variable value must have one and only one value for intervalstyle",
286 }
287 .fail();
288 };
289 let Expr::Value(value) = var_value else {
290 return NotSupportedSnafu {
291 feat: "Set variable value must be a value",
292 }
293 .fail();
294 };
295 ctx.configuration_parameter().set_pg_intervalstyle_format(
296 PGIntervalStyle::try_from(&value.value).context(InvalidConfigValueSnafu)?,
297 );
298 Ok(())
299}
300
301pub fn set_datestyle(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
302 let (style, order) = exprs
308 .iter()
309 .try_fold((None, None), |(style, order), expr| {
310 let (new_style, new_order) = try_parse_datestyle(expr)?;
311 Ok((
312 merge_datestyle_value(style, new_style)?,
313 merge_datestyle_value(order, new_order)?,
314 ))
315 })?;
316
317 let (old_style, older_order) = *ctx.configuration_parameter().pg_datetime_style();
318 ctx.configuration_parameter()
319 .set_pg_datetime_style(style.unwrap_or(old_style), order.unwrap_or(older_order));
320 Ok(())
321}
322
323pub fn set_query_timeout(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
324 let timeout_expr = exprs.first().context(NotSupportedSnafu {
325 feat: "No timeout value find in set query timeout statement",
326 })?;
327 match timeout_expr {
328 Expr::Value(ValueWithSpan {
329 value: Value::Number(timeout, _),
330 ..
331 }) => {
332 match timeout.parse::<u64>() {
333 Ok(timeout) => ctx.set_query_timeout(Duration::from_millis(timeout)),
334 Err(_) => {
335 return NotSupportedSnafu {
336 feat: format!("Invalid timeout expr {} in set variable statement", timeout),
337 }
338 .fail();
339 }
340 }
341 Ok(())
342 }
343 Expr::Value(ValueWithSpan {
345 value: Value::SingleQuotedString(timeout),
346 ..
347 })
348 | Expr::Value(ValueWithSpan {
349 value: Value::DoubleQuotedString(timeout),
350 ..
351 }) => {
352 if ctx.channel() != Postgres {
353 return NotSupportedSnafu {
354 feat: format!("Invalid timeout expr {} in set variable statement", timeout),
355 }
356 .fail();
357 }
358 let timeout = parse_pg_query_timeout_input(timeout)?;
359 ctx.set_query_timeout(Duration::from_millis(timeout));
360 Ok(())
361 }
362 expr => NotSupportedSnafu {
363 feat: format!(
364 "Unsupported timeout expr {} in set variable statement",
365 expr
366 ),
367 }
368 .fail(),
369 }
370}
371
372fn parse_pg_query_timeout_input(input: &str) -> Result<u64> {
375 match input.parse::<u64>() {
376 Ok(timeout) => Ok(timeout),
377 Err(_) => {
378 if let Some(captures) = PG_TIME_INPUT_REGEX.captures(input) {
379 let value = captures[1].parse::<u64>().expect("regex failed");
380 let unit = &captures[2];
381
382 match unit {
383 "ms" => Ok(value),
384 "s" => Ok(value * 1000),
385 "min" => Ok(value * 60 * 1000),
386 "h" => Ok(value * 60 * 60 * 1000),
387 "d" => Ok(value * 24 * 60 * 60 * 1000),
388 _ => unreachable!("regex failed"),
389 }
390 } else {
391 NotSupportedSnafu {
392 feat: format!(
393 "Unsupported timeout expr {} in set variable statement",
394 input
395 ),
396 }
397 .fail()
398 }
399 }
400 }
401}
402
403#[cfg(test)]
404mod test {
405 use crate::statement::set::parse_pg_query_timeout_input;
406
407 #[test]
408 fn test_parse_pg_query_timeout_input() {
409 assert!(parse_pg_query_timeout_input("").is_err());
410 assert!(parse_pg_query_timeout_input(" 50 ms").is_err());
411 assert!(parse_pg_query_timeout_input("5s 1ms").is_err());
412 assert!(parse_pg_query_timeout_input("3a").is_err());
413 assert!(parse_pg_query_timeout_input("1.5min").is_err());
414 assert!(parse_pg_query_timeout_input("ms").is_err());
415 assert!(parse_pg_query_timeout_input("a").is_err());
416 assert!(parse_pg_query_timeout_input("-1").is_err());
417
418 assert_eq!(50, parse_pg_query_timeout_input("50").unwrap());
419 assert_eq!(12, parse_pg_query_timeout_input("12ms").unwrap());
420 assert_eq!(2000, parse_pg_query_timeout_input("2s").unwrap());
421 assert_eq!(60000, parse_pg_query_timeout_input("1min").unwrap());
422 }
423}