1use std::collections::HashMap;
16use std::path::Path;
17use std::sync::mpsc::channel;
18use std::sync::{Arc, Mutex};
19
20use async_trait::async_trait;
21use common_telemetry::{info, warn};
22use notify::{EventKind, RecursiveMode, Watcher};
23use snafu::{ensure, ResultExt};
24
25use crate::error::{FileWatchSnafu, InvalidConfigSnafu, Result};
26use crate::user_info::DefaultUserInfo;
27use crate::user_provider::{authenticate_with_credential, load_credential_from_file};
28use crate::{Identity, Password, UserInfoRef, UserProvider};
29
30pub(crate) const WATCH_FILE_USER_PROVIDER: &str = "watch_file_user_provider";
31
32type WatchedCredentialRef = Arc<Mutex<Option<HashMap<String, Vec<u8>>>>>;
33
34pub(crate) struct WatchFileUserProvider {
38 users: WatchedCredentialRef,
39}
40
41impl WatchFileUserProvider {
42 pub fn new(filepath: &str) -> Result<Self> {
43 let credential = load_credential_from_file(filepath)?;
44 let users = Arc::new(Mutex::new(credential));
45 let this = WatchFileUserProvider {
46 users: users.clone(),
47 };
48
49 let (tx, rx) = channel::<notify::Result<notify::Event>>();
50 let mut debouncer =
51 notify::recommended_watcher(tx).context(FileWatchSnafu { path: "<none>" })?;
52 let mut dir = Path::new(filepath).to_path_buf();
53 ensure!(
54 dir.pop(),
55 InvalidConfigSnafu {
56 value: filepath,
57 msg: "UserProvider path must be a file path",
58 }
59 );
60 debouncer
61 .watch(&dir, RecursiveMode::NonRecursive)
62 .context(FileWatchSnafu { path: filepath })?;
63
64 let filepath = filepath.to_string();
65 std::thread::spawn(move || {
66 let filename = Path::new(&filepath).file_name();
67 let _hold = debouncer;
68 while let Ok(res) = rx.recv() {
69 if let Ok(event) = res {
70 let is_this_file = event.paths.iter().any(|p| p.file_name() == filename);
71 let is_relevant_event = matches!(
72 event.kind,
73 EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)
74 );
75 if is_this_file && is_relevant_event {
76 info!(?event.kind, "User provider file {} changed", &filepath);
77 match load_credential_from_file(&filepath) {
78 Ok(credential) => {
79 let mut users =
80 users.lock().expect("users credential must be valid");
81 #[cfg(not(test))]
82 info!("User provider file {filepath} reloaded");
83 #[cfg(test)]
84 info!("User provider file {filepath} reloaded: {credential:?}");
85 *users = credential;
86 }
87 Err(err) => {
88 warn!(
89 ?err,
90 "Fail to load credential from file {filepath}; keep the old one",
91 )
92 }
93 }
94 }
95 }
96 }
97 });
98
99 Ok(this)
100 }
101}
102
103#[async_trait]
104impl UserProvider for WatchFileUserProvider {
105 fn name(&self) -> &str {
106 WATCH_FILE_USER_PROVIDER
107 }
108
109 async fn authenticate(&self, id: Identity<'_>, password: Password<'_>) -> Result<UserInfoRef> {
110 let users = self.users.lock().expect("users credential must be valid");
111 if let Some(users) = users.as_ref() {
112 authenticate_with_credential(users, id, password)
113 } else {
114 match id {
115 Identity::UserId(id, _) => {
116 warn!(id, "User provider file not exist, allow all users");
117 Ok(DefaultUserInfo::with_name(id))
118 }
119 }
120 }
121 }
122
123 async fn authorize(&self, _: &str, _: &str, _: &UserInfoRef) -> Result<()> {
124 Ok(())
126 }
127}
128
129#[cfg(test)]
130pub mod test {
131 use std::time::{Duration, Instant};
132
133 use common_test_util::temp_dir::create_temp_dir;
134 use tokio::time::sleep;
135
136 use crate::user_provider::watch_file_user_provider::WatchFileUserProvider;
137 use crate::user_provider::{Identity, Password};
138 use crate::UserProvider;
139
140 async fn test_authenticate(
141 provider: &dyn UserProvider,
142 username: &str,
143 password: &str,
144 ok: bool,
145 timeout: Option<Duration>,
146 ) {
147 if let Some(timeout) = timeout {
148 let deadline = Instant::now().checked_add(timeout).unwrap();
149 loop {
150 let re = provider
151 .authenticate(
152 Identity::UserId(username, None),
153 Password::PlainText(password.to_string().into()),
154 )
155 .await;
156 if re.is_ok() == ok {
157 break;
158 } else if Instant::now() < deadline {
159 sleep(Duration::from_millis(100)).await;
160 } else {
161 panic!("timeout (username: {username}, password: {password}, expected: {ok})");
162 }
163 }
164 } else {
165 let re = provider
166 .authenticate(
167 Identity::UserId(username, None),
168 Password::PlainText(password.to_string().into()),
169 )
170 .await;
171 assert_eq!(
172 re.is_ok(),
173 ok,
174 "username: {}, password: {}",
175 username,
176 password
177 );
178 }
179 }
180
181 #[tokio::test]
182 async fn test_file_provider() {
183 common_telemetry::init_default_ut_logging();
184
185 let dir = create_temp_dir("test_file_provider");
186 let file_path = format!("{}/test_file_provider", dir.path().to_str().unwrap());
187
188 assert!(std::fs::write(&file_path, "root=123456\nadmin=654321\n").is_ok());
190 let provider = WatchFileUserProvider::new(file_path.as_str()).unwrap();
191 let timeout = Duration::from_secs(60);
192
193 test_authenticate(&provider, "root", "123456", true, None).await;
194 test_authenticate(&provider, "admin", "654321", true, None).await;
195 test_authenticate(&provider, "root", "654321", false, None).await;
196
197 assert!(std::fs::write(&file_path, "root=654321\n").is_ok());
199 test_authenticate(&provider, "root", "123456", false, Some(timeout)).await;
200 test_authenticate(&provider, "root", "654321", true, Some(timeout)).await;
201 test_authenticate(&provider, "admin", "654321", false, Some(timeout)).await;
202
203 assert!(std::fs::remove_file(&file_path).is_ok());
205 test_authenticate(&provider, "root", "123456", true, Some(timeout)).await;
206 test_authenticate(&provider, "root", "654321", true, Some(timeout)).await;
207 test_authenticate(&provider, "admin", "654321", true, Some(timeout)).await;
208
209 assert!(std::fs::write(&file_path, "root=123456\n").is_ok());
211 test_authenticate(&provider, "root", "123456", true, Some(timeout)).await;
212 test_authenticate(&provider, "root", "654321", false, Some(timeout)).await;
213 test_authenticate(&provider, "admin", "654321", false, Some(timeout)).await;
214 }
215}