servers/mysql/
server.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::future::Future;
16use std::net::SocketAddr;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use auth::UserProviderRef;
21use catalog::process_manager::ProcessManagerRef;
22use common_runtime::runtime::RuntimeTrait;
23use common_runtime::Runtime;
24use common_telemetry::{debug, warn};
25use futures::StreamExt;
26use opensrv_mysql::{
27    plain_run_with_options, secure_run_with_options, AsyncMysqlIntermediary, IntermediaryOptions,
28};
29use snafu::ensure;
30use tokio;
31use tokio::io::BufWriter;
32use tokio::net::TcpStream;
33use tokio_rustls::rustls::ServerConfig;
34
35use crate::error::{Error, Result, TlsRequiredSnafu};
36use crate::mysql::handler::MysqlInstanceShim;
37use crate::query_handler::sql::ServerSqlQueryHandlerRef;
38use crate::server::{AbortableStream, BaseTcpServer, Server};
39use crate::tls::ReloadableTlsServerConfig;
40
41// Default size of ResultSet write buffer: 100KB
42const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024;
43
44/// [`MysqlSpawnRef`] stores arc refs
45/// that should be passed to new [`MysqlInstanceShim`]s.
46pub struct MysqlSpawnRef {
47    query_handler: ServerSqlQueryHandlerRef,
48    user_provider: Option<UserProviderRef>,
49}
50
51impl MysqlSpawnRef {
52    pub fn new(
53        query_handler: ServerSqlQueryHandlerRef,
54        user_provider: Option<UserProviderRef>,
55    ) -> MysqlSpawnRef {
56        MysqlSpawnRef {
57            query_handler,
58            user_provider,
59        }
60    }
61
62    fn query_handler(&self) -> ServerSqlQueryHandlerRef {
63        self.query_handler.clone()
64    }
65    fn user_provider(&self) -> Option<UserProviderRef> {
66        self.user_provider.clone()
67    }
68}
69
70/// [`MysqlSpawnConfig`] stores config values
71/// which are used to initialize [`MysqlInstanceShim`]s.
72pub struct MysqlSpawnConfig {
73    // tls config
74    force_tls: bool,
75    tls: Arc<ReloadableTlsServerConfig>,
76    // keep-alive config
77    keep_alive_secs: u64,
78    // other shim config
79    reject_no_database: bool,
80}
81
82impl MysqlSpawnConfig {
83    pub fn new(
84        force_tls: bool,
85        tls: Arc<ReloadableTlsServerConfig>,
86        keep_alive_secs: u64,
87        reject_no_database: bool,
88    ) -> MysqlSpawnConfig {
89        MysqlSpawnConfig {
90            force_tls,
91            tls,
92            keep_alive_secs,
93            reject_no_database,
94        }
95    }
96
97    fn tls(&self) -> Option<Arc<ServerConfig>> {
98        self.tls.get_server_config()
99    }
100}
101
102impl From<&MysqlSpawnConfig> for IntermediaryOptions {
103    fn from(value: &MysqlSpawnConfig) -> Self {
104        IntermediaryOptions {
105            reject_connection_on_dbname_absence: value.reject_no_database,
106            ..Default::default()
107        }
108    }
109}
110
111pub struct MysqlServer {
112    base_server: BaseTcpServer,
113    spawn_ref: Arc<MysqlSpawnRef>,
114    spawn_config: Arc<MysqlSpawnConfig>,
115    bind_addr: Option<SocketAddr>,
116    process_manager: Option<ProcessManagerRef>,
117}
118
119impl MysqlServer {
120    pub fn create_server(
121        io_runtime: Runtime,
122        spawn_ref: Arc<MysqlSpawnRef>,
123        spawn_config: Arc<MysqlSpawnConfig>,
124        process_manager: Option<ProcessManagerRef>,
125    ) -> Box<dyn Server> {
126        Box::new(MysqlServer {
127            base_server: BaseTcpServer::create_server("MySQL", io_runtime),
128            spawn_ref,
129            spawn_config,
130            bind_addr: None,
131            process_manager,
132        })
133    }
134
135    fn accept(
136        &self,
137        io_runtime: Runtime,
138        stream: AbortableStream,
139        process_manager: Option<ProcessManagerRef>,
140    ) -> impl Future<Output = ()> {
141        let spawn_ref = self.spawn_ref.clone();
142        let spawn_config = self.spawn_config.clone();
143
144        stream.for_each(move |tcp_stream| {
145            let spawn_ref = spawn_ref.clone();
146            let spawn_config = spawn_config.clone();
147            let io_runtime = io_runtime.clone();
148            let process_id = process_manager.as_ref().map(|p| p.next_id()).unwrap_or(8);
149            async move {
150                match tcp_stream {
151                    Err(e) => warn!(e; "Broken pipe"), // IoError doesn't impl ErrorExt.
152                    Ok(io_stream) => {
153                        if let Err(e) = io_stream.set_nodelay(true) {
154                            warn!(e; "Failed to set TCP nodelay");
155                        }
156                        io_runtime.spawn(async move {
157                            if let Err(error) =
158                                Self::handle(io_stream, spawn_ref, spawn_config, process_id).await
159                            {
160                                warn!(error; "Unexpected error when handling TcpStream");
161                            };
162                        });
163                    }
164                };
165            }
166        })
167    }
168
169    async fn handle(
170        stream: TcpStream,
171        spawn_ref: Arc<MysqlSpawnRef>,
172        spawn_config: Arc<MysqlSpawnConfig>,
173        process_id: u32,
174    ) -> Result<()> {
175        debug!("MySQL connection coming from: {}", stream.peer_addr()?);
176        crate::metrics::METRIC_MYSQL_CONNECTIONS.inc();
177        if let Err(e) = Self::do_handle(stream, spawn_ref, spawn_config, process_id).await {
178            if let Error::InternalIo { error } = &e
179                && error.kind() == std::io::ErrorKind::ConnectionAborted
180            {
181                // This is a client-side error, we don't need to log it.
182            } else {
183                // TODO(LFC): Write this error to client as well, in MySQL text protocol.
184                // Looks like we have to expose opensrv-mysql's `PacketWriter`?
185                warn!(e; "Internal error occurred during query exec, server actively close the channel to let client try next time");
186            }
187        }
188        crate::metrics::METRIC_MYSQL_CONNECTIONS.dec();
189
190        Ok(())
191    }
192
193    async fn do_handle(
194        stream: TcpStream,
195        spawn_ref: Arc<MysqlSpawnRef>,
196        spawn_config: Arc<MysqlSpawnConfig>,
197        process_id: u32,
198    ) -> Result<()> {
199        let mut shim = MysqlInstanceShim::create(
200            spawn_ref.query_handler(),
201            spawn_ref.user_provider(),
202            stream.peer_addr()?,
203            process_id,
204        );
205        let (mut r, w) = stream.into_split();
206        let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w);
207
208        let ops = spawn_config.as_ref().into();
209
210        let (client_tls, init_params) =
211            AsyncMysqlIntermediary::init_before_ssl(&mut shim, &mut r, &mut w, &spawn_config.tls())
212                .await?;
213
214        ensure!(
215            !spawn_config.force_tls || client_tls,
216            TlsRequiredSnafu {
217                server: "mysql".to_owned()
218            }
219        );
220
221        match spawn_config.tls() {
222            Some(tls_conf) if client_tls => {
223                secure_run_with_options(shim, w, ops, tls_conf, init_params).await
224            }
225            _ => plain_run_with_options(shim, w, ops, init_params).await,
226        }
227    }
228}
229
230pub const MYSQL_SERVER: &str = "MYSQL_SERVER";
231
232#[async_trait]
233impl Server for MysqlServer {
234    async fn shutdown(&self) -> Result<()> {
235        self.base_server.shutdown().await
236    }
237
238    async fn start(&mut self, listening: SocketAddr) -> Result<()> {
239        let (stream, addr) = self
240            .base_server
241            .bind(listening, self.spawn_config.keep_alive_secs)
242            .await?;
243        let io_runtime = self.base_server.io_runtime();
244
245        let join_handle = common_runtime::spawn_global(self.accept(
246            io_runtime,
247            stream,
248            self.process_manager.clone(),
249        ));
250        self.base_server.start_with(join_handle).await?;
251
252        self.bind_addr = Some(addr);
253        Ok(())
254    }
255
256    fn name(&self) -> &str {
257        MYSQL_SERVER
258    }
259
260    fn bind_addr(&self) -> Option<SocketAddr> {
261        self.bind_addr
262    }
263}