servers/postgres/
server.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::future::Future;
17use std::net::SocketAddr;
18use std::sync::Arc;
19
20use ::auth::UserProviderRef;
21use async_trait::async_trait;
22use catalog::process_manager::ProcessManagerRef;
23use common_runtime::Runtime;
24use common_runtime::runtime::RuntimeTrait;
25use common_telemetry::{debug, warn};
26use futures::StreamExt;
27use pgwire::tokio::process_socket;
28use tokio_rustls::TlsAcceptor;
29
30use crate::error::Result;
31use crate::postgres::{MakePostgresServerHandler, MakePostgresServerHandlerBuilder};
32use crate::query_handler::sql::ServerSqlQueryHandlerRef;
33use crate::server::{AbortableStream, BaseTcpServer, Server};
34use crate::tls::ReloadableTlsServerConfig;
35
36pub struct PostgresServer {
37    base_server: BaseTcpServer,
38    make_handler: Arc<MakePostgresServerHandler>,
39    tls_server_config: Arc<ReloadableTlsServerConfig>,
40    keep_alive_secs: u64,
41    bind_addr: Option<SocketAddr>,
42    process_manager: Option<ProcessManagerRef>,
43}
44
45impl PostgresServer {
46    /// Creates a new Postgres server with provided query_handler and async runtime
47    pub fn new(
48        query_handler: ServerSqlQueryHandlerRef,
49        force_tls: bool,
50        tls_server_config: Arc<ReloadableTlsServerConfig>,
51        keep_alive_secs: u64,
52        io_runtime: Runtime,
53        user_provider: Option<UserProviderRef>,
54        process_manager: Option<ProcessManagerRef>,
55    ) -> PostgresServer {
56        let make_handler = Arc::new(
57            MakePostgresServerHandlerBuilder::default()
58                .query_handler(query_handler.clone())
59                .user_provider(user_provider.clone())
60                .force_tls(force_tls)
61                .build()
62                .unwrap(),
63        );
64        PostgresServer {
65            base_server: BaseTcpServer::create_server("Postgres", io_runtime),
66            make_handler,
67            tls_server_config,
68            keep_alive_secs,
69            bind_addr: None,
70            process_manager,
71        }
72    }
73
74    fn accept(
75        &self,
76        io_runtime: Runtime,
77        accepting_stream: AbortableStream,
78    ) -> impl Future<Output = ()> + use<> {
79        let handler_maker = self.make_handler.clone();
80        let tls_server_config = self.tls_server_config.clone();
81        let process_manager = self.process_manager.clone();
82        accepting_stream.for_each(move |tcp_stream| {
83            let io_runtime = io_runtime.clone();
84            let tls_acceptor = tls_server_config.get_config().map(TlsAcceptor::from);
85            let handler_maker = handler_maker.clone();
86            let process_id = process_manager.as_ref().map(|p| p.next_id()).unwrap_or(0);
87
88            async move {
89                match tcp_stream {
90                    Err(error) => debug!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt.
91                    Ok(io_stream) => {
92                        let addr = match io_stream.peer_addr() {
93                            Ok(addr) => {
94                                debug!("PostgreSQL client coming from {}", addr);
95                                Some(addr)
96                            }
97                            Err(e) => {
98                                warn!(e; "Failed to get PostgreSQL client addr");
99                                None
100                            }
101                        };
102
103                        let _handle = io_runtime.spawn(async move {
104                            crate::metrics::METRIC_POSTGRES_CONNECTIONS.inc();
105                            let pg_handler = Arc::new(handler_maker.make(addr, process_id));
106                            let r =
107                                process_socket(io_stream, tls_acceptor.clone(), pg_handler).await;
108                            crate::metrics::METRIC_POSTGRES_CONNECTIONS.dec();
109                            r
110                        });
111                    }
112                };
113            }
114        })
115    }
116}
117
118pub const POSTGRES_SERVER: &str = "POSTGRES_SERVER";
119
120#[async_trait]
121impl Server for PostgresServer {
122    async fn shutdown(&self) -> Result<()> {
123        self.base_server.shutdown().await
124    }
125
126    async fn start(&mut self, listening: SocketAddr) -> Result<()> {
127        let (stream, addr) = self
128            .base_server
129            .bind(listening, self.keep_alive_secs)
130            .await?;
131
132        let io_runtime = self.base_server.io_runtime();
133        let join_handle = common_runtime::spawn_global(self.accept(io_runtime, stream));
134
135        self.base_server.start_with(join_handle).await?;
136
137        self.bind_addr = Some(addr);
138        Ok(())
139    }
140
141    fn name(&self) -> &str {
142        POSTGRES_SERVER
143    }
144
145    fn bind_addr(&self) -> Option<SocketAddr> {
146        self.bind_addr
147    }
148
149    fn as_any(&self) -> &dyn Any {
150        self
151    }
152}