Skip to main content

servers/http/
authorize.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use ::auth::UserProviderRef;
16use api::v1::Basic;
17use axum::extract::{Request, State};
18use axum::http::{self, StatusCode};
19use axum::middleware::Next;
20use axum::response::{IntoResponse, Response};
21use base64::Engine;
22use base64::prelude::BASE64_STANDARD;
23use common_base::secrets::{ExposeSecret, SecretString};
24use common_catalog::consts::DEFAULT_SCHEMA_NAME;
25use common_catalog::parse_catalog_and_schema_from_db_string;
26use common_error::ext::ErrorExt;
27use common_telemetry::warn;
28use common_time::Timezone;
29use common_time::timezone::parse_timezone;
30use headers::Header;
31use session::context::QueryContextBuilder;
32use snafu::{OptionExt, ResultExt, ensure};
33
34use crate::error::{
35    self, InvalidAuthHeaderInvisibleASCIISnafu, InvalidAuthHeaderSnafu, InvalidParameterSnafu,
36    NotFoundAuthHeaderSnafu, NotFoundInfluxAuthSnafu, Result, UnsupportedAuthSchemeSnafu,
37    UrlDecodeSnafu,
38};
39use crate::http::header::{GREPTIME_TIMEZONE_HEADER_NAME, GreptimeDbName};
40use crate::http::result::error_result::ErrorResponse;
41use crate::http::splunk::is_splunk_request;
42use crate::http::{AUTHORIZATION_HEADER, HTTP_API_PREFIX, PUBLIC_API_PREFIX};
43use crate::influxdb::{is_influxdb_request, is_influxdb_v2_request};
44
45/// AuthState is a holder state for [`UserProviderRef`]
46/// during [`check_http_auth`] function in axum's middleware
47#[derive(Clone)]
48pub struct AuthState {
49    user_provider: Option<UserProviderRef>,
50}
51
52impl AuthState {
53    pub fn new(user_provider: Option<UserProviderRef>) -> Self {
54        Self { user_provider }
55    }
56}
57
58pub async fn inner_auth<B>(
59    user_provider: Option<UserProviderRef>,
60    mut req: Request<B>,
61) -> std::result::Result<Request<B>, Response> {
62    // 1. prepare
63    let (catalog, schema) = extract_catalog_and_schema(&req);
64    // TODO(ruihang): move this out of auth module
65    let timezone = extract_timezone(&req);
66    let query_ctx_builder = QueryContextBuilder::default()
67        .current_catalog(catalog.clone())
68        .current_schema(schema.clone())
69        .timezone(timezone);
70
71    let query_ctx = query_ctx_builder.build();
72    let need_auth = need_auth(&req);
73
74    // 2. check if auth is needed
75    let user_provider = if let Some(user_provider) = user_provider.filter(|_| need_auth) {
76        user_provider
77    } else {
78        query_ctx.set_current_user(auth::userinfo_by_name(None));
79        let _ = req.extensions_mut().insert(query_ctx);
80        return Ok(req);
81    };
82
83    // 3. get username and pwd
84    let (username, password) = match extract_username_and_password(&req) {
85        Ok((username, password)) => (username, password),
86        Err(e) => {
87            warn!(e; "extract username and password failed");
88            crate::metrics::METRIC_AUTH_FAILURE
89                .with_label_values(&[e.status_code().as_ref()])
90                .inc();
91            if is_splunk_request(&req) {
92                // HEC: missing header -> 2 ("token is required"), else 4 ("invalid token").
93                let (status, code) = match &e {
94                    error::Error::NotFoundAuthHeader { .. } => (StatusCode::UNAUTHORIZED, 2),
95                    _ => (StatusCode::FORBIDDEN, 4),
96                };
97                return Err(splunk_hec_err(status, code));
98            }
99            return Err(err_response(e));
100        }
101    };
102
103    // 4. auth
104    match user_provider
105        .auth(
106            auth::Identity::UserId(&username, None),
107            auth::Password::PlainText(password),
108            &catalog,
109            &schema,
110        )
111        .await
112    {
113        Ok(userinfo) => {
114            query_ctx.set_current_user(userinfo);
115            let _ = req.extensions_mut().insert(query_ctx);
116            Ok(req)
117        }
118        Err(e) => {
119            warn!(e; "authenticate failed");
120            crate::metrics::METRIC_AUTH_FAILURE
121                .with_label_values(&[e.status_code().as_ref()])
122                .inc();
123            // HEC: bad credentials -> 4 ("invalid token", 403).
124            if is_splunk_request(&req) {
125                return Err(splunk_hec_err(StatusCode::FORBIDDEN, 4));
126            }
127            Err(err_response(e))
128        }
129    }
130}
131
132pub async fn check_http_auth(
133    State(auth_state): State<AuthState>,
134    req: Request,
135    next: Next,
136) -> Response {
137    match inner_auth(auth_state.user_provider, req).await {
138        Ok(req) => next.run(req).await,
139        Err(resp) => resp,
140    }
141}
142
143/// HEC-shaped auth error (`{"text","code"}`) so Splunk clients can branch on `code`.
144fn splunk_hec_err(status: StatusCode, code: u32) -> Response {
145    let text = match code {
146        2 => "Token is required",
147        4 => "Invalid token",
148        _ => "Unauthorized",
149    };
150    (
151        status,
152        axum::Json(serde_json::json!({ "text": text, "code": code })),
153    )
154        .into_response()
155}
156
157fn err_response(err: impl ErrorExt) -> Response {
158    (StatusCode::UNAUTHORIZED, ErrorResponse::from_error(err)).into_response()
159}
160
161pub fn extract_catalog_and_schema<B>(request: &Request<B>) -> (String, String) {
162    // parse database from header
163    let dbname = request
164        .headers()
165        .get(GreptimeDbName::name())
166        // eat this invalid ascii error and give user the final IllegalParam error
167        .and_then(|header| header.to_str().ok())
168        .or_else(|| {
169            let query = request.uri().query().unwrap_or_default();
170            if is_influxdb_v2_request(request) {
171                extract_db_from_query(query).or_else(|| extract_bucket_from_query(query))
172            } else {
173                extract_db_from_query(query)
174            }
175        })
176        .unwrap_or(DEFAULT_SCHEMA_NAME);
177
178    parse_catalog_and_schema_from_db_string(dbname)
179}
180
181fn extract_timezone<B>(request: &Request<B>) -> Timezone {
182    // parse timezone from header
183    let timezone = request
184        .headers()
185        .get(&GREPTIME_TIMEZONE_HEADER_NAME)
186        // eat this invalid ascii error and give user the final IllegalParam error
187        .and_then(|header| header.to_str().ok())
188        .unwrap_or("");
189    parse_timezone(Some(timezone))
190}
191
192fn get_influxdb_credentials<B>(request: &Request<B>) -> Result<Option<(Username, Password)>> {
193    // compat with influxdb v2 and v1
194    if let Some(header) = request.headers().get(http::header::AUTHORIZATION) {
195        // try header
196        let (auth_scheme, credential) = header
197            .to_str()
198            .context(InvalidAuthHeaderInvisibleASCIISnafu)?
199            .split_once(' ')
200            .context(InvalidAuthHeaderSnafu)?;
201
202        let (username, password) = match auth_scheme.to_lowercase().as_str() {
203            "token" => {
204                let (u, p) = credential.split_once(':').context(InvalidAuthHeaderSnafu)?;
205                (u.to_string(), p.to_string().into())
206            }
207            "basic" => decode_basic(credential)?,
208            _ => UnsupportedAuthSchemeSnafu { name: auth_scheme }.fail()?,
209        };
210
211        Ok(Some((username, password)))
212    } else {
213        // try u and p in query
214        let Some(query_str) = request.uri().query() else {
215            return Ok(None);
216        };
217
218        let query_str = urlencoding::decode(query_str).context(UrlDecodeSnafu)?;
219
220        match extract_influxdb_user_from_query(&query_str) {
221            (None, None) => Ok(None),
222            (Some(username), Some(password)) => {
223                Ok(Some((username.to_string(), password.to_string().into())))
224            }
225            _ => InvalidParameterSnafu {
226                reason: "influxdb auth: username and password must be provided together"
227                    .to_string(),
228            }
229            .fail(),
230        }
231    }
232}
233
234fn get_splunk_credentials<B>(request: &Request<B>) -> Result<Option<(Username, Password)>> {
235    let Some(header) = request.headers().get(http::header::AUTHORIZATION) else {
236        return Ok(None);
237    };
238    let (auth_scheme, credential) = header
239        .to_str()
240        .context(InvalidAuthHeaderInvisibleASCIISnafu)?
241        .split_once(' ')
242        .context(InvalidAuthHeaderSnafu)?;
243
244    let (username, password) = match auth_scheme.to_lowercase().as_str() {
245        "splunk" => {
246            let (u, p) = credential.split_once(':').context(InvalidAuthHeaderSnafu)?;
247            (u.to_string(), p.to_string().into())
248        }
249        "basic" => decode_basic(credential)?,
250        _ => UnsupportedAuthSchemeSnafu { name: auth_scheme }.fail()?,
251    };
252    Ok(Some((username, password)))
253}
254
255pub fn extract_username_and_password<B>(request: &Request<B>) -> Result<(Username, Password)> {
256    Ok(if is_influxdb_request(request) {
257        // compatible with influxdb auth
258        get_influxdb_credentials(request)?.context(NotFoundInfluxAuthSnafu)?
259    } else if is_splunk_request(request) {
260        get_splunk_credentials(request)?.context(NotFoundAuthHeaderSnafu)?
261    } else {
262        // normal http auth
263        let scheme = auth_header(request)?;
264        match scheme {
265            AuthScheme::Basic(username, password) => (username, password),
266        }
267    })
268}
269
270#[derive(Debug)]
271pub enum AuthScheme {
272    Basic(Username, Password),
273}
274
275type Username = String;
276type Password = SecretString;
277
278impl TryFrom<&str> for AuthScheme {
279    type Error = error::Error;
280
281    fn try_from(value: &str) -> Result<Self> {
282        let (scheme, encoded_credentials) =
283            value.split_once(' ').context(InvalidAuthHeaderSnafu)?;
284
285        ensure!(!encoded_credentials.contains(' '), InvalidAuthHeaderSnafu);
286
287        match scheme.to_lowercase().as_str() {
288            "basic" => decode_basic(encoded_credentials)
289                .map(|(username, password)| AuthScheme::Basic(username, password)),
290            other => UnsupportedAuthSchemeSnafu { name: other }.fail(),
291        }
292    }
293}
294
295impl From<AuthScheme> for api::v1::auth_header::AuthScheme {
296    fn from(value: AuthScheme) -> Self {
297        match value {
298            AuthScheme::Basic(username, password) => {
299                api::v1::auth_header::AuthScheme::Basic(Basic {
300                    username,
301                    password: password.expose_secret().clone(),
302                })
303            }
304        }
305    }
306}
307
308type Credential<'a> = &'a str;
309
310fn auth_header<B>(req: &Request<B>) -> Result<AuthScheme> {
311    let auth_header = req
312        .headers()
313        .get(AUTHORIZATION_HEADER)
314        .or_else(|| req.headers().get(http::header::AUTHORIZATION))
315        .context(error::NotFoundAuthHeaderSnafu)?
316        .to_str()
317        .context(InvalidAuthHeaderInvisibleASCIISnafu)?;
318
319    auth_header.try_into()
320}
321
322fn decode_basic(credential: Credential) -> Result<(Username, Password)> {
323    let decoded = BASE64_STANDARD
324        .decode(credential)
325        .context(error::InvalidBase64ValueSnafu)?;
326    let as_utf8 =
327        String::from_utf8(decoded).context(error::InvalidAuthHeaderInvalidUtf8ValueSnafu)?;
328
329    if let Some((user_id, password)) = as_utf8.split_once(':') {
330        return Ok((user_id.to_string(), password.to_string().into()));
331    }
332
333    InvalidAuthHeaderSnafu {}.fail()
334}
335
336fn need_auth<B>(req: &Request<B>) -> bool {
337    let path = req.uri().path();
338
339    for api in PUBLIC_API_PREFIX {
340        if path.starts_with(api) {
341            return false;
342        }
343    }
344
345    path.starts_with(HTTP_API_PREFIX)
346}
347
348fn extract_param_from_query<'a>(query: &'a str, param: &'a str) -> Option<&'a str> {
349    let prefix = format!("{}=", param);
350    for pair in query.split('&') {
351        if let Some(param) = pair.strip_prefix(&prefix) {
352            return if param.is_empty() { None } else { Some(param) };
353        }
354    }
355    None
356}
357
358fn extract_db_from_query(query: &str) -> Option<&str> {
359    extract_param_from_query(query, "db")
360}
361
362/// InfluxDB v2 uses "bucket" instead of "db"
363/// https://docs.influxdata.com/influxdb/v1/tools/api/#apiv2write-http-endpoint
364fn extract_bucket_from_query(query: &str) -> Option<&str> {
365    extract_param_from_query(query, "bucket")
366}
367
368fn extract_influxdb_user_from_query(query: &str) -> (Option<&str>, Option<&str>) {
369    let mut username = None;
370    let mut password = None;
371
372    for pair in query.split('&') {
373        if pair.starts_with("u=") && pair.len() > 2 {
374            username = Some(&pair[2..]);
375        } else if pair.starts_with("p=") && pair.len() > 2 {
376            password = Some(&pair[2..]);
377        }
378    }
379    (username, password)
380}
381
382#[cfg(test)]
383mod tests {
384    use std::assert_matches;
385
386    use common_base::secrets::ExposeSecret;
387
388    use super::*;
389
390    #[test]
391    fn test_need_auth() {
392        let req = Request::builder()
393            .uri("http://127.0.0.1/v1/influxdb/ping")
394            .body(())
395            .unwrap();
396
397        assert!(!need_auth(&req));
398
399        let req = Request::builder()
400            .uri("http://127.0.0.1/v1/influxdb/health")
401            .body(())
402            .unwrap();
403
404        assert!(!need_auth(&req));
405
406        let req = Request::builder()
407            .uri("http://127.0.0.1/v1/influxdb/write")
408            .body(())
409            .unwrap();
410
411        assert!(need_auth(&req));
412    }
413
414    #[test]
415    fn test_splunk_auth() {
416        let splunk_uri = "http://127.0.0.1/v1/splunk/services/collector/event";
417        let splunk_req = |auth: Option<&str>| {
418            let mut req = Request::builder().uri(splunk_uri);
419            if let Some(auth) = auth {
420                req = req.header(http::header::AUTHORIZATION, auth);
421            }
422            req.body(()).unwrap()
423        };
424
425        // is_splunk_request matches our mount, not other endpoints.
426        assert!(is_splunk_request(&splunk_req(None)));
427        assert!(!is_splunk_request(
428            &Request::builder()
429                .uri("http://127.0.0.1/v1/influxdb/write")
430                .body(())
431                .unwrap()
432        ));
433        assert!(!is_splunk_request(
434            &Request::builder()
435                .uri("http://127.0.0.1/v1/sql")
436                .body(())
437                .unwrap()
438        ));
439
440        // `Splunk <user:pass>` -> (user, pass).
441        let (username, password) =
442            get_splunk_credentials(&splunk_req(Some("Splunk teamA:secretA")))
443                .unwrap()
444                .unwrap();
445        assert_eq!(username, "teamA");
446        assert_eq!(password.expose_secret(), "secretA");
447
448        // standard Basic is also accepted (parity with influxdb).
449        let basic = basic_auth("u", "p");
450        let (username, password) = get_splunk_credentials(&splunk_req(Some(&basic)))
451            .unwrap()
452            .unwrap();
453        assert_eq!(username, "u");
454        assert_eq!(password.expose_secret(), "p");
455
456        // missing header -> None; token without ':' -> error.
457        assert!(get_splunk_credentials(&splunk_req(None)).unwrap().is_none());
458        assert!(get_splunk_credentials(&splunk_req(Some("Splunk no_colon_token"))).is_err());
459
460        // full dispatch routes a splunk request through the splunk scheme.
461        let (username, password) =
462            extract_username_and_password(&splunk_req(Some("Splunk teamA:secretA"))).unwrap();
463        assert_eq!(username, "teamA");
464        assert_eq!(password.expose_secret(), "secretA");
465    }
466
467    #[test]
468    fn test_decode_basic() {
469        let credential = basic_auth_credentials("username", "password");
470        let (username, pwd) = decode_basic(&credential).unwrap();
471        assert_eq!("username", username);
472        assert_eq!("password", pwd.expose_secret());
473
474        let wrong_credential = credential.replacen('c', "c ", 1);
475        let result = decode_basic(&wrong_credential);
476        assert_matches!(result.err(), Some(error::Error::InvalidBase64Value { .. }));
477    }
478
479    #[test]
480    fn test_try_into_auth_scheme() {
481        let auth_scheme_str = "basic";
482        let re: Result<AuthScheme> = auth_scheme_str.try_into();
483        assert!(re.is_err());
484
485        let auth_scheme_str = basic_auth("test", "test");
486        let scheme: AuthScheme = auth_scheme_str.as_str().try_into().unwrap();
487        assert_matches!(scheme, AuthScheme::Basic(username, pwd) if username == "test" && pwd.expose_secret() == "test");
488
489        let unsupported = "digest";
490        let auth_scheme: Result<AuthScheme> = unsupported.try_into();
491        assert!(auth_scheme.is_err());
492    }
493
494    #[test]
495    fn test_inner_auth_assigns_remote_query_id() {
496        let req =
497            mock_http_request(None, Some("http://127.0.0.1/v1/sql?db=greptime-public")).unwrap();
498        let req = futures::executor::block_on(inner_auth::<()>(None, req)).unwrap();
499        let query_ctx = req
500            .extensions()
501            .get::<session::context::QueryContext>()
502            .unwrap();
503
504        assert!(query_ctx.remote_query_id().is_some());
505    }
506
507    #[test]
508    fn test_auth_header() {
509        let header_value = basic_auth("username", "password");
510        let req = mock_http_request(Some(&header_value), None).unwrap();
511
512        let auth_scheme = auth_header(&req).unwrap();
513        assert_matches!(auth_scheme, AuthScheme::Basic(username, pwd) if username == "username" && pwd.expose_secret() == "password");
514
515        let wrong_auth_header = header_value.replacen('c', "c ", 1);
516        let wrong_req = mock_http_request(Some(&wrong_auth_header), None).unwrap();
517        let res = auth_header(&wrong_req);
518        assert_matches!(res.err(), Some(error::Error::InvalidAuthHeader { .. }));
519
520        let wrong_req = mock_http_request(
521            Some(&format!(
522                "Digest {}",
523                basic_auth_credentials("username", "password")
524            )),
525            None,
526        )
527        .unwrap();
528        let res = auth_header(&wrong_req);
529        assert_matches!(res.err(), Some(error::Error::UnsupportedAuthScheme { .. }));
530    }
531
532    fn basic_auth(username: &str, password: &str) -> String {
533        format!("Basic {}", basic_auth_credentials(username, password))
534    }
535
536    fn basic_auth_credentials(username: &str, password: &str) -> String {
537        BASE64_STANDARD.encode(format!("{username}:{password}"))
538    }
539
540    fn mock_http_request(auth_header: Option<&str>, uri: Option<&str>) -> Result<Request<()>> {
541        let http_api_version = crate::http::HTTP_API_VERSION;
542        let mut req = Request::builder()
543            .uri(uri.unwrap_or(format!("http://localhost/{http_api_version}/sql").as_str()));
544        if let Some(auth_header) = auth_header {
545            req = req.header(http::header::AUTHORIZATION, auth_header);
546        }
547
548        Ok(req.body(()).unwrap())
549    }
550
551    #[test]
552    fn test_db_name_header() {
553        let http_api_version = crate::http::HTTP_API_VERSION;
554        let req = Request::builder()
555            .uri(format!("http://localhost/{http_api_version}/sql").as_str())
556            .header(GreptimeDbName::name(), "greptime-tomcat")
557            .body(())
558            .unwrap();
559
560        let db = extract_catalog_and_schema(&req);
561        assert_eq!(db, ("greptime".to_string(), "tomcat".to_string()));
562    }
563
564    #[test]
565    fn test_extract_db() {
566        assert_matches!(extract_db_from_query(""), None);
567        assert_matches!(extract_db_from_query("&"), None);
568        assert_matches!(extract_db_from_query("db="), None);
569        assert_matches!(extract_bucket_from_query("bucket="), None);
570        assert_matches!(extract_bucket_from_query("db=foo"), None);
571        assert_matches!(extract_db_from_query("db=foo"), Some("foo"));
572        assert_matches!(extract_bucket_from_query("bucket=foo"), Some("foo"));
573        assert_matches!(extract_db_from_query("name=bar"), None);
574        assert_matches!(extract_db_from_query("db=&name=bar"), None);
575        assert_matches!(extract_db_from_query("db=foo&name=bar"), Some("foo"));
576        assert_matches!(extract_bucket_from_query("db=foo&bucket=bar"), Some("bar"));
577        assert_matches!(extract_db_from_query("name=bar&db="), None);
578        assert_matches!(extract_db_from_query("name=bar&db=foo"), Some("foo"));
579        assert_matches!(extract_db_from_query("name=bar&db=&name=bar"), None);
580        assert_matches!(
581            extract_db_from_query("name=bar&db=foo&name=bar"),
582            Some("foo")
583        );
584    }
585
586    #[test]
587    fn test_extract_user() {
588        assert_matches!(extract_influxdb_user_from_query(""), (None, None));
589        assert_matches!(extract_influxdb_user_from_query("u="), (None, None));
590        assert_matches!(
591            extract_influxdb_user_from_query("u=123"),
592            (Some("123"), None)
593        );
594        assert_matches!(
595            extract_influxdb_user_from_query("u=123&p="),
596            (Some("123"), None)
597        );
598        assert_matches!(
599            extract_influxdb_user_from_query("u=123&p=4"),
600            (Some("123"), Some("4"))
601        );
602        assert_matches!(extract_influxdb_user_from_query("p="), (None, None));
603        assert_matches!(extract_influxdb_user_from_query("p=4"), (None, Some("4")));
604        assert_matches!(
605            extract_influxdb_user_from_query("p=4&u="),
606            (None, Some("4"))
607        );
608        assert_matches!(
609            extract_influxdb_user_from_query("p=4&u=123"),
610            (Some("123"), Some("4"))
611        );
612    }
613}