auth/
common.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use common_base::secrets::SecretString;
18use digest::Digest;
19use sha1::Sha1;
20use snafu::{ensure, OptionExt};
21
22use crate::error::{IllegalParamSnafu, InvalidConfigSnafu, Result, UserPasswordMismatchSnafu};
23use crate::user_info::DefaultUserInfo;
24use crate::user_provider::static_user_provider::{StaticUserProvider, STATIC_USER_PROVIDER};
25use crate::user_provider::watch_file_user_provider::{
26    WatchFileUserProvider, WATCH_FILE_USER_PROVIDER,
27};
28use crate::{UserInfoRef, UserProviderRef};
29
30pub(crate) const DEFAULT_USERNAME: &str = "greptime";
31
32/// construct a [`UserInfo`](crate::user_info::UserInfo) impl with name
33/// use default username `greptime` if None is provided
34pub fn userinfo_by_name(username: Option<String>) -> UserInfoRef {
35    DefaultUserInfo::with_name(username.unwrap_or_else(|| DEFAULT_USERNAME.to_string()))
36}
37
38pub fn user_provider_from_option(opt: &String) -> Result<UserProviderRef> {
39    let (name, content) = opt.split_once(':').context(InvalidConfigSnafu {
40        value: opt.to_string(),
41        msg: "UserProviderOption must be in format `<option>:<value>`",
42    })?;
43    match name {
44        STATIC_USER_PROVIDER => {
45            let provider =
46                StaticUserProvider::new(content).map(|p| Arc::new(p) as UserProviderRef)?;
47            Ok(provider)
48        }
49        WATCH_FILE_USER_PROVIDER => {
50            WatchFileUserProvider::new(content).map(|p| Arc::new(p) as UserProviderRef)
51        }
52        _ => InvalidConfigSnafu {
53            value: name.to_string(),
54            msg: "Invalid UserProviderOption",
55        }
56        .fail(),
57    }
58}
59
60type Username<'a> = &'a str;
61type HostOrIp<'a> = &'a str;
62
63#[derive(Debug, Clone)]
64pub enum Identity<'a> {
65    UserId(Username<'a>, Option<HostOrIp<'a>>),
66}
67
68pub type HashedPassword<'a> = &'a [u8];
69pub type Salt<'a> = &'a [u8];
70
71/// Authentication information sent by the client.
72pub enum Password<'a> {
73    PlainText(SecretString),
74    MysqlNativePassword(HashedPassword<'a>, Salt<'a>),
75    PgMD5(HashedPassword<'a>, Salt<'a>),
76}
77
78impl Password<'_> {
79    pub fn r#type(&self) -> &str {
80        match self {
81            Password::PlainText(_) => "plain_text",
82            Password::MysqlNativePassword(_, _) => "mysql_native_password",
83            Password::PgMD5(_, _) => "pg_md5",
84        }
85    }
86}
87
88pub fn auth_mysql(
89    auth_data: HashedPassword,
90    salt: Salt,
91    username: &str,
92    save_pwd: &[u8],
93) -> Result<()> {
94    ensure!(
95        auth_data.len() == 20,
96        IllegalParamSnafu {
97            msg: "Illegal mysql password length"
98        }
99    );
100    // ref: https://github.com/mysql/mysql-server/blob/a246bad76b9271cb4333634e954040a970222e0a/sql/auth/password.cc#L62
101    let hash_stage_2 = double_sha1(save_pwd);
102    let tmp = sha1_two(salt, &hash_stage_2);
103    // xor auth_data and tmp
104    let mut xor_result = [0u8; 20];
105    for i in 0..20 {
106        xor_result[i] = auth_data[i] ^ tmp[i];
107    }
108    let candidate_stage_2 = sha1_one(&xor_result);
109    if candidate_stage_2 == hash_stage_2 {
110        Ok(())
111    } else {
112        UserPasswordMismatchSnafu {
113            username: username.to_string(),
114        }
115        .fail()
116    }
117}
118
119fn sha1_two(input_1: &[u8], input_2: &[u8]) -> Vec<u8> {
120    let mut hasher = Sha1::new();
121    hasher.update(input_1);
122    hasher.update(input_2);
123    hasher.finalize().to_vec()
124}
125
126fn sha1_one(data: &[u8]) -> Vec<u8> {
127    let mut hasher = Sha1::new();
128    hasher.update(data);
129    hasher.finalize().to_vec()
130}
131
132fn double_sha1(data: &[u8]) -> Vec<u8> {
133    sha1_one(&sha1_one(data))
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    #[test]
141    fn test_sha() {
142        let sha_1_answer: Vec<u8> = vec![
143            124, 74, 141, 9, 202, 55, 98, 175, 97, 229, 149, 32, 148, 61, 194, 100, 148, 248, 148,
144            27,
145        ];
146        let sha_1 = sha1_one("123456".as_bytes());
147        assert_eq!(sha_1, sha_1_answer);
148
149        let double_sha1_answer: Vec<u8> = vec![
150            107, 180, 131, 126, 183, 67, 41, 16, 94, 228, 86, 141, 218, 125, 198, 126, 210, 202,
151            42, 217,
152        ];
153        let double_sha1 = double_sha1("123456".as_bytes());
154        assert_eq!(double_sha1, double_sha1_answer);
155
156        let sha1_2_answer: Vec<u8> = vec![
157            132, 115, 215, 211, 99, 186, 164, 206, 168, 152, 217, 192, 117, 47, 240, 252, 142, 244,
158            37, 204,
159        ];
160        let sha1_2 = sha1_two("123456".as_bytes(), "654321".as_bytes());
161        assert_eq!(sha1_2, sha1_2_answer);
162    }
163}