Skip to main content

cmd/
user.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![allow(clippy::print_stdout)]
16
17use std::io;
18
19use auth::{
20    DEFAULT_PBKDF2_SHA256_ITERATIONS, MAX_PBKDF2_SHA256_SALT_LEN,
21    format_mysql_native_password_verifier, format_pbkdf2_sha256_password_verifier,
22};
23use clap::{ArgGroup, Parser, Subcommand, ValueEnum};
24use rand::RngCore;
25use snafu::ResultExt;
26
27use crate::error::{self, Result};
28
29#[derive(Debug, Parser)]
30pub struct Command {
31    #[clap(subcommand)]
32    pub subcmd: SubCommand,
33}
34
35#[derive(Debug, Subcommand)]
36pub enum SubCommand {
37    /// Generate a password verifier for user_provider.
38    HashPassword(HashPasswordCommand),
39}
40
41impl Command {
42    pub fn run(self) -> Result<()> {
43        match self.subcmd {
44            SubCommand::HashPassword(cmd) => cmd.run(),
45        }
46    }
47}
48
49#[derive(Debug, Parser)]
50#[clap(group(
51    ArgGroup::new("password-input")
52        .required(true)
53        .args(["password", "password_stdin"])
54))]
55pub struct HashPasswordCommand {
56    /// Password verifier format to generate.
57    #[clap(long, value_enum, default_value = "pbkdf2_sha256")]
58    format: PasswordFormat,
59
60    /// Plaintext password. Prefer --password-stdin to avoid shell history leaks.
61    #[clap(long)]
62    password: Option<String>,
63
64    /// Read the plaintext password from stdin.
65    #[clap(long)]
66    password_stdin: bool,
67
68    /// PBKDF2-SHA256 iteration count.
69    #[clap(long, default_value_t = DEFAULT_PBKDF2_SHA256_ITERATIONS)]
70    iterations: u32,
71
72    /// PBKDF2-SHA256 random salt length in bytes.
73    #[clap(long, default_value_t = 16)]
74    salt_len: usize,
75
76    /// PBKDF2-SHA256 salt as hex. Mainly useful for deterministic automation.
77    #[clap(long)]
78    salt_hex: Option<String>,
79}
80
81#[derive(Clone, Copy, Debug, ValueEnum)]
82#[clap(rename_all = "snake_case")]
83enum PasswordFormat {
84    Pbkdf2Sha256,
85    MysqlNativePassword,
86}
87
88impl HashPasswordCommand {
89    fn run(self) -> Result<()> {
90        let password = self.read_password()?;
91        let verifier = match self.format {
92            PasswordFormat::Pbkdf2Sha256 => {
93                let salt = self.pbkdf2_salt()?;
94                format_pbkdf2_sha256_password_verifier(password.as_bytes(), &salt, self.iterations)
95                    .map_err(common_error::ext::BoxedError::new)
96                    .context(error::OtherSnafu)?
97            }
98            PasswordFormat::MysqlNativePassword => {
99                format_mysql_native_password_verifier(password.as_bytes())
100            }
101        };
102
103        println!("{verifier}");
104        Ok(())
105    }
106
107    fn read_password(&self) -> Result<String> {
108        let password = if let Some(password) = self.password.as_ref() {
109            password.clone()
110        } else {
111            let mut password = String::new();
112            io::stdin()
113                .read_line(&mut password)
114                .context(error::FileIoSnafu)?;
115            password.trim_end_matches(['\r', '\n']).to_string()
116        };
117
118        // A blank password is rejected by the user provider before verifier
119        // comparison, so a verifier built from it would be unusable. Fail fast
120        // instead of emitting a dead verifier (e.g. on EOF or an empty line).
121        if password.is_empty() {
122            return error::IllegalConfigSnafu {
123                msg: "password must not be empty",
124            }
125            .fail();
126        }
127
128        Ok(password)
129    }
130
131    fn pbkdf2_salt(&self) -> Result<Vec<u8>> {
132        if let Some(salt_hex) = self.salt_hex.as_ref() {
133            let salt = hex::decode(salt_hex).map_err(|err| {
134                error::IllegalConfigSnafu {
135                    msg: format!("invalid --salt-hex: {err}"),
136                }
137                .build()
138            })?;
139            Self::ensure_salt_len(salt.len())?;
140            return Ok(salt);
141        }
142
143        Self::ensure_salt_len(self.salt_len)?;
144        let mut salt = vec![0u8; self.salt_len];
145        rand::rng().fill_bytes(&mut salt);
146        Ok(salt)
147    }
148
149    fn ensure_salt_len(salt_len: usize) -> Result<()> {
150        if salt_len == 0 || salt_len > MAX_PBKDF2_SHA256_SALT_LEN {
151            return error::IllegalConfigSnafu {
152                msg: format!("salt length must be in 1..={}", MAX_PBKDF2_SHA256_SALT_LEN),
153            }
154            .fail();
155        }
156        Ok(())
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use auth::MAX_PBKDF2_SHA256_ITERATIONS;
163
164    use super::*;
165
166    #[test]
167    fn test_hash_password_command_with_pbkdf2_sha256() {
168        let cmd = HashPasswordCommand {
169            format: PasswordFormat::Pbkdf2Sha256,
170            password: Some("password".to_string()),
171            password_stdin: false,
172            iterations: 4096,
173            salt_len: 16,
174            salt_hex: Some("73616c74".to_string()),
175        };
176
177        let password = cmd.read_password().unwrap();
178        let salt = cmd.pbkdf2_salt().unwrap();
179        let verifier =
180            format_pbkdf2_sha256_password_verifier(password.as_bytes(), &salt, cmd.iterations)
181                .unwrap();
182
183        assert_eq!(
184            "pbkdf2_sha256:4096:73616c74:c5e478d59288c841aa530db6845c4c8d962893a001ce4e11a4963873aa98134a",
185            verifier
186        );
187    }
188
189    #[test]
190    fn test_hash_password_command_with_mysql_native_password() {
191        let verifier = format_mysql_native_password_verifier("123456".as_bytes());
192
193        assert_eq!(
194            "mysql_native_password:6bb4837eb74329105ee4568dda7dc67ed2ca2ad9",
195            verifier
196        );
197    }
198
199    #[test]
200    fn test_reject_empty_salt() {
201        let cmd = HashPasswordCommand {
202            format: PasswordFormat::Pbkdf2Sha256,
203            password: Some("password".to_string()),
204            password_stdin: false,
205            iterations: 4096,
206            salt_len: 0,
207            salt_hex: None,
208        };
209
210        assert!(cmd.pbkdf2_salt().is_err());
211    }
212
213    #[test]
214    fn test_reject_empty_password() {
215        let cmd = HashPasswordCommand {
216            format: PasswordFormat::Pbkdf2Sha256,
217            password: Some(String::new()),
218            password_stdin: false,
219            iterations: 4096,
220            salt_len: 16,
221            salt_hex: Some("73616c74".to_string()),
222        };
223
224        assert!(cmd.read_password().is_err());
225    }
226
227    #[test]
228    fn test_reject_too_many_iterations() {
229        let result = format_pbkdf2_sha256_password_verifier(
230            b"password",
231            b"salt",
232            MAX_PBKDF2_SHA256_ITERATIONS + 1,
233        );
234
235        assert!(result.is_err());
236    }
237}