1use std::str::FromStr;
16use std::time::Duration;
17
18use common_time::Timezone;
19use lazy_static::lazy_static;
20use regex::Regex;
21use session::context::Channel::Postgres;
22use session::context::QueryContextRef;
23use session::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle};
24use session::ReadPreference;
25use snafu::{ensure, OptionExt, ResultExt};
26use sql::ast::{Expr, Ident, Value};
27use sql::statements::set_variables::SetVariables;
28
29use crate::error::{InvalidConfigValueSnafu, InvalidSqlSnafu, NotSupportedSnafu, Result};
30
31lazy_static! {
32 static ref PG_TIME_INPUT_REGEX: Regex = Regex::new(r"^(\d+)(ms|s|min|h|d)$").unwrap();
38}
39
40pub fn set_read_preference(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
41 let read_preference_expr = exprs.first().context(NotSupportedSnafu {
42 feat: "No read preference find in set variable statement",
43 })?;
44
45 match read_preference_expr {
46 Expr::Value(Value::SingleQuotedString(expr))
47 | Expr::Value(Value::DoubleQuotedString(expr)) => {
48 match ReadPreference::from_str(expr.as_str().to_lowercase().as_str()) {
49 Ok(read_preference) => ctx.set_read_preference(read_preference),
50 Err(_) => {
51 return NotSupportedSnafu {
52 feat: format!(
53 "Invalid read preference expr {} in set variable statement",
54 expr,
55 ),
56 }
57 .fail()
58 }
59 }
60 Ok(())
61 }
62 expr => NotSupportedSnafu {
63 feat: format!(
64 "Unsupported read preference expr {} in set variable statement",
65 expr
66 ),
67 }
68 .fail(),
69 }
70}
71
72pub fn set_timezone(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
73 let tz_expr = exprs.first().context(NotSupportedSnafu {
74 feat: "No timezone find in set variable statement",
75 })?;
76 match tz_expr {
77 Expr::Value(Value::SingleQuotedString(tz)) | Expr::Value(Value::DoubleQuotedString(tz)) => {
78 match Timezone::from_tz_string(tz.as_str()) {
79 Ok(timezone) => ctx.set_timezone(timezone),
80 Err(_) => {
81 return NotSupportedSnafu {
82 feat: format!("Invalid timezone expr {} in set variable statement", tz),
83 }
84 .fail()
85 }
86 }
87 Ok(())
88 }
89 expr => NotSupportedSnafu {
90 feat: format!(
91 "Unsupported timezone expr {} in set variable statement",
92 expr
93 ),
94 }
95 .fail(),
96 }
97}
98
99pub fn set_bytea_output(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
100 let Some((var_value, [])) = exprs.split_first() else {
101 return (NotSupportedSnafu {
102 feat: "Set variable value must have one and only one value for bytea_output",
103 })
104 .fail();
105 };
106 let Expr::Value(value) = var_value else {
107 return (NotSupportedSnafu {
108 feat: "Set variable value must be a value",
109 })
110 .fail();
111 };
112 ctx.configuration_parameter().set_postgres_bytea_output(
113 PGByteaOutputValue::try_from(value.clone()).context(InvalidConfigValueSnafu)?,
114 );
115 Ok(())
116}
117
118pub fn set_search_path(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
119 let search_expr = exprs.first().context(NotSupportedSnafu {
120 feat: "No search path find in set variable statement",
121 })?;
122 match search_expr {
123 Expr::Value(Value::SingleQuotedString(search_path))
124 | Expr::Value(Value::DoubleQuotedString(search_path)) => {
125 ctx.set_current_schema(search_path);
126 Ok(())
127 }
128 Expr::Identifier(Ident { value, .. }) => {
129 ctx.set_current_schema(value);
130 Ok(())
131 }
132 expr => NotSupportedSnafu {
133 feat: format!(
134 "Unsupported search path expr {} in set variable statement",
135 expr
136 ),
137 }
138 .fail(),
139 }
140}
141
142pub fn validate_client_encoding(set: SetVariables) -> Result<()> {
143 let Some((encoding, [])) = set.value.split_first() else {
144 return InvalidSqlSnafu {
145 err_msg: "must provide one and only one client encoding value",
146 }
147 .fail();
148 };
149 let encoding = match encoding {
150 Expr::Value(Value::SingleQuotedString(x))
151 | Expr::Identifier(Ident {
152 value: x,
153 quote_style: _,
154 span: _,
155 }) => x.to_uppercase(),
156 _ => {
157 return InvalidSqlSnafu {
158 err_msg: format!("client encoding must be a string, actual: {:?}", encoding),
159 }
160 .fail();
161 }
162 };
163 ensure!(
168 encoding == "UTF8" || encoding == "UNICODE",
169 NotSupportedSnafu {
170 feat: format!("client encoding of '{}'", encoding)
171 }
172 );
173 Ok(())
174}
175
176fn merge_datestyle_value<T>(value: Option<T>, new_value: Option<T>) -> Result<Option<T>>
180where
181 T: PartialEq,
182{
183 match (&value, &new_value) {
184 (None, _) => Ok(new_value),
185 (_, None) => Ok(value),
186 (Some(v1), Some(v2)) if v1 == v2 => Ok(new_value),
187 _ => InvalidSqlSnafu {
188 err_msg: "Conflicting \"datestyle\" specifications.",
189 }
190 .fail(),
191 }
192}
193
194fn try_parse_datestyle(expr: &Expr) -> Result<(Option<PGDateTimeStyle>, Option<PGDateOrder>)> {
195 enum ParsedDateStyle {
196 Order(PGDateOrder),
197 Style(PGDateTimeStyle),
198 }
199 fn try_parse_str(s: &str) -> Result<ParsedDateStyle> {
200 PGDateTimeStyle::try_from(s)
201 .map_or_else(
202 |_| PGDateOrder::try_from(s).map(ParsedDateStyle::Order),
203 |style| Ok(ParsedDateStyle::Style(style)),
204 )
205 .context(InvalidConfigValueSnafu)
206 }
207 match expr {
208 Expr::Identifier(Ident {
209 value: s,
210 quote_style: _,
211 span: _,
212 })
213 | Expr::Value(Value::SingleQuotedString(s))
214 | Expr::Value(Value::DoubleQuotedString(s)) => {
215 s.split(',')
216 .map(|s| s.trim())
217 .try_fold((None, None), |(style, order), s| match try_parse_str(s)? {
218 ParsedDateStyle::Order(o) => {
219 Ok((style, merge_datestyle_value(order, Some(o))?))
220 }
221 ParsedDateStyle::Style(s) => {
222 Ok((merge_datestyle_value(style, Some(s))?, order))
223 }
224 })
225 }
226 _ => NotSupportedSnafu {
227 feat: "Not supported expression for datestyle",
228 }
229 .fail(),
230 }
231}
232
233pub fn set_allow_query_fallback(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
236 let allow_fallback_expr = exprs.first().context(NotSupportedSnafu {
237 feat: "No allow query fallback value find in set variable statement",
238 })?;
239 match allow_fallback_expr {
240 Expr::Value(Value::Boolean(allow)) => {
241 ctx.configuration_parameter()
242 .set_allow_query_fallback(*allow);
243 Ok(())
244 }
245 expr => NotSupportedSnafu {
246 feat: format!(
247 "Unsupported allow query fallback expr {} in set variable statement",
248 expr
249 ),
250 }
251 .fail(),
252 }
253}
254
255pub fn set_datestyle(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
256 let (style, order) = exprs
262 .iter()
263 .try_fold((None, None), |(style, order), expr| {
264 let (new_style, new_order) = try_parse_datestyle(expr)?;
265 Ok((
266 merge_datestyle_value(style, new_style)?,
267 merge_datestyle_value(order, new_order)?,
268 ))
269 })?;
270
271 let (old_style, older_order) = *ctx.configuration_parameter().pg_datetime_style();
272 ctx.configuration_parameter()
273 .set_pg_datetime_style(style.unwrap_or(old_style), order.unwrap_or(older_order));
274 Ok(())
275}
276
277pub fn set_query_timeout(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
278 let timeout_expr = exprs.first().context(NotSupportedSnafu {
279 feat: "No timeout value find in set query timeout statement",
280 })?;
281 match timeout_expr {
282 Expr::Value(Value::Number(timeout, _)) => {
283 match timeout.parse::<u64>() {
284 Ok(timeout) => ctx.set_query_timeout(Duration::from_millis(timeout)),
285 Err(_) => {
286 return NotSupportedSnafu {
287 feat: format!("Invalid timeout expr {} in set variable statement", timeout),
288 }
289 .fail()
290 }
291 }
292 Ok(())
293 }
294 Expr::Value(Value::SingleQuotedString(timeout))
296 | Expr::Value(Value::DoubleQuotedString(timeout)) => {
297 if ctx.channel() != Postgres {
298 return NotSupportedSnafu {
299 feat: format!("Invalid timeout expr {} in set variable statement", timeout),
300 }
301 .fail();
302 }
303 let timeout = parse_pg_query_timeout_input(timeout)?;
304 ctx.set_query_timeout(Duration::from_millis(timeout));
305 Ok(())
306 }
307 expr => NotSupportedSnafu {
308 feat: format!(
309 "Unsupported timeout expr {} in set variable statement",
310 expr
311 ),
312 }
313 .fail(),
314 }
315}
316
317fn parse_pg_query_timeout_input(input: &str) -> Result<u64> {
320 match input.parse::<u64>() {
321 Ok(timeout) => Ok(timeout),
322 Err(_) => {
323 if let Some(captures) = PG_TIME_INPUT_REGEX.captures(input) {
324 let value = captures[1].parse::<u64>().expect("regex failed");
325 let unit = &captures[2];
326
327 match unit {
328 "ms" => Ok(value),
329 "s" => Ok(value * 1000),
330 "min" => Ok(value * 60 * 1000),
331 "h" => Ok(value * 60 * 60 * 1000),
332 "d" => Ok(value * 24 * 60 * 60 * 1000),
333 _ => unreachable!("regex failed"),
334 }
335 } else {
336 NotSupportedSnafu {
337 feat: format!(
338 "Unsupported timeout expr {} in set variable statement",
339 input
340 ),
341 }
342 .fail()
343 }
344 }
345 }
346}
347
348#[cfg(test)]
349mod test {
350 use crate::statement::set::parse_pg_query_timeout_input;
351
352 #[test]
353 fn test_parse_pg_query_timeout_input() {
354 assert!(parse_pg_query_timeout_input("").is_err());
355 assert!(parse_pg_query_timeout_input(" 50 ms").is_err());
356 assert!(parse_pg_query_timeout_input("5s 1ms").is_err());
357 assert!(parse_pg_query_timeout_input("3a").is_err());
358 assert!(parse_pg_query_timeout_input("1.5min").is_err());
359 assert!(parse_pg_query_timeout_input("ms").is_err());
360 assert!(parse_pg_query_timeout_input("a").is_err());
361 assert!(parse_pg_query_timeout_input("-1").is_err());
362
363 assert_eq!(50, parse_pg_query_timeout_input("50").unwrap());
364 assert_eq!(12, parse_pg_query_timeout_input("12ms").unwrap());
365 assert_eq!(2000, parse_pg_query_timeout_input("2s").unwrap());
366 assert_eq!(60000, parse_pg_query_timeout_input("1min").unwrap());
367 }
368}