servers/postgres/
auth_handler.rs1use std::fmt::Debug;
16use std::sync::Exclusive;
17
18use ::auth::{userinfo_by_name, Identity, Password, UserInfoRef, UserProviderRef};
19use async_trait::async_trait;
20use common_catalog::parse_catalog_and_schema_from_db_string;
21use common_error::ext::ErrorExt;
22use futures::{Sink, SinkExt};
23use pgwire::api::auth::StartupHandler;
24use pgwire::api::{auth, ClientInfo, PgWireConnectionState};
25use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
26use pgwire::messages::response::ErrorResponse;
27use pgwire::messages::startup::Authentication;
28use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
29use session::Session;
30use snafu::IntoError;
31
32use crate::error::{AuthSnafu, Result};
33use crate::metrics::METRIC_AUTH_FAILURE;
34use crate::postgres::types::PgErrorCode;
35use crate::postgres::PostgresServerHandlerInner;
36use crate::query_handler::sql::ServerSqlQueryHandlerRef;
37
38pub(crate) struct PgLoginVerifier {
39 user_provider: Option<UserProviderRef>,
40}
41
42impl PgLoginVerifier {
43 pub(crate) fn new(user_provider: Option<UserProviderRef>) -> Self {
44 Self { user_provider }
45 }
46}
47
48#[allow(dead_code)]
49struct LoginInfo {
50 user: Option<String>,
51 catalog: Option<String>,
52 schema: Option<String>,
53 host: String,
54}
55
56impl LoginInfo {
57 pub fn from_client_info<C>(client: &C) -> LoginInfo
58 where
59 C: ClientInfo,
60 {
61 LoginInfo {
62 user: client.metadata().get(super::METADATA_USER).map(Into::into),
63 catalog: client
64 .metadata()
65 .get(super::METADATA_CATALOG)
66 .map(Into::into),
67 schema: client
68 .metadata()
69 .get(super::METADATA_SCHEMA)
70 .map(Into::into),
71 host: client.socket_addr().ip().to_string(),
72 }
73 }
74}
75
76impl PgLoginVerifier {
77 async fn auth(&self, login: &LoginInfo, password: &str) -> Result<Option<UserInfoRef>> {
78 let user_provider = match &self.user_provider {
79 Some(provider) => provider,
80 None => return Ok(None),
81 };
82
83 let user_name = match &login.user {
84 Some(name) => name,
85 None => return Ok(None),
86 };
87 let catalog = match &login.catalog {
88 Some(name) => name,
89 None => return Ok(None),
90 };
91 let schema = match &login.schema {
92 Some(name) => name,
93 None => return Ok(None),
94 };
95
96 match user_provider
97 .auth(
98 Identity::UserId(user_name, None),
99 Password::PlainText(password.to_string().into()),
100 catalog,
101 schema,
102 )
103 .await
104 {
105 Err(e) => {
106 METRIC_AUTH_FAILURE
107 .with_label_values(&[e.status_code().as_ref()])
108 .inc();
109 Err(AuthSnafu.into_error(e))
110 }
111 Ok(user_info) => Ok(Some(user_info)),
112 }
113 }
114}
115
116fn set_client_info<C>(client: &mut C, session: &Session)
117where
118 C: ClientInfo,
119{
120 if let Some(current_catalog) = client.metadata().get(super::METADATA_CATALOG) {
121 session.set_catalog(current_catalog.clone());
122 }
123 if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) {
124 session.set_schema(current_schema.clone());
125 }
126
127 client.set_pid_and_secret_key(session.process_id() as i32, rand::random::<i32>());
130 }
132
133#[async_trait]
134impl StartupHandler for PostgresServerHandlerInner {
135 async fn on_startup<C>(
136 &self,
137 client: &mut C,
138 message: PgWireFrontendMessage,
139 ) -> PgWireResult<()>
140 where
141 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
142 C::Error: Debug,
143 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
144 {
145 match message {
146 PgWireFrontendMessage::Startup(ref startup) => {
147 if !client.is_secure() && self.force_tls {
149 send_error(
150 client,
151 PgErrorCode::Ec28000.to_err_info("No encryption".to_string()),
152 )
153 .await?;
154 return Ok(());
155 }
156
157 auth::save_startup_parameters_to_metadata(client, startup);
158
159 match resolve_db_info(Exclusive::new(client), self.query_handler.clone()).await? {
161 DbResolution::Resolved(catalog, schema) => {
162 let metadata = client.metadata_mut();
163 let _ = metadata.insert(super::METADATA_CATALOG.to_owned(), catalog);
164 let _ = metadata.insert(super::METADATA_SCHEMA.to_owned(), schema);
165 }
166 DbResolution::NotFound(msg) => {
167 send_error(client, PgErrorCode::Ec3D000.to_err_info(msg)).await?;
168 return Ok(());
169 }
170 }
171
172 if self.login_verifier.user_provider.is_some() {
173 client.set_state(PgWireConnectionState::AuthenticationInProgress);
174 client
175 .send(PgWireBackendMessage::Authentication(
176 Authentication::CleartextPassword,
177 ))
178 .await?;
179 } else {
180 self.session.set_user_info(userinfo_by_name(
181 client.metadata().get(super::METADATA_USER).cloned(),
182 ));
183 set_client_info(client, &self.session);
184 auth::finish_authentication(client, self.param_provider.as_ref()).await?;
185 }
186 }
187 PgWireFrontendMessage::PasswordMessageFamily(pwd) => {
188 let pwd = pwd.into_password()?;
192
193 let login_info = LoginInfo::from_client_info(client);
194
195 let auth_result = self.login_verifier.auth(&login_info, &pwd.password).await;
197
198 if let Ok(Some(user_info)) = auth_result {
199 self.session.set_user_info(user_info);
200 set_client_info(client, &self.session);
201 auth::finish_authentication(client, self.param_provider.as_ref()).await?;
202 } else {
203 return send_error(
204 client,
205 PgErrorCode::Ec28P01
206 .to_err_info("password authentication failed".to_string()),
207 )
208 .await;
209 }
210 }
211 _ => {}
212 }
213 Ok(())
214 }
215}
216
217async fn send_error<C>(client: &mut C, err_info: ErrorInfo) -> PgWireResult<()>
218where
219 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
220 C::Error: Debug,
221 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
222{
223 let error = ErrorResponse::from(err_info);
224 client
225 .feed(PgWireBackendMessage::ErrorResponse(error))
226 .await?;
227 client.close().await?;
228 Ok(())
229}
230
231enum DbResolution {
232 Resolved(String, String),
233 NotFound(String),
234}
235
236async fn resolve_db_info<C>(
238 client: Exclusive<&mut C>,
239 query_handler: ServerSqlQueryHandlerRef,
240) -> PgWireResult<DbResolution>
241where
242 C: ClientInfo + Unpin + Send,
243{
244 let db_ref = client.into_inner().metadata().get(super::METADATA_DATABASE);
245 if let Some(db) = db_ref {
246 let (catalog, schema) = parse_catalog_and_schema_from_db_string(db);
247 if query_handler
248 .is_valid_schema(&catalog, &schema)
249 .await
250 .map_err(|e| PgWireError::ApiError(Box::new(e)))?
251 {
252 Ok(DbResolution::Resolved(catalog, schema))
253 } else {
254 Ok(DbResolution::NotFound(format!("Database not found: {db}")))
255 }
256 } else {
257 Ok(DbResolution::NotFound("Database not specified".to_owned()))
258 }
259}