servers/mysql/
server.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::future::Future;
16use std::net::SocketAddr;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use auth::UserProviderRef;
21use catalog::process_manager::ProcessManagerRef;
22use common_runtime::runtime::RuntimeTrait;
23use common_runtime::Runtime;
24use common_telemetry::{debug, warn};
25use futures::StreamExt;
26use opensrv_mysql::{
27    plain_run_with_options, secure_run_with_options, AsyncMysqlIntermediary, IntermediaryOptions,
28};
29use snafu::ensure;
30use tokio;
31use tokio::io::BufWriter;
32use tokio::net::TcpStream;
33use tokio_rustls::rustls::ServerConfig;
34
35use crate::error::{Error, Result, TlsRequiredSnafu};
36use crate::mysql::handler::MysqlInstanceShim;
37use crate::query_handler::sql::ServerSqlQueryHandlerRef;
38use crate::server::{AbortableStream, BaseTcpServer, Server};
39use crate::tls::ReloadableTlsServerConfig;
40
41// Default size of ResultSet write buffer: 100KB
42const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024;
43
44/// [`MysqlSpawnRef`] stores arc refs
45/// that should be passed to new [`MysqlInstanceShim`]s.
46pub struct MysqlSpawnRef {
47    query_handler: ServerSqlQueryHandlerRef,
48    user_provider: Option<UserProviderRef>,
49}
50
51impl MysqlSpawnRef {
52    pub fn new(
53        query_handler: ServerSqlQueryHandlerRef,
54        user_provider: Option<UserProviderRef>,
55    ) -> MysqlSpawnRef {
56        MysqlSpawnRef {
57            query_handler,
58            user_provider,
59        }
60    }
61
62    fn query_handler(&self) -> ServerSqlQueryHandlerRef {
63        self.query_handler.clone()
64    }
65    fn user_provider(&self) -> Option<UserProviderRef> {
66        self.user_provider.clone()
67    }
68}
69
70/// [`MysqlSpawnConfig`] stores config values
71/// which are used to initialize [`MysqlInstanceShim`]s.
72pub struct MysqlSpawnConfig {
73    // tls config
74    force_tls: bool,
75    tls: Arc<ReloadableTlsServerConfig>,
76    // keep-alive config
77    keep_alive_secs: u64,
78    // other shim config
79    reject_no_database: bool,
80    // prepared statement cache capacity
81    prepared_stmt_cache_size: usize,
82}
83
84impl MysqlSpawnConfig {
85    pub fn new(
86        force_tls: bool,
87        tls: Arc<ReloadableTlsServerConfig>,
88        keep_alive_secs: u64,
89        reject_no_database: bool,
90        prepared_stmt_cache_size: usize,
91    ) -> MysqlSpawnConfig {
92        MysqlSpawnConfig {
93            force_tls,
94            tls,
95            keep_alive_secs,
96            reject_no_database,
97            prepared_stmt_cache_size,
98        }
99    }
100
101    fn tls(&self) -> Option<Arc<ServerConfig>> {
102        self.tls.get_server_config()
103    }
104}
105
106impl From<&MysqlSpawnConfig> for IntermediaryOptions {
107    fn from(value: &MysqlSpawnConfig) -> Self {
108        IntermediaryOptions {
109            reject_connection_on_dbname_absence: value.reject_no_database,
110            ..Default::default()
111        }
112    }
113}
114
115pub struct MysqlServer {
116    base_server: BaseTcpServer,
117    spawn_ref: Arc<MysqlSpawnRef>,
118    spawn_config: Arc<MysqlSpawnConfig>,
119    bind_addr: Option<SocketAddr>,
120    process_manager: Option<ProcessManagerRef>,
121}
122
123impl MysqlServer {
124    pub fn create_server(
125        io_runtime: Runtime,
126        spawn_ref: Arc<MysqlSpawnRef>,
127        spawn_config: Arc<MysqlSpawnConfig>,
128        process_manager: Option<ProcessManagerRef>,
129    ) -> Box<dyn Server> {
130        Box::new(MysqlServer {
131            base_server: BaseTcpServer::create_server("MySQL", io_runtime),
132            spawn_ref,
133            spawn_config,
134            bind_addr: None,
135            process_manager,
136        })
137    }
138
139    fn accept(
140        &self,
141        io_runtime: Runtime,
142        stream: AbortableStream,
143        process_manager: Option<ProcessManagerRef>,
144    ) -> impl Future<Output = ()> {
145        let spawn_ref = self.spawn_ref.clone();
146        let spawn_config = self.spawn_config.clone();
147
148        stream.for_each(move |tcp_stream| {
149            let spawn_ref = spawn_ref.clone();
150            let spawn_config = spawn_config.clone();
151            let io_runtime = io_runtime.clone();
152            let process_id = process_manager.as_ref().map(|p| p.next_id()).unwrap_or(8);
153            async move {
154                match tcp_stream {
155                    Err(e) => warn!(e; "Broken pipe"), // IoError doesn't impl ErrorExt.
156                    Ok(io_stream) => {
157                        if let Err(e) = io_stream.set_nodelay(true) {
158                            warn!(e; "Failed to set TCP nodelay");
159                        }
160                        io_runtime.spawn(async move {
161                            if let Err(error) =
162                                Self::handle(io_stream, spawn_ref, spawn_config, process_id).await
163                            {
164                                warn!(error; "Unexpected error when handling TcpStream");
165                            };
166                        });
167                    }
168                };
169            }
170        })
171    }
172
173    async fn handle(
174        stream: TcpStream,
175        spawn_ref: Arc<MysqlSpawnRef>,
176        spawn_config: Arc<MysqlSpawnConfig>,
177        process_id: u32,
178    ) -> Result<()> {
179        debug!("MySQL connection coming from: {}", stream.peer_addr()?);
180        crate::metrics::METRIC_MYSQL_CONNECTIONS.inc();
181        if let Err(e) = Self::do_handle(stream, spawn_ref, spawn_config, process_id).await {
182            if let Error::InternalIo { error } = &e
183                && error.kind() == std::io::ErrorKind::ConnectionAborted
184            {
185                // This is a client-side error, we don't need to log it.
186            } else {
187                // TODO(LFC): Write this error to client as well, in MySQL text protocol.
188                // Looks like we have to expose opensrv-mysql's `PacketWriter`?
189                warn!(e; "Internal error occurred during query exec, server actively close the channel to let client try next time");
190            }
191        }
192        crate::metrics::METRIC_MYSQL_CONNECTIONS.dec();
193
194        Ok(())
195    }
196
197    async fn do_handle(
198        stream: TcpStream,
199        spawn_ref: Arc<MysqlSpawnRef>,
200        spawn_config: Arc<MysqlSpawnConfig>,
201        process_id: u32,
202    ) -> Result<()> {
203        let mut shim = MysqlInstanceShim::create(
204            spawn_ref.query_handler(),
205            spawn_ref.user_provider(),
206            stream.peer_addr()?,
207            process_id,
208            spawn_config.prepared_stmt_cache_size,
209        );
210        let (mut r, w) = stream.into_split();
211        let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w);
212
213        let ops = spawn_config.as_ref().into();
214
215        let (client_tls, init_params) =
216            AsyncMysqlIntermediary::init_before_ssl(&mut shim, &mut r, &mut w, &spawn_config.tls())
217                .await?;
218
219        ensure!(
220            !spawn_config.force_tls || client_tls,
221            TlsRequiredSnafu {
222                server: "mysql".to_owned()
223            }
224        );
225
226        match spawn_config.tls() {
227            Some(tls_conf) if client_tls => {
228                secure_run_with_options(shim, w, ops, tls_conf, init_params).await
229            }
230            _ => plain_run_with_options(shim, w, ops, init_params).await,
231        }
232    }
233}
234
235pub const MYSQL_SERVER: &str = "MYSQL_SERVER";
236
237#[async_trait]
238impl Server for MysqlServer {
239    async fn shutdown(&self) -> Result<()> {
240        self.base_server.shutdown().await
241    }
242
243    async fn start(&mut self, listening: SocketAddr) -> Result<()> {
244        let (stream, addr) = self
245            .base_server
246            .bind(listening, self.spawn_config.keep_alive_secs)
247            .await?;
248        let io_runtime = self.base_server.io_runtime();
249
250        let join_handle = common_runtime::spawn_global(self.accept(
251            io_runtime,
252            stream,
253            self.process_manager.clone(),
254        ));
255        self.base_server.start_with(join_handle).await?;
256
257        self.bind_addr = Some(addr);
258        Ok(())
259    }
260
261    fn name(&self) -> &str {
262        MYSQL_SERVER
263    }
264
265    fn bind_addr(&self) -> Option<SocketAddr> {
266        self.bind_addr
267    }
268}