servers/postgres/
server.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::future::Future;
16use std::net::SocketAddr;
17use std::sync::Arc;
18
19use ::auth::UserProviderRef;
20use async_trait::async_trait;
21use catalog::process_manager::ProcessManagerRef;
22use common_runtime::runtime::RuntimeTrait;
23use common_runtime::Runtime;
24use common_telemetry::{debug, warn};
25use futures::StreamExt;
26use pgwire::tokio::process_socket;
27use tokio_rustls::TlsAcceptor;
28
29use crate::error::Result;
30use crate::postgres::{MakePostgresServerHandler, MakePostgresServerHandlerBuilder};
31use crate::query_handler::sql::ServerSqlQueryHandlerRef;
32use crate::server::{AbortableStream, BaseTcpServer, Server};
33use crate::tls::ReloadableTlsServerConfig;
34
35pub struct PostgresServer {
36    base_server: BaseTcpServer,
37    make_handler: Arc<MakePostgresServerHandler>,
38    tls_server_config: Arc<ReloadableTlsServerConfig>,
39    keep_alive_secs: u64,
40    bind_addr: Option<SocketAddr>,
41    process_manager: Option<ProcessManagerRef>,
42}
43
44impl PostgresServer {
45    /// Creates a new Postgres server with provided query_handler and async runtime
46    pub fn new(
47        query_handler: ServerSqlQueryHandlerRef,
48        force_tls: bool,
49        tls_server_config: Arc<ReloadableTlsServerConfig>,
50        keep_alive_secs: u64,
51        io_runtime: Runtime,
52        user_provider: Option<UserProviderRef>,
53        process_manager: Option<ProcessManagerRef>,
54    ) -> PostgresServer {
55        let make_handler = Arc::new(
56            MakePostgresServerHandlerBuilder::default()
57                .query_handler(query_handler.clone())
58                .user_provider(user_provider.clone())
59                .force_tls(force_tls)
60                .build()
61                .unwrap(),
62        );
63        PostgresServer {
64            base_server: BaseTcpServer::create_server("Postgres", io_runtime),
65            make_handler,
66            tls_server_config,
67            keep_alive_secs,
68            bind_addr: None,
69            process_manager,
70        }
71    }
72
73    fn accept(
74        &self,
75        io_runtime: Runtime,
76        accepting_stream: AbortableStream,
77    ) -> impl Future<Output = ()> {
78        let handler_maker = self.make_handler.clone();
79        let tls_server_config = self.tls_server_config.clone();
80        let process_manager = self.process_manager.clone();
81        accepting_stream.for_each(move |tcp_stream| {
82            let io_runtime = io_runtime.clone();
83            let tls_acceptor = tls_server_config.get_server_config().map(TlsAcceptor::from);
84            let handler_maker = handler_maker.clone();
85            let process_id = process_manager.as_ref().map(|p| p.next_id()).unwrap_or(0);
86
87            async move {
88                match tcp_stream {
89                    Err(error) => debug!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt.
90                    Ok(io_stream) => {
91                        let addr = match io_stream.peer_addr() {
92                            Ok(addr) => {
93                                debug!("PostgreSQL client coming from {}", addr);
94                                Some(addr)
95                            }
96                            Err(e) => {
97                                warn!(e; "Failed to get PostgreSQL client addr");
98                                None
99                            }
100                        };
101
102                        let _handle = io_runtime.spawn(async move {
103                            crate::metrics::METRIC_POSTGRES_CONNECTIONS.inc();
104                            let pg_handler = Arc::new(handler_maker.make(addr, process_id));
105                            let r =
106                                process_socket(io_stream, tls_acceptor.clone(), pg_handler).await;
107                            crate::metrics::METRIC_POSTGRES_CONNECTIONS.dec();
108                            r
109                        });
110                    }
111                };
112            }
113        })
114    }
115}
116
117pub const POSTGRES_SERVER: &str = "POSTGRES_SERVER";
118
119#[async_trait]
120impl Server for PostgresServer {
121    async fn shutdown(&self) -> Result<()> {
122        self.base_server.shutdown().await
123    }
124
125    async fn start(&mut self, listening: SocketAddr) -> Result<()> {
126        let (stream, addr) = self
127            .base_server
128            .bind(listening, self.keep_alive_secs)
129            .await?;
130
131        let io_runtime = self.base_server.io_runtime();
132        let join_handle = common_runtime::spawn_global(self.accept(io_runtime, stream));
133
134        self.base_server.start_with(join_handle).await?;
135
136        self.bind_addr = Some(addr);
137        Ok(())
138    }
139
140    fn name(&self) -> &str {
141        POSTGRES_SERVER
142    }
143
144    fn bind_addr(&self) -> Option<SocketAddr> {
145        self.bind_addr
146    }
147}