1use std::fs::File;
16use std::io::{BufReader, Error as IoError, ErrorKind};
17use std::path::Path;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::mpsc::channel;
20use std::sync::{Arc, RwLock};
21
22use common_telemetry::{error, info};
23use notify::{EventKind, RecursiveMode, Watcher};
24use rustls::ServerConfig;
25use rustls_pemfile::{certs, read_one, Item};
26use rustls_pki_types::{CertificateDer, PrivateKeyDer};
27use serde::{Deserialize, Serialize};
28use snafu::ResultExt;
29use strum::EnumString;
30
31use crate::error::{FileWatchSnafu, InternalIoSnafu, Result};
32
33#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq, EnumString)]
35#[serde(rename_all = "snake_case")]
36pub enum TlsMode {
37 #[default]
38 #[strum(to_string = "disable")]
39 Disable,
40
41 #[strum(to_string = "prefer")]
42 Prefer,
43
44 #[strum(to_string = "require")]
45 Require,
46
47 #[strum(to_string = "verify-ca")]
50 VerifyCa,
51
52 #[strum(to_string = "verify-full")]
53 VerifyFull,
54}
55
56#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq)]
57#[serde(rename_all = "snake_case")]
58pub struct TlsOption {
59 pub mode: TlsMode,
60 #[serde(default)]
61 pub cert_path: String,
62 #[serde(default)]
63 pub key_path: String,
64 #[serde(default)]
65 pub watch: bool,
66}
67
68impl TlsOption {
69 pub fn new(mode: Option<TlsMode>, cert_path: Option<String>, key_path: Option<String>) -> Self {
70 let mut tls_option = TlsOption::default();
71
72 if let Some(mode) = mode {
73 tls_option.mode = mode
74 };
75
76 if let Some(cert_path) = cert_path {
77 tls_option.cert_path = cert_path
78 };
79
80 if let Some(key_path) = key_path {
81 tls_option.key_path = key_path
82 };
83
84 tls_option
85 }
86
87 pub fn setup(&self) -> Result<Option<ServerConfig>> {
88 if let TlsMode::Disable = self.mode {
89 return Ok(None);
90 }
91 let cert = certs(&mut BufReader::new(
92 File::open(&self.cert_path)
93 .inspect_err(|e| error!(e; "Failed to open {}", self.cert_path))
94 .context(InternalIoSnafu)?,
95 ))
96 .collect::<std::result::Result<Vec<CertificateDer>, IoError>>()
97 .context(InternalIoSnafu)?;
98
99 let mut key_reader = BufReader::new(
100 File::open(&self.key_path)
101 .inspect_err(|e| error!(e; "Failed to open {}", self.key_path))
102 .context(InternalIoSnafu)?,
103 );
104 let key = match read_one(&mut key_reader)
105 .inspect_err(|e| error!(e; "Failed to read {}", self.key_path))
106 .context(InternalIoSnafu)?
107 {
108 Some(Item::Pkcs1Key(key)) => PrivateKeyDer::from(key),
109 Some(Item::Pkcs8Key(key)) => PrivateKeyDer::from(key),
110 Some(Item::Sec1Key(key)) => PrivateKeyDer::from(key),
111 _ => {
112 return Err(IoError::new(ErrorKind::InvalidInput, "invalid key"))
113 .context(InternalIoSnafu);
114 }
115 };
116
117 let config = ServerConfig::builder()
119 .with_no_client_auth()
120 .with_single_cert(cert, key)
121 .map_err(|err| std::io::Error::new(ErrorKind::InvalidInput, err))?;
122
123 Ok(Some(config))
124 }
125
126 pub fn should_force_tls(&self) -> bool {
127 !matches!(self.mode, TlsMode::Disable | TlsMode::Prefer)
128 }
129
130 pub fn cert_path(&self) -> &Path {
131 Path::new(&self.cert_path)
132 }
133
134 pub fn key_path(&self) -> &Path {
135 Path::new(&self.key_path)
136 }
137
138 pub fn watch_enabled(&self) -> bool {
139 self.mode != TlsMode::Disable && self.watch
140 }
141}
142
143pub struct ReloadableTlsServerConfig {
147 tls_option: TlsOption,
148 config: RwLock<Option<Arc<ServerConfig>>>,
149 version: AtomicUsize,
150}
151
152impl ReloadableTlsServerConfig {
153 pub fn try_new(tls_option: TlsOption) -> Result<ReloadableTlsServerConfig> {
155 let server_config = tls_option.setup()?;
156 Ok(Self {
157 tls_option,
158 config: RwLock::new(server_config.map(Arc::new)),
159 version: AtomicUsize::new(0),
160 })
161 }
162
163 pub fn reload(&self) -> Result<()> {
165 let server_config = self.tls_option.setup()?;
166 *self.config.write().unwrap() = server_config.map(Arc::new);
167 self.version.fetch_add(1, Ordering::Relaxed);
168 Ok(())
169 }
170
171 pub fn get_server_config(&self) -> Option<Arc<ServerConfig>> {
173 self.config.read().unwrap().clone()
174 }
175
176 pub fn get_tls_option(&self) -> &TlsOption {
178 &self.tls_option
179 }
180
181 pub fn get_version(&self) -> usize {
185 self.version.load(Ordering::Relaxed)
186 }
187}
188
189pub fn maybe_watch_tls_config(tls_server_config: Arc<ReloadableTlsServerConfig>) -> Result<()> {
190 if !tls_server_config.get_tls_option().watch_enabled() {
191 return Ok(());
192 }
193
194 let tls_server_config_for_watcher = tls_server_config.clone();
195
196 let (tx, rx) = channel::<notify::Result<notify::Event>>();
197 let mut watcher = notify::recommended_watcher(tx).context(FileWatchSnafu { path: "<none>" })?;
198
199 let cert_path = tls_server_config.get_tls_option().cert_path();
200 watcher
201 .watch(cert_path, RecursiveMode::NonRecursive)
202 .with_context(|_| FileWatchSnafu {
203 path: cert_path.display().to_string(),
204 })?;
205
206 let key_path = tls_server_config.get_tls_option().key_path();
207 watcher
208 .watch(key_path, RecursiveMode::NonRecursive)
209 .with_context(|_| FileWatchSnafu {
210 path: key_path.display().to_string(),
211 })?;
212
213 std::thread::spawn(move || {
214 let _watcher = watcher;
215 while let Ok(res) = rx.recv() {
216 if let Ok(event) = res {
217 match event.kind {
218 EventKind::Modify(_) | EventKind::Create(_) => {
219 info!("Detected TLS cert/key file change: {:?}", event);
220 if let Err(err) = tls_server_config_for_watcher.reload() {
221 error!(err; "Failed to reload TLS server config");
222 } else {
223 info!("Reloaded TLS cert/key file successfully.");
224 }
225 }
226 _ => {}
227 }
228 }
229 }
230 });
231
232 Ok(())
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use crate::install_ring_crypto_provider;
239 use crate::tls::TlsMode::Disable;
240
241 #[test]
242 fn test_new_tls_option() {
243 assert_eq!(TlsOption::default(), TlsOption::new(None, None, None));
244 assert_eq!(
245 TlsOption {
246 mode: Disable,
247 ..Default::default()
248 },
249 TlsOption::new(Some(Disable), None, None)
250 );
251 assert_eq!(
252 TlsOption {
253 mode: Disable,
254 cert_path: "/path/to/cert_path".to_string(),
255 key_path: "/path/to/key_path".to_string(),
256 watch: false
257 },
258 TlsOption::new(
259 Some(Disable),
260 Some("/path/to/cert_path".to_string()),
261 Some("/path/to/key_path".to_string())
262 )
263 );
264 }
265
266 #[test]
267 fn test_tls_option_disable() {
268 let s = r#"
269 {
270 "mode": "disable"
271 }
272 "#;
273
274 let t: TlsOption = serde_json::from_str(s).unwrap();
275
276 assert!(!t.should_force_tls());
277
278 assert!(matches!(t.mode, TlsMode::Disable));
279 assert!(t.key_path.is_empty());
280 assert!(t.cert_path.is_empty());
281 assert!(!t.watch_enabled());
282
283 let setup = t.setup();
284 let setup = setup.unwrap();
285 assert!(setup.is_none());
286 }
287
288 #[test]
289 fn test_tls_option_prefer() {
290 let s = r#"
291 {
292 "mode": "prefer",
293 "cert_path": "/some_dir/some.crt",
294 "key_path": "/some_dir/some.key"
295 }
296 "#;
297
298 let t: TlsOption = serde_json::from_str(s).unwrap();
299
300 assert!(!t.should_force_tls());
301
302 assert!(matches!(t.mode, TlsMode::Prefer));
303 assert!(!t.key_path.is_empty());
304 assert!(!t.cert_path.is_empty());
305 assert!(!t.watch_enabled());
306 }
307
308 #[test]
309 fn test_tls_option_require() {
310 let s = r#"
311 {
312 "mode": "require",
313 "cert_path": "/some_dir/some.crt",
314 "key_path": "/some_dir/some.key"
315 }
316 "#;
317
318 let t: TlsOption = serde_json::from_str(s).unwrap();
319
320 assert!(t.should_force_tls());
321
322 assert!(matches!(t.mode, TlsMode::Require));
323 assert!(!t.key_path.is_empty());
324 assert!(!t.cert_path.is_empty());
325 assert!(!t.watch_enabled());
326 }
327
328 #[test]
329 fn test_tls_option_verify_ca() {
330 let s = r#"
331 {
332 "mode": "verify_ca",
333 "cert_path": "/some_dir/some.crt",
334 "key_path": "/some_dir/some.key"
335 }
336 "#;
337
338 let t: TlsOption = serde_json::from_str(s).unwrap();
339
340 assert!(t.should_force_tls());
341
342 assert!(matches!(t.mode, TlsMode::VerifyCa));
343 assert!(!t.key_path.is_empty());
344 assert!(!t.cert_path.is_empty());
345 assert!(!t.watch_enabled());
346 }
347
348 #[test]
349 fn test_tls_option_verify_full() {
350 let s = r#"
351 {
352 "mode": "verify_full",
353 "cert_path": "/some_dir/some.crt",
354 "key_path": "/some_dir/some.key"
355 }
356 "#;
357
358 let t: TlsOption = serde_json::from_str(s).unwrap();
359
360 assert!(t.should_force_tls());
361
362 assert!(matches!(t.mode, TlsMode::VerifyFull));
363 assert!(!t.key_path.is_empty());
364 assert!(!t.cert_path.is_empty());
365 assert!(!t.watch_enabled());
366 }
367
368 #[test]
369 fn test_tls_option_watch_enabled() {
370 let s = r#"
371 {
372 "mode": "verify_full",
373 "cert_path": "/some_dir/some.crt",
374 "key_path": "/some_dir/some.key",
375 "watch": true
376 }
377 "#;
378
379 let t: TlsOption = serde_json::from_str(s).unwrap();
380
381 assert!(t.should_force_tls());
382
383 assert!(matches!(t.mode, TlsMode::VerifyFull));
384 assert!(!t.key_path.is_empty());
385 assert!(!t.cert_path.is_empty());
386 assert!(t.watch_enabled());
387 }
388
389 #[test]
390 fn test_tls_file_change_watch() {
391 common_telemetry::init_default_ut_logging();
392 let _ = install_ring_crypto_provider();
393
394 let dir = tempfile::tempdir().unwrap();
395 let cert_path = dir.path().join("server.crt");
396 let key_path = dir.path().join("server.key");
397
398 std::fs::copy("tests/ssl/server.crt", &cert_path).expect("failed to copy cert to tmpdir");
399 std::fs::copy("tests/ssl/server-rsa.key", &key_path).expect("failed to copy key to tmpdir");
400
401 assert!(std::fs::exists(&cert_path).unwrap());
402 assert!(std::fs::exists(&key_path).unwrap());
403
404 let server_tls = TlsOption {
405 mode: TlsMode::Require,
406 cert_path: cert_path
407 .clone()
408 .into_os_string()
409 .into_string()
410 .expect("failed to convert path to string"),
411 key_path: key_path
412 .clone()
413 .into_os_string()
414 .into_string()
415 .expect("failed to convert path to string"),
416 watch: true,
417 };
418
419 let server_config = Arc::new(
420 ReloadableTlsServerConfig::try_new(server_tls).expect("failed to create server config"),
421 );
422 maybe_watch_tls_config(server_config.clone()).expect("failed to watch server config");
423
424 assert_eq!(0, server_config.get_version());
425 assert!(server_config.get_server_config().is_some());
426
427 let tmp_file = key_path.with_extension("tmp");
428 std::fs::copy("tests/ssl/server-pkcs8.key", &tmp_file)
429 .expect("Failed to copy temp key file");
430 std::fs::rename(&tmp_file, &key_path).expect("Failed to rename temp key file");
431
432 const MAX_RETRIES: usize = 30;
433 let mut retries = 0;
434 let mut version_updated = false;
435
436 while retries < MAX_RETRIES {
437 if server_config.get_version() > 0 {
438 version_updated = true;
439 break;
440 }
441 std::thread::sleep(std::time::Duration::from_millis(100));
442 retries += 1;
443 }
444
445 assert!(version_updated, "TLS config did not reload in time");
446 assert!(server_config.get_version() > 0);
447 assert!(server_config.get_server_config().is_some());
448 }
449}