1use std::str::FromStr;
16use std::time::Duration;
17
18use common_time::Timezone;
19use lazy_static::lazy_static;
20use regex::Regex;
21use session::context::Channel::Postgres;
22use session::context::QueryContextRef;
23use session::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle};
24use session::ReadPreference;
25use snafu::{ensure, OptionExt, ResultExt};
26use sql::ast::{Expr, Ident, Value};
27use sql::statements::set_variables::SetVariables;
28
29use crate::error::{InvalidConfigValueSnafu, InvalidSqlSnafu, NotSupportedSnafu, Result};
30
31lazy_static! {
32 static ref PG_TIME_INPUT_REGEX: Regex = Regex::new(r"^(\d+)(ms|s|min|h|d)$").unwrap();
38}
39
40pub fn set_read_preference(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
41 let read_preference_expr = exprs.first().context(NotSupportedSnafu {
42 feat: "No read preference find in set variable statement",
43 })?;
44
45 match read_preference_expr {
46 Expr::Value(Value::SingleQuotedString(expr))
47 | Expr::Value(Value::DoubleQuotedString(expr)) => {
48 match ReadPreference::from_str(expr.as_str().to_lowercase().as_str()) {
49 Ok(read_preference) => ctx.set_read_preference(read_preference),
50 Err(_) => {
51 return NotSupportedSnafu {
52 feat: format!(
53 "Invalid read preference expr {} in set variable statement",
54 expr,
55 ),
56 }
57 .fail()
58 }
59 }
60 Ok(())
61 }
62 expr => NotSupportedSnafu {
63 feat: format!(
64 "Unsupported read preference expr {} in set variable statement",
65 expr
66 ),
67 }
68 .fail(),
69 }
70}
71
72pub fn set_timezone(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
73 let tz_expr = exprs.first().context(NotSupportedSnafu {
74 feat: "No timezone find in set variable statement",
75 })?;
76 match tz_expr {
77 Expr::Value(Value::SingleQuotedString(tz)) | Expr::Value(Value::DoubleQuotedString(tz)) => {
78 match Timezone::from_tz_string(tz.as_str()) {
79 Ok(timezone) => ctx.set_timezone(timezone),
80 Err(_) => {
81 return NotSupportedSnafu {
82 feat: format!("Invalid timezone expr {} in set variable statement", tz),
83 }
84 .fail()
85 }
86 }
87 Ok(())
88 }
89 expr => NotSupportedSnafu {
90 feat: format!(
91 "Unsupported timezone expr {} in set variable statement",
92 expr
93 ),
94 }
95 .fail(),
96 }
97}
98
99pub fn set_bytea_output(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
100 let Some((var_value, [])) = exprs.split_first() else {
101 return (NotSupportedSnafu {
102 feat: "Set variable value must have one and only one value for bytea_output",
103 })
104 .fail();
105 };
106 let Expr::Value(value) = var_value else {
107 return (NotSupportedSnafu {
108 feat: "Set variable value must be a value",
109 })
110 .fail();
111 };
112 ctx.configuration_parameter().set_postgres_bytea_output(
113 PGByteaOutputValue::try_from(value.clone()).context(InvalidConfigValueSnafu)?,
114 );
115 Ok(())
116}
117
118pub fn set_search_path(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
119 let search_expr = exprs.first().context(NotSupportedSnafu {
120 feat: "No search path find in set variable statement",
121 })?;
122 match search_expr {
123 Expr::Value(Value::SingleQuotedString(search_path))
124 | Expr::Value(Value::DoubleQuotedString(search_path)) => {
125 ctx.set_current_schema(&search_path.clone());
126 Ok(())
127 }
128 expr => NotSupportedSnafu {
129 feat: format!(
130 "Unsupported search path expr {} in set variable statement",
131 expr
132 ),
133 }
134 .fail(),
135 }
136}
137
138pub fn validate_client_encoding(set: SetVariables) -> Result<()> {
139 let Some((encoding, [])) = set.value.split_first() else {
140 return InvalidSqlSnafu {
141 err_msg: "must provide one and only one client encoding value",
142 }
143 .fail();
144 };
145 let encoding = match encoding {
146 Expr::Value(Value::SingleQuotedString(x))
147 | Expr::Identifier(Ident {
148 value: x,
149 quote_style: _,
150 span: _,
151 }) => x.to_uppercase(),
152 _ => {
153 return InvalidSqlSnafu {
154 err_msg: format!("client encoding must be a string, actual: {:?}", encoding),
155 }
156 .fail();
157 }
158 };
159 ensure!(
164 encoding == "UTF8" || encoding == "UNICODE",
165 NotSupportedSnafu {
166 feat: format!("client encoding of '{}'", encoding)
167 }
168 );
169 Ok(())
170}
171
172fn merge_datestyle_value<T>(value: Option<T>, new_value: Option<T>) -> Result<Option<T>>
176where
177 T: PartialEq,
178{
179 match (&value, &new_value) {
180 (None, _) => Ok(new_value),
181 (_, None) => Ok(value),
182 (Some(v1), Some(v2)) if v1 == v2 => Ok(new_value),
183 _ => InvalidSqlSnafu {
184 err_msg: "Conflicting \"datestyle\" specifications.",
185 }
186 .fail(),
187 }
188}
189
190fn try_parse_datestyle(expr: &Expr) -> Result<(Option<PGDateTimeStyle>, Option<PGDateOrder>)> {
191 enum ParsedDateStyle {
192 Order(PGDateOrder),
193 Style(PGDateTimeStyle),
194 }
195 fn try_parse_str(s: &str) -> Result<ParsedDateStyle> {
196 PGDateTimeStyle::try_from(s)
197 .map_or_else(
198 |_| PGDateOrder::try_from(s).map(ParsedDateStyle::Order),
199 |style| Ok(ParsedDateStyle::Style(style)),
200 )
201 .context(InvalidConfigValueSnafu)
202 }
203 match expr {
204 Expr::Identifier(Ident {
205 value: s,
206 quote_style: _,
207 span: _,
208 })
209 | Expr::Value(Value::SingleQuotedString(s))
210 | Expr::Value(Value::DoubleQuotedString(s)) => {
211 s.split(',')
212 .map(|s| s.trim())
213 .try_fold((None, None), |(style, order), s| match try_parse_str(s)? {
214 ParsedDateStyle::Order(o) => {
215 Ok((style, merge_datestyle_value(order, Some(o))?))
216 }
217 ParsedDateStyle::Style(s) => {
218 Ok((merge_datestyle_value(style, Some(s))?, order))
219 }
220 })
221 }
222 _ => NotSupportedSnafu {
223 feat: "Not supported expression for datestyle",
224 }
225 .fail(),
226 }
227}
228
229pub fn set_datestyle(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
230 let (style, order) = exprs
236 .iter()
237 .try_fold((None, None), |(style, order), expr| {
238 let (new_style, new_order) = try_parse_datestyle(expr)?;
239 Ok((
240 merge_datestyle_value(style, new_style)?,
241 merge_datestyle_value(order, new_order)?,
242 ))
243 })?;
244
245 let (old_style, older_order) = *ctx.configuration_parameter().pg_datetime_style();
246 ctx.configuration_parameter()
247 .set_pg_datetime_style(style.unwrap_or(old_style), order.unwrap_or(older_order));
248 Ok(())
249}
250
251pub fn set_query_timeout(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
252 let timeout_expr = exprs.first().context(NotSupportedSnafu {
253 feat: "No timeout value find in set query timeout statement",
254 })?;
255 match timeout_expr {
256 Expr::Value(Value::Number(timeout, _)) => {
257 match timeout.parse::<u64>() {
258 Ok(timeout) => ctx.set_query_timeout(Duration::from_millis(timeout)),
259 Err(_) => {
260 return NotSupportedSnafu {
261 feat: format!("Invalid timeout expr {} in set variable statement", timeout),
262 }
263 .fail()
264 }
265 }
266 Ok(())
267 }
268 Expr::Value(Value::SingleQuotedString(timeout))
270 | Expr::Value(Value::DoubleQuotedString(timeout)) => {
271 if ctx.channel() != Postgres {
272 return NotSupportedSnafu {
273 feat: format!("Invalid timeout expr {} in set variable statement", timeout),
274 }
275 .fail();
276 }
277 let timeout = parse_pg_query_timeout_input(timeout)?;
278 ctx.set_query_timeout(Duration::from_millis(timeout));
279 Ok(())
280 }
281 expr => NotSupportedSnafu {
282 feat: format!(
283 "Unsupported timeout expr {} in set variable statement",
284 expr
285 ),
286 }
287 .fail(),
288 }
289}
290
291fn parse_pg_query_timeout_input(input: &str) -> Result<u64> {
294 match input.parse::<u64>() {
295 Ok(timeout) => Ok(timeout),
296 Err(_) => {
297 if let Some(captures) = PG_TIME_INPUT_REGEX.captures(input) {
298 let value = captures[1].parse::<u64>().expect("regex failed");
299 let unit = &captures[2];
300
301 match unit {
302 "ms" => Ok(value),
303 "s" => Ok(value * 1000),
304 "min" => Ok(value * 60 * 1000),
305 "h" => Ok(value * 60 * 60 * 1000),
306 "d" => Ok(value * 24 * 60 * 60 * 1000),
307 _ => unreachable!("regex failed"),
308 }
309 } else {
310 NotSupportedSnafu {
311 feat: format!(
312 "Unsupported timeout expr {} in set variable statement",
313 input
314 ),
315 }
316 .fail()
317 }
318 }
319 }
320}
321
322#[cfg(test)]
323mod test {
324 use crate::statement::set::parse_pg_query_timeout_input;
325
326 #[test]
327 fn test_parse_pg_query_timeout_input() {
328 assert!(parse_pg_query_timeout_input("").is_err());
329 assert!(parse_pg_query_timeout_input(" 50 ms").is_err());
330 assert!(parse_pg_query_timeout_input("5s 1ms").is_err());
331 assert!(parse_pg_query_timeout_input("3a").is_err());
332 assert!(parse_pg_query_timeout_input("1.5min").is_err());
333 assert!(parse_pg_query_timeout_input("ms").is_err());
334 assert!(parse_pg_query_timeout_input("a").is_err());
335 assert!(parse_pg_query_timeout_input("-1").is_err());
336
337 assert_eq!(50, parse_pg_query_timeout_input("50").unwrap());
338 assert_eq!(12, parse_pg_query_timeout_input("12ms").unwrap());
339 assert_eq!(2000, parse_pg_query_timeout_input("2s").unwrap());
340 assert_eq!(60000, parse_pg_query_timeout_input("1min").unwrap());
341 }
342}