auth/user_provider/
watch_file_user_provider.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::path::Path;
16use std::sync::mpsc::channel;
17use std::sync::{Arc, Mutex};
18
19use async_trait::async_trait;
20use common_telemetry::{info, warn};
21use notify::{EventKind, RecursiveMode, Watcher};
22use snafu::{ResultExt, ensure};
23
24use crate::error::{FileWatchSnafu, InvalidConfigSnafu, Result};
25use crate::user_provider::{UserInfoMap, authenticate_with_credential, load_credential_from_file};
26use crate::{Identity, Password, UserInfoRef, UserProvider};
27
28pub(crate) const WATCH_FILE_USER_PROVIDER: &str = "watch_file_user_provider";
29
30type WatchedCredentialRef = Arc<Mutex<UserInfoMap>>;
31
32/// A user provider that reads user credential from a file and watches the file for changes.
33///
34/// Both empty file and non-existent file are invalid and will cause initialization to fail.
35#[derive(Debug)]
36pub(crate) struct WatchFileUserProvider {
37    users: WatchedCredentialRef,
38}
39
40impl WatchFileUserProvider {
41    pub fn new(filepath: &str) -> Result<Self> {
42        let credential = load_credential_from_file(filepath)?;
43        let users = Arc::new(Mutex::new(credential));
44        let this = WatchFileUserProvider {
45            users: users.clone(),
46        };
47
48        let (tx, rx) = channel::<notify::Result<notify::Event>>();
49        let mut debouncer =
50            notify::recommended_watcher(tx).context(FileWatchSnafu { path: "<none>" })?;
51        let mut dir = Path::new(filepath).to_path_buf();
52        ensure!(
53            dir.pop(),
54            InvalidConfigSnafu {
55                value: filepath,
56                msg: "UserProvider path must be a file path",
57            }
58        );
59        debouncer
60            .watch(&dir, RecursiveMode::NonRecursive)
61            .context(FileWatchSnafu { path: filepath })?;
62
63        let filepath = filepath.to_string();
64        std::thread::spawn(move || {
65            let filename = Path::new(&filepath).file_name();
66            let _hold = debouncer;
67            while let Ok(res) = rx.recv() {
68                if let Ok(event) = res {
69                    let is_this_file = event.paths.iter().any(|p| p.file_name() == filename);
70                    let is_relevant_event = matches!(
71                        event.kind,
72                        EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)
73                    );
74                    if is_this_file && is_relevant_event {
75                        info!(?event.kind, "User provider file {} changed", &filepath);
76                        match load_credential_from_file(&filepath) {
77                            Ok(credential) => {
78                                let mut users =
79                                    users.lock().expect("users credential must be valid");
80                                #[cfg(not(test))]
81                                info!("User provider file {filepath} reloaded");
82                                #[cfg(test)]
83                                info!("User provider file {filepath} reloaded: {credential:?}");
84                                *users = credential;
85                            }
86                            Err(err) => {
87                                warn!(
88                                    ?err,
89                                    "Fail to load credential from file {filepath}; keep the old one",
90                                )
91                            }
92                        }
93                    }
94                }
95            }
96        });
97
98        Ok(this)
99    }
100}
101
102#[async_trait]
103impl UserProvider for WatchFileUserProvider {
104    fn name(&self) -> &str {
105        WATCH_FILE_USER_PROVIDER
106    }
107
108    async fn authenticate(&self, id: Identity<'_>, password: Password<'_>) -> Result<UserInfoRef> {
109        let users = self.users.lock().expect("users credential must be valid");
110        authenticate_with_credential(&users, id, password)
111    }
112
113    async fn authorize(&self, _: &str, _: &str, _: &UserInfoRef) -> Result<()> {
114        // default allow all
115        Ok(())
116    }
117}
118
119#[cfg(test)]
120pub mod test {
121    use std::time::{Duration, Instant};
122
123    use common_test_util::temp_dir::create_temp_dir;
124    use tokio::time::sleep;
125
126    use crate::UserProvider;
127    use crate::user_provider::watch_file_user_provider::WatchFileUserProvider;
128    use crate::user_provider::{Identity, Password};
129
130    async fn test_authenticate(
131        provider: &dyn UserProvider,
132        username: &str,
133        password: &str,
134        ok: bool,
135        timeout: Option<Duration>,
136    ) {
137        if let Some(timeout) = timeout {
138            let deadline = Instant::now().checked_add(timeout).unwrap();
139            loop {
140                let re = provider
141                    .authenticate(
142                        Identity::UserId(username, None),
143                        Password::PlainText(password.to_string().into()),
144                    )
145                    .await;
146                if re.is_ok() == ok {
147                    break;
148                } else if Instant::now() < deadline {
149                    sleep(Duration::from_millis(100)).await;
150                } else {
151                    panic!("timeout (username: {username}, password: {password}, expected: {ok})");
152                }
153            }
154        } else {
155            let re = provider
156                .authenticate(
157                    Identity::UserId(username, None),
158                    Password::PlainText(password.to_string().into()),
159                )
160                .await;
161            assert_eq!(
162                re.is_ok(),
163                ok,
164                "username: {}, password: {}",
165                username,
166                password
167            );
168        }
169    }
170
171    #[tokio::test]
172    async fn test_file_provider_initialization_with_missing_file() {
173        common_telemetry::init_default_ut_logging();
174
175        let dir = create_temp_dir("test_missing_file");
176        let file_path = format!("{}/non_existent_file", dir.path().to_str().unwrap());
177
178        // Try to create provider with non-existent file should fail
179        let result = WatchFileUserProvider::new(file_path.as_str());
180        assert!(result.is_err());
181
182        let error = result.unwrap_err();
183        assert!(error.to_string().contains("UserProvider file must exist"));
184    }
185
186    #[tokio::test]
187    async fn test_file_provider() {
188        common_telemetry::init_default_ut_logging();
189
190        let dir = create_temp_dir("test_file_provider");
191        let file_path = format!("{}/test_file_provider", dir.path().to_str().unwrap());
192
193        // write a tmp file
194        assert!(std::fs::write(&file_path, "root=123456\nadmin=654321\n").is_ok());
195        let provider = WatchFileUserProvider::new(file_path.as_str()).unwrap();
196        let timeout = Duration::from_secs(60);
197
198        test_authenticate(&provider, "root", "123456", true, None).await;
199        test_authenticate(&provider, "admin", "654321", true, None).await;
200        test_authenticate(&provider, "root", "654321", false, None).await;
201
202        // update the tmp file
203        assert!(std::fs::write(&file_path, "root=654321\n").is_ok());
204        test_authenticate(&provider, "root", "123456", false, Some(timeout)).await;
205        test_authenticate(&provider, "root", "654321", true, Some(timeout)).await;
206        test_authenticate(&provider, "admin", "654321", false, Some(timeout)).await;
207
208        // remove the tmp file
209        assert!(std::fs::remove_file(&file_path).is_ok());
210        // When file is deleted during runtime, keep the last known good credentials
211        test_authenticate(&provider, "root", "654321", true, Some(timeout)).await;
212        test_authenticate(&provider, "root", "123456", false, Some(timeout)).await;
213        test_authenticate(&provider, "admin", "654321", false, Some(timeout)).await;
214
215        // recreate the tmp file
216        assert!(std::fs::write(&file_path, "root=123456\n").is_ok());
217        test_authenticate(&provider, "root", "123456", true, Some(timeout)).await;
218        test_authenticate(&provider, "root", "654321", false, Some(timeout)).await;
219        test_authenticate(&provider, "admin", "654321", false, Some(timeout)).await;
220    }
221}