auth/user_provider/
watch_file_user_provider.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::path::Path;
17use std::sync::mpsc::channel;
18use std::sync::{Arc, Mutex};
19
20use async_trait::async_trait;
21use common_telemetry::{info, warn};
22use notify::{EventKind, RecursiveMode, Watcher};
23use snafu::{ensure, ResultExt};
24
25use crate::error::{FileWatchSnafu, InvalidConfigSnafu, Result};
26use crate::user_info::DefaultUserInfo;
27use crate::user_provider::{authenticate_with_credential, load_credential_from_file};
28use crate::{Identity, Password, UserInfoRef, UserProvider};
29
30pub(crate) const WATCH_FILE_USER_PROVIDER: &str = "watch_file_user_provider";
31
32type WatchedCredentialRef = Arc<Mutex<Option<HashMap<String, Vec<u8>>>>>;
33
34/// A user provider that reads user credential from a file and watches the file for changes.
35///
36/// Empty file is invalid; but file not exist means every user can be authenticated.
37pub(crate) struct WatchFileUserProvider {
38    users: WatchedCredentialRef,
39}
40
41impl WatchFileUserProvider {
42    pub fn new(filepath: &str) -> Result<Self> {
43        let credential = load_credential_from_file(filepath)?;
44        let users = Arc::new(Mutex::new(credential));
45        let this = WatchFileUserProvider {
46            users: users.clone(),
47        };
48
49        let (tx, rx) = channel::<notify::Result<notify::Event>>();
50        let mut debouncer =
51            notify::recommended_watcher(tx).context(FileWatchSnafu { path: "<none>" })?;
52        let mut dir = Path::new(filepath).to_path_buf();
53        ensure!(
54            dir.pop(),
55            InvalidConfigSnafu {
56                value: filepath,
57                msg: "UserProvider path must be a file path",
58            }
59        );
60        debouncer
61            .watch(&dir, RecursiveMode::NonRecursive)
62            .context(FileWatchSnafu { path: filepath })?;
63
64        let filepath = filepath.to_string();
65        std::thread::spawn(move || {
66            let filename = Path::new(&filepath).file_name();
67            let _hold = debouncer;
68            while let Ok(res) = rx.recv() {
69                if let Ok(event) = res {
70                    let is_this_file = event.paths.iter().any(|p| p.file_name() == filename);
71                    let is_relevant_event = matches!(
72                        event.kind,
73                        EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)
74                    );
75                    if is_this_file && is_relevant_event {
76                        info!(?event.kind, "User provider file {} changed", &filepath);
77                        match load_credential_from_file(&filepath) {
78                            Ok(credential) => {
79                                let mut users =
80                                    users.lock().expect("users credential must be valid");
81                                #[cfg(not(test))]
82                                info!("User provider file {filepath} reloaded");
83                                #[cfg(test)]
84                                info!("User provider file {filepath} reloaded: {credential:?}");
85                                *users = credential;
86                            }
87                            Err(err) => {
88                                warn!(
89                                    ?err,
90                                    "Fail to load credential from file {filepath}; keep the old one",
91                                )
92                            }
93                        }
94                    }
95                }
96            }
97        });
98
99        Ok(this)
100    }
101}
102
103#[async_trait]
104impl UserProvider for WatchFileUserProvider {
105    fn name(&self) -> &str {
106        WATCH_FILE_USER_PROVIDER
107    }
108
109    async fn authenticate(&self, id: Identity<'_>, password: Password<'_>) -> Result<UserInfoRef> {
110        let users = self.users.lock().expect("users credential must be valid");
111        if let Some(users) = users.as_ref() {
112            authenticate_with_credential(users, id, password)
113        } else {
114            match id {
115                Identity::UserId(id, _) => {
116                    warn!(id, "User provider file not exist, allow all users");
117                    Ok(DefaultUserInfo::with_name(id))
118                }
119            }
120        }
121    }
122
123    async fn authorize(&self, _: &str, _: &str, _: &UserInfoRef) -> Result<()> {
124        // default allow all
125        Ok(())
126    }
127}
128
129#[cfg(test)]
130pub mod test {
131    use std::time::{Duration, Instant};
132
133    use common_test_util::temp_dir::create_temp_dir;
134    use tokio::time::sleep;
135
136    use crate::user_provider::watch_file_user_provider::WatchFileUserProvider;
137    use crate::user_provider::{Identity, Password};
138    use crate::UserProvider;
139
140    async fn test_authenticate(
141        provider: &dyn UserProvider,
142        username: &str,
143        password: &str,
144        ok: bool,
145        timeout: Option<Duration>,
146    ) {
147        if let Some(timeout) = timeout {
148            let deadline = Instant::now().checked_add(timeout).unwrap();
149            loop {
150                let re = provider
151                    .authenticate(
152                        Identity::UserId(username, None),
153                        Password::PlainText(password.to_string().into()),
154                    )
155                    .await;
156                if re.is_ok() == ok {
157                    break;
158                } else if Instant::now() < deadline {
159                    sleep(Duration::from_millis(100)).await;
160                } else {
161                    panic!("timeout (username: {username}, password: {password}, expected: {ok})");
162                }
163            }
164        } else {
165            let re = provider
166                .authenticate(
167                    Identity::UserId(username, None),
168                    Password::PlainText(password.to_string().into()),
169                )
170                .await;
171            assert_eq!(
172                re.is_ok(),
173                ok,
174                "username: {}, password: {}",
175                username,
176                password
177            );
178        }
179    }
180
181    #[tokio::test]
182    async fn test_file_provider() {
183        common_telemetry::init_default_ut_logging();
184
185        let dir = create_temp_dir("test_file_provider");
186        let file_path = format!("{}/test_file_provider", dir.path().to_str().unwrap());
187
188        // write a tmp file
189        assert!(std::fs::write(&file_path, "root=123456\nadmin=654321\n").is_ok());
190        let provider = WatchFileUserProvider::new(file_path.as_str()).unwrap();
191        let timeout = Duration::from_secs(60);
192
193        test_authenticate(&provider, "root", "123456", true, None).await;
194        test_authenticate(&provider, "admin", "654321", true, None).await;
195        test_authenticate(&provider, "root", "654321", false, None).await;
196
197        // update the tmp file
198        assert!(std::fs::write(&file_path, "root=654321\n").is_ok());
199        test_authenticate(&provider, "root", "123456", false, Some(timeout)).await;
200        test_authenticate(&provider, "root", "654321", true, Some(timeout)).await;
201        test_authenticate(&provider, "admin", "654321", false, Some(timeout)).await;
202
203        // remove the tmp file
204        assert!(std::fs::remove_file(&file_path).is_ok());
205        test_authenticate(&provider, "root", "123456", true, Some(timeout)).await;
206        test_authenticate(&provider, "root", "654321", true, Some(timeout)).await;
207        test_authenticate(&provider, "admin", "654321", true, Some(timeout)).await;
208
209        // recreate the tmp file
210        assert!(std::fs::write(&file_path, "root=123456\n").is_ok());
211        test_authenticate(&provider, "root", "123456", true, Some(timeout)).await;
212        test_authenticate(&provider, "root", "654321", false, Some(timeout)).await;
213        test_authenticate(&provider, "admin", "654321", false, Some(timeout)).await;
214    }
215}