1use ::auth::UserProviderRef;
16use api::v1::Basic;
17use axum::extract::{Request, State};
18use axum::http::{self, StatusCode};
19use axum::middleware::Next;
20use axum::response::{IntoResponse, Response};
21use base64::Engine;
22use base64::prelude::BASE64_STANDARD;
23use common_base::secrets::{ExposeSecret, SecretString};
24use common_catalog::consts::DEFAULT_SCHEMA_NAME;
25use common_catalog::parse_catalog_and_schema_from_db_string;
26use common_error::ext::ErrorExt;
27use common_telemetry::warn;
28use common_time::Timezone;
29use common_time::timezone::parse_timezone;
30use headers::Header;
31use session::context::QueryContextBuilder;
32use snafu::{OptionExt, ResultExt, ensure};
33
34use crate::error::{
35 self, InvalidAuthHeaderInvisibleASCIISnafu, InvalidAuthHeaderSnafu, InvalidParameterSnafu,
36 NotFoundInfluxAuthSnafu, Result, UnsupportedAuthSchemeSnafu, UrlDecodeSnafu,
37};
38use crate::http::header::{GREPTIME_TIMEZONE_HEADER_NAME, GreptimeDbName};
39use crate::http::result::error_result::ErrorResponse;
40use crate::http::{AUTHORIZATION_HEADER, HTTP_API_PREFIX, PUBLIC_APIS};
41use crate::influxdb::{is_influxdb_request, is_influxdb_v2_request};
42
43#[derive(Clone)]
46pub struct AuthState {
47 user_provider: Option<UserProviderRef>,
48}
49
50impl AuthState {
51 pub fn new(user_provider: Option<UserProviderRef>) -> Self {
52 Self { user_provider }
53 }
54}
55
56pub async fn inner_auth<B>(
57 user_provider: Option<UserProviderRef>,
58 mut req: Request<B>,
59) -> std::result::Result<Request<B>, Response> {
60 let (catalog, schema) = extract_catalog_and_schema(&req);
62 let timezone = extract_timezone(&req);
64 let query_ctx_builder = QueryContextBuilder::default()
65 .current_catalog(catalog.clone())
66 .current_schema(schema.clone())
67 .timezone(timezone);
68
69 let query_ctx = query_ctx_builder.build();
70 let need_auth = need_auth(&req);
71
72 let user_provider = if let Some(user_provider) = user_provider.filter(|_| need_auth) {
74 user_provider
75 } else {
76 query_ctx.set_current_user(auth::userinfo_by_name(None));
77 let _ = req.extensions_mut().insert(query_ctx);
78 return Ok(req);
79 };
80
81 let (username, password) = match extract_username_and_password(&req) {
83 Ok((username, password)) => (username, password),
84 Err(e) => {
85 warn!(e; "extract username and password failed");
86 crate::metrics::METRIC_AUTH_FAILURE
87 .with_label_values(&[e.status_code().as_ref()])
88 .inc();
89 return Err(err_response(e));
90 }
91 };
92
93 match user_provider
95 .auth(
96 auth::Identity::UserId(&username, None),
97 auth::Password::PlainText(password),
98 &catalog,
99 &schema,
100 )
101 .await
102 {
103 Ok(userinfo) => {
104 query_ctx.set_current_user(userinfo);
105 let _ = req.extensions_mut().insert(query_ctx);
106 Ok(req)
107 }
108 Err(e) => {
109 warn!(e; "authenticate failed");
110 crate::metrics::METRIC_AUTH_FAILURE
111 .with_label_values(&[e.status_code().as_ref()])
112 .inc();
113 Err(err_response(e))
114 }
115 }
116}
117
118pub async fn check_http_auth(
119 State(auth_state): State<AuthState>,
120 req: Request,
121 next: Next,
122) -> Response {
123 match inner_auth(auth_state.user_provider, req).await {
124 Ok(req) => next.run(req).await,
125 Err(resp) => resp,
126 }
127}
128
129fn err_response(err: impl ErrorExt) -> Response {
130 (StatusCode::UNAUTHORIZED, ErrorResponse::from_error(err)).into_response()
131}
132
133pub fn extract_catalog_and_schema<B>(request: &Request<B>) -> (String, String) {
134 let dbname = request
136 .headers()
137 .get(GreptimeDbName::name())
138 .and_then(|header| header.to_str().ok())
140 .or_else(|| {
141 let query = request.uri().query().unwrap_or_default();
142 if is_influxdb_v2_request(request) {
143 extract_db_from_query(query).or_else(|| extract_bucket_from_query(query))
144 } else {
145 extract_db_from_query(query)
146 }
147 })
148 .unwrap_or(DEFAULT_SCHEMA_NAME);
149
150 parse_catalog_and_schema_from_db_string(dbname)
151}
152
153fn extract_timezone<B>(request: &Request<B>) -> Timezone {
154 let timezone = request
156 .headers()
157 .get(&GREPTIME_TIMEZONE_HEADER_NAME)
158 .and_then(|header| header.to_str().ok())
160 .unwrap_or("");
161 parse_timezone(Some(timezone))
162}
163
164fn get_influxdb_credentials<B>(request: &Request<B>) -> Result<Option<(Username, Password)>> {
165 if let Some(header) = request.headers().get(http::header::AUTHORIZATION) {
167 let (auth_scheme, credential) = header
169 .to_str()
170 .context(InvalidAuthHeaderInvisibleASCIISnafu)?
171 .split_once(' ')
172 .context(InvalidAuthHeaderSnafu)?;
173
174 let (username, password) = match auth_scheme.to_lowercase().as_str() {
175 "token" => {
176 let (u, p) = credential.split_once(':').context(InvalidAuthHeaderSnafu)?;
177 (u.to_string(), p.to_string().into())
178 }
179 "basic" => decode_basic(credential)?,
180 _ => UnsupportedAuthSchemeSnafu { name: auth_scheme }.fail()?,
181 };
182
183 Ok(Some((username, password)))
184 } else {
185 let Some(query_str) = request.uri().query() else {
187 return Ok(None);
188 };
189
190 let query_str = urlencoding::decode(query_str).context(UrlDecodeSnafu)?;
191
192 match extract_influxdb_user_from_query(&query_str) {
193 (None, None) => Ok(None),
194 (Some(username), Some(password)) => {
195 Ok(Some((username.to_string(), password.to_string().into())))
196 }
197 _ => InvalidParameterSnafu {
198 reason: "influxdb auth: username and password must be provided together"
199 .to_string(),
200 }
201 .fail(),
202 }
203 }
204}
205
206pub fn extract_username_and_password<B>(request: &Request<B>) -> Result<(Username, Password)> {
207 Ok(if is_influxdb_request(request) {
208 get_influxdb_credentials(request)?.context(NotFoundInfluxAuthSnafu)?
210 } else {
211 let scheme = auth_header(request)?;
213 match scheme {
214 AuthScheme::Basic(username, password) => (username, password),
215 }
216 })
217}
218
219#[derive(Debug)]
220pub enum AuthScheme {
221 Basic(Username, Password),
222}
223
224type Username = String;
225type Password = SecretString;
226
227impl TryFrom<&str> for AuthScheme {
228 type Error = error::Error;
229
230 fn try_from(value: &str) -> Result<Self> {
231 let (scheme, encoded_credentials) =
232 value.split_once(' ').context(InvalidAuthHeaderSnafu)?;
233
234 ensure!(!encoded_credentials.contains(' '), InvalidAuthHeaderSnafu);
235
236 match scheme.to_lowercase().as_str() {
237 "basic" => decode_basic(encoded_credentials)
238 .map(|(username, password)| AuthScheme::Basic(username, password)),
239 other => UnsupportedAuthSchemeSnafu { name: other }.fail(),
240 }
241 }
242}
243
244impl From<AuthScheme> for api::v1::auth_header::AuthScheme {
245 fn from(value: AuthScheme) -> Self {
246 match value {
247 AuthScheme::Basic(username, password) => {
248 api::v1::auth_header::AuthScheme::Basic(Basic {
249 username,
250 password: password.expose_secret().to_string(),
251 })
252 }
253 }
254 }
255}
256
257type Credential<'a> = &'a str;
258
259fn auth_header<B>(req: &Request<B>) -> Result<AuthScheme> {
260 let auth_header = req
261 .headers()
262 .get(AUTHORIZATION_HEADER)
263 .or_else(|| req.headers().get(http::header::AUTHORIZATION))
264 .context(error::NotFoundAuthHeaderSnafu)?
265 .to_str()
266 .context(InvalidAuthHeaderInvisibleASCIISnafu)?;
267
268 auth_header.try_into()
269}
270
271fn decode_basic(credential: Credential) -> Result<(Username, Password)> {
272 let decoded = BASE64_STANDARD
273 .decode(credential)
274 .context(error::InvalidBase64ValueSnafu)?;
275 let as_utf8 =
276 String::from_utf8(decoded).context(error::InvalidAuthHeaderInvalidUtf8ValueSnafu)?;
277
278 if let Some((user_id, password)) = as_utf8.split_once(':') {
279 return Ok((user_id.to_string(), password.to_string().into()));
280 }
281
282 InvalidAuthHeaderSnafu {}.fail()
283}
284
285fn need_auth<B>(req: &Request<B>) -> bool {
286 let path = req.uri().path();
287
288 for api in PUBLIC_APIS {
289 if path.starts_with(api) {
290 return false;
291 }
292 }
293
294 path.starts_with(HTTP_API_PREFIX)
295}
296
297fn extract_param_from_query<'a>(query: &'a str, param: &'a str) -> Option<&'a str> {
298 let prefix = format!("{}=", param);
299 for pair in query.split('&') {
300 if let Some(param) = pair.strip_prefix(&prefix) {
301 return if param.is_empty() { None } else { Some(param) };
302 }
303 }
304 None
305}
306
307fn extract_db_from_query(query: &str) -> Option<&str> {
308 extract_param_from_query(query, "db")
309}
310
311fn extract_bucket_from_query(query: &str) -> Option<&str> {
314 extract_param_from_query(query, "bucket")
315}
316
317fn extract_influxdb_user_from_query(query: &str) -> (Option<&str>, Option<&str>) {
318 let mut username = None;
319 let mut password = None;
320
321 for pair in query.split('&') {
322 if pair.starts_with("u=") && pair.len() > 2 {
323 username = Some(&pair[2..]);
324 } else if pair.starts_with("p=") && pair.len() > 2 {
325 password = Some(&pair[2..]);
326 }
327 }
328 (username, password)
329}
330
331#[cfg(test)]
332mod tests {
333 use std::assert_matches::assert_matches;
334
335 use common_base::secrets::ExposeSecret;
336
337 use super::*;
338
339 #[test]
340 fn test_need_auth() {
341 let req = Request::builder()
342 .uri("http://127.0.0.1/v1/influxdb/ping")
343 .body(())
344 .unwrap();
345
346 assert!(!need_auth(&req));
347
348 let req = Request::builder()
349 .uri("http://127.0.0.1/v1/influxdb/health")
350 .body(())
351 .unwrap();
352
353 assert!(!need_auth(&req));
354
355 let req = Request::builder()
356 .uri("http://127.0.0.1/v1/influxdb/write")
357 .body(())
358 .unwrap();
359
360 assert!(need_auth(&req));
361 }
362
363 #[test]
364 fn test_decode_basic() {
365 let credential = "dXNlcm5hbWU6cGFzc3dvcmQ=";
367 let (username, pwd) = decode_basic(credential).unwrap();
368 assert_eq!("username", username);
369 assert_eq!("password", pwd.expose_secret());
370
371 let wrong_credential = "dXNlcm5hbWU6cG Fzc3dvcmQ=";
372 let result = decode_basic(wrong_credential);
373 assert_matches!(result.err(), Some(error::Error::InvalidBase64Value { .. }));
374 }
375
376 #[test]
377 fn test_try_into_auth_scheme() {
378 let auth_scheme_str = "basic";
379 let re: Result<AuthScheme> = auth_scheme_str.try_into();
380 assert!(re.is_err());
381
382 let auth_scheme_str = "basic dGVzdDp0ZXN0";
383 let scheme: AuthScheme = auth_scheme_str.try_into().unwrap();
384 assert_matches!(scheme, AuthScheme::Basic(username, pwd) if username == "test" && pwd.expose_secret() == "test");
385
386 let unsupported = "digest";
387 let auth_scheme: Result<AuthScheme> = unsupported.try_into();
388 assert!(auth_scheme.is_err());
389 }
390
391 #[test]
392 fn test_auth_header() {
393 let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
395
396 let auth_scheme = auth_header(&req).unwrap();
397 assert_matches!(auth_scheme, AuthScheme::Basic(username, pwd) if username == "username" && pwd.expose_secret() == "password");
398
399 let wrong_req = mock_http_request(Some("Basic dXNlcm5hbWU6 cGFzc3dvcmQ="), None).unwrap();
400 let res = auth_header(&wrong_req);
401 assert_matches!(res.err(), Some(error::Error::InvalidAuthHeader { .. }));
402
403 let wrong_req = mock_http_request(Some("Digest dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
404 let res = auth_header(&wrong_req);
405 assert_matches!(res.err(), Some(error::Error::UnsupportedAuthScheme { .. }));
406 }
407
408 fn mock_http_request(auth_header: Option<&str>, uri: Option<&str>) -> Result<Request<()>> {
409 let http_api_version = crate::http::HTTP_API_VERSION;
410 let mut req = Request::builder()
411 .uri(uri.unwrap_or(format!("http://localhost/{http_api_version}/sql").as_str()));
412 if let Some(auth_header) = auth_header {
413 req = req.header(http::header::AUTHORIZATION, auth_header);
414 }
415
416 Ok(req.body(()).unwrap())
417 }
418
419 #[test]
420 fn test_db_name_header() {
421 let http_api_version = crate::http::HTTP_API_VERSION;
422 let req = Request::builder()
423 .uri(format!("http://localhost/{http_api_version}/sql").as_str())
424 .header(GreptimeDbName::name(), "greptime-tomcat")
425 .body(())
426 .unwrap();
427
428 let db = extract_catalog_and_schema(&req);
429 assert_eq!(db, ("greptime".to_string(), "tomcat".to_string()));
430 }
431
432 #[test]
433 fn test_extract_db() {
434 assert_matches!(extract_db_from_query(""), None);
435 assert_matches!(extract_db_from_query("&"), None);
436 assert_matches!(extract_db_from_query("db="), None);
437 assert_matches!(extract_bucket_from_query("bucket="), None);
438 assert_matches!(extract_bucket_from_query("db=foo"), None);
439 assert_matches!(extract_db_from_query("db=foo"), Some("foo"));
440 assert_matches!(extract_bucket_from_query("bucket=foo"), Some("foo"));
441 assert_matches!(extract_db_from_query("name=bar"), None);
442 assert_matches!(extract_db_from_query("db=&name=bar"), None);
443 assert_matches!(extract_db_from_query("db=foo&name=bar"), Some("foo"));
444 assert_matches!(extract_bucket_from_query("db=foo&bucket=bar"), Some("bar"));
445 assert_matches!(extract_db_from_query("name=bar&db="), None);
446 assert_matches!(extract_db_from_query("name=bar&db=foo"), Some("foo"));
447 assert_matches!(extract_db_from_query("name=bar&db=&name=bar"), None);
448 assert_matches!(
449 extract_db_from_query("name=bar&db=foo&name=bar"),
450 Some("foo")
451 );
452 }
453
454 #[test]
455 fn test_extract_user() {
456 assert_matches!(extract_influxdb_user_from_query(""), (None, None));
457 assert_matches!(extract_influxdb_user_from_query("u="), (None, None));
458 assert_matches!(
459 extract_influxdb_user_from_query("u=123"),
460 (Some("123"), None)
461 );
462 assert_matches!(
463 extract_influxdb_user_from_query("u=123&p="),
464 (Some("123"), None)
465 );
466 assert_matches!(
467 extract_influxdb_user_from_query("u=123&p=4"),
468 (Some("123"), Some("4"))
469 );
470 assert_matches!(extract_influxdb_user_from_query("p="), (None, None));
471 assert_matches!(extract_influxdb_user_from_query("p=4"), (None, Some("4")));
472 assert_matches!(
473 extract_influxdb_user_from_query("p=4&u="),
474 (None, Some("4"))
475 );
476 assert_matches!(
477 extract_influxdb_user_from_query("p=4&u=123"),
478 (Some("123"), Some("4"))
479 );
480 }
481}