sqlness_runner/
client.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::net::SocketAddr;
16use std::time::Duration;
17
18use client::error::ServerSnafu;
19use client::{
20    Client, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, Database, OutputData, RecordBatches,
21};
22use common_error::ext::ErrorExt;
23use common_query::Output;
24use mysql::prelude::Queryable;
25use mysql::{Conn as MySqlClient, Row as MySqlRow};
26use tokio_postgres::{Client as PgClient, SimpleQueryMessage as PgRow};
27
28use crate::util::retry_with_backoff;
29
30/// A client that can connect to GreptimeDB using multiple protocols.
31pub struct MultiProtocolClient {
32    grpc_client: Database,
33    pg_client: PgClient,
34    mysql_client: MySqlClient,
35}
36
37/// The result of a MySQL query.
38pub enum MysqlSqlResult {
39    AffectedRows(u64),
40    Rows(Vec<MySqlRow>),
41}
42
43impl MultiProtocolClient {
44    /// Connect to the GreptimeDB server using multiple protocols.
45    ///
46    /// # Arguments
47    ///
48    /// * `grpc_server_addr` - The address of the GreptimeDB server.
49    /// * `pg_server_addr` - The address of the Postgres server.
50    /// * `mysql_server_addr` - The address of the MySQL server.
51    ///
52    /// # Returns
53    ///
54    /// A `MultiProtocolClient` instance.
55    ///
56    /// # Panics
57    ///
58    /// Panics if the server addresses are invalid or the connection fails.
59    pub async fn connect(
60        grpc_server_addr: &str,
61        pg_server_addr: &str,
62        mysql_server_addr: &str,
63    ) -> MultiProtocolClient {
64        let grpc_client = Database::new(
65            DEFAULT_CATALOG_NAME,
66            DEFAULT_SCHEMA_NAME,
67            Client::with_urls(vec![grpc_server_addr]),
68        );
69        let pg_client = create_postgres_client(pg_server_addr).await;
70        let mysql_client = create_mysql_client(mysql_server_addr).await;
71        MultiProtocolClient {
72            grpc_client,
73            pg_client,
74            mysql_client,
75        }
76    }
77
78    /// Reconnect the MySQL client.
79    pub async fn reconnect_mysql_client(&mut self, mysql_server_addr: &str) {
80        self.mysql_client = create_mysql_client(mysql_server_addr).await;
81    }
82
83    /// Reconnect the Postgres client.
84    pub async fn reconnect_pg_client(&mut self, pg_server_addr: &str) {
85        self.pg_client = create_postgres_client(pg_server_addr).await;
86    }
87
88    /// Execute a query on the Postgres server.
89    pub async fn postgres_query(&mut self, query: &str) -> Result<Vec<PgRow>, String> {
90        match self.pg_client.simple_query(query).await {
91            Ok(rows) => Ok(rows),
92            Err(e) => Err(format!("Failed to execute query, encountered: {:?}", e)),
93        }
94    }
95
96    /// Execute a query on the MySQL server.
97    pub async fn mysql_query(&mut self, query: &str) -> Result<MysqlSqlResult, String> {
98        let result = self.mysql_client.query_iter(query);
99        match result {
100            Ok(result) => {
101                let mut rows = vec![];
102                let affected_rows = result.affected_rows();
103                for row in result {
104                    match row {
105                        Ok(r) => rows.push(r),
106                        Err(e) => {
107                            return Err(format!("Failed to parse query result, err: {:?}", e));
108                        }
109                    }
110                }
111
112                if rows.is_empty() {
113                    Ok(MysqlSqlResult::AffectedRows(affected_rows))
114                } else {
115                    Ok(MysqlSqlResult::Rows(rows))
116                }
117            }
118            Err(e) => Err(format!("Failed to execute query, err: {:?}", e)),
119        }
120    }
121
122    /// Execute a query on the GreptimeDB server.
123    pub async fn grpc_query(&mut self, query: &str) -> Result<Output, client::Error> {
124        let query_str = query.trim().to_lowercase();
125        if query_str.starts_with("use ") {
126            // use [db]
127            let database = query
128                .split_ascii_whitespace()
129                .nth(1)
130                .expect("Illegal `USE` statement: expecting a database.")
131                .trim_end_matches(';');
132            self.grpc_client.set_schema(database);
133            Ok(Output::new_with_affected_rows(0))
134        } else if query_str.starts_with("set time_zone")
135            || query_str.starts_with("set session time_zone")
136            || query_str.starts_with("set local time_zone")
137        {
138            // set time_zone='xxx'
139            let timezone = query
140                .split('=')
141                .nth(1)
142                .expect("Illegal `SET TIMEZONE` statement: expecting a timezone expr.")
143                .trim()
144                .strip_prefix('\'')
145                .unwrap()
146                .strip_suffix("';")
147                .unwrap();
148
149            self.grpc_client.set_timezone(timezone);
150            Ok(Output::new_with_affected_rows(0))
151        } else {
152            let mut result = self.grpc_client.sql(&query).await;
153            if let Ok(Output {
154                data: OutputData::Stream(stream),
155                ..
156            }) = result
157            {
158                match RecordBatches::try_collect(stream).await {
159                    Ok(recordbatches) => {
160                        result = Ok(Output::new_with_record_batches(recordbatches));
161                    }
162                    Err(e) => {
163                        let status_code = e.status_code();
164                        let msg = e.output_msg();
165                        result = ServerSnafu {
166                            code: status_code,
167                            msg,
168                        }
169                        .fail();
170                    }
171                }
172            }
173
174            result
175        }
176    }
177}
178
179/// Create a Postgres client with retry.
180///
181/// # Panics
182///
183/// Panics if the Postgres server address is invalid or the connection fails.
184async fn create_postgres_client(pg_server_addr: &str) -> PgClient {
185    let sockaddr: SocketAddr = pg_server_addr.parse().unwrap_or_else(|_| {
186        panic!("Failed to parse the Postgres server address {pg_server_addr}. Please check if the address is in the format of `ip:port`.")
187    });
188    let mut config = tokio_postgres::config::Config::new();
189    config.host(sockaddr.ip().to_string());
190    config.port(sockaddr.port());
191    config.dbname(DEFAULT_SCHEMA_NAME);
192
193    retry_with_backoff(
194        || async {
195            config
196                .connect(tokio_postgres::NoTls)
197                .await
198                .map(|(client, conn)| {
199                    tokio::spawn(conn);
200                    client
201                })
202        },
203        3,
204        Duration::from_millis(500),
205    )
206    .await
207    .unwrap_or_else(|_| {
208        panic!("Failed to connect to Postgres server. Please check if the server is running.")
209    })
210}
211
212/// Create a MySQL client with retry.
213///
214/// # Panics
215///
216/// Panics if the MySQL server address is invalid or the connection fails.
217async fn create_mysql_client(mysql_server_addr: &str) -> MySqlClient {
218    let sockaddr: SocketAddr = mysql_server_addr.parse().unwrap_or_else(|_| {
219        panic!("Failed to parse the MySQL server address {mysql_server_addr}. Please check if the address is in the format of `ip:port`.")
220    });
221    let ops = mysql::OptsBuilder::new()
222        .ip_or_hostname(Some(sockaddr.ip().to_string()))
223        .tcp_port(sockaddr.port())
224        .db_name(Some(DEFAULT_SCHEMA_NAME));
225
226    retry_with_backoff(
227        || async { mysql::Conn::new(ops.clone()) },
228        3,
229        Duration::from_millis(500),
230    )
231    .await
232    .unwrap_or_else(|_| {
233        panic!("Failed to connect to MySQL server. Please check if the server is running.")
234    })
235}