servers/postgres/
fixtures.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::collections::HashMap;
17use std::sync::Arc;
18
19use futures::stream;
20use once_cell::sync::Lazy;
21use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
22use pgwire::api::Type;
23use pgwire::error::PgWireResult;
24use pgwire::messages::data::DataRow;
25use regex::Regex;
26use session::context::QueryContextRef;
27
28fn build_string_data_rows(
29    schema: Arc<Vec<FieldInfo>>,
30    rows: Vec<Vec<String>>,
31) -> Vec<PgWireResult<DataRow>> {
32    rows.iter()
33        .map(|row| {
34            let mut encoder = DataRowEncoder::new(schema.clone());
35            for value in row {
36                encoder.encode_field(&Some(value))?;
37            }
38            encoder.finish()
39        })
40        .collect()
41}
42
43static VAR_VALUES: Lazy<HashMap<&str, &str>> = Lazy::new(|| {
44    HashMap::from([
45        ("default_transaction_isolation", "read committed"),
46        ("transaction isolation level", "read committed"),
47        ("standard_conforming_strings", "on"),
48        ("client_encoding", "UTF8"),
49    ])
50});
51
52static SHOW_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new("(?i)^SHOW (.*?);?$").unwrap());
53static SET_TRANSACTION_PATTERN: Lazy<Regex> =
54    Lazy::new(|| Regex::new("(?i)^SET TRANSACTION (.*?);?$").unwrap());
55static START_TRANSACTION_PATTERN: Lazy<Regex> =
56    Lazy::new(|| Regex::new("(?i)^(START TRANSACTION.*|BEGIN);?").unwrap());
57static COMMIT_TRANSACTION_PATTERN: Lazy<Regex> =
58    Lazy::new(|| Regex::new("(?i)^(COMMIT TRANSACTION|COMMIT);?").unwrap());
59static ABORT_TRANSACTION_PATTERN: Lazy<Regex> =
60    Lazy::new(|| Regex::new("(?i)^(ABORT TRANSACTION|ROLLBACK);?").unwrap());
61
62/// Test if given query statement matches the patterns
63pub(crate) fn matches(query: &str) -> bool {
64    START_TRANSACTION_PATTERN.is_match(query)
65        || COMMIT_TRANSACTION_PATTERN.is_match(query)
66        || ABORT_TRANSACTION_PATTERN.is_match(query)
67        || SHOW_PATTERN.captures(query).is_some()
68        || SET_TRANSACTION_PATTERN.is_match(query)
69}
70
71fn set_transaction_warning(query_ctx: QueryContextRef) {
72    query_ctx.set_warning("Please note transaction is not supported in GreptimeDB.".to_string());
73}
74
75/// Process unsupported SQL and return fixed result as a compatibility solution
76pub(crate) fn process<'a>(query: &str, query_ctx: QueryContextRef) -> Option<Vec<Response<'a>>> {
77    // Transaction directives:
78    if START_TRANSACTION_PATTERN.is_match(query) {
79        set_transaction_warning(query_ctx);
80        if query.to_lowercase().starts_with("begin") {
81            Some(vec![Response::TransactionStart(Tag::new("BEGIN"))])
82        } else {
83            Some(vec![Response::TransactionStart(Tag::new(
84                "START TRANSACTION",
85            ))])
86        }
87    } else if ABORT_TRANSACTION_PATTERN.is_match(query) {
88        Some(vec![Response::TransactionEnd(Tag::new("ROLLBACK"))])
89    } else if COMMIT_TRANSACTION_PATTERN.is_match(query) {
90        Some(vec![Response::TransactionEnd(Tag::new("COMMIT"))])
91    } else if let Some(show_var) = SHOW_PATTERN.captures(query) {
92        let show_var = show_var[1].to_lowercase();
93        if let Some(value) = VAR_VALUES.get(&show_var.as_ref()) {
94            let f1 = FieldInfo::new(
95                show_var.clone(),
96                None,
97                None,
98                Type::VARCHAR,
99                FieldFormat::Text,
100            );
101            let schema = Arc::new(vec![f1]);
102            let data = stream::iter(build_string_data_rows(
103                schema.clone(),
104                vec![vec![value.to_string()]],
105            ));
106
107            Some(vec![Response::Query(QueryResponse::new(schema, data))])
108        } else {
109            None
110        }
111    } else if SET_TRANSACTION_PATTERN.is_match(query) {
112        Some(vec![Response::Execution(Tag::new("SET"))])
113    } else {
114        None
115    }
116}
117
118pub(crate) fn rewrite_sql(query: &str) -> Cow<'_, str> {
119    // DBeaver tricky replacement for datafusion not support sql
120    // TODO: add more here
121    query
122        .replace(
123            "SELECT db.oid,db.* FROM pg_catalog.pg_database db",
124            "SELECT db.oid as _oid,db.* FROM pg_catalog.pg_database db",
125        )
126        .into()
127}
128
129#[cfg(test)]
130mod test {
131    use session::context::{QueryContext, QueryContextRef};
132
133    use super::*;
134
135    fn assert_tag(q: &str, t: &str, query_context: QueryContextRef) {
136        if let Response::Execution(tag)
137        | Response::TransactionStart(tag)
138        | Response::TransactionEnd(tag) = process(q, query_context.clone())
139            .unwrap_or_else(|| panic!("fail to match {}", q))
140            .remove(0)
141        {
142            assert_eq!(Tag::new(t), tag);
143        } else {
144            panic!("Invalid response");
145        }
146    }
147
148    fn get_data<'a>(q: &str, query_context: QueryContextRef) -> QueryResponse<'a> {
149        if let Response::Query(resp) = process(q, query_context.clone())
150            .unwrap_or_else(|| panic!("fail to match {}", q))
151            .remove(0)
152        {
153            resp
154        } else {
155            panic!("Invalid response");
156        }
157    }
158
159    #[test]
160    fn test_process() {
161        let query_context = QueryContext::arc();
162
163        assert_tag("BEGIN", "BEGIN", query_context.clone());
164        assert_tag("BEGIN;", "BEGIN", query_context.clone());
165        assert_tag("begin;", "BEGIN", query_context.clone());
166        assert_tag("ROLLBACK", "ROLLBACK", query_context.clone());
167        assert_tag("ROLLBACK;", "ROLLBACK", query_context.clone());
168        assert_tag("rollback;", "ROLLBACK", query_context.clone());
169        assert_tag("COMMIT", "COMMIT", query_context.clone());
170        assert_tag("COMMIT;", "COMMIT", query_context.clone());
171        assert_tag("commit;", "COMMIT", query_context.clone());
172        assert_tag(
173            "SET TRANSACTION ISOLATION LEVEL READ COMMITTED",
174            "SET",
175            query_context.clone(),
176        );
177        assert_tag(
178            "SET TRANSACTION ISOLATION LEVEL READ COMMITTED;",
179            "SET",
180            query_context.clone(),
181        );
182        assert_tag(
183            "SET transaction isolation level READ COMMITTED;",
184            "SET",
185            query_context.clone(),
186        );
187        assert_tag(
188            "START TRANSACTION isolation level READ COMMITTED;",
189            "START TRANSACTION",
190            query_context.clone(),
191        );
192        assert_tag(
193            "start transaction isolation level READ COMMITTED;",
194            "START TRANSACTION",
195            query_context.clone(),
196        );
197        assert_tag("abort transaction;", "ROLLBACK", query_context.clone());
198        assert_tag("commit transaction;", "COMMIT", query_context.clone());
199        assert_tag("COMMIT transaction;", "COMMIT", query_context.clone());
200
201        let resp = get_data("SHOW transaction isolation level", query_context.clone());
202        assert_eq!(1, resp.row_schema().len());
203        let resp = get_data("show client_encoding;", query_context.clone());
204        assert_eq!(1, resp.row_schema().len());
205        let resp = get_data("show standard_conforming_strings;", query_context.clone());
206        assert_eq!(1, resp.row_schema().len());
207        let resp = get_data("show default_transaction_isolation", query_context.clone());
208        assert_eq!(1, resp.row_schema().len());
209
210        assert!(process("SELECT 1", query_context.clone()).is_none());
211        assert!(process("SHOW TABLES ", query_context.clone()).is_none());
212        assert!(process("SET TIME_ZONE=utc ", query_context.clone()).is_none());
213    }
214
215    #[test]
216    fn test_rewrite() {
217        assert_eq!(
218            "SELECT db.oid as _oid,db.* FROM pg_catalog.pg_database db",
219            rewrite_sql("SELECT db.oid,db.* FROM pg_catalog.pg_database db")
220        );
221    }
222}