1#![allow(clippy::print_stdout)]
16
17use std::io;
18
19use auth::{
20 DEFAULT_PBKDF2_SHA256_ITERATIONS, MAX_PBKDF2_SHA256_SALT_LEN,
21 format_mysql_native_password_verifier, format_pbkdf2_sha256_password_verifier,
22};
23use clap::{ArgGroup, Parser, Subcommand, ValueEnum};
24use rand::RngCore;
25use snafu::ResultExt;
26
27use crate::error::{self, Result};
28
29#[derive(Debug, Parser)]
30pub struct Command {
31 #[clap(subcommand)]
32 pub subcmd: SubCommand,
33}
34
35#[derive(Debug, Subcommand)]
36pub enum SubCommand {
37 HashPassword(HashPasswordCommand),
39}
40
41impl Command {
42 pub fn run(self) -> Result<()> {
43 match self.subcmd {
44 SubCommand::HashPassword(cmd) => cmd.run(),
45 }
46 }
47}
48
49#[derive(Debug, Parser)]
50#[clap(group(
51 ArgGroup::new("password-input")
52 .required(true)
53 .args(["password", "password_stdin"])
54))]
55pub struct HashPasswordCommand {
56 #[clap(long, value_enum, default_value = "pbkdf2_sha256")]
58 format: PasswordFormat,
59
60 #[clap(long)]
62 password: Option<String>,
63
64 #[clap(long)]
66 password_stdin: bool,
67
68 #[clap(long, default_value_t = DEFAULT_PBKDF2_SHA256_ITERATIONS)]
70 iterations: u32,
71
72 #[clap(long, default_value_t = 16)]
74 salt_len: usize,
75
76 #[clap(long)]
78 salt_hex: Option<String>,
79}
80
81#[derive(Clone, Copy, Debug, ValueEnum)]
82#[clap(rename_all = "snake_case")]
83enum PasswordFormat {
84 Pbkdf2Sha256,
85 MysqlNativePassword,
86}
87
88impl HashPasswordCommand {
89 fn run(self) -> Result<()> {
90 let password = self.read_password()?;
91 let verifier = match self.format {
92 PasswordFormat::Pbkdf2Sha256 => {
93 let salt = self.pbkdf2_salt()?;
94 format_pbkdf2_sha256_password_verifier(password.as_bytes(), &salt, self.iterations)
95 .map_err(common_error::ext::BoxedError::new)
96 .context(error::OtherSnafu)?
97 }
98 PasswordFormat::MysqlNativePassword => {
99 format_mysql_native_password_verifier(password.as_bytes())
100 }
101 };
102
103 println!("{verifier}");
104 Ok(())
105 }
106
107 fn read_password(&self) -> Result<String> {
108 let password = if let Some(password) = self.password.as_ref() {
109 password.clone()
110 } else {
111 let mut password = String::new();
112 io::stdin()
113 .read_line(&mut password)
114 .context(error::FileIoSnafu)?;
115 password.trim_end_matches(['\r', '\n']).to_string()
116 };
117
118 if password.is_empty() {
122 return error::IllegalConfigSnafu {
123 msg: "password must not be empty",
124 }
125 .fail();
126 }
127
128 Ok(password)
129 }
130
131 fn pbkdf2_salt(&self) -> Result<Vec<u8>> {
132 if let Some(salt_hex) = self.salt_hex.as_ref() {
133 let salt = hex::decode(salt_hex).map_err(|err| {
134 error::IllegalConfigSnafu {
135 msg: format!("invalid --salt-hex: {err}"),
136 }
137 .build()
138 })?;
139 Self::ensure_salt_len(salt.len())?;
140 return Ok(salt);
141 }
142
143 Self::ensure_salt_len(self.salt_len)?;
144 let mut salt = vec![0u8; self.salt_len];
145 rand::rng().fill_bytes(&mut salt);
146 Ok(salt)
147 }
148
149 fn ensure_salt_len(salt_len: usize) -> Result<()> {
150 if salt_len == 0 || salt_len > MAX_PBKDF2_SHA256_SALT_LEN {
151 return error::IllegalConfigSnafu {
152 msg: format!("salt length must be in 1..={}", MAX_PBKDF2_SHA256_SALT_LEN),
153 }
154 .fail();
155 }
156 Ok(())
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use auth::MAX_PBKDF2_SHA256_ITERATIONS;
163
164 use super::*;
165
166 #[test]
167 fn test_hash_password_command_with_pbkdf2_sha256() {
168 let cmd = HashPasswordCommand {
169 format: PasswordFormat::Pbkdf2Sha256,
170 password: Some("password".to_string()),
171 password_stdin: false,
172 iterations: 4096,
173 salt_len: 16,
174 salt_hex: Some("73616c74".to_string()),
175 };
176
177 let password = cmd.read_password().unwrap();
178 let salt = cmd.pbkdf2_salt().unwrap();
179 let verifier =
180 format_pbkdf2_sha256_password_verifier(password.as_bytes(), &salt, cmd.iterations)
181 .unwrap();
182
183 assert_eq!(
184 "pbkdf2_sha256:4096:73616c74:c5e478d59288c841aa530db6845c4c8d962893a001ce4e11a4963873aa98134a",
185 verifier
186 );
187 }
188
189 #[test]
190 fn test_hash_password_command_with_mysql_native_password() {
191 let verifier = format_mysql_native_password_verifier("123456".as_bytes());
192
193 assert_eq!(
194 "mysql_native_password:6bb4837eb74329105ee4568dda7dc67ed2ca2ad9",
195 verifier
196 );
197 }
198
199 #[test]
200 fn test_reject_empty_salt() {
201 let cmd = HashPasswordCommand {
202 format: PasswordFormat::Pbkdf2Sha256,
203 password: Some("password".to_string()),
204 password_stdin: false,
205 iterations: 4096,
206 salt_len: 0,
207 salt_hex: None,
208 };
209
210 assert!(cmd.pbkdf2_salt().is_err());
211 }
212
213 #[test]
214 fn test_reject_empty_password() {
215 let cmd = HashPasswordCommand {
216 format: PasswordFormat::Pbkdf2Sha256,
217 password: Some(String::new()),
218 password_stdin: false,
219 iterations: 4096,
220 salt_len: 16,
221 salt_hex: Some("73616c74".to_string()),
222 };
223
224 assert!(cmd.read_password().is_err());
225 }
226
227 #[test]
228 fn test_reject_too_many_iterations() {
229 let result = format_pbkdf2_sha256_password_verifier(
230 b"password",
231 b"salt",
232 MAX_PBKDF2_SHA256_ITERATIONS + 1,
233 );
234
235 assert!(result.is_err());
236 }
237}