servers/
tls.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fs::File;
16use std::io::{BufReader, Error as IoError, ErrorKind};
17use std::path::Path;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::mpsc::channel;
20use std::sync::{Arc, RwLock};
21
22use common_telemetry::{error, info};
23use notify::{EventKind, RecursiveMode, Watcher};
24use rustls::ServerConfig;
25use rustls_pemfile::{certs, read_one, Item};
26use rustls_pki_types::{CertificateDer, PrivateKeyDer};
27use serde::{Deserialize, Serialize};
28use snafu::ResultExt;
29use strum::EnumString;
30
31use crate::error::{FileWatchSnafu, InternalIoSnafu, Result};
32
33/// TlsMode is used for Mysql and Postgres server start up.
34#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq, EnumString)]
35#[serde(rename_all = "snake_case")]
36pub enum TlsMode {
37    #[default]
38    #[strum(to_string = "disable")]
39    Disable,
40
41    #[strum(to_string = "prefer")]
42    Prefer,
43
44    #[strum(to_string = "require")]
45    Require,
46
47    // TODO(SSebo): Implement the following 2 TSL mode described in
48    // ["34.19.3. Protection Provided in Different Modes"](https://www.postgresql.org/docs/current/libpq-ssl.html)
49    #[strum(to_string = "verify-ca")]
50    VerifyCa,
51
52    #[strum(to_string = "verify-full")]
53    VerifyFull,
54}
55
56#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq)]
57#[serde(rename_all = "snake_case")]
58pub struct TlsOption {
59    pub mode: TlsMode,
60    #[serde(default)]
61    pub cert_path: String,
62    #[serde(default)]
63    pub key_path: String,
64    #[serde(default)]
65    pub watch: bool,
66}
67
68impl TlsOption {
69    pub fn new(mode: Option<TlsMode>, cert_path: Option<String>, key_path: Option<String>) -> Self {
70        let mut tls_option = TlsOption::default();
71
72        if let Some(mode) = mode {
73            tls_option.mode = mode
74        };
75
76        if let Some(cert_path) = cert_path {
77            tls_option.cert_path = cert_path
78        };
79
80        if let Some(key_path) = key_path {
81            tls_option.key_path = key_path
82        };
83
84        tls_option
85    }
86
87    pub fn setup(&self) -> Result<Option<ServerConfig>> {
88        if let TlsMode::Disable = self.mode {
89            return Ok(None);
90        }
91        let cert = certs(&mut BufReader::new(
92            File::open(&self.cert_path)
93                .inspect_err(|e| error!(e; "Failed to open {}", self.cert_path))
94                .context(InternalIoSnafu)?,
95        ))
96        .collect::<std::result::Result<Vec<CertificateDer>, IoError>>()
97        .context(InternalIoSnafu)?;
98
99        let mut key_reader = BufReader::new(
100            File::open(&self.key_path)
101                .inspect_err(|e| error!(e; "Failed to open {}", self.key_path))
102                .context(InternalIoSnafu)?,
103        );
104        let key = match read_one(&mut key_reader)
105            .inspect_err(|e| error!(e; "Failed to read {}", self.key_path))
106            .context(InternalIoSnafu)?
107        {
108            Some(Item::Pkcs1Key(key)) => PrivateKeyDer::from(key),
109            Some(Item::Pkcs8Key(key)) => PrivateKeyDer::from(key),
110            Some(Item::Sec1Key(key)) => PrivateKeyDer::from(key),
111            _ => {
112                return Err(IoError::new(ErrorKind::InvalidInput, "invalid key"))
113                    .context(InternalIoSnafu);
114            }
115        };
116
117        // TODO(SSebo): with_client_cert_verifier if TlsMode is Required.
118        let config = ServerConfig::builder()
119            .with_no_client_auth()
120            .with_single_cert(cert, key)
121            .map_err(|err| std::io::Error::new(ErrorKind::InvalidInput, err))?;
122
123        Ok(Some(config))
124    }
125
126    pub fn should_force_tls(&self) -> bool {
127        !matches!(self.mode, TlsMode::Disable | TlsMode::Prefer)
128    }
129
130    pub fn cert_path(&self) -> &Path {
131        Path::new(&self.cert_path)
132    }
133
134    pub fn key_path(&self) -> &Path {
135        Path::new(&self.key_path)
136    }
137
138    pub fn watch_enabled(&self) -> bool {
139        self.mode != TlsMode::Disable && self.watch
140    }
141}
142
143/// A mutable container for TLS server config
144///
145/// This struct allows dynamic reloading of server certificates and keys
146pub struct ReloadableTlsServerConfig {
147    tls_option: TlsOption,
148    config: RwLock<Option<Arc<ServerConfig>>>,
149    version: AtomicUsize,
150}
151
152impl ReloadableTlsServerConfig {
153    /// Create server config by loading configuration from `TlsOption`
154    pub fn try_new(tls_option: TlsOption) -> Result<ReloadableTlsServerConfig> {
155        let server_config = tls_option.setup()?;
156        Ok(Self {
157            tls_option,
158            config: RwLock::new(server_config.map(Arc::new)),
159            version: AtomicUsize::new(0),
160        })
161    }
162
163    /// Reread server certificates and keys from file system.
164    pub fn reload(&self) -> Result<()> {
165        let server_config = self.tls_option.setup()?;
166        *self.config.write().unwrap() = server_config.map(Arc::new);
167        self.version.fetch_add(1, Ordering::Relaxed);
168        Ok(())
169    }
170
171    /// Get the server config hold by this container
172    pub fn get_server_config(&self) -> Option<Arc<ServerConfig>> {
173        self.config.read().unwrap().clone()
174    }
175
176    /// Get associated `TlsOption`
177    pub fn get_tls_option(&self) -> &TlsOption {
178        &self.tls_option
179    }
180
181    /// Get version of current config
182    ///
183    /// this version will auto increase when server config get reloaded.
184    pub fn get_version(&self) -> usize {
185        self.version.load(Ordering::Relaxed)
186    }
187}
188
189pub fn maybe_watch_tls_config(tls_server_config: Arc<ReloadableTlsServerConfig>) -> Result<()> {
190    if !tls_server_config.get_tls_option().watch_enabled() {
191        return Ok(());
192    }
193
194    let tls_server_config_for_watcher = tls_server_config.clone();
195
196    let (tx, rx) = channel::<notify::Result<notify::Event>>();
197    let mut watcher = notify::recommended_watcher(tx).context(FileWatchSnafu { path: "<none>" })?;
198
199    let cert_path = tls_server_config.get_tls_option().cert_path();
200    watcher
201        .watch(cert_path, RecursiveMode::NonRecursive)
202        .with_context(|_| FileWatchSnafu {
203            path: cert_path.display().to_string(),
204        })?;
205
206    let key_path = tls_server_config.get_tls_option().key_path();
207    watcher
208        .watch(key_path, RecursiveMode::NonRecursive)
209        .with_context(|_| FileWatchSnafu {
210            path: key_path.display().to_string(),
211        })?;
212
213    std::thread::spawn(move || {
214        let _watcher = watcher;
215        while let Ok(res) = rx.recv() {
216            if let Ok(event) = res {
217                match event.kind {
218                    EventKind::Modify(_) | EventKind::Create(_) => {
219                        info!("Detected TLS cert/key file change: {:?}", event);
220                        if let Err(err) = tls_server_config_for_watcher.reload() {
221                            error!(err; "Failed to reload TLS server config");
222                        } else {
223                            info!("Reloaded TLS cert/key file successfully.");
224                        }
225                    }
226                    _ => {}
227                }
228            }
229        }
230    });
231
232    Ok(())
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use crate::install_ring_crypto_provider;
239    use crate::tls::TlsMode::Disable;
240
241    #[test]
242    fn test_new_tls_option() {
243        assert_eq!(TlsOption::default(), TlsOption::new(None, None, None));
244        assert_eq!(
245            TlsOption {
246                mode: Disable,
247                ..Default::default()
248            },
249            TlsOption::new(Some(Disable), None, None)
250        );
251        assert_eq!(
252            TlsOption {
253                mode: Disable,
254                cert_path: "/path/to/cert_path".to_string(),
255                key_path: "/path/to/key_path".to_string(),
256                watch: false
257            },
258            TlsOption::new(
259                Some(Disable),
260                Some("/path/to/cert_path".to_string()),
261                Some("/path/to/key_path".to_string())
262            )
263        );
264    }
265
266    #[test]
267    fn test_tls_option_disable() {
268        let s = r#"
269        {
270            "mode": "disable"
271        }
272        "#;
273
274        let t: TlsOption = serde_json::from_str(s).unwrap();
275
276        assert!(!t.should_force_tls());
277
278        assert!(matches!(t.mode, TlsMode::Disable));
279        assert!(t.key_path.is_empty());
280        assert!(t.cert_path.is_empty());
281        assert!(!t.watch_enabled());
282
283        let setup = t.setup();
284        let setup = setup.unwrap();
285        assert!(setup.is_none());
286    }
287
288    #[test]
289    fn test_tls_option_prefer() {
290        let s = r#"
291        {
292            "mode": "prefer",
293            "cert_path": "/some_dir/some.crt",
294            "key_path": "/some_dir/some.key"
295        }
296        "#;
297
298        let t: TlsOption = serde_json::from_str(s).unwrap();
299
300        assert!(!t.should_force_tls());
301
302        assert!(matches!(t.mode, TlsMode::Prefer));
303        assert!(!t.key_path.is_empty());
304        assert!(!t.cert_path.is_empty());
305        assert!(!t.watch_enabled());
306    }
307
308    #[test]
309    fn test_tls_option_require() {
310        let s = r#"
311        {
312            "mode": "require",
313            "cert_path": "/some_dir/some.crt",
314            "key_path": "/some_dir/some.key"
315        }
316        "#;
317
318        let t: TlsOption = serde_json::from_str(s).unwrap();
319
320        assert!(t.should_force_tls());
321
322        assert!(matches!(t.mode, TlsMode::Require));
323        assert!(!t.key_path.is_empty());
324        assert!(!t.cert_path.is_empty());
325        assert!(!t.watch_enabled());
326    }
327
328    #[test]
329    fn test_tls_option_verify_ca() {
330        let s = r#"
331        {
332            "mode": "verify_ca",
333            "cert_path": "/some_dir/some.crt",
334            "key_path": "/some_dir/some.key"
335        }
336        "#;
337
338        let t: TlsOption = serde_json::from_str(s).unwrap();
339
340        assert!(t.should_force_tls());
341
342        assert!(matches!(t.mode, TlsMode::VerifyCa));
343        assert!(!t.key_path.is_empty());
344        assert!(!t.cert_path.is_empty());
345        assert!(!t.watch_enabled());
346    }
347
348    #[test]
349    fn test_tls_option_verify_full() {
350        let s = r#"
351        {
352            "mode": "verify_full",
353            "cert_path": "/some_dir/some.crt",
354            "key_path": "/some_dir/some.key"
355        }
356        "#;
357
358        let t: TlsOption = serde_json::from_str(s).unwrap();
359
360        assert!(t.should_force_tls());
361
362        assert!(matches!(t.mode, TlsMode::VerifyFull));
363        assert!(!t.key_path.is_empty());
364        assert!(!t.cert_path.is_empty());
365        assert!(!t.watch_enabled());
366    }
367
368    #[test]
369    fn test_tls_option_watch_enabled() {
370        let s = r#"
371        {
372            "mode": "verify_full",
373            "cert_path": "/some_dir/some.crt",
374            "key_path": "/some_dir/some.key",
375            "watch": true
376        }
377        "#;
378
379        let t: TlsOption = serde_json::from_str(s).unwrap();
380
381        assert!(t.should_force_tls());
382
383        assert!(matches!(t.mode, TlsMode::VerifyFull));
384        assert!(!t.key_path.is_empty());
385        assert!(!t.cert_path.is_empty());
386        assert!(t.watch_enabled());
387    }
388
389    #[test]
390    fn test_tls_file_change_watch() {
391        common_telemetry::init_default_ut_logging();
392        let _ = install_ring_crypto_provider();
393
394        let dir = tempfile::tempdir().unwrap();
395        let cert_path = dir.path().join("server.crt");
396        let key_path = dir.path().join("server.key");
397
398        std::fs::copy("tests/ssl/server.crt", &cert_path).expect("failed to copy cert to tmpdir");
399        std::fs::copy("tests/ssl/server-rsa.key", &key_path).expect("failed to copy key to tmpdir");
400
401        assert!(std::fs::exists(&cert_path).unwrap());
402        assert!(std::fs::exists(&key_path).unwrap());
403
404        let server_tls = TlsOption {
405            mode: TlsMode::Require,
406            cert_path: cert_path
407                .clone()
408                .into_os_string()
409                .into_string()
410                .expect("failed to convert path to string"),
411            key_path: key_path
412                .clone()
413                .into_os_string()
414                .into_string()
415                .expect("failed to convert path to string"),
416            watch: true,
417        };
418
419        let server_config = Arc::new(
420            ReloadableTlsServerConfig::try_new(server_tls).expect("failed to create server config"),
421        );
422        maybe_watch_tls_config(server_config.clone()).expect("failed to watch server config");
423
424        assert_eq!(0, server_config.get_version());
425        assert!(server_config.get_server_config().is_some());
426
427        let tmp_file = key_path.with_extension("tmp");
428        std::fs::copy("tests/ssl/server-pkcs8.key", &tmp_file)
429            .expect("Failed to copy temp key file");
430        std::fs::rename(&tmp_file, &key_path).expect("Failed to rename temp key file");
431
432        const MAX_RETRIES: usize = 30;
433        let mut retries = 0;
434        let mut version_updated = false;
435
436        while retries < MAX_RETRIES {
437            if server_config.get_version() > 0 {
438                version_updated = true;
439                break;
440            }
441            std::thread::sleep(std::time::Duration::from_millis(100));
442            retries += 1;
443        }
444
445        assert!(version_updated, "TLS config did not reload in time");
446        assert!(server_config.get_version() > 0);
447        assert!(server_config.get_server_config().is_some());
448    }
449}