1use std::collections::HashMap;
19use std::sync::Arc;
20
21use common_query::Output;
22use common_recordbatch::RecordBatches;
23use common_time::timezone::system_timezone_name;
24use common_version;
25use datatypes::prelude::ConcreteDataType;
26use datatypes::schema::{ColumnSchema, Schema};
27use datatypes::vectors::StringVector;
28use once_cell::sync::Lazy;
29use regex::Regex;
30use regex::bytes::RegexSet;
31use session::SessionRef;
32use session::context::QueryContextRef;
33
34static SELECT_VAR_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new("(?i)^(SELECT @@(.*))").unwrap());
35static MYSQL_CONN_JAVA_PATTERN: Lazy<Regex> =
36 Lazy::new(|| Regex::new("(?i)^(/\\* mysql-connector-j(.*))").unwrap());
37static SHOW_LOWER_CASE_PATTERN: Lazy<Regex> =
38 Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES LIKE 'lower_case_table_names'(.*))").unwrap());
39static SHOW_VARIABLES_LIKE_PATTERN: Lazy<Regex> =
40 Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES( LIKE (.*))?)").unwrap());
41static SHOW_WARNINGS_PATTERN: Lazy<Regex> =
42 Lazy::new(|| Regex::new("(?i)^(/\\* ApplicationName=.*)?SHOW WARNINGS").unwrap());
43
44static SELECT_TIME_DIFF_FUNC_PATTERN: Lazy<Regex> =
46 Lazy::new(|| Regex::new("(?i)^(SELECT TIMEDIFF\\(NOW\\(\\), UTC_TIMESTAMP\\(\\)\\))").unwrap());
47
48static SHOW_SQL_MODE_PATTERN: Lazy<Regex> =
50 Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES LIKE 'sql_mode'(.*))").unwrap());
51
52static OTHER_NOT_SUPPORTED_STMT: Lazy<RegexSet> = Lazy::new(|| {
53 RegexSet::new([
54 "(?i)^(ROLLBACK(.*))",
56 "(?i)^(COMMIT(.*))",
57 "(?i)^(START(.*))",
58
59 "(?i)^(SET NAMES(.*))",
61 "(?i)^(SET character_set_results(.*))",
62 "(?i)^(SET net_write_timeout(.*))",
63 "(?i)^(SET FOREIGN_KEY_CHECKS(.*))",
64 "(?i)^(SET AUTOCOMMIT(.*))",
65 "(?i)^(SET SQL_LOG_BIN(.*))",
66 "(?i)^(SET SESSION TRANSACTION(.*))",
67 "(?i)^(SET TRANSACTION(.*))",
68 "(?i)^(SET sql_mode(.*))",
69 "(?i)^(SET SQL_SELECT_LIMIT(.*))",
70 "(?i)^(SET PROFILING(.*))",
71
72 "(?i)^(SELECT \\$\\$)",
74
75 "(?i)^(SET SQL_QUOTE_SHOW_CREATE(.*))",
77 "(?i)^(LOCK TABLES(.*))",
78 "(?i)^(UNLOCK TABLES(.*))",
79 "(?i)^(SELECT LOGFILE_GROUP_NAME, FILE_NAME, TOTAL_EXTENTS, INITIAL_SIZE, ENGINE, EXTRA FROM INFORMATION_SCHEMA.FILES(.*))",
80
81 "(?i)^(/\\*!80003 SET(.*) \\*/)$",
83 "(?i)^(SHOW MASTER STATUS)",
84 "(?i)^(SHOW ALL SLAVES STATUS)",
85 "(?i)^(LOCK BINLOG FOR BACKUP)",
86 "(?i)^(LOCK TABLES FOR BACKUP)",
87 "(?i)^(UNLOCK BINLOG(.*))",
88 "(?i)^(/\\*!40101 SET(.*) \\*/)$",
89
90 "(?i)^(/\\* ApplicationName=(.*)SHOW PLUGINS)",
92 "(?i)^(/\\* ApplicationName=(.*)SHOW ENGINES)",
93 "(?i)^(/\\* ApplicationName=(.*)SELECT @@(.*))",
94 "(?i)^(/\\* ApplicationName=(.*)SHOW @@(.*))",
95 "(?i)^(/\\* ApplicationName=(.*)SET net_write_timeout(.*))",
96 "(?i)^(/\\* ApplicationName=(.*)SET SQL_SELECT_LIMIT(.*))",
97 "(?i)^(/\\* ApplicationName=(.*)SHOW VARIABLES(.*))",
98
99 "(?i)^(/\\*!40101 SET(.*) \\*/)$",
101
102 "(?i)^(/\\*!40100 SET(.*) \\*/)$",
104 "(?i)^(/\\*!40103 SET(.*) \\*/)$",
105 "(?i)^(/\\*!40111 SET(.*) \\*/)$",
106 "(?i)^(/\\*!40101 SET(.*) \\*/)$",
107 "(?i)^(/\\*!40014 SET(.*) \\*/)$",
108 "(?i)^(/\\*!40000 SET(.*) \\*/)$",
109 ]).unwrap()
110});
111
112static VAR_VALUES: Lazy<HashMap<&str, &str>> = Lazy::new(|| {
113 HashMap::from([
114 ("tx_isolation", "REPEATABLE-READ"),
115 ("session.tx_isolation", "REPEATABLE-READ"),
116 ("transaction_isolation", "REPEATABLE-READ"),
117 ("session.transaction_isolation", "REPEATABLE-READ"),
118 ("session.transaction_read_only", "0"),
119 ("max_allowed_packet", "134217728"),
120 ("interactive_timeout", "31536000"),
121 ("wait_timeout", "31536000"),
122 ("net_write_timeout", "31536000"),
123 ("version_comment", common_version::product_name()),
124 ])
125});
126
127fn select_function(name: &str, value: &str) -> RecordBatches {
132 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
133 name,
134 ConcreteDataType::string_datatype(),
135 true,
136 )]));
137 let columns = vec![Arc::new(StringVector::from(vec![value])) as _];
138 RecordBatches::try_from_columns(schema, columns)
139 .unwrap()
141}
142
143fn show_variables(name: &str, value: &str) -> RecordBatches {
148 let schema = Arc::new(Schema::new(vec![
149 ColumnSchema::new("Variable_name", ConcreteDataType::string_datatype(), true),
150 ColumnSchema::new("Value", ConcreteDataType::string_datatype(), true),
151 ]));
152 let columns = vec![
153 Arc::new(StringVector::from(vec![name])) as _,
154 Arc::new(StringVector::from(vec![value])) as _,
155 ];
156 RecordBatches::try_from_columns(schema, columns)
157 .unwrap()
159}
160
161fn select_variable(query: &str, query_context: QueryContextRef) -> Option<Output> {
162 let mut fields = vec![];
163 let mut values = vec![];
164
165 let query = query.to_lowercase();
167 let vars: Vec<&str> = query.split("@@").collect();
168 if vars.len() <= 1 {
169 return None;
170 }
171
172 for var in vars.iter().skip(1) {
174 let var = var.trim_matches(|c| c == ' ' || c == ',' || c == ';');
175 let var_as: Vec<&str> = var
176 .split(" as ")
177 .map(|x| {
178 x.trim_matches(|c| c == ' ')
179 .split_whitespace()
180 .next()
181 .unwrap_or("")
182 })
183 .collect();
184
185 let value = match var_as[0] {
187 "session.time_zone" | "time_zone" => query_context.timezone().to_string(),
188 "system_time_zone" => system_timezone_name(),
189 "max_execution_time" | "session.max_execution_time" => {
190 query_context.query_timeout_as_millis().to_string()
191 }
192 _ => VAR_VALUES
193 .get(var_as[0])
194 .map(|v| v.to_string())
195 .unwrap_or_else(|| "0".to_owned()),
196 };
197
198 values.push(Arc::new(StringVector::from(vec![value])) as _);
199 match var_as.len() {
200 1 => {
201 fields.push(ColumnSchema::new(
204 format!("@@{}", var_as[0]),
205 ConcreteDataType::string_datatype(),
206 true,
207 ));
208 }
209 2 => {
210 fields.push(ColumnSchema::new(
214 var_as[1],
215 ConcreteDataType::string_datatype(),
216 true,
217 ));
218 }
219 _ => return None,
220 }
221 }
222
223 let schema = Arc::new(Schema::new(fields));
224 let batches = RecordBatches::try_from_columns(schema, values).unwrap();
226 Some(Output::new_with_record_batches(batches))
227}
228
229fn check_select_variable(query: &str, query_context: QueryContextRef) -> Option<Output> {
230 if [&SELECT_VAR_PATTERN, &MYSQL_CONN_JAVA_PATTERN]
231 .iter()
232 .any(|r| r.is_match(query))
233 {
234 select_variable(query, query_context)
235 } else {
236 None
237 }
238}
239
240fn check_show_variables(query: &str) -> Option<Output> {
241 let recordbatches = if SHOW_SQL_MODE_PATTERN.is_match(query) {
242 Some(show_variables(
243 "sql_mode",
244 "ONLY_FULL_GROUP_BY STRICT_TRANS_TABLES NO_ZERO_IN_DATE NO_ZERO_DATE ERROR_FOR_DIVISION_BY_ZERO NO_ENGINE_SUBSTITUTION",
245 ))
246 } else if SHOW_LOWER_CASE_PATTERN.is_match(query) {
247 Some(show_variables("lower_case_table_names", "0"))
248 } else if SHOW_VARIABLES_LIKE_PATTERN.is_match(query) {
249 Some(show_variables("", ""))
250 } else {
251 None
252 };
253 recordbatches.map(Output::new_with_record_batches)
254}
255
256fn show_warnings(session: &SessionRef) -> RecordBatches {
258 let schema = Arc::new(Schema::new(vec![
259 ColumnSchema::new("Level", ConcreteDataType::string_datatype(), false),
260 ColumnSchema::new("Code", ConcreteDataType::uint16_datatype(), false),
261 ColumnSchema::new("Message", ConcreteDataType::string_datatype(), false),
262 ]));
263
264 let warnings = session.warnings();
265 let count = warnings.len();
266
267 let columns = if count > 0 {
268 vec![
269 Arc::new(StringVector::from(vec!["Warning"; count])) as _,
270 Arc::new(datatypes::vectors::UInt16Vector::from(vec![
271 Some(1000u16);
272 count
273 ])) as _,
274 Arc::new(StringVector::from(warnings)) as _,
275 ]
276 } else {
277 vec![
278 Arc::new(StringVector::from(Vec::<String>::new())) as _,
279 Arc::new(datatypes::vectors::UInt16Vector::from(
280 Vec::<Option<u16>>::new(),
281 )) as _,
282 Arc::new(StringVector::from(Vec::<String>::new())) as _,
283 ]
284 };
285
286 RecordBatches::try_from_columns(schema, columns).unwrap()
287}
288
289fn check_show_warnings(query: &str, session: &SessionRef) -> Option<Output> {
290 if SHOW_WARNINGS_PATTERN.is_match(query) {
291 Some(Output::new_with_record_batches(show_warnings(session)))
292 } else {
293 None
294 }
295}
296
297fn check_others(query: &str, _query_ctx: QueryContextRef) -> Option<Output> {
299 if OTHER_NOT_SUPPORTED_STMT.is_match(query.as_bytes()) {
300 return Some(Output::new_with_record_batches(RecordBatches::empty()));
301 }
302
303 let recordbatches = if SELECT_TIME_DIFF_FUNC_PATTERN.is_match(query) {
304 Some(select_function(
305 "TIMEDIFF(NOW(), UTC_TIMESTAMP())",
306 "00:00:00",
307 ))
308 } else {
309 None
310 };
311 recordbatches.map(Output::new_with_record_batches)
312}
313
314pub(crate) fn check(
317 query: &str,
318 query_ctx: QueryContextRef,
319 session: SessionRef,
320) -> Option<Output> {
321 let the_6th_index = query.char_indices().nth(6).map(|(i, _)| i);
324 if let Some(index) = the_6th_index
325 && query[..index].eq_ignore_ascii_case("INSERT")
326 {
327 return None;
328 }
329
330 check_select_variable(query, query_ctx.clone())
332 .or_else(|| check_show_variables(query))
333 .or_else(|| check_show_warnings(query, &session))
334 .or_else(|| check_others(query, query_ctx))
336}
337
338#[cfg(test)]
339mod test {
340
341 use common_query::OutputData;
342 use common_time::timezone::set_default_timezone;
343 use session::Session;
344 use session::context::{Channel, QueryContext};
345
346 use super::*;
347
348 #[test]
349 fn test_check_abnormal() {
350 let session = Arc::new(Session::new(None, Channel::Mysql, Default::default(), 0));
351 let query = "🫣一点不正常的东西🫣";
352 let output = check(query, QueryContext::arc(), session.clone());
353
354 assert!(output.is_none());
355 }
356
357 #[test]
358 fn test_check() {
359 let session = Arc::new(Session::new(None, Channel::Mysql, Default::default(), 0));
360 let query = "select 1";
361 let result = check(query, QueryContext::arc(), session.clone());
362 assert!(result.is_none());
363
364 let query = "select version";
365 let output = check(query, QueryContext::arc(), session.clone());
366 assert!(output.is_none());
367
368 fn test(query: &str, expected: &str) {
369 let session = Arc::new(Session::new(None, Channel::Mysql, Default::default(), 0));
370 let output = check(query, QueryContext::arc(), session.clone());
371 match output.unwrap().data {
372 OutputData::RecordBatches(r) => {
373 assert_eq!(&r.pretty_print().unwrap(), expected)
374 }
375 _ => unreachable!(),
376 }
377 }
378
379 let query = "SELECT @@version_comment LIMIT 1";
380 let expected = "\
381+-------------------+
382| @@version_comment |
383+-------------------+
384| GreptimeDB |
385+-------------------+";
386 test(query, expected);
387
388 let query = "select @@tx_isolation, @@session.tx_isolation";
390 let expected = "\
391+-----------------+------------------------+
392| @@tx_isolation | @@session.tx_isolation |
393+-----------------+------------------------+
394| REPEATABLE-READ | REPEATABLE-READ |
395+-----------------+------------------------+";
396 test(query, expected);
397
398 set_default_timezone(Some("Asia/Shanghai")).unwrap();
400 let query = "/* mysql-connector-java-8.0.17 (Revision: 16a712ddb3f826a1933ab42b0039f7fb9eebc6ec) */SELECT @@session.auto_increment_increment AS auto_increment_increment, @@character_set_client AS character_set_client, @@character_set_connection AS character_set_connection, @@character_set_results AS character_set_results, @@character_set_server AS character_set_server, @@collation_server AS collation_server, @@collation_connection AS collation_connection, @@init_connect AS init_connect, @@interactive_timeout AS interactive_timeout, @@license AS license, @@lower_case_table_names AS lower_case_table_names, @@max_allowed_packet AS max_allowed_packet, @@net_write_timeout AS net_write_timeout, @@performance_schema AS performance_schema, @@sql_mode AS sql_mode, @@system_time_zone AS system_time_zone, @@time_zone AS time_zone, @@transaction_isolation AS transaction_isolation, @@wait_timeout AS wait_timeout;";
402 let expected = "\
403+--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+------------------+---------------+-----------------------+--------------+
404| auto_increment_increment | character_set_client | character_set_connection | character_set_results | character_set_server | collation_server | collation_connection | init_connect | interactive_timeout | license | lower_case_table_names | max_allowed_packet | net_write_timeout | performance_schema | sql_mode | system_time_zone | time_zone | transaction_isolation | wait_timeout |
405+--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+------------------+---------------+-----------------------+--------------+
406| 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 31536000 | 0 | 0 | 134217728 | 31536000 | 0 | 0 | Asia/Shanghai | Asia/Shanghai | REPEATABLE-READ | 31536000 |
407+--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+------------------+---------------+-----------------------+--------------+";
408 test(query, expected);
409
410 let query = "show variables";
411 let expected = "\
412+---------------+-------+
413| Variable_name | Value |
414+---------------+-------+
415| | |
416+---------------+-------+";
417 test(query, expected);
418
419 let query = "show variables like 'lower_case_table_names'";
420 let expected = "\
421+------------------------+-------+
422| Variable_name | Value |
423+------------------------+-------+
424| lower_case_table_names | 0 |
425+------------------------+-------+";
426 test(query, expected);
427
428 let query = "SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP())";
429 let expected = "\
430+----------------------------------+
431| TIMEDIFF(NOW(), UTC_TIMESTAMP()) |
432+----------------------------------+
433| 00:00:00 |
434+----------------------------------+";
435 test(query, expected);
436 }
437
438 #[test]
439 fn test_show_warnings() {
440 let session = Arc::new(Session::new(None, Channel::Mysql, Default::default(), 0));
442 let output = check("SHOW WARNINGS", QueryContext::arc(), session.clone());
443 match output.unwrap().data {
444 OutputData::RecordBatches(r) => {
445 assert_eq!(r.iter().map(|b| b.num_rows()).sum::<usize>(), 0);
446 }
447 _ => unreachable!(),
448 }
449
450 session.add_warning("Test warning message".to_string());
452 let output = check("SHOW WARNINGS", QueryContext::arc(), session.clone());
453 match output.unwrap().data {
454 OutputData::RecordBatches(r) => {
455 let expected = "\
456+---------+------+----------------------+
457| Level | Code | Message |
458+---------+------+----------------------+
459| Warning | 1000 | Test warning message |
460+---------+------+----------------------+";
461 assert_eq!(&r.pretty_print().unwrap(), expected);
462 }
463 _ => unreachable!(),
464 }
465
466 session.clear_warnings();
468 session.add_warning("First warning".to_string());
469 session.add_warning("Second warning".to_string());
470 let output = check("SHOW WARNINGS", QueryContext::arc(), session.clone());
471 match output.unwrap().data {
472 OutputData::RecordBatches(r) => {
473 let expected = "\
474+---------+------+----------------+
475| Level | Code | Message |
476+---------+------+----------------+
477| Warning | 1000 | First warning |
478| Warning | 1000 | Second warning |
479+---------+------+----------------+";
480 assert_eq!(&r.pretty_print().unwrap(), expected);
481 }
482 _ => unreachable!(),
483 }
484
485 let output = check("show warnings", QueryContext::arc(), session.clone());
487 assert!(output.is_some());
488
489 let output = check(
491 "/* ApplicationName=DBeaver */SHOW WARNINGS",
492 QueryContext::arc(),
493 session.clone(),
494 );
495 assert!(output.is_some());
496 }
497}