1use std::sync::Arc;
16
17use common_base::secrets::SecretString;
18use digest::Digest;
19use sha1::Sha1;
20use snafu::{ensure, OptionExt};
21
22use crate::error::{IllegalParamSnafu, InvalidConfigSnafu, Result, UserPasswordMismatchSnafu};
23use crate::user_info::DefaultUserInfo;
24use crate::user_provider::static_user_provider::{StaticUserProvider, STATIC_USER_PROVIDER};
25use crate::user_provider::watch_file_user_provider::{
26 WatchFileUserProvider, WATCH_FILE_USER_PROVIDER,
27};
28use crate::{UserInfoRef, UserProviderRef};
29
30pub(crate) const DEFAULT_USERNAME: &str = "greptime";
31
32pub fn userinfo_by_name(username: Option<String>) -> UserInfoRef {
35 DefaultUserInfo::with_name(username.unwrap_or_else(|| DEFAULT_USERNAME.to_string()))
36}
37
38pub fn user_provider_from_option(opt: &String) -> Result<UserProviderRef> {
39 let (name, content) = opt.split_once(':').with_context(|| InvalidConfigSnafu {
40 value: opt.to_string(),
41 msg: "UserProviderOption must be in format `<option>:<value>`",
42 })?;
43 match name {
44 STATIC_USER_PROVIDER => {
45 let provider =
46 StaticUserProvider::new(content).map(|p| Arc::new(p) as UserProviderRef)?;
47 Ok(provider)
48 }
49 WATCH_FILE_USER_PROVIDER => {
50 WatchFileUserProvider::new(content).map(|p| Arc::new(p) as UserProviderRef)
51 }
52 _ => InvalidConfigSnafu {
53 value: name.to_string(),
54 msg: "Invalid UserProviderOption",
55 }
56 .fail(),
57 }
58}
59
60pub fn static_user_provider_from_option(opt: &String) -> Result<StaticUserProvider> {
61 let (name, content) = opt.split_once(':').with_context(|| InvalidConfigSnafu {
62 value: opt.to_string(),
63 msg: "UserProviderOption must be in format `<option>:<value>`",
64 })?;
65 match name {
66 STATIC_USER_PROVIDER => {
67 let provider = StaticUserProvider::new(content)?;
68 Ok(provider)
69 }
70 _ => InvalidConfigSnafu {
71 value: name.to_string(),
72 msg: format!("Invalid UserProviderOption, expect only {STATIC_USER_PROVIDER}"),
73 }
74 .fail(),
75 }
76}
77
78type Username<'a> = &'a str;
79type HostOrIp<'a> = &'a str;
80
81#[derive(Debug, Clone)]
82pub enum Identity<'a> {
83 UserId(Username<'a>, Option<HostOrIp<'a>>),
84}
85
86pub type HashedPassword<'a> = &'a [u8];
87pub type Salt<'a> = &'a [u8];
88
89pub enum Password<'a> {
91 PlainText(SecretString),
92 MysqlNativePassword(HashedPassword<'a>, Salt<'a>),
93 PgMD5(HashedPassword<'a>, Salt<'a>),
94}
95
96impl Password<'_> {
97 pub fn r#type(&self) -> &str {
98 match self {
99 Password::PlainText(_) => "plain_text",
100 Password::MysqlNativePassword(_, _) => "mysql_native_password",
101 Password::PgMD5(_, _) => "pg_md5",
102 }
103 }
104}
105
106pub fn auth_mysql(
107 auth_data: HashedPassword,
108 salt: Salt,
109 username: &str,
110 save_pwd: &[u8],
111) -> Result<()> {
112 ensure!(
113 auth_data.len() == 20,
114 IllegalParamSnafu {
115 msg: "Illegal mysql password length"
116 }
117 );
118 let hash_stage_2 = double_sha1(save_pwd);
120 let tmp = sha1_two(salt, &hash_stage_2);
121 let mut xor_result = [0u8; 20];
123 for i in 0..20 {
124 xor_result[i] = auth_data[i] ^ tmp[i];
125 }
126 let candidate_stage_2 = sha1_one(&xor_result);
127 if candidate_stage_2 == hash_stage_2 {
128 Ok(())
129 } else {
130 UserPasswordMismatchSnafu {
131 username: username.to_string(),
132 }
133 .fail()
134 }
135}
136
137fn sha1_two(input_1: &[u8], input_2: &[u8]) -> Vec<u8> {
138 let mut hasher = Sha1::new();
139 hasher.update(input_1);
140 hasher.update(input_2);
141 hasher.finalize().to_vec()
142}
143
144fn sha1_one(data: &[u8]) -> Vec<u8> {
145 let mut hasher = Sha1::new();
146 hasher.update(data);
147 hasher.finalize().to_vec()
148}
149
150fn double_sha1(data: &[u8]) -> Vec<u8> {
151 sha1_one(&sha1_one(data))
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 #[test]
159 fn test_sha() {
160 let sha_1_answer: Vec<u8> = vec![
161 124, 74, 141, 9, 202, 55, 98, 175, 97, 229, 149, 32, 148, 61, 194, 100, 148, 248, 148,
162 27,
163 ];
164 let sha_1 = sha1_one("123456".as_bytes());
165 assert_eq!(sha_1, sha_1_answer);
166
167 let double_sha1_answer: Vec<u8> = vec![
168 107, 180, 131, 126, 183, 67, 41, 16, 94, 228, 86, 141, 218, 125, 198, 126, 210, 202,
169 42, 217,
170 ];
171 let double_sha1 = double_sha1("123456".as_bytes());
172 assert_eq!(double_sha1, double_sha1_answer);
173
174 let sha1_2_answer: Vec<u8> = vec![
175 132, 115, 215, 211, 99, 186, 164, 206, 168, 152, 217, 192, 117, 47, 240, 252, 142, 244,
176 37, 204,
177 ];
178 let sha1_2 = sha1_two("123456".as_bytes(), "654321".as_bytes());
179 assert_eq!(sha1_2, sha1_2_answer);
180 }
181}