auth/
user_provider.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub(crate) mod static_user_provider;
16pub(crate) mod watch_file_user_provider;
17
18use std::collections::HashMap;
19use std::fs::File;
20use std::io;
21use std::io::BufRead;
22use std::path::Path;
23
24use common_base::secrets::ExposeSecret;
25use snafu::{ensure, OptionExt, ResultExt};
26
27use crate::common::{Identity, Password};
28use crate::error::{
29    IllegalParamSnafu, InvalidConfigSnafu, IoSnafu, Result, UnsupportedPasswordTypeSnafu,
30    UserNotFoundSnafu, UserPasswordMismatchSnafu,
31};
32use crate::user_info::DefaultUserInfo;
33use crate::{auth_mysql, UserInfoRef};
34
35#[async_trait::async_trait]
36pub trait UserProvider: Send + Sync {
37    fn name(&self) -> &str;
38
39    /// Checks whether a user is valid and allowed to access the database.
40    async fn authenticate(&self, id: Identity<'_>, password: Password<'_>) -> Result<UserInfoRef>;
41
42    /// Checks whether a connection request
43    /// from a certain user to a certain catalog/schema is legal.
44    /// This method should be called after [authenticate()](UserProvider::authenticate()).
45    async fn authorize(&self, catalog: &str, schema: &str, user_info: &UserInfoRef) -> Result<()>;
46
47    /// Combination of [authenticate()](UserProvider::authenticate()) and [authorize()](UserProvider::authorize()).
48    /// In most cases it's preferred for both convenience and performance.
49    async fn auth(
50        &self,
51        id: Identity<'_>,
52        password: Password<'_>,
53        catalog: &str,
54        schema: &str,
55    ) -> Result<UserInfoRef> {
56        let user_info = self.authenticate(id, password).await?;
57        self.authorize(catalog, schema, &user_info).await?;
58        Ok(user_info)
59    }
60
61    /// Returns whether this user provider implementation is backed by an external system.
62    fn external(&self) -> bool {
63        false
64    }
65}
66
67fn load_credential_from_file(filepath: &str) -> Result<Option<HashMap<String, Vec<u8>>>> {
68    // check valid path
69    let path = Path::new(filepath);
70    if !path.exists() {
71        return Ok(None);
72    }
73
74    ensure!(
75        path.is_file(),
76        InvalidConfigSnafu {
77            value: filepath,
78            msg: "UserProvider file must be a file",
79        }
80    );
81    let file = File::open(path).context(IoSnafu)?;
82    let credential = io::BufReader::new(file)
83        .lines()
84        .map_while(std::result::Result::ok)
85        .filter_map(|line| {
86            if let Some((k, v)) = line.split_once('=') {
87                Some((k.to_string(), v.as_bytes().to_vec()))
88            } else {
89                None
90            }
91        })
92        .collect::<HashMap<String, Vec<u8>>>();
93
94    ensure!(
95        !credential.is_empty(),
96        InvalidConfigSnafu {
97            value: filepath,
98            msg: "UserProvider's file must contains at least one valid credential",
99        }
100    );
101
102    Ok(Some(credential))
103}
104
105fn authenticate_with_credential(
106    users: &HashMap<String, Vec<u8>>,
107    input_id: Identity<'_>,
108    input_pwd: Password<'_>,
109) -> Result<UserInfoRef> {
110    match input_id {
111        Identity::UserId(username, _) => {
112            ensure!(
113                !username.is_empty(),
114                IllegalParamSnafu {
115                    msg: "blank username"
116                }
117            );
118            let save_pwd = users.get(username).context(UserNotFoundSnafu {
119                username: username.to_string(),
120            })?;
121
122            match input_pwd {
123                Password::PlainText(pwd) => {
124                    ensure!(
125                        !pwd.expose_secret().is_empty(),
126                        IllegalParamSnafu {
127                            msg: "blank password"
128                        }
129                    );
130                    if save_pwd == pwd.expose_secret().as_bytes() {
131                        Ok(DefaultUserInfo::with_name(username))
132                    } else {
133                        UserPasswordMismatchSnafu {
134                            username: username.to_string(),
135                        }
136                        .fail()
137                    }
138                }
139                Password::MysqlNativePassword(auth_data, salt) => {
140                    auth_mysql(auth_data, salt, username, save_pwd)
141                        .map(|_| DefaultUserInfo::with_name(username))
142                }
143                Password::PgMD5(_, _) => UnsupportedPasswordTypeSnafu {
144                    password_type: "pg_md5",
145                }
146                .fail(),
147            }
148        }
149    }
150}