1use std::fs::File;
16use std::io::{BufReader, Error as IoError, ErrorKind};
17use std::path::Path;
18use std::sync::Arc;
19
20use common_grpc::reloadable_tls::{ReloadableTlsConfig, TlsConfigLoader};
21use common_telemetry::error;
22use rustls::ServerConfig;
23use rustls_pemfile::{Item, certs, read_one};
24use rustls_pki_types::{CertificateDer, PrivateKeyDer};
25use serde::{Deserialize, Serialize};
26use snafu::ResultExt;
27use strum::EnumString;
28
29use crate::error::{InternalIoSnafu, Result};
30
31#[derive(Debug, Default, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, EnumString)]
33#[serde(rename_all = "snake_case")]
34pub enum TlsMode {
35 #[default]
36 #[strum(to_string = "disable")]
37 Disable,
38
39 #[strum(to_string = "prefer")]
40 Prefer,
41
42 #[strum(to_string = "require")]
43 Require,
44
45 #[strum(to_string = "verify-ca")]
48 VerifyCa,
49
50 #[strum(to_string = "verify-full")]
51 VerifyFull,
52}
53
54#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq)]
55#[serde(rename_all = "snake_case")]
56pub struct TlsOption {
57 pub mode: TlsMode,
58 #[serde(default)]
59 pub cert_path: String,
60 #[serde(default)]
61 pub key_path: String,
62 #[serde(default)]
63 pub ca_cert_path: String,
64 #[serde(default)]
65 pub watch: bool,
66}
67
68impl TlsOption {
69 pub fn new(
70 mode: Option<TlsMode>,
71 cert_path: Option<String>,
72 key_path: Option<String>,
73 watch: bool,
74 ) -> Self {
75 let mut tls_option = TlsOption::default();
76
77 if let Some(mode) = mode {
78 tls_option.mode = mode
79 };
80
81 if let Some(cert_path) = cert_path {
82 tls_option.cert_path = cert_path
83 };
84
85 if let Some(key_path) = key_path {
86 tls_option.key_path = key_path
87 };
88
89 tls_option.watch = watch;
90
91 tls_option
92 }
93
94 pub fn prefer() -> Self {
96 Self {
97 mode: TlsMode::Prefer,
98 cert_path: String::new(),
99 key_path: String::new(),
100 ca_cert_path: String::new(),
101 watch: false,
102 }
103 }
104
105 pub fn validate(&self) -> Result<()> {
111 if self.mode == TlsMode::Disable {
112 return Ok(());
113 }
114
115 if self.cert_path.is_empty() {
117 return Err(crate::error::Error::Internal {
118 err_msg: format!(
119 "TLS mode is {:?} but cert_path is not configured",
120 self.mode
121 ),
122 });
123 }
124
125 if self.key_path.is_empty() {
126 return Err(crate::error::Error::Internal {
127 err_msg: format!("TLS mode is {:?} but key_path is not configured", self.mode),
128 });
129 }
130
131 if matches!(self.mode, TlsMode::VerifyCa | TlsMode::VerifyFull)
133 && self.ca_cert_path.is_empty()
134 {
135 return Err(crate::error::Error::Internal {
136 err_msg: format!(
137 "TLS mode is {:?} but ca_cert_path is not configured",
138 self.mode
139 ),
140 });
141 }
142
143 Ok(())
144 }
145
146 pub fn setup(&self) -> Result<Option<ServerConfig>> {
147 if let TlsMode::Disable = self.mode {
148 return Ok(None);
149 }
150 let cert = certs(&mut BufReader::new(
151 File::open(&self.cert_path)
152 .inspect_err(|e| error!(e; "Failed to open {}", self.cert_path))
153 .context(InternalIoSnafu)?,
154 ))
155 .collect::<std::result::Result<Vec<CertificateDer>, IoError>>()
156 .context(InternalIoSnafu)?;
157
158 let mut key_reader = BufReader::new(
159 File::open(&self.key_path)
160 .inspect_err(|e| error!(e; "Failed to open {}", self.key_path))
161 .context(InternalIoSnafu)?,
162 );
163 let key = match read_one(&mut key_reader)
164 .inspect_err(|e| error!(e; "Failed to read {}", self.key_path))
165 .context(InternalIoSnafu)?
166 {
167 Some(Item::Pkcs1Key(key)) => PrivateKeyDer::from(key),
168 Some(Item::Pkcs8Key(key)) => PrivateKeyDer::from(key),
169 Some(Item::Sec1Key(key)) => PrivateKeyDer::from(key),
170 _ => {
171 return Err(IoError::new(ErrorKind::InvalidInput, "invalid key"))
172 .context(InternalIoSnafu);
173 }
174 };
175
176 let config = ServerConfig::builder()
178 .with_no_client_auth()
179 .with_single_cert(cert, key)
180 .map_err(|err| std::io::Error::new(ErrorKind::InvalidInput, err))?;
181
182 Ok(Some(config))
183 }
184
185 pub fn should_force_tls(&self) -> bool {
186 !matches!(self.mode, TlsMode::Disable | TlsMode::Prefer)
187 }
188
189 pub fn cert_path(&self) -> &Path {
190 Path::new(&self.cert_path)
191 }
192
193 pub fn key_path(&self) -> &Path {
194 Path::new(&self.key_path)
195 }
196
197 pub fn watch_enabled(&self) -> bool {
198 self.mode != TlsMode::Disable && self.watch
199 }
200}
201
202pub fn merge_tls_option(main: &TlsOption, other: TlsOption) -> TlsOption {
203 if other.mode != TlsMode::Disable && other.validate().is_ok() {
204 return other;
205 }
206 main.clone()
207}
208
209impl TlsConfigLoader<Arc<ServerConfig>> for TlsOption {
210 type Error = crate::error::Error;
211
212 fn load(&self) -> Result<Option<Arc<ServerConfig>>> {
213 Ok(self.setup()?.map(Arc::new))
214 }
215
216 fn watch_paths(&self) -> Vec<&Path> {
217 vec![self.cert_path(), self.key_path()]
218 }
219
220 fn watch_enabled(&self) -> bool {
221 self.mode != TlsMode::Disable && self.watch
222 }
223}
224
225pub type ReloadableTlsServerConfig = ReloadableTlsConfig<Arc<ServerConfig>, TlsOption>;
227
228pub fn maybe_watch_server_tls_config(
230 tls_server_config: Arc<ReloadableTlsServerConfig>,
231) -> Result<()> {
232 common_grpc::reloadable_tls::maybe_watch_tls_config(tls_server_config, || {}).map_err(|e| {
233 crate::error::Error::Internal {
234 err_msg: format!("Failed to watch TLS config: {}", e),
235 }
236 })
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use crate::install_ring_crypto_provider;
243 use crate::tls::TlsMode::Disable;
244
245 #[test]
246 fn test_validate_disable_mode() {
247 let tls = TlsOption {
248 mode: TlsMode::Disable,
249 cert_path: String::new(),
250 key_path: String::new(),
251 ca_cert_path: String::new(),
252 watch: false,
253 };
254 assert!(tls.validate().is_ok());
255 }
256
257 #[test]
258 fn test_validate_missing_cert_path() {
259 let tls = TlsOption {
260 mode: TlsMode::Require,
261 cert_path: String::new(),
262 key_path: "/path/to/key".to_string(),
263 ca_cert_path: String::new(),
264 watch: false,
265 };
266 let err = tls.validate().unwrap_err();
267 assert!(err.to_string().contains("cert_path"));
268 }
269
270 #[test]
271 fn test_validate_missing_key_path() {
272 let tls = TlsOption {
273 mode: TlsMode::Require,
274 cert_path: "/path/to/cert".to_string(),
275 key_path: String::new(),
276 ca_cert_path: String::new(),
277 watch: false,
278 };
279 let err = tls.validate().unwrap_err();
280 assert!(err.to_string().contains("key_path"));
281 }
282
283 #[test]
284 fn test_validate_require_mode_success() {
285 let tls = TlsOption {
286 mode: TlsMode::Require,
287 cert_path: "/path/to/cert".to_string(),
288 key_path: "/path/to/key".to_string(),
289 ca_cert_path: String::new(),
290 watch: false,
291 };
292 assert!(tls.validate().is_ok());
293 }
294
295 #[test]
296 fn test_validate_verify_ca_missing_ca_cert() {
297 let tls = TlsOption {
298 mode: TlsMode::VerifyCa,
299 cert_path: "/path/to/cert".to_string(),
300 key_path: "/path/to/key".to_string(),
301 ca_cert_path: String::new(),
302 watch: false,
303 };
304 let err = tls.validate().unwrap_err();
305 assert!(err.to_string().contains("ca_cert_path"));
306 }
307
308 #[test]
309 fn test_validate_verify_full_missing_ca_cert() {
310 let tls = TlsOption {
311 mode: TlsMode::VerifyFull,
312 cert_path: "/path/to/cert".to_string(),
313 key_path: "/path/to/key".to_string(),
314 ca_cert_path: String::new(),
315 watch: false,
316 };
317 let err = tls.validate().unwrap_err();
318 assert!(err.to_string().contains("ca_cert_path"));
319 }
320
321 #[test]
322 fn test_validate_verify_ca_success() {
323 let tls = TlsOption {
324 mode: TlsMode::VerifyCa,
325 cert_path: "/path/to/cert".to_string(),
326 key_path: "/path/to/key".to_string(),
327 ca_cert_path: "/path/to/ca".to_string(),
328 watch: false,
329 };
330 assert!(tls.validate().is_ok());
331 }
332
333 #[test]
334 fn test_validate_verify_full_success() {
335 let tls = TlsOption {
336 mode: TlsMode::VerifyFull,
337 cert_path: "/path/to/cert".to_string(),
338 key_path: "/path/to/key".to_string(),
339 ca_cert_path: "/path/to/ca".to_string(),
340 watch: false,
341 };
342 assert!(tls.validate().is_ok());
343 }
344
345 #[test]
346 fn test_validate_prefer_mode() {
347 let tls = TlsOption {
348 mode: TlsMode::Prefer,
349 cert_path: "/path/to/cert".to_string(),
350 key_path: "/path/to/key".to_string(),
351 ca_cert_path: String::new(),
352 watch: false,
353 };
354 assert!(tls.validate().is_ok());
355 }
356
357 #[test]
358 fn test_new_tls_option() {
359 assert_eq!(
360 TlsOption::default(),
361 TlsOption::new(None, None, None, false)
362 );
363 assert_eq!(
364 TlsOption {
365 mode: Disable,
366 ..Default::default()
367 },
368 TlsOption::new(Some(Disable), None, None, false)
369 );
370 assert_eq!(
371 TlsOption {
372 mode: Disable,
373 cert_path: "/path/to/cert_path".to_string(),
374 key_path: "/path/to/key_path".to_string(),
375 ca_cert_path: String::new(),
376 watch: false,
377 },
378 TlsOption::new(
379 Some(Disable),
380 Some("/path/to/cert_path".to_string()),
381 Some("/path/to/key_path".to_string()),
382 false,
383 )
384 );
385 }
386
387 #[test]
388 fn test_tls_option_disable() {
389 let s = r#"
390 {
391 "mode": "disable"
392 }
393 "#;
394
395 let t: TlsOption = serde_json::from_str(s).unwrap();
396
397 assert!(!t.should_force_tls());
398
399 assert!(matches!(t.mode, TlsMode::Disable));
400 assert!(t.key_path.is_empty());
401 assert!(t.cert_path.is_empty());
402 assert!(!t.watch_enabled());
403
404 let setup = t.setup();
405 let setup = setup.unwrap();
406 assert!(setup.is_none());
407 }
408
409 #[test]
410 fn test_tls_option_prefer() {
411 let s = r#"
412 {
413 "mode": "prefer",
414 "cert_path": "/some_dir/some.crt",
415 "key_path": "/some_dir/some.key"
416 }
417 "#;
418
419 let t: TlsOption = serde_json::from_str(s).unwrap();
420
421 assert!(!t.should_force_tls());
422
423 assert!(matches!(t.mode, TlsMode::Prefer));
424 assert!(!t.key_path.is_empty());
425 assert!(!t.cert_path.is_empty());
426 assert!(!t.watch_enabled());
427 }
428
429 #[test]
430 fn test_tls_option_require() {
431 let s = r#"
432 {
433 "mode": "require",
434 "cert_path": "/some_dir/some.crt",
435 "key_path": "/some_dir/some.key"
436 }
437 "#;
438
439 let t: TlsOption = serde_json::from_str(s).unwrap();
440
441 assert!(t.should_force_tls());
442
443 assert!(matches!(t.mode, TlsMode::Require));
444 assert!(!t.key_path.is_empty());
445 assert!(!t.cert_path.is_empty());
446 assert!(!t.watch_enabled());
447 }
448
449 #[test]
450 fn test_tls_option_verify_ca() {
451 let s = r#"
452 {
453 "mode": "verify_ca",
454 "cert_path": "/some_dir/some.crt",
455 "key_path": "/some_dir/some.key"
456 }
457 "#;
458
459 let t: TlsOption = serde_json::from_str(s).unwrap();
460
461 assert!(t.should_force_tls());
462
463 assert!(matches!(t.mode, TlsMode::VerifyCa));
464 assert!(!t.key_path.is_empty());
465 assert!(!t.cert_path.is_empty());
466 assert!(!t.watch_enabled());
467 }
468
469 #[test]
470 fn test_tls_option_verify_full() {
471 let s = r#"
472 {
473 "mode": "verify_full",
474 "cert_path": "/some_dir/some.crt",
475 "key_path": "/some_dir/some.key"
476 }
477 "#;
478
479 let t: TlsOption = serde_json::from_str(s).unwrap();
480
481 assert!(t.should_force_tls());
482
483 assert!(matches!(t.mode, TlsMode::VerifyFull));
484 assert!(!t.key_path.is_empty());
485 assert!(!t.cert_path.is_empty());
486 assert!(!t.watch_enabled());
487 }
488
489 #[test]
490 fn test_tls_option_watch_enabled() {
491 let s = r#"
492 {
493 "mode": "verify_full",
494 "cert_path": "/some_dir/some.crt",
495 "key_path": "/some_dir/some.key",
496 "watch": true
497 }
498 "#;
499
500 let t: TlsOption = serde_json::from_str(s).unwrap();
501
502 assert!(t.should_force_tls());
503
504 assert!(matches!(t.mode, TlsMode::VerifyFull));
505 assert!(!t.key_path.is_empty());
506 assert!(!t.cert_path.is_empty());
507 assert!(t.watch_enabled());
508 }
509
510 #[test]
511 fn test_tls_file_change_watch() {
512 common_telemetry::init_default_ut_logging();
513 let _ = install_ring_crypto_provider();
514
515 let dir = tempfile::tempdir().unwrap();
516 let cert_path = dir.path().join("server.crt");
517 let key_path = dir.path().join("server.key");
518
519 std::fs::copy("tests/ssl/server.crt", &cert_path).expect("failed to copy cert to tmpdir");
520 std::fs::copy("tests/ssl/server-rsa.key", &key_path).expect("failed to copy key to tmpdir");
521
522 assert!(std::fs::exists(&cert_path).unwrap());
523 assert!(std::fs::exists(&key_path).unwrap());
524
525 let server_tls = TlsOption {
526 mode: TlsMode::Require,
527 cert_path: cert_path
528 .clone()
529 .into_os_string()
530 .into_string()
531 .expect("failed to convert path to string"),
532 key_path: key_path
533 .clone()
534 .into_os_string()
535 .into_string()
536 .expect("failed to convert path to string"),
537 ca_cert_path: String::new(),
538 watch: true,
539 };
540
541 let server_config = Arc::new(
542 ReloadableTlsServerConfig::try_new(server_tls).expect("failed to create server config"),
543 );
544 maybe_watch_server_tls_config(server_config.clone())
545 .expect("failed to watch server config");
546
547 assert_eq!(0, server_config.get_version());
548 assert!(server_config.get_config().is_some());
549
550 let tmp_file = key_path.with_extension("tmp");
551 std::fs::copy("tests/ssl/server-pkcs8.key", &tmp_file)
552 .expect("Failed to copy temp key file");
553 std::fs::rename(&tmp_file, &key_path).expect("Failed to rename temp key file");
554
555 const MAX_RETRIES: usize = 30;
556 let mut retries = 0;
557 let mut version_updated = false;
558
559 while retries < MAX_RETRIES {
560 if server_config.get_version() > 0 {
561 version_updated = true;
562 break;
563 }
564 std::thread::sleep(std::time::Duration::from_millis(100));
565 retries += 1;
566 }
567
568 assert!(version_updated, "TLS config did not reload in time");
569 assert!(server_config.get_version() > 0);
570 assert!(server_config.get_config().is_some());
571 }
572}