servers/http/
authorize.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use ::auth::UserProviderRef;
16use api::v1::Basic;
17use axum::extract::{Request, State};
18use axum::http::{self, StatusCode};
19use axum::middleware::Next;
20use axum::response::{IntoResponse, Response};
21use base64::Engine;
22use base64::prelude::BASE64_STANDARD;
23use common_base::secrets::{ExposeSecret, SecretString};
24use common_catalog::consts::DEFAULT_SCHEMA_NAME;
25use common_catalog::parse_catalog_and_schema_from_db_string;
26use common_error::ext::ErrorExt;
27use common_telemetry::warn;
28use common_time::Timezone;
29use common_time::timezone::parse_timezone;
30use headers::Header;
31use session::context::QueryContextBuilder;
32use snafu::{OptionExt, ResultExt, ensure};
33
34use crate::error::{
35    self, InvalidAuthHeaderInvisibleASCIISnafu, InvalidAuthHeaderSnafu, InvalidParameterSnafu,
36    NotFoundInfluxAuthSnafu, Result, UnsupportedAuthSchemeSnafu, UrlDecodeSnafu,
37};
38use crate::http::header::{GREPTIME_TIMEZONE_HEADER_NAME, GreptimeDbName};
39use crate::http::result::error_result::ErrorResponse;
40use crate::http::{AUTHORIZATION_HEADER, HTTP_API_PREFIX, PUBLIC_APIS};
41use crate::influxdb::{is_influxdb_request, is_influxdb_v2_request};
42
43/// AuthState is a holder state for [`UserProviderRef`]
44/// during [`check_http_auth`] function in axum's middleware
45#[derive(Clone)]
46pub struct AuthState {
47    user_provider: Option<UserProviderRef>,
48}
49
50impl AuthState {
51    pub fn new(user_provider: Option<UserProviderRef>) -> Self {
52        Self { user_provider }
53    }
54}
55
56pub async fn inner_auth<B>(
57    user_provider: Option<UserProviderRef>,
58    mut req: Request<B>,
59) -> std::result::Result<Request<B>, Response> {
60    // 1. prepare
61    let (catalog, schema) = extract_catalog_and_schema(&req);
62    // TODO(ruihang): move this out of auth module
63    let timezone = extract_timezone(&req);
64    let query_ctx_builder = QueryContextBuilder::default()
65        .current_catalog(catalog.clone())
66        .current_schema(schema.clone())
67        .timezone(timezone);
68
69    let query_ctx = query_ctx_builder.build();
70    let need_auth = need_auth(&req);
71
72    // 2. check if auth is needed
73    let user_provider = if let Some(user_provider) = user_provider.filter(|_| need_auth) {
74        user_provider
75    } else {
76        query_ctx.set_current_user(auth::userinfo_by_name(None));
77        let _ = req.extensions_mut().insert(query_ctx);
78        return Ok(req);
79    };
80
81    // 3. get username and pwd
82    let (username, password) = match extract_username_and_password(&req) {
83        Ok((username, password)) => (username, password),
84        Err(e) => {
85            warn!(e; "extract username and password failed");
86            crate::metrics::METRIC_AUTH_FAILURE
87                .with_label_values(&[e.status_code().as_ref()])
88                .inc();
89            return Err(err_response(e));
90        }
91    };
92
93    // 4. auth
94    match user_provider
95        .auth(
96            auth::Identity::UserId(&username, None),
97            auth::Password::PlainText(password),
98            &catalog,
99            &schema,
100        )
101        .await
102    {
103        Ok(userinfo) => {
104            query_ctx.set_current_user(userinfo);
105            let _ = req.extensions_mut().insert(query_ctx);
106            Ok(req)
107        }
108        Err(e) => {
109            warn!(e; "authenticate failed");
110            crate::metrics::METRIC_AUTH_FAILURE
111                .with_label_values(&[e.status_code().as_ref()])
112                .inc();
113            Err(err_response(e))
114        }
115    }
116}
117
118pub async fn check_http_auth(
119    State(auth_state): State<AuthState>,
120    req: Request,
121    next: Next,
122) -> Response {
123    match inner_auth(auth_state.user_provider, req).await {
124        Ok(req) => next.run(req).await,
125        Err(resp) => resp,
126    }
127}
128
129fn err_response(err: impl ErrorExt) -> Response {
130    (StatusCode::UNAUTHORIZED, ErrorResponse::from_error(err)).into_response()
131}
132
133pub fn extract_catalog_and_schema<B>(request: &Request<B>) -> (String, String) {
134    // parse database from header
135    let dbname = request
136        .headers()
137        .get(GreptimeDbName::name())
138        // eat this invalid ascii error and give user the final IllegalParam error
139        .and_then(|header| header.to_str().ok())
140        .or_else(|| {
141            let query = request.uri().query().unwrap_or_default();
142            if is_influxdb_v2_request(request) {
143                extract_db_from_query(query).or_else(|| extract_bucket_from_query(query))
144            } else {
145                extract_db_from_query(query)
146            }
147        })
148        .unwrap_or(DEFAULT_SCHEMA_NAME);
149
150    parse_catalog_and_schema_from_db_string(dbname)
151}
152
153fn extract_timezone<B>(request: &Request<B>) -> Timezone {
154    // parse timezone from header
155    let timezone = request
156        .headers()
157        .get(&GREPTIME_TIMEZONE_HEADER_NAME)
158        // eat this invalid ascii error and give user the final IllegalParam error
159        .and_then(|header| header.to_str().ok())
160        .unwrap_or("");
161    parse_timezone(Some(timezone))
162}
163
164fn get_influxdb_credentials<B>(request: &Request<B>) -> Result<Option<(Username, Password)>> {
165    // compat with influxdb v2 and v1
166    if let Some(header) = request.headers().get(http::header::AUTHORIZATION) {
167        // try header
168        let (auth_scheme, credential) = header
169            .to_str()
170            .context(InvalidAuthHeaderInvisibleASCIISnafu)?
171            .split_once(' ')
172            .context(InvalidAuthHeaderSnafu)?;
173
174        let (username, password) = match auth_scheme.to_lowercase().as_str() {
175            "token" => {
176                let (u, p) = credential.split_once(':').context(InvalidAuthHeaderSnafu)?;
177                (u.to_string(), p.to_string().into())
178            }
179            "basic" => decode_basic(credential)?,
180            _ => UnsupportedAuthSchemeSnafu { name: auth_scheme }.fail()?,
181        };
182
183        Ok(Some((username, password)))
184    } else {
185        // try u and p in query
186        let Some(query_str) = request.uri().query() else {
187            return Ok(None);
188        };
189
190        let query_str = urlencoding::decode(query_str).context(UrlDecodeSnafu)?;
191
192        match extract_influxdb_user_from_query(&query_str) {
193            (None, None) => Ok(None),
194            (Some(username), Some(password)) => {
195                Ok(Some((username.to_string(), password.to_string().into())))
196            }
197            _ => InvalidParameterSnafu {
198                reason: "influxdb auth: username and password must be provided together"
199                    .to_string(),
200            }
201            .fail(),
202        }
203    }
204}
205
206pub fn extract_username_and_password<B>(request: &Request<B>) -> Result<(Username, Password)> {
207    Ok(if is_influxdb_request(request) {
208        // compatible with influxdb auth
209        get_influxdb_credentials(request)?.context(NotFoundInfluxAuthSnafu)?
210    } else {
211        // normal http auth
212        let scheme = auth_header(request)?;
213        match scheme {
214            AuthScheme::Basic(username, password) => (username, password),
215        }
216    })
217}
218
219#[derive(Debug)]
220pub enum AuthScheme {
221    Basic(Username, Password),
222}
223
224type Username = String;
225type Password = SecretString;
226
227impl TryFrom<&str> for AuthScheme {
228    type Error = error::Error;
229
230    fn try_from(value: &str) -> Result<Self> {
231        let (scheme, encoded_credentials) =
232            value.split_once(' ').context(InvalidAuthHeaderSnafu)?;
233
234        ensure!(!encoded_credentials.contains(' '), InvalidAuthHeaderSnafu);
235
236        match scheme.to_lowercase().as_str() {
237            "basic" => decode_basic(encoded_credentials)
238                .map(|(username, password)| AuthScheme::Basic(username, password)),
239            other => UnsupportedAuthSchemeSnafu { name: other }.fail(),
240        }
241    }
242}
243
244impl From<AuthScheme> for api::v1::auth_header::AuthScheme {
245    fn from(value: AuthScheme) -> Self {
246        match value {
247            AuthScheme::Basic(username, password) => {
248                api::v1::auth_header::AuthScheme::Basic(Basic {
249                    username,
250                    password: password.expose_secret().to_string(),
251                })
252            }
253        }
254    }
255}
256
257type Credential<'a> = &'a str;
258
259fn auth_header<B>(req: &Request<B>) -> Result<AuthScheme> {
260    let auth_header = req
261        .headers()
262        .get(AUTHORIZATION_HEADER)
263        .or_else(|| req.headers().get(http::header::AUTHORIZATION))
264        .context(error::NotFoundAuthHeaderSnafu)?
265        .to_str()
266        .context(InvalidAuthHeaderInvisibleASCIISnafu)?;
267
268    auth_header.try_into()
269}
270
271fn decode_basic(credential: Credential) -> Result<(Username, Password)> {
272    let decoded = BASE64_STANDARD
273        .decode(credential)
274        .context(error::InvalidBase64ValueSnafu)?;
275    let as_utf8 =
276        String::from_utf8(decoded).context(error::InvalidAuthHeaderInvalidUtf8ValueSnafu)?;
277
278    if let Some((user_id, password)) = as_utf8.split_once(':') {
279        return Ok((user_id.to_string(), password.to_string().into()));
280    }
281
282    InvalidAuthHeaderSnafu {}.fail()
283}
284
285fn need_auth<B>(req: &Request<B>) -> bool {
286    let path = req.uri().path();
287
288    for api in PUBLIC_APIS {
289        if path.starts_with(api) {
290            return false;
291        }
292    }
293
294    path.starts_with(HTTP_API_PREFIX)
295}
296
297fn extract_param_from_query<'a>(query: &'a str, param: &'a str) -> Option<&'a str> {
298    let prefix = format!("{}=", param);
299    for pair in query.split('&') {
300        if let Some(param) = pair.strip_prefix(&prefix) {
301            return if param.is_empty() { None } else { Some(param) };
302        }
303    }
304    None
305}
306
307fn extract_db_from_query(query: &str) -> Option<&str> {
308    extract_param_from_query(query, "db")
309}
310
311/// InfluxDB v2 uses "bucket" instead of "db"
312/// https://docs.influxdata.com/influxdb/v1/tools/api/#apiv2write-http-endpoint
313fn extract_bucket_from_query(query: &str) -> Option<&str> {
314    extract_param_from_query(query, "bucket")
315}
316
317fn extract_influxdb_user_from_query(query: &str) -> (Option<&str>, Option<&str>) {
318    let mut username = None;
319    let mut password = None;
320
321    for pair in query.split('&') {
322        if pair.starts_with("u=") && pair.len() > 2 {
323            username = Some(&pair[2..]);
324        } else if pair.starts_with("p=") && pair.len() > 2 {
325            password = Some(&pair[2..]);
326        }
327    }
328    (username, password)
329}
330
331#[cfg(test)]
332mod tests {
333    use std::assert_matches::assert_matches;
334
335    use common_base::secrets::ExposeSecret;
336
337    use super::*;
338
339    #[test]
340    fn test_need_auth() {
341        let req = Request::builder()
342            .uri("http://127.0.0.1/v1/influxdb/ping")
343            .body(())
344            .unwrap();
345
346        assert!(!need_auth(&req));
347
348        let req = Request::builder()
349            .uri("http://127.0.0.1/v1/influxdb/health")
350            .body(())
351            .unwrap();
352
353        assert!(!need_auth(&req));
354
355        let req = Request::builder()
356            .uri("http://127.0.0.1/v1/influxdb/write")
357            .body(())
358            .unwrap();
359
360        assert!(need_auth(&req));
361    }
362
363    #[test]
364    fn test_decode_basic() {
365        // base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
366        let credential = "dXNlcm5hbWU6cGFzc3dvcmQ=";
367        let (username, pwd) = decode_basic(credential).unwrap();
368        assert_eq!("username", username);
369        assert_eq!("password", pwd.expose_secret());
370
371        let wrong_credential = "dXNlcm5hbWU6cG Fzc3dvcmQ=";
372        let result = decode_basic(wrong_credential);
373        assert_matches!(result.err(), Some(error::Error::InvalidBase64Value { .. }));
374    }
375
376    #[test]
377    fn test_try_into_auth_scheme() {
378        let auth_scheme_str = "basic";
379        let re: Result<AuthScheme> = auth_scheme_str.try_into();
380        assert!(re.is_err());
381
382        let auth_scheme_str = "basic dGVzdDp0ZXN0";
383        let scheme: AuthScheme = auth_scheme_str.try_into().unwrap();
384        assert_matches!(scheme, AuthScheme::Basic(username, pwd) if username == "test" && pwd.expose_secret() == "test");
385
386        let unsupported = "digest";
387        let auth_scheme: Result<AuthScheme> = unsupported.try_into();
388        assert!(auth_scheme.is_err());
389    }
390
391    #[test]
392    fn test_auth_header() {
393        // base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
394        let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
395
396        let auth_scheme = auth_header(&req).unwrap();
397        assert_matches!(auth_scheme, AuthScheme::Basic(username, pwd) if username == "username" && pwd.expose_secret() == "password");
398
399        let wrong_req = mock_http_request(Some("Basic dXNlcm5hbWU6 cGFzc3dvcmQ="), None).unwrap();
400        let res = auth_header(&wrong_req);
401        assert_matches!(res.err(), Some(error::Error::InvalidAuthHeader { .. }));
402
403        let wrong_req = mock_http_request(Some("Digest dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
404        let res = auth_header(&wrong_req);
405        assert_matches!(res.err(), Some(error::Error::UnsupportedAuthScheme { .. }));
406    }
407
408    fn mock_http_request(auth_header: Option<&str>, uri: Option<&str>) -> Result<Request<()>> {
409        let http_api_version = crate::http::HTTP_API_VERSION;
410        let mut req = Request::builder()
411            .uri(uri.unwrap_or(format!("http://localhost/{http_api_version}/sql").as_str()));
412        if let Some(auth_header) = auth_header {
413            req = req.header(http::header::AUTHORIZATION, auth_header);
414        }
415
416        Ok(req.body(()).unwrap())
417    }
418
419    #[test]
420    fn test_db_name_header() {
421        let http_api_version = crate::http::HTTP_API_VERSION;
422        let req = Request::builder()
423            .uri(format!("http://localhost/{http_api_version}/sql").as_str())
424            .header(GreptimeDbName::name(), "greptime-tomcat")
425            .body(())
426            .unwrap();
427
428        let db = extract_catalog_and_schema(&req);
429        assert_eq!(db, ("greptime".to_string(), "tomcat".to_string()));
430    }
431
432    #[test]
433    fn test_extract_db() {
434        assert_matches!(extract_db_from_query(""), None);
435        assert_matches!(extract_db_from_query("&"), None);
436        assert_matches!(extract_db_from_query("db="), None);
437        assert_matches!(extract_bucket_from_query("bucket="), None);
438        assert_matches!(extract_bucket_from_query("db=foo"), None);
439        assert_matches!(extract_db_from_query("db=foo"), Some("foo"));
440        assert_matches!(extract_bucket_from_query("bucket=foo"), Some("foo"));
441        assert_matches!(extract_db_from_query("name=bar"), None);
442        assert_matches!(extract_db_from_query("db=&name=bar"), None);
443        assert_matches!(extract_db_from_query("db=foo&name=bar"), Some("foo"));
444        assert_matches!(extract_bucket_from_query("db=foo&bucket=bar"), Some("bar"));
445        assert_matches!(extract_db_from_query("name=bar&db="), None);
446        assert_matches!(extract_db_from_query("name=bar&db=foo"), Some("foo"));
447        assert_matches!(extract_db_from_query("name=bar&db=&name=bar"), None);
448        assert_matches!(
449            extract_db_from_query("name=bar&db=foo&name=bar"),
450            Some("foo")
451        );
452    }
453
454    #[test]
455    fn test_extract_user() {
456        assert_matches!(extract_influxdb_user_from_query(""), (None, None));
457        assert_matches!(extract_influxdb_user_from_query("u="), (None, None));
458        assert_matches!(
459            extract_influxdb_user_from_query("u=123"),
460            (Some("123"), None)
461        );
462        assert_matches!(
463            extract_influxdb_user_from_query("u=123&p="),
464            (Some("123"), None)
465        );
466        assert_matches!(
467            extract_influxdb_user_from_query("u=123&p=4"),
468            (Some("123"), Some("4"))
469        );
470        assert_matches!(extract_influxdb_user_from_query("p="), (None, None));
471        assert_matches!(extract_influxdb_user_from_query("p=4"), (None, Some("4")));
472        assert_matches!(
473            extract_influxdb_user_from_query("p=4&u="),
474            (None, Some("4"))
475        );
476        assert_matches!(
477            extract_influxdb_user_from_query("p=4&u=123"),
478            (Some("123"), Some("4"))
479        );
480    }
481}