servers/postgres/
auth_handler.rs1use std::fmt::Debug;
16use std::sync::Exclusive;
17
18use ::auth::{userinfo_by_name, Identity, Password, UserInfoRef, UserProviderRef};
19use async_trait::async_trait;
20use common_catalog::parse_catalog_and_schema_from_db_string;
21use common_error::ext::ErrorExt;
22use futures::{Sink, SinkExt};
23use pgwire::api::auth::StartupHandler;
24use pgwire::api::{auth, ClientInfo, PgWireConnectionState};
25use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
26use pgwire::messages::response::ErrorResponse;
27use pgwire::messages::startup::Authentication;
28use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
29use session::Session;
30use snafu::IntoError;
31
32use crate::error::{AuthSnafu, Result};
33use crate::metrics::METRIC_AUTH_FAILURE;
34use crate::postgres::types::PgErrorCode;
35use crate::postgres::PostgresServerHandlerInner;
36use crate::query_handler::sql::ServerSqlQueryHandlerRef;
37
38pub(crate) struct PgLoginVerifier {
39 user_provider: Option<UserProviderRef>,
40}
41
42impl PgLoginVerifier {
43 pub(crate) fn new(user_provider: Option<UserProviderRef>) -> Self {
44 Self { user_provider }
45 }
46}
47
48#[allow(dead_code)]
49struct LoginInfo {
50 user: Option<String>,
51 catalog: Option<String>,
52 schema: Option<String>,
53 host: String,
54}
55
56impl LoginInfo {
57 pub fn from_client_info<C>(client: &C) -> LoginInfo
58 where
59 C: ClientInfo,
60 {
61 LoginInfo {
62 user: client.metadata().get(super::METADATA_USER).map(Into::into),
63 catalog: client
64 .metadata()
65 .get(super::METADATA_CATALOG)
66 .map(Into::into),
67 schema: client
68 .metadata()
69 .get(super::METADATA_SCHEMA)
70 .map(Into::into),
71 host: client.socket_addr().ip().to_string(),
72 }
73 }
74}
75
76impl PgLoginVerifier {
77 async fn auth(&self, login: &LoginInfo, password: &str) -> Result<Option<UserInfoRef>> {
78 let user_provider = match &self.user_provider {
79 Some(provider) => provider,
80 None => return Ok(None),
81 };
82
83 let user_name = match &login.user {
84 Some(name) => name,
85 None => return Ok(None),
86 };
87 let catalog = match &login.catalog {
88 Some(name) => name,
89 None => return Ok(None),
90 };
91 let schema = match &login.schema {
92 Some(name) => name,
93 None => return Ok(None),
94 };
95
96 match user_provider
97 .auth(
98 Identity::UserId(user_name, None),
99 Password::PlainText(password.to_string().into()),
100 catalog,
101 schema,
102 )
103 .await
104 {
105 Err(e) => {
106 METRIC_AUTH_FAILURE
107 .with_label_values(&[e.status_code().as_ref()])
108 .inc();
109 Err(AuthSnafu.into_error(e))
110 }
111 Ok(user_info) => Ok(Some(user_info)),
112 }
113 }
114}
115
116fn set_client_info<C>(client: &C, session: &Session)
117where
118 C: ClientInfo,
119{
120 if let Some(current_catalog) = client.metadata().get(super::METADATA_CATALOG) {
121 session.set_catalog(current_catalog.clone());
122 }
123 if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) {
124 session.set_schema(current_schema.clone());
125 }
126 }
128
129#[async_trait]
130impl StartupHandler for PostgresServerHandlerInner {
131 async fn on_startup<C>(
132 &self,
133 client: &mut C,
134 message: PgWireFrontendMessage,
135 ) -> PgWireResult<()>
136 where
137 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
138 C::Error: Debug,
139 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
140 {
141 match message {
142 PgWireFrontendMessage::Startup(ref startup) => {
143 if !client.is_secure() && self.force_tls {
145 send_error(
146 client,
147 PgErrorCode::Ec28000.to_err_info("No encryption".to_string()),
148 )
149 .await?;
150 return Ok(());
151 }
152
153 auth::save_startup_parameters_to_metadata(client, startup);
154
155 match resolve_db_info(Exclusive::new(client), self.query_handler.clone()).await? {
157 DbResolution::Resolved(catalog, schema) => {
158 let metadata = client.metadata_mut();
159 let _ = metadata.insert(super::METADATA_CATALOG.to_owned(), catalog);
160 let _ = metadata.insert(super::METADATA_SCHEMA.to_owned(), schema);
161 }
162 DbResolution::NotFound(msg) => {
163 send_error(client, PgErrorCode::Ec3D000.to_err_info(msg)).await?;
164 return Ok(());
165 }
166 }
167
168 if self.login_verifier.user_provider.is_some() {
169 client.set_state(PgWireConnectionState::AuthenticationInProgress);
170 client
171 .send(PgWireBackendMessage::Authentication(
172 Authentication::CleartextPassword,
173 ))
174 .await?;
175 } else {
176 self.session.set_user_info(userinfo_by_name(
177 client.metadata().get(super::METADATA_USER).cloned(),
178 ));
179 set_client_info(client, &self.session);
180 auth::finish_authentication(client, self.param_provider.as_ref()).await?;
181 }
182 }
183 PgWireFrontendMessage::PasswordMessageFamily(pwd) => {
184 let pwd = pwd.into_password()?;
188
189 let login_info = LoginInfo::from_client_info(client);
190
191 let auth_result = self.login_verifier.auth(&login_info, &pwd.password).await;
193
194 if let Ok(Some(user_info)) = auth_result {
195 self.session.set_user_info(user_info);
196 set_client_info(client, &self.session);
197 auth::finish_authentication(client, self.param_provider.as_ref()).await?;
198 } else {
199 return send_error(
200 client,
201 PgErrorCode::Ec28P01
202 .to_err_info("password authentication failed".to_string()),
203 )
204 .await;
205 }
206 }
207 _ => {}
208 }
209 Ok(())
210 }
211}
212
213async fn send_error<C>(client: &mut C, err_info: ErrorInfo) -> PgWireResult<()>
214where
215 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
216 C::Error: Debug,
217 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
218{
219 let error = ErrorResponse::from(err_info);
220 client
221 .feed(PgWireBackendMessage::ErrorResponse(error))
222 .await?;
223 client.close().await?;
224 Ok(())
225}
226
227enum DbResolution {
228 Resolved(String, String),
229 NotFound(String),
230}
231
232async fn resolve_db_info<C>(
234 client: Exclusive<&mut C>,
235 query_handler: ServerSqlQueryHandlerRef,
236) -> PgWireResult<DbResolution>
237where
238 C: ClientInfo + Unpin + Send,
239{
240 let db_ref = client.into_inner().metadata().get(super::METADATA_DATABASE);
241 if let Some(db) = db_ref {
242 let (catalog, schema) = parse_catalog_and_schema_from_db_string(db);
243 if query_handler
244 .is_valid_schema(&catalog, &schema)
245 .await
246 .map_err(|e| PgWireError::ApiError(Box::new(e)))?
247 {
248 Ok(DbResolution::Resolved(catalog, schema))
249 } else {
250 Ok(DbResolution::NotFound(format!("Database not found: {db}")))
251 }
252 } else {
253 Ok(DbResolution::NotFound("Database not specified".to_owned()))
254 }
255}