auth/user_provider/
static_user_provider.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use async_trait::async_trait;
18use snafu::{OptionExt, ResultExt};
19
20use crate::error::{FromUtf8Snafu, InvalidConfigSnafu, Result};
21use crate::user_provider::{authenticate_with_credential, load_credential_from_file};
22use crate::{Identity, Password, UserInfoRef, UserProvider};
23
24pub(crate) const STATIC_USER_PROVIDER: &str = "static_user_provider";
25
26pub struct StaticUserProvider {
27    users: HashMap<String, Vec<u8>>,
28}
29
30impl StaticUserProvider {
31    pub(crate) fn new(value: &str) -> Result<Self> {
32        let (mode, content) = value.split_once(':').context(InvalidConfigSnafu {
33            value: value.to_string(),
34            msg: "StaticUserProviderOption must be in format `<option>:<value>`",
35        })?;
36        match mode {
37            "file" => {
38                let users = load_credential_from_file(content)?
39                    .context(InvalidConfigSnafu {
40                        value: content.to_string(),
41                        msg: "StaticFileUserProvider must be a valid file path",
42                    })?;
43                Ok(StaticUserProvider { users })
44            }
45            "cmd" => content
46                .split(',')
47                .map(|kv| {
48                    let (k, v) = kv.split_once('=').context(InvalidConfigSnafu {
49                        value: kv.to_string(),
50                        msg: "StaticUserProviderOption cmd values must be in format `user=pwd[,user=pwd]`",
51                    })?;
52                    Ok((k.to_string(), v.as_bytes().to_vec()))
53                })
54                .collect::<Result<HashMap<String, Vec<u8>>>>()
55                .map(|users| StaticUserProvider { users }),
56            _ => InvalidConfigSnafu {
57                value: mode.to_string(),
58                msg: "StaticUserProviderOption must be in format `file:<path>` or `cmd:<values>`",
59            }
60                .fail(),
61        }
62    }
63
64    /// Return a random username/password pair
65    /// This is useful for invoking from other components in the cluster
66    pub fn get_one_user_pwd(&self) -> Result<(String, String)> {
67        let kv = self.users.iter().next().context(InvalidConfigSnafu {
68            value: "",
69            msg: "Expect at least one pair of username and password",
70        })?;
71        let username = kv.0;
72        let pwd = String::from_utf8(kv.1.clone()).context(FromUtf8Snafu)?;
73        Ok((username.clone(), pwd))
74    }
75}
76
77#[async_trait]
78impl UserProvider for StaticUserProvider {
79    fn name(&self) -> &str {
80        STATIC_USER_PROVIDER
81    }
82
83    async fn authenticate(&self, id: Identity<'_>, pwd: Password<'_>) -> Result<UserInfoRef> {
84        authenticate_with_credential(&self.users, id, pwd)
85    }
86
87    async fn authorize(
88        &self,
89        _catalog: &str,
90        _schema: &str,
91        _user_info: &UserInfoRef,
92    ) -> Result<()> {
93        // default allow all
94        Ok(())
95    }
96}
97
98#[cfg(test)]
99pub mod test {
100    use std::fs::File;
101    use std::io::{LineWriter, Write};
102
103    use common_test_util::temp_dir::create_temp_dir;
104
105    use crate::user_info::DefaultUserInfo;
106    use crate::user_provider::static_user_provider::StaticUserProvider;
107    use crate::user_provider::{Identity, Password};
108    use crate::UserProvider;
109
110    async fn test_authenticate(provider: &dyn UserProvider, username: &str, password: &str) {
111        let re = provider
112            .authenticate(
113                Identity::UserId(username, None),
114                Password::PlainText(password.to_string().into()),
115            )
116            .await;
117        let _ = re.unwrap();
118    }
119
120    #[tokio::test]
121    async fn test_authorize() {
122        let user_info = DefaultUserInfo::with_name("root");
123        let provider = StaticUserProvider::new("cmd:root=123456,admin=654321").unwrap();
124        provider
125            .authorize("catalog", "schema", &user_info)
126            .await
127            .unwrap();
128    }
129
130    #[tokio::test]
131    async fn test_inline_provider() {
132        let provider = StaticUserProvider::new("cmd:root=123456,admin=654321").unwrap();
133        test_authenticate(&provider, "root", "123456").await;
134        test_authenticate(&provider, "admin", "654321").await;
135    }
136
137    #[tokio::test]
138    async fn test_file_provider() {
139        let dir = create_temp_dir("test_file_provider");
140        let file_path = format!("{}/test_file_provider", dir.path().to_str().unwrap());
141        {
142            // write a tmp file
143            let file = File::create(&file_path);
144            let file = file.unwrap();
145            let mut lw = LineWriter::new(file);
146            assert!(lw
147                .write_all(
148                    b"root=123456
149admin=654321",
150                )
151                .is_ok());
152            lw.flush().unwrap();
153        }
154
155        let param = format!("file:{file_path}");
156        let provider = StaticUserProvider::new(param.as_str()).unwrap();
157        test_authenticate(&provider, "root", "123456").await;
158        test_authenticate(&provider, "admin", "654321").await;
159    }
160}