servers/grpc/
authorize.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::result::Result as StdResult;
17use std::task::{Context, Poll};
18
19use auth::UserProviderRef;
20use session::context::{Channel, QueryContext};
21use tonic::body::BoxBody;
22use tonic::server::NamedService;
23use tower::{Layer, Service};
24
25use crate::http::authorize::{extract_catalog_and_schema, extract_username_and_password};
26
27#[derive(Clone)]
28pub struct AuthMiddlewareLayer {
29    user_provider: Option<UserProviderRef>,
30}
31
32impl<S> Layer<S> for AuthMiddlewareLayer {
33    type Service = AuthMiddleware<S>;
34
35    fn layer(&self, service: S) -> Self::Service {
36        AuthMiddleware {
37            inner: service,
38            user_provider: self.user_provider.clone(),
39        }
40    }
41}
42
43/// This middleware is responsible for authenticating the user and setting the user
44/// info in the request extension.
45///
46/// Detail: Authorization information is passed in through the Authorization request
47/// header.
48#[derive(Clone)]
49pub struct AuthMiddleware<S> {
50    inner: S,
51    user_provider: Option<UserProviderRef>,
52}
53
54impl<S> NamedService for AuthMiddleware<S>
55where
56    S: NamedService,
57{
58    const NAME: &'static str = S::NAME;
59}
60
61type BoxFuture<'a, T> = Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
62
63impl<S> Service<http::Request<BoxBody>> for AuthMiddleware<S>
64where
65    S: Service<http::Request<BoxBody>, Response = http::Response<BoxBody>> + Clone + Send + 'static,
66    S::Future: Send + 'static,
67{
68    type Response = S::Response;
69    type Error = S::Error;
70    type Future = BoxFuture<'static, StdResult<Self::Response, Self::Error>>;
71
72    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<StdResult<(), Self::Error>> {
73        self.inner.poll_ready(cx)
74    }
75
76    fn call(&mut self, mut req: http::Request<BoxBody>) -> Self::Future {
77        // This is necessary because tonic internally uses `tower::buffer::Buffer`.
78        // See https://github.com/tower-rs/tower/issues/547#issuecomment-767629149
79        // for details on why this is necessary.
80        let clone = self.inner.clone();
81        let mut inner = std::mem::replace(&mut self.inner, clone);
82
83        let user_provider = self.user_provider.clone();
84
85        Box::pin(async move {
86            if let Err(status) = do_auth(&mut req, user_provider).await {
87                return Ok(status.into_http());
88            }
89            inner.call(req).await
90        })
91    }
92}
93
94async fn do_auth<T>(
95    req: &mut http::Request<T>,
96    user_provider: Option<UserProviderRef>,
97) -> Result<(), tonic::Status> {
98    let (catalog, schema) = extract_catalog_and_schema(req);
99
100    let query_ctx = QueryContext::with_channel(&catalog, &schema, Channel::Grpc);
101
102    let Some(user_provider) = user_provider else {
103        query_ctx.set_current_user(auth::userinfo_by_name(None));
104        let _ = req.extensions_mut().insert(query_ctx);
105        return Ok(());
106    };
107
108    let (username, password) = extract_username_and_password(req)
109        .map_err(|e| tonic::Status::invalid_argument(e.to_string()))?;
110
111    let id = auth::Identity::UserId(&username, None);
112    let pwd = auth::Password::PlainText(password);
113
114    let user_info = user_provider
115        .auth(id, pwd, &catalog, &schema)
116        .await
117        .map_err(|e| tonic::Status::unauthenticated(e.to_string()))?;
118
119    query_ctx.set_current_user(user_info);
120    let _ = req.extensions_mut().insert(query_ctx);
121
122    Ok(())
123}
124
125#[cfg(test)]
126mod tests {
127    use std::sync::Arc;
128
129    use auth::tests::MockUserProvider;
130    use base64::engine::general_purpose::STANDARD;
131    use base64::Engine;
132    use headers::Header;
133    use hyper::Request;
134    use session::context::QueryContext;
135
136    use crate::grpc::authorize::do_auth;
137    use crate::http::header::GreptimeDbName;
138
139    #[tokio::test]
140    async fn test_do_auth_with_user_provider() {
141        let user_provider = Arc::new(MockUserProvider::default());
142
143        // auth success
144        let authorization_val = format!("Basic {}", STANDARD.encode("greptime:greptime"));
145        let mut req = Request::new(());
146        req.headers_mut()
147            .insert("authorization", authorization_val.parse().unwrap());
148
149        let auth_result = do_auth(&mut req, Some(user_provider.clone())).await;
150
151        assert!(auth_result.is_ok());
152        check_req(&req, "greptime", "public", "greptime");
153
154        // auth failed, err: user not exist.
155        let authorization_val = format!("Basic {}", STANDARD.encode("greptime2:greptime2"));
156        let mut req = Request::new(());
157        req.headers_mut()
158            .insert("authorization", authorization_val.parse().unwrap());
159
160        let auth_result = do_auth(&mut req, Some(user_provider)).await;
161        assert!(auth_result.is_err());
162    }
163
164    #[tokio::test]
165    async fn test_do_auth_without_user_provider() {
166        let mut req = Request::new(());
167        req.headers_mut()
168            .insert("authentication", "pwd".parse().unwrap());
169        let auth_result = do_auth(&mut req, None).await;
170        assert!(auth_result.is_ok());
171        check_req(&req, "greptime", "public", "greptime");
172
173        let mut req = Request::new(());
174        let auth_result = do_auth(&mut req, None).await;
175        assert!(auth_result.is_ok());
176        check_req(&req, "greptime", "public", "greptime");
177
178        let mut req = Request::new(());
179        req.headers_mut()
180            .insert(GreptimeDbName::name(), "catalog-schema".parse().unwrap());
181        let auth_result = do_auth(&mut req, None).await;
182        assert!(auth_result.is_ok());
183        check_req(&req, "catalog", "schema", "greptime");
184    }
185
186    fn check_req<T>(
187        req: &Request<T>,
188        expected_catalog: &str,
189        expected_schema: &str,
190        expected_user_name: &str,
191    ) {
192        let ctx = req.extensions().get::<QueryContext>().unwrap();
193        assert_eq!(expected_catalog, ctx.current_catalog());
194        assert_eq!(expected_schema, ctx.current_schema());
195
196        let user_info = ctx.current_user();
197        assert_eq!(expected_user_name, user_info.username());
198    }
199}