servers/
tls.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fs::File;
16use std::io::{BufReader, Error as IoError, ErrorKind};
17use std::path::Path;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::mpsc::channel;
20use std::sync::{Arc, RwLock};
21
22use common_telemetry::{error, info};
23use notify::{EventKind, RecursiveMode, Watcher};
24use rustls::ServerConfig;
25use rustls_pemfile::{certs, read_one, Item};
26use rustls_pki_types::{CertificateDer, PrivateKeyDer};
27use serde::{Deserialize, Serialize};
28use snafu::ResultExt;
29use strum::EnumString;
30
31use crate::error::{FileWatchSnafu, InternalIoSnafu, Result};
32
33/// TlsMode is used for Mysql and Postgres server start up.
34#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq, EnumString)]
35#[serde(rename_all = "snake_case")]
36pub enum TlsMode {
37    #[default]
38    #[strum(to_string = "disable")]
39    Disable,
40
41    #[strum(to_string = "prefer")]
42    Prefer,
43
44    #[strum(to_string = "require")]
45    Require,
46
47    // TODO(SSebo): Implement the following 2 TSL mode described in
48    // ["34.19.3. Protection Provided in Different Modes"](https://www.postgresql.org/docs/current/libpq-ssl.html)
49    #[strum(to_string = "verify-ca")]
50    VerifyCa,
51
52    #[strum(to_string = "verify-full")]
53    VerifyFull,
54}
55
56#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq)]
57#[serde(rename_all = "snake_case")]
58pub struct TlsOption {
59    pub mode: TlsMode,
60    #[serde(default)]
61    pub cert_path: String,
62    #[serde(default)]
63    pub key_path: String,
64    #[serde(default)]
65    pub ca_cert_path: String,
66    #[serde(default)]
67    pub watch: bool,
68}
69
70impl TlsOption {
71    pub fn new(mode: Option<TlsMode>, cert_path: Option<String>, key_path: Option<String>) -> Self {
72        let mut tls_option = TlsOption::default();
73
74        if let Some(mode) = mode {
75            tls_option.mode = mode
76        };
77
78        if let Some(cert_path) = cert_path {
79            tls_option.cert_path = cert_path
80        };
81
82        if let Some(key_path) = key_path {
83            tls_option.key_path = key_path
84        };
85
86        tls_option
87    }
88
89    pub fn setup(&self) -> Result<Option<ServerConfig>> {
90        if let TlsMode::Disable = self.mode {
91            return Ok(None);
92        }
93        let cert = certs(&mut BufReader::new(
94            File::open(&self.cert_path)
95                .inspect_err(|e| error!(e; "Failed to open {}", self.cert_path))
96                .context(InternalIoSnafu)?,
97        ))
98        .collect::<std::result::Result<Vec<CertificateDer>, IoError>>()
99        .context(InternalIoSnafu)?;
100
101        let mut key_reader = BufReader::new(
102            File::open(&self.key_path)
103                .inspect_err(|e| error!(e; "Failed to open {}", self.key_path))
104                .context(InternalIoSnafu)?,
105        );
106        let key = match read_one(&mut key_reader)
107            .inspect_err(|e| error!(e; "Failed to read {}", self.key_path))
108            .context(InternalIoSnafu)?
109        {
110            Some(Item::Pkcs1Key(key)) => PrivateKeyDer::from(key),
111            Some(Item::Pkcs8Key(key)) => PrivateKeyDer::from(key),
112            Some(Item::Sec1Key(key)) => PrivateKeyDer::from(key),
113            _ => {
114                return Err(IoError::new(ErrorKind::InvalidInput, "invalid key"))
115                    .context(InternalIoSnafu);
116            }
117        };
118
119        // TODO(SSebo): with_client_cert_verifier if TlsMode is Required.
120        let config = ServerConfig::builder()
121            .with_no_client_auth()
122            .with_single_cert(cert, key)
123            .map_err(|err| std::io::Error::new(ErrorKind::InvalidInput, err))?;
124
125        Ok(Some(config))
126    }
127
128    pub fn should_force_tls(&self) -> bool {
129        !matches!(self.mode, TlsMode::Disable | TlsMode::Prefer)
130    }
131
132    pub fn cert_path(&self) -> &Path {
133        Path::new(&self.cert_path)
134    }
135
136    pub fn key_path(&self) -> &Path {
137        Path::new(&self.key_path)
138    }
139
140    pub fn watch_enabled(&self) -> bool {
141        self.mode != TlsMode::Disable && self.watch
142    }
143}
144
145/// A mutable container for TLS server config
146///
147/// This struct allows dynamic reloading of server certificates and keys
148pub struct ReloadableTlsServerConfig {
149    tls_option: TlsOption,
150    config: RwLock<Option<Arc<ServerConfig>>>,
151    version: AtomicUsize,
152}
153
154impl ReloadableTlsServerConfig {
155    /// Create server config by loading configuration from `TlsOption`
156    pub fn try_new(tls_option: TlsOption) -> Result<ReloadableTlsServerConfig> {
157        let server_config = tls_option.setup()?;
158        Ok(Self {
159            tls_option,
160            config: RwLock::new(server_config.map(Arc::new)),
161            version: AtomicUsize::new(0),
162        })
163    }
164
165    /// Reread server certificates and keys from file system.
166    pub fn reload(&self) -> Result<()> {
167        let server_config = self.tls_option.setup()?;
168        *self.config.write().unwrap() = server_config.map(Arc::new);
169        self.version.fetch_add(1, Ordering::Relaxed);
170        Ok(())
171    }
172
173    /// Get the server config hold by this container
174    pub fn get_server_config(&self) -> Option<Arc<ServerConfig>> {
175        self.config.read().unwrap().clone()
176    }
177
178    /// Get associated `TlsOption`
179    pub fn get_tls_option(&self) -> &TlsOption {
180        &self.tls_option
181    }
182
183    /// Get version of current config
184    ///
185    /// this version will auto increase when server config get reloaded.
186    pub fn get_version(&self) -> usize {
187        self.version.load(Ordering::Relaxed)
188    }
189}
190
191pub fn maybe_watch_tls_config(tls_server_config: Arc<ReloadableTlsServerConfig>) -> Result<()> {
192    if !tls_server_config.get_tls_option().watch_enabled() {
193        return Ok(());
194    }
195
196    let tls_server_config_for_watcher = tls_server_config.clone();
197
198    let (tx, rx) = channel::<notify::Result<notify::Event>>();
199    let mut watcher = notify::recommended_watcher(tx).context(FileWatchSnafu { path: "<none>" })?;
200
201    let cert_path = tls_server_config.get_tls_option().cert_path();
202    watcher
203        .watch(cert_path, RecursiveMode::NonRecursive)
204        .with_context(|_| FileWatchSnafu {
205            path: cert_path.display().to_string(),
206        })?;
207
208    let key_path = tls_server_config.get_tls_option().key_path();
209    watcher
210        .watch(key_path, RecursiveMode::NonRecursive)
211        .with_context(|_| FileWatchSnafu {
212            path: key_path.display().to_string(),
213        })?;
214
215    std::thread::spawn(move || {
216        let _watcher = watcher;
217        while let Ok(res) = rx.recv() {
218            if let Ok(event) = res {
219                match event.kind {
220                    EventKind::Modify(_) | EventKind::Create(_) => {
221                        info!("Detected TLS cert/key file change: {:?}", event);
222                        if let Err(err) = tls_server_config_for_watcher.reload() {
223                            error!(err; "Failed to reload TLS server config");
224                        } else {
225                            info!("Reloaded TLS cert/key file successfully.");
226                        }
227                    }
228                    _ => {}
229                }
230            }
231        }
232    });
233
234    Ok(())
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use crate::install_ring_crypto_provider;
241    use crate::tls::TlsMode::Disable;
242
243    #[test]
244    fn test_new_tls_option() {
245        assert_eq!(TlsOption::default(), TlsOption::new(None, None, None));
246        assert_eq!(
247            TlsOption {
248                mode: Disable,
249                ..Default::default()
250            },
251            TlsOption::new(Some(Disable), None, None)
252        );
253        assert_eq!(
254            TlsOption {
255                mode: Disable,
256                cert_path: "/path/to/cert_path".to_string(),
257                key_path: "/path/to/key_path".to_string(),
258                ca_cert_path: String::new(),
259                watch: false
260            },
261            TlsOption::new(
262                Some(Disable),
263                Some("/path/to/cert_path".to_string()),
264                Some("/path/to/key_path".to_string())
265            )
266        );
267    }
268
269    #[test]
270    fn test_tls_option_disable() {
271        let s = r#"
272        {
273            "mode": "disable"
274        }
275        "#;
276
277        let t: TlsOption = serde_json::from_str(s).unwrap();
278
279        assert!(!t.should_force_tls());
280
281        assert!(matches!(t.mode, TlsMode::Disable));
282        assert!(t.key_path.is_empty());
283        assert!(t.cert_path.is_empty());
284        assert!(!t.watch_enabled());
285
286        let setup = t.setup();
287        let setup = setup.unwrap();
288        assert!(setup.is_none());
289    }
290
291    #[test]
292    fn test_tls_option_prefer() {
293        let s = r#"
294        {
295            "mode": "prefer",
296            "cert_path": "/some_dir/some.crt",
297            "key_path": "/some_dir/some.key"
298        }
299        "#;
300
301        let t: TlsOption = serde_json::from_str(s).unwrap();
302
303        assert!(!t.should_force_tls());
304
305        assert!(matches!(t.mode, TlsMode::Prefer));
306        assert!(!t.key_path.is_empty());
307        assert!(!t.cert_path.is_empty());
308        assert!(!t.watch_enabled());
309    }
310
311    #[test]
312    fn test_tls_option_require() {
313        let s = r#"
314        {
315            "mode": "require",
316            "cert_path": "/some_dir/some.crt",
317            "key_path": "/some_dir/some.key"
318        }
319        "#;
320
321        let t: TlsOption = serde_json::from_str(s).unwrap();
322
323        assert!(t.should_force_tls());
324
325        assert!(matches!(t.mode, TlsMode::Require));
326        assert!(!t.key_path.is_empty());
327        assert!(!t.cert_path.is_empty());
328        assert!(!t.watch_enabled());
329    }
330
331    #[test]
332    fn test_tls_option_verify_ca() {
333        let s = r#"
334        {
335            "mode": "verify_ca",
336            "cert_path": "/some_dir/some.crt",
337            "key_path": "/some_dir/some.key"
338        }
339        "#;
340
341        let t: TlsOption = serde_json::from_str(s).unwrap();
342
343        assert!(t.should_force_tls());
344
345        assert!(matches!(t.mode, TlsMode::VerifyCa));
346        assert!(!t.key_path.is_empty());
347        assert!(!t.cert_path.is_empty());
348        assert!(!t.watch_enabled());
349    }
350
351    #[test]
352    fn test_tls_option_verify_full() {
353        let s = r#"
354        {
355            "mode": "verify_full",
356            "cert_path": "/some_dir/some.crt",
357            "key_path": "/some_dir/some.key"
358        }
359        "#;
360
361        let t: TlsOption = serde_json::from_str(s).unwrap();
362
363        assert!(t.should_force_tls());
364
365        assert!(matches!(t.mode, TlsMode::VerifyFull));
366        assert!(!t.key_path.is_empty());
367        assert!(!t.cert_path.is_empty());
368        assert!(!t.watch_enabled());
369    }
370
371    #[test]
372    fn test_tls_option_watch_enabled() {
373        let s = r#"
374        {
375            "mode": "verify_full",
376            "cert_path": "/some_dir/some.crt",
377            "key_path": "/some_dir/some.key",
378            "watch": true
379        }
380        "#;
381
382        let t: TlsOption = serde_json::from_str(s).unwrap();
383
384        assert!(t.should_force_tls());
385
386        assert!(matches!(t.mode, TlsMode::VerifyFull));
387        assert!(!t.key_path.is_empty());
388        assert!(!t.cert_path.is_empty());
389        assert!(t.watch_enabled());
390    }
391
392    #[test]
393    fn test_tls_file_change_watch() {
394        common_telemetry::init_default_ut_logging();
395        let _ = install_ring_crypto_provider();
396
397        let dir = tempfile::tempdir().unwrap();
398        let cert_path = dir.path().join("server.crt");
399        let key_path = dir.path().join("server.key");
400
401        std::fs::copy("tests/ssl/server.crt", &cert_path).expect("failed to copy cert to tmpdir");
402        std::fs::copy("tests/ssl/server-rsa.key", &key_path).expect("failed to copy key to tmpdir");
403
404        assert!(std::fs::exists(&cert_path).unwrap());
405        assert!(std::fs::exists(&key_path).unwrap());
406
407        let server_tls = TlsOption {
408            mode: TlsMode::Require,
409            cert_path: cert_path
410                .clone()
411                .into_os_string()
412                .into_string()
413                .expect("failed to convert path to string"),
414            key_path: key_path
415                .clone()
416                .into_os_string()
417                .into_string()
418                .expect("failed to convert path to string"),
419            ca_cert_path: String::new(),
420            watch: true,
421        };
422
423        let server_config = Arc::new(
424            ReloadableTlsServerConfig::try_new(server_tls).expect("failed to create server config"),
425        );
426        maybe_watch_tls_config(server_config.clone()).expect("failed to watch server config");
427
428        assert_eq!(0, server_config.get_version());
429        assert!(server_config.get_server_config().is_some());
430
431        let tmp_file = key_path.with_extension("tmp");
432        std::fs::copy("tests/ssl/server-pkcs8.key", &tmp_file)
433            .expect("Failed to copy temp key file");
434        std::fs::rename(&tmp_file, &key_path).expect("Failed to rename temp key file");
435
436        const MAX_RETRIES: usize = 30;
437        let mut retries = 0;
438        let mut version_updated = false;
439
440        while retries < MAX_RETRIES {
441            if server_config.get_version() > 0 {
442                version_updated = true;
443                break;
444            }
445            std::thread::sleep(std::time::Duration::from_millis(100));
446            retries += 1;
447        }
448
449        assert!(version_updated, "TLS config did not reload in time");
450        assert!(server_config.get_version() > 0);
451        assert!(server_config.get_server_config().is_some());
452    }
453}