use std::collections::HashMap;
use std::sync::Arc;
use common_query::Output;
use common_recordbatch::RecordBatches;
use common_time::timezone::system_timezone_name;
use datatypes::prelude::ConcreteDataType;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::vectors::StringVector;
use once_cell::sync::Lazy;
use regex::bytes::RegexSet;
use regex::Regex;
use session::context::QueryContextRef;
use session::SessionRef;
static SELECT_VAR_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new("(?i)^(SELECT @@(.*))").unwrap());
static MYSQL_CONN_JAVA_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new("(?i)^(/\\* mysql-connector-j(.*))").unwrap());
static SHOW_LOWER_CASE_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES LIKE 'lower_case_table_names'(.*))").unwrap());
static SHOW_VARIABLES_LIKE_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES( LIKE (.*))?)").unwrap());
static SELECT_TIME_DIFF_FUNC_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new("(?i)^(SELECT TIMEDIFF\\(NOW\\(\\), UTC_TIMESTAMP\\(\\)\\))").unwrap());
static SHOW_SQL_MODE_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES LIKE 'sql_mode'(.*))").unwrap());
static OTHER_NOT_SUPPORTED_STMT: Lazy<RegexSet> = Lazy::new(|| {
RegexSet::new([
"(?i)^(ROLLBACK(.*))",
"(?i)^(COMMIT(.*))",
"(?i)^(START(.*))",
"(?i)^(SET NAMES(.*))",
"(?i)^(SET character_set_results(.*))",
"(?i)^(SET net_write_timeout(.*))",
"(?i)^(SET FOREIGN_KEY_CHECKS(.*))",
"(?i)^(SET AUTOCOMMIT(.*))",
"(?i)^(SET SQL_LOG_BIN(.*))",
"(?i)^(SET SESSION TRANSACTION(.*))",
"(?i)^(SET TRANSACTION(.*))",
"(?i)^(SET sql_mode(.*))",
"(?i)^(SET SQL_SELECT_LIMIT(.*))",
"(?i)^(SET @@(.*))",
"(?i)^(SET PROFILING(.*))",
"(?i)^(SELECT \\$\\$)",
"(?i)^(SET SQL_QUOTE_SHOW_CREATE(.*))",
"(?i)^(LOCK TABLES(.*))",
"(?i)^(UNLOCK TABLES(.*))",
"(?i)^(SELECT LOGFILE_GROUP_NAME, FILE_NAME, TOTAL_EXTENTS, INITIAL_SIZE, ENGINE, EXTRA FROM INFORMATION_SCHEMA.FILES(.*))",
"(?i)^(/\\*!80003 SET(.*) \\*/)$",
"(?i)^(SHOW MASTER STATUS)",
"(?i)^(SHOW ALL SLAVES STATUS)",
"(?i)^(LOCK BINLOG FOR BACKUP)",
"(?i)^(LOCK TABLES FOR BACKUP)",
"(?i)^(UNLOCK BINLOG(.*))",
"(?i)^(/\\*!40101 SET(.*) \\*/)$",
"(?i)^(SHOW WARNINGS)",
"(?i)^(/\\* ApplicationName=(.*)SHOW WARNINGS)",
"(?i)^(/\\* ApplicationName=(.*)SHOW PLUGINS)",
"(?i)^(/\\* ApplicationName=(.*)SHOW ENGINES)",
"(?i)^(/\\* ApplicationName=(.*)SELECT @@(.*))",
"(?i)^(/\\* ApplicationName=(.*)SHOW @@(.*))",
"(?i)^(/\\* ApplicationName=(.*)SET net_write_timeout(.*))",
"(?i)^(/\\* ApplicationName=(.*)SET SQL_SELECT_LIMIT(.*))",
"(?i)^(/\\* ApplicationName=(.*)SHOW VARIABLES(.*))",
"(?i)^(/\\*!40101 SET(.*) \\*/)$",
"(?i)^(/\\*!40100 SET(.*) \\*/)$",
"(?i)^(/\\*!40103 SET(.*) \\*/)$",
"(?i)^(/\\*!40111 SET(.*) \\*/)$",
"(?i)^(/\\*!40101 SET(.*) \\*/)$",
"(?i)^(/\\*!40014 SET(.*) \\*/)$",
"(?i)^(/\\*!40000 SET(.*) \\*/)$",
]).unwrap()
});
static VAR_VALUES: Lazy<HashMap<&str, &str>> = Lazy::new(|| {
HashMap::from([
("tx_isolation", "REPEATABLE-READ"),
("session.tx_isolation", "REPEATABLE-READ"),
("transaction_isolation", "REPEATABLE-READ"),
("session.transaction_isolation", "REPEATABLE-READ"),
("session.transaction_read_only", "0"),
("max_allowed_packet", "134217728"),
("interactive_timeout", "31536000"),
("wait_timeout", "31536000"),
("net_write_timeout", "31536000"),
("version_comment", "Greptime"),
])
});
fn select_function(name: &str, value: &str) -> RecordBatches {
let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
name,
ConcreteDataType::string_datatype(),
true,
)]));
let columns = vec![Arc::new(StringVector::from(vec![value])) as _];
RecordBatches::try_from_columns(schema, columns)
.unwrap()
}
fn show_variables(name: &str, value: &str) -> RecordBatches {
let schema = Arc::new(Schema::new(vec![
ColumnSchema::new("Variable_name", ConcreteDataType::string_datatype(), true),
ColumnSchema::new("Value", ConcreteDataType::string_datatype(), true),
]));
let columns = vec![
Arc::new(StringVector::from(vec![name])) as _,
Arc::new(StringVector::from(vec![value])) as _,
];
RecordBatches::try_from_columns(schema, columns)
.unwrap()
}
fn select_variable(query: &str, query_context: QueryContextRef) -> Option<Output> {
let mut fields = vec![];
let mut values = vec![];
let query = query.to_lowercase();
let vars: Vec<&str> = query.split("@@").collect();
if vars.len() <= 1 {
return None;
}
for var in vars.iter().skip(1) {
let var = var.trim_matches(|c| c == ' ' || c == ',');
let var_as: Vec<&str> = var
.split(" as ")
.map(|x| {
x.trim_matches(|c| c == ' ')
.split_whitespace()
.next()
.unwrap_or("")
})
.collect();
let value = match var_as[0] {
"time_zone" => query_context.timezone().to_string(),
"system_time_zone" => system_timezone_name(),
_ => VAR_VALUES
.get(var_as[0])
.map(|v| v.to_string())
.unwrap_or_else(|| "0".to_owned()),
};
values.push(Arc::new(StringVector::from(vec![value])) as _);
match var_as.len() {
1 => {
fields.push(ColumnSchema::new(
format!("@@{}", var_as[0]),
ConcreteDataType::string_datatype(),
true,
));
}
2 => {
fields.push(ColumnSchema::new(
var_as[1],
ConcreteDataType::string_datatype(),
true,
));
}
_ => return None,
}
}
let schema = Arc::new(Schema::new(fields));
let batches = RecordBatches::try_from_columns(schema, values).unwrap();
Some(Output::new_with_record_batches(batches))
}
fn check_select_variable(query: &str, query_context: QueryContextRef) -> Option<Output> {
if [&SELECT_VAR_PATTERN, &MYSQL_CONN_JAVA_PATTERN]
.iter()
.any(|r| r.is_match(query))
{
select_variable(query, query_context)
} else {
None
}
}
fn check_show_variables(query: &str) -> Option<Output> {
let recordbatches = if SHOW_SQL_MODE_PATTERN.is_match(query) {
Some(show_variables("sql_mode", "ONLY_FULL_GROUP_BY STRICT_TRANS_TABLES NO_ZERO_IN_DATE NO_ZERO_DATE ERROR_FOR_DIVISION_BY_ZERO NO_ENGINE_SUBSTITUTION"))
} else if SHOW_LOWER_CASE_PATTERN.is_match(query) {
Some(show_variables("lower_case_table_names", "0"))
} else if SHOW_VARIABLES_LIKE_PATTERN.is_match(query) {
Some(show_variables("", ""))
} else {
None
};
recordbatches.map(Output::new_with_record_batches)
}
fn check_others(query: &str, _query_ctx: QueryContextRef) -> Option<Output> {
if OTHER_NOT_SUPPORTED_STMT.is_match(query.as_bytes()) {
return Some(Output::new_with_record_batches(RecordBatches::empty()));
}
let recordbatches = if SELECT_TIME_DIFF_FUNC_PATTERN.is_match(query) {
Some(select_function(
"TIMEDIFF(NOW(), UTC_TIMESTAMP())",
"00:00:00",
))
} else {
None
};
recordbatches.map(Output::new_with_record_batches)
}
pub(crate) fn check(
query: &str,
query_ctx: QueryContextRef,
_session: SessionRef,
) -> Option<Output> {
if query.len() > 6 && query[..6].eq_ignore_ascii_case("INSERT") {
return None;
}
check_select_variable(query, query_ctx.clone())
.or_else(|| check_show_variables(query))
.or_else(|| check_others(query, query_ctx))
}
#[cfg(test)]
mod test {
use common_query::OutputData;
use common_time::timezone::set_default_timezone;
use session::context::{Channel, QueryContext};
use session::Session;
use super::*;
#[test]
fn test_check() {
let session = Arc::new(Session::new(None, Channel::Mysql, Default::default()));
let query = "select 1";
let result = check(query, QueryContext::arc(), session.clone());
assert!(result.is_none());
let query = "select version";
let output = check(query, QueryContext::arc(), session.clone());
assert!(output.is_none());
fn test(query: &str, expected: &str) {
let session = Arc::new(Session::new(None, Channel::Mysql, Default::default()));
let output = check(query, QueryContext::arc(), session.clone());
match output.unwrap().data {
OutputData::RecordBatches(r) => {
assert_eq!(&r.pretty_print().unwrap(), expected)
}
_ => unreachable!(),
}
}
let query = "SELECT @@version_comment LIMIT 1";
let expected = "\
+-------------------+
| @@version_comment |
+-------------------+
| Greptime |
+-------------------+";
test(query, expected);
let query = "select @@tx_isolation, @@session.tx_isolation";
let expected = "\
+-----------------+------------------------+
| @@tx_isolation | @@session.tx_isolation |
+-----------------+------------------------+
| REPEATABLE-READ | REPEATABLE-READ |
+-----------------+------------------------+";
test(query, expected);
set_default_timezone(Some("Asia/Shanghai")).unwrap();
let query = "/* mysql-connector-java-8.0.17 (Revision: 16a712ddb3f826a1933ab42b0039f7fb9eebc6ec) */SELECT @@session.auto_increment_increment AS auto_increment_increment, @@character_set_client AS character_set_client, @@character_set_connection AS character_set_connection, @@character_set_results AS character_set_results, @@character_set_server AS character_set_server, @@collation_server AS collation_server, @@collation_connection AS collation_connection, @@init_connect AS init_connect, @@interactive_timeout AS interactive_timeout, @@license AS license, @@lower_case_table_names AS lower_case_table_names, @@max_allowed_packet AS max_allowed_packet, @@net_write_timeout AS net_write_timeout, @@performance_schema AS performance_schema, @@sql_mode AS sql_mode, @@system_time_zone AS system_time_zone, @@time_zone AS time_zone, @@transaction_isolation AS transaction_isolation, @@wait_timeout AS wait_timeout;";
let expected = "\
+--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+------------------+---------------+-----------------------+---------------+
| auto_increment_increment | character_set_client | character_set_connection | character_set_results | character_set_server | collation_server | collation_connection | init_connect | interactive_timeout | license | lower_case_table_names | max_allowed_packet | net_write_timeout | performance_schema | sql_mode | system_time_zone | time_zone | transaction_isolation | wait_timeout; |
+--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+------------------+---------------+-----------------------+---------------+
| 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 31536000 | 0 | 0 | 134217728 | 31536000 | 0 | 0 | Asia/Shanghai | Asia/Shanghai | REPEATABLE-READ | 31536000 |
+--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+------------------+---------------+-----------------------+---------------+";
test(query, expected);
let query = "show variables";
let expected = "\
+---------------+-------+
| Variable_name | Value |
+---------------+-------+
| | |
+---------------+-------+";
test(query, expected);
let query = "show variables like 'lower_case_table_names'";
let expected = "\
+------------------------+-------+
| Variable_name | Value |
+------------------------+-------+
| lower_case_table_names | 0 |
+------------------------+-------+";
test(query, expected);
let query = "SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP())";
let expected = "\
+----------------------------------+
| TIMEDIFF(NOW(), UTC_TIMESTAMP()) |
+----------------------------------+
| 00:00:00 |
+----------------------------------+";
test(query, expected);
}
}