servers/postgres/
auth_handler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::sync::Exclusive;
17
18use ::auth::{userinfo_by_name, Identity, Password, UserInfoRef, UserProviderRef};
19use async_trait::async_trait;
20use common_catalog::parse_catalog_and_schema_from_db_string;
21use common_error::ext::ErrorExt;
22use futures::{Sink, SinkExt};
23use pgwire::api::auth::StartupHandler;
24use pgwire::api::{auth, ClientInfo, PgWireConnectionState};
25use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
26use pgwire::messages::response::ErrorResponse;
27use pgwire::messages::startup::Authentication;
28use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
29use session::Session;
30use snafu::IntoError;
31
32use crate::error::{AuthSnafu, Result};
33use crate::metrics::METRIC_AUTH_FAILURE;
34use crate::postgres::types::PgErrorCode;
35use crate::postgres::PostgresServerHandlerInner;
36use crate::query_handler::sql::ServerSqlQueryHandlerRef;
37
38pub(crate) struct PgLoginVerifier {
39    user_provider: Option<UserProviderRef>,
40}
41
42impl PgLoginVerifier {
43    pub(crate) fn new(user_provider: Option<UserProviderRef>) -> Self {
44        Self { user_provider }
45    }
46}
47
48#[allow(dead_code)]
49struct LoginInfo {
50    user: Option<String>,
51    catalog: Option<String>,
52    schema: Option<String>,
53    host: String,
54}
55
56impl LoginInfo {
57    pub fn from_client_info<C>(client: &C) -> LoginInfo
58    where
59        C: ClientInfo,
60    {
61        LoginInfo {
62            user: client.metadata().get(super::METADATA_USER).map(Into::into),
63            catalog: client
64                .metadata()
65                .get(super::METADATA_CATALOG)
66                .map(Into::into),
67            schema: client
68                .metadata()
69                .get(super::METADATA_SCHEMA)
70                .map(Into::into),
71            host: client.socket_addr().ip().to_string(),
72        }
73    }
74}
75
76impl PgLoginVerifier {
77    async fn auth(&self, login: &LoginInfo, password: &str) -> Result<Option<UserInfoRef>> {
78        let user_provider = match &self.user_provider {
79            Some(provider) => provider,
80            None => return Ok(None),
81        };
82
83        let user_name = match &login.user {
84            Some(name) => name,
85            None => return Ok(None),
86        };
87        let catalog = match &login.catalog {
88            Some(name) => name,
89            None => return Ok(None),
90        };
91        let schema = match &login.schema {
92            Some(name) => name,
93            None => return Ok(None),
94        };
95
96        match user_provider
97            .auth(
98                Identity::UserId(user_name, None),
99                Password::PlainText(password.to_string().into()),
100                catalog,
101                schema,
102            )
103            .await
104        {
105            Err(e) => {
106                METRIC_AUTH_FAILURE
107                    .with_label_values(&[e.status_code().as_ref()])
108                    .inc();
109                Err(AuthSnafu.into_error(e))
110            }
111            Ok(user_info) => Ok(Some(user_info)),
112        }
113    }
114}
115
116fn set_client_info<C>(client: &mut C, session: &Session)
117where
118    C: ClientInfo,
119{
120    if let Some(current_catalog) = client.metadata().get(super::METADATA_CATALOG) {
121        session.set_catalog(current_catalog.clone());
122    }
123    if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) {
124        session.set_schema(current_schema.clone());
125    }
126
127    // pass generated process id and secret key to client, this information will
128    // be sent to postgres client for query cancellation.
129    client.set_pid_and_secret_key(session.process_id() as i32, rand::random::<i32>());
130    // set userinfo outside
131}
132
133#[async_trait]
134impl StartupHandler for PostgresServerHandlerInner {
135    async fn on_startup<C>(
136        &self,
137        client: &mut C,
138        message: PgWireFrontendMessage,
139    ) -> PgWireResult<()>
140    where
141        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
142        C::Error: Debug,
143        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
144    {
145        match message {
146            PgWireFrontendMessage::Startup(ref startup) => {
147                // check ssl requirement
148                if !client.is_secure() && self.force_tls {
149                    send_error(
150                        client,
151                        PgErrorCode::Ec28000.to_err_info("No encryption".to_string()),
152                    )
153                    .await?;
154                    return Ok(());
155                }
156
157                auth::save_startup_parameters_to_metadata(client, startup);
158
159                // check if db is valid
160                match resolve_db_info(Exclusive::new(client), self.query_handler.clone()).await? {
161                    DbResolution::Resolved(catalog, schema) => {
162                        let metadata = client.metadata_mut();
163                        let _ = metadata.insert(super::METADATA_CATALOG.to_owned(), catalog);
164                        let _ = metadata.insert(super::METADATA_SCHEMA.to_owned(), schema);
165                    }
166                    DbResolution::NotFound(msg) => {
167                        send_error(client, PgErrorCode::Ec3D000.to_err_info(msg)).await?;
168                        return Ok(());
169                    }
170                }
171
172                if self.login_verifier.user_provider.is_some() {
173                    client.set_state(PgWireConnectionState::AuthenticationInProgress);
174                    client
175                        .send(PgWireBackendMessage::Authentication(
176                            Authentication::CleartextPassword,
177                        ))
178                        .await?;
179                } else {
180                    self.session.set_user_info(userinfo_by_name(
181                        client.metadata().get(super::METADATA_USER).cloned(),
182                    ));
183                    set_client_info(client, &self.session);
184                    auth::finish_authentication(client, self.param_provider.as_ref()).await?;
185                }
186            }
187            PgWireFrontendMessage::PasswordMessageFamily(pwd) => {
188                // the newer version of pgwire has a few variant password
189                // message like cleartext/md5 password, saslresponse, etc. Here
190                // we must manually coerce it into password
191                let pwd = pwd.into_password()?;
192
193                let login_info = LoginInfo::from_client_info(client);
194
195                // do authenticate
196                let auth_result = self.login_verifier.auth(&login_info, &pwd.password).await;
197
198                if let Ok(Some(user_info)) = auth_result {
199                    self.session.set_user_info(user_info);
200                    set_client_info(client, &self.session);
201                    auth::finish_authentication(client, self.param_provider.as_ref()).await?;
202                } else {
203                    return send_error(
204                        client,
205                        PgErrorCode::Ec28P01
206                            .to_err_info("password authentication failed".to_string()),
207                    )
208                    .await;
209                }
210            }
211            _ => {}
212        }
213        Ok(())
214    }
215}
216
217async fn send_error<C>(client: &mut C, err_info: ErrorInfo) -> PgWireResult<()>
218where
219    C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
220    C::Error: Debug,
221    PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
222{
223    let error = ErrorResponse::from(err_info);
224    client
225        .feed(PgWireBackendMessage::ErrorResponse(error))
226        .await?;
227    client.close().await?;
228    Ok(())
229}
230
231enum DbResolution {
232    Resolved(String, String),
233    NotFound(String),
234}
235
236/// A function extracted to resolve lifetime and readability issues:
237async fn resolve_db_info<C>(
238    client: Exclusive<&mut C>,
239    query_handler: ServerSqlQueryHandlerRef,
240) -> PgWireResult<DbResolution>
241where
242    C: ClientInfo + Unpin + Send,
243{
244    let db_ref = client.into_inner().metadata().get(super::METADATA_DATABASE);
245    if let Some(db) = db_ref {
246        let (catalog, schema) = parse_catalog_and_schema_from_db_string(db);
247        if query_handler
248            .is_valid_schema(&catalog, &schema)
249            .await
250            .map_err(|e| PgWireError::ApiError(Box::new(e)))?
251        {
252            Ok(DbResolution::Resolved(catalog, schema))
253        } else {
254            Ok(DbResolution::NotFound(format!("Database not found: {db}")))
255        }
256    } else {
257        Ok(DbResolution::NotFound("Database not specified".to_owned()))
258    }
259}