servers/postgres/
auth_handler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::sync::Exclusive;
17
18use ::auth::{Identity, Password, UserInfoRef, UserProviderRef, userinfo_by_name};
19use async_trait::async_trait;
20use common_catalog::parse_catalog_and_schema_from_db_string;
21use common_error::ext::ErrorExt;
22use common_time::Timezone;
23use futures::{Sink, SinkExt};
24use pgwire::api::auth::StartupHandler;
25use pgwire::api::{ClientInfo, PgWireConnectionState, auth};
26use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
27use pgwire::messages::response::ErrorResponse;
28use pgwire::messages::startup::{Authentication, SecretKey};
29use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
30use session::Session;
31use snafu::IntoError;
32
33use crate::error::{AuthSnafu, Result};
34use crate::metrics::METRIC_AUTH_FAILURE;
35use crate::postgres::PostgresServerHandlerInner;
36use crate::postgres::types::PgErrorCode;
37use crate::postgres::utils::convert_err;
38use crate::query_handler::sql::ServerSqlQueryHandlerRef;
39
40pub(crate) struct PgLoginVerifier {
41    user_provider: Option<UserProviderRef>,
42}
43
44impl PgLoginVerifier {
45    pub(crate) fn new(user_provider: Option<UserProviderRef>) -> Self {
46        Self { user_provider }
47    }
48}
49
50#[allow(dead_code)]
51struct LoginInfo {
52    user: Option<String>,
53    catalog: Option<String>,
54    schema: Option<String>,
55    host: String,
56}
57
58impl LoginInfo {
59    pub fn from_client_info<C>(client: &C) -> LoginInfo
60    where
61        C: ClientInfo,
62    {
63        LoginInfo {
64            user: client.metadata().get(super::METADATA_USER).map(Into::into),
65            catalog: client
66                .metadata()
67                .get(super::METADATA_CATALOG)
68                .map(Into::into),
69            schema: client
70                .metadata()
71                .get(super::METADATA_SCHEMA)
72                .map(Into::into),
73            host: client.socket_addr().ip().to_string(),
74        }
75    }
76}
77
78impl PgLoginVerifier {
79    async fn auth(&self, login: &LoginInfo, password: &str) -> Result<Option<UserInfoRef>> {
80        let user_provider = match &self.user_provider {
81            Some(provider) => provider,
82            None => return Ok(None),
83        };
84
85        let user_name = match &login.user {
86            Some(name) => name,
87            None => return Ok(None),
88        };
89        let catalog = match &login.catalog {
90            Some(name) => name,
91            None => return Ok(None),
92        };
93        let schema = match &login.schema {
94            Some(name) => name,
95            None => return Ok(None),
96        };
97
98        match user_provider
99            .auth(
100                Identity::UserId(user_name, None),
101                Password::PlainText(password.to_string().into()),
102                catalog,
103                schema,
104            )
105            .await
106        {
107            Err(e) => {
108                METRIC_AUTH_FAILURE
109                    .with_label_values(&[e.status_code().as_ref()])
110                    .inc();
111                Err(AuthSnafu.into_error(e))
112            }
113            Ok(user_info) => Ok(Some(user_info)),
114        }
115    }
116}
117
118fn set_client_info<C>(client: &mut C, session: &Session)
119where
120    C: ClientInfo,
121{
122    if let Some(current_catalog) = client.metadata().get(super::METADATA_CATALOG) {
123        session.set_catalog(current_catalog.clone());
124    }
125    if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) {
126        session.set_schema(current_schema.clone());
127    }
128
129    // pass generated process id and secret key to client, this information will
130    // be sent to postgres client for query cancellation.
131    // use all 0 before we actually supported query cancellation
132    client.set_pid_and_secret_key(0, SecretKey::I32(0));
133    // set userinfo outside
134}
135
136#[async_trait]
137impl StartupHandler for PostgresServerHandlerInner {
138    async fn on_startup<C>(
139        &self,
140        client: &mut C,
141        message: PgWireFrontendMessage,
142    ) -> PgWireResult<()>
143    where
144        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
145        C::Error: Debug,
146        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
147    {
148        match message {
149            PgWireFrontendMessage::Startup(ref startup) => {
150                // check ssl requirement
151                if !client.is_secure() && self.force_tls {
152                    send_error(
153                        client,
154                        PgErrorCode::Ec28000.to_err_info("No encryption".to_string()),
155                    )
156                    .await?;
157                    return Ok(());
158                }
159
160                auth::save_startup_parameters_to_metadata(client, startup);
161
162                // check if db is valid
163                match resolve_db_info(Exclusive::new(client), self.query_handler.clone()).await? {
164                    DbResolution::Resolved(catalog, schema) => {
165                        let metadata = client.metadata_mut();
166                        let _ = metadata.insert(super::METADATA_CATALOG.to_owned(), catalog);
167                        let _ = metadata.insert(super::METADATA_SCHEMA.to_owned(), schema);
168                    }
169                    DbResolution::NotFound(msg) => {
170                        send_error(client, PgErrorCode::Ec3D000.to_err_info(msg)).await?;
171                        return Ok(());
172                    }
173                }
174
175                // try to set TimeZone
176                if let Some(tz) = client.metadata().get("TimeZone") {
177                    match Timezone::from_tz_string(tz) {
178                        Ok(tz) => self.session.set_timezone(tz),
179                        Err(_) => {
180                            send_error(
181                                client,
182                                PgErrorCode::Ec22023
183                                    .to_err_info(format!("Invalid TimeZone: {}", tz)),
184                            )
185                            .await?;
186
187                            return Ok(());
188                        }
189                    }
190                }
191
192                if self.login_verifier.user_provider.is_some() {
193                    client.set_state(PgWireConnectionState::AuthenticationInProgress);
194                    client
195                        .send(PgWireBackendMessage::Authentication(
196                            Authentication::CleartextPassword,
197                        ))
198                        .await?;
199                } else {
200                    self.session.set_user_info(userinfo_by_name(
201                        client.metadata().get(super::METADATA_USER).cloned(),
202                    ));
203                    set_client_info(client, &self.session);
204                    auth::finish_authentication(client, self.param_provider.as_ref()).await?;
205                }
206            }
207            PgWireFrontendMessage::PasswordMessageFamily(pwd) => {
208                // the newer version of pgwire has a few variant password
209                // message like cleartext/md5 password, saslresponse, etc. Here
210                // we must manually coerce it into password
211                let pwd = pwd.into_password()?;
212
213                let login_info = LoginInfo::from_client_info(client);
214
215                // do authenticate
216                let auth_result = self.login_verifier.auth(&login_info, &pwd.password).await;
217
218                if let Ok(Some(user_info)) = auth_result {
219                    self.session.set_user_info(user_info);
220                    set_client_info(client, &self.session);
221                    auth::finish_authentication(client, self.param_provider.as_ref()).await?;
222                } else {
223                    return send_error(
224                        client,
225                        PgErrorCode::Ec28P01
226                            .to_err_info("password authentication failed".to_string()),
227                    )
228                    .await;
229                }
230            }
231            _ => {}
232        }
233        Ok(())
234    }
235}
236
237async fn send_error<C>(client: &mut C, err_info: ErrorInfo) -> PgWireResult<()>
238where
239    C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
240    C::Error: Debug,
241    PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
242{
243    let error = ErrorResponse::from(err_info);
244    client
245        .feed(PgWireBackendMessage::ErrorResponse(error))
246        .await?;
247    client.close().await?;
248    Ok(())
249}
250
251enum DbResolution {
252    Resolved(String, String),
253    NotFound(String),
254}
255
256/// A function extracted to resolve lifetime and readability issues:
257async fn resolve_db_info<C>(
258    client: Exclusive<&mut C>,
259    query_handler: ServerSqlQueryHandlerRef,
260) -> PgWireResult<DbResolution>
261where
262    C: ClientInfo + Unpin + Send,
263{
264    let db_ref = client.into_inner().metadata().get(super::METADATA_DATABASE);
265    if let Some(db) = db_ref {
266        let (catalog, schema) = parse_catalog_and_schema_from_db_string(db);
267        if query_handler
268            .is_valid_schema(&catalog, &schema)
269            .await
270            .map_err(convert_err)?
271        {
272            Ok(DbResolution::Resolved(catalog, schema))
273        } else {
274            Ok(DbResolution::NotFound(format!("Database not found: {db}")))
275        }
276    } else {
277        Ok(DbResolution::NotFound("Database not specified".to_owned()))
278    }
279}