1use std::any::Any;
16use std::future::Future;
17use std::net::SocketAddr;
18use std::sync::Arc;
19
20use async_trait::async_trait;
21use auth::UserProviderRef;
22use catalog::process_manager::ProcessManagerRef;
23use common_runtime::Runtime;
24use common_runtime::runtime::RuntimeTrait;
25use common_telemetry::{debug, warn};
26use futures::StreamExt;
27use opensrv_mysql::{
28 AsyncMysqlIntermediary, IntermediaryOptions, plain_run_with_options, secure_run_with_options,
29};
30use snafu::ensure;
31use tokio;
32use tokio::io::BufWriter;
33use tokio::net::TcpStream;
34use tokio_rustls::rustls::ServerConfig;
35
36use crate::error::{Error, Result, TlsRequiredSnafu};
37use crate::mysql::handler::MysqlInstanceShim;
38use crate::query_handler::sql::ServerSqlQueryHandlerRef;
39use crate::server::{AbortableStream, BaseTcpServer, Server};
40use crate::tls::ReloadableTlsServerConfig;
41
42const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024;
44
45pub struct MysqlSpawnRef {
48 query_handler: ServerSqlQueryHandlerRef,
49 user_provider: Option<UserProviderRef>,
50}
51
52impl MysqlSpawnRef {
53 pub fn new(
54 query_handler: ServerSqlQueryHandlerRef,
55 user_provider: Option<UserProviderRef>,
56 ) -> MysqlSpawnRef {
57 MysqlSpawnRef {
58 query_handler,
59 user_provider,
60 }
61 }
62
63 fn query_handler(&self) -> ServerSqlQueryHandlerRef {
64 self.query_handler.clone()
65 }
66 fn user_provider(&self) -> Option<UserProviderRef> {
67 self.user_provider.clone()
68 }
69}
70
71pub struct MysqlSpawnConfig {
74 force_tls: bool,
76 tls: Arc<ReloadableTlsServerConfig>,
77 keep_alive_secs: u64,
79 reject_no_database: bool,
81 prepared_stmt_cache_size: usize,
83}
84
85impl MysqlSpawnConfig {
86 pub fn new(
87 force_tls: bool,
88 tls: Arc<ReloadableTlsServerConfig>,
89 keep_alive_secs: u64,
90 reject_no_database: bool,
91 prepared_stmt_cache_size: usize,
92 ) -> MysqlSpawnConfig {
93 MysqlSpawnConfig {
94 force_tls,
95 tls,
96 keep_alive_secs,
97 reject_no_database,
98 prepared_stmt_cache_size,
99 }
100 }
101
102 fn tls(&self) -> Option<Arc<ServerConfig>> {
103 self.tls.get_config()
104 }
105}
106
107impl From<&MysqlSpawnConfig> for IntermediaryOptions {
108 fn from(value: &MysqlSpawnConfig) -> Self {
109 IntermediaryOptions {
110 reject_connection_on_dbname_absence: value.reject_no_database,
111 ..Default::default()
112 }
113 }
114}
115
116pub struct MysqlServer {
117 base_server: BaseTcpServer,
118 spawn_ref: Arc<MysqlSpawnRef>,
119 spawn_config: Arc<MysqlSpawnConfig>,
120 bind_addr: Option<SocketAddr>,
121 process_manager: Option<ProcessManagerRef>,
122}
123
124impl MysqlServer {
125 pub fn create_server(
126 io_runtime: Runtime,
127 spawn_ref: Arc<MysqlSpawnRef>,
128 spawn_config: Arc<MysqlSpawnConfig>,
129 process_manager: Option<ProcessManagerRef>,
130 ) -> Box<dyn Server> {
131 Box::new(MysqlServer {
132 base_server: BaseTcpServer::create_server("MySQL", io_runtime),
133 spawn_ref,
134 spawn_config,
135 bind_addr: None,
136 process_manager,
137 })
138 }
139
140 fn accept(
141 &self,
142 io_runtime: Runtime,
143 stream: AbortableStream,
144 process_manager: Option<ProcessManagerRef>,
145 ) -> impl Future<Output = ()> + use<> {
146 let spawn_ref = self.spawn_ref.clone();
147 let spawn_config = self.spawn_config.clone();
148
149 stream.for_each(move |tcp_stream| {
150 let spawn_ref = spawn_ref.clone();
151 let spawn_config = spawn_config.clone();
152 let io_runtime = io_runtime.clone();
153 let process_id = process_manager.as_ref().map(|p| p.next_id()).unwrap_or(8);
154 async move {
155 match tcp_stream {
156 Err(e) => warn!(e; "Broken pipe"), Ok(io_stream) => {
158 if let Err(e) = io_stream.set_nodelay(true) {
159 warn!(e; "Failed to set TCP nodelay");
160 }
161 io_runtime.spawn(async move {
162 if let Err(error) =
163 Self::handle(io_stream, spawn_ref, spawn_config, process_id).await
164 {
165 warn!(error; "Unexpected error when handling TcpStream");
166 };
167 });
168 }
169 };
170 }
171 })
172 }
173
174 async fn handle(
175 stream: TcpStream,
176 spawn_ref: Arc<MysqlSpawnRef>,
177 spawn_config: Arc<MysqlSpawnConfig>,
178 process_id: u32,
179 ) -> Result<()> {
180 debug!("MySQL connection coming from: {}", stream.peer_addr()?);
181 crate::metrics::METRIC_MYSQL_CONNECTIONS.inc();
182 if let Err(e) = Self::do_handle(stream, spawn_ref, spawn_config, process_id).await {
183 if let Error::InternalIo { error } = &e
184 && error.kind() == std::io::ErrorKind::ConnectionAborted
185 {
186 } else {
188 warn!(e; "Internal error occurred during query exec, server actively close the channel to let client try next time");
191 }
192 }
193 crate::metrics::METRIC_MYSQL_CONNECTIONS.dec();
194
195 Ok(())
196 }
197
198 async fn do_handle(
199 stream: TcpStream,
200 spawn_ref: Arc<MysqlSpawnRef>,
201 spawn_config: Arc<MysqlSpawnConfig>,
202 process_id: u32,
203 ) -> Result<()> {
204 let mut shim = MysqlInstanceShim::create(
205 spawn_ref.query_handler(),
206 spawn_ref.user_provider(),
207 stream.peer_addr()?,
208 process_id,
209 spawn_config.prepared_stmt_cache_size,
210 );
211 let (mut r, w) = stream.into_split();
212 let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w);
213
214 let ops = spawn_config.as_ref().into();
215
216 let (client_tls, init_params) =
217 AsyncMysqlIntermediary::init_before_ssl(&mut shim, &mut r, &mut w, &spawn_config.tls())
218 .await?;
219
220 ensure!(
221 !spawn_config.force_tls || client_tls,
222 TlsRequiredSnafu {
223 server: "mysql".to_owned()
224 }
225 );
226
227 match spawn_config.tls() {
228 Some(tls_conf) if client_tls => {
229 secure_run_with_options(shim, w, ops, tls_conf, init_params).await
230 }
231 _ => plain_run_with_options(shim, w, ops, init_params).await,
232 }
233 }
234}
235
236pub const MYSQL_SERVER: &str = "MYSQL_SERVER";
237
238#[async_trait]
239impl Server for MysqlServer {
240 async fn shutdown(&self) -> Result<()> {
241 self.base_server.shutdown().await
242 }
243
244 async fn start(&mut self, listening: SocketAddr) -> Result<()> {
245 let (stream, addr) = self
246 .base_server
247 .bind(listening, self.spawn_config.keep_alive_secs)
248 .await?;
249 let io_runtime = self.base_server.io_runtime();
250
251 let join_handle = common_runtime::spawn_global(self.accept(
252 io_runtime,
253 stream,
254 self.process_manager.clone(),
255 ));
256 self.base_server.start_with(join_handle).await?;
257
258 self.bind_addr = Some(addr);
259 Ok(())
260 }
261
262 fn name(&self) -> &str {
263 MYSQL_SERVER
264 }
265
266 fn bind_addr(&self) -> Option<SocketAddr> {
267 self.bind_addr
268 }
269
270 fn as_any(&self) -> &dyn Any {
271 self
272 }
273}