Skip to main content

servers/mysql/
federated.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Use regex to filter out some MySQL federated components' emitted statements.
16//! Inspired by Databend's "[mysql_federated.rs](https://github.com/datafuselabs/databend/blob/ac706bf65845e6895141c96c0a10bad6fdc2d367/src/query/service/src/servers/mysql/mysql_federated.rs)".
17
18use std::collections::HashMap;
19use std::sync::Arc;
20
21use common_query::Output;
22use common_recordbatch::RecordBatches;
23use common_time::timezone::system_timezone_name;
24use common_version;
25use datatypes::prelude::ConcreteDataType;
26use datatypes::schema::{ColumnSchema, Schema};
27use datatypes::vectors::StringVector;
28use once_cell::sync::Lazy;
29use regex::Regex;
30use regex::bytes::RegexSet;
31use session::SessionRef;
32use session::context::QueryContextRef;
33
34static SELECT_VAR_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new("(?i)^(SELECT @@(.*))").unwrap());
35static MYSQL_CONN_JAVA_PATTERN: Lazy<Regex> =
36    Lazy::new(|| Regex::new("(?i)^(/\\* mysql-connector-j(.*))").unwrap());
37static SHOW_LOWER_CASE_PATTERN: Lazy<Regex> =
38    Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES LIKE 'lower_case_table_names'(.*))").unwrap());
39static SHOW_VARIABLES_LIKE_PATTERN: Lazy<Regex> =
40    Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES( LIKE (.*))?)").unwrap());
41static SHOW_WARNINGS_PATTERN: Lazy<Regex> =
42    Lazy::new(|| Regex::new("(?i)^(/\\* ApplicationName=.*)?SHOW WARNINGS").unwrap());
43
44// SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP());
45static SELECT_TIME_DIFF_FUNC_PATTERN: Lazy<Regex> =
46    Lazy::new(|| Regex::new("(?i)^(SELECT TIMEDIFF\\(NOW\\(\\), UTC_TIMESTAMP\\(\\)\\))").unwrap());
47
48// sqlalchemy < 1.4.30
49static SHOW_SQL_MODE_PATTERN: Lazy<Regex> =
50    Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES LIKE 'sql_mode'(.*))").unwrap());
51
52static OTHER_NOT_SUPPORTED_STMT: Lazy<RegexSet> = Lazy::new(|| {
53    RegexSet::new([
54        // Txn.
55        "(?i)^(ROLLBACK(.*))",
56        "(?i)^(COMMIT(.*))",
57        "(?i)^(START(.*))",
58
59        // Set.
60        "(?i)^(SET NAMES(.*))",
61        "(?i)^(SET character_set_results(.*))",
62        "(?i)^(SET net_write_timeout(.*))",
63        "(?i)^(SET FOREIGN_KEY_CHECKS(.*))",
64        "(?i)^(SET AUTOCOMMIT(.*))",
65        "(?i)^(SET SQL_LOG_BIN(.*))",
66        "(?i)^(SET SESSION TRANSACTION(.*))",
67        "(?i)^(SET TRANSACTION(.*))",
68        "(?i)^(SET sql_mode(.*))",
69        "(?i)^(SET SQL_SELECT_LIMIT(.*))",
70        "(?i)^(SET PROFILING(.*))",
71
72        // mysqlclient.
73        "(?i)^(SELECT \\$\\$)",
74
75        // mysqldump.
76        "(?i)^(SET SQL_QUOTE_SHOW_CREATE(.*))",
77        "(?i)^(LOCK TABLES(.*))",
78        "(?i)^(UNLOCK TABLES(.*))",
79        "(?i)^(SELECT LOGFILE_GROUP_NAME, FILE_NAME, TOTAL_EXTENTS, INITIAL_SIZE, ENGINE, EXTRA FROM INFORMATION_SCHEMA.FILES(.*))",
80
81        // mydumper.
82        "(?i)^(/\\*!80003 SET(.*) \\*/)$",
83        "(?i)^(SHOW MASTER STATUS)",
84        "(?i)^(SHOW ALL SLAVES STATUS)",
85        "(?i)^(LOCK BINLOG FOR BACKUP)",
86        "(?i)^(LOCK TABLES FOR BACKUP)",
87        "(?i)^(UNLOCK BINLOG(.*))",
88        "(?i)^(/\\*!40101 SET(.*) \\*/)$",
89
90        // DBeaver.
91        "(?i)^(/\\* ApplicationName=(.*)SHOW PLUGINS)",
92        "(?i)^(/\\* ApplicationName=(.*)SHOW ENGINES)",
93        "(?i)^(/\\* ApplicationName=(.*)SELECT @@(.*))",
94        "(?i)^(/\\* ApplicationName=(.*)SHOW @@(.*))",
95        "(?i)^(/\\* ApplicationName=(.*)SET net_write_timeout(.*))",
96        "(?i)^(/\\* ApplicationName=(.*)SET SQL_SELECT_LIMIT(.*))",
97        "(?i)^(/\\* ApplicationName=(.*)SHOW VARIABLES(.*))",
98
99        // pt-toolkit
100        "(?i)^(/\\*!40101 SET(.*) \\*/)$",
101
102        // mysqldump 5.7.16
103        "(?i)^(/\\*!40100 SET(.*) \\*/)$",
104        "(?i)^(/\\*!40103 SET(.*) \\*/)$",
105        "(?i)^(/\\*!40111 SET(.*) \\*/)$",
106        "(?i)^(/\\*!40101 SET(.*) \\*/)$",
107        "(?i)^(/\\*!40014 SET(.*) \\*/)$",
108        "(?i)^(/\\*!40000 SET(.*) \\*/)$",
109    ]).unwrap()
110});
111
112static VAR_VALUES: Lazy<HashMap<&str, &str>> = Lazy::new(|| {
113    HashMap::from([
114        ("tx_isolation", "REPEATABLE-READ"),
115        ("session.tx_isolation", "REPEATABLE-READ"),
116        ("transaction_isolation", "REPEATABLE-READ"),
117        ("session.transaction_isolation", "REPEATABLE-READ"),
118        ("session.transaction_read_only", "0"),
119        ("max_allowed_packet", "134217728"),
120        ("interactive_timeout", "31536000"),
121        ("wait_timeout", "31536000"),
122        ("net_write_timeout", "31536000"),
123        ("version_comment", common_version::product_name()),
124    ])
125});
126
127// Recordbatches for select function.
128// Format:
129// |function_name|
130// |value|
131fn select_function(name: &str, value: &str) -> RecordBatches {
132    let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
133        name,
134        ConcreteDataType::string_datatype(),
135        true,
136    )]));
137    let columns = vec![Arc::new(StringVector::from(vec![value])) as _];
138    RecordBatches::try_from_columns(schema, columns)
139        // unwrap is safe because the schema and data are definitely able to form a recordbatch, they are all string type
140        .unwrap()
141}
142
143// Recordbatches for show variable statement.
144// Format is:
145// | Variable_name | Value |
146// | xx            | yy    |
147fn show_variables(name: &str, value: &str) -> RecordBatches {
148    let schema = Arc::new(Schema::new(vec![
149        ColumnSchema::new("Variable_name", ConcreteDataType::string_datatype(), true),
150        ColumnSchema::new("Value", ConcreteDataType::string_datatype(), true),
151    ]));
152    let columns = vec![
153        Arc::new(StringVector::from(vec![name])) as _,
154        Arc::new(StringVector::from(vec![value])) as _,
155    ];
156    RecordBatches::try_from_columns(schema, columns)
157        // unwrap is safe because the schema and data are definitely able to form a recordbatch, they are all string type
158        .unwrap()
159}
160
161fn select_variable(query: &str, query_context: QueryContextRef) -> Option<Output> {
162    let mut fields = vec![];
163    let mut values = vec![];
164
165    // query like "SELECT @@aa, @@bb as cc, @dd..."
166    let query = query.to_lowercase();
167    let vars: Vec<&str> = query.split("@@").collect();
168    if vars.len() <= 1 {
169        return None;
170    }
171
172    // skip the first "select"
173    for var in vars.iter().skip(1) {
174        let var = var.trim_matches(|c| c == ' ' || c == ',' || c == ';');
175        let var_as: Vec<&str> = var
176            .split(" as ")
177            .map(|x| {
178                x.trim_matches(|c| c == ' ')
179                    .split_whitespace()
180                    .next()
181                    .unwrap_or("")
182            })
183            .collect();
184
185        // get value of variables from known sources or fallback to defaults
186        let value = match var_as[0] {
187            "session.time_zone" | "time_zone" => query_context.timezone().to_string(),
188            "system_time_zone" => system_timezone_name(),
189            "max_execution_time" | "session.max_execution_time" => {
190                query_context.query_timeout_as_millis().to_string()
191            }
192            _ => VAR_VALUES
193                .get(var_as[0])
194                .map(|v| v.to_string())
195                .unwrap_or_else(|| "0".to_owned()),
196        };
197
198        values.push(Arc::new(StringVector::from(vec![value])) as _);
199        match var_as.len() {
200            1 => {
201                // @@aa
202                // field is '@@aa'
203                fields.push(ColumnSchema::new(
204                    format!("@@{}", var_as[0]),
205                    ConcreteDataType::string_datatype(),
206                    true,
207                ));
208            }
209            2 => {
210                // @@bb as cc:
211                // var is 'bb'.
212                // field is 'cc'.
213                fields.push(ColumnSchema::new(
214                    var_as[1],
215                    ConcreteDataType::string_datatype(),
216                    true,
217                ));
218            }
219            _ => return None,
220        }
221    }
222
223    let schema = Arc::new(Schema::new(fields));
224    // unwrap is safe because the schema and data are definitely able to form a recordbatch, they are all string type
225    let batches = RecordBatches::try_from_columns(schema, values).unwrap();
226    Some(Output::new_with_record_batches(batches))
227}
228
229fn check_select_variable(query: &str, query_context: QueryContextRef) -> Option<Output> {
230    if [&SELECT_VAR_PATTERN, &MYSQL_CONN_JAVA_PATTERN]
231        .iter()
232        .any(|r| r.is_match(query))
233    {
234        select_variable(query, query_context)
235    } else {
236        None
237    }
238}
239
240fn check_show_variables(query: &str) -> Option<Output> {
241    let recordbatches = if SHOW_SQL_MODE_PATTERN.is_match(query) {
242        Some(show_variables(
243            "sql_mode",
244            "ONLY_FULL_GROUP_BY STRICT_TRANS_TABLES NO_ZERO_IN_DATE NO_ZERO_DATE ERROR_FOR_DIVISION_BY_ZERO NO_ENGINE_SUBSTITUTION",
245        ))
246    } else if SHOW_LOWER_CASE_PATTERN.is_match(query) {
247        Some(show_variables("lower_case_table_names", "0"))
248    } else if SHOW_VARIABLES_LIKE_PATTERN.is_match(query) {
249        Some(show_variables("", ""))
250    } else {
251        None
252    };
253    recordbatches.map(Output::new_with_record_batches)
254}
255
256/// Build SHOW WARNINGS result from session's warnings
257fn show_warnings(session: &SessionRef) -> RecordBatches {
258    let schema = Arc::new(Schema::new(vec![
259        ColumnSchema::new("Level", ConcreteDataType::string_datatype(), false),
260        ColumnSchema::new("Code", ConcreteDataType::uint16_datatype(), false),
261        ColumnSchema::new("Message", ConcreteDataType::string_datatype(), false),
262    ]));
263
264    let warnings = session.warnings();
265    let count = warnings.len();
266
267    let columns = if count > 0 {
268        vec![
269            Arc::new(StringVector::from(vec!["Warning"; count])) as _,
270            Arc::new(datatypes::vectors::UInt16Vector::from(vec![
271                Some(1000u16);
272                count
273            ])) as _,
274            Arc::new(StringVector::from(warnings)) as _,
275        ]
276    } else {
277        vec![
278            Arc::new(StringVector::from(Vec::<String>::new())) as _,
279            Arc::new(datatypes::vectors::UInt16Vector::from(
280                Vec::<Option<u16>>::new(),
281            )) as _,
282            Arc::new(StringVector::from(Vec::<String>::new())) as _,
283        ]
284    };
285
286    RecordBatches::try_from_columns(schema, columns).unwrap()
287}
288
289fn check_show_warnings(query: &str, session: &SessionRef) -> Option<Output> {
290    if SHOW_WARNINGS_PATTERN.is_match(query) {
291        Some(Output::new_with_record_batches(show_warnings(session)))
292    } else {
293        None
294    }
295}
296
297// Check for SET or others query, this is the final check of the federated query.
298fn check_others(query: &str, _query_ctx: QueryContextRef) -> Option<Output> {
299    if OTHER_NOT_SUPPORTED_STMT.is_match(query.as_bytes()) {
300        return Some(Output::new_with_record_batches(RecordBatches::empty()));
301    }
302
303    let recordbatches = if SELECT_TIME_DIFF_FUNC_PATTERN.is_match(query) {
304        Some(select_function(
305            "TIMEDIFF(NOW(), UTC_TIMESTAMP())",
306            "00:00:00",
307        ))
308    } else {
309        None
310    };
311    recordbatches.map(Output::new_with_record_batches)
312}
313
314// Check whether the query is a federated or driver setup command,
315// and return some faked results if there are any.
316pub(crate) fn check(
317    query: &str,
318    query_ctx: QueryContextRef,
319    session: SessionRef,
320) -> Option<Output> {
321    // INSERT don't need MySQL federated check. We assume the query doesn't contain
322    // federated or driver setup command if it starts with a 'INSERT' statement.
323    let the_6th_index = query.char_indices().nth(6).map(|(i, _)| i);
324    if let Some(index) = the_6th_index
325        && query[..index].eq_ignore_ascii_case("INSERT")
326    {
327        return None;
328    }
329
330    // First to check the query is like "select @@variables".
331    check_select_variable(query, query_ctx.clone())
332        .or_else(|| check_show_variables(query))
333        .or_else(|| check_show_warnings(query, &session))
334        // Last check
335        .or_else(|| check_others(query, query_ctx))
336}
337
338#[cfg(test)]
339mod test {
340
341    use common_query::OutputData;
342    use common_time::timezone::set_default_timezone;
343    use session::Session;
344    use session::context::{Channel, QueryContext};
345
346    use super::*;
347
348    #[test]
349    fn test_check_abnormal() {
350        let session = Arc::new(Session::new(None, Channel::Mysql, Default::default(), 0));
351        let query = "🫣一点不正常的东西🫣";
352        let output = check(query, QueryContext::arc(), session.clone());
353
354        assert!(output.is_none());
355    }
356
357    #[test]
358    fn test_check() {
359        let session = Arc::new(Session::new(None, Channel::Mysql, Default::default(), 0));
360        let query = "select 1";
361        let result = check(query, QueryContext::arc(), session.clone());
362        assert!(result.is_none());
363
364        let query = "select version";
365        let output = check(query, QueryContext::arc(), session.clone());
366        assert!(output.is_none());
367
368        fn test(query: &str, expected: &str) {
369            let session = Arc::new(Session::new(None, Channel::Mysql, Default::default(), 0));
370            let output = check(query, QueryContext::arc(), session.clone());
371            match output.unwrap().data {
372                OutputData::RecordBatches(r) => {
373                    assert_eq!(&r.pretty_print().unwrap(), expected)
374                }
375                _ => unreachable!(),
376            }
377        }
378
379        let query = "SELECT @@version_comment LIMIT 1";
380        let expected = "\
381+-------------------+
382| @@version_comment |
383+-------------------+
384| GreptimeDB        |
385+-------------------+";
386        test(query, expected);
387
388        // variables
389        let query = "select @@tx_isolation, @@session.tx_isolation";
390        let expected = "\
391+-----------------+------------------------+
392| @@tx_isolation  | @@session.tx_isolation |
393+-----------------+------------------------+
394| REPEATABLE-READ | REPEATABLE-READ        |
395+-----------------+------------------------+";
396        test(query, expected);
397
398        // set system timezone
399        set_default_timezone(Some("Asia/Shanghai")).unwrap();
400        // complex variables
401        let query = "/* mysql-connector-java-8.0.17 (Revision: 16a712ddb3f826a1933ab42b0039f7fb9eebc6ec) */SELECT  @@session.auto_increment_increment AS auto_increment_increment, @@character_set_client AS character_set_client, @@character_set_connection AS character_set_connection, @@character_set_results AS character_set_results, @@character_set_server AS character_set_server, @@collation_server AS collation_server, @@collation_connection AS collation_connection, @@init_connect AS init_connect, @@interactive_timeout AS interactive_timeout, @@license AS license, @@lower_case_table_names AS lower_case_table_names, @@max_allowed_packet AS max_allowed_packet, @@net_write_timeout AS net_write_timeout, @@performance_schema AS performance_schema, @@sql_mode AS sql_mode, @@system_time_zone AS system_time_zone, @@time_zone AS time_zone, @@transaction_isolation AS transaction_isolation, @@wait_timeout AS wait_timeout;";
402        let expected = "\
403+--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+------------------+---------------+-----------------------+--------------+
404| auto_increment_increment | character_set_client | character_set_connection | character_set_results | character_set_server | collation_server | collation_connection | init_connect | interactive_timeout | license | lower_case_table_names | max_allowed_packet | net_write_timeout | performance_schema | sql_mode | system_time_zone | time_zone     | transaction_isolation | wait_timeout |
405+--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+------------------+---------------+-----------------------+--------------+
406| 0                        | 0                    | 0                        | 0                     | 0                    | 0                | 0                    | 0            | 31536000            | 0       | 0                      | 134217728          | 31536000          | 0                  | 0        | Asia/Shanghai    | Asia/Shanghai | REPEATABLE-READ       | 31536000     |
407+--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+------------------+---------------+-----------------------+--------------+";
408        test(query, expected);
409
410        let query = "show variables";
411        let expected = "\
412+---------------+-------+
413| Variable_name | Value |
414+---------------+-------+
415|               |       |
416+---------------+-------+";
417        test(query, expected);
418
419        let query = "show variables like 'lower_case_table_names'";
420        let expected = "\
421+------------------------+-------+
422| Variable_name          | Value |
423+------------------------+-------+
424| lower_case_table_names | 0     |
425+------------------------+-------+";
426        test(query, expected);
427
428        let query = "SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP())";
429        let expected = "\
430+----------------------------------+
431| TIMEDIFF(NOW(), UTC_TIMESTAMP()) |
432+----------------------------------+
433| 00:00:00                         |
434+----------------------------------+";
435        test(query, expected);
436    }
437
438    #[test]
439    fn test_show_warnings() {
440        // Test SHOW WARNINGS with no warnings
441        let session = Arc::new(Session::new(None, Channel::Mysql, Default::default(), 0));
442        let output = check("SHOW WARNINGS", QueryContext::arc(), session.clone());
443        match output.unwrap().data {
444            OutputData::RecordBatches(r) => {
445                assert_eq!(r.iter().map(|b| b.num_rows()).sum::<usize>(), 0);
446            }
447            _ => unreachable!(),
448        }
449
450        // Test SHOW WARNINGS with a single warning
451        session.add_warning("Test warning message".to_string());
452        let output = check("SHOW WARNINGS", QueryContext::arc(), session.clone());
453        match output.unwrap().data {
454            OutputData::RecordBatches(r) => {
455                let expected = "\
456+---------+------+----------------------+
457| Level   | Code | Message              |
458+---------+------+----------------------+
459| Warning | 1000 | Test warning message |
460+---------+------+----------------------+";
461                assert_eq!(&r.pretty_print().unwrap(), expected);
462            }
463            _ => unreachable!(),
464        }
465
466        // Test SHOW WARNINGS with multiple warnings
467        session.clear_warnings();
468        session.add_warning("First warning".to_string());
469        session.add_warning("Second warning".to_string());
470        let output = check("SHOW WARNINGS", QueryContext::arc(), session.clone());
471        match output.unwrap().data {
472            OutputData::RecordBatches(r) => {
473                let expected = "\
474+---------+------+----------------+
475| Level   | Code | Message        |
476+---------+------+----------------+
477| Warning | 1000 | First warning  |
478| Warning | 1000 | Second warning |
479+---------+------+----------------+";
480                assert_eq!(&r.pretty_print().unwrap(), expected);
481            }
482            _ => unreachable!(),
483        }
484
485        // Test case insensitivity
486        let output = check("show warnings", QueryContext::arc(), session.clone());
487        assert!(output.is_some());
488
489        // Test with DBeaver-style comment prefix
490        let output = check(
491            "/* ApplicationName=DBeaver */SHOW WARNINGS",
492            QueryContext::arc(),
493            session.clone(),
494        );
495        assert!(output.is_some());
496    }
497}