1use ::auth::UserProviderRef;
16use axum::extract::{Request, State};
17use axum::http::{self, StatusCode};
18use axum::middleware::Next;
19use axum::response::{IntoResponse, Response};
20use base64::prelude::BASE64_STANDARD;
21use base64::Engine;
22use common_base::secrets::SecretString;
23use common_catalog::consts::DEFAULT_SCHEMA_NAME;
24use common_catalog::parse_catalog_and_schema_from_db_string;
25use common_error::ext::ErrorExt;
26use common_telemetry::warn;
27use common_time::timezone::parse_timezone;
28use common_time::Timezone;
29use headers::Header;
30use session::context::QueryContextBuilder;
31use snafu::{ensure, OptionExt, ResultExt};
32
33use crate::error::{
34 self, InvalidAuthHeaderInvisibleASCIISnafu, InvalidAuthHeaderSnafu, InvalidParameterSnafu,
35 NotFoundInfluxAuthSnafu, Result, UnsupportedAuthSchemeSnafu, UrlDecodeSnafu,
36};
37use crate::http::header::{GreptimeDbName, GREPTIME_TIMEZONE_HEADER_NAME};
38use crate::http::result::error_result::ErrorResponse;
39use crate::http::{AUTHORIZATION_HEADER, HTTP_API_PREFIX, PUBLIC_APIS};
40use crate::influxdb::{is_influxdb_request, is_influxdb_v2_request};
41
42#[derive(Clone)]
45pub struct AuthState {
46 user_provider: Option<UserProviderRef>,
47}
48
49impl AuthState {
50 pub fn new(user_provider: Option<UserProviderRef>) -> Self {
51 Self { user_provider }
52 }
53}
54
55pub async fn inner_auth<B>(
56 user_provider: Option<UserProviderRef>,
57 mut req: Request<B>,
58) -> std::result::Result<Request<B>, Response> {
59 let (catalog, schema) = extract_catalog_and_schema(&req);
61 let timezone = extract_timezone(&req);
63 let query_ctx_builder = QueryContextBuilder::default()
64 .current_catalog(catalog.clone())
65 .current_schema(schema.clone())
66 .timezone(timezone);
67
68 let query_ctx = query_ctx_builder.build();
69 let need_auth = need_auth(&req);
70
71 let user_provider = if let Some(user_provider) = user_provider.filter(|_| need_auth) {
73 user_provider
74 } else {
75 query_ctx.set_current_user(auth::userinfo_by_name(None));
76 let _ = req.extensions_mut().insert(query_ctx);
77 return Ok(req);
78 };
79
80 let (username, password) = match extract_username_and_password(&req) {
82 Ok((username, password)) => (username, password),
83 Err(e) => {
84 warn!(e; "extract username and password failed");
85 crate::metrics::METRIC_AUTH_FAILURE
86 .with_label_values(&[e.status_code().as_ref()])
87 .inc();
88 return Err(err_response(e));
89 }
90 };
91
92 match user_provider
94 .auth(
95 auth::Identity::UserId(&username, None),
96 auth::Password::PlainText(password),
97 &catalog,
98 &schema,
99 )
100 .await
101 {
102 Ok(userinfo) => {
103 query_ctx.set_current_user(userinfo);
104 let _ = req.extensions_mut().insert(query_ctx);
105 Ok(req)
106 }
107 Err(e) => {
108 warn!(e; "authenticate failed");
109 crate::metrics::METRIC_AUTH_FAILURE
110 .with_label_values(&[e.status_code().as_ref()])
111 .inc();
112 Err(err_response(e))
113 }
114 }
115}
116
117pub async fn check_http_auth(
118 State(auth_state): State<AuthState>,
119 req: Request,
120 next: Next,
121) -> Response {
122 match inner_auth(auth_state.user_provider, req).await {
123 Ok(req) => next.run(req).await,
124 Err(resp) => resp,
125 }
126}
127
128fn err_response(err: impl ErrorExt) -> Response {
129 (StatusCode::UNAUTHORIZED, ErrorResponse::from_error(err)).into_response()
130}
131
132pub fn extract_catalog_and_schema<B>(request: &Request<B>) -> (String, String) {
133 let dbname = request
135 .headers()
136 .get(GreptimeDbName::name())
137 .and_then(|header| header.to_str().ok())
139 .or_else(|| {
140 let query = request.uri().query().unwrap_or_default();
141 if is_influxdb_v2_request(request) {
142 extract_db_from_query(query).or_else(|| extract_bucket_from_query(query))
143 } else {
144 extract_db_from_query(query)
145 }
146 })
147 .unwrap_or(DEFAULT_SCHEMA_NAME);
148
149 parse_catalog_and_schema_from_db_string(dbname)
150}
151
152fn extract_timezone<B>(request: &Request<B>) -> Timezone {
153 let timezone = request
155 .headers()
156 .get(&GREPTIME_TIMEZONE_HEADER_NAME)
157 .and_then(|header| header.to_str().ok())
159 .unwrap_or("");
160 parse_timezone(Some(timezone))
161}
162
163fn get_influxdb_credentials<B>(request: &Request<B>) -> Result<Option<(Username, Password)>> {
164 if let Some(header) = request.headers().get(http::header::AUTHORIZATION) {
166 let (auth_scheme, credential) = header
168 .to_str()
169 .context(InvalidAuthHeaderInvisibleASCIISnafu)?
170 .split_once(' ')
171 .context(InvalidAuthHeaderSnafu)?;
172
173 let (username, password) = match auth_scheme.to_lowercase().as_str() {
174 "token" => {
175 let (u, p) = credential.split_once(':').context(InvalidAuthHeaderSnafu)?;
176 (u.to_string(), p.to_string().into())
177 }
178 "basic" => decode_basic(credential)?,
179 _ => UnsupportedAuthSchemeSnafu { name: auth_scheme }.fail()?,
180 };
181
182 Ok(Some((username, password)))
183 } else {
184 let Some(query_str) = request.uri().query() else {
186 return Ok(None);
187 };
188
189 let query_str = urlencoding::decode(query_str).context(UrlDecodeSnafu)?;
190
191 match extract_influxdb_user_from_query(&query_str) {
192 (None, None) => Ok(None),
193 (Some(username), Some(password)) => {
194 Ok(Some((username.to_string(), password.to_string().into())))
195 }
196 _ => InvalidParameterSnafu {
197 reason: "influxdb auth: username and password must be provided together"
198 .to_string(),
199 }
200 .fail(),
201 }
202 }
203}
204
205pub fn extract_username_and_password<B>(request: &Request<B>) -> Result<(Username, Password)> {
206 Ok(if is_influxdb_request(request) {
207 get_influxdb_credentials(request)?.context(NotFoundInfluxAuthSnafu)?
209 } else {
210 let scheme = auth_header(request)?;
212 match scheme {
213 AuthScheme::Basic(username, password) => (username, password),
214 }
215 })
216}
217
218#[derive(Debug)]
219pub enum AuthScheme {
220 Basic(Username, Password),
221}
222
223type Username = String;
224type Password = SecretString;
225
226impl TryFrom<&str> for AuthScheme {
227 type Error = error::Error;
228
229 fn try_from(value: &str) -> Result<Self> {
230 let (scheme, encoded_credentials) =
231 value.split_once(' ').context(InvalidAuthHeaderSnafu)?;
232
233 ensure!(!encoded_credentials.contains(' '), InvalidAuthHeaderSnafu);
234
235 match scheme.to_lowercase().as_str() {
236 "basic" => decode_basic(encoded_credentials)
237 .map(|(username, password)| AuthScheme::Basic(username, password)),
238 other => UnsupportedAuthSchemeSnafu { name: other }.fail(),
239 }
240 }
241}
242
243type Credential<'a> = &'a str;
244
245fn auth_header<B>(req: &Request<B>) -> Result<AuthScheme> {
246 let auth_header = req
247 .headers()
248 .get(AUTHORIZATION_HEADER)
249 .or_else(|| req.headers().get(http::header::AUTHORIZATION))
250 .context(error::NotFoundAuthHeaderSnafu)?
251 .to_str()
252 .context(InvalidAuthHeaderInvisibleASCIISnafu)?;
253
254 auth_header.try_into()
255}
256
257fn decode_basic(credential: Credential) -> Result<(Username, Password)> {
258 let decoded = BASE64_STANDARD
259 .decode(credential)
260 .context(error::InvalidBase64ValueSnafu)?;
261 let as_utf8 =
262 String::from_utf8(decoded).context(error::InvalidAuthHeaderInvalidUtf8ValueSnafu)?;
263
264 if let Some((user_id, password)) = as_utf8.split_once(':') {
265 return Ok((user_id.to_string(), password.to_string().into()));
266 }
267
268 InvalidAuthHeaderSnafu {}.fail()
269}
270
271fn need_auth<B>(req: &Request<B>) -> bool {
272 let path = req.uri().path();
273
274 for api in PUBLIC_APIS {
275 if path.starts_with(api) {
276 return false;
277 }
278 }
279
280 path.starts_with(HTTP_API_PREFIX)
281}
282
283fn extract_param_from_query<'a>(query: &'a str, param: &'a str) -> Option<&'a str> {
284 let prefix = format!("{}=", param);
285 for pair in query.split('&') {
286 if let Some(param) = pair.strip_prefix(&prefix) {
287 return if param.is_empty() { None } else { Some(param) };
288 }
289 }
290 None
291}
292
293fn extract_db_from_query(query: &str) -> Option<&str> {
294 extract_param_from_query(query, "db")
295}
296
297fn extract_bucket_from_query(query: &str) -> Option<&str> {
300 extract_param_from_query(query, "bucket")
301}
302
303fn extract_influxdb_user_from_query(query: &str) -> (Option<&str>, Option<&str>) {
304 let mut username = None;
305 let mut password = None;
306
307 for pair in query.split('&') {
308 if pair.starts_with("u=") && pair.len() > 2 {
309 username = Some(&pair[2..]);
310 } else if pair.starts_with("p=") && pair.len() > 2 {
311 password = Some(&pair[2..]);
312 }
313 }
314 (username, password)
315}
316
317#[cfg(test)]
318mod tests {
319 use std::assert_matches::assert_matches;
320
321 use common_base::secrets::ExposeSecret;
322
323 use super::*;
324
325 #[test]
326 fn test_need_auth() {
327 let req = Request::builder()
328 .uri("http://127.0.0.1/v1/influxdb/ping")
329 .body(())
330 .unwrap();
331
332 assert!(!need_auth(&req));
333
334 let req = Request::builder()
335 .uri("http://127.0.0.1/v1/influxdb/health")
336 .body(())
337 .unwrap();
338
339 assert!(!need_auth(&req));
340
341 let req = Request::builder()
342 .uri("http://127.0.0.1/v1/influxdb/write")
343 .body(())
344 .unwrap();
345
346 assert!(need_auth(&req));
347 }
348
349 #[test]
350 fn test_decode_basic() {
351 let credential = "dXNlcm5hbWU6cGFzc3dvcmQ=";
353 let (username, pwd) = decode_basic(credential).unwrap();
354 assert_eq!("username", username);
355 assert_eq!("password", pwd.expose_secret());
356
357 let wrong_credential = "dXNlcm5hbWU6cG Fzc3dvcmQ=";
358 let result = decode_basic(wrong_credential);
359 assert_matches!(result.err(), Some(error::Error::InvalidBase64Value { .. }));
360 }
361
362 #[test]
363 fn test_try_into_auth_scheme() {
364 let auth_scheme_str = "basic";
365 let re: Result<AuthScheme> = auth_scheme_str.try_into();
366 assert!(re.is_err());
367
368 let auth_scheme_str = "basic dGVzdDp0ZXN0";
369 let scheme: AuthScheme = auth_scheme_str.try_into().unwrap();
370 assert_matches!(scheme, AuthScheme::Basic(username, pwd) if username == "test" && pwd.expose_secret() == "test");
371
372 let unsupported = "digest";
373 let auth_scheme: Result<AuthScheme> = unsupported.try_into();
374 assert!(auth_scheme.is_err());
375 }
376
377 #[test]
378 fn test_auth_header() {
379 let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
381
382 let auth_scheme = auth_header(&req).unwrap();
383 assert_matches!(auth_scheme, AuthScheme::Basic(username, pwd) if username == "username" && pwd.expose_secret() == "password");
384
385 let wrong_req = mock_http_request(Some("Basic dXNlcm5hbWU6 cGFzc3dvcmQ="), None).unwrap();
386 let res = auth_header(&wrong_req);
387 assert_matches!(res.err(), Some(error::Error::InvalidAuthHeader { .. }));
388
389 let wrong_req = mock_http_request(Some("Digest dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
390 let res = auth_header(&wrong_req);
391 assert_matches!(res.err(), Some(error::Error::UnsupportedAuthScheme { .. }));
392 }
393
394 fn mock_http_request(auth_header: Option<&str>, uri: Option<&str>) -> Result<Request<()>> {
395 let http_api_version = crate::http::HTTP_API_VERSION;
396 let mut req = Request::builder()
397 .uri(uri.unwrap_or(format!("http://localhost/{http_api_version}/sql").as_str()));
398 if let Some(auth_header) = auth_header {
399 req = req.header(http::header::AUTHORIZATION, auth_header);
400 }
401
402 Ok(req.body(()).unwrap())
403 }
404
405 #[test]
406 fn test_db_name_header() {
407 let http_api_version = crate::http::HTTP_API_VERSION;
408 let req = Request::builder()
409 .uri(format!("http://localhost/{http_api_version}/sql").as_str())
410 .header(GreptimeDbName::name(), "greptime-tomcat")
411 .body(())
412 .unwrap();
413
414 let db = extract_catalog_and_schema(&req);
415 assert_eq!(db, ("greptime".to_string(), "tomcat".to_string()));
416 }
417
418 #[test]
419 fn test_extract_db() {
420 assert_matches!(extract_db_from_query(""), None);
421 assert_matches!(extract_db_from_query("&"), None);
422 assert_matches!(extract_db_from_query("db="), None);
423 assert_matches!(extract_bucket_from_query("bucket="), None);
424 assert_matches!(extract_bucket_from_query("db=foo"), None);
425 assert_matches!(extract_db_from_query("db=foo"), Some("foo"));
426 assert_matches!(extract_bucket_from_query("bucket=foo"), Some("foo"));
427 assert_matches!(extract_db_from_query("name=bar"), None);
428 assert_matches!(extract_db_from_query("db=&name=bar"), None);
429 assert_matches!(extract_db_from_query("db=foo&name=bar"), Some("foo"));
430 assert_matches!(extract_bucket_from_query("db=foo&bucket=bar"), Some("bar"));
431 assert_matches!(extract_db_from_query("name=bar&db="), None);
432 assert_matches!(extract_db_from_query("name=bar&db=foo"), Some("foo"));
433 assert_matches!(extract_db_from_query("name=bar&db=&name=bar"), None);
434 assert_matches!(
435 extract_db_from_query("name=bar&db=foo&name=bar"),
436 Some("foo")
437 );
438 }
439
440 #[test]
441 fn test_extract_user() {
442 assert_matches!(extract_influxdb_user_from_query(""), (None, None));
443 assert_matches!(extract_influxdb_user_from_query("u="), (None, None));
444 assert_matches!(
445 extract_influxdb_user_from_query("u=123"),
446 (Some("123"), None)
447 );
448 assert_matches!(
449 extract_influxdb_user_from_query("u=123&p="),
450 (Some("123"), None)
451 );
452 assert_matches!(
453 extract_influxdb_user_from_query("u=123&p=4"),
454 (Some("123"), Some("4"))
455 );
456 assert_matches!(extract_influxdb_user_from_query("p="), (None, None));
457 assert_matches!(extract_influxdb_user_from_query("p=4"), (None, Some("4")));
458 assert_matches!(
459 extract_influxdb_user_from_query("p=4&u="),
460 (None, Some("4"))
461 );
462 assert_matches!(
463 extract_influxdb_user_from_query("p=4&u=123"),
464 (Some("123"), Some("4"))
465 );
466 }
467}