1use std::fs::File;
16use std::io::{BufReader, Error as IoError, ErrorKind};
17use std::path::Path;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::mpsc::channel;
20use std::sync::{Arc, RwLock};
21
22use common_telemetry::{error, info};
23use notify::{EventKind, RecursiveMode, Watcher};
24use rustls::ServerConfig;
25use rustls_pemfile::{certs, read_one, Item};
26use rustls_pki_types::{CertificateDer, PrivateKeyDer};
27use serde::{Deserialize, Serialize};
28use snafu::ResultExt;
29use strum::EnumString;
30
31use crate::error::{FileWatchSnafu, InternalIoSnafu, Result};
32
33#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq, EnumString)]
35#[serde(rename_all = "snake_case")]
36pub enum TlsMode {
37 #[default]
38 #[strum(to_string = "disable")]
39 Disable,
40
41 #[strum(to_string = "prefer")]
42 Prefer,
43
44 #[strum(to_string = "require")]
45 Require,
46
47 #[strum(to_string = "verify-ca")]
50 VerifyCa,
51
52 #[strum(to_string = "verify-full")]
53 VerifyFull,
54}
55
56#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq)]
57#[serde(rename_all = "snake_case")]
58pub struct TlsOption {
59 pub mode: TlsMode,
60 #[serde(default)]
61 pub cert_path: String,
62 #[serde(default)]
63 pub key_path: String,
64 #[serde(default)]
65 pub watch: bool,
66}
67
68impl TlsOption {
69 pub fn new(mode: Option<TlsMode>, cert_path: Option<String>, key_path: Option<String>) -> Self {
70 let mut tls_option = TlsOption::default();
71
72 if let Some(mode) = mode {
73 tls_option.mode = mode
74 };
75
76 if let Some(cert_path) = cert_path {
77 tls_option.cert_path = cert_path
78 };
79
80 if let Some(key_path) = key_path {
81 tls_option.key_path = key_path
82 };
83
84 tls_option
85 }
86
87 pub fn setup(&self) -> Result<Option<ServerConfig>> {
88 if let TlsMode::Disable = self.mode {
89 return Ok(None);
90 }
91 let cert = certs(&mut BufReader::new(
92 File::open(&self.cert_path).context(InternalIoSnafu)?,
93 ))
94 .collect::<std::result::Result<Vec<CertificateDer>, IoError>>()
95 .context(InternalIoSnafu)?;
96
97 let mut key_reader = BufReader::new(File::open(&self.key_path).context(InternalIoSnafu)?);
98 let key = match read_one(&mut key_reader).context(InternalIoSnafu)? {
99 Some(Item::Pkcs1Key(key)) => PrivateKeyDer::from(key),
100 Some(Item::Pkcs8Key(key)) => PrivateKeyDer::from(key),
101 Some(Item::Sec1Key(key)) => PrivateKeyDer::from(key),
102 _ => {
103 return Err(IoError::new(ErrorKind::InvalidInput, "invalid key"))
104 .context(InternalIoSnafu);
105 }
106 };
107
108 let config = ServerConfig::builder()
110 .with_no_client_auth()
111 .with_single_cert(cert, key)
112 .map_err(|err| std::io::Error::new(ErrorKind::InvalidInput, err))?;
113
114 Ok(Some(config))
115 }
116
117 pub fn should_force_tls(&self) -> bool {
118 !matches!(self.mode, TlsMode::Disable | TlsMode::Prefer)
119 }
120
121 pub fn cert_path(&self) -> &Path {
122 Path::new(&self.cert_path)
123 }
124
125 pub fn key_path(&self) -> &Path {
126 Path::new(&self.key_path)
127 }
128
129 pub fn watch_enabled(&self) -> bool {
130 self.mode != TlsMode::Disable && self.watch
131 }
132}
133
134pub struct ReloadableTlsServerConfig {
138 tls_option: TlsOption,
139 config: RwLock<Option<Arc<ServerConfig>>>,
140 version: AtomicUsize,
141}
142
143impl ReloadableTlsServerConfig {
144 pub fn try_new(tls_option: TlsOption) -> Result<ReloadableTlsServerConfig> {
146 let server_config = tls_option.setup()?;
147 Ok(Self {
148 tls_option,
149 config: RwLock::new(server_config.map(Arc::new)),
150 version: AtomicUsize::new(0),
151 })
152 }
153
154 pub fn reload(&self) -> Result<()> {
156 let server_config = self.tls_option.setup()?;
157 *self.config.write().unwrap() = server_config.map(Arc::new);
158 self.version.fetch_add(1, Ordering::Relaxed);
159 Ok(())
160 }
161
162 pub fn get_server_config(&self) -> Option<Arc<ServerConfig>> {
164 self.config.read().unwrap().clone()
165 }
166
167 pub fn get_tls_option(&self) -> &TlsOption {
169 &self.tls_option
170 }
171
172 pub fn get_version(&self) -> usize {
176 self.version.load(Ordering::Relaxed)
177 }
178}
179
180pub fn maybe_watch_tls_config(tls_server_config: Arc<ReloadableTlsServerConfig>) -> Result<()> {
181 if !tls_server_config.get_tls_option().watch_enabled() {
182 return Ok(());
183 }
184
185 let tls_server_config_for_watcher = tls_server_config.clone();
186
187 let (tx, rx) = channel::<notify::Result<notify::Event>>();
188 let mut watcher = notify::recommended_watcher(tx).context(FileWatchSnafu { path: "<none>" })?;
189
190 let cert_path = tls_server_config.get_tls_option().cert_path();
191 watcher
192 .watch(cert_path, RecursiveMode::NonRecursive)
193 .with_context(|_| FileWatchSnafu {
194 path: cert_path.display().to_string(),
195 })?;
196
197 let key_path = tls_server_config.get_tls_option().key_path();
198 watcher
199 .watch(key_path, RecursiveMode::NonRecursive)
200 .with_context(|_| FileWatchSnafu {
201 path: key_path.display().to_string(),
202 })?;
203
204 std::thread::spawn(move || {
205 let _watcher = watcher;
206 while let Ok(res) = rx.recv() {
207 if let Ok(event) = res {
208 match event.kind {
209 EventKind::Modify(_) | EventKind::Create(_) => {
210 info!("Detected TLS cert/key file change: {:?}", event);
211 if let Err(err) = tls_server_config_for_watcher.reload() {
212 error!(err; "Failed to reload TLS server config");
213 }
214 }
215 _ => {}
216 }
217 }
218 }
219 });
220
221 Ok(())
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use crate::install_ring_crypto_provider;
228 use crate::tls::TlsMode::Disable;
229
230 #[test]
231 fn test_new_tls_option() {
232 assert_eq!(TlsOption::default(), TlsOption::new(None, None, None));
233 assert_eq!(
234 TlsOption {
235 mode: Disable,
236 ..Default::default()
237 },
238 TlsOption::new(Some(Disable), None, None)
239 );
240 assert_eq!(
241 TlsOption {
242 mode: Disable,
243 cert_path: "/path/to/cert_path".to_string(),
244 key_path: "/path/to/key_path".to_string(),
245 watch: false
246 },
247 TlsOption::new(
248 Some(Disable),
249 Some("/path/to/cert_path".to_string()),
250 Some("/path/to/key_path".to_string())
251 )
252 );
253 }
254
255 #[test]
256 fn test_tls_option_disable() {
257 let s = r#"
258 {
259 "mode": "disable"
260 }
261 "#;
262
263 let t: TlsOption = serde_json::from_str(s).unwrap();
264
265 assert!(!t.should_force_tls());
266
267 assert!(matches!(t.mode, TlsMode::Disable));
268 assert!(t.key_path.is_empty());
269 assert!(t.cert_path.is_empty());
270 assert!(!t.watch_enabled());
271
272 let setup = t.setup();
273 let setup = setup.unwrap();
274 assert!(setup.is_none());
275 }
276
277 #[test]
278 fn test_tls_option_prefer() {
279 let s = r#"
280 {
281 "mode": "prefer",
282 "cert_path": "/some_dir/some.crt",
283 "key_path": "/some_dir/some.key"
284 }
285 "#;
286
287 let t: TlsOption = serde_json::from_str(s).unwrap();
288
289 assert!(!t.should_force_tls());
290
291 assert!(matches!(t.mode, TlsMode::Prefer));
292 assert!(!t.key_path.is_empty());
293 assert!(!t.cert_path.is_empty());
294 assert!(!t.watch_enabled());
295 }
296
297 #[test]
298 fn test_tls_option_require() {
299 let s = r#"
300 {
301 "mode": "require",
302 "cert_path": "/some_dir/some.crt",
303 "key_path": "/some_dir/some.key"
304 }
305 "#;
306
307 let t: TlsOption = serde_json::from_str(s).unwrap();
308
309 assert!(t.should_force_tls());
310
311 assert!(matches!(t.mode, TlsMode::Require));
312 assert!(!t.key_path.is_empty());
313 assert!(!t.cert_path.is_empty());
314 assert!(!t.watch_enabled());
315 }
316
317 #[test]
318 fn test_tls_option_verify_ca() {
319 let s = r#"
320 {
321 "mode": "verify_ca",
322 "cert_path": "/some_dir/some.crt",
323 "key_path": "/some_dir/some.key"
324 }
325 "#;
326
327 let t: TlsOption = serde_json::from_str(s).unwrap();
328
329 assert!(t.should_force_tls());
330
331 assert!(matches!(t.mode, TlsMode::VerifyCa));
332 assert!(!t.key_path.is_empty());
333 assert!(!t.cert_path.is_empty());
334 assert!(!t.watch_enabled());
335 }
336
337 #[test]
338 fn test_tls_option_verify_full() {
339 let s = r#"
340 {
341 "mode": "verify_full",
342 "cert_path": "/some_dir/some.crt",
343 "key_path": "/some_dir/some.key"
344 }
345 "#;
346
347 let t: TlsOption = serde_json::from_str(s).unwrap();
348
349 assert!(t.should_force_tls());
350
351 assert!(matches!(t.mode, TlsMode::VerifyFull));
352 assert!(!t.key_path.is_empty());
353 assert!(!t.cert_path.is_empty());
354 assert!(!t.watch_enabled());
355 }
356
357 #[test]
358 fn test_tls_option_watch_enabled() {
359 let s = r#"
360 {
361 "mode": "verify_full",
362 "cert_path": "/some_dir/some.crt",
363 "key_path": "/some_dir/some.key",
364 "watch": true
365 }
366 "#;
367
368 let t: TlsOption = serde_json::from_str(s).unwrap();
369
370 assert!(t.should_force_tls());
371
372 assert!(matches!(t.mode, TlsMode::VerifyFull));
373 assert!(!t.key_path.is_empty());
374 assert!(!t.cert_path.is_empty());
375 assert!(t.watch_enabled());
376 }
377
378 #[test]
379 fn test_tls_file_change_watch() {
380 common_telemetry::init_default_ut_logging();
381 let _ = install_ring_crypto_provider();
382
383 let dir = tempfile::tempdir().unwrap();
384 let cert_path = dir.path().join("serevr.crt");
385 let key_path = dir.path().join("server.key");
386
387 std::fs::copy("tests/ssl/server.crt", &cert_path).expect("failed to copy cert to tmpdir");
388 std::fs::copy("tests/ssl/server-rsa.key", &key_path).expect("failed to copy key to tmpdir");
389
390 let server_tls = TlsOption {
391 mode: TlsMode::Require,
392 cert_path: cert_path
393 .clone()
394 .into_os_string()
395 .into_string()
396 .expect("failed to convert path to string"),
397 key_path: key_path
398 .clone()
399 .into_os_string()
400 .into_string()
401 .expect("failed to convert path to string"),
402 watch: true,
403 };
404
405 let server_config = Arc::new(
406 ReloadableTlsServerConfig::try_new(server_tls).expect("failed to create server config"),
407 );
408 maybe_watch_tls_config(server_config.clone()).expect("failed to watch server config");
409
410 assert_eq!(0, server_config.get_version());
411 assert!(server_config.get_server_config().is_some());
412
413 std::fs::copy("tests/ssl/server-pkcs8.key", &key_path)
414 .expect("failed to copy key to tmpdir");
415
416 #[cfg(not(target_os = "windows"))]
418 let timeout_millis = 300;
419 #[cfg(target_os = "windows")]
420 let timeout_millis = 2000;
421
422 std::thread::sleep(std::time::Duration::from_millis(timeout_millis));
423
424 assert!(server_config.get_version() > 1);
425 assert!(server_config.get_server_config().is_some());
426 }
427}