1use std::str::FromStr;
16use std::time::Duration;
17
18use common_time::Timezone;
19use lazy_static::lazy_static;
20use regex::Regex;
21use session::context::Channel::Postgres;
22use session::context::QueryContextRef;
23use session::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle};
24use session::ReadPreference;
25use snafu::{ensure, OptionExt, ResultExt};
26use sql::ast::{Expr, Ident, Value};
27use sql::statements::set_variables::SetVariables;
28use sqlparser::ast::ValueWithSpan;
29
30use crate::error::{InvalidConfigValueSnafu, InvalidSqlSnafu, NotSupportedSnafu, Result};
31
32lazy_static! {
33 static ref PG_TIME_INPUT_REGEX: Regex = Regex::new(r"^(\d+)(ms|s|min|h|d)$").unwrap();
39}
40
41pub fn set_read_preference(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
42 let read_preference_expr = exprs.first().context(NotSupportedSnafu {
43 feat: "No read preference find in set variable statement",
44 })?;
45
46 match read_preference_expr {
47 Expr::Value(ValueWithSpan {
48 value: Value::SingleQuotedString(expr),
49 ..
50 })
51 | Expr::Value(ValueWithSpan {
52 value: Value::DoubleQuotedString(expr),
53 ..
54 }) => {
55 match ReadPreference::from_str(expr.as_str().to_lowercase().as_str()) {
56 Ok(read_preference) => ctx.set_read_preference(read_preference),
57 Err(_) => {
58 return NotSupportedSnafu {
59 feat: format!(
60 "Invalid read preference expr {} in set variable statement",
61 expr,
62 ),
63 }
64 .fail()
65 }
66 }
67 Ok(())
68 }
69 expr => NotSupportedSnafu {
70 feat: format!(
71 "Unsupported read preference expr {} in set variable statement",
72 expr
73 ),
74 }
75 .fail(),
76 }
77}
78
79pub fn set_timezone(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
80 let tz_expr = exprs.first().context(NotSupportedSnafu {
81 feat: "No timezone find in set variable statement",
82 })?;
83 match tz_expr {
84 Expr::Value(ValueWithSpan {
85 value: Value::SingleQuotedString(tz),
86 ..
87 })
88 | Expr::Value(ValueWithSpan {
89 value: Value::DoubleQuotedString(tz),
90 ..
91 }) => {
92 match Timezone::from_tz_string(tz.as_str()) {
93 Ok(timezone) => ctx.set_timezone(timezone),
94 Err(_) => {
95 return NotSupportedSnafu {
96 feat: format!("Invalid timezone expr {} in set variable statement", tz),
97 }
98 .fail()
99 }
100 }
101 Ok(())
102 }
103 expr => NotSupportedSnafu {
104 feat: format!(
105 "Unsupported timezone expr {} in set variable statement",
106 expr
107 ),
108 }
109 .fail(),
110 }
111}
112
113pub fn set_bytea_output(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
114 let Some((var_value, [])) = exprs.split_first() else {
115 return (NotSupportedSnafu {
116 feat: "Set variable value must have one and only one value for bytea_output",
117 })
118 .fail();
119 };
120 let Expr::Value(value) = var_value else {
121 return (NotSupportedSnafu {
122 feat: "Set variable value must be a value",
123 })
124 .fail();
125 };
126 ctx.configuration_parameter().set_postgres_bytea_output(
127 PGByteaOutputValue::try_from(value.value.clone()).context(InvalidConfigValueSnafu)?,
128 );
129 Ok(())
130}
131
132pub fn set_search_path(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
133 let search_expr = exprs.first().context(NotSupportedSnafu {
134 feat: "No search path find in set variable statement",
135 })?;
136 match search_expr {
137 Expr::Value(ValueWithSpan {
138 value: Value::SingleQuotedString(search_path),
139 ..
140 })
141 | Expr::Value(ValueWithSpan {
142 value: Value::DoubleQuotedString(search_path),
143 ..
144 }) => {
145 ctx.set_current_schema(search_path);
146 Ok(())
147 }
148 Expr::Identifier(Ident { value, .. }) => {
149 ctx.set_current_schema(value);
150 Ok(())
151 }
152 expr => NotSupportedSnafu {
153 feat: format!(
154 "Unsupported search path expr {} in set variable statement",
155 expr
156 ),
157 }
158 .fail(),
159 }
160}
161
162pub fn validate_client_encoding(set: SetVariables) -> Result<()> {
163 let Some((encoding, [])) = set.value.split_first() else {
164 return InvalidSqlSnafu {
165 err_msg: "must provide one and only one client encoding value",
166 }
167 .fail();
168 };
169 let encoding = match encoding {
170 Expr::Value(ValueWithSpan {
171 value: Value::SingleQuotedString(x),
172 ..
173 })
174 | Expr::Identifier(Ident {
175 value: x,
176 quote_style: _,
177 span: _,
178 }) => x.to_uppercase(),
179 _ => {
180 return InvalidSqlSnafu {
181 err_msg: format!("client encoding must be a string, actual: {:?}", encoding),
182 }
183 .fail();
184 }
185 };
186 ensure!(
191 encoding == "UTF8" || encoding == "UNICODE",
192 NotSupportedSnafu {
193 feat: format!("client encoding of '{}'", encoding)
194 }
195 );
196 Ok(())
197}
198
199fn merge_datestyle_value<T>(value: Option<T>, new_value: Option<T>) -> Result<Option<T>>
203where
204 T: PartialEq,
205{
206 match (&value, &new_value) {
207 (None, _) => Ok(new_value),
208 (_, None) => Ok(value),
209 (Some(v1), Some(v2)) if v1 == v2 => Ok(new_value),
210 _ => InvalidSqlSnafu {
211 err_msg: "Conflicting \"datestyle\" specifications.",
212 }
213 .fail(),
214 }
215}
216
217fn try_parse_datestyle(expr: &Expr) -> Result<(Option<PGDateTimeStyle>, Option<PGDateOrder>)> {
218 enum ParsedDateStyle {
219 Order(PGDateOrder),
220 Style(PGDateTimeStyle),
221 }
222 fn try_parse_str(s: &str) -> Result<ParsedDateStyle> {
223 PGDateTimeStyle::try_from(s)
224 .map_or_else(
225 |_| PGDateOrder::try_from(s).map(ParsedDateStyle::Order),
226 |style| Ok(ParsedDateStyle::Style(style)),
227 )
228 .context(InvalidConfigValueSnafu)
229 }
230 match expr {
231 Expr::Identifier(Ident {
232 value: s,
233 quote_style: _,
234 span: _,
235 })
236 | Expr::Value(ValueWithSpan {
237 value: Value::SingleQuotedString(s),
238 ..
239 })
240 | Expr::Value(ValueWithSpan {
241 value: Value::DoubleQuotedString(s),
242 ..
243 }) => s
244 .split(',')
245 .map(|s| s.trim())
246 .try_fold((None, None), |(style, order), s| match try_parse_str(s)? {
247 ParsedDateStyle::Order(o) => Ok((style, merge_datestyle_value(order, Some(o))?)),
248 ParsedDateStyle::Style(s) => Ok((merge_datestyle_value(style, Some(s))?, order)),
249 }),
250 _ => NotSupportedSnafu {
251 feat: "Not supported expression for datestyle",
252 }
253 .fail(),
254 }
255}
256
257pub fn set_allow_query_fallback(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
260 let allow_fallback_expr = exprs.first().context(NotSupportedSnafu {
261 feat: "No allow query fallback value find in set variable statement",
262 })?;
263 match allow_fallback_expr {
264 Expr::Value(ValueWithSpan {
265 value: Value::Boolean(allow),
266 span: _,
267 }) => {
268 ctx.configuration_parameter()
269 .set_allow_query_fallback(*allow);
270 Ok(())
271 }
272 expr => NotSupportedSnafu {
273 feat: format!(
274 "Unsupported allow query fallback expr {} in set variable statement",
275 expr
276 ),
277 }
278 .fail(),
279 }
280}
281
282pub fn set_datestyle(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
283 let (style, order) = exprs
289 .iter()
290 .try_fold((None, None), |(style, order), expr| {
291 let (new_style, new_order) = try_parse_datestyle(expr)?;
292 Ok((
293 merge_datestyle_value(style, new_style)?,
294 merge_datestyle_value(order, new_order)?,
295 ))
296 })?;
297
298 let (old_style, older_order) = *ctx.configuration_parameter().pg_datetime_style();
299 ctx.configuration_parameter()
300 .set_pg_datetime_style(style.unwrap_or(old_style), order.unwrap_or(older_order));
301 Ok(())
302}
303
304pub fn set_query_timeout(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
305 let timeout_expr = exprs.first().context(NotSupportedSnafu {
306 feat: "No timeout value find in set query timeout statement",
307 })?;
308 match timeout_expr {
309 Expr::Value(ValueWithSpan {
310 value: Value::Number(timeout, _),
311 ..
312 }) => {
313 match timeout.parse::<u64>() {
314 Ok(timeout) => ctx.set_query_timeout(Duration::from_millis(timeout)),
315 Err(_) => {
316 return NotSupportedSnafu {
317 feat: format!("Invalid timeout expr {} in set variable statement", timeout),
318 }
319 .fail()
320 }
321 }
322 Ok(())
323 }
324 Expr::Value(ValueWithSpan {
326 value: Value::SingleQuotedString(timeout),
327 ..
328 })
329 | Expr::Value(ValueWithSpan {
330 value: Value::DoubleQuotedString(timeout),
331 ..
332 }) => {
333 if ctx.channel() != Postgres {
334 return NotSupportedSnafu {
335 feat: format!("Invalid timeout expr {} in set variable statement", timeout),
336 }
337 .fail();
338 }
339 let timeout = parse_pg_query_timeout_input(timeout)?;
340 ctx.set_query_timeout(Duration::from_millis(timeout));
341 Ok(())
342 }
343 expr => NotSupportedSnafu {
344 feat: format!(
345 "Unsupported timeout expr {} in set variable statement",
346 expr
347 ),
348 }
349 .fail(),
350 }
351}
352
353fn parse_pg_query_timeout_input(input: &str) -> Result<u64> {
356 match input.parse::<u64>() {
357 Ok(timeout) => Ok(timeout),
358 Err(_) => {
359 if let Some(captures) = PG_TIME_INPUT_REGEX.captures(input) {
360 let value = captures[1].parse::<u64>().expect("regex failed");
361 let unit = &captures[2];
362
363 match unit {
364 "ms" => Ok(value),
365 "s" => Ok(value * 1000),
366 "min" => Ok(value * 60 * 1000),
367 "h" => Ok(value * 60 * 60 * 1000),
368 "d" => Ok(value * 24 * 60 * 60 * 1000),
369 _ => unreachable!("regex failed"),
370 }
371 } else {
372 NotSupportedSnafu {
373 feat: format!(
374 "Unsupported timeout expr {} in set variable statement",
375 input
376 ),
377 }
378 .fail()
379 }
380 }
381 }
382}
383
384#[cfg(test)]
385mod test {
386 use crate::statement::set::parse_pg_query_timeout_input;
387
388 #[test]
389 fn test_parse_pg_query_timeout_input() {
390 assert!(parse_pg_query_timeout_input("").is_err());
391 assert!(parse_pg_query_timeout_input(" 50 ms").is_err());
392 assert!(parse_pg_query_timeout_input("5s 1ms").is_err());
393 assert!(parse_pg_query_timeout_input("3a").is_err());
394 assert!(parse_pg_query_timeout_input("1.5min").is_err());
395 assert!(parse_pg_query_timeout_input("ms").is_err());
396 assert!(parse_pg_query_timeout_input("a").is_err());
397 assert!(parse_pg_query_timeout_input("-1").is_err());
398
399 assert_eq!(50, parse_pg_query_timeout_input("50").unwrap());
400 assert_eq!(12, parse_pg_query_timeout_input("12ms").unwrap());
401 assert_eq!(2000, parse_pg_query_timeout_input("2s").unwrap());
402 assert_eq!(60000, parse_pg_query_timeout_input("1min").unwrap());
403 }
404}