1use std::collections::HashMap;
19use std::sync::Arc;
20
21use common_query::Output;
22use common_recordbatch::RecordBatches;
23use common_time::timezone::system_timezone_name;
24use datatypes::prelude::ConcreteDataType;
25use datatypes::schema::{ColumnSchema, Schema};
26use datatypes::vectors::StringVector;
27use once_cell::sync::Lazy;
28use regex::Regex;
29use regex::bytes::RegexSet;
30use session::SessionRef;
31use session::context::QueryContextRef;
32
33static SELECT_VAR_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new("(?i)^(SELECT @@(.*))").unwrap());
34static MYSQL_CONN_JAVA_PATTERN: Lazy<Regex> =
35 Lazy::new(|| Regex::new("(?i)^(/\\* mysql-connector-j(.*))").unwrap());
36static SHOW_LOWER_CASE_PATTERN: Lazy<Regex> =
37 Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES LIKE 'lower_case_table_names'(.*))").unwrap());
38static SHOW_VARIABLES_LIKE_PATTERN: Lazy<Regex> =
39 Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES( LIKE (.*))?)").unwrap());
40static SHOW_WARNINGS_PATTERN: Lazy<Regex> =
41 Lazy::new(|| Regex::new("(?i)^(/\\* ApplicationName=.*)?SHOW WARNINGS").unwrap());
42
43static SELECT_TIME_DIFF_FUNC_PATTERN: Lazy<Regex> =
45 Lazy::new(|| Regex::new("(?i)^(SELECT TIMEDIFF\\(NOW\\(\\), UTC_TIMESTAMP\\(\\)\\))").unwrap());
46
47static SHOW_SQL_MODE_PATTERN: Lazy<Regex> =
49 Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES LIKE 'sql_mode'(.*))").unwrap());
50
51static OTHER_NOT_SUPPORTED_STMT: Lazy<RegexSet> = Lazy::new(|| {
52 RegexSet::new([
53 "(?i)^(ROLLBACK(.*))",
55 "(?i)^(COMMIT(.*))",
56 "(?i)^(START(.*))",
57
58 "(?i)^(SET NAMES(.*))",
60 "(?i)^(SET character_set_results(.*))",
61 "(?i)^(SET net_write_timeout(.*))",
62 "(?i)^(SET FOREIGN_KEY_CHECKS(.*))",
63 "(?i)^(SET AUTOCOMMIT(.*))",
64 "(?i)^(SET SQL_LOG_BIN(.*))",
65 "(?i)^(SET SESSION TRANSACTION(.*))",
66 "(?i)^(SET TRANSACTION(.*))",
67 "(?i)^(SET sql_mode(.*))",
68 "(?i)^(SET SQL_SELECT_LIMIT(.*))",
69 "(?i)^(SET PROFILING(.*))",
70
71 "(?i)^(SELECT \\$\\$)",
73
74 "(?i)^(SET SQL_QUOTE_SHOW_CREATE(.*))",
76 "(?i)^(LOCK TABLES(.*))",
77 "(?i)^(UNLOCK TABLES(.*))",
78 "(?i)^(SELECT LOGFILE_GROUP_NAME, FILE_NAME, TOTAL_EXTENTS, INITIAL_SIZE, ENGINE, EXTRA FROM INFORMATION_SCHEMA.FILES(.*))",
79
80 "(?i)^(/\\*!80003 SET(.*) \\*/)$",
82 "(?i)^(SHOW MASTER STATUS)",
83 "(?i)^(SHOW ALL SLAVES STATUS)",
84 "(?i)^(LOCK BINLOG FOR BACKUP)",
85 "(?i)^(LOCK TABLES FOR BACKUP)",
86 "(?i)^(UNLOCK BINLOG(.*))",
87 "(?i)^(/\\*!40101 SET(.*) \\*/)$",
88
89 "(?i)^(/\\* ApplicationName=(.*)SHOW PLUGINS)",
91 "(?i)^(/\\* ApplicationName=(.*)SHOW ENGINES)",
92 "(?i)^(/\\* ApplicationName=(.*)SELECT @@(.*))",
93 "(?i)^(/\\* ApplicationName=(.*)SHOW @@(.*))",
94 "(?i)^(/\\* ApplicationName=(.*)SET net_write_timeout(.*))",
95 "(?i)^(/\\* ApplicationName=(.*)SET SQL_SELECT_LIMIT(.*))",
96 "(?i)^(/\\* ApplicationName=(.*)SHOW VARIABLES(.*))",
97
98 "(?i)^(/\\*!40101 SET(.*) \\*/)$",
100
101 "(?i)^(/\\*!40100 SET(.*) \\*/)$",
103 "(?i)^(/\\*!40103 SET(.*) \\*/)$",
104 "(?i)^(/\\*!40111 SET(.*) \\*/)$",
105 "(?i)^(/\\*!40101 SET(.*) \\*/)$",
106 "(?i)^(/\\*!40014 SET(.*) \\*/)$",
107 "(?i)^(/\\*!40000 SET(.*) \\*/)$",
108 ]).unwrap()
109});
110
111static VAR_VALUES: Lazy<HashMap<&str, &str>> = Lazy::new(|| {
112 HashMap::from([
113 ("tx_isolation", "REPEATABLE-READ"),
114 ("session.tx_isolation", "REPEATABLE-READ"),
115 ("transaction_isolation", "REPEATABLE-READ"),
116 ("session.transaction_isolation", "REPEATABLE-READ"),
117 ("session.transaction_read_only", "0"),
118 ("max_allowed_packet", "134217728"),
119 ("interactive_timeout", "31536000"),
120 ("wait_timeout", "31536000"),
121 ("net_write_timeout", "31536000"),
122 ("version_comment", "Greptime"),
123 ])
124});
125
126fn select_function(name: &str, value: &str) -> RecordBatches {
131 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
132 name,
133 ConcreteDataType::string_datatype(),
134 true,
135 )]));
136 let columns = vec![Arc::new(StringVector::from(vec![value])) as _];
137 RecordBatches::try_from_columns(schema, columns)
138 .unwrap()
140}
141
142fn show_variables(name: &str, value: &str) -> RecordBatches {
147 let schema = Arc::new(Schema::new(vec![
148 ColumnSchema::new("Variable_name", ConcreteDataType::string_datatype(), true),
149 ColumnSchema::new("Value", ConcreteDataType::string_datatype(), true),
150 ]));
151 let columns = vec![
152 Arc::new(StringVector::from(vec![name])) as _,
153 Arc::new(StringVector::from(vec![value])) as _,
154 ];
155 RecordBatches::try_from_columns(schema, columns)
156 .unwrap()
158}
159
160fn select_variable(query: &str, query_context: QueryContextRef) -> Option<Output> {
161 let mut fields = vec![];
162 let mut values = vec![];
163
164 let query = query.to_lowercase();
166 let vars: Vec<&str> = query.split("@@").collect();
167 if vars.len() <= 1 {
168 return None;
169 }
170
171 for var in vars.iter().skip(1) {
173 let var = var.trim_matches(|c| c == ' ' || c == ',' || c == ';');
174 let var_as: Vec<&str> = var
175 .split(" as ")
176 .map(|x| {
177 x.trim_matches(|c| c == ' ')
178 .split_whitespace()
179 .next()
180 .unwrap_or("")
181 })
182 .collect();
183
184 let value = match var_as[0] {
186 "session.time_zone" | "time_zone" => query_context.timezone().to_string(),
187 "system_time_zone" => system_timezone_name(),
188 "max_execution_time" | "session.max_execution_time" => {
189 query_context.query_timeout_as_millis().to_string()
190 }
191 _ => VAR_VALUES
192 .get(var_as[0])
193 .map(|v| v.to_string())
194 .unwrap_or_else(|| "0".to_owned()),
195 };
196
197 values.push(Arc::new(StringVector::from(vec![value])) as _);
198 match var_as.len() {
199 1 => {
200 fields.push(ColumnSchema::new(
203 format!("@@{}", var_as[0]),
204 ConcreteDataType::string_datatype(),
205 true,
206 ));
207 }
208 2 => {
209 fields.push(ColumnSchema::new(
213 var_as[1],
214 ConcreteDataType::string_datatype(),
215 true,
216 ));
217 }
218 _ => return None,
219 }
220 }
221
222 let schema = Arc::new(Schema::new(fields));
223 let batches = RecordBatches::try_from_columns(schema, values).unwrap();
225 Some(Output::new_with_record_batches(batches))
226}
227
228fn check_select_variable(query: &str, query_context: QueryContextRef) -> Option<Output> {
229 if [&SELECT_VAR_PATTERN, &MYSQL_CONN_JAVA_PATTERN]
230 .iter()
231 .any(|r| r.is_match(query))
232 {
233 select_variable(query, query_context)
234 } else {
235 None
236 }
237}
238
239fn check_show_variables(query: &str) -> Option<Output> {
240 let recordbatches = if SHOW_SQL_MODE_PATTERN.is_match(query) {
241 Some(show_variables(
242 "sql_mode",
243 "ONLY_FULL_GROUP_BY STRICT_TRANS_TABLES NO_ZERO_IN_DATE NO_ZERO_DATE ERROR_FOR_DIVISION_BY_ZERO NO_ENGINE_SUBSTITUTION",
244 ))
245 } else if SHOW_LOWER_CASE_PATTERN.is_match(query) {
246 Some(show_variables("lower_case_table_names", "0"))
247 } else if SHOW_VARIABLES_LIKE_PATTERN.is_match(query) {
248 Some(show_variables("", ""))
249 } else {
250 None
251 };
252 recordbatches.map(Output::new_with_record_batches)
253}
254
255fn show_warnings(session: &SessionRef) -> RecordBatches {
257 let schema = Arc::new(Schema::new(vec![
258 ColumnSchema::new("Level", ConcreteDataType::string_datatype(), false),
259 ColumnSchema::new("Code", ConcreteDataType::uint16_datatype(), false),
260 ColumnSchema::new("Message", ConcreteDataType::string_datatype(), false),
261 ]));
262
263 let warnings = session.warnings();
264 let count = warnings.len();
265
266 let columns = if count > 0 {
267 vec![
268 Arc::new(StringVector::from(vec!["Warning"; count])) as _,
269 Arc::new(datatypes::vectors::UInt16Vector::from(vec![
270 Some(1000u16);
271 count
272 ])) as _,
273 Arc::new(StringVector::from(warnings)) as _,
274 ]
275 } else {
276 vec![
277 Arc::new(StringVector::from(Vec::<String>::new())) as _,
278 Arc::new(datatypes::vectors::UInt16Vector::from(
279 Vec::<Option<u16>>::new(),
280 )) as _,
281 Arc::new(StringVector::from(Vec::<String>::new())) as _,
282 ]
283 };
284
285 RecordBatches::try_from_columns(schema, columns).unwrap()
286}
287
288fn check_show_warnings(query: &str, session: &SessionRef) -> Option<Output> {
289 if SHOW_WARNINGS_PATTERN.is_match(query) {
290 Some(Output::new_with_record_batches(show_warnings(session)))
291 } else {
292 None
293 }
294}
295
296fn check_others(query: &str, _query_ctx: QueryContextRef) -> Option<Output> {
298 if OTHER_NOT_SUPPORTED_STMT.is_match(query.as_bytes()) {
299 return Some(Output::new_with_record_batches(RecordBatches::empty()));
300 }
301
302 let recordbatches = if SELECT_TIME_DIFF_FUNC_PATTERN.is_match(query) {
303 Some(select_function(
304 "TIMEDIFF(NOW(), UTC_TIMESTAMP())",
305 "00:00:00",
306 ))
307 } else {
308 None
309 };
310 recordbatches.map(Output::new_with_record_batches)
311}
312
313pub(crate) fn check(
316 query: &str,
317 query_ctx: QueryContextRef,
318 session: SessionRef,
319) -> Option<Output> {
320 let the_6th_index = query.char_indices().nth(6).map(|(i, _)| i);
323 if let Some(index) = the_6th_index
324 && query[..index].eq_ignore_ascii_case("INSERT")
325 {
326 return None;
327 }
328
329 check_select_variable(query, query_ctx.clone())
331 .or_else(|| check_show_variables(query))
332 .or_else(|| check_show_warnings(query, &session))
333 .or_else(|| check_others(query, query_ctx))
335}
336
337#[cfg(test)]
338mod test {
339
340 use common_query::OutputData;
341 use common_time::timezone::set_default_timezone;
342 use session::Session;
343 use session::context::{Channel, QueryContext};
344
345 use super::*;
346
347 #[test]
348 fn test_check_abnormal() {
349 let session = Arc::new(Session::new(None, Channel::Mysql, Default::default(), 0));
350 let query = "🫣一点不正常的东西🫣";
351 let output = check(query, QueryContext::arc(), session.clone());
352
353 assert!(output.is_none());
354 }
355
356 #[test]
357 fn test_check() {
358 let session = Arc::new(Session::new(None, Channel::Mysql, Default::default(), 0));
359 let query = "select 1";
360 let result = check(query, QueryContext::arc(), session.clone());
361 assert!(result.is_none());
362
363 let query = "select version";
364 let output = check(query, QueryContext::arc(), session.clone());
365 assert!(output.is_none());
366
367 fn test(query: &str, expected: &str) {
368 let session = Arc::new(Session::new(None, Channel::Mysql, Default::default(), 0));
369 let output = check(query, QueryContext::arc(), session.clone());
370 match output.unwrap().data {
371 OutputData::RecordBatches(r) => {
372 assert_eq!(&r.pretty_print().unwrap(), expected)
373 }
374 _ => unreachable!(),
375 }
376 }
377
378 let query = "SELECT @@version_comment LIMIT 1";
379 let expected = "\
380+-------------------+
381| @@version_comment |
382+-------------------+
383| Greptime |
384+-------------------+";
385 test(query, expected);
386
387 let query = "select @@tx_isolation, @@session.tx_isolation";
389 let expected = "\
390+-----------------+------------------------+
391| @@tx_isolation | @@session.tx_isolation |
392+-----------------+------------------------+
393| REPEATABLE-READ | REPEATABLE-READ |
394+-----------------+------------------------+";
395 test(query, expected);
396
397 set_default_timezone(Some("Asia/Shanghai")).unwrap();
399 let query = "/* mysql-connector-java-8.0.17 (Revision: 16a712ddb3f826a1933ab42b0039f7fb9eebc6ec) */SELECT @@session.auto_increment_increment AS auto_increment_increment, @@character_set_client AS character_set_client, @@character_set_connection AS character_set_connection, @@character_set_results AS character_set_results, @@character_set_server AS character_set_server, @@collation_server AS collation_server, @@collation_connection AS collation_connection, @@init_connect AS init_connect, @@interactive_timeout AS interactive_timeout, @@license AS license, @@lower_case_table_names AS lower_case_table_names, @@max_allowed_packet AS max_allowed_packet, @@net_write_timeout AS net_write_timeout, @@performance_schema AS performance_schema, @@sql_mode AS sql_mode, @@system_time_zone AS system_time_zone, @@time_zone AS time_zone, @@transaction_isolation AS transaction_isolation, @@wait_timeout AS wait_timeout;";
401 let expected = "\
402+--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+------------------+---------------+-----------------------+--------------+
403| auto_increment_increment | character_set_client | character_set_connection | character_set_results | character_set_server | collation_server | collation_connection | init_connect | interactive_timeout | license | lower_case_table_names | max_allowed_packet | net_write_timeout | performance_schema | sql_mode | system_time_zone | time_zone | transaction_isolation | wait_timeout |
404+--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+------------------+---------------+-----------------------+--------------+
405| 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 31536000 | 0 | 0 | 134217728 | 31536000 | 0 | 0 | Asia/Shanghai | Asia/Shanghai | REPEATABLE-READ | 31536000 |
406+--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+------------------+---------------+-----------------------+--------------+";
407 test(query, expected);
408
409 let query = "show variables";
410 let expected = "\
411+---------------+-------+
412| Variable_name | Value |
413+---------------+-------+
414| | |
415+---------------+-------+";
416 test(query, expected);
417
418 let query = "show variables like 'lower_case_table_names'";
419 let expected = "\
420+------------------------+-------+
421| Variable_name | Value |
422+------------------------+-------+
423| lower_case_table_names | 0 |
424+------------------------+-------+";
425 test(query, expected);
426
427 let query = "SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP())";
428 let expected = "\
429+----------------------------------+
430| TIMEDIFF(NOW(), UTC_TIMESTAMP()) |
431+----------------------------------+
432| 00:00:00 |
433+----------------------------------+";
434 test(query, expected);
435 }
436
437 #[test]
438 fn test_show_warnings() {
439 let session = Arc::new(Session::new(None, Channel::Mysql, Default::default(), 0));
441 let output = check("SHOW WARNINGS", QueryContext::arc(), session.clone());
442 match output.unwrap().data {
443 OutputData::RecordBatches(r) => {
444 assert_eq!(r.iter().map(|b| b.num_rows()).sum::<usize>(), 0);
445 }
446 _ => unreachable!(),
447 }
448
449 session.add_warning("Test warning message".to_string());
451 let output = check("SHOW WARNINGS", QueryContext::arc(), session.clone());
452 match output.unwrap().data {
453 OutputData::RecordBatches(r) => {
454 let expected = "\
455+---------+------+----------------------+
456| Level | Code | Message |
457+---------+------+----------------------+
458| Warning | 1000 | Test warning message |
459+---------+------+----------------------+";
460 assert_eq!(&r.pretty_print().unwrap(), expected);
461 }
462 _ => unreachable!(),
463 }
464
465 session.clear_warnings();
467 session.add_warning("First warning".to_string());
468 session.add_warning("Second warning".to_string());
469 let output = check("SHOW WARNINGS", QueryContext::arc(), session.clone());
470 match output.unwrap().data {
471 OutputData::RecordBatches(r) => {
472 let expected = "\
473+---------+------+----------------+
474| Level | Code | Message |
475+---------+------+----------------+
476| Warning | 1000 | First warning |
477| Warning | 1000 | Second warning |
478+---------+------+----------------+";
479 assert_eq!(&r.pretty_print().unwrap(), expected);
480 }
481 _ => unreachable!(),
482 }
483
484 let output = check("show warnings", QueryContext::arc(), session.clone());
486 assert!(output.is_some());
487
488 let output = check(
490 "/* ApplicationName=DBeaver */SHOW WARNINGS",
491 QueryContext::arc(),
492 session.clone(),
493 );
494 assert!(output.is_some());
495 }
496}