use std::time::Duration;
use common_time::Timezone;
use lazy_static::lazy_static;
use regex::Regex;
use session::context::Channel::Postgres;
use session::context::QueryContextRef;
use session::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle};
use snafu::{ensure, OptionExt, ResultExt};
use sql::ast::{Expr, Ident, Value};
use sql::statements::set_variables::SetVariables;
use crate::error::{InvalidConfigValueSnafu, InvalidSqlSnafu, NotSupportedSnafu, Result};
lazy_static! {
static ref PG_TIME_INPUT_REGEX: Regex = Regex::new(r"^(\d+)(ms|s|min|h|d)$").unwrap();
}
pub fn set_timezone(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
let tz_expr = exprs.first().context(NotSupportedSnafu {
feat: "No timezone find in set variable statement",
})?;
match tz_expr {
Expr::Value(Value::SingleQuotedString(tz)) | Expr::Value(Value::DoubleQuotedString(tz)) => {
match Timezone::from_tz_string(tz.as_str()) {
Ok(timezone) => ctx.set_timezone(timezone),
Err(_) => {
return NotSupportedSnafu {
feat: format!("Invalid timezone expr {} in set variable statement", tz),
}
.fail()
}
}
Ok(())
}
expr => NotSupportedSnafu {
feat: format!(
"Unsupported timezone expr {} in set variable statement",
expr
),
}
.fail(),
}
}
pub fn set_bytea_output(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
let Some((var_value, [])) = exprs.split_first() else {
return (NotSupportedSnafu {
feat: "Set variable value must have one and only one value for bytea_output",
})
.fail();
};
let Expr::Value(value) = var_value else {
return (NotSupportedSnafu {
feat: "Set variable value must be a value",
})
.fail();
};
ctx.configuration_parameter().set_postgres_bytea_output(
PGByteaOutputValue::try_from(value.clone()).context(InvalidConfigValueSnafu)?,
);
Ok(())
}
pub fn validate_client_encoding(set: SetVariables) -> Result<()> {
let Some((encoding, [])) = set.value.split_first() else {
return InvalidSqlSnafu {
err_msg: "must provide one and only one client encoding value",
}
.fail();
};
let encoding = match encoding {
Expr::Value(Value::SingleQuotedString(x))
| Expr::Identifier(Ident {
value: x,
quote_style: _,
}) => x.to_uppercase(),
_ => {
return InvalidSqlSnafu {
err_msg: format!("client encoding must be a string, actual: {:?}", encoding),
}
.fail();
}
};
ensure!(
encoding == "UTF8" || encoding == "UNICODE",
NotSupportedSnafu {
feat: format!("client encoding of '{}'", encoding)
}
);
Ok(())
}
fn merge_datestyle_value<T>(value: Option<T>, new_value: Option<T>) -> Result<Option<T>>
where
T: PartialEq,
{
match (&value, &new_value) {
(None, _) => Ok(new_value),
(_, None) => Ok(value),
(Some(v1), Some(v2)) if v1 == v2 => Ok(new_value),
_ => InvalidSqlSnafu {
err_msg: "Conflicting \"datestyle\" specifications.",
}
.fail(),
}
}
fn try_parse_datestyle(expr: &Expr) -> Result<(Option<PGDateTimeStyle>, Option<PGDateOrder>)> {
enum ParsedDateStyle {
Order(PGDateOrder),
Style(PGDateTimeStyle),
}
fn try_parse_str(s: &str) -> Result<ParsedDateStyle> {
PGDateTimeStyle::try_from(s)
.map_or_else(
|_| PGDateOrder::try_from(s).map(ParsedDateStyle::Order),
|style| Ok(ParsedDateStyle::Style(style)),
)
.context(InvalidConfigValueSnafu)
}
match expr {
Expr::Identifier(Ident {
value: s,
quote_style: _,
})
| Expr::Value(Value::SingleQuotedString(s))
| Expr::Value(Value::DoubleQuotedString(s)) => {
s.split(',')
.map(|s| s.trim())
.try_fold((None, None), |(style, order), s| match try_parse_str(s)? {
ParsedDateStyle::Order(o) => {
Ok((style, merge_datestyle_value(order, Some(o))?))
}
ParsedDateStyle::Style(s) => {
Ok((merge_datestyle_value(style, Some(s))?, order))
}
})
}
_ => NotSupportedSnafu {
feat: "Not supported expression for datestyle",
}
.fail(),
}
}
pub fn set_datestyle(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
let (style, order) = exprs
.iter()
.try_fold((None, None), |(style, order), expr| {
let (new_style, new_order) = try_parse_datestyle(expr)?;
Ok((
merge_datestyle_value(style, new_style)?,
merge_datestyle_value(order, new_order)?,
))
})?;
let (old_style, older_order) = *ctx.configuration_parameter().pg_datetime_style();
ctx.configuration_parameter()
.set_pg_datetime_style(style.unwrap_or(old_style), order.unwrap_or(older_order));
Ok(())
}
pub fn set_query_timeout(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
let timeout_expr = exprs.first().context(NotSupportedSnafu {
feat: "No timeout value find in set query timeout statement",
})?;
match timeout_expr {
Expr::Value(Value::Number(timeout, _)) => {
match timeout.parse::<u64>() {
Ok(timeout) => ctx.set_query_timeout(Duration::from_millis(timeout)),
Err(_) => {
return NotSupportedSnafu {
feat: format!("Invalid timeout expr {} in set variable statement", timeout),
}
.fail()
}
}
Ok(())
}
Expr::Value(Value::SingleQuotedString(timeout))
| Expr::Value(Value::DoubleQuotedString(timeout)) => {
if ctx.channel() != Postgres {
return NotSupportedSnafu {
feat: format!("Invalid timeout expr {} in set variable statement", timeout),
}
.fail();
}
let timeout = parse_pg_query_timeout_input(timeout)?;
ctx.set_query_timeout(Duration::from_millis(timeout));
Ok(())
}
expr => NotSupportedSnafu {
feat: format!(
"Unsupported timeout expr {} in set variable statement",
expr
),
}
.fail(),
}
}
fn parse_pg_query_timeout_input(input: &str) -> Result<u64> {
match input.parse::<u64>() {
Ok(timeout) => Ok(timeout),
Err(_) => {
if let Some(captures) = PG_TIME_INPUT_REGEX.captures(input) {
let value = captures[1].parse::<u64>().expect("regex failed");
let unit = &captures[2];
match unit {
"ms" => Ok(value),
"s" => Ok(value * 1000),
"min" => Ok(value * 60 * 1000),
"h" => Ok(value * 60 * 60 * 1000),
"d" => Ok(value * 24 * 60 * 60 * 1000),
_ => unreachable!("regex failed"),
}
} else {
NotSupportedSnafu {
feat: format!(
"Unsupported timeout expr {} in set variable statement",
input
),
}
.fail()
}
}
}
}
#[cfg(test)]
mod test {
use crate::statement::set::parse_pg_query_timeout_input;
#[test]
fn test_parse_pg_query_timeout_input() {
assert!(parse_pg_query_timeout_input("").is_err());
assert!(parse_pg_query_timeout_input(" 50 ms").is_err());
assert!(parse_pg_query_timeout_input("5s 1ms").is_err());
assert!(parse_pg_query_timeout_input("3a").is_err());
assert!(parse_pg_query_timeout_input("1.5min").is_err());
assert!(parse_pg_query_timeout_input("ms").is_err());
assert!(parse_pg_query_timeout_input("a").is_err());
assert!(parse_pg_query_timeout_input("-1").is_err());
assert_eq!(50, parse_pg_query_timeout_input("50").unwrap());
assert_eq!(12, parse_pg_query_timeout_input("12ms").unwrap());
assert_eq!(2000, parse_pg_query_timeout_input("2s").unwrap());
assert_eq!(60000, parse_pg_query_timeout_input("1min").unwrap());
}
}