1use std::sync::Arc;
16
17use common_base::secrets::SecretString;
18use digest::Digest;
19use sha1::Sha1;
20use snafu::{ensure, OptionExt};
21
22use crate::error::{IllegalParamSnafu, InvalidConfigSnafu, Result, UserPasswordMismatchSnafu};
23use crate::user_info::DefaultUserInfo;
24use crate::user_provider::static_user_provider::{StaticUserProvider, STATIC_USER_PROVIDER};
25use crate::user_provider::watch_file_user_provider::{
26 WatchFileUserProvider, WATCH_FILE_USER_PROVIDER,
27};
28use crate::{UserInfoRef, UserProviderRef};
29
30pub(crate) const DEFAULT_USERNAME: &str = "greptime";
31
32pub fn userinfo_by_name(username: Option<String>) -> UserInfoRef {
35 DefaultUserInfo::with_name(username.unwrap_or_else(|| DEFAULT_USERNAME.to_string()))
36}
37
38pub fn user_provider_from_option(opt: &String) -> Result<UserProviderRef> {
39 let (name, content) = opt.split_once(':').context(InvalidConfigSnafu {
40 value: opt.to_string(),
41 msg: "UserProviderOption must be in format `<option>:<value>`",
42 })?;
43 match name {
44 STATIC_USER_PROVIDER => {
45 let provider =
46 StaticUserProvider::new(content).map(|p| Arc::new(p) as UserProviderRef)?;
47 Ok(provider)
48 }
49 WATCH_FILE_USER_PROVIDER => {
50 WatchFileUserProvider::new(content).map(|p| Arc::new(p) as UserProviderRef)
51 }
52 _ => InvalidConfigSnafu {
53 value: name.to_string(),
54 msg: "Invalid UserProviderOption",
55 }
56 .fail(),
57 }
58}
59
60type Username<'a> = &'a str;
61type HostOrIp<'a> = &'a str;
62
63#[derive(Debug, Clone)]
64pub enum Identity<'a> {
65 UserId(Username<'a>, Option<HostOrIp<'a>>),
66}
67
68pub type HashedPassword<'a> = &'a [u8];
69pub type Salt<'a> = &'a [u8];
70
71pub enum Password<'a> {
73 PlainText(SecretString),
74 MysqlNativePassword(HashedPassword<'a>, Salt<'a>),
75 PgMD5(HashedPassword<'a>, Salt<'a>),
76}
77
78impl Password<'_> {
79 pub fn r#type(&self) -> &str {
80 match self {
81 Password::PlainText(_) => "plain_text",
82 Password::MysqlNativePassword(_, _) => "mysql_native_password",
83 Password::PgMD5(_, _) => "pg_md5",
84 }
85 }
86}
87
88pub fn auth_mysql(
89 auth_data: HashedPassword,
90 salt: Salt,
91 username: &str,
92 save_pwd: &[u8],
93) -> Result<()> {
94 ensure!(
95 auth_data.len() == 20,
96 IllegalParamSnafu {
97 msg: "Illegal mysql password length"
98 }
99 );
100 let hash_stage_2 = double_sha1(save_pwd);
102 let tmp = sha1_two(salt, &hash_stage_2);
103 let mut xor_result = [0u8; 20];
105 for i in 0..20 {
106 xor_result[i] = auth_data[i] ^ tmp[i];
107 }
108 let candidate_stage_2 = sha1_one(&xor_result);
109 if candidate_stage_2 == hash_stage_2 {
110 Ok(())
111 } else {
112 UserPasswordMismatchSnafu {
113 username: username.to_string(),
114 }
115 .fail()
116 }
117}
118
119fn sha1_two(input_1: &[u8], input_2: &[u8]) -> Vec<u8> {
120 let mut hasher = Sha1::new();
121 hasher.update(input_1);
122 hasher.update(input_2);
123 hasher.finalize().to_vec()
124}
125
126fn sha1_one(data: &[u8]) -> Vec<u8> {
127 let mut hasher = Sha1::new();
128 hasher.update(data);
129 hasher.finalize().to_vec()
130}
131
132fn double_sha1(data: &[u8]) -> Vec<u8> {
133 sha1_one(&sha1_one(data))
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 #[test]
141 fn test_sha() {
142 let sha_1_answer: Vec<u8> = vec![
143 124, 74, 141, 9, 202, 55, 98, 175, 97, 229, 149, 32, 148, 61, 194, 100, 148, 248, 148,
144 27,
145 ];
146 let sha_1 = sha1_one("123456".as_bytes());
147 assert_eq!(sha_1, sha_1_answer);
148
149 let double_sha1_answer: Vec<u8> = vec![
150 107, 180, 131, 126, 183, 67, 41, 16, 94, 228, 86, 141, 218, 125, 198, 126, 210, 202,
151 42, 217,
152 ];
153 let double_sha1 = double_sha1("123456".as_bytes());
154 assert_eq!(double_sha1, double_sha1_answer);
155
156 let sha1_2_answer: Vec<u8> = vec![
157 132, 115, 215, 211, 99, 186, 164, 206, 168, 152, 217, 192, 117, 47, 240, 252, 142, 244,
158 37, 204,
159 ];
160 let sha1_2 = sha1_two("123456".as_bytes(), "654321".as_bytes());
161 assert_eq!(sha1_2, sha1_2_answer);
162 }
163}