1use std::borrow::Cow;
16use std::collections::HashMap;
17use std::sync::Arc;
18
19use futures::stream;
20use once_cell::sync::Lazy;
21use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
22use pgwire::api::Type;
23use pgwire::error::PgWireResult;
24use pgwire::messages::data::DataRow;
25use regex::Regex;
26use session::context::QueryContextRef;
27
28fn build_string_data_rows(
29 schema: Arc<Vec<FieldInfo>>,
30 rows: Vec<Vec<String>>,
31) -> Vec<PgWireResult<DataRow>> {
32 rows.iter()
33 .map(|row| {
34 let mut encoder = DataRowEncoder::new(schema.clone());
35 for value in row {
36 encoder.encode_field(&Some(value))?;
37 }
38 encoder.finish()
39 })
40 .collect()
41}
42
43static VAR_VALUES: Lazy<HashMap<&str, &str>> = Lazy::new(|| {
44 HashMap::from([
45 ("default_transaction_isolation", "read committed"),
46 ("transaction isolation level", "read committed"),
47 ("standard_conforming_strings", "on"),
48 ("client_encoding", "UTF8"),
49 ])
50});
51
52static SHOW_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new("(?i)^SHOW (.*?);?$").unwrap());
53static SET_TRANSACTION_PATTERN: Lazy<Regex> =
54 Lazy::new(|| Regex::new("(?i)^SET TRANSACTION (.*?);?$").unwrap());
55static START_TRANSACTION_PATTERN: Lazy<Regex> =
56 Lazy::new(|| Regex::new("(?i)^(START TRANSACTION.*|BEGIN);?").unwrap());
57static COMMIT_TRANSACTION_PATTERN: Lazy<Regex> =
58 Lazy::new(|| Regex::new("(?i)^(COMMIT TRANSACTION|COMMIT);?").unwrap());
59static ABORT_TRANSACTION_PATTERN: Lazy<Regex> =
60 Lazy::new(|| Regex::new("(?i)^(ABORT TRANSACTION|ROLLBACK);?").unwrap());
61
62pub(crate) fn matches(query: &str) -> bool {
64 START_TRANSACTION_PATTERN.is_match(query)
65 || COMMIT_TRANSACTION_PATTERN.is_match(query)
66 || ABORT_TRANSACTION_PATTERN.is_match(query)
67 || SHOW_PATTERN.captures(query).is_some()
68 || SET_TRANSACTION_PATTERN.is_match(query)
69}
70
71fn set_transaction_warning(query_ctx: QueryContextRef) {
72 query_ctx.set_warning("Please note transaction is not supported in GreptimeDB.".to_string());
73}
74
75pub(crate) fn process<'a>(query: &str, query_ctx: QueryContextRef) -> Option<Vec<Response<'a>>> {
77 if START_TRANSACTION_PATTERN.is_match(query) {
79 set_transaction_warning(query_ctx);
80 if query.to_lowercase().starts_with("begin") {
81 Some(vec![Response::TransactionStart(Tag::new("BEGIN"))])
82 } else {
83 Some(vec![Response::TransactionStart(Tag::new(
84 "START TRANSACTION",
85 ))])
86 }
87 } else if ABORT_TRANSACTION_PATTERN.is_match(query) {
88 Some(vec![Response::TransactionEnd(Tag::new("ROLLBACK"))])
89 } else if COMMIT_TRANSACTION_PATTERN.is_match(query) {
90 Some(vec![Response::TransactionEnd(Tag::new("COMMIT"))])
91 } else if let Some(show_var) = SHOW_PATTERN.captures(query) {
92 let show_var = show_var[1].to_lowercase();
93 if let Some(value) = VAR_VALUES.get(&show_var.as_ref()) {
94 let f1 = FieldInfo::new(
95 show_var.clone(),
96 None,
97 None,
98 Type::VARCHAR,
99 FieldFormat::Text,
100 );
101 let schema = Arc::new(vec![f1]);
102 let data = stream::iter(build_string_data_rows(
103 schema.clone(),
104 vec![vec![value.to_string()]],
105 ));
106
107 Some(vec![Response::Query(QueryResponse::new(schema, data))])
108 } else {
109 None
110 }
111 } else if SET_TRANSACTION_PATTERN.is_match(query) {
112 Some(vec![Response::Execution(Tag::new("SET"))])
113 } else {
114 None
115 }
116}
117
118pub(crate) fn rewrite_sql(query: &str) -> Cow<'_, str> {
119 query
122 .replace(
123 "SELECT db.oid,db.* FROM pg_catalog.pg_database db",
124 "SELECT db.oid as _oid,db.* FROM pg_catalog.pg_database db",
125 )
126 .into()
127}
128
129#[cfg(test)]
130mod test {
131 use session::context::{QueryContext, QueryContextRef};
132
133 use super::*;
134
135 fn assert_tag(q: &str, t: &str, query_context: QueryContextRef) {
136 if let Response::Execution(tag)
137 | Response::TransactionStart(tag)
138 | Response::TransactionEnd(tag) = process(q, query_context.clone())
139 .unwrap_or_else(|| panic!("fail to match {}", q))
140 .remove(0)
141 {
142 assert_eq!(Tag::new(t), tag);
143 } else {
144 panic!("Invalid response");
145 }
146 }
147
148 fn get_data<'a>(q: &str, query_context: QueryContextRef) -> QueryResponse<'a> {
149 if let Response::Query(resp) = process(q, query_context.clone())
150 .unwrap_or_else(|| panic!("fail to match {}", q))
151 .remove(0)
152 {
153 resp
154 } else {
155 panic!("Invalid response");
156 }
157 }
158
159 #[test]
160 fn test_process() {
161 let query_context = QueryContext::arc();
162
163 assert_tag("BEGIN", "BEGIN", query_context.clone());
164 assert_tag("BEGIN;", "BEGIN", query_context.clone());
165 assert_tag("begin;", "BEGIN", query_context.clone());
166 assert_tag("ROLLBACK", "ROLLBACK", query_context.clone());
167 assert_tag("ROLLBACK;", "ROLLBACK", query_context.clone());
168 assert_tag("rollback;", "ROLLBACK", query_context.clone());
169 assert_tag("COMMIT", "COMMIT", query_context.clone());
170 assert_tag("COMMIT;", "COMMIT", query_context.clone());
171 assert_tag("commit;", "COMMIT", query_context.clone());
172 assert_tag(
173 "SET TRANSACTION ISOLATION LEVEL READ COMMITTED",
174 "SET",
175 query_context.clone(),
176 );
177 assert_tag(
178 "SET TRANSACTION ISOLATION LEVEL READ COMMITTED;",
179 "SET",
180 query_context.clone(),
181 );
182 assert_tag(
183 "SET transaction isolation level READ COMMITTED;",
184 "SET",
185 query_context.clone(),
186 );
187 assert_tag(
188 "START TRANSACTION isolation level READ COMMITTED;",
189 "START TRANSACTION",
190 query_context.clone(),
191 );
192 assert_tag(
193 "start transaction isolation level READ COMMITTED;",
194 "START TRANSACTION",
195 query_context.clone(),
196 );
197 assert_tag("abort transaction;", "ROLLBACK", query_context.clone());
198 assert_tag("commit transaction;", "COMMIT", query_context.clone());
199 assert_tag("COMMIT transaction;", "COMMIT", query_context.clone());
200
201 let resp = get_data("SHOW transaction isolation level", query_context.clone());
202 assert_eq!(1, resp.row_schema().len());
203 let resp = get_data("show client_encoding;", query_context.clone());
204 assert_eq!(1, resp.row_schema().len());
205 let resp = get_data("show standard_conforming_strings;", query_context.clone());
206 assert_eq!(1, resp.row_schema().len());
207 let resp = get_data("show default_transaction_isolation", query_context.clone());
208 assert_eq!(1, resp.row_schema().len());
209
210 assert!(process("SELECT 1", query_context.clone()).is_none());
211 assert!(process("SHOW TABLES ", query_context.clone()).is_none());
212 assert!(process("SET TIME_ZONE=utc ", query_context.clone()).is_none());
213 }
214
215 #[test]
216 fn test_rewrite() {
217 assert_eq!(
218 "SELECT db.oid as _oid,db.* FROM pg_catalog.pg_database db",
219 rewrite_sql("SELECT db.oid,db.* FROM pg_catalog.pg_database db")
220 );
221 }
222}