servers/
tls.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fs::File;
16use std::io::{BufReader, Error as IoError, ErrorKind};
17use std::path::Path;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::mpsc::channel;
20use std::sync::{Arc, RwLock};
21
22use common_telemetry::{error, info};
23use notify::{EventKind, RecursiveMode, Watcher};
24use rustls::ServerConfig;
25use rustls_pemfile::{Item, certs, read_one};
26use rustls_pki_types::{CertificateDer, PrivateKeyDer};
27use serde::{Deserialize, Serialize};
28use snafu::ResultExt;
29use strum::EnumString;
30
31use crate::error::{FileWatchSnafu, InternalIoSnafu, Result};
32
33/// TlsMode is used for Mysql and Postgres server start up.
34#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq, EnumString)]
35#[serde(rename_all = "snake_case")]
36pub enum TlsMode {
37    #[default]
38    #[strum(to_string = "disable")]
39    Disable,
40
41    #[strum(to_string = "prefer")]
42    Prefer,
43
44    #[strum(to_string = "require")]
45    Require,
46
47    // TODO(SSebo): Implement the following 2 TSL mode described in
48    // ["34.19.3. Protection Provided in Different Modes"](https://www.postgresql.org/docs/current/libpq-ssl.html)
49    #[strum(to_string = "verify-ca")]
50    VerifyCa,
51
52    #[strum(to_string = "verify-full")]
53    VerifyFull,
54}
55
56#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq)]
57#[serde(rename_all = "snake_case")]
58pub struct TlsOption {
59    pub mode: TlsMode,
60    #[serde(default)]
61    pub cert_path: String,
62    #[serde(default)]
63    pub key_path: String,
64    #[serde(default)]
65    pub ca_cert_path: String,
66    #[serde(default)]
67    pub watch: bool,
68}
69
70impl TlsOption {
71    pub fn new(
72        mode: Option<TlsMode>,
73        cert_path: Option<String>,
74        key_path: Option<String>,
75        watch: bool,
76    ) -> Self {
77        let mut tls_option = TlsOption::default();
78
79        if let Some(mode) = mode {
80            tls_option.mode = mode
81        };
82
83        if let Some(cert_path) = cert_path {
84            tls_option.cert_path = cert_path
85        };
86
87        if let Some(key_path) = key_path {
88            tls_option.key_path = key_path
89        };
90
91        tls_option.watch = watch;
92
93        tls_option
94    }
95
96    pub fn setup(&self) -> Result<Option<ServerConfig>> {
97        if let TlsMode::Disable = self.mode {
98            return Ok(None);
99        }
100        let cert = certs(&mut BufReader::new(
101            File::open(&self.cert_path)
102                .inspect_err(|e| error!(e; "Failed to open {}", self.cert_path))
103                .context(InternalIoSnafu)?,
104        ))
105        .collect::<std::result::Result<Vec<CertificateDer>, IoError>>()
106        .context(InternalIoSnafu)?;
107
108        let mut key_reader = BufReader::new(
109            File::open(&self.key_path)
110                .inspect_err(|e| error!(e; "Failed to open {}", self.key_path))
111                .context(InternalIoSnafu)?,
112        );
113        let key = match read_one(&mut key_reader)
114            .inspect_err(|e| error!(e; "Failed to read {}", self.key_path))
115            .context(InternalIoSnafu)?
116        {
117            Some(Item::Pkcs1Key(key)) => PrivateKeyDer::from(key),
118            Some(Item::Pkcs8Key(key)) => PrivateKeyDer::from(key),
119            Some(Item::Sec1Key(key)) => PrivateKeyDer::from(key),
120            _ => {
121                return Err(IoError::new(ErrorKind::InvalidInput, "invalid key"))
122                    .context(InternalIoSnafu);
123            }
124        };
125
126        // TODO(SSebo): with_client_cert_verifier if TlsMode is Required.
127        let config = ServerConfig::builder()
128            .with_no_client_auth()
129            .with_single_cert(cert, key)
130            .map_err(|err| std::io::Error::new(ErrorKind::InvalidInput, err))?;
131
132        Ok(Some(config))
133    }
134
135    pub fn should_force_tls(&self) -> bool {
136        !matches!(self.mode, TlsMode::Disable | TlsMode::Prefer)
137    }
138
139    pub fn cert_path(&self) -> &Path {
140        Path::new(&self.cert_path)
141    }
142
143    pub fn key_path(&self) -> &Path {
144        Path::new(&self.key_path)
145    }
146
147    pub fn watch_enabled(&self) -> bool {
148        self.mode != TlsMode::Disable && self.watch
149    }
150}
151
152/// A mutable container for TLS server config
153///
154/// This struct allows dynamic reloading of server certificates and keys
155pub struct ReloadableTlsServerConfig {
156    tls_option: TlsOption,
157    config: RwLock<Option<Arc<ServerConfig>>>,
158    version: AtomicUsize,
159}
160
161impl ReloadableTlsServerConfig {
162    /// Create server config by loading configuration from `TlsOption`
163    pub fn try_new(tls_option: TlsOption) -> Result<ReloadableTlsServerConfig> {
164        let server_config = tls_option.setup()?;
165        Ok(Self {
166            tls_option,
167            config: RwLock::new(server_config.map(Arc::new)),
168            version: AtomicUsize::new(0),
169        })
170    }
171
172    /// Reread server certificates and keys from file system.
173    pub fn reload(&self) -> Result<()> {
174        let server_config = self.tls_option.setup()?;
175        *self.config.write().unwrap() = server_config.map(Arc::new);
176        self.version.fetch_add(1, Ordering::Relaxed);
177        Ok(())
178    }
179
180    /// Get the server config hold by this container
181    pub fn get_server_config(&self) -> Option<Arc<ServerConfig>> {
182        self.config.read().unwrap().clone()
183    }
184
185    /// Get associated `TlsOption`
186    pub fn get_tls_option(&self) -> &TlsOption {
187        &self.tls_option
188    }
189
190    /// Get version of current config
191    ///
192    /// this version will auto increase when server config get reloaded.
193    pub fn get_version(&self) -> usize {
194        self.version.load(Ordering::Relaxed)
195    }
196}
197
198pub fn maybe_watch_tls_config(tls_server_config: Arc<ReloadableTlsServerConfig>) -> Result<()> {
199    if !tls_server_config.get_tls_option().watch_enabled() {
200        return Ok(());
201    }
202
203    let tls_server_config_for_watcher = tls_server_config.clone();
204
205    let (tx, rx) = channel::<notify::Result<notify::Event>>();
206    let mut watcher = notify::recommended_watcher(tx).context(FileWatchSnafu { path: "<none>" })?;
207
208    let cert_path = tls_server_config.get_tls_option().cert_path();
209    watcher
210        .watch(cert_path, RecursiveMode::NonRecursive)
211        .with_context(|_| FileWatchSnafu {
212            path: cert_path.display().to_string(),
213        })?;
214
215    let key_path = tls_server_config.get_tls_option().key_path();
216    watcher
217        .watch(key_path, RecursiveMode::NonRecursive)
218        .with_context(|_| FileWatchSnafu {
219            path: key_path.display().to_string(),
220        })?;
221
222    std::thread::spawn(move || {
223        let _watcher = watcher;
224        while let Ok(res) = rx.recv() {
225            if let Ok(event) = res {
226                match event.kind {
227                    EventKind::Modify(_) | EventKind::Create(_) => {
228                        info!("Detected TLS cert/key file change: {:?}", event);
229                        if let Err(err) = tls_server_config_for_watcher.reload() {
230                            error!(err; "Failed to reload TLS server config");
231                        } else {
232                            info!("Reloaded TLS cert/key file successfully.");
233                        }
234                    }
235                    _ => {}
236                }
237            }
238        }
239    });
240
241    Ok(())
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use crate::install_ring_crypto_provider;
248    use crate::tls::TlsMode::Disable;
249
250    #[test]
251    fn test_new_tls_option() {
252        assert_eq!(
253            TlsOption::default(),
254            TlsOption::new(None, None, None, false)
255        );
256        assert_eq!(
257            TlsOption {
258                mode: Disable,
259                ..Default::default()
260            },
261            TlsOption::new(Some(Disable), None, None, false)
262        );
263        assert_eq!(
264            TlsOption {
265                mode: Disable,
266                cert_path: "/path/to/cert_path".to_string(),
267                key_path: "/path/to/key_path".to_string(),
268                ca_cert_path: String::new(),
269                watch: false
270            },
271            TlsOption::new(
272                Some(Disable),
273                Some("/path/to/cert_path".to_string()),
274                Some("/path/to/key_path".to_string()),
275                false
276            )
277        );
278    }
279
280    #[test]
281    fn test_tls_option_disable() {
282        let s = r#"
283        {
284            "mode": "disable"
285        }
286        "#;
287
288        let t: TlsOption = serde_json::from_str(s).unwrap();
289
290        assert!(!t.should_force_tls());
291
292        assert!(matches!(t.mode, TlsMode::Disable));
293        assert!(t.key_path.is_empty());
294        assert!(t.cert_path.is_empty());
295        assert!(!t.watch_enabled());
296
297        let setup = t.setup();
298        let setup = setup.unwrap();
299        assert!(setup.is_none());
300    }
301
302    #[test]
303    fn test_tls_option_prefer() {
304        let s = r#"
305        {
306            "mode": "prefer",
307            "cert_path": "/some_dir/some.crt",
308            "key_path": "/some_dir/some.key"
309        }
310        "#;
311
312        let t: TlsOption = serde_json::from_str(s).unwrap();
313
314        assert!(!t.should_force_tls());
315
316        assert!(matches!(t.mode, TlsMode::Prefer));
317        assert!(!t.key_path.is_empty());
318        assert!(!t.cert_path.is_empty());
319        assert!(!t.watch_enabled());
320    }
321
322    #[test]
323    fn test_tls_option_require() {
324        let s = r#"
325        {
326            "mode": "require",
327            "cert_path": "/some_dir/some.crt",
328            "key_path": "/some_dir/some.key"
329        }
330        "#;
331
332        let t: TlsOption = serde_json::from_str(s).unwrap();
333
334        assert!(t.should_force_tls());
335
336        assert!(matches!(t.mode, TlsMode::Require));
337        assert!(!t.key_path.is_empty());
338        assert!(!t.cert_path.is_empty());
339        assert!(!t.watch_enabled());
340    }
341
342    #[test]
343    fn test_tls_option_verify_ca() {
344        let s = r#"
345        {
346            "mode": "verify_ca",
347            "cert_path": "/some_dir/some.crt",
348            "key_path": "/some_dir/some.key"
349        }
350        "#;
351
352        let t: TlsOption = serde_json::from_str(s).unwrap();
353
354        assert!(t.should_force_tls());
355
356        assert!(matches!(t.mode, TlsMode::VerifyCa));
357        assert!(!t.key_path.is_empty());
358        assert!(!t.cert_path.is_empty());
359        assert!(!t.watch_enabled());
360    }
361
362    #[test]
363    fn test_tls_option_verify_full() {
364        let s = r#"
365        {
366            "mode": "verify_full",
367            "cert_path": "/some_dir/some.crt",
368            "key_path": "/some_dir/some.key"
369        }
370        "#;
371
372        let t: TlsOption = serde_json::from_str(s).unwrap();
373
374        assert!(t.should_force_tls());
375
376        assert!(matches!(t.mode, TlsMode::VerifyFull));
377        assert!(!t.key_path.is_empty());
378        assert!(!t.cert_path.is_empty());
379        assert!(!t.watch_enabled());
380    }
381
382    #[test]
383    fn test_tls_option_watch_enabled() {
384        let s = r#"
385        {
386            "mode": "verify_full",
387            "cert_path": "/some_dir/some.crt",
388            "key_path": "/some_dir/some.key",
389            "watch": true
390        }
391        "#;
392
393        let t: TlsOption = serde_json::from_str(s).unwrap();
394
395        assert!(t.should_force_tls());
396
397        assert!(matches!(t.mode, TlsMode::VerifyFull));
398        assert!(!t.key_path.is_empty());
399        assert!(!t.cert_path.is_empty());
400        assert!(t.watch_enabled());
401    }
402
403    #[test]
404    fn test_tls_file_change_watch() {
405        common_telemetry::init_default_ut_logging();
406        let _ = install_ring_crypto_provider();
407
408        let dir = tempfile::tempdir().unwrap();
409        let cert_path = dir.path().join("server.crt");
410        let key_path = dir.path().join("server.key");
411
412        std::fs::copy("tests/ssl/server.crt", &cert_path).expect("failed to copy cert to tmpdir");
413        std::fs::copy("tests/ssl/server-rsa.key", &key_path).expect("failed to copy key to tmpdir");
414
415        assert!(std::fs::exists(&cert_path).unwrap());
416        assert!(std::fs::exists(&key_path).unwrap());
417
418        let server_tls = TlsOption {
419            mode: TlsMode::Require,
420            cert_path: cert_path
421                .clone()
422                .into_os_string()
423                .into_string()
424                .expect("failed to convert path to string"),
425            key_path: key_path
426                .clone()
427                .into_os_string()
428                .into_string()
429                .expect("failed to convert path to string"),
430            ca_cert_path: String::new(),
431            watch: true,
432        };
433
434        let server_config = Arc::new(
435            ReloadableTlsServerConfig::try_new(server_tls).expect("failed to create server config"),
436        );
437        maybe_watch_tls_config(server_config.clone()).expect("failed to watch server config");
438
439        assert_eq!(0, server_config.get_version());
440        assert!(server_config.get_server_config().is_some());
441
442        let tmp_file = key_path.with_extension("tmp");
443        std::fs::copy("tests/ssl/server-pkcs8.key", &tmp_file)
444            .expect("Failed to copy temp key file");
445        std::fs::rename(&tmp_file, &key_path).expect("Failed to rename temp key file");
446
447        const MAX_RETRIES: usize = 30;
448        let mut retries = 0;
449        let mut version_updated = false;
450
451        while retries < MAX_RETRIES {
452            if server_config.get_version() > 0 {
453                version_updated = true;
454                break;
455            }
456            std::thread::sleep(std::time::Duration::from_millis(100));
457            retries += 1;
458        }
459
460        assert!(version_updated, "TLS config did not reload in time");
461        assert!(server_config.get_version() > 0);
462        assert!(server_config.get_server_config().is_some());
463    }
464}