servers/postgres/
auth_handler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::sync::Exclusive;
17
18use ::auth::{userinfo_by_name, Identity, Password, UserInfoRef, UserProviderRef};
19use async_trait::async_trait;
20use common_catalog::parse_catalog_and_schema_from_db_string;
21use common_error::ext::ErrorExt;
22use futures::{Sink, SinkExt};
23use pgwire::api::auth::StartupHandler;
24use pgwire::api::{auth, ClientInfo, PgWireConnectionState};
25use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
26use pgwire::messages::response::ErrorResponse;
27use pgwire::messages::startup::{Authentication, SecretKey};
28use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
29use session::Session;
30use snafu::IntoError;
31
32use crate::error::{AuthSnafu, Result};
33use crate::metrics::METRIC_AUTH_FAILURE;
34use crate::postgres::types::PgErrorCode;
35use crate::postgres::utils::convert_err;
36use crate::postgres::PostgresServerHandlerInner;
37use crate::query_handler::sql::ServerSqlQueryHandlerRef;
38
39pub(crate) struct PgLoginVerifier {
40    user_provider: Option<UserProviderRef>,
41}
42
43impl PgLoginVerifier {
44    pub(crate) fn new(user_provider: Option<UserProviderRef>) -> Self {
45        Self { user_provider }
46    }
47}
48
49#[allow(dead_code)]
50struct LoginInfo {
51    user: Option<String>,
52    catalog: Option<String>,
53    schema: Option<String>,
54    host: String,
55}
56
57impl LoginInfo {
58    pub fn from_client_info<C>(client: &C) -> LoginInfo
59    where
60        C: ClientInfo,
61    {
62        LoginInfo {
63            user: client.metadata().get(super::METADATA_USER).map(Into::into),
64            catalog: client
65                .metadata()
66                .get(super::METADATA_CATALOG)
67                .map(Into::into),
68            schema: client
69                .metadata()
70                .get(super::METADATA_SCHEMA)
71                .map(Into::into),
72            host: client.socket_addr().ip().to_string(),
73        }
74    }
75}
76
77impl PgLoginVerifier {
78    async fn auth(&self, login: &LoginInfo, password: &str) -> Result<Option<UserInfoRef>> {
79        let user_provider = match &self.user_provider {
80            Some(provider) => provider,
81            None => return Ok(None),
82        };
83
84        let user_name = match &login.user {
85            Some(name) => name,
86            None => return Ok(None),
87        };
88        let catalog = match &login.catalog {
89            Some(name) => name,
90            None => return Ok(None),
91        };
92        let schema = match &login.schema {
93            Some(name) => name,
94            None => return Ok(None),
95        };
96
97        match user_provider
98            .auth(
99                Identity::UserId(user_name, None),
100                Password::PlainText(password.to_string().into()),
101                catalog,
102                schema,
103            )
104            .await
105        {
106            Err(e) => {
107                METRIC_AUTH_FAILURE
108                    .with_label_values(&[e.status_code().as_ref()])
109                    .inc();
110                Err(AuthSnafu.into_error(e))
111            }
112            Ok(user_info) => Ok(Some(user_info)),
113        }
114    }
115}
116
117fn set_client_info<C>(client: &mut C, session: &Session)
118where
119    C: ClientInfo,
120{
121    if let Some(current_catalog) = client.metadata().get(super::METADATA_CATALOG) {
122        session.set_catalog(current_catalog.clone());
123    }
124    if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) {
125        session.set_schema(current_schema.clone());
126    }
127
128    // pass generated process id and secret key to client, this information will
129    // be sent to postgres client for query cancellation.
130    // use all 0 before we actually supported query cancellation
131    client.set_pid_and_secret_key(0, SecretKey::I32(0));
132    // set userinfo outside
133}
134
135#[async_trait]
136impl StartupHandler for PostgresServerHandlerInner {
137    async fn on_startup<C>(
138        &self,
139        client: &mut C,
140        message: PgWireFrontendMessage,
141    ) -> PgWireResult<()>
142    where
143        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
144        C::Error: Debug,
145        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
146    {
147        match message {
148            PgWireFrontendMessage::Startup(ref startup) => {
149                // check ssl requirement
150                if !client.is_secure() && self.force_tls {
151                    send_error(
152                        client,
153                        PgErrorCode::Ec28000.to_err_info("No encryption".to_string()),
154                    )
155                    .await?;
156                    return Ok(());
157                }
158
159                auth::save_startup_parameters_to_metadata(client, startup);
160
161                // check if db is valid
162                match resolve_db_info(Exclusive::new(client), self.query_handler.clone()).await? {
163                    DbResolution::Resolved(catalog, schema) => {
164                        let metadata = client.metadata_mut();
165                        let _ = metadata.insert(super::METADATA_CATALOG.to_owned(), catalog);
166                        let _ = metadata.insert(super::METADATA_SCHEMA.to_owned(), schema);
167                    }
168                    DbResolution::NotFound(msg) => {
169                        send_error(client, PgErrorCode::Ec3D000.to_err_info(msg)).await?;
170                        return Ok(());
171                    }
172                }
173
174                if self.login_verifier.user_provider.is_some() {
175                    client.set_state(PgWireConnectionState::AuthenticationInProgress);
176                    client
177                        .send(PgWireBackendMessage::Authentication(
178                            Authentication::CleartextPassword,
179                        ))
180                        .await?;
181                } else {
182                    self.session.set_user_info(userinfo_by_name(
183                        client.metadata().get(super::METADATA_USER).cloned(),
184                    ));
185                    set_client_info(client, &self.session);
186                    auth::finish_authentication(client, self.param_provider.as_ref()).await?;
187                }
188            }
189            PgWireFrontendMessage::PasswordMessageFamily(pwd) => {
190                // the newer version of pgwire has a few variant password
191                // message like cleartext/md5 password, saslresponse, etc. Here
192                // we must manually coerce it into password
193                let pwd = pwd.into_password()?;
194
195                let login_info = LoginInfo::from_client_info(client);
196
197                // do authenticate
198                let auth_result = self.login_verifier.auth(&login_info, &pwd.password).await;
199
200                if let Ok(Some(user_info)) = auth_result {
201                    self.session.set_user_info(user_info);
202                    set_client_info(client, &self.session);
203                    auth::finish_authentication(client, self.param_provider.as_ref()).await?;
204                } else {
205                    return send_error(
206                        client,
207                        PgErrorCode::Ec28P01
208                            .to_err_info("password authentication failed".to_string()),
209                    )
210                    .await;
211                }
212            }
213            _ => {}
214        }
215        Ok(())
216    }
217}
218
219async fn send_error<C>(client: &mut C, err_info: ErrorInfo) -> PgWireResult<()>
220where
221    C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
222    C::Error: Debug,
223    PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
224{
225    let error = ErrorResponse::from(err_info);
226    client
227        .feed(PgWireBackendMessage::ErrorResponse(error))
228        .await?;
229    client.close().await?;
230    Ok(())
231}
232
233enum DbResolution {
234    Resolved(String, String),
235    NotFound(String),
236}
237
238/// A function extracted to resolve lifetime and readability issues:
239async fn resolve_db_info<C>(
240    client: Exclusive<&mut C>,
241    query_handler: ServerSqlQueryHandlerRef,
242) -> PgWireResult<DbResolution>
243where
244    C: ClientInfo + Unpin + Send,
245{
246    let db_ref = client.into_inner().metadata().get(super::METADATA_DATABASE);
247    if let Some(db) = db_ref {
248        let (catalog, schema) = parse_catalog_and_schema_from_db_string(db);
249        if query_handler
250            .is_valid_schema(&catalog, &schema)
251            .await
252            .map_err(convert_err)?
253        {
254            Ok(DbResolution::Resolved(catalog, schema))
255        } else {
256            Ok(DbResolution::NotFound(format!("Database not found: {db}")))
257        }
258    } else {
259        Ok(DbResolution::NotFound("Database not specified".to_owned()))
260    }
261}