servers/http/
authorize.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use ::auth::UserProviderRef;
16use axum::extract::{Request, State};
17use axum::http::{self, StatusCode};
18use axum::middleware::Next;
19use axum::response::{IntoResponse, Response};
20use base64::prelude::BASE64_STANDARD;
21use base64::Engine;
22use common_base::secrets::SecretString;
23use common_catalog::consts::DEFAULT_SCHEMA_NAME;
24use common_catalog::parse_catalog_and_schema_from_db_string;
25use common_error::ext::ErrorExt;
26use common_telemetry::warn;
27use common_time::timezone::parse_timezone;
28use common_time::Timezone;
29use headers::Header;
30use session::context::QueryContextBuilder;
31use snafu::{ensure, OptionExt, ResultExt};
32
33use crate::error::{
34    self, InvalidAuthHeaderInvisibleASCIISnafu, InvalidAuthHeaderSnafu, InvalidParameterSnafu,
35    NotFoundInfluxAuthSnafu, Result, UnsupportedAuthSchemeSnafu, UrlDecodeSnafu,
36};
37use crate::http::header::{GreptimeDbName, GREPTIME_TIMEZONE_HEADER_NAME};
38use crate::http::result::error_result::ErrorResponse;
39use crate::http::{AUTHORIZATION_HEADER, HTTP_API_PREFIX, PUBLIC_APIS};
40use crate::influxdb::{is_influxdb_request, is_influxdb_v2_request};
41
42/// AuthState is a holder state for [`UserProviderRef`]
43/// during [`check_http_auth`] function in axum's middleware
44#[derive(Clone)]
45pub struct AuthState {
46    user_provider: Option<UserProviderRef>,
47}
48
49impl AuthState {
50    pub fn new(user_provider: Option<UserProviderRef>) -> Self {
51        Self { user_provider }
52    }
53}
54
55pub async fn inner_auth<B>(
56    user_provider: Option<UserProviderRef>,
57    mut req: Request<B>,
58) -> std::result::Result<Request<B>, Response> {
59    // 1. prepare
60    let (catalog, schema) = extract_catalog_and_schema(&req);
61    // TODO(ruihang): move this out of auth module
62    let timezone = extract_timezone(&req);
63    let query_ctx_builder = QueryContextBuilder::default()
64        .current_catalog(catalog.clone())
65        .current_schema(schema.clone())
66        .timezone(timezone);
67
68    let query_ctx = query_ctx_builder.build();
69    let need_auth = need_auth(&req);
70
71    // 2. check if auth is needed
72    let user_provider = if let Some(user_provider) = user_provider.filter(|_| need_auth) {
73        user_provider
74    } else {
75        query_ctx.set_current_user(auth::userinfo_by_name(None));
76        let _ = req.extensions_mut().insert(query_ctx);
77        return Ok(req);
78    };
79
80    // 3. get username and pwd
81    let (username, password) = match extract_username_and_password(&req) {
82        Ok((username, password)) => (username, password),
83        Err(e) => {
84            warn!(e; "extract username and password failed");
85            crate::metrics::METRIC_AUTH_FAILURE
86                .with_label_values(&[e.status_code().as_ref()])
87                .inc();
88            return Err(err_response(e));
89        }
90    };
91
92    // 4. auth
93    match user_provider
94        .auth(
95            auth::Identity::UserId(&username, None),
96            auth::Password::PlainText(password),
97            &catalog,
98            &schema,
99        )
100        .await
101    {
102        Ok(userinfo) => {
103            query_ctx.set_current_user(userinfo);
104            let _ = req.extensions_mut().insert(query_ctx);
105            Ok(req)
106        }
107        Err(e) => {
108            warn!(e; "authenticate failed");
109            crate::metrics::METRIC_AUTH_FAILURE
110                .with_label_values(&[e.status_code().as_ref()])
111                .inc();
112            Err(err_response(e))
113        }
114    }
115}
116
117pub async fn check_http_auth(
118    State(auth_state): State<AuthState>,
119    req: Request,
120    next: Next,
121) -> Response {
122    match inner_auth(auth_state.user_provider, req).await {
123        Ok(req) => next.run(req).await,
124        Err(resp) => resp,
125    }
126}
127
128fn err_response(err: impl ErrorExt) -> Response {
129    (StatusCode::UNAUTHORIZED, ErrorResponse::from_error(err)).into_response()
130}
131
132pub fn extract_catalog_and_schema<B>(request: &Request<B>) -> (String, String) {
133    // parse database from header
134    let dbname = request
135        .headers()
136        .get(GreptimeDbName::name())
137        // eat this invalid ascii error and give user the final IllegalParam error
138        .and_then(|header| header.to_str().ok())
139        .or_else(|| {
140            let query = request.uri().query().unwrap_or_default();
141            if is_influxdb_v2_request(request) {
142                extract_db_from_query(query).or_else(|| extract_bucket_from_query(query))
143            } else {
144                extract_db_from_query(query)
145            }
146        })
147        .unwrap_or(DEFAULT_SCHEMA_NAME);
148
149    parse_catalog_and_schema_from_db_string(dbname)
150}
151
152fn extract_timezone<B>(request: &Request<B>) -> Timezone {
153    // parse timezone from header
154    let timezone = request
155        .headers()
156        .get(&GREPTIME_TIMEZONE_HEADER_NAME)
157        // eat this invalid ascii error and give user the final IllegalParam error
158        .and_then(|header| header.to_str().ok())
159        .unwrap_or("");
160    parse_timezone(Some(timezone))
161}
162
163fn get_influxdb_credentials<B>(request: &Request<B>) -> Result<Option<(Username, Password)>> {
164    // compat with influxdb v2 and v1
165    if let Some(header) = request.headers().get(http::header::AUTHORIZATION) {
166        // try header
167        let (auth_scheme, credential) = header
168            .to_str()
169            .context(InvalidAuthHeaderInvisibleASCIISnafu)?
170            .split_once(' ')
171            .context(InvalidAuthHeaderSnafu)?;
172
173        let (username, password) = match auth_scheme.to_lowercase().as_str() {
174            "token" => {
175                let (u, p) = credential.split_once(':').context(InvalidAuthHeaderSnafu)?;
176                (u.to_string(), p.to_string().into())
177            }
178            "basic" => decode_basic(credential)?,
179            _ => UnsupportedAuthSchemeSnafu { name: auth_scheme }.fail()?,
180        };
181
182        Ok(Some((username, password)))
183    } else {
184        // try u and p in query
185        let Some(query_str) = request.uri().query() else {
186            return Ok(None);
187        };
188
189        let query_str = urlencoding::decode(query_str).context(UrlDecodeSnafu)?;
190
191        match extract_influxdb_user_from_query(&query_str) {
192            (None, None) => Ok(None),
193            (Some(username), Some(password)) => {
194                Ok(Some((username.to_string(), password.to_string().into())))
195            }
196            _ => InvalidParameterSnafu {
197                reason: "influxdb auth: username and password must be provided together"
198                    .to_string(),
199            }
200            .fail(),
201        }
202    }
203}
204
205pub fn extract_username_and_password<B>(request: &Request<B>) -> Result<(Username, Password)> {
206    Ok(if is_influxdb_request(request) {
207        // compatible with influxdb auth
208        get_influxdb_credentials(request)?.context(NotFoundInfluxAuthSnafu)?
209    } else {
210        // normal http auth
211        let scheme = auth_header(request)?;
212        match scheme {
213            AuthScheme::Basic(username, password) => (username, password),
214        }
215    })
216}
217
218#[derive(Debug)]
219pub enum AuthScheme {
220    Basic(Username, Password),
221}
222
223type Username = String;
224type Password = SecretString;
225
226impl TryFrom<&str> for AuthScheme {
227    type Error = error::Error;
228
229    fn try_from(value: &str) -> Result<Self> {
230        let (scheme, encoded_credentials) =
231            value.split_once(' ').context(InvalidAuthHeaderSnafu)?;
232
233        ensure!(!encoded_credentials.contains(' '), InvalidAuthHeaderSnafu);
234
235        match scheme.to_lowercase().as_str() {
236            "basic" => decode_basic(encoded_credentials)
237                .map(|(username, password)| AuthScheme::Basic(username, password)),
238            other => UnsupportedAuthSchemeSnafu { name: other }.fail(),
239        }
240    }
241}
242
243type Credential<'a> = &'a str;
244
245fn auth_header<B>(req: &Request<B>) -> Result<AuthScheme> {
246    let auth_header = req
247        .headers()
248        .get(AUTHORIZATION_HEADER)
249        .or_else(|| req.headers().get(http::header::AUTHORIZATION))
250        .context(error::NotFoundAuthHeaderSnafu)?
251        .to_str()
252        .context(InvalidAuthHeaderInvisibleASCIISnafu)?;
253
254    auth_header.try_into()
255}
256
257fn decode_basic(credential: Credential) -> Result<(Username, Password)> {
258    let decoded = BASE64_STANDARD
259        .decode(credential)
260        .context(error::InvalidBase64ValueSnafu)?;
261    let as_utf8 =
262        String::from_utf8(decoded).context(error::InvalidAuthHeaderInvalidUtf8ValueSnafu)?;
263
264    if let Some((user_id, password)) = as_utf8.split_once(':') {
265        return Ok((user_id.to_string(), password.to_string().into()));
266    }
267
268    InvalidAuthHeaderSnafu {}.fail()
269}
270
271fn need_auth<B>(req: &Request<B>) -> bool {
272    let path = req.uri().path();
273
274    for api in PUBLIC_APIS {
275        if path.starts_with(api) {
276            return false;
277        }
278    }
279
280    path.starts_with(HTTP_API_PREFIX)
281}
282
283fn extract_param_from_query<'a>(query: &'a str, param: &'a str) -> Option<&'a str> {
284    let prefix = format!("{}=", param);
285    for pair in query.split('&') {
286        if let Some(param) = pair.strip_prefix(&prefix) {
287            return if param.is_empty() { None } else { Some(param) };
288        }
289    }
290    None
291}
292
293fn extract_db_from_query(query: &str) -> Option<&str> {
294    extract_param_from_query(query, "db")
295}
296
297/// InfluxDB v2 uses "bucket" instead of "db"
298/// https://docs.influxdata.com/influxdb/v1/tools/api/#apiv2write-http-endpoint
299fn extract_bucket_from_query(query: &str) -> Option<&str> {
300    extract_param_from_query(query, "bucket")
301}
302
303fn extract_influxdb_user_from_query(query: &str) -> (Option<&str>, Option<&str>) {
304    let mut username = None;
305    let mut password = None;
306
307    for pair in query.split('&') {
308        if pair.starts_with("u=") && pair.len() > 2 {
309            username = Some(&pair[2..]);
310        } else if pair.starts_with("p=") && pair.len() > 2 {
311            password = Some(&pair[2..]);
312        }
313    }
314    (username, password)
315}
316
317#[cfg(test)]
318mod tests {
319    use std::assert_matches::assert_matches;
320
321    use common_base::secrets::ExposeSecret;
322
323    use super::*;
324
325    #[test]
326    fn test_need_auth() {
327        let req = Request::builder()
328            .uri("http://127.0.0.1/v1/influxdb/ping")
329            .body(())
330            .unwrap();
331
332        assert!(!need_auth(&req));
333
334        let req = Request::builder()
335            .uri("http://127.0.0.1/v1/influxdb/health")
336            .body(())
337            .unwrap();
338
339        assert!(!need_auth(&req));
340
341        let req = Request::builder()
342            .uri("http://127.0.0.1/v1/influxdb/write")
343            .body(())
344            .unwrap();
345
346        assert!(need_auth(&req));
347    }
348
349    #[test]
350    fn test_decode_basic() {
351        // base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
352        let credential = "dXNlcm5hbWU6cGFzc3dvcmQ=";
353        let (username, pwd) = decode_basic(credential).unwrap();
354        assert_eq!("username", username);
355        assert_eq!("password", pwd.expose_secret());
356
357        let wrong_credential = "dXNlcm5hbWU6cG Fzc3dvcmQ=";
358        let result = decode_basic(wrong_credential);
359        assert_matches!(result.err(), Some(error::Error::InvalidBase64Value { .. }));
360    }
361
362    #[test]
363    fn test_try_into_auth_scheme() {
364        let auth_scheme_str = "basic";
365        let re: Result<AuthScheme> = auth_scheme_str.try_into();
366        assert!(re.is_err());
367
368        let auth_scheme_str = "basic dGVzdDp0ZXN0";
369        let scheme: AuthScheme = auth_scheme_str.try_into().unwrap();
370        assert_matches!(scheme, AuthScheme::Basic(username, pwd) if username == "test" && pwd.expose_secret() == "test");
371
372        let unsupported = "digest";
373        let auth_scheme: Result<AuthScheme> = unsupported.try_into();
374        assert!(auth_scheme.is_err());
375    }
376
377    #[test]
378    fn test_auth_header() {
379        // base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
380        let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
381
382        let auth_scheme = auth_header(&req).unwrap();
383        assert_matches!(auth_scheme, AuthScheme::Basic(username, pwd) if username == "username" && pwd.expose_secret() == "password");
384
385        let wrong_req = mock_http_request(Some("Basic dXNlcm5hbWU6 cGFzc3dvcmQ="), None).unwrap();
386        let res = auth_header(&wrong_req);
387        assert_matches!(res.err(), Some(error::Error::InvalidAuthHeader { .. }));
388
389        let wrong_req = mock_http_request(Some("Digest dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
390        let res = auth_header(&wrong_req);
391        assert_matches!(res.err(), Some(error::Error::UnsupportedAuthScheme { .. }));
392    }
393
394    fn mock_http_request(auth_header: Option<&str>, uri: Option<&str>) -> Result<Request<()>> {
395        let http_api_version = crate::http::HTTP_API_VERSION;
396        let mut req = Request::builder()
397            .uri(uri.unwrap_or(format!("http://localhost/{http_api_version}/sql").as_str()));
398        if let Some(auth_header) = auth_header {
399            req = req.header(http::header::AUTHORIZATION, auth_header);
400        }
401
402        Ok(req.body(()).unwrap())
403    }
404
405    #[test]
406    fn test_db_name_header() {
407        let http_api_version = crate::http::HTTP_API_VERSION;
408        let req = Request::builder()
409            .uri(format!("http://localhost/{http_api_version}/sql").as_str())
410            .header(GreptimeDbName::name(), "greptime-tomcat")
411            .body(())
412            .unwrap();
413
414        let db = extract_catalog_and_schema(&req);
415        assert_eq!(db, ("greptime".to_string(), "tomcat".to_string()));
416    }
417
418    #[test]
419    fn test_extract_db() {
420        assert_matches!(extract_db_from_query(""), None);
421        assert_matches!(extract_db_from_query("&"), None);
422        assert_matches!(extract_db_from_query("db="), None);
423        assert_matches!(extract_bucket_from_query("bucket="), None);
424        assert_matches!(extract_bucket_from_query("db=foo"), None);
425        assert_matches!(extract_db_from_query("db=foo"), Some("foo"));
426        assert_matches!(extract_bucket_from_query("bucket=foo"), Some("foo"));
427        assert_matches!(extract_db_from_query("name=bar"), None);
428        assert_matches!(extract_db_from_query("db=&name=bar"), None);
429        assert_matches!(extract_db_from_query("db=foo&name=bar"), Some("foo"));
430        assert_matches!(extract_bucket_from_query("db=foo&bucket=bar"), Some("bar"));
431        assert_matches!(extract_db_from_query("name=bar&db="), None);
432        assert_matches!(extract_db_from_query("name=bar&db=foo"), Some("foo"));
433        assert_matches!(extract_db_from_query("name=bar&db=&name=bar"), None);
434        assert_matches!(
435            extract_db_from_query("name=bar&db=foo&name=bar"),
436            Some("foo")
437        );
438    }
439
440    #[test]
441    fn test_extract_user() {
442        assert_matches!(extract_influxdb_user_from_query(""), (None, None));
443        assert_matches!(extract_influxdb_user_from_query("u="), (None, None));
444        assert_matches!(
445            extract_influxdb_user_from_query("u=123"),
446            (Some("123"), None)
447        );
448        assert_matches!(
449            extract_influxdb_user_from_query("u=123&p="),
450            (Some("123"), None)
451        );
452        assert_matches!(
453            extract_influxdb_user_from_query("u=123&p=4"),
454            (Some("123"), Some("4"))
455        );
456        assert_matches!(extract_influxdb_user_from_query("p="), (None, None));
457        assert_matches!(extract_influxdb_user_from_query("p=4"), (None, Some("4")));
458        assert_matches!(
459            extract_influxdb_user_from_query("p=4&u="),
460            (None, Some("4"))
461        );
462        assert_matches!(
463            extract_influxdb_user_from_query("p=4&u=123"),
464            (Some("123"), Some("4"))
465        );
466    }
467}