1use std::fs::File;
16use std::io::{BufReader, Error as IoError, ErrorKind};
17use std::path::Path;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::mpsc::channel;
20use std::sync::{Arc, RwLock};
21
22use common_telemetry::{error, info};
23use notify::{EventKind, RecursiveMode, Watcher};
24use rustls::ServerConfig;
25use rustls_pemfile::{certs, read_one, Item};
26use rustls_pki_types::{CertificateDer, PrivateKeyDer};
27use serde::{Deserialize, Serialize};
28use snafu::ResultExt;
29use strum::EnumString;
30
31use crate::error::{FileWatchSnafu, InternalIoSnafu, Result};
32
33#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq, EnumString)]
35#[serde(rename_all = "snake_case")]
36pub enum TlsMode {
37 #[default]
38 #[strum(to_string = "disable")]
39 Disable,
40
41 #[strum(to_string = "prefer")]
42 Prefer,
43
44 #[strum(to_string = "require")]
45 Require,
46
47 #[strum(to_string = "verify-ca")]
50 VerifyCa,
51
52 #[strum(to_string = "verify-full")]
53 VerifyFull,
54}
55
56#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq)]
57#[serde(rename_all = "snake_case")]
58pub struct TlsOption {
59 pub mode: TlsMode,
60 #[serde(default)]
61 pub cert_path: String,
62 #[serde(default)]
63 pub key_path: String,
64 #[serde(default)]
65 pub ca_cert_path: String,
66 #[serde(default)]
67 pub watch: bool,
68}
69
70impl TlsOption {
71 pub fn new(mode: Option<TlsMode>, cert_path: Option<String>, key_path: Option<String>) -> Self {
72 let mut tls_option = TlsOption::default();
73
74 if let Some(mode) = mode {
75 tls_option.mode = mode
76 };
77
78 if let Some(cert_path) = cert_path {
79 tls_option.cert_path = cert_path
80 };
81
82 if let Some(key_path) = key_path {
83 tls_option.key_path = key_path
84 };
85
86 tls_option
87 }
88
89 pub fn setup(&self) -> Result<Option<ServerConfig>> {
90 if let TlsMode::Disable = self.mode {
91 return Ok(None);
92 }
93 let cert = certs(&mut BufReader::new(
94 File::open(&self.cert_path)
95 .inspect_err(|e| error!(e; "Failed to open {}", self.cert_path))
96 .context(InternalIoSnafu)?,
97 ))
98 .collect::<std::result::Result<Vec<CertificateDer>, IoError>>()
99 .context(InternalIoSnafu)?;
100
101 let mut key_reader = BufReader::new(
102 File::open(&self.key_path)
103 .inspect_err(|e| error!(e; "Failed to open {}", self.key_path))
104 .context(InternalIoSnafu)?,
105 );
106 let key = match read_one(&mut key_reader)
107 .inspect_err(|e| error!(e; "Failed to read {}", self.key_path))
108 .context(InternalIoSnafu)?
109 {
110 Some(Item::Pkcs1Key(key)) => PrivateKeyDer::from(key),
111 Some(Item::Pkcs8Key(key)) => PrivateKeyDer::from(key),
112 Some(Item::Sec1Key(key)) => PrivateKeyDer::from(key),
113 _ => {
114 return Err(IoError::new(ErrorKind::InvalidInput, "invalid key"))
115 .context(InternalIoSnafu);
116 }
117 };
118
119 let config = ServerConfig::builder()
121 .with_no_client_auth()
122 .with_single_cert(cert, key)
123 .map_err(|err| std::io::Error::new(ErrorKind::InvalidInput, err))?;
124
125 Ok(Some(config))
126 }
127
128 pub fn should_force_tls(&self) -> bool {
129 !matches!(self.mode, TlsMode::Disable | TlsMode::Prefer)
130 }
131
132 pub fn cert_path(&self) -> &Path {
133 Path::new(&self.cert_path)
134 }
135
136 pub fn key_path(&self) -> &Path {
137 Path::new(&self.key_path)
138 }
139
140 pub fn watch_enabled(&self) -> bool {
141 self.mode != TlsMode::Disable && self.watch
142 }
143}
144
145pub struct ReloadableTlsServerConfig {
149 tls_option: TlsOption,
150 config: RwLock<Option<Arc<ServerConfig>>>,
151 version: AtomicUsize,
152}
153
154impl ReloadableTlsServerConfig {
155 pub fn try_new(tls_option: TlsOption) -> Result<ReloadableTlsServerConfig> {
157 let server_config = tls_option.setup()?;
158 Ok(Self {
159 tls_option,
160 config: RwLock::new(server_config.map(Arc::new)),
161 version: AtomicUsize::new(0),
162 })
163 }
164
165 pub fn reload(&self) -> Result<()> {
167 let server_config = self.tls_option.setup()?;
168 *self.config.write().unwrap() = server_config.map(Arc::new);
169 self.version.fetch_add(1, Ordering::Relaxed);
170 Ok(())
171 }
172
173 pub fn get_server_config(&self) -> Option<Arc<ServerConfig>> {
175 self.config.read().unwrap().clone()
176 }
177
178 pub fn get_tls_option(&self) -> &TlsOption {
180 &self.tls_option
181 }
182
183 pub fn get_version(&self) -> usize {
187 self.version.load(Ordering::Relaxed)
188 }
189}
190
191pub fn maybe_watch_tls_config(tls_server_config: Arc<ReloadableTlsServerConfig>) -> Result<()> {
192 if !tls_server_config.get_tls_option().watch_enabled() {
193 return Ok(());
194 }
195
196 let tls_server_config_for_watcher = tls_server_config.clone();
197
198 let (tx, rx) = channel::<notify::Result<notify::Event>>();
199 let mut watcher = notify::recommended_watcher(tx).context(FileWatchSnafu { path: "<none>" })?;
200
201 let cert_path = tls_server_config.get_tls_option().cert_path();
202 watcher
203 .watch(cert_path, RecursiveMode::NonRecursive)
204 .with_context(|_| FileWatchSnafu {
205 path: cert_path.display().to_string(),
206 })?;
207
208 let key_path = tls_server_config.get_tls_option().key_path();
209 watcher
210 .watch(key_path, RecursiveMode::NonRecursive)
211 .with_context(|_| FileWatchSnafu {
212 path: key_path.display().to_string(),
213 })?;
214
215 std::thread::spawn(move || {
216 let _watcher = watcher;
217 while let Ok(res) = rx.recv() {
218 if let Ok(event) = res {
219 match event.kind {
220 EventKind::Modify(_) | EventKind::Create(_) => {
221 info!("Detected TLS cert/key file change: {:?}", event);
222 if let Err(err) = tls_server_config_for_watcher.reload() {
223 error!(err; "Failed to reload TLS server config");
224 } else {
225 info!("Reloaded TLS cert/key file successfully.");
226 }
227 }
228 _ => {}
229 }
230 }
231 }
232 });
233
234 Ok(())
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use crate::install_ring_crypto_provider;
241 use crate::tls::TlsMode::Disable;
242
243 #[test]
244 fn test_new_tls_option() {
245 assert_eq!(TlsOption::default(), TlsOption::new(None, None, None));
246 assert_eq!(
247 TlsOption {
248 mode: Disable,
249 ..Default::default()
250 },
251 TlsOption::new(Some(Disable), None, None)
252 );
253 assert_eq!(
254 TlsOption {
255 mode: Disable,
256 cert_path: "/path/to/cert_path".to_string(),
257 key_path: "/path/to/key_path".to_string(),
258 ca_cert_path: String::new(),
259 watch: false
260 },
261 TlsOption::new(
262 Some(Disable),
263 Some("/path/to/cert_path".to_string()),
264 Some("/path/to/key_path".to_string())
265 )
266 );
267 }
268
269 #[test]
270 fn test_tls_option_disable() {
271 let s = r#"
272 {
273 "mode": "disable"
274 }
275 "#;
276
277 let t: TlsOption = serde_json::from_str(s).unwrap();
278
279 assert!(!t.should_force_tls());
280
281 assert!(matches!(t.mode, TlsMode::Disable));
282 assert!(t.key_path.is_empty());
283 assert!(t.cert_path.is_empty());
284 assert!(!t.watch_enabled());
285
286 let setup = t.setup();
287 let setup = setup.unwrap();
288 assert!(setup.is_none());
289 }
290
291 #[test]
292 fn test_tls_option_prefer() {
293 let s = r#"
294 {
295 "mode": "prefer",
296 "cert_path": "/some_dir/some.crt",
297 "key_path": "/some_dir/some.key"
298 }
299 "#;
300
301 let t: TlsOption = serde_json::from_str(s).unwrap();
302
303 assert!(!t.should_force_tls());
304
305 assert!(matches!(t.mode, TlsMode::Prefer));
306 assert!(!t.key_path.is_empty());
307 assert!(!t.cert_path.is_empty());
308 assert!(!t.watch_enabled());
309 }
310
311 #[test]
312 fn test_tls_option_require() {
313 let s = r#"
314 {
315 "mode": "require",
316 "cert_path": "/some_dir/some.crt",
317 "key_path": "/some_dir/some.key"
318 }
319 "#;
320
321 let t: TlsOption = serde_json::from_str(s).unwrap();
322
323 assert!(t.should_force_tls());
324
325 assert!(matches!(t.mode, TlsMode::Require));
326 assert!(!t.key_path.is_empty());
327 assert!(!t.cert_path.is_empty());
328 assert!(!t.watch_enabled());
329 }
330
331 #[test]
332 fn test_tls_option_verify_ca() {
333 let s = r#"
334 {
335 "mode": "verify_ca",
336 "cert_path": "/some_dir/some.crt",
337 "key_path": "/some_dir/some.key"
338 }
339 "#;
340
341 let t: TlsOption = serde_json::from_str(s).unwrap();
342
343 assert!(t.should_force_tls());
344
345 assert!(matches!(t.mode, TlsMode::VerifyCa));
346 assert!(!t.key_path.is_empty());
347 assert!(!t.cert_path.is_empty());
348 assert!(!t.watch_enabled());
349 }
350
351 #[test]
352 fn test_tls_option_verify_full() {
353 let s = r#"
354 {
355 "mode": "verify_full",
356 "cert_path": "/some_dir/some.crt",
357 "key_path": "/some_dir/some.key"
358 }
359 "#;
360
361 let t: TlsOption = serde_json::from_str(s).unwrap();
362
363 assert!(t.should_force_tls());
364
365 assert!(matches!(t.mode, TlsMode::VerifyFull));
366 assert!(!t.key_path.is_empty());
367 assert!(!t.cert_path.is_empty());
368 assert!(!t.watch_enabled());
369 }
370
371 #[test]
372 fn test_tls_option_watch_enabled() {
373 let s = r#"
374 {
375 "mode": "verify_full",
376 "cert_path": "/some_dir/some.crt",
377 "key_path": "/some_dir/some.key",
378 "watch": true
379 }
380 "#;
381
382 let t: TlsOption = serde_json::from_str(s).unwrap();
383
384 assert!(t.should_force_tls());
385
386 assert!(matches!(t.mode, TlsMode::VerifyFull));
387 assert!(!t.key_path.is_empty());
388 assert!(!t.cert_path.is_empty());
389 assert!(t.watch_enabled());
390 }
391
392 #[test]
393 fn test_tls_file_change_watch() {
394 common_telemetry::init_default_ut_logging();
395 let _ = install_ring_crypto_provider();
396
397 let dir = tempfile::tempdir().unwrap();
398 let cert_path = dir.path().join("server.crt");
399 let key_path = dir.path().join("server.key");
400
401 std::fs::copy("tests/ssl/server.crt", &cert_path).expect("failed to copy cert to tmpdir");
402 std::fs::copy("tests/ssl/server-rsa.key", &key_path).expect("failed to copy key to tmpdir");
403
404 assert!(std::fs::exists(&cert_path).unwrap());
405 assert!(std::fs::exists(&key_path).unwrap());
406
407 let server_tls = TlsOption {
408 mode: TlsMode::Require,
409 cert_path: cert_path
410 .clone()
411 .into_os_string()
412 .into_string()
413 .expect("failed to convert path to string"),
414 key_path: key_path
415 .clone()
416 .into_os_string()
417 .into_string()
418 .expect("failed to convert path to string"),
419 ca_cert_path: String::new(),
420 watch: true,
421 };
422
423 let server_config = Arc::new(
424 ReloadableTlsServerConfig::try_new(server_tls).expect("failed to create server config"),
425 );
426 maybe_watch_tls_config(server_config.clone()).expect("failed to watch server config");
427
428 assert_eq!(0, server_config.get_version());
429 assert!(server_config.get_server_config().is_some());
430
431 let tmp_file = key_path.with_extension("tmp");
432 std::fs::copy("tests/ssl/server-pkcs8.key", &tmp_file)
433 .expect("Failed to copy temp key file");
434 std::fs::rename(&tmp_file, &key_path).expect("Failed to rename temp key file");
435
436 const MAX_RETRIES: usize = 30;
437 let mut retries = 0;
438 let mut version_updated = false;
439
440 while retries < MAX_RETRIES {
441 if server_config.get_version() > 0 {
442 version_updated = true;
443 break;
444 }
445 std::thread::sleep(std::time::Duration::from_millis(100));
446 retries += 1;
447 }
448
449 assert!(version_updated, "TLS config did not reload in time");
450 assert!(server_config.get_version() > 0);
451 assert!(server_config.get_server_config().is_some());
452 }
453}