1use std::future::Future;
16use std::net::SocketAddr;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use auth::UserProviderRef;
21use common_runtime::runtime::RuntimeTrait;
22use common_runtime::Runtime;
23use common_telemetry::{debug, warn};
24use futures::StreamExt;
25use opensrv_mysql::{
26 plain_run_with_options, secure_run_with_options, AsyncMysqlIntermediary, IntermediaryOptions,
27};
28use snafu::ensure;
29use tokio;
30use tokio::io::BufWriter;
31use tokio::net::TcpStream;
32use tokio_rustls::rustls::ServerConfig;
33
34use crate::error::{Error, Result, TlsRequiredSnafu};
35use crate::mysql::handler::MysqlInstanceShim;
36use crate::query_handler::sql::ServerSqlQueryHandlerRef;
37use crate::server::{AbortableStream, BaseTcpServer, Server};
38use crate::tls::ReloadableTlsServerConfig;
39
40const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024;
42
43pub struct MysqlSpawnRef {
46 query_handler: ServerSqlQueryHandlerRef,
47 user_provider: Option<UserProviderRef>,
48}
49
50impl MysqlSpawnRef {
51 pub fn new(
52 query_handler: ServerSqlQueryHandlerRef,
53 user_provider: Option<UserProviderRef>,
54 ) -> MysqlSpawnRef {
55 MysqlSpawnRef {
56 query_handler,
57 user_provider,
58 }
59 }
60
61 fn query_handler(&self) -> ServerSqlQueryHandlerRef {
62 self.query_handler.clone()
63 }
64 fn user_provider(&self) -> Option<UserProviderRef> {
65 self.user_provider.clone()
66 }
67}
68
69pub struct MysqlSpawnConfig {
72 force_tls: bool,
74 tls: Arc<ReloadableTlsServerConfig>,
75 keep_alive_secs: u64,
77 reject_no_database: bool,
79}
80
81impl MysqlSpawnConfig {
82 pub fn new(
83 force_tls: bool,
84 tls: Arc<ReloadableTlsServerConfig>,
85 keep_alive_secs: u64,
86 reject_no_database: bool,
87 ) -> MysqlSpawnConfig {
88 MysqlSpawnConfig {
89 force_tls,
90 tls,
91 keep_alive_secs,
92 reject_no_database,
93 }
94 }
95
96 fn tls(&self) -> Option<Arc<ServerConfig>> {
97 self.tls.get_server_config()
98 }
99}
100
101impl From<&MysqlSpawnConfig> for IntermediaryOptions {
102 fn from(value: &MysqlSpawnConfig) -> Self {
103 IntermediaryOptions {
104 reject_connection_on_dbname_absence: value.reject_no_database,
105 ..Default::default()
106 }
107 }
108}
109
110pub struct MysqlServer {
111 base_server: BaseTcpServer,
112 spawn_ref: Arc<MysqlSpawnRef>,
113 spawn_config: Arc<MysqlSpawnConfig>,
114 bind_addr: Option<SocketAddr>,
115}
116
117impl MysqlServer {
118 pub fn create_server(
119 io_runtime: Runtime,
120 spawn_ref: Arc<MysqlSpawnRef>,
121 spawn_config: Arc<MysqlSpawnConfig>,
122 ) -> Box<dyn Server> {
123 Box::new(MysqlServer {
124 base_server: BaseTcpServer::create_server("MySQL", io_runtime),
125 spawn_ref,
126 spawn_config,
127 bind_addr: None,
128 })
129 }
130
131 fn accept(&self, io_runtime: Runtime, stream: AbortableStream) -> impl Future<Output = ()> {
132 let spawn_ref = self.spawn_ref.clone();
133 let spawn_config = self.spawn_config.clone();
134
135 stream.for_each(move |tcp_stream| {
136 let spawn_ref = spawn_ref.clone();
137 let spawn_config = spawn_config.clone();
138 let io_runtime = io_runtime.clone();
139
140 async move {
141 match tcp_stream {
142 Err(e) => warn!(e; "Broken pipe"), Ok(io_stream) => {
144 if let Err(e) = io_stream.set_nodelay(true) {
145 warn!(e; "Failed to set TCP nodelay");
146 }
147 io_runtime.spawn(async move {
148 if let Err(error) =
149 Self::handle(io_stream, spawn_ref, spawn_config).await
150 {
151 warn!(error; "Unexpected error when handling TcpStream");
152 };
153 });
154 }
155 };
156 }
157 })
158 }
159
160 async fn handle(
161 stream: TcpStream,
162 spawn_ref: Arc<MysqlSpawnRef>,
163 spawn_config: Arc<MysqlSpawnConfig>,
164 ) -> Result<()> {
165 debug!("MySQL connection coming from: {}", stream.peer_addr()?);
166 crate::metrics::METRIC_MYSQL_CONNECTIONS.inc();
167 if let Err(e) = Self::do_handle(stream, spawn_ref, spawn_config).await {
168 if let Error::InternalIo { error } = &e
169 && error.kind() == std::io::ErrorKind::ConnectionAborted
170 {
171 } else {
173 warn!(e; "Internal error occurred during query exec, server actively close the channel to let client try next time");
176 }
177 }
178 crate::metrics::METRIC_MYSQL_CONNECTIONS.dec();
179
180 Ok(())
181 }
182
183 async fn do_handle(
184 stream: TcpStream,
185 spawn_ref: Arc<MysqlSpawnRef>,
186 spawn_config: Arc<MysqlSpawnConfig>,
187 ) -> Result<()> {
188 let mut shim = MysqlInstanceShim::create(
189 spawn_ref.query_handler(),
190 spawn_ref.user_provider(),
191 stream.peer_addr()?,
192 );
193 let (mut r, w) = stream.into_split();
194 let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w);
195
196 let ops = spawn_config.as_ref().into();
197
198 let (client_tls, init_params) =
199 AsyncMysqlIntermediary::init_before_ssl(&mut shim, &mut r, &mut w, &spawn_config.tls())
200 .await?;
201
202 ensure!(
203 !spawn_config.force_tls || client_tls,
204 TlsRequiredSnafu {
205 server: "mysql".to_owned()
206 }
207 );
208
209 match spawn_config.tls() {
210 Some(tls_conf) if client_tls => {
211 secure_run_with_options(shim, w, ops, tls_conf, init_params).await
212 }
213 _ => plain_run_with_options(shim, w, ops, init_params).await,
214 }
215 }
216}
217
218pub const MYSQL_SERVER: &str = "MYSQL_SERVER";
219
220#[async_trait]
221impl Server for MysqlServer {
222 async fn shutdown(&self) -> Result<()> {
223 self.base_server.shutdown().await
224 }
225
226 async fn start(&mut self, listening: SocketAddr) -> Result<()> {
227 let (stream, addr) = self
228 .base_server
229 .bind(listening, self.spawn_config.keep_alive_secs)
230 .await?;
231 let io_runtime = self.base_server.io_runtime();
232
233 let join_handle = common_runtime::spawn_global(self.accept(io_runtime, stream));
234 self.base_server.start_with(join_handle).await?;
235
236 self.bind_addr = Some(addr);
237 Ok(())
238 }
239
240 fn name(&self) -> &str {
241 MYSQL_SERVER
242 }
243
244 fn bind_addr(&self) -> Option<SocketAddr> {
245 self.bind_addr
246 }
247}