common_grpc/
reloadable_tls.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::path::Path;
16use std::result::Result as StdResult;
17use std::sync::atomic::{AtomicUsize, Ordering};
18use std::sync::{Arc, RwLock};
19
20use common_config::file_watcher::{FileWatcherBuilder, FileWatcherConfig};
21use common_telemetry::{error, info};
22use snafu::ResultExt;
23
24use crate::error::{FileWatchSnafu, Result};
25
26/// A trait for loading TLS configuration from an option type
27pub trait TlsConfigLoader<T> {
28    type Error;
29
30    /// Load the TLS configuration
31    fn load(&self) -> StdResult<Option<T>, Self::Error>;
32
33    /// Get paths to certificate files for watching
34    fn watch_paths(&self) -> Vec<&Path>;
35
36    /// Check if watching is enabled
37    fn watch_enabled(&self) -> bool;
38}
39
40/// A mutable container for TLS config
41///
42/// This struct allows dynamic reloading of certificates and keys.
43/// It's generic over the config type (e.g., ServerConfig, ClientTlsConfig)
44/// and the option type (e.g., TlsOption, ClientTlsOption).
45#[derive(Debug)]
46pub struct ReloadableTlsConfig<T, O>
47where
48    O: TlsConfigLoader<T>,
49{
50    tls_option: O,
51    config: RwLock<Option<T>>,
52    version: AtomicUsize,
53}
54
55impl<T, O> ReloadableTlsConfig<T, O>
56where
57    O: TlsConfigLoader<T>,
58{
59    /// Create config by loading configuration from the option type
60    pub fn try_new(tls_option: O) -> StdResult<Self, O::Error> {
61        let config = tls_option.load()?;
62        Ok(Self {
63            tls_option,
64            config: RwLock::new(config),
65            version: AtomicUsize::new(0),
66        })
67    }
68
69    /// Reread certificates and keys from file system.
70    pub fn reload(&self) -> StdResult<(), O::Error> {
71        let config = self.tls_option.load()?;
72        *self.config.write().unwrap() = config;
73        self.version.fetch_add(1, Ordering::Relaxed);
74        Ok(())
75    }
76
77    /// Get the config held by this container
78    pub fn get_config(&self) -> Option<T>
79    where
80        T: Clone,
81    {
82        self.config.read().unwrap().clone()
83    }
84
85    /// Get associated option
86    pub fn get_tls_option(&self) -> &O {
87        &self.tls_option
88    }
89
90    /// Get version of current config
91    ///
92    /// this version will auto increase when config get reloaded.
93    pub fn get_version(&self) -> usize {
94        self.version.load(Ordering::Relaxed)
95    }
96}
97
98/// Watch TLS configuration files for changes and reload automatically
99///
100/// This is a generic function that works with any ReloadableTlsConfig.
101/// When changes are detected, it calls the provided callback after reloading.
102///
103/// T: the original TLS config
104/// O: the compiled TLS option
105/// F: the hook function to be called after reloading
106/// E: the error type for the loading operation
107pub fn maybe_watch_tls_config<T, O, F, E>(
108    tls_config: Arc<ReloadableTlsConfig<T, O>>,
109    on_reload: F,
110) -> Result<()>
111where
112    T: Send + Sync + 'static,
113    O: TlsConfigLoader<T, Error = E> + Send + Sync + 'static,
114    E: std::error::Error + Send + Sync + 'static,
115    F: Fn() + Send + 'static,
116{
117    if !tls_config.get_tls_option().watch_enabled() {
118        return Ok(());
119    }
120
121    let watch_paths: Vec<_> = tls_config
122        .get_tls_option()
123        .watch_paths()
124        .iter()
125        .map(|p| p.to_path_buf())
126        .collect();
127
128    let tls_config_for_watcher = tls_config.clone();
129
130    FileWatcherBuilder::new()
131        .watch_paths(&watch_paths)
132        .context(FileWatchSnafu)?
133        .config(FileWatcherConfig::new())
134        .spawn(move || {
135            if let Err(err) = tls_config_for_watcher.reload() {
136                error!("Failed to reload TLS config: {}", err);
137            } else {
138                info!("Reloaded TLS cert/key file successfully.");
139                on_reload();
140            }
141        })
142        .context(FileWatchSnafu)?;
143
144    Ok(())
145}