1use std::path::Path;
16use std::sync::mpsc::channel;
17use std::sync::{Arc, Mutex};
18
19use async_trait::async_trait;
20use common_telemetry::{info, warn};
21use notify::{EventKind, RecursiveMode, Watcher};
22use snafu::{ResultExt, ensure};
23
24use crate::error::{FileWatchSnafu, InvalidConfigSnafu, Result};
25use crate::user_provider::{UserInfoMap, authenticate_with_credential, load_credential_from_file};
26use crate::{Identity, Password, UserInfoRef, UserProvider};
27
28pub(crate) const WATCH_FILE_USER_PROVIDER: &str = "watch_file_user_provider";
29
30type WatchedCredentialRef = Arc<Mutex<UserInfoMap>>;
31
32#[derive(Debug)]
36pub(crate) struct WatchFileUserProvider {
37 users: WatchedCredentialRef,
38}
39
40impl WatchFileUserProvider {
41 pub fn new(filepath: &str) -> Result<Self> {
42 let credential = load_credential_from_file(filepath)?;
43 let users = Arc::new(Mutex::new(credential));
44 let this = WatchFileUserProvider {
45 users: users.clone(),
46 };
47
48 let (tx, rx) = channel::<notify::Result<notify::Event>>();
49 let mut debouncer =
50 notify::recommended_watcher(tx).context(FileWatchSnafu { path: "<none>" })?;
51 let mut dir = Path::new(filepath).to_path_buf();
52 ensure!(
53 dir.pop(),
54 InvalidConfigSnafu {
55 value: filepath,
56 msg: "UserProvider path must be a file path",
57 }
58 );
59 debouncer
60 .watch(&dir, RecursiveMode::NonRecursive)
61 .context(FileWatchSnafu { path: filepath })?;
62
63 let filepath = filepath.to_string();
64 std::thread::spawn(move || {
65 let filename = Path::new(&filepath).file_name();
66 let _hold = debouncer;
67 while let Ok(res) = rx.recv() {
68 if let Ok(event) = res {
69 let is_this_file = event.paths.iter().any(|p| p.file_name() == filename);
70 let is_relevant_event = matches!(
71 event.kind,
72 EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)
73 );
74 if is_this_file && is_relevant_event {
75 info!(?event.kind, "User provider file {} changed", &filepath);
76 match load_credential_from_file(&filepath) {
77 Ok(credential) => {
78 let mut users =
79 users.lock().expect("users credential must be valid");
80 #[cfg(not(test))]
81 info!("User provider file {filepath} reloaded");
82 #[cfg(test)]
83 info!("User provider file {filepath} reloaded: {credential:?}");
84 *users = credential;
85 }
86 Err(err) => {
87 warn!(
88 ?err,
89 "Fail to load credential from file {filepath}; keep the old one",
90 )
91 }
92 }
93 }
94 }
95 }
96 });
97
98 Ok(this)
99 }
100}
101
102#[async_trait]
103impl UserProvider for WatchFileUserProvider {
104 fn name(&self) -> &str {
105 WATCH_FILE_USER_PROVIDER
106 }
107
108 async fn authenticate(&self, id: Identity<'_>, password: Password<'_>) -> Result<UserInfoRef> {
109 let users = self.users.lock().expect("users credential must be valid");
110 authenticate_with_credential(&users, id, password)
111 }
112
113 async fn authorize(&self, _: &str, _: &str, _: &UserInfoRef) -> Result<()> {
114 Ok(())
116 }
117}
118
119#[cfg(test)]
120pub mod test {
121 use std::time::{Duration, Instant};
122
123 use common_test_util::temp_dir::create_temp_dir;
124 use tokio::time::sleep;
125
126 use crate::UserProvider;
127 use crate::user_provider::watch_file_user_provider::WatchFileUserProvider;
128 use crate::user_provider::{Identity, Password};
129
130 async fn test_authenticate(
131 provider: &dyn UserProvider,
132 username: &str,
133 password: &str,
134 ok: bool,
135 timeout: Option<Duration>,
136 ) {
137 if let Some(timeout) = timeout {
138 let deadline = Instant::now().checked_add(timeout).unwrap();
139 loop {
140 let re = provider
141 .authenticate(
142 Identity::UserId(username, None),
143 Password::PlainText(password.to_string().into()),
144 )
145 .await;
146 if re.is_ok() == ok {
147 break;
148 } else if Instant::now() < deadline {
149 sleep(Duration::from_millis(100)).await;
150 } else {
151 panic!("timeout (username: {username}, password: {password}, expected: {ok})");
152 }
153 }
154 } else {
155 let re = provider
156 .authenticate(
157 Identity::UserId(username, None),
158 Password::PlainText(password.to_string().into()),
159 )
160 .await;
161 assert_eq!(
162 re.is_ok(),
163 ok,
164 "username: {}, password: {}",
165 username,
166 password
167 );
168 }
169 }
170
171 #[tokio::test]
172 async fn test_file_provider_initialization_with_missing_file() {
173 common_telemetry::init_default_ut_logging();
174
175 let dir = create_temp_dir("test_missing_file");
176 let file_path = format!("{}/non_existent_file", dir.path().to_str().unwrap());
177
178 let result = WatchFileUserProvider::new(file_path.as_str());
180 assert!(result.is_err());
181
182 let error = result.unwrap_err();
183 assert!(error.to_string().contains("UserProvider file must exist"));
184 }
185
186 #[tokio::test]
187 async fn test_file_provider() {
188 common_telemetry::init_default_ut_logging();
189
190 let dir = create_temp_dir("test_file_provider");
191 let file_path = format!("{}/test_file_provider", dir.path().to_str().unwrap());
192
193 assert!(std::fs::write(&file_path, "root=123456\nadmin=654321\n").is_ok());
195 let provider = WatchFileUserProvider::new(file_path.as_str()).unwrap();
196 let timeout = Duration::from_secs(60);
197
198 test_authenticate(&provider, "root", "123456", true, None).await;
199 test_authenticate(&provider, "admin", "654321", true, None).await;
200 test_authenticate(&provider, "root", "654321", false, None).await;
201
202 assert!(std::fs::write(&file_path, "root=654321\n").is_ok());
204 test_authenticate(&provider, "root", "123456", false, Some(timeout)).await;
205 test_authenticate(&provider, "root", "654321", true, Some(timeout)).await;
206 test_authenticate(&provider, "admin", "654321", false, Some(timeout)).await;
207
208 assert!(std::fs::remove_file(&file_path).is_ok());
210 test_authenticate(&provider, "root", "654321", true, Some(timeout)).await;
212 test_authenticate(&provider, "root", "123456", false, Some(timeout)).await;
213 test_authenticate(&provider, "admin", "654321", false, Some(timeout)).await;
214
215 assert!(std::fs::write(&file_path, "root=123456\n").is_ok());
217 test_authenticate(&provider, "root", "123456", true, Some(timeout)).await;
218 test_authenticate(&provider, "root", "654321", false, Some(timeout)).await;
219 test_authenticate(&provider, "admin", "654321", false, Some(timeout)).await;
220 }
221}