1use std::net::SocketAddr;
16use std::time::Duration;
17
18use client::error::ServerSnafu;
19use client::{
20 Client, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, Database, OutputData, RecordBatches,
21};
22use common_error::ext::ErrorExt;
23use common_query::Output;
24use mysql::prelude::Queryable;
25use mysql::{Conn as MySqlClient, Row as MySqlRow};
26use tokio_postgres::{Client as PgClient, SimpleQueryMessage as PgRow};
27
28use crate::util::retry_with_backoff;
29
30pub struct MultiProtocolClient {
32 grpc_client: Database,
33 pg_client: PgClient,
34 mysql_client: MySqlClient,
35}
36
37pub enum MysqlSqlResult {
39 AffectedRows(u64),
40 Rows(Vec<MySqlRow>),
41}
42
43impl MultiProtocolClient {
44 pub async fn connect(
60 grpc_server_addr: &str,
61 pg_server_addr: &str,
62 mysql_server_addr: &str,
63 ) -> MultiProtocolClient {
64 let grpc_client = Database::new(
65 DEFAULT_CATALOG_NAME,
66 DEFAULT_SCHEMA_NAME,
67 Client::with_urls(vec![grpc_server_addr]),
68 );
69 let pg_client = create_postgres_client(pg_server_addr).await;
70 let mysql_client = create_mysql_client(mysql_server_addr).await;
71 MultiProtocolClient {
72 grpc_client,
73 pg_client,
74 mysql_client,
75 }
76 }
77
78 pub async fn reconnect_mysql_client(&mut self, mysql_server_addr: &str) {
80 self.mysql_client = create_mysql_client(mysql_server_addr).await;
81 }
82
83 pub async fn reconnect_pg_client(&mut self, pg_server_addr: &str) {
85 self.pg_client = create_postgres_client(pg_server_addr).await;
86 }
87
88 pub async fn postgres_query(&mut self, query: &str) -> Result<Vec<PgRow>, String> {
90 match self.pg_client.simple_query(query).await {
91 Ok(rows) => Ok(rows),
92 Err(e) => Err(format!("Failed to execute query, encountered: {:?}", e)),
93 }
94 }
95
96 pub async fn mysql_query(&mut self, query: &str) -> Result<MysqlSqlResult, String> {
98 let result = self.mysql_client.query_iter(query);
99 match result {
100 Ok(result) => {
101 let mut rows = vec![];
102 let affected_rows = result.affected_rows();
103 for row in result {
104 match row {
105 Ok(r) => rows.push(r),
106 Err(e) => {
107 return Err(format!("Failed to parse query result, err: {:?}", e));
108 }
109 }
110 }
111
112 if rows.is_empty() {
113 Ok(MysqlSqlResult::AffectedRows(affected_rows))
114 } else {
115 Ok(MysqlSqlResult::Rows(rows))
116 }
117 }
118 Err(e) => Err(format!("Failed to execute query, err: {:?}", e)),
119 }
120 }
121
122 pub async fn grpc_query(&mut self, query: &str) -> Result<Output, client::Error> {
124 let query_str = query.trim().to_lowercase();
125 if query_str.starts_with("use ") {
126 let database = query
128 .split_ascii_whitespace()
129 .nth(1)
130 .expect("Illegal `USE` statement: expecting a database.")
131 .trim_end_matches(';');
132 self.grpc_client.set_schema(database);
133 Ok(Output::new_with_affected_rows(0))
134 } else if query_str.starts_with("set time_zone")
135 || query_str.starts_with("set session time_zone")
136 || query_str.starts_with("set local time_zone")
137 {
138 let timezone = query
140 .split('=')
141 .nth(1)
142 .expect("Illegal `SET TIMEZONE` statement: expecting a timezone expr.")
143 .trim()
144 .strip_prefix('\'')
145 .unwrap()
146 .strip_suffix("';")
147 .unwrap();
148
149 self.grpc_client.set_timezone(timezone);
150 Ok(Output::new_with_affected_rows(0))
151 } else {
152 let mut result = self.grpc_client.sql(&query).await;
153 if let Ok(Output {
154 data: OutputData::Stream(stream),
155 ..
156 }) = result
157 {
158 match RecordBatches::try_collect(stream).await {
159 Ok(recordbatches) => {
160 result = Ok(Output::new_with_record_batches(recordbatches));
161 }
162 Err(e) => {
163 let status_code = e.status_code();
164 let msg = e.output_msg();
165 result = ServerSnafu {
166 code: status_code,
167 msg,
168 }
169 .fail();
170 }
171 }
172 }
173
174 result
175 }
176 }
177}
178
179async fn create_postgres_client(pg_server_addr: &str) -> PgClient {
185 let sockaddr: SocketAddr = pg_server_addr.parse().unwrap_or_else(|_| {
186 panic!("Failed to parse the Postgres server address {pg_server_addr}. Please check if the address is in the format of `ip:port`.")
187 });
188 let mut config = tokio_postgres::config::Config::new();
189 config.host(sockaddr.ip().to_string());
190 config.port(sockaddr.port());
191 config.dbname(DEFAULT_SCHEMA_NAME);
192
193 retry_with_backoff(
194 || async {
195 config
196 .connect(tokio_postgres::NoTls)
197 .await
198 .map(|(client, conn)| {
199 tokio::spawn(conn);
200 client
201 })
202 },
203 3,
204 Duration::from_millis(500),
205 )
206 .await
207 .unwrap_or_else(|_| {
208 panic!("Failed to connect to Postgres server. Please check if the server is running.")
209 })
210}
211
212async fn create_mysql_client(mysql_server_addr: &str) -> MySqlClient {
218 let sockaddr: SocketAddr = mysql_server_addr.parse().unwrap_or_else(|_| {
219 panic!("Failed to parse the MySQL server address {mysql_server_addr}. Please check if the address is in the format of `ip:port`.")
220 });
221 let ops = mysql::OptsBuilder::new()
222 .ip_or_hostname(Some(sockaddr.ip().to_string()))
223 .tcp_port(sockaddr.port())
224 .db_name(Some(DEFAULT_SCHEMA_NAME));
225
226 retry_with_backoff(
227 || async { mysql::Conn::new(ops.clone()) },
228 3,
229 Duration::from_millis(500),
230 )
231 .await
232 .unwrap_or_else(|_| {
233 panic!("Failed to connect to MySQL server. Please check if the server is running.")
234 })
235}