servers/postgres/
auth_handler.rs1use std::fmt::Debug;
16use std::sync::Exclusive;
17
18use ::auth::{Identity, Password, UserInfoRef, UserProviderRef, userinfo_by_name};
19use async_trait::async_trait;
20use common_catalog::parse_catalog_and_schema_from_db_string;
21use common_error::ext::ErrorExt;
22use common_time::Timezone;
23use futures::{Sink, SinkExt};
24use pgwire::api::auth::StartupHandler;
25use pgwire::api::{ClientInfo, PgWireConnectionState, auth};
26use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
27use pgwire::messages::response::ErrorResponse;
28use pgwire::messages::startup::{Authentication, SecretKey};
29use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
30use session::Session;
31use snafu::IntoError;
32
33use crate::error::{AuthSnafu, Result};
34use crate::metrics::METRIC_AUTH_FAILURE;
35use crate::postgres::PostgresServerHandlerInner;
36use crate::postgres::types::PgErrorCode;
37use crate::postgres::utils::convert_err;
38use crate::query_handler::sql::ServerSqlQueryHandlerRef;
39
40pub(crate) struct PgLoginVerifier {
41 user_provider: Option<UserProviderRef>,
42}
43
44impl PgLoginVerifier {
45 pub(crate) fn new(user_provider: Option<UserProviderRef>) -> Self {
46 Self { user_provider }
47 }
48}
49
50#[allow(dead_code)]
51struct LoginInfo {
52 user: Option<String>,
53 catalog: Option<String>,
54 schema: Option<String>,
55 host: String,
56}
57
58impl LoginInfo {
59 pub fn from_client_info<C>(client: &C) -> LoginInfo
60 where
61 C: ClientInfo,
62 {
63 LoginInfo {
64 user: client.metadata().get(super::METADATA_USER).map(Into::into),
65 catalog: client
66 .metadata()
67 .get(super::METADATA_CATALOG)
68 .map(Into::into),
69 schema: client
70 .metadata()
71 .get(super::METADATA_SCHEMA)
72 .map(Into::into),
73 host: client.socket_addr().ip().to_string(),
74 }
75 }
76}
77
78impl PgLoginVerifier {
79 async fn auth(&self, login: &LoginInfo, password: &str) -> Result<Option<UserInfoRef>> {
80 let user_provider = match &self.user_provider {
81 Some(provider) => provider,
82 None => return Ok(None),
83 };
84
85 let user_name = match &login.user {
86 Some(name) => name,
87 None => return Ok(None),
88 };
89 let catalog = match &login.catalog {
90 Some(name) => name,
91 None => return Ok(None),
92 };
93 let schema = match &login.schema {
94 Some(name) => name,
95 None => return Ok(None),
96 };
97
98 match user_provider
99 .auth(
100 Identity::UserId(user_name, None),
101 Password::PlainText(password.to_string().into()),
102 catalog,
103 schema,
104 )
105 .await
106 {
107 Err(e) => {
108 METRIC_AUTH_FAILURE
109 .with_label_values(&[e.status_code().as_ref()])
110 .inc();
111 Err(AuthSnafu.into_error(e))
112 }
113 Ok(user_info) => Ok(Some(user_info)),
114 }
115 }
116}
117
118fn set_client_info<C>(client: &mut C, session: &Session)
119where
120 C: ClientInfo,
121{
122 if let Some(current_catalog) = client.metadata().get(super::METADATA_CATALOG) {
123 session.set_catalog(current_catalog.clone());
124 }
125 if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) {
126 session.set_schema(current_schema.clone());
127 }
128
129 client.set_pid_and_secret_key(0, SecretKey::I32(0));
133 }
135
136#[async_trait]
137impl StartupHandler for PostgresServerHandlerInner {
138 async fn on_startup<C>(
139 &self,
140 client: &mut C,
141 message: PgWireFrontendMessage,
142 ) -> PgWireResult<()>
143 where
144 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
145 C::Error: Debug,
146 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
147 {
148 match message {
149 PgWireFrontendMessage::Startup(ref startup) => {
150 if !client.is_secure() && self.force_tls {
152 send_error(
153 client,
154 PgErrorCode::Ec28000.to_err_info("No encryption".to_string()),
155 )
156 .await?;
157 return Ok(());
158 }
159
160 auth::save_startup_parameters_to_metadata(client, startup);
161
162 match resolve_db_info(Exclusive::new(client), self.query_handler.clone()).await? {
164 DbResolution::Resolved(catalog, schema) => {
165 let metadata = client.metadata_mut();
166 let _ = metadata.insert(super::METADATA_CATALOG.to_owned(), catalog);
167 let _ = metadata.insert(super::METADATA_SCHEMA.to_owned(), schema);
168 }
169 DbResolution::NotFound(msg) => {
170 send_error(client, PgErrorCode::Ec3D000.to_err_info(msg)).await?;
171 return Ok(());
172 }
173 }
174
175 if let Some(tz) = client.metadata().get("TimeZone") {
177 match Timezone::from_tz_string(tz) {
178 Ok(tz) => self.session.set_timezone(tz),
179 Err(_) => {
180 send_error(
181 client,
182 PgErrorCode::Ec22023
183 .to_err_info(format!("Invalid TimeZone: {}", tz)),
184 )
185 .await?;
186
187 return Ok(());
188 }
189 }
190 }
191
192 if self.login_verifier.user_provider.is_some() {
193 client.set_state(PgWireConnectionState::AuthenticationInProgress);
194 client
195 .send(PgWireBackendMessage::Authentication(
196 Authentication::CleartextPassword,
197 ))
198 .await?;
199 } else {
200 self.session.set_user_info(userinfo_by_name(
201 client.metadata().get(super::METADATA_USER).cloned(),
202 ));
203 set_client_info(client, &self.session);
204 auth::finish_authentication(client, self.param_provider.as_ref()).await?;
205 }
206 }
207 PgWireFrontendMessage::PasswordMessageFamily(pwd) => {
208 let pwd = pwd.into_password()?;
212
213 let login_info = LoginInfo::from_client_info(client);
214
215 let auth_result = self.login_verifier.auth(&login_info, &pwd.password).await;
217
218 if let Ok(Some(user_info)) = auth_result {
219 self.session.set_user_info(user_info);
220 set_client_info(client, &self.session);
221 auth::finish_authentication(client, self.param_provider.as_ref()).await?;
222 } else {
223 return send_error(
224 client,
225 PgErrorCode::Ec28P01
226 .to_err_info("password authentication failed".to_string()),
227 )
228 .await;
229 }
230 }
231 _ => {}
232 }
233 Ok(())
234 }
235}
236
237async fn send_error<C>(client: &mut C, err_info: ErrorInfo) -> PgWireResult<()>
238where
239 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
240 C::Error: Debug,
241 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
242{
243 let error = ErrorResponse::from(err_info);
244 client
245 .feed(PgWireBackendMessage::ErrorResponse(error))
246 .await?;
247 client.close().await?;
248 Ok(())
249}
250
251enum DbResolution {
252 Resolved(String, String),
253 NotFound(String),
254}
255
256async fn resolve_db_info<C>(
258 client: Exclusive<&mut C>,
259 query_handler: ServerSqlQueryHandlerRef,
260) -> PgWireResult<DbResolution>
261where
262 C: ClientInfo + Unpin + Send,
263{
264 let db_ref = client.into_inner().metadata().get(super::METADATA_DATABASE);
265 if let Some(db) = db_ref {
266 let (catalog, schema) = parse_catalog_and_schema_from_db_string(db);
267 if query_handler
268 .is_valid_schema(&catalog, &schema)
269 .await
270 .map_err(convert_err)?
271 {
272 Ok(DbResolution::Resolved(catalog, schema))
273 } else {
274 Ok(DbResolution::NotFound(format!("Database not found: {db}")))
275 }
276 } else {
277 Ok(DbResolution::NotFound("Database not specified".to_owned()))
278 }
279}