1use std::future::Future;
16use std::net::SocketAddr;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use auth::UserProviderRef;
21use catalog::process_manager::ProcessManagerRef;
22use common_runtime::runtime::RuntimeTrait;
23use common_runtime::Runtime;
24use common_telemetry::{debug, warn};
25use futures::StreamExt;
26use opensrv_mysql::{
27 plain_run_with_options, secure_run_with_options, AsyncMysqlIntermediary, IntermediaryOptions,
28};
29use snafu::ensure;
30use tokio;
31use tokio::io::BufWriter;
32use tokio::net::TcpStream;
33use tokio_rustls::rustls::ServerConfig;
34
35use crate::error::{Error, Result, TlsRequiredSnafu};
36use crate::mysql::handler::MysqlInstanceShim;
37use crate::query_handler::sql::ServerSqlQueryHandlerRef;
38use crate::server::{AbortableStream, BaseTcpServer, Server};
39use crate::tls::ReloadableTlsServerConfig;
40
41const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024;
43
44pub struct MysqlSpawnRef {
47 query_handler: ServerSqlQueryHandlerRef,
48 user_provider: Option<UserProviderRef>,
49}
50
51impl MysqlSpawnRef {
52 pub fn new(
53 query_handler: ServerSqlQueryHandlerRef,
54 user_provider: Option<UserProviderRef>,
55 ) -> MysqlSpawnRef {
56 MysqlSpawnRef {
57 query_handler,
58 user_provider,
59 }
60 }
61
62 fn query_handler(&self) -> ServerSqlQueryHandlerRef {
63 self.query_handler.clone()
64 }
65 fn user_provider(&self) -> Option<UserProviderRef> {
66 self.user_provider.clone()
67 }
68}
69
70pub struct MysqlSpawnConfig {
73 force_tls: bool,
75 tls: Arc<ReloadableTlsServerConfig>,
76 keep_alive_secs: u64,
78 reject_no_database: bool,
80}
81
82impl MysqlSpawnConfig {
83 pub fn new(
84 force_tls: bool,
85 tls: Arc<ReloadableTlsServerConfig>,
86 keep_alive_secs: u64,
87 reject_no_database: bool,
88 ) -> MysqlSpawnConfig {
89 MysqlSpawnConfig {
90 force_tls,
91 tls,
92 keep_alive_secs,
93 reject_no_database,
94 }
95 }
96
97 fn tls(&self) -> Option<Arc<ServerConfig>> {
98 self.tls.get_server_config()
99 }
100}
101
102impl From<&MysqlSpawnConfig> for IntermediaryOptions {
103 fn from(value: &MysqlSpawnConfig) -> Self {
104 IntermediaryOptions {
105 reject_connection_on_dbname_absence: value.reject_no_database,
106 ..Default::default()
107 }
108 }
109}
110
111pub struct MysqlServer {
112 base_server: BaseTcpServer,
113 spawn_ref: Arc<MysqlSpawnRef>,
114 spawn_config: Arc<MysqlSpawnConfig>,
115 bind_addr: Option<SocketAddr>,
116 process_manager: Option<ProcessManagerRef>,
117}
118
119impl MysqlServer {
120 pub fn create_server(
121 io_runtime: Runtime,
122 spawn_ref: Arc<MysqlSpawnRef>,
123 spawn_config: Arc<MysqlSpawnConfig>,
124 process_manager: Option<ProcessManagerRef>,
125 ) -> Box<dyn Server> {
126 Box::new(MysqlServer {
127 base_server: BaseTcpServer::create_server("MySQL", io_runtime),
128 spawn_ref,
129 spawn_config,
130 bind_addr: None,
131 process_manager,
132 })
133 }
134
135 fn accept(
136 &self,
137 io_runtime: Runtime,
138 stream: AbortableStream,
139 process_manager: Option<ProcessManagerRef>,
140 ) -> impl Future<Output = ()> {
141 let spawn_ref = self.spawn_ref.clone();
142 let spawn_config = self.spawn_config.clone();
143
144 stream.for_each(move |tcp_stream| {
145 let spawn_ref = spawn_ref.clone();
146 let spawn_config = spawn_config.clone();
147 let io_runtime = io_runtime.clone();
148 let process_id = process_manager.as_ref().map(|p| p.next_id()).unwrap_or(8);
149 async move {
150 match tcp_stream {
151 Err(e) => warn!(e; "Broken pipe"), Ok(io_stream) => {
153 if let Err(e) = io_stream.set_nodelay(true) {
154 warn!(e; "Failed to set TCP nodelay");
155 }
156 io_runtime.spawn(async move {
157 if let Err(error) =
158 Self::handle(io_stream, spawn_ref, spawn_config, process_id).await
159 {
160 warn!(error; "Unexpected error when handling TcpStream");
161 };
162 });
163 }
164 };
165 }
166 })
167 }
168
169 async fn handle(
170 stream: TcpStream,
171 spawn_ref: Arc<MysqlSpawnRef>,
172 spawn_config: Arc<MysqlSpawnConfig>,
173 process_id: u32,
174 ) -> Result<()> {
175 debug!("MySQL connection coming from: {}", stream.peer_addr()?);
176 crate::metrics::METRIC_MYSQL_CONNECTIONS.inc();
177 if let Err(e) = Self::do_handle(stream, spawn_ref, spawn_config, process_id).await {
178 if let Error::InternalIo { error } = &e
179 && error.kind() == std::io::ErrorKind::ConnectionAborted
180 {
181 } else {
183 warn!(e; "Internal error occurred during query exec, server actively close the channel to let client try next time");
186 }
187 }
188 crate::metrics::METRIC_MYSQL_CONNECTIONS.dec();
189
190 Ok(())
191 }
192
193 async fn do_handle(
194 stream: TcpStream,
195 spawn_ref: Arc<MysqlSpawnRef>,
196 spawn_config: Arc<MysqlSpawnConfig>,
197 process_id: u32,
198 ) -> Result<()> {
199 let mut shim = MysqlInstanceShim::create(
200 spawn_ref.query_handler(),
201 spawn_ref.user_provider(),
202 stream.peer_addr()?,
203 process_id,
204 );
205 let (mut r, w) = stream.into_split();
206 let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w);
207
208 let ops = spawn_config.as_ref().into();
209
210 let (client_tls, init_params) =
211 AsyncMysqlIntermediary::init_before_ssl(&mut shim, &mut r, &mut w, &spawn_config.tls())
212 .await?;
213
214 ensure!(
215 !spawn_config.force_tls || client_tls,
216 TlsRequiredSnafu {
217 server: "mysql".to_owned()
218 }
219 );
220
221 match spawn_config.tls() {
222 Some(tls_conf) if client_tls => {
223 secure_run_with_options(shim, w, ops, tls_conf, init_params).await
224 }
225 _ => plain_run_with_options(shim, w, ops, init_params).await,
226 }
227 }
228}
229
230pub const MYSQL_SERVER: &str = "MYSQL_SERVER";
231
232#[async_trait]
233impl Server for MysqlServer {
234 async fn shutdown(&self) -> Result<()> {
235 self.base_server.shutdown().await
236 }
237
238 async fn start(&mut self, listening: SocketAddr) -> Result<()> {
239 let (stream, addr) = self
240 .base_server
241 .bind(listening, self.spawn_config.keep_alive_secs)
242 .await?;
243 let io_runtime = self.base_server.io_runtime();
244
245 let join_handle = common_runtime::spawn_global(self.accept(
246 io_runtime,
247 stream,
248 self.process_manager.clone(),
249 ));
250 self.base_server.start_with(join_handle).await?;
251
252 self.bind_addr = Some(addr);
253 Ok(())
254 }
255
256 fn name(&self) -> &str {
257 MYSQL_SERVER
258 }
259
260 fn bind_addr(&self) -> Option<SocketAddr> {
261 self.bind_addr
262 }
263}