servers/mysql/
server.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::future::Future;
17use std::net::SocketAddr;
18use std::sync::Arc;
19
20use async_trait::async_trait;
21use auth::UserProviderRef;
22use catalog::process_manager::ProcessManagerRef;
23use common_runtime::Runtime;
24use common_runtime::runtime::RuntimeTrait;
25use common_telemetry::{debug, warn};
26use futures::StreamExt;
27use opensrv_mysql::{
28    AsyncMysqlIntermediary, IntermediaryOptions, plain_run_with_options, secure_run_with_options,
29};
30use snafu::ensure;
31use tokio;
32use tokio::io::BufWriter;
33use tokio::net::TcpStream;
34use tokio_rustls::rustls::ServerConfig;
35
36use crate::error::{Error, Result, TlsRequiredSnafu};
37use crate::mysql::handler::MysqlInstanceShim;
38use crate::query_handler::sql::ServerSqlQueryHandlerRef;
39use crate::server::{AbortableStream, BaseTcpServer, Server};
40use crate::tls::ReloadableTlsServerConfig;
41
42// Default size of ResultSet write buffer: 100KB
43const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024;
44
45/// [`MysqlSpawnRef`] stores arc refs
46/// that should be passed to new [`MysqlInstanceShim`]s.
47pub struct MysqlSpawnRef {
48    query_handler: ServerSqlQueryHandlerRef,
49    user_provider: Option<UserProviderRef>,
50}
51
52impl MysqlSpawnRef {
53    pub fn new(
54        query_handler: ServerSqlQueryHandlerRef,
55        user_provider: Option<UserProviderRef>,
56    ) -> MysqlSpawnRef {
57        MysqlSpawnRef {
58            query_handler,
59            user_provider,
60        }
61    }
62
63    fn query_handler(&self) -> ServerSqlQueryHandlerRef {
64        self.query_handler.clone()
65    }
66    fn user_provider(&self) -> Option<UserProviderRef> {
67        self.user_provider.clone()
68    }
69}
70
71/// [`MysqlSpawnConfig`] stores config values
72/// which are used to initialize [`MysqlInstanceShim`]s.
73pub struct MysqlSpawnConfig {
74    // tls config
75    force_tls: bool,
76    tls: Arc<ReloadableTlsServerConfig>,
77    // keep-alive config
78    keep_alive_secs: u64,
79    // other shim config
80    reject_no_database: bool,
81    // prepared statement cache capacity
82    prepared_stmt_cache_size: usize,
83}
84
85impl MysqlSpawnConfig {
86    pub fn new(
87        force_tls: bool,
88        tls: Arc<ReloadableTlsServerConfig>,
89        keep_alive_secs: u64,
90        reject_no_database: bool,
91        prepared_stmt_cache_size: usize,
92    ) -> MysqlSpawnConfig {
93        MysqlSpawnConfig {
94            force_tls,
95            tls,
96            keep_alive_secs,
97            reject_no_database,
98            prepared_stmt_cache_size,
99        }
100    }
101
102    fn tls(&self) -> Option<Arc<ServerConfig>> {
103        self.tls.get_config()
104    }
105}
106
107impl From<&MysqlSpawnConfig> for IntermediaryOptions {
108    fn from(value: &MysqlSpawnConfig) -> Self {
109        IntermediaryOptions {
110            reject_connection_on_dbname_absence: value.reject_no_database,
111            ..Default::default()
112        }
113    }
114}
115
116pub struct MysqlServer {
117    base_server: BaseTcpServer,
118    spawn_ref: Arc<MysqlSpawnRef>,
119    spawn_config: Arc<MysqlSpawnConfig>,
120    bind_addr: Option<SocketAddr>,
121    process_manager: Option<ProcessManagerRef>,
122}
123
124impl MysqlServer {
125    pub fn create_server(
126        io_runtime: Runtime,
127        spawn_ref: Arc<MysqlSpawnRef>,
128        spawn_config: Arc<MysqlSpawnConfig>,
129        process_manager: Option<ProcessManagerRef>,
130    ) -> Box<dyn Server> {
131        Box::new(MysqlServer {
132            base_server: BaseTcpServer::create_server("MySQL", io_runtime),
133            spawn_ref,
134            spawn_config,
135            bind_addr: None,
136            process_manager,
137        })
138    }
139
140    fn accept(
141        &self,
142        io_runtime: Runtime,
143        stream: AbortableStream,
144        process_manager: Option<ProcessManagerRef>,
145    ) -> impl Future<Output = ()> + use<> {
146        let spawn_ref = self.spawn_ref.clone();
147        let spawn_config = self.spawn_config.clone();
148
149        stream.for_each(move |tcp_stream| {
150            let spawn_ref = spawn_ref.clone();
151            let spawn_config = spawn_config.clone();
152            let io_runtime = io_runtime.clone();
153            let process_id = process_manager.as_ref().map(|p| p.next_id()).unwrap_or(8);
154            async move {
155                match tcp_stream {
156                    Err(e) => warn!(e; "Broken pipe"), // IoError doesn't impl ErrorExt.
157                    Ok(io_stream) => {
158                        if let Err(e) = io_stream.set_nodelay(true) {
159                            warn!(e; "Failed to set TCP nodelay");
160                        }
161                        io_runtime.spawn(async move {
162                            if let Err(error) =
163                                Self::handle(io_stream, spawn_ref, spawn_config, process_id).await
164                            {
165                                warn!(error; "Unexpected error when handling TcpStream");
166                            };
167                        });
168                    }
169                };
170            }
171        })
172    }
173
174    async fn handle(
175        stream: TcpStream,
176        spawn_ref: Arc<MysqlSpawnRef>,
177        spawn_config: Arc<MysqlSpawnConfig>,
178        process_id: u32,
179    ) -> Result<()> {
180        debug!("MySQL connection coming from: {}", stream.peer_addr()?);
181        crate::metrics::METRIC_MYSQL_CONNECTIONS.inc();
182        if let Err(e) = Self::do_handle(stream, spawn_ref, spawn_config, process_id).await {
183            if let Error::InternalIo { error } = &e
184                && error.kind() == std::io::ErrorKind::ConnectionAborted
185            {
186                // This is a client-side error, we don't need to log it.
187            } else {
188                // TODO(LFC): Write this error to client as well, in MySQL text protocol.
189                // Looks like we have to expose opensrv-mysql's `PacketWriter`?
190                warn!(e; "Internal error occurred during query exec, server actively close the channel to let client try next time");
191            }
192        }
193        crate::metrics::METRIC_MYSQL_CONNECTIONS.dec();
194
195        Ok(())
196    }
197
198    async fn do_handle(
199        stream: TcpStream,
200        spawn_ref: Arc<MysqlSpawnRef>,
201        spawn_config: Arc<MysqlSpawnConfig>,
202        process_id: u32,
203    ) -> Result<()> {
204        let mut shim = MysqlInstanceShim::create(
205            spawn_ref.query_handler(),
206            spawn_ref.user_provider(),
207            stream.peer_addr()?,
208            process_id,
209            spawn_config.prepared_stmt_cache_size,
210        );
211        let (mut r, w) = stream.into_split();
212        let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w);
213
214        let ops = spawn_config.as_ref().into();
215
216        let (client_tls, init_params) =
217            AsyncMysqlIntermediary::init_before_ssl(&mut shim, &mut r, &mut w, &spawn_config.tls())
218                .await?;
219
220        ensure!(
221            !spawn_config.force_tls || client_tls,
222            TlsRequiredSnafu {
223                server: "mysql".to_owned()
224            }
225        );
226
227        match spawn_config.tls() {
228            Some(tls_conf) if client_tls => {
229                secure_run_with_options(shim, w, ops, tls_conf, init_params).await
230            }
231            _ => plain_run_with_options(shim, w, ops, init_params).await,
232        }
233    }
234}
235
236pub const MYSQL_SERVER: &str = "MYSQL_SERVER";
237
238#[async_trait]
239impl Server for MysqlServer {
240    async fn shutdown(&self) -> Result<()> {
241        self.base_server.shutdown().await
242    }
243
244    async fn start(&mut self, listening: SocketAddr) -> Result<()> {
245        let (stream, addr) = self
246            .base_server
247            .bind(listening, self.spawn_config.keep_alive_secs)
248            .await?;
249        let io_runtime = self.base_server.io_runtime();
250
251        let join_handle = common_runtime::spawn_global(self.accept(
252            io_runtime,
253            stream,
254            self.process_manager.clone(),
255        ));
256        self.base_server.start_with(join_handle).await?;
257
258        self.bind_addr = Some(addr);
259        Ok(())
260    }
261
262    fn name(&self) -> &str {
263        MYSQL_SERVER
264    }
265
266    fn bind_addr(&self) -> Option<SocketAddr> {
267        self.bind_addr
268    }
269
270    fn as_any(&self) -> &dyn Any {
271        self
272    }
273}