auth/user_provider/
watch_file_user_provider.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::{Arc, Mutex};
16
17use async_trait::async_trait;
18use common_config::file_watcher::{FileWatcherBuilder, FileWatcherConfig};
19use common_telemetry::{info, warn};
20use snafu::ResultExt;
21
22use crate::error::{FileWatchSnafu, Result};
23use crate::user_provider::{UserInfoMap, authenticate_with_credential, load_credential_from_file};
24use crate::{Identity, Password, UserInfoRef, UserProvider};
25
26pub(crate) const WATCH_FILE_USER_PROVIDER: &str = "watch_file_user_provider";
27
28type WatchedCredentialRef = Arc<Mutex<UserInfoMap>>;
29
30/// A user provider that reads user credential from a file and watches the file for changes.
31///
32/// Both empty file and non-existent file are invalid and will cause initialization to fail.
33#[derive(Debug)]
34pub(crate) struct WatchFileUserProvider {
35    users: WatchedCredentialRef,
36}
37
38impl WatchFileUserProvider {
39    pub fn new(filepath: &str) -> Result<Self> {
40        let credential = load_credential_from_file(filepath)?;
41        let users = Arc::new(Mutex::new(credential));
42
43        let users_clone = users.clone();
44        let filepath_owned = filepath.to_string();
45
46        FileWatcherBuilder::new()
47            .watch_path(filepath)
48            .context(FileWatchSnafu)?
49            .config(FileWatcherConfig::new())
50            .spawn(move || match load_credential_from_file(&filepath_owned) {
51                Ok(credential) => {
52                    let mut users = users_clone.lock().expect("users credential must be valid");
53                    #[cfg(not(test))]
54                    info!("User provider file {} reloaded", &filepath_owned);
55                    #[cfg(test)]
56                    info!(
57                        "User provider file {} reloaded: {:?}",
58                        &filepath_owned, credential
59                    );
60                    *users = credential;
61                }
62                Err(err) => {
63                    warn!(
64                        ?err,
65                        "Fail to load credential from file {}; keep the old one", &filepath_owned
66                    )
67                }
68            })
69            .context(FileWatchSnafu)?;
70
71        Ok(WatchFileUserProvider { users })
72    }
73}
74
75#[async_trait]
76impl UserProvider for WatchFileUserProvider {
77    fn name(&self) -> &str {
78        WATCH_FILE_USER_PROVIDER
79    }
80
81    async fn authenticate(&self, id: Identity<'_>, password: Password<'_>) -> Result<UserInfoRef> {
82        let users = self.users.lock().expect("users credential must be valid");
83        authenticate_with_credential(&users, id, password)
84    }
85
86    async fn authorize(&self, _: &str, _: &str, _: &UserInfoRef) -> Result<()> {
87        // default allow all
88        Ok(())
89    }
90}
91
92#[cfg(test)]
93pub mod test {
94    use std::time::{Duration, Instant};
95
96    use common_test_util::temp_dir::create_temp_dir;
97    use tokio::time::sleep;
98
99    use crate::UserProvider;
100    use crate::user_provider::watch_file_user_provider::WatchFileUserProvider;
101    use crate::user_provider::{Identity, Password};
102
103    async fn test_authenticate(
104        provider: &dyn UserProvider,
105        username: &str,
106        password: &str,
107        ok: bool,
108        timeout: Option<Duration>,
109    ) {
110        if let Some(timeout) = timeout {
111            let deadline = Instant::now().checked_add(timeout).unwrap();
112            loop {
113                let re = provider
114                    .authenticate(
115                        Identity::UserId(username, None),
116                        Password::PlainText(password.to_string().into()),
117                    )
118                    .await;
119                if re.is_ok() == ok {
120                    break;
121                } else if Instant::now() < deadline {
122                    sleep(Duration::from_millis(100)).await;
123                } else {
124                    panic!("timeout (username: {username}, password: {password}, expected: {ok})");
125                }
126            }
127        } else {
128            let re = provider
129                .authenticate(
130                    Identity::UserId(username, None),
131                    Password::PlainText(password.to_string().into()),
132                )
133                .await;
134            assert_eq!(
135                re.is_ok(),
136                ok,
137                "username: {}, password: {}",
138                username,
139                password
140            );
141        }
142    }
143
144    #[tokio::test]
145    async fn test_file_provider_initialization_with_missing_file() {
146        common_telemetry::init_default_ut_logging();
147
148        let dir = create_temp_dir("test_missing_file");
149        let file_path = format!("{}/non_existent_file", dir.path().to_str().unwrap());
150
151        // Try to create provider with non-existent file should fail
152        let result = WatchFileUserProvider::new(file_path.as_str());
153        assert!(result.is_err());
154
155        let error = result.unwrap_err();
156        assert!(error.to_string().contains("UserProvider file must exist"));
157    }
158
159    #[tokio::test]
160    async fn test_file_provider() {
161        common_telemetry::init_default_ut_logging();
162
163        let dir = create_temp_dir("test_file_provider");
164        let file_path = format!("{}/test_file_provider", dir.path().to_str().unwrap());
165
166        // write a tmp file
167        assert!(std::fs::write(&file_path, "root=123456\nadmin=654321\n").is_ok());
168        let provider = WatchFileUserProvider::new(file_path.as_str()).unwrap();
169        let timeout = Duration::from_secs(60);
170
171        test_authenticate(&provider, "root", "123456", true, None).await;
172        test_authenticate(&provider, "admin", "654321", true, None).await;
173        test_authenticate(&provider, "root", "654321", false, None).await;
174
175        // update the tmp file
176        assert!(std::fs::write(&file_path, "root=654321\n").is_ok());
177        test_authenticate(&provider, "root", "123456", false, Some(timeout)).await;
178        test_authenticate(&provider, "root", "654321", true, Some(timeout)).await;
179        test_authenticate(&provider, "admin", "654321", false, Some(timeout)).await;
180
181        // remove the tmp file
182        assert!(std::fs::remove_file(&file_path).is_ok());
183        // When file is deleted during runtime, keep the last known good credentials
184        test_authenticate(&provider, "root", "654321", true, Some(timeout)).await;
185        test_authenticate(&provider, "root", "123456", false, Some(timeout)).await;
186        test_authenticate(&provider, "admin", "654321", false, Some(timeout)).await;
187
188        // recreate the tmp file
189        assert!(std::fs::write(&file_path, "root=123456\n").is_ok());
190        test_authenticate(&provider, "root", "123456", true, Some(timeout)).await;
191        test_authenticate(&provider, "root", "654321", false, Some(timeout)).await;
192        test_authenticate(&provider, "admin", "654321", false, Some(timeout)).await;
193    }
194}