1use std::fs::File;
16use std::io::{BufReader, Error as IoError, ErrorKind};
17use std::path::Path;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::mpsc::channel;
20use std::sync::{Arc, RwLock};
21
22use common_telemetry::{error, info};
23use notify::{EventKind, RecursiveMode, Watcher};
24use rustls::ServerConfig;
25use rustls_pemfile::{Item, certs, read_one};
26use rustls_pki_types::{CertificateDer, PrivateKeyDer};
27use serde::{Deserialize, Serialize};
28use snafu::ResultExt;
29use strum::EnumString;
30
31use crate::error::{FileWatchSnafu, InternalIoSnafu, Result};
32
33#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq, EnumString)]
35#[serde(rename_all = "snake_case")]
36pub enum TlsMode {
37 #[default]
38 #[strum(to_string = "disable")]
39 Disable,
40
41 #[strum(to_string = "prefer")]
42 Prefer,
43
44 #[strum(to_string = "require")]
45 Require,
46
47 #[strum(to_string = "verify-ca")]
50 VerifyCa,
51
52 #[strum(to_string = "verify-full")]
53 VerifyFull,
54}
55
56#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq)]
57#[serde(rename_all = "snake_case")]
58pub struct TlsOption {
59 pub mode: TlsMode,
60 #[serde(default)]
61 pub cert_path: String,
62 #[serde(default)]
63 pub key_path: String,
64 #[serde(default)]
65 pub ca_cert_path: String,
66 #[serde(default)]
67 pub watch: bool,
68}
69
70impl TlsOption {
71 pub fn new(
72 mode: Option<TlsMode>,
73 cert_path: Option<String>,
74 key_path: Option<String>,
75 watch: bool,
76 ) -> Self {
77 let mut tls_option = TlsOption::default();
78
79 if let Some(mode) = mode {
80 tls_option.mode = mode
81 };
82
83 if let Some(cert_path) = cert_path {
84 tls_option.cert_path = cert_path
85 };
86
87 if let Some(key_path) = key_path {
88 tls_option.key_path = key_path
89 };
90
91 tls_option.watch = watch;
92
93 tls_option
94 }
95
96 pub fn setup(&self) -> Result<Option<ServerConfig>> {
97 if let TlsMode::Disable = self.mode {
98 return Ok(None);
99 }
100 let cert = certs(&mut BufReader::new(
101 File::open(&self.cert_path)
102 .inspect_err(|e| error!(e; "Failed to open {}", self.cert_path))
103 .context(InternalIoSnafu)?,
104 ))
105 .collect::<std::result::Result<Vec<CertificateDer>, IoError>>()
106 .context(InternalIoSnafu)?;
107
108 let mut key_reader = BufReader::new(
109 File::open(&self.key_path)
110 .inspect_err(|e| error!(e; "Failed to open {}", self.key_path))
111 .context(InternalIoSnafu)?,
112 );
113 let key = match read_one(&mut key_reader)
114 .inspect_err(|e| error!(e; "Failed to read {}", self.key_path))
115 .context(InternalIoSnafu)?
116 {
117 Some(Item::Pkcs1Key(key)) => PrivateKeyDer::from(key),
118 Some(Item::Pkcs8Key(key)) => PrivateKeyDer::from(key),
119 Some(Item::Sec1Key(key)) => PrivateKeyDer::from(key),
120 _ => {
121 return Err(IoError::new(ErrorKind::InvalidInput, "invalid key"))
122 .context(InternalIoSnafu);
123 }
124 };
125
126 let config = ServerConfig::builder()
128 .with_no_client_auth()
129 .with_single_cert(cert, key)
130 .map_err(|err| std::io::Error::new(ErrorKind::InvalidInput, err))?;
131
132 Ok(Some(config))
133 }
134
135 pub fn should_force_tls(&self) -> bool {
136 !matches!(self.mode, TlsMode::Disable | TlsMode::Prefer)
137 }
138
139 pub fn cert_path(&self) -> &Path {
140 Path::new(&self.cert_path)
141 }
142
143 pub fn key_path(&self) -> &Path {
144 Path::new(&self.key_path)
145 }
146
147 pub fn watch_enabled(&self) -> bool {
148 self.mode != TlsMode::Disable && self.watch
149 }
150}
151
152pub struct ReloadableTlsServerConfig {
156 tls_option: TlsOption,
157 config: RwLock<Option<Arc<ServerConfig>>>,
158 version: AtomicUsize,
159}
160
161impl ReloadableTlsServerConfig {
162 pub fn try_new(tls_option: TlsOption) -> Result<ReloadableTlsServerConfig> {
164 let server_config = tls_option.setup()?;
165 Ok(Self {
166 tls_option,
167 config: RwLock::new(server_config.map(Arc::new)),
168 version: AtomicUsize::new(0),
169 })
170 }
171
172 pub fn reload(&self) -> Result<()> {
174 let server_config = self.tls_option.setup()?;
175 *self.config.write().unwrap() = server_config.map(Arc::new);
176 self.version.fetch_add(1, Ordering::Relaxed);
177 Ok(())
178 }
179
180 pub fn get_server_config(&self) -> Option<Arc<ServerConfig>> {
182 self.config.read().unwrap().clone()
183 }
184
185 pub fn get_tls_option(&self) -> &TlsOption {
187 &self.tls_option
188 }
189
190 pub fn get_version(&self) -> usize {
194 self.version.load(Ordering::Relaxed)
195 }
196}
197
198pub fn maybe_watch_tls_config(tls_server_config: Arc<ReloadableTlsServerConfig>) -> Result<()> {
199 if !tls_server_config.get_tls_option().watch_enabled() {
200 return Ok(());
201 }
202
203 let tls_server_config_for_watcher = tls_server_config.clone();
204
205 let (tx, rx) = channel::<notify::Result<notify::Event>>();
206 let mut watcher = notify::recommended_watcher(tx).context(FileWatchSnafu { path: "<none>" })?;
207
208 let cert_path = tls_server_config.get_tls_option().cert_path();
209 watcher
210 .watch(cert_path, RecursiveMode::NonRecursive)
211 .with_context(|_| FileWatchSnafu {
212 path: cert_path.display().to_string(),
213 })?;
214
215 let key_path = tls_server_config.get_tls_option().key_path();
216 watcher
217 .watch(key_path, RecursiveMode::NonRecursive)
218 .with_context(|_| FileWatchSnafu {
219 path: key_path.display().to_string(),
220 })?;
221
222 std::thread::spawn(move || {
223 let _watcher = watcher;
224 while let Ok(res) = rx.recv() {
225 if let Ok(event) = res {
226 match event.kind {
227 EventKind::Modify(_) | EventKind::Create(_) => {
228 info!("Detected TLS cert/key file change: {:?}", event);
229 if let Err(err) = tls_server_config_for_watcher.reload() {
230 error!(err; "Failed to reload TLS server config");
231 } else {
232 info!("Reloaded TLS cert/key file successfully.");
233 }
234 }
235 _ => {}
236 }
237 }
238 }
239 });
240
241 Ok(())
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use crate::install_ring_crypto_provider;
248 use crate::tls::TlsMode::Disable;
249
250 #[test]
251 fn test_new_tls_option() {
252 assert_eq!(
253 TlsOption::default(),
254 TlsOption::new(None, None, None, false)
255 );
256 assert_eq!(
257 TlsOption {
258 mode: Disable,
259 ..Default::default()
260 },
261 TlsOption::new(Some(Disable), None, None, false)
262 );
263 assert_eq!(
264 TlsOption {
265 mode: Disable,
266 cert_path: "/path/to/cert_path".to_string(),
267 key_path: "/path/to/key_path".to_string(),
268 ca_cert_path: String::new(),
269 watch: false
270 },
271 TlsOption::new(
272 Some(Disable),
273 Some("/path/to/cert_path".to_string()),
274 Some("/path/to/key_path".to_string()),
275 false
276 )
277 );
278 }
279
280 #[test]
281 fn test_tls_option_disable() {
282 let s = r#"
283 {
284 "mode": "disable"
285 }
286 "#;
287
288 let t: TlsOption = serde_json::from_str(s).unwrap();
289
290 assert!(!t.should_force_tls());
291
292 assert!(matches!(t.mode, TlsMode::Disable));
293 assert!(t.key_path.is_empty());
294 assert!(t.cert_path.is_empty());
295 assert!(!t.watch_enabled());
296
297 let setup = t.setup();
298 let setup = setup.unwrap();
299 assert!(setup.is_none());
300 }
301
302 #[test]
303 fn test_tls_option_prefer() {
304 let s = r#"
305 {
306 "mode": "prefer",
307 "cert_path": "/some_dir/some.crt",
308 "key_path": "/some_dir/some.key"
309 }
310 "#;
311
312 let t: TlsOption = serde_json::from_str(s).unwrap();
313
314 assert!(!t.should_force_tls());
315
316 assert!(matches!(t.mode, TlsMode::Prefer));
317 assert!(!t.key_path.is_empty());
318 assert!(!t.cert_path.is_empty());
319 assert!(!t.watch_enabled());
320 }
321
322 #[test]
323 fn test_tls_option_require() {
324 let s = r#"
325 {
326 "mode": "require",
327 "cert_path": "/some_dir/some.crt",
328 "key_path": "/some_dir/some.key"
329 }
330 "#;
331
332 let t: TlsOption = serde_json::from_str(s).unwrap();
333
334 assert!(t.should_force_tls());
335
336 assert!(matches!(t.mode, TlsMode::Require));
337 assert!(!t.key_path.is_empty());
338 assert!(!t.cert_path.is_empty());
339 assert!(!t.watch_enabled());
340 }
341
342 #[test]
343 fn test_tls_option_verify_ca() {
344 let s = r#"
345 {
346 "mode": "verify_ca",
347 "cert_path": "/some_dir/some.crt",
348 "key_path": "/some_dir/some.key"
349 }
350 "#;
351
352 let t: TlsOption = serde_json::from_str(s).unwrap();
353
354 assert!(t.should_force_tls());
355
356 assert!(matches!(t.mode, TlsMode::VerifyCa));
357 assert!(!t.key_path.is_empty());
358 assert!(!t.cert_path.is_empty());
359 assert!(!t.watch_enabled());
360 }
361
362 #[test]
363 fn test_tls_option_verify_full() {
364 let s = r#"
365 {
366 "mode": "verify_full",
367 "cert_path": "/some_dir/some.crt",
368 "key_path": "/some_dir/some.key"
369 }
370 "#;
371
372 let t: TlsOption = serde_json::from_str(s).unwrap();
373
374 assert!(t.should_force_tls());
375
376 assert!(matches!(t.mode, TlsMode::VerifyFull));
377 assert!(!t.key_path.is_empty());
378 assert!(!t.cert_path.is_empty());
379 assert!(!t.watch_enabled());
380 }
381
382 #[test]
383 fn test_tls_option_watch_enabled() {
384 let s = r#"
385 {
386 "mode": "verify_full",
387 "cert_path": "/some_dir/some.crt",
388 "key_path": "/some_dir/some.key",
389 "watch": true
390 }
391 "#;
392
393 let t: TlsOption = serde_json::from_str(s).unwrap();
394
395 assert!(t.should_force_tls());
396
397 assert!(matches!(t.mode, TlsMode::VerifyFull));
398 assert!(!t.key_path.is_empty());
399 assert!(!t.cert_path.is_empty());
400 assert!(t.watch_enabled());
401 }
402
403 #[test]
404 fn test_tls_file_change_watch() {
405 common_telemetry::init_default_ut_logging();
406 let _ = install_ring_crypto_provider();
407
408 let dir = tempfile::tempdir().unwrap();
409 let cert_path = dir.path().join("server.crt");
410 let key_path = dir.path().join("server.key");
411
412 std::fs::copy("tests/ssl/server.crt", &cert_path).expect("failed to copy cert to tmpdir");
413 std::fs::copy("tests/ssl/server-rsa.key", &key_path).expect("failed to copy key to tmpdir");
414
415 assert!(std::fs::exists(&cert_path).unwrap());
416 assert!(std::fs::exists(&key_path).unwrap());
417
418 let server_tls = TlsOption {
419 mode: TlsMode::Require,
420 cert_path: cert_path
421 .clone()
422 .into_os_string()
423 .into_string()
424 .expect("failed to convert path to string"),
425 key_path: key_path
426 .clone()
427 .into_os_string()
428 .into_string()
429 .expect("failed to convert path to string"),
430 ca_cert_path: String::new(),
431 watch: true,
432 };
433
434 let server_config = Arc::new(
435 ReloadableTlsServerConfig::try_new(server_tls).expect("failed to create server config"),
436 );
437 maybe_watch_tls_config(server_config.clone()).expect("failed to watch server config");
438
439 assert_eq!(0, server_config.get_version());
440 assert!(server_config.get_server_config().is_some());
441
442 let tmp_file = key_path.with_extension("tmp");
443 std::fs::copy("tests/ssl/server-pkcs8.key", &tmp_file)
444 .expect("Failed to copy temp key file");
445 std::fs::rename(&tmp_file, &key_path).expect("Failed to rename temp key file");
446
447 const MAX_RETRIES: usize = 30;
448 let mut retries = 0;
449 let mut version_updated = false;
450
451 while retries < MAX_RETRIES {
452 if server_config.get_version() > 0 {
453 version_updated = true;
454 break;
455 }
456 std::thread::sleep(std::time::Duration::from_millis(100));
457 retries += 1;
458 }
459
460 assert!(version_updated, "TLS config did not reload in time");
461 assert!(server_config.get_version() > 0);
462 assert!(server_config.get_server_config().is_some());
463 }
464}