servers/
tls.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fs::File;
16use std::io::{BufReader, Error as IoError, ErrorKind};
17use std::path::Path;
18use std::sync::Arc;
19
20use common_grpc::reloadable_tls::{ReloadableTlsConfig, TlsConfigLoader};
21use common_telemetry::error;
22use rustls::ServerConfig;
23use rustls_pemfile::{Item, certs, read_one};
24use rustls_pki_types::{CertificateDer, PrivateKeyDer};
25use serde::{Deserialize, Serialize};
26use snafu::ResultExt;
27use strum::EnumString;
28
29use crate::error::{InternalIoSnafu, Result};
30
31/// TlsMode is used for Mysql and Postgres server start up.
32#[derive(Debug, Default, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, EnumString)]
33#[serde(rename_all = "snake_case")]
34pub enum TlsMode {
35    #[default]
36    #[strum(to_string = "disable")]
37    Disable,
38
39    #[strum(to_string = "prefer")]
40    Prefer,
41
42    #[strum(to_string = "require")]
43    Require,
44
45    // TODO(SSebo): Implement the following 2 TSL mode described in
46    // ["34.19.3. Protection Provided in Different Modes"](https://www.postgresql.org/docs/current/libpq-ssl.html)
47    #[strum(to_string = "verify-ca")]
48    VerifyCa,
49
50    #[strum(to_string = "verify-full")]
51    VerifyFull,
52}
53
54#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq)]
55#[serde(rename_all = "snake_case")]
56pub struct TlsOption {
57    pub mode: TlsMode,
58    #[serde(default)]
59    pub cert_path: String,
60    #[serde(default)]
61    pub key_path: String,
62    #[serde(default)]
63    pub ca_cert_path: String,
64    #[serde(default)]
65    pub watch: bool,
66}
67
68impl TlsOption {
69    pub fn new(
70        mode: Option<TlsMode>,
71        cert_path: Option<String>,
72        key_path: Option<String>,
73        watch: bool,
74    ) -> Self {
75        let mut tls_option = TlsOption::default();
76
77        if let Some(mode) = mode {
78            tls_option.mode = mode
79        };
80
81        if let Some(cert_path) = cert_path {
82            tls_option.cert_path = cert_path
83        };
84
85        if let Some(key_path) = key_path {
86            tls_option.key_path = key_path
87        };
88
89        tls_option.watch = watch;
90
91        tls_option
92    }
93
94    /// Creates a new TLS option with the prefer mode.
95    pub fn prefer() -> Self {
96        Self {
97            mode: TlsMode::Prefer,
98            cert_path: String::new(),
99            key_path: String::new(),
100            ca_cert_path: String::new(),
101            watch: false,
102        }
103    }
104
105    /// Validates the TLS configuration.
106    ///
107    /// Returns an error if:
108    /// - TLS mode is enabled (not `Disable`) but `cert_path` or `key_path` is empty
109    /// - TLS mode is `VerifyCa` or `VerifyFull` but `ca_cert_path` is empty
110    pub fn validate(&self) -> Result<()> {
111        if self.mode == TlsMode::Disable {
112            return Ok(());
113        }
114
115        // When TLS is enabled, cert_path and key_path are required
116        if self.cert_path.is_empty() {
117            return Err(crate::error::Error::Internal {
118                err_msg: format!(
119                    "TLS mode is {:?} but cert_path is not configured",
120                    self.mode
121                ),
122            });
123        }
124
125        if self.key_path.is_empty() {
126            return Err(crate::error::Error::Internal {
127                err_msg: format!("TLS mode is {:?} but key_path is not configured", self.mode),
128            });
129        }
130
131        // For VerifyCa and VerifyFull modes, ca_cert_path is required for client verification
132        if matches!(self.mode, TlsMode::VerifyCa | TlsMode::VerifyFull)
133            && self.ca_cert_path.is_empty()
134        {
135            return Err(crate::error::Error::Internal {
136                err_msg: format!(
137                    "TLS mode is {:?} but ca_cert_path is not configured",
138                    self.mode
139                ),
140            });
141        }
142
143        Ok(())
144    }
145
146    pub fn setup(&self) -> Result<Option<ServerConfig>> {
147        if let TlsMode::Disable = self.mode {
148            return Ok(None);
149        }
150        let cert = certs(&mut BufReader::new(
151            File::open(&self.cert_path)
152                .inspect_err(|e| error!(e; "Failed to open {}", self.cert_path))
153                .context(InternalIoSnafu)?,
154        ))
155        .collect::<std::result::Result<Vec<CertificateDer>, IoError>>()
156        .context(InternalIoSnafu)?;
157
158        let mut key_reader = BufReader::new(
159            File::open(&self.key_path)
160                .inspect_err(|e| error!(e; "Failed to open {}", self.key_path))
161                .context(InternalIoSnafu)?,
162        );
163        let key = match read_one(&mut key_reader)
164            .inspect_err(|e| error!(e; "Failed to read {}", self.key_path))
165            .context(InternalIoSnafu)?
166        {
167            Some(Item::Pkcs1Key(key)) => PrivateKeyDer::from(key),
168            Some(Item::Pkcs8Key(key)) => PrivateKeyDer::from(key),
169            Some(Item::Sec1Key(key)) => PrivateKeyDer::from(key),
170            _ => {
171                return Err(IoError::new(ErrorKind::InvalidInput, "invalid key"))
172                    .context(InternalIoSnafu);
173            }
174        };
175
176        // TODO(SSebo): with_client_cert_verifier if TlsMode is Required.
177        let config = ServerConfig::builder()
178            .with_no_client_auth()
179            .with_single_cert(cert, key)
180            .map_err(|err| std::io::Error::new(ErrorKind::InvalidInput, err))?;
181
182        Ok(Some(config))
183    }
184
185    pub fn should_force_tls(&self) -> bool {
186        !matches!(self.mode, TlsMode::Disable | TlsMode::Prefer)
187    }
188
189    pub fn cert_path(&self) -> &Path {
190        Path::new(&self.cert_path)
191    }
192
193    pub fn key_path(&self) -> &Path {
194        Path::new(&self.key_path)
195    }
196
197    pub fn watch_enabled(&self) -> bool {
198        self.mode != TlsMode::Disable && self.watch
199    }
200}
201
202pub fn merge_tls_option(main: &TlsOption, other: TlsOption) -> TlsOption {
203    if other.mode != TlsMode::Disable && other.validate().is_ok() {
204        return other;
205    }
206    main.clone()
207}
208
209impl TlsConfigLoader<Arc<ServerConfig>> for TlsOption {
210    type Error = crate::error::Error;
211
212    fn load(&self) -> Result<Option<Arc<ServerConfig>>> {
213        Ok(self.setup()?.map(Arc::new))
214    }
215
216    fn watch_paths(&self) -> Vec<&Path> {
217        vec![self.cert_path(), self.key_path()]
218    }
219
220    fn watch_enabled(&self) -> bool {
221        self.mode != TlsMode::Disable && self.watch
222    }
223}
224
225/// Type alias for server-side reloadable TLS config
226pub type ReloadableTlsServerConfig = ReloadableTlsConfig<Arc<ServerConfig>, TlsOption>;
227
228/// Convenience function for watching server TLS configuration
229pub fn maybe_watch_server_tls_config(
230    tls_server_config: Arc<ReloadableTlsServerConfig>,
231) -> Result<()> {
232    common_grpc::reloadable_tls::maybe_watch_tls_config(tls_server_config, || {}).map_err(|e| {
233        crate::error::Error::Internal {
234            err_msg: format!("Failed to watch TLS config: {}", e),
235        }
236    })
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use crate::install_ring_crypto_provider;
243    use crate::tls::TlsMode::Disable;
244
245    #[test]
246    fn test_validate_disable_mode() {
247        let tls = TlsOption {
248            mode: TlsMode::Disable,
249            cert_path: String::new(),
250            key_path: String::new(),
251            ca_cert_path: String::new(),
252            watch: false,
253        };
254        assert!(tls.validate().is_ok());
255    }
256
257    #[test]
258    fn test_validate_missing_cert_path() {
259        let tls = TlsOption {
260            mode: TlsMode::Require,
261            cert_path: String::new(),
262            key_path: "/path/to/key".to_string(),
263            ca_cert_path: String::new(),
264            watch: false,
265        };
266        let err = tls.validate().unwrap_err();
267        assert!(err.to_string().contains("cert_path"));
268    }
269
270    #[test]
271    fn test_validate_missing_key_path() {
272        let tls = TlsOption {
273            mode: TlsMode::Require,
274            cert_path: "/path/to/cert".to_string(),
275            key_path: String::new(),
276            ca_cert_path: String::new(),
277            watch: false,
278        };
279        let err = tls.validate().unwrap_err();
280        assert!(err.to_string().contains("key_path"));
281    }
282
283    #[test]
284    fn test_validate_require_mode_success() {
285        let tls = TlsOption {
286            mode: TlsMode::Require,
287            cert_path: "/path/to/cert".to_string(),
288            key_path: "/path/to/key".to_string(),
289            ca_cert_path: String::new(),
290            watch: false,
291        };
292        assert!(tls.validate().is_ok());
293    }
294
295    #[test]
296    fn test_validate_verify_ca_missing_ca_cert() {
297        let tls = TlsOption {
298            mode: TlsMode::VerifyCa,
299            cert_path: "/path/to/cert".to_string(),
300            key_path: "/path/to/key".to_string(),
301            ca_cert_path: String::new(),
302            watch: false,
303        };
304        let err = tls.validate().unwrap_err();
305        assert!(err.to_string().contains("ca_cert_path"));
306    }
307
308    #[test]
309    fn test_validate_verify_full_missing_ca_cert() {
310        let tls = TlsOption {
311            mode: TlsMode::VerifyFull,
312            cert_path: "/path/to/cert".to_string(),
313            key_path: "/path/to/key".to_string(),
314            ca_cert_path: String::new(),
315            watch: false,
316        };
317        let err = tls.validate().unwrap_err();
318        assert!(err.to_string().contains("ca_cert_path"));
319    }
320
321    #[test]
322    fn test_validate_verify_ca_success() {
323        let tls = TlsOption {
324            mode: TlsMode::VerifyCa,
325            cert_path: "/path/to/cert".to_string(),
326            key_path: "/path/to/key".to_string(),
327            ca_cert_path: "/path/to/ca".to_string(),
328            watch: false,
329        };
330        assert!(tls.validate().is_ok());
331    }
332
333    #[test]
334    fn test_validate_verify_full_success() {
335        let tls = TlsOption {
336            mode: TlsMode::VerifyFull,
337            cert_path: "/path/to/cert".to_string(),
338            key_path: "/path/to/key".to_string(),
339            ca_cert_path: "/path/to/ca".to_string(),
340            watch: false,
341        };
342        assert!(tls.validate().is_ok());
343    }
344
345    #[test]
346    fn test_validate_prefer_mode() {
347        let tls = TlsOption {
348            mode: TlsMode::Prefer,
349            cert_path: "/path/to/cert".to_string(),
350            key_path: "/path/to/key".to_string(),
351            ca_cert_path: String::new(),
352            watch: false,
353        };
354        assert!(tls.validate().is_ok());
355    }
356
357    #[test]
358    fn test_new_tls_option() {
359        assert_eq!(
360            TlsOption::default(),
361            TlsOption::new(None, None, None, false)
362        );
363        assert_eq!(
364            TlsOption {
365                mode: Disable,
366                ..Default::default()
367            },
368            TlsOption::new(Some(Disable), None, None, false)
369        );
370        assert_eq!(
371            TlsOption {
372                mode: Disable,
373                cert_path: "/path/to/cert_path".to_string(),
374                key_path: "/path/to/key_path".to_string(),
375                ca_cert_path: String::new(),
376                watch: false,
377            },
378            TlsOption::new(
379                Some(Disable),
380                Some("/path/to/cert_path".to_string()),
381                Some("/path/to/key_path".to_string()),
382                false,
383            )
384        );
385    }
386
387    #[test]
388    fn test_tls_option_disable() {
389        let s = r#"
390        {
391            "mode": "disable"
392        }
393        "#;
394
395        let t: TlsOption = serde_json::from_str(s).unwrap();
396
397        assert!(!t.should_force_tls());
398
399        assert!(matches!(t.mode, TlsMode::Disable));
400        assert!(t.key_path.is_empty());
401        assert!(t.cert_path.is_empty());
402        assert!(!t.watch_enabled());
403
404        let setup = t.setup();
405        let setup = setup.unwrap();
406        assert!(setup.is_none());
407    }
408
409    #[test]
410    fn test_tls_option_prefer() {
411        let s = r#"
412        {
413            "mode": "prefer",
414            "cert_path": "/some_dir/some.crt",
415            "key_path": "/some_dir/some.key"
416        }
417        "#;
418
419        let t: TlsOption = serde_json::from_str(s).unwrap();
420
421        assert!(!t.should_force_tls());
422
423        assert!(matches!(t.mode, TlsMode::Prefer));
424        assert!(!t.key_path.is_empty());
425        assert!(!t.cert_path.is_empty());
426        assert!(!t.watch_enabled());
427    }
428
429    #[test]
430    fn test_tls_option_require() {
431        let s = r#"
432        {
433            "mode": "require",
434            "cert_path": "/some_dir/some.crt",
435            "key_path": "/some_dir/some.key"
436        }
437        "#;
438
439        let t: TlsOption = serde_json::from_str(s).unwrap();
440
441        assert!(t.should_force_tls());
442
443        assert!(matches!(t.mode, TlsMode::Require));
444        assert!(!t.key_path.is_empty());
445        assert!(!t.cert_path.is_empty());
446        assert!(!t.watch_enabled());
447    }
448
449    #[test]
450    fn test_tls_option_verify_ca() {
451        let s = r#"
452        {
453            "mode": "verify_ca",
454            "cert_path": "/some_dir/some.crt",
455            "key_path": "/some_dir/some.key"
456        }
457        "#;
458
459        let t: TlsOption = serde_json::from_str(s).unwrap();
460
461        assert!(t.should_force_tls());
462
463        assert!(matches!(t.mode, TlsMode::VerifyCa));
464        assert!(!t.key_path.is_empty());
465        assert!(!t.cert_path.is_empty());
466        assert!(!t.watch_enabled());
467    }
468
469    #[test]
470    fn test_tls_option_verify_full() {
471        let s = r#"
472        {
473            "mode": "verify_full",
474            "cert_path": "/some_dir/some.crt",
475            "key_path": "/some_dir/some.key"
476        }
477        "#;
478
479        let t: TlsOption = serde_json::from_str(s).unwrap();
480
481        assert!(t.should_force_tls());
482
483        assert!(matches!(t.mode, TlsMode::VerifyFull));
484        assert!(!t.key_path.is_empty());
485        assert!(!t.cert_path.is_empty());
486        assert!(!t.watch_enabled());
487    }
488
489    #[test]
490    fn test_tls_option_watch_enabled() {
491        let s = r#"
492        {
493            "mode": "verify_full",
494            "cert_path": "/some_dir/some.crt",
495            "key_path": "/some_dir/some.key",
496            "watch": true
497        }
498        "#;
499
500        let t: TlsOption = serde_json::from_str(s).unwrap();
501
502        assert!(t.should_force_tls());
503
504        assert!(matches!(t.mode, TlsMode::VerifyFull));
505        assert!(!t.key_path.is_empty());
506        assert!(!t.cert_path.is_empty());
507        assert!(t.watch_enabled());
508    }
509
510    #[test]
511    fn test_tls_file_change_watch() {
512        common_telemetry::init_default_ut_logging();
513        let _ = install_ring_crypto_provider();
514
515        let dir = tempfile::tempdir().unwrap();
516        let cert_path = dir.path().join("server.crt");
517        let key_path = dir.path().join("server.key");
518
519        std::fs::copy("tests/ssl/server.crt", &cert_path).expect("failed to copy cert to tmpdir");
520        std::fs::copy("tests/ssl/server-rsa.key", &key_path).expect("failed to copy key to tmpdir");
521
522        assert!(std::fs::exists(&cert_path).unwrap());
523        assert!(std::fs::exists(&key_path).unwrap());
524
525        let server_tls = TlsOption {
526            mode: TlsMode::Require,
527            cert_path: cert_path
528                .clone()
529                .into_os_string()
530                .into_string()
531                .expect("failed to convert path to string"),
532            key_path: key_path
533                .clone()
534                .into_os_string()
535                .into_string()
536                .expect("failed to convert path to string"),
537            ca_cert_path: String::new(),
538            watch: true,
539        };
540
541        let server_config = Arc::new(
542            ReloadableTlsServerConfig::try_new(server_tls).expect("failed to create server config"),
543        );
544        maybe_watch_server_tls_config(server_config.clone())
545            .expect("failed to watch server config");
546
547        assert_eq!(0, server_config.get_version());
548        assert!(server_config.get_config().is_some());
549
550        let tmp_file = key_path.with_extension("tmp");
551        std::fs::copy("tests/ssl/server-pkcs8.key", &tmp_file)
552            .expect("Failed to copy temp key file");
553        std::fs::rename(&tmp_file, &key_path).expect("Failed to rename temp key file");
554
555        const MAX_RETRIES: usize = 30;
556        let mut retries = 0;
557        let mut version_updated = false;
558
559        while retries < MAX_RETRIES {
560            if server_config.get_version() > 0 {
561                version_updated = true;
562                break;
563            }
564            std::thread::sleep(std::time::Duration::from_millis(100));
565            retries += 1;
566        }
567
568        assert!(version_updated, "TLS config did not reload in time");
569        assert!(server_config.get_version() > 0);
570        assert!(server_config.get_config().is_some());
571    }
572}