1use std::future::Future;
16use std::net::SocketAddr;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use auth::UserProviderRef;
21use catalog::process_manager::ProcessManagerRef;
22use common_runtime::runtime::RuntimeTrait;
23use common_runtime::Runtime;
24use common_telemetry::{debug, warn};
25use futures::StreamExt;
26use opensrv_mysql::{
27 plain_run_with_options, secure_run_with_options, AsyncMysqlIntermediary, IntermediaryOptions,
28};
29use snafu::ensure;
30use tokio;
31use tokio::io::BufWriter;
32use tokio::net::TcpStream;
33use tokio_rustls::rustls::ServerConfig;
34
35use crate::error::{Error, Result, TlsRequiredSnafu};
36use crate::mysql::handler::MysqlInstanceShim;
37use crate::query_handler::sql::ServerSqlQueryHandlerRef;
38use crate::server::{AbortableStream, BaseTcpServer, Server};
39use crate::tls::ReloadableTlsServerConfig;
40
41const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024;
43
44pub struct MysqlSpawnRef {
47 query_handler: ServerSqlQueryHandlerRef,
48 user_provider: Option<UserProviderRef>,
49}
50
51impl MysqlSpawnRef {
52 pub fn new(
53 query_handler: ServerSqlQueryHandlerRef,
54 user_provider: Option<UserProviderRef>,
55 ) -> MysqlSpawnRef {
56 MysqlSpawnRef {
57 query_handler,
58 user_provider,
59 }
60 }
61
62 fn query_handler(&self) -> ServerSqlQueryHandlerRef {
63 self.query_handler.clone()
64 }
65 fn user_provider(&self) -> Option<UserProviderRef> {
66 self.user_provider.clone()
67 }
68}
69
70pub struct MysqlSpawnConfig {
73 force_tls: bool,
75 tls: Arc<ReloadableTlsServerConfig>,
76 keep_alive_secs: u64,
78 reject_no_database: bool,
80 prepared_stmt_cache_size: usize,
82}
83
84impl MysqlSpawnConfig {
85 pub fn new(
86 force_tls: bool,
87 tls: Arc<ReloadableTlsServerConfig>,
88 keep_alive_secs: u64,
89 reject_no_database: bool,
90 prepared_stmt_cache_size: usize,
91 ) -> MysqlSpawnConfig {
92 MysqlSpawnConfig {
93 force_tls,
94 tls,
95 keep_alive_secs,
96 reject_no_database,
97 prepared_stmt_cache_size,
98 }
99 }
100
101 fn tls(&self) -> Option<Arc<ServerConfig>> {
102 self.tls.get_server_config()
103 }
104}
105
106impl From<&MysqlSpawnConfig> for IntermediaryOptions {
107 fn from(value: &MysqlSpawnConfig) -> Self {
108 IntermediaryOptions {
109 reject_connection_on_dbname_absence: value.reject_no_database,
110 ..Default::default()
111 }
112 }
113}
114
115pub struct MysqlServer {
116 base_server: BaseTcpServer,
117 spawn_ref: Arc<MysqlSpawnRef>,
118 spawn_config: Arc<MysqlSpawnConfig>,
119 bind_addr: Option<SocketAddr>,
120 process_manager: Option<ProcessManagerRef>,
121}
122
123impl MysqlServer {
124 pub fn create_server(
125 io_runtime: Runtime,
126 spawn_ref: Arc<MysqlSpawnRef>,
127 spawn_config: Arc<MysqlSpawnConfig>,
128 process_manager: Option<ProcessManagerRef>,
129 ) -> Box<dyn Server> {
130 Box::new(MysqlServer {
131 base_server: BaseTcpServer::create_server("MySQL", io_runtime),
132 spawn_ref,
133 spawn_config,
134 bind_addr: None,
135 process_manager,
136 })
137 }
138
139 fn accept(
140 &self,
141 io_runtime: Runtime,
142 stream: AbortableStream,
143 process_manager: Option<ProcessManagerRef>,
144 ) -> impl Future<Output = ()> {
145 let spawn_ref = self.spawn_ref.clone();
146 let spawn_config = self.spawn_config.clone();
147
148 stream.for_each(move |tcp_stream| {
149 let spawn_ref = spawn_ref.clone();
150 let spawn_config = spawn_config.clone();
151 let io_runtime = io_runtime.clone();
152 let process_id = process_manager.as_ref().map(|p| p.next_id()).unwrap_or(8);
153 async move {
154 match tcp_stream {
155 Err(e) => warn!(e; "Broken pipe"), Ok(io_stream) => {
157 if let Err(e) = io_stream.set_nodelay(true) {
158 warn!(e; "Failed to set TCP nodelay");
159 }
160 io_runtime.spawn(async move {
161 if let Err(error) =
162 Self::handle(io_stream, spawn_ref, spawn_config, process_id).await
163 {
164 warn!(error; "Unexpected error when handling TcpStream");
165 };
166 });
167 }
168 };
169 }
170 })
171 }
172
173 async fn handle(
174 stream: TcpStream,
175 spawn_ref: Arc<MysqlSpawnRef>,
176 spawn_config: Arc<MysqlSpawnConfig>,
177 process_id: u32,
178 ) -> Result<()> {
179 debug!("MySQL connection coming from: {}", stream.peer_addr()?);
180 crate::metrics::METRIC_MYSQL_CONNECTIONS.inc();
181 if let Err(e) = Self::do_handle(stream, spawn_ref, spawn_config, process_id).await {
182 if let Error::InternalIo { error } = &e
183 && error.kind() == std::io::ErrorKind::ConnectionAborted
184 {
185 } else {
187 warn!(e; "Internal error occurred during query exec, server actively close the channel to let client try next time");
190 }
191 }
192 crate::metrics::METRIC_MYSQL_CONNECTIONS.dec();
193
194 Ok(())
195 }
196
197 async fn do_handle(
198 stream: TcpStream,
199 spawn_ref: Arc<MysqlSpawnRef>,
200 spawn_config: Arc<MysqlSpawnConfig>,
201 process_id: u32,
202 ) -> Result<()> {
203 let mut shim = MysqlInstanceShim::create(
204 spawn_ref.query_handler(),
205 spawn_ref.user_provider(),
206 stream.peer_addr()?,
207 process_id,
208 spawn_config.prepared_stmt_cache_size,
209 );
210 let (mut r, w) = stream.into_split();
211 let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w);
212
213 let ops = spawn_config.as_ref().into();
214
215 let (client_tls, init_params) =
216 AsyncMysqlIntermediary::init_before_ssl(&mut shim, &mut r, &mut w, &spawn_config.tls())
217 .await?;
218
219 ensure!(
220 !spawn_config.force_tls || client_tls,
221 TlsRequiredSnafu {
222 server: "mysql".to_owned()
223 }
224 );
225
226 match spawn_config.tls() {
227 Some(tls_conf) if client_tls => {
228 secure_run_with_options(shim, w, ops, tls_conf, init_params).await
229 }
230 _ => plain_run_with_options(shim, w, ops, init_params).await,
231 }
232 }
233}
234
235pub const MYSQL_SERVER: &str = "MYSQL_SERVER";
236
237#[async_trait]
238impl Server for MysqlServer {
239 async fn shutdown(&self) -> Result<()> {
240 self.base_server.shutdown().await
241 }
242
243 async fn start(&mut self, listening: SocketAddr) -> Result<()> {
244 let (stream, addr) = self
245 .base_server
246 .bind(listening, self.spawn_config.keep_alive_secs)
247 .await?;
248 let io_runtime = self.base_server.io_runtime();
249
250 let join_handle = common_runtime::spawn_global(self.accept(
251 io_runtime,
252 stream,
253 self.process_manager.clone(),
254 ));
255 self.base_server.start_with(join_handle).await?;
256
257 self.bind_addr = Some(addr);
258 Ok(())
259 }
260
261 fn name(&self) -> &str {
262 MYSQL_SERVER
263 }
264
265 fn bind_addr(&self) -> Option<SocketAddr> {
266 self.bind_addr
267 }
268}