servers/mysql/
server.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::future::Future;
16use std::net::SocketAddr;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use auth::UserProviderRef;
21use common_runtime::runtime::RuntimeTrait;
22use common_runtime::Runtime;
23use common_telemetry::{debug, warn};
24use futures::StreamExt;
25use opensrv_mysql::{
26    plain_run_with_options, secure_run_with_options, AsyncMysqlIntermediary, IntermediaryOptions,
27};
28use snafu::ensure;
29use tokio;
30use tokio::io::BufWriter;
31use tokio::net::TcpStream;
32use tokio_rustls::rustls::ServerConfig;
33
34use crate::error::{Error, Result, TlsRequiredSnafu};
35use crate::mysql::handler::MysqlInstanceShim;
36use crate::query_handler::sql::ServerSqlQueryHandlerRef;
37use crate::server::{AbortableStream, BaseTcpServer, Server};
38use crate::tls::ReloadableTlsServerConfig;
39
40// Default size of ResultSet write buffer: 100KB
41const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024;
42
43/// [`MysqlSpawnRef`] stores arc refs
44/// that should be passed to new [`MysqlInstanceShim`]s.
45pub struct MysqlSpawnRef {
46    query_handler: ServerSqlQueryHandlerRef,
47    user_provider: Option<UserProviderRef>,
48}
49
50impl MysqlSpawnRef {
51    pub fn new(
52        query_handler: ServerSqlQueryHandlerRef,
53        user_provider: Option<UserProviderRef>,
54    ) -> MysqlSpawnRef {
55        MysqlSpawnRef {
56            query_handler,
57            user_provider,
58        }
59    }
60
61    fn query_handler(&self) -> ServerSqlQueryHandlerRef {
62        self.query_handler.clone()
63    }
64    fn user_provider(&self) -> Option<UserProviderRef> {
65        self.user_provider.clone()
66    }
67}
68
69/// [`MysqlSpawnConfig`] stores config values
70/// which are used to initialize [`MysqlInstanceShim`]s.
71pub struct MysqlSpawnConfig {
72    // tls config
73    force_tls: bool,
74    tls: Arc<ReloadableTlsServerConfig>,
75    // keep-alive config
76    keep_alive_secs: u64,
77    // other shim config
78    reject_no_database: bool,
79}
80
81impl MysqlSpawnConfig {
82    pub fn new(
83        force_tls: bool,
84        tls: Arc<ReloadableTlsServerConfig>,
85        keep_alive_secs: u64,
86        reject_no_database: bool,
87    ) -> MysqlSpawnConfig {
88        MysqlSpawnConfig {
89            force_tls,
90            tls,
91            keep_alive_secs,
92            reject_no_database,
93        }
94    }
95
96    fn tls(&self) -> Option<Arc<ServerConfig>> {
97        self.tls.get_server_config()
98    }
99}
100
101impl From<&MysqlSpawnConfig> for IntermediaryOptions {
102    fn from(value: &MysqlSpawnConfig) -> Self {
103        IntermediaryOptions {
104            reject_connection_on_dbname_absence: value.reject_no_database,
105            ..Default::default()
106        }
107    }
108}
109
110pub struct MysqlServer {
111    base_server: BaseTcpServer,
112    spawn_ref: Arc<MysqlSpawnRef>,
113    spawn_config: Arc<MysqlSpawnConfig>,
114    bind_addr: Option<SocketAddr>,
115}
116
117impl MysqlServer {
118    pub fn create_server(
119        io_runtime: Runtime,
120        spawn_ref: Arc<MysqlSpawnRef>,
121        spawn_config: Arc<MysqlSpawnConfig>,
122    ) -> Box<dyn Server> {
123        Box::new(MysqlServer {
124            base_server: BaseTcpServer::create_server("MySQL", io_runtime),
125            spawn_ref,
126            spawn_config,
127            bind_addr: None,
128        })
129    }
130
131    fn accept(&self, io_runtime: Runtime, stream: AbortableStream) -> impl Future<Output = ()> {
132        let spawn_ref = self.spawn_ref.clone();
133        let spawn_config = self.spawn_config.clone();
134
135        stream.for_each(move |tcp_stream| {
136            let spawn_ref = spawn_ref.clone();
137            let spawn_config = spawn_config.clone();
138            let io_runtime = io_runtime.clone();
139
140            async move {
141                match tcp_stream {
142                    Err(e) => warn!(e; "Broken pipe"), // IoError doesn't impl ErrorExt.
143                    Ok(io_stream) => {
144                        if let Err(e) = io_stream.set_nodelay(true) {
145                            warn!(e; "Failed to set TCP nodelay");
146                        }
147                        io_runtime.spawn(async move {
148                            if let Err(error) =
149                                Self::handle(io_stream, spawn_ref, spawn_config).await
150                            {
151                                warn!(error; "Unexpected error when handling TcpStream");
152                            };
153                        });
154                    }
155                };
156            }
157        })
158    }
159
160    async fn handle(
161        stream: TcpStream,
162        spawn_ref: Arc<MysqlSpawnRef>,
163        spawn_config: Arc<MysqlSpawnConfig>,
164    ) -> Result<()> {
165        debug!("MySQL connection coming from: {}", stream.peer_addr()?);
166        crate::metrics::METRIC_MYSQL_CONNECTIONS.inc();
167        if let Err(e) = Self::do_handle(stream, spawn_ref, spawn_config).await {
168            if let Error::InternalIo { error } = &e
169                && error.kind() == std::io::ErrorKind::ConnectionAborted
170            {
171                // This is a client-side error, we don't need to log it.
172            } else {
173                // TODO(LFC): Write this error to client as well, in MySQL text protocol.
174                // Looks like we have to expose opensrv-mysql's `PacketWriter`?
175                warn!(e; "Internal error occurred during query exec, server actively close the channel to let client try next time");
176            }
177        }
178        crate::metrics::METRIC_MYSQL_CONNECTIONS.dec();
179
180        Ok(())
181    }
182
183    async fn do_handle(
184        stream: TcpStream,
185        spawn_ref: Arc<MysqlSpawnRef>,
186        spawn_config: Arc<MysqlSpawnConfig>,
187    ) -> Result<()> {
188        let mut shim = MysqlInstanceShim::create(
189            spawn_ref.query_handler(),
190            spawn_ref.user_provider(),
191            stream.peer_addr()?,
192        );
193        let (mut r, w) = stream.into_split();
194        let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w);
195
196        let ops = spawn_config.as_ref().into();
197
198        let (client_tls, init_params) =
199            AsyncMysqlIntermediary::init_before_ssl(&mut shim, &mut r, &mut w, &spawn_config.tls())
200                .await?;
201
202        ensure!(
203            !spawn_config.force_tls || client_tls,
204            TlsRequiredSnafu {
205                server: "mysql".to_owned()
206            }
207        );
208
209        match spawn_config.tls() {
210            Some(tls_conf) if client_tls => {
211                secure_run_with_options(shim, w, ops, tls_conf, init_params).await
212            }
213            _ => plain_run_with_options(shim, w, ops, init_params).await,
214        }
215    }
216}
217
218pub const MYSQL_SERVER: &str = "MYSQL_SERVER";
219
220#[async_trait]
221impl Server for MysqlServer {
222    async fn shutdown(&self) -> Result<()> {
223        self.base_server.shutdown().await
224    }
225
226    async fn start(&mut self, listening: SocketAddr) -> Result<()> {
227        let (stream, addr) = self
228            .base_server
229            .bind(listening, self.spawn_config.keep_alive_secs)
230            .await?;
231        let io_runtime = self.base_server.io_runtime();
232
233        let join_handle = common_runtime::spawn_global(self.accept(io_runtime, stream));
234        self.base_server.start_with(join_handle).await?;
235
236        self.bind_addr = Some(addr);
237        Ok(())
238    }
239
240    fn name(&self) -> &str {
241        MYSQL_SERVER
242    }
243
244    fn bind_addr(&self) -> Option<SocketAddr> {
245        self.bind_addr
246    }
247}