servers/postgres/
server.rs1use std::future::Future;
16use std::net::SocketAddr;
17use std::sync::Arc;
18
19use ::auth::UserProviderRef;
20use async_trait::async_trait;
21use catalog::process_manager::ProcessManagerRef;
22use common_runtime::runtime::RuntimeTrait;
23use common_runtime::Runtime;
24use common_telemetry::{debug, warn};
25use futures::StreamExt;
26use pgwire::tokio::process_socket;
27use tokio_rustls::TlsAcceptor;
28
29use crate::error::Result;
30use crate::postgres::{MakePostgresServerHandler, MakePostgresServerHandlerBuilder};
31use crate::query_handler::sql::ServerSqlQueryHandlerRef;
32use crate::server::{AbortableStream, BaseTcpServer, Server};
33use crate::tls::ReloadableTlsServerConfig;
34
35pub struct PostgresServer {
36 base_server: BaseTcpServer,
37 make_handler: Arc<MakePostgresServerHandler>,
38 tls_server_config: Arc<ReloadableTlsServerConfig>,
39 keep_alive_secs: u64,
40 bind_addr: Option<SocketAddr>,
41 process_manager: Option<ProcessManagerRef>,
42}
43
44impl PostgresServer {
45 pub fn new(
47 query_handler: ServerSqlQueryHandlerRef,
48 force_tls: bool,
49 tls_server_config: Arc<ReloadableTlsServerConfig>,
50 keep_alive_secs: u64,
51 io_runtime: Runtime,
52 user_provider: Option<UserProviderRef>,
53 process_manager: Option<ProcessManagerRef>,
54 ) -> PostgresServer {
55 let make_handler = Arc::new(
56 MakePostgresServerHandlerBuilder::default()
57 .query_handler(query_handler.clone())
58 .user_provider(user_provider.clone())
59 .force_tls(force_tls)
60 .build()
61 .unwrap(),
62 );
63 PostgresServer {
64 base_server: BaseTcpServer::create_server("Postgres", io_runtime),
65 make_handler,
66 tls_server_config,
67 keep_alive_secs,
68 bind_addr: None,
69 process_manager,
70 }
71 }
72
73 fn accept(
74 &self,
75 io_runtime: Runtime,
76 accepting_stream: AbortableStream,
77 ) -> impl Future<Output = ()> {
78 let handler_maker = self.make_handler.clone();
79 let tls_server_config = self.tls_server_config.clone();
80 let process_manager = self.process_manager.clone();
81 accepting_stream.for_each(move |tcp_stream| {
82 let io_runtime = io_runtime.clone();
83 let tls_acceptor = tls_server_config.get_server_config().map(TlsAcceptor::from);
84 let handler_maker = handler_maker.clone();
85 let process_id = process_manager.as_ref().map(|p| p.next_id()).unwrap_or(0);
86
87 async move {
88 match tcp_stream {
89 Err(error) => debug!("Broken pipe: {}", error), Ok(io_stream) => {
91 let addr = match io_stream.peer_addr() {
92 Ok(addr) => {
93 debug!("PostgreSQL client coming from {}", addr);
94 Some(addr)
95 }
96 Err(e) => {
97 warn!(e; "Failed to get PostgreSQL client addr");
98 None
99 }
100 };
101
102 let _handle = io_runtime.spawn(async move {
103 crate::metrics::METRIC_POSTGRES_CONNECTIONS.inc();
104 let pg_handler = Arc::new(handler_maker.make(addr, process_id));
105 let r =
106 process_socket(io_stream, tls_acceptor.clone(), pg_handler).await;
107 crate::metrics::METRIC_POSTGRES_CONNECTIONS.dec();
108 r
109 });
110 }
111 };
112 }
113 })
114 }
115}
116
117pub const POSTGRES_SERVER: &str = "POSTGRES_SERVER";
118
119#[async_trait]
120impl Server for PostgresServer {
121 async fn shutdown(&self) -> Result<()> {
122 self.base_server.shutdown().await
123 }
124
125 async fn start(&mut self, listening: SocketAddr) -> Result<()> {
126 let (stream, addr) = self
127 .base_server
128 .bind(listening, self.keep_alive_secs)
129 .await?;
130
131 let io_runtime = self.base_server.io_runtime();
132 let join_handle = common_runtime::spawn_global(self.accept(io_runtime, stream));
133
134 self.base_server.start_with(join_handle).await?;
135
136 self.bind_addr = Some(addr);
137 Ok(())
138 }
139
140 fn name(&self) -> &str {
141 POSTGRES_SERVER
142 }
143
144 fn bind_addr(&self) -> Option<SocketAddr> {
145 self.bind_addr
146 }
147}