common_grpc/
reloadable_tls.rs1use std::path::Path;
16use std::result::Result as StdResult;
17use std::sync::atomic::{AtomicUsize, Ordering};
18use std::sync::{Arc, RwLock};
19
20use common_config::file_watcher::{FileWatcherBuilder, FileWatcherConfig};
21use common_telemetry::{error, info};
22use snafu::ResultExt;
23
24use crate::error::{FileWatchSnafu, Result};
25
26pub trait TlsConfigLoader<T> {
28 type Error;
29
30 fn load(&self) -> StdResult<Option<T>, Self::Error>;
32
33 fn watch_paths(&self) -> Vec<&Path>;
35
36 fn watch_enabled(&self) -> bool;
38}
39
40#[derive(Debug)]
46pub struct ReloadableTlsConfig<T, O>
47where
48 O: TlsConfigLoader<T>,
49{
50 tls_option: O,
51 config: RwLock<Option<T>>,
52 version: AtomicUsize,
53}
54
55impl<T, O> ReloadableTlsConfig<T, O>
56where
57 O: TlsConfigLoader<T>,
58{
59 pub fn try_new(tls_option: O) -> StdResult<Self, O::Error> {
61 let config = tls_option.load()?;
62 Ok(Self {
63 tls_option,
64 config: RwLock::new(config),
65 version: AtomicUsize::new(0),
66 })
67 }
68
69 pub fn reload(&self) -> StdResult<(), O::Error> {
71 let config = self.tls_option.load()?;
72 *self.config.write().unwrap() = config;
73 self.version.fetch_add(1, Ordering::Relaxed);
74 Ok(())
75 }
76
77 pub fn get_config(&self) -> Option<T>
79 where
80 T: Clone,
81 {
82 self.config.read().unwrap().clone()
83 }
84
85 pub fn get_tls_option(&self) -> &O {
87 &self.tls_option
88 }
89
90 pub fn get_version(&self) -> usize {
94 self.version.load(Ordering::Relaxed)
95 }
96}
97
98pub fn maybe_watch_tls_config<T, O, F, E>(
108 tls_config: Arc<ReloadableTlsConfig<T, O>>,
109 on_reload: F,
110) -> Result<()>
111where
112 T: Send + Sync + 'static,
113 O: TlsConfigLoader<T, Error = E> + Send + Sync + 'static,
114 E: std::error::Error + Send + Sync + 'static,
115 F: Fn() + Send + 'static,
116{
117 if !tls_config.get_tls_option().watch_enabled() {
118 return Ok(());
119 }
120
121 let watch_paths: Vec<_> = tls_config
122 .get_tls_option()
123 .watch_paths()
124 .iter()
125 .map(|p| p.to_path_buf())
126 .collect();
127
128 let tls_config_for_watcher = tls_config.clone();
129
130 FileWatcherBuilder::new()
131 .watch_paths(&watch_paths)
132 .context(FileWatchSnafu)?
133 .config(FileWatcherConfig::new())
134 .spawn(move || {
135 if let Err(err) = tls_config_for_watcher.reload() {
136 error!("Failed to reload TLS config: {}", err);
137 } else {
138 info!("Reloaded TLS cert/key file successfully.");
139 on_reload();
140 }
141 })
142 .context(FileWatchSnafu)?;
143
144 Ok(())
145}