servers/postgres/
server.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::future::Future;
16use std::net::SocketAddr;
17use std::sync::Arc;
18
19use ::auth::UserProviderRef;
20use async_trait::async_trait;
21use common_runtime::runtime::RuntimeTrait;
22use common_runtime::Runtime;
23use common_telemetry::{debug, warn};
24use futures::StreamExt;
25use pgwire::tokio::process_socket;
26use tokio_rustls::TlsAcceptor;
27
28use crate::error::Result;
29use crate::postgres::{MakePostgresServerHandler, MakePostgresServerHandlerBuilder};
30use crate::query_handler::sql::ServerSqlQueryHandlerRef;
31use crate::server::{AbortableStream, BaseTcpServer, Server};
32use crate::tls::ReloadableTlsServerConfig;
33
34pub struct PostgresServer {
35    base_server: BaseTcpServer,
36    make_handler: Arc<MakePostgresServerHandler>,
37    tls_server_config: Arc<ReloadableTlsServerConfig>,
38    keep_alive_secs: u64,
39    bind_addr: Option<SocketAddr>,
40}
41
42impl PostgresServer {
43    /// Creates a new Postgres server with provided query_handler and async runtime
44    pub fn new(
45        query_handler: ServerSqlQueryHandlerRef,
46        force_tls: bool,
47        tls_server_config: Arc<ReloadableTlsServerConfig>,
48        keep_alive_secs: u64,
49        io_runtime: Runtime,
50        user_provider: Option<UserProviderRef>,
51    ) -> PostgresServer {
52        let make_handler = Arc::new(
53            MakePostgresServerHandlerBuilder::default()
54                .query_handler(query_handler.clone())
55                .user_provider(user_provider.clone())
56                .force_tls(force_tls)
57                .build()
58                .unwrap(),
59        );
60        PostgresServer {
61            base_server: BaseTcpServer::create_server("Postgres", io_runtime),
62            make_handler,
63            tls_server_config,
64            keep_alive_secs,
65            bind_addr: None,
66        }
67    }
68
69    fn accept(
70        &self,
71        io_runtime: Runtime,
72        accepting_stream: AbortableStream,
73    ) -> impl Future<Output = ()> {
74        let handler_maker = self.make_handler.clone();
75        let tls_server_config = self.tls_server_config.clone();
76        accepting_stream.for_each(move |tcp_stream| {
77            let io_runtime = io_runtime.clone();
78
79            let tls_acceptor = tls_server_config.get_server_config().map(TlsAcceptor::from);
80
81            let handler_maker = handler_maker.clone();
82
83            async move {
84                match tcp_stream {
85                    Err(error) => debug!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt.
86                    Ok(io_stream) => {
87                        let addr = match io_stream.peer_addr() {
88                            Ok(addr) => {
89                                debug!("PostgreSQL client coming from {}", addr);
90                                Some(addr)
91                            }
92                            Err(e) => {
93                                warn!(e; "Failed to get PostgreSQL client addr");
94                                None
95                            }
96                        };
97
98                        let _handle = io_runtime.spawn(async move {
99                            crate::metrics::METRIC_POSTGRES_CONNECTIONS.inc();
100                            let pg_handler = Arc::new(handler_maker.make(addr));
101                            let r =
102                                process_socket(io_stream, tls_acceptor.clone(), pg_handler).await;
103                            crate::metrics::METRIC_POSTGRES_CONNECTIONS.dec();
104                            r
105                        });
106                    }
107                };
108            }
109        })
110    }
111}
112
113pub const POSTGRES_SERVER: &str = "POSTGRES_SERVER";
114
115#[async_trait]
116impl Server for PostgresServer {
117    async fn shutdown(&self) -> Result<()> {
118        self.base_server.shutdown().await
119    }
120
121    async fn start(&mut self, listening: SocketAddr) -> Result<()> {
122        let (stream, addr) = self
123            .base_server
124            .bind(listening, self.keep_alive_secs)
125            .await?;
126
127        let io_runtime = self.base_server.io_runtime();
128        let join_handle = common_runtime::spawn_global(self.accept(io_runtime, stream));
129
130        self.base_server.start_with(join_handle).await?;
131
132        self.bind_addr = Some(addr);
133        Ok(())
134    }
135
136    fn name(&self) -> &str {
137        POSTGRES_SERVER
138    }
139
140    fn bind_addr(&self) -> Option<SocketAddr> {
141        self.bind_addr
142    }
143}