1use std::fs::File;
16use std::io::{BufReader, Error as IoError, ErrorKind};
17use std::path::Path;
18use std::sync::Arc;
19
20use common_grpc::reloadable_tls::{ReloadableTlsConfig, TlsConfigLoader};
21use common_telemetry::error;
22use rustls::ServerConfig;
23use rustls_pemfile::{Item, certs, read_one};
24use rustls_pki_types::{CertificateDer, PrivateKeyDer};
25use serde::{Deserialize, Serialize};
26use snafu::ResultExt;
27use strum::EnumString;
28
29use crate::error::{InternalIoSnafu, Result};
30
31#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq, EnumString)]
33#[serde(rename_all = "snake_case")]
34pub enum TlsMode {
35 #[default]
36 #[strum(to_string = "disable")]
37 Disable,
38
39 #[strum(to_string = "prefer")]
40 Prefer,
41
42 #[strum(to_string = "require")]
43 Require,
44
45 #[strum(to_string = "verify-ca")]
48 VerifyCa,
49
50 #[strum(to_string = "verify-full")]
51 VerifyFull,
52}
53
54#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq)]
55#[serde(rename_all = "snake_case")]
56pub struct TlsOption {
57 pub mode: TlsMode,
58 #[serde(default)]
59 pub cert_path: String,
60 #[serde(default)]
61 pub key_path: String,
62 #[serde(default)]
63 pub ca_cert_path: String,
64 #[serde(default)]
65 pub watch: bool,
66}
67
68impl TlsOption {
69 pub fn new(
70 mode: Option<TlsMode>,
71 cert_path: Option<String>,
72 key_path: Option<String>,
73 watch: bool,
74 ) -> Self {
75 let mut tls_option = TlsOption::default();
76
77 if let Some(mode) = mode {
78 tls_option.mode = mode
79 };
80
81 if let Some(cert_path) = cert_path {
82 tls_option.cert_path = cert_path
83 };
84
85 if let Some(key_path) = key_path {
86 tls_option.key_path = key_path
87 };
88
89 tls_option.watch = watch;
90
91 tls_option
92 }
93
94 pub fn setup(&self) -> Result<Option<ServerConfig>> {
95 if let TlsMode::Disable = self.mode {
96 return Ok(None);
97 }
98 let cert = certs(&mut BufReader::new(
99 File::open(&self.cert_path)
100 .inspect_err(|e| error!(e; "Failed to open {}", self.cert_path))
101 .context(InternalIoSnafu)?,
102 ))
103 .collect::<std::result::Result<Vec<CertificateDer>, IoError>>()
104 .context(InternalIoSnafu)?;
105
106 let mut key_reader = BufReader::new(
107 File::open(&self.key_path)
108 .inspect_err(|e| error!(e; "Failed to open {}", self.key_path))
109 .context(InternalIoSnafu)?,
110 );
111 let key = match read_one(&mut key_reader)
112 .inspect_err(|e| error!(e; "Failed to read {}", self.key_path))
113 .context(InternalIoSnafu)?
114 {
115 Some(Item::Pkcs1Key(key)) => PrivateKeyDer::from(key),
116 Some(Item::Pkcs8Key(key)) => PrivateKeyDer::from(key),
117 Some(Item::Sec1Key(key)) => PrivateKeyDer::from(key),
118 _ => {
119 return Err(IoError::new(ErrorKind::InvalidInput, "invalid key"))
120 .context(InternalIoSnafu);
121 }
122 };
123
124 let config = ServerConfig::builder()
126 .with_no_client_auth()
127 .with_single_cert(cert, key)
128 .map_err(|err| std::io::Error::new(ErrorKind::InvalidInput, err))?;
129
130 Ok(Some(config))
131 }
132
133 pub fn should_force_tls(&self) -> bool {
134 !matches!(self.mode, TlsMode::Disable | TlsMode::Prefer)
135 }
136
137 pub fn cert_path(&self) -> &Path {
138 Path::new(&self.cert_path)
139 }
140
141 pub fn key_path(&self) -> &Path {
142 Path::new(&self.key_path)
143 }
144
145 pub fn watch_enabled(&self) -> bool {
146 self.mode != TlsMode::Disable && self.watch
147 }
148}
149
150impl TlsConfigLoader<Arc<ServerConfig>> for TlsOption {
151 type Error = crate::error::Error;
152
153 fn load(&self) -> Result<Option<Arc<ServerConfig>>> {
154 Ok(self.setup()?.map(Arc::new))
155 }
156
157 fn watch_paths(&self) -> Vec<&Path> {
158 vec![self.cert_path(), self.key_path()]
159 }
160
161 fn watch_enabled(&self) -> bool {
162 self.mode != TlsMode::Disable && self.watch
163 }
164}
165
166pub type ReloadableTlsServerConfig = ReloadableTlsConfig<Arc<ServerConfig>, TlsOption>;
168
169pub fn maybe_watch_server_tls_config(
171 tls_server_config: Arc<ReloadableTlsServerConfig>,
172) -> Result<()> {
173 common_grpc::reloadable_tls::maybe_watch_tls_config(tls_server_config, || {}).map_err(|e| {
174 crate::error::Error::Internal {
175 err_msg: format!("Failed to watch TLS config: {}", e),
176 }
177 })
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use crate::install_ring_crypto_provider;
184 use crate::tls::TlsMode::Disable;
185
186 #[test]
187 fn test_new_tls_option() {
188 assert_eq!(
189 TlsOption::default(),
190 TlsOption::new(None, None, None, false)
191 );
192 assert_eq!(
193 TlsOption {
194 mode: Disable,
195 ..Default::default()
196 },
197 TlsOption::new(Some(Disable), None, None, false)
198 );
199 assert_eq!(
200 TlsOption {
201 mode: Disable,
202 cert_path: "/path/to/cert_path".to_string(),
203 key_path: "/path/to/key_path".to_string(),
204 ca_cert_path: String::new(),
205 watch: false
206 },
207 TlsOption::new(
208 Some(Disable),
209 Some("/path/to/cert_path".to_string()),
210 Some("/path/to/key_path".to_string()),
211 false
212 )
213 );
214 }
215
216 #[test]
217 fn test_tls_option_disable() {
218 let s = r#"
219 {
220 "mode": "disable"
221 }
222 "#;
223
224 let t: TlsOption = serde_json::from_str(s).unwrap();
225
226 assert!(!t.should_force_tls());
227
228 assert!(matches!(t.mode, TlsMode::Disable));
229 assert!(t.key_path.is_empty());
230 assert!(t.cert_path.is_empty());
231 assert!(!t.watch_enabled());
232
233 let setup = t.setup();
234 let setup = setup.unwrap();
235 assert!(setup.is_none());
236 }
237
238 #[test]
239 fn test_tls_option_prefer() {
240 let s = r#"
241 {
242 "mode": "prefer",
243 "cert_path": "/some_dir/some.crt",
244 "key_path": "/some_dir/some.key"
245 }
246 "#;
247
248 let t: TlsOption = serde_json::from_str(s).unwrap();
249
250 assert!(!t.should_force_tls());
251
252 assert!(matches!(t.mode, TlsMode::Prefer));
253 assert!(!t.key_path.is_empty());
254 assert!(!t.cert_path.is_empty());
255 assert!(!t.watch_enabled());
256 }
257
258 #[test]
259 fn test_tls_option_require() {
260 let s = r#"
261 {
262 "mode": "require",
263 "cert_path": "/some_dir/some.crt",
264 "key_path": "/some_dir/some.key"
265 }
266 "#;
267
268 let t: TlsOption = serde_json::from_str(s).unwrap();
269
270 assert!(t.should_force_tls());
271
272 assert!(matches!(t.mode, TlsMode::Require));
273 assert!(!t.key_path.is_empty());
274 assert!(!t.cert_path.is_empty());
275 assert!(!t.watch_enabled());
276 }
277
278 #[test]
279 fn test_tls_option_verify_ca() {
280 let s = r#"
281 {
282 "mode": "verify_ca",
283 "cert_path": "/some_dir/some.crt",
284 "key_path": "/some_dir/some.key"
285 }
286 "#;
287
288 let t: TlsOption = serde_json::from_str(s).unwrap();
289
290 assert!(t.should_force_tls());
291
292 assert!(matches!(t.mode, TlsMode::VerifyCa));
293 assert!(!t.key_path.is_empty());
294 assert!(!t.cert_path.is_empty());
295 assert!(!t.watch_enabled());
296 }
297
298 #[test]
299 fn test_tls_option_verify_full() {
300 let s = r#"
301 {
302 "mode": "verify_full",
303 "cert_path": "/some_dir/some.crt",
304 "key_path": "/some_dir/some.key"
305 }
306 "#;
307
308 let t: TlsOption = serde_json::from_str(s).unwrap();
309
310 assert!(t.should_force_tls());
311
312 assert!(matches!(t.mode, TlsMode::VerifyFull));
313 assert!(!t.key_path.is_empty());
314 assert!(!t.cert_path.is_empty());
315 assert!(!t.watch_enabled());
316 }
317
318 #[test]
319 fn test_tls_option_watch_enabled() {
320 let s = r#"
321 {
322 "mode": "verify_full",
323 "cert_path": "/some_dir/some.crt",
324 "key_path": "/some_dir/some.key",
325 "watch": true
326 }
327 "#;
328
329 let t: TlsOption = serde_json::from_str(s).unwrap();
330
331 assert!(t.should_force_tls());
332
333 assert!(matches!(t.mode, TlsMode::VerifyFull));
334 assert!(!t.key_path.is_empty());
335 assert!(!t.cert_path.is_empty());
336 assert!(t.watch_enabled());
337 }
338
339 #[test]
340 fn test_tls_file_change_watch() {
341 common_telemetry::init_default_ut_logging();
342 let _ = install_ring_crypto_provider();
343
344 let dir = tempfile::tempdir().unwrap();
345 let cert_path = dir.path().join("server.crt");
346 let key_path = dir.path().join("server.key");
347
348 std::fs::copy("tests/ssl/server.crt", &cert_path).expect("failed to copy cert to tmpdir");
349 std::fs::copy("tests/ssl/server-rsa.key", &key_path).expect("failed to copy key to tmpdir");
350
351 assert!(std::fs::exists(&cert_path).unwrap());
352 assert!(std::fs::exists(&key_path).unwrap());
353
354 let server_tls = TlsOption {
355 mode: TlsMode::Require,
356 cert_path: cert_path
357 .clone()
358 .into_os_string()
359 .into_string()
360 .expect("failed to convert path to string"),
361 key_path: key_path
362 .clone()
363 .into_os_string()
364 .into_string()
365 .expect("failed to convert path to string"),
366 ca_cert_path: String::new(),
367 watch: true,
368 };
369
370 let server_config = Arc::new(
371 ReloadableTlsServerConfig::try_new(server_tls).expect("failed to create server config"),
372 );
373 maybe_watch_server_tls_config(server_config.clone())
374 .expect("failed to watch server config");
375
376 assert_eq!(0, server_config.get_version());
377 assert!(server_config.get_config().is_some());
378
379 let tmp_file = key_path.with_extension("tmp");
380 std::fs::copy("tests/ssl/server-pkcs8.key", &tmp_file)
381 .expect("Failed to copy temp key file");
382 std::fs::rename(&tmp_file, &key_path).expect("Failed to rename temp key file");
383
384 const MAX_RETRIES: usize = 30;
385 let mut retries = 0;
386 let mut version_updated = false;
387
388 while retries < MAX_RETRIES {
389 if server_config.get_version() > 0 {
390 version_updated = true;
391 break;
392 }
393 std::thread::sleep(std::time::Duration::from_millis(100));
394 retries += 1;
395 }
396
397 assert!(version_updated, "TLS config did not reload in time");
398 assert!(server_config.get_version() > 0);
399 assert!(server_config.get_config().is_some());
400 }
401}