servers/
tls.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fs::File;
16use std::io::{BufReader, Error as IoError, ErrorKind};
17use std::path::Path;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::mpsc::channel;
20use std::sync::{Arc, RwLock};
21
22use common_telemetry::{error, info};
23use notify::{EventKind, RecursiveMode, Watcher};
24use rustls::ServerConfig;
25use rustls_pemfile::{certs, read_one, Item};
26use rustls_pki_types::{CertificateDer, PrivateKeyDer};
27use serde::{Deserialize, Serialize};
28use snafu::ResultExt;
29use strum::EnumString;
30
31use crate::error::{FileWatchSnafu, InternalIoSnafu, Result};
32
33/// TlsMode is used for Mysql and Postgres server start up.
34#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq, EnumString)]
35#[serde(rename_all = "snake_case")]
36pub enum TlsMode {
37    #[default]
38    #[strum(to_string = "disable")]
39    Disable,
40
41    #[strum(to_string = "prefer")]
42    Prefer,
43
44    #[strum(to_string = "require")]
45    Require,
46
47    // TODO(SSebo): Implement the following 2 TSL mode described in
48    // ["34.19.3. Protection Provided in Different Modes"](https://www.postgresql.org/docs/current/libpq-ssl.html)
49    #[strum(to_string = "verify-ca")]
50    VerifyCa,
51
52    #[strum(to_string = "verify-full")]
53    VerifyFull,
54}
55
56#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq)]
57#[serde(rename_all = "snake_case")]
58pub struct TlsOption {
59    pub mode: TlsMode,
60    #[serde(default)]
61    pub cert_path: String,
62    #[serde(default)]
63    pub key_path: String,
64    #[serde(default)]
65    pub watch: bool,
66}
67
68impl TlsOption {
69    pub fn new(mode: Option<TlsMode>, cert_path: Option<String>, key_path: Option<String>) -> Self {
70        let mut tls_option = TlsOption::default();
71
72        if let Some(mode) = mode {
73            tls_option.mode = mode
74        };
75
76        if let Some(cert_path) = cert_path {
77            tls_option.cert_path = cert_path
78        };
79
80        if let Some(key_path) = key_path {
81            tls_option.key_path = key_path
82        };
83
84        tls_option
85    }
86
87    pub fn setup(&self) -> Result<Option<ServerConfig>> {
88        if let TlsMode::Disable = self.mode {
89            return Ok(None);
90        }
91        let cert = certs(&mut BufReader::new(
92            File::open(&self.cert_path).context(InternalIoSnafu)?,
93        ))
94        .collect::<std::result::Result<Vec<CertificateDer>, IoError>>()
95        .context(InternalIoSnafu)?;
96
97        let mut key_reader = BufReader::new(File::open(&self.key_path).context(InternalIoSnafu)?);
98        let key = match read_one(&mut key_reader).context(InternalIoSnafu)? {
99            Some(Item::Pkcs1Key(key)) => PrivateKeyDer::from(key),
100            Some(Item::Pkcs8Key(key)) => PrivateKeyDer::from(key),
101            Some(Item::Sec1Key(key)) => PrivateKeyDer::from(key),
102            _ => {
103                return Err(IoError::new(ErrorKind::InvalidInput, "invalid key"))
104                    .context(InternalIoSnafu);
105            }
106        };
107
108        // TODO(SSebo): with_client_cert_verifier if TlsMode is Required.
109        let config = ServerConfig::builder()
110            .with_no_client_auth()
111            .with_single_cert(cert, key)
112            .map_err(|err| std::io::Error::new(ErrorKind::InvalidInput, err))?;
113
114        Ok(Some(config))
115    }
116
117    pub fn should_force_tls(&self) -> bool {
118        !matches!(self.mode, TlsMode::Disable | TlsMode::Prefer)
119    }
120
121    pub fn cert_path(&self) -> &Path {
122        Path::new(&self.cert_path)
123    }
124
125    pub fn key_path(&self) -> &Path {
126        Path::new(&self.key_path)
127    }
128
129    pub fn watch_enabled(&self) -> bool {
130        self.mode != TlsMode::Disable && self.watch
131    }
132}
133
134/// A mutable container for TLS server config
135///
136/// This struct allows dynamic reloading of server certificates and keys
137pub struct ReloadableTlsServerConfig {
138    tls_option: TlsOption,
139    config: RwLock<Option<Arc<ServerConfig>>>,
140    version: AtomicUsize,
141}
142
143impl ReloadableTlsServerConfig {
144    /// Create server config by loading configuration from `TlsOption`
145    pub fn try_new(tls_option: TlsOption) -> Result<ReloadableTlsServerConfig> {
146        let server_config = tls_option.setup()?;
147        Ok(Self {
148            tls_option,
149            config: RwLock::new(server_config.map(Arc::new)),
150            version: AtomicUsize::new(0),
151        })
152    }
153
154    /// Reread server certificates and keys from file system.
155    pub fn reload(&self) -> Result<()> {
156        let server_config = self.tls_option.setup()?;
157        *self.config.write().unwrap() = server_config.map(Arc::new);
158        self.version.fetch_add(1, Ordering::Relaxed);
159        Ok(())
160    }
161
162    /// Get the server config hold by this container
163    pub fn get_server_config(&self) -> Option<Arc<ServerConfig>> {
164        self.config.read().unwrap().clone()
165    }
166
167    /// Get associated `TlsOption`
168    pub fn get_tls_option(&self) -> &TlsOption {
169        &self.tls_option
170    }
171
172    /// Get version of current config
173    ///
174    /// this version will auto increase when server config get reloaded.
175    pub fn get_version(&self) -> usize {
176        self.version.load(Ordering::Relaxed)
177    }
178}
179
180pub fn maybe_watch_tls_config(tls_server_config: Arc<ReloadableTlsServerConfig>) -> Result<()> {
181    if !tls_server_config.get_tls_option().watch_enabled() {
182        return Ok(());
183    }
184
185    let tls_server_config_for_watcher = tls_server_config.clone();
186
187    let (tx, rx) = channel::<notify::Result<notify::Event>>();
188    let mut watcher = notify::recommended_watcher(tx).context(FileWatchSnafu { path: "<none>" })?;
189
190    let cert_path = tls_server_config.get_tls_option().cert_path();
191    watcher
192        .watch(cert_path, RecursiveMode::NonRecursive)
193        .with_context(|_| FileWatchSnafu {
194            path: cert_path.display().to_string(),
195        })?;
196
197    let key_path = tls_server_config.get_tls_option().key_path();
198    watcher
199        .watch(key_path, RecursiveMode::NonRecursive)
200        .with_context(|_| FileWatchSnafu {
201            path: key_path.display().to_string(),
202        })?;
203
204    std::thread::spawn(move || {
205        let _watcher = watcher;
206        while let Ok(res) = rx.recv() {
207            if let Ok(event) = res {
208                match event.kind {
209                    EventKind::Modify(_) | EventKind::Create(_) => {
210                        info!("Detected TLS cert/key file change: {:?}", event);
211                        if let Err(err) = tls_server_config_for_watcher.reload() {
212                            error!(err; "Failed to reload TLS server config");
213                        }
214                    }
215                    _ => {}
216                }
217            }
218        }
219    });
220
221    Ok(())
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use crate::install_ring_crypto_provider;
228    use crate::tls::TlsMode::Disable;
229
230    #[test]
231    fn test_new_tls_option() {
232        assert_eq!(TlsOption::default(), TlsOption::new(None, None, None));
233        assert_eq!(
234            TlsOption {
235                mode: Disable,
236                ..Default::default()
237            },
238            TlsOption::new(Some(Disable), None, None)
239        );
240        assert_eq!(
241            TlsOption {
242                mode: Disable,
243                cert_path: "/path/to/cert_path".to_string(),
244                key_path: "/path/to/key_path".to_string(),
245                watch: false
246            },
247            TlsOption::new(
248                Some(Disable),
249                Some("/path/to/cert_path".to_string()),
250                Some("/path/to/key_path".to_string())
251            )
252        );
253    }
254
255    #[test]
256    fn test_tls_option_disable() {
257        let s = r#"
258        {
259            "mode": "disable"
260        }
261        "#;
262
263        let t: TlsOption = serde_json::from_str(s).unwrap();
264
265        assert!(!t.should_force_tls());
266
267        assert!(matches!(t.mode, TlsMode::Disable));
268        assert!(t.key_path.is_empty());
269        assert!(t.cert_path.is_empty());
270        assert!(!t.watch_enabled());
271
272        let setup = t.setup();
273        let setup = setup.unwrap();
274        assert!(setup.is_none());
275    }
276
277    #[test]
278    fn test_tls_option_prefer() {
279        let s = r#"
280        {
281            "mode": "prefer",
282            "cert_path": "/some_dir/some.crt",
283            "key_path": "/some_dir/some.key"
284        }
285        "#;
286
287        let t: TlsOption = serde_json::from_str(s).unwrap();
288
289        assert!(!t.should_force_tls());
290
291        assert!(matches!(t.mode, TlsMode::Prefer));
292        assert!(!t.key_path.is_empty());
293        assert!(!t.cert_path.is_empty());
294        assert!(!t.watch_enabled());
295    }
296
297    #[test]
298    fn test_tls_option_require() {
299        let s = r#"
300        {
301            "mode": "require",
302            "cert_path": "/some_dir/some.crt",
303            "key_path": "/some_dir/some.key"
304        }
305        "#;
306
307        let t: TlsOption = serde_json::from_str(s).unwrap();
308
309        assert!(t.should_force_tls());
310
311        assert!(matches!(t.mode, TlsMode::Require));
312        assert!(!t.key_path.is_empty());
313        assert!(!t.cert_path.is_empty());
314        assert!(!t.watch_enabled());
315    }
316
317    #[test]
318    fn test_tls_option_verify_ca() {
319        let s = r#"
320        {
321            "mode": "verify_ca",
322            "cert_path": "/some_dir/some.crt",
323            "key_path": "/some_dir/some.key"
324        }
325        "#;
326
327        let t: TlsOption = serde_json::from_str(s).unwrap();
328
329        assert!(t.should_force_tls());
330
331        assert!(matches!(t.mode, TlsMode::VerifyCa));
332        assert!(!t.key_path.is_empty());
333        assert!(!t.cert_path.is_empty());
334        assert!(!t.watch_enabled());
335    }
336
337    #[test]
338    fn test_tls_option_verify_full() {
339        let s = r#"
340        {
341            "mode": "verify_full",
342            "cert_path": "/some_dir/some.crt",
343            "key_path": "/some_dir/some.key"
344        }
345        "#;
346
347        let t: TlsOption = serde_json::from_str(s).unwrap();
348
349        assert!(t.should_force_tls());
350
351        assert!(matches!(t.mode, TlsMode::VerifyFull));
352        assert!(!t.key_path.is_empty());
353        assert!(!t.cert_path.is_empty());
354        assert!(!t.watch_enabled());
355    }
356
357    #[test]
358    fn test_tls_option_watch_enabled() {
359        let s = r#"
360        {
361            "mode": "verify_full",
362            "cert_path": "/some_dir/some.crt",
363            "key_path": "/some_dir/some.key",
364            "watch": true
365        }
366        "#;
367
368        let t: TlsOption = serde_json::from_str(s).unwrap();
369
370        assert!(t.should_force_tls());
371
372        assert!(matches!(t.mode, TlsMode::VerifyFull));
373        assert!(!t.key_path.is_empty());
374        assert!(!t.cert_path.is_empty());
375        assert!(t.watch_enabled());
376    }
377
378    #[test]
379    fn test_tls_file_change_watch() {
380        common_telemetry::init_default_ut_logging();
381        let _ = install_ring_crypto_provider();
382
383        let dir = tempfile::tempdir().unwrap();
384        let cert_path = dir.path().join("serevr.crt");
385        let key_path = dir.path().join("server.key");
386
387        std::fs::copy("tests/ssl/server.crt", &cert_path).expect("failed to copy cert to tmpdir");
388        std::fs::copy("tests/ssl/server-rsa.key", &key_path).expect("failed to copy key to tmpdir");
389
390        let server_tls = TlsOption {
391            mode: TlsMode::Require,
392            cert_path: cert_path
393                .clone()
394                .into_os_string()
395                .into_string()
396                .expect("failed to convert path to string"),
397            key_path: key_path
398                .clone()
399                .into_os_string()
400                .into_string()
401                .expect("failed to convert path to string"),
402            watch: true,
403        };
404
405        let server_config = Arc::new(
406            ReloadableTlsServerConfig::try_new(server_tls).expect("failed to create server config"),
407        );
408        maybe_watch_tls_config(server_config.clone()).expect("failed to watch server config");
409
410        assert_eq!(0, server_config.get_version());
411        assert!(server_config.get_server_config().is_some());
412
413        std::fs::copy("tests/ssl/server-pkcs8.key", &key_path)
414            .expect("failed to copy key to tmpdir");
415
416        // waiting for async load
417        #[cfg(not(target_os = "windows"))]
418        let timeout_millis = 300;
419        #[cfg(target_os = "windows")]
420        let timeout_millis = 2000;
421
422        std::thread::sleep(std::time::Duration::from_millis(timeout_millis));
423
424        assert!(server_config.get_version() > 1);
425        assert!(server_config.get_server_config().is_some());
426    }
427}