servers/
tls.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fs::File;
16use std::io::{BufReader, Error as IoError, ErrorKind};
17use std::path::Path;
18use std::sync::Arc;
19
20use common_grpc::reloadable_tls::{ReloadableTlsConfig, TlsConfigLoader};
21use common_telemetry::error;
22use rustls::ServerConfig;
23use rustls_pemfile::{Item, certs, read_one};
24use rustls_pki_types::{CertificateDer, PrivateKeyDer};
25use serde::{Deserialize, Serialize};
26use snafu::ResultExt;
27use strum::EnumString;
28
29use crate::error::{InternalIoSnafu, Result};
30
31/// TlsMode is used for Mysql and Postgres server start up.
32#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq, EnumString)]
33#[serde(rename_all = "snake_case")]
34pub enum TlsMode {
35    #[default]
36    #[strum(to_string = "disable")]
37    Disable,
38
39    #[strum(to_string = "prefer")]
40    Prefer,
41
42    #[strum(to_string = "require")]
43    Require,
44
45    // TODO(SSebo): Implement the following 2 TSL mode described in
46    // ["34.19.3. Protection Provided in Different Modes"](https://www.postgresql.org/docs/current/libpq-ssl.html)
47    #[strum(to_string = "verify-ca")]
48    VerifyCa,
49
50    #[strum(to_string = "verify-full")]
51    VerifyFull,
52}
53
54#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq)]
55#[serde(rename_all = "snake_case")]
56pub struct TlsOption {
57    pub mode: TlsMode,
58    #[serde(default)]
59    pub cert_path: String,
60    #[serde(default)]
61    pub key_path: String,
62    #[serde(default)]
63    pub ca_cert_path: String,
64    #[serde(default)]
65    pub watch: bool,
66}
67
68impl TlsOption {
69    pub fn new(
70        mode: Option<TlsMode>,
71        cert_path: Option<String>,
72        key_path: Option<String>,
73        watch: bool,
74    ) -> Self {
75        let mut tls_option = TlsOption::default();
76
77        if let Some(mode) = mode {
78            tls_option.mode = mode
79        };
80
81        if let Some(cert_path) = cert_path {
82            tls_option.cert_path = cert_path
83        };
84
85        if let Some(key_path) = key_path {
86            tls_option.key_path = key_path
87        };
88
89        tls_option.watch = watch;
90
91        tls_option
92    }
93
94    pub fn setup(&self) -> Result<Option<ServerConfig>> {
95        if let TlsMode::Disable = self.mode {
96            return Ok(None);
97        }
98        let cert = certs(&mut BufReader::new(
99            File::open(&self.cert_path)
100                .inspect_err(|e| error!(e; "Failed to open {}", self.cert_path))
101                .context(InternalIoSnafu)?,
102        ))
103        .collect::<std::result::Result<Vec<CertificateDer>, IoError>>()
104        .context(InternalIoSnafu)?;
105
106        let mut key_reader = BufReader::new(
107            File::open(&self.key_path)
108                .inspect_err(|e| error!(e; "Failed to open {}", self.key_path))
109                .context(InternalIoSnafu)?,
110        );
111        let key = match read_one(&mut key_reader)
112            .inspect_err(|e| error!(e; "Failed to read {}", self.key_path))
113            .context(InternalIoSnafu)?
114        {
115            Some(Item::Pkcs1Key(key)) => PrivateKeyDer::from(key),
116            Some(Item::Pkcs8Key(key)) => PrivateKeyDer::from(key),
117            Some(Item::Sec1Key(key)) => PrivateKeyDer::from(key),
118            _ => {
119                return Err(IoError::new(ErrorKind::InvalidInput, "invalid key"))
120                    .context(InternalIoSnafu);
121            }
122        };
123
124        // TODO(SSebo): with_client_cert_verifier if TlsMode is Required.
125        let config = ServerConfig::builder()
126            .with_no_client_auth()
127            .with_single_cert(cert, key)
128            .map_err(|err| std::io::Error::new(ErrorKind::InvalidInput, err))?;
129
130        Ok(Some(config))
131    }
132
133    pub fn should_force_tls(&self) -> bool {
134        !matches!(self.mode, TlsMode::Disable | TlsMode::Prefer)
135    }
136
137    pub fn cert_path(&self) -> &Path {
138        Path::new(&self.cert_path)
139    }
140
141    pub fn key_path(&self) -> &Path {
142        Path::new(&self.key_path)
143    }
144
145    pub fn watch_enabled(&self) -> bool {
146        self.mode != TlsMode::Disable && self.watch
147    }
148}
149
150impl TlsConfigLoader<Arc<ServerConfig>> for TlsOption {
151    type Error = crate::error::Error;
152
153    fn load(&self) -> Result<Option<Arc<ServerConfig>>> {
154        Ok(self.setup()?.map(Arc::new))
155    }
156
157    fn watch_paths(&self) -> Vec<&Path> {
158        vec![self.cert_path(), self.key_path()]
159    }
160
161    fn watch_enabled(&self) -> bool {
162        self.mode != TlsMode::Disable && self.watch
163    }
164}
165
166/// Type alias for server-side reloadable TLS config
167pub type ReloadableTlsServerConfig = ReloadableTlsConfig<Arc<ServerConfig>, TlsOption>;
168
169/// Convenience function for watching server TLS configuration
170pub fn maybe_watch_server_tls_config(
171    tls_server_config: Arc<ReloadableTlsServerConfig>,
172) -> Result<()> {
173    common_grpc::reloadable_tls::maybe_watch_tls_config(tls_server_config, || {}).map_err(|e| {
174        crate::error::Error::Internal {
175            err_msg: format!("Failed to watch TLS config: {}", e),
176        }
177    })
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use crate::install_ring_crypto_provider;
184    use crate::tls::TlsMode::Disable;
185
186    #[test]
187    fn test_new_tls_option() {
188        assert_eq!(
189            TlsOption::default(),
190            TlsOption::new(None, None, None, false)
191        );
192        assert_eq!(
193            TlsOption {
194                mode: Disable,
195                ..Default::default()
196            },
197            TlsOption::new(Some(Disable), None, None, false)
198        );
199        assert_eq!(
200            TlsOption {
201                mode: Disable,
202                cert_path: "/path/to/cert_path".to_string(),
203                key_path: "/path/to/key_path".to_string(),
204                ca_cert_path: String::new(),
205                watch: false
206            },
207            TlsOption::new(
208                Some(Disable),
209                Some("/path/to/cert_path".to_string()),
210                Some("/path/to/key_path".to_string()),
211                false
212            )
213        );
214    }
215
216    #[test]
217    fn test_tls_option_disable() {
218        let s = r#"
219        {
220            "mode": "disable"
221        }
222        "#;
223
224        let t: TlsOption = serde_json::from_str(s).unwrap();
225
226        assert!(!t.should_force_tls());
227
228        assert!(matches!(t.mode, TlsMode::Disable));
229        assert!(t.key_path.is_empty());
230        assert!(t.cert_path.is_empty());
231        assert!(!t.watch_enabled());
232
233        let setup = t.setup();
234        let setup = setup.unwrap();
235        assert!(setup.is_none());
236    }
237
238    #[test]
239    fn test_tls_option_prefer() {
240        let s = r#"
241        {
242            "mode": "prefer",
243            "cert_path": "/some_dir/some.crt",
244            "key_path": "/some_dir/some.key"
245        }
246        "#;
247
248        let t: TlsOption = serde_json::from_str(s).unwrap();
249
250        assert!(!t.should_force_tls());
251
252        assert!(matches!(t.mode, TlsMode::Prefer));
253        assert!(!t.key_path.is_empty());
254        assert!(!t.cert_path.is_empty());
255        assert!(!t.watch_enabled());
256    }
257
258    #[test]
259    fn test_tls_option_require() {
260        let s = r#"
261        {
262            "mode": "require",
263            "cert_path": "/some_dir/some.crt",
264            "key_path": "/some_dir/some.key"
265        }
266        "#;
267
268        let t: TlsOption = serde_json::from_str(s).unwrap();
269
270        assert!(t.should_force_tls());
271
272        assert!(matches!(t.mode, TlsMode::Require));
273        assert!(!t.key_path.is_empty());
274        assert!(!t.cert_path.is_empty());
275        assert!(!t.watch_enabled());
276    }
277
278    #[test]
279    fn test_tls_option_verify_ca() {
280        let s = r#"
281        {
282            "mode": "verify_ca",
283            "cert_path": "/some_dir/some.crt",
284            "key_path": "/some_dir/some.key"
285        }
286        "#;
287
288        let t: TlsOption = serde_json::from_str(s).unwrap();
289
290        assert!(t.should_force_tls());
291
292        assert!(matches!(t.mode, TlsMode::VerifyCa));
293        assert!(!t.key_path.is_empty());
294        assert!(!t.cert_path.is_empty());
295        assert!(!t.watch_enabled());
296    }
297
298    #[test]
299    fn test_tls_option_verify_full() {
300        let s = r#"
301        {
302            "mode": "verify_full",
303            "cert_path": "/some_dir/some.crt",
304            "key_path": "/some_dir/some.key"
305        }
306        "#;
307
308        let t: TlsOption = serde_json::from_str(s).unwrap();
309
310        assert!(t.should_force_tls());
311
312        assert!(matches!(t.mode, TlsMode::VerifyFull));
313        assert!(!t.key_path.is_empty());
314        assert!(!t.cert_path.is_empty());
315        assert!(!t.watch_enabled());
316    }
317
318    #[test]
319    fn test_tls_option_watch_enabled() {
320        let s = r#"
321        {
322            "mode": "verify_full",
323            "cert_path": "/some_dir/some.crt",
324            "key_path": "/some_dir/some.key",
325            "watch": true
326        }
327        "#;
328
329        let t: TlsOption = serde_json::from_str(s).unwrap();
330
331        assert!(t.should_force_tls());
332
333        assert!(matches!(t.mode, TlsMode::VerifyFull));
334        assert!(!t.key_path.is_empty());
335        assert!(!t.cert_path.is_empty());
336        assert!(t.watch_enabled());
337    }
338
339    #[test]
340    fn test_tls_file_change_watch() {
341        common_telemetry::init_default_ut_logging();
342        let _ = install_ring_crypto_provider();
343
344        let dir = tempfile::tempdir().unwrap();
345        let cert_path = dir.path().join("server.crt");
346        let key_path = dir.path().join("server.key");
347
348        std::fs::copy("tests/ssl/server.crt", &cert_path).expect("failed to copy cert to tmpdir");
349        std::fs::copy("tests/ssl/server-rsa.key", &key_path).expect("failed to copy key to tmpdir");
350
351        assert!(std::fs::exists(&cert_path).unwrap());
352        assert!(std::fs::exists(&key_path).unwrap());
353
354        let server_tls = TlsOption {
355            mode: TlsMode::Require,
356            cert_path: cert_path
357                .clone()
358                .into_os_string()
359                .into_string()
360                .expect("failed to convert path to string"),
361            key_path: key_path
362                .clone()
363                .into_os_string()
364                .into_string()
365                .expect("failed to convert path to string"),
366            ca_cert_path: String::new(),
367            watch: true,
368        };
369
370        let server_config = Arc::new(
371            ReloadableTlsServerConfig::try_new(server_tls).expect("failed to create server config"),
372        );
373        maybe_watch_server_tls_config(server_config.clone())
374            .expect("failed to watch server config");
375
376        assert_eq!(0, server_config.get_version());
377        assert!(server_config.get_config().is_some());
378
379        let tmp_file = key_path.with_extension("tmp");
380        std::fs::copy("tests/ssl/server-pkcs8.key", &tmp_file)
381            .expect("Failed to copy temp key file");
382        std::fs::rename(&tmp_file, &key_path).expect("Failed to rename temp key file");
383
384        const MAX_RETRIES: usize = 30;
385        let mut retries = 0;
386        let mut version_updated = false;
387
388        while retries < MAX_RETRIES {
389            if server_config.get_version() > 0 {
390                version_updated = true;
391                break;
392            }
393            std::thread::sleep(std::time::Duration::from_millis(100));
394            retries += 1;
395        }
396
397        assert!(version_updated, "TLS config did not reload in time");
398        assert!(server_config.get_version() > 0);
399        assert!(server_config.get_config().is_some());
400    }
401}