auth/
common.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use common_base::secrets::SecretString;
18use digest::Digest;
19use sha1::Sha1;
20use snafu::{ensure, OptionExt};
21
22use crate::error::{IllegalParamSnafu, InvalidConfigSnafu, Result, UserPasswordMismatchSnafu};
23use crate::user_info::DefaultUserInfo;
24use crate::user_provider::static_user_provider::{StaticUserProvider, STATIC_USER_PROVIDER};
25use crate::user_provider::watch_file_user_provider::{
26    WatchFileUserProvider, WATCH_FILE_USER_PROVIDER,
27};
28use crate::{UserInfoRef, UserProviderRef};
29
30pub(crate) const DEFAULT_USERNAME: &str = "greptime";
31
32/// construct a [`UserInfo`](crate::user_info::UserInfo) impl with name
33/// use default username `greptime` if None is provided
34pub fn userinfo_by_name(username: Option<String>) -> UserInfoRef {
35    DefaultUserInfo::with_name(username.unwrap_or_else(|| DEFAULT_USERNAME.to_string()))
36}
37
38pub fn user_provider_from_option(opt: &String) -> Result<UserProviderRef> {
39    let (name, content) = opt.split_once(':').with_context(|| InvalidConfigSnafu {
40        value: opt.to_string(),
41        msg: "UserProviderOption must be in format `<option>:<value>`",
42    })?;
43    match name {
44        STATIC_USER_PROVIDER => {
45            let provider =
46                StaticUserProvider::new(content).map(|p| Arc::new(p) as UserProviderRef)?;
47            Ok(provider)
48        }
49        WATCH_FILE_USER_PROVIDER => {
50            WatchFileUserProvider::new(content).map(|p| Arc::new(p) as UserProviderRef)
51        }
52        _ => InvalidConfigSnafu {
53            value: name.to_string(),
54            msg: "Invalid UserProviderOption",
55        }
56        .fail(),
57    }
58}
59
60pub fn static_user_provider_from_option(opt: &String) -> Result<StaticUserProvider> {
61    let (name, content) = opt.split_once(':').with_context(|| InvalidConfigSnafu {
62        value: opt.to_string(),
63        msg: "UserProviderOption must be in format `<option>:<value>`",
64    })?;
65    match name {
66        STATIC_USER_PROVIDER => {
67            let provider = StaticUserProvider::new(content)?;
68            Ok(provider)
69        }
70        _ => InvalidConfigSnafu {
71            value: name.to_string(),
72            msg: format!("Invalid UserProviderOption, expect only {STATIC_USER_PROVIDER}"),
73        }
74        .fail(),
75    }
76}
77
78type Username<'a> = &'a str;
79type HostOrIp<'a> = &'a str;
80
81#[derive(Debug, Clone)]
82pub enum Identity<'a> {
83    UserId(Username<'a>, Option<HostOrIp<'a>>),
84}
85
86pub type HashedPassword<'a> = &'a [u8];
87pub type Salt<'a> = &'a [u8];
88
89/// Authentication information sent by the client.
90pub enum Password<'a> {
91    PlainText(SecretString),
92    MysqlNativePassword(HashedPassword<'a>, Salt<'a>),
93    PgMD5(HashedPassword<'a>, Salt<'a>),
94}
95
96impl Password<'_> {
97    pub fn r#type(&self) -> &str {
98        match self {
99            Password::PlainText(_) => "plain_text",
100            Password::MysqlNativePassword(_, _) => "mysql_native_password",
101            Password::PgMD5(_, _) => "pg_md5",
102        }
103    }
104}
105
106pub fn auth_mysql(
107    auth_data: HashedPassword,
108    salt: Salt,
109    username: &str,
110    save_pwd: &[u8],
111) -> Result<()> {
112    ensure!(
113        auth_data.len() == 20,
114        IllegalParamSnafu {
115            msg: "Illegal mysql password length"
116        }
117    );
118    // ref: https://github.com/mysql/mysql-server/blob/a246bad76b9271cb4333634e954040a970222e0a/sql/auth/password.cc#L62
119    let hash_stage_2 = double_sha1(save_pwd);
120    let tmp = sha1_two(salt, &hash_stage_2);
121    // xor auth_data and tmp
122    let mut xor_result = [0u8; 20];
123    for i in 0..20 {
124        xor_result[i] = auth_data[i] ^ tmp[i];
125    }
126    let candidate_stage_2 = sha1_one(&xor_result);
127    if candidate_stage_2 == hash_stage_2 {
128        Ok(())
129    } else {
130        UserPasswordMismatchSnafu {
131            username: username.to_string(),
132        }
133        .fail()
134    }
135}
136
137fn sha1_two(input_1: &[u8], input_2: &[u8]) -> Vec<u8> {
138    let mut hasher = Sha1::new();
139    hasher.update(input_1);
140    hasher.update(input_2);
141    hasher.finalize().to_vec()
142}
143
144fn sha1_one(data: &[u8]) -> Vec<u8> {
145    let mut hasher = Sha1::new();
146    hasher.update(data);
147    hasher.finalize().to_vec()
148}
149
150fn double_sha1(data: &[u8]) -> Vec<u8> {
151    sha1_one(&sha1_one(data))
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    #[test]
159    fn test_sha() {
160        let sha_1_answer: Vec<u8> = vec![
161            124, 74, 141, 9, 202, 55, 98, 175, 97, 229, 149, 32, 148, 61, 194, 100, 148, 248, 148,
162            27,
163        ];
164        let sha_1 = sha1_one("123456".as_bytes());
165        assert_eq!(sha_1, sha_1_answer);
166
167        let double_sha1_answer: Vec<u8> = vec![
168            107, 180, 131, 126, 183, 67, 41, 16, 94, 228, 86, 141, 218, 125, 198, 126, 210, 202,
169            42, 217,
170        ];
171        let double_sha1 = double_sha1("123456".as_bytes());
172        assert_eq!(double_sha1, double_sha1_answer);
173
174        let sha1_2_answer: Vec<u8> = vec![
175            132, 115, 215, 211, 99, 186, 164, 206, 168, 152, 217, 192, 117, 47, 240, 252, 142, 244,
176            37, 204,
177        ];
178        let sha1_2 = sha1_two("123456".as_bytes(), "654321".as_bytes());
179        assert_eq!(sha1_2, sha1_2_answer);
180    }
181}