servers/grpc/
authorize.rs1use std::pin::Pin;
16use std::result::Result as StdResult;
17use std::task::{Context, Poll};
18
19use auth::UserProviderRef;
20use session::context::{Channel, QueryContext};
21use tonic::body::BoxBody;
22use tonic::server::NamedService;
23use tower::{Layer, Service};
24
25use crate::http::authorize::{extract_catalog_and_schema, extract_username_and_password};
26
27#[derive(Clone)]
28pub struct AuthMiddlewareLayer {
29 user_provider: Option<UserProviderRef>,
30}
31
32impl<S> Layer<S> for AuthMiddlewareLayer {
33 type Service = AuthMiddleware<S>;
34
35 fn layer(&self, service: S) -> Self::Service {
36 AuthMiddleware {
37 inner: service,
38 user_provider: self.user_provider.clone(),
39 }
40 }
41}
42
43#[derive(Clone)]
49pub struct AuthMiddleware<S> {
50 inner: S,
51 user_provider: Option<UserProviderRef>,
52}
53
54impl<S> NamedService for AuthMiddleware<S>
55where
56 S: NamedService,
57{
58 const NAME: &'static str = S::NAME;
59}
60
61type BoxFuture<'a, T> = Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
62
63impl<S> Service<http::Request<BoxBody>> for AuthMiddleware<S>
64where
65 S: Service<http::Request<BoxBody>, Response = http::Response<BoxBody>> + Clone + Send + 'static,
66 S::Future: Send + 'static,
67{
68 type Response = S::Response;
69 type Error = S::Error;
70 type Future = BoxFuture<'static, StdResult<Self::Response, Self::Error>>;
71
72 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<StdResult<(), Self::Error>> {
73 self.inner.poll_ready(cx)
74 }
75
76 fn call(&mut self, mut req: http::Request<BoxBody>) -> Self::Future {
77 let clone = self.inner.clone();
81 let mut inner = std::mem::replace(&mut self.inner, clone);
82
83 let user_provider = self.user_provider.clone();
84
85 Box::pin(async move {
86 if let Err(status) = do_auth(&mut req, user_provider).await {
87 return Ok(status.into_http());
88 }
89 inner.call(req).await
90 })
91 }
92}
93
94async fn do_auth<T>(
95 req: &mut http::Request<T>,
96 user_provider: Option<UserProviderRef>,
97) -> Result<(), tonic::Status> {
98 let (catalog, schema) = extract_catalog_and_schema(req);
99
100 let query_ctx = QueryContext::with_channel(&catalog, &schema, Channel::Grpc);
101
102 let Some(user_provider) = user_provider else {
103 query_ctx.set_current_user(auth::userinfo_by_name(None));
104 let _ = req.extensions_mut().insert(query_ctx);
105 return Ok(());
106 };
107
108 let (username, password) = extract_username_and_password(req)
109 .map_err(|e| tonic::Status::invalid_argument(e.to_string()))?;
110
111 let id = auth::Identity::UserId(&username, None);
112 let pwd = auth::Password::PlainText(password);
113
114 let user_info = user_provider
115 .auth(id, pwd, &catalog, &schema)
116 .await
117 .map_err(|e| tonic::Status::unauthenticated(e.to_string()))?;
118
119 query_ctx.set_current_user(user_info);
120 let _ = req.extensions_mut().insert(query_ctx);
121
122 Ok(())
123}
124
125#[cfg(test)]
126mod tests {
127 use std::sync::Arc;
128
129 use auth::tests::MockUserProvider;
130 use base64::engine::general_purpose::STANDARD;
131 use base64::Engine;
132 use headers::Header;
133 use hyper::Request;
134 use session::context::QueryContext;
135
136 use crate::grpc::authorize::do_auth;
137 use crate::http::header::GreptimeDbName;
138
139 #[tokio::test]
140 async fn test_do_auth_with_user_provider() {
141 let user_provider = Arc::new(MockUserProvider::default());
142
143 let authorization_val = format!("Basic {}", STANDARD.encode("greptime:greptime"));
145 let mut req = Request::new(());
146 req.headers_mut()
147 .insert("authorization", authorization_val.parse().unwrap());
148
149 let auth_result = do_auth(&mut req, Some(user_provider.clone())).await;
150
151 assert!(auth_result.is_ok());
152 check_req(&req, "greptime", "public", "greptime");
153
154 let authorization_val = format!("Basic {}", STANDARD.encode("greptime2:greptime2"));
156 let mut req = Request::new(());
157 req.headers_mut()
158 .insert("authorization", authorization_val.parse().unwrap());
159
160 let auth_result = do_auth(&mut req, Some(user_provider)).await;
161 assert!(auth_result.is_err());
162 }
163
164 #[tokio::test]
165 async fn test_do_auth_without_user_provider() {
166 let mut req = Request::new(());
167 req.headers_mut()
168 .insert("authentication", "pwd".parse().unwrap());
169 let auth_result = do_auth(&mut req, None).await;
170 assert!(auth_result.is_ok());
171 check_req(&req, "greptime", "public", "greptime");
172
173 let mut req = Request::new(());
174 let auth_result = do_auth(&mut req, None).await;
175 assert!(auth_result.is_ok());
176 check_req(&req, "greptime", "public", "greptime");
177
178 let mut req = Request::new(());
179 req.headers_mut()
180 .insert(GreptimeDbName::name(), "catalog-schema".parse().unwrap());
181 let auth_result = do_auth(&mut req, None).await;
182 assert!(auth_result.is_ok());
183 check_req(&req, "catalog", "schema", "greptime");
184 }
185
186 fn check_req<T>(
187 req: &Request<T>,
188 expected_catalog: &str,
189 expected_schema: &str,
190 expected_user_name: &str,
191 ) {
192 let ctx = req.extensions().get::<QueryContext>().unwrap();
193 assert_eq!(expected_catalog, ctx.current_catalog());
194 assert_eq!(expected_schema, ctx.current_schema());
195
196 let user_info = ctx.current_user();
197 assert_eq!(expected_user_name, user_info.username());
198 }
199}