servers/postgres/
auth_handler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::sync::Exclusive;
17
18use ::auth::{userinfo_by_name, Identity, Password, UserInfoRef, UserProviderRef};
19use async_trait::async_trait;
20use common_catalog::parse_catalog_and_schema_from_db_string;
21use common_error::ext::ErrorExt;
22use futures::{Sink, SinkExt};
23use pgwire::api::auth::StartupHandler;
24use pgwire::api::{auth, ClientInfo, PgWireConnectionState};
25use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
26use pgwire::messages::response::ErrorResponse;
27use pgwire::messages::startup::Authentication;
28use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
29use session::Session;
30use snafu::IntoError;
31
32use crate::error::{AuthSnafu, Result};
33use crate::metrics::METRIC_AUTH_FAILURE;
34use crate::postgres::types::PgErrorCode;
35use crate::postgres::PostgresServerHandlerInner;
36use crate::query_handler::sql::ServerSqlQueryHandlerRef;
37
38pub(crate) struct PgLoginVerifier {
39    user_provider: Option<UserProviderRef>,
40}
41
42impl PgLoginVerifier {
43    pub(crate) fn new(user_provider: Option<UserProviderRef>) -> Self {
44        Self { user_provider }
45    }
46}
47
48#[allow(dead_code)]
49struct LoginInfo {
50    user: Option<String>,
51    catalog: Option<String>,
52    schema: Option<String>,
53    host: String,
54}
55
56impl LoginInfo {
57    pub fn from_client_info<C>(client: &C) -> LoginInfo
58    where
59        C: ClientInfo,
60    {
61        LoginInfo {
62            user: client.metadata().get(super::METADATA_USER).map(Into::into),
63            catalog: client
64                .metadata()
65                .get(super::METADATA_CATALOG)
66                .map(Into::into),
67            schema: client
68                .metadata()
69                .get(super::METADATA_SCHEMA)
70                .map(Into::into),
71            host: client.socket_addr().ip().to_string(),
72        }
73    }
74}
75
76impl PgLoginVerifier {
77    async fn auth(&self, login: &LoginInfo, password: &str) -> Result<Option<UserInfoRef>> {
78        let user_provider = match &self.user_provider {
79            Some(provider) => provider,
80            None => return Ok(None),
81        };
82
83        let user_name = match &login.user {
84            Some(name) => name,
85            None => return Ok(None),
86        };
87        let catalog = match &login.catalog {
88            Some(name) => name,
89            None => return Ok(None),
90        };
91        let schema = match &login.schema {
92            Some(name) => name,
93            None => return Ok(None),
94        };
95
96        match user_provider
97            .auth(
98                Identity::UserId(user_name, None),
99                Password::PlainText(password.to_string().into()),
100                catalog,
101                schema,
102            )
103            .await
104        {
105            Err(e) => {
106                METRIC_AUTH_FAILURE
107                    .with_label_values(&[e.status_code().as_ref()])
108                    .inc();
109                Err(AuthSnafu.into_error(e))
110            }
111            Ok(user_info) => Ok(Some(user_info)),
112        }
113    }
114}
115
116fn set_client_info<C>(client: &C, session: &Session)
117where
118    C: ClientInfo,
119{
120    if let Some(current_catalog) = client.metadata().get(super::METADATA_CATALOG) {
121        session.set_catalog(current_catalog.clone());
122    }
123    if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) {
124        session.set_schema(current_schema.clone());
125    }
126    // set userinfo outside
127}
128
129#[async_trait]
130impl StartupHandler for PostgresServerHandlerInner {
131    async fn on_startup<C>(
132        &self,
133        client: &mut C,
134        message: PgWireFrontendMessage,
135    ) -> PgWireResult<()>
136    where
137        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
138        C::Error: Debug,
139        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
140    {
141        match message {
142            PgWireFrontendMessage::Startup(ref startup) => {
143                // check ssl requirement
144                if !client.is_secure() && self.force_tls {
145                    send_error(
146                        client,
147                        PgErrorCode::Ec28000.to_err_info("No encryption".to_string()),
148                    )
149                    .await?;
150                    return Ok(());
151                }
152
153                auth::save_startup_parameters_to_metadata(client, startup);
154
155                // check if db is valid
156                match resolve_db_info(Exclusive::new(client), self.query_handler.clone()).await? {
157                    DbResolution::Resolved(catalog, schema) => {
158                        let metadata = client.metadata_mut();
159                        let _ = metadata.insert(super::METADATA_CATALOG.to_owned(), catalog);
160                        let _ = metadata.insert(super::METADATA_SCHEMA.to_owned(), schema);
161                    }
162                    DbResolution::NotFound(msg) => {
163                        send_error(client, PgErrorCode::Ec3D000.to_err_info(msg)).await?;
164                        return Ok(());
165                    }
166                }
167
168                if self.login_verifier.user_provider.is_some() {
169                    client.set_state(PgWireConnectionState::AuthenticationInProgress);
170                    client
171                        .send(PgWireBackendMessage::Authentication(
172                            Authentication::CleartextPassword,
173                        ))
174                        .await?;
175                } else {
176                    self.session.set_user_info(userinfo_by_name(
177                        client.metadata().get(super::METADATA_USER).cloned(),
178                    ));
179                    set_client_info(client, &self.session);
180                    auth::finish_authentication(client, self.param_provider.as_ref()).await?;
181                }
182            }
183            PgWireFrontendMessage::PasswordMessageFamily(pwd) => {
184                // the newer version of pgwire has a few variant password
185                // message like cleartext/md5 password, saslresponse, etc. Here
186                // we must manually coerce it into password
187                let pwd = pwd.into_password()?;
188
189                let login_info = LoginInfo::from_client_info(client);
190
191                // do authenticate
192                let auth_result = self.login_verifier.auth(&login_info, &pwd.password).await;
193
194                if let Ok(Some(user_info)) = auth_result {
195                    self.session.set_user_info(user_info);
196                    set_client_info(client, &self.session);
197                    auth::finish_authentication(client, self.param_provider.as_ref()).await?;
198                } else {
199                    return send_error(
200                        client,
201                        PgErrorCode::Ec28P01
202                            .to_err_info("password authentication failed".to_string()),
203                    )
204                    .await;
205                }
206            }
207            _ => {}
208        }
209        Ok(())
210    }
211}
212
213async fn send_error<C>(client: &mut C, err_info: ErrorInfo) -> PgWireResult<()>
214where
215    C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
216    C::Error: Debug,
217    PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
218{
219    let error = ErrorResponse::from(err_info);
220    client
221        .feed(PgWireBackendMessage::ErrorResponse(error))
222        .await?;
223    client.close().await?;
224    Ok(())
225}
226
227enum DbResolution {
228    Resolved(String, String),
229    NotFound(String),
230}
231
232/// A function extracted to resolve lifetime and readability issues:
233async fn resolve_db_info<C>(
234    client: Exclusive<&mut C>,
235    query_handler: ServerSqlQueryHandlerRef,
236) -> PgWireResult<DbResolution>
237where
238    C: ClientInfo + Unpin + Send,
239{
240    let db_ref = client.into_inner().metadata().get(super::METADATA_DATABASE);
241    if let Some(db) = db_ref {
242        let (catalog, schema) = parse_catalog_and_schema_from_db_string(db);
243        if query_handler
244            .is_valid_schema(&catalog, &schema)
245            .await
246            .map_err(|e| PgWireError::ApiError(Box::new(e)))?
247        {
248            Ok(DbResolution::Resolved(catalog, schema))
249        } else {
250            Ok(DbResolution::NotFound(format!("Database not found: {db}")))
251        }
252    } else {
253        Ok(DbResolution::NotFound("Database not specified".to_owned()))
254    }
255}