servers/postgres/
server.rs1use std::future::Future;
16use std::net::SocketAddr;
17use std::sync::Arc;
18
19use ::auth::UserProviderRef;
20use async_trait::async_trait;
21use common_runtime::runtime::RuntimeTrait;
22use common_runtime::Runtime;
23use common_telemetry::{debug, warn};
24use futures::StreamExt;
25use pgwire::tokio::process_socket;
26use tokio_rustls::TlsAcceptor;
27
28use crate::error::Result;
29use crate::postgres::{MakePostgresServerHandler, MakePostgresServerHandlerBuilder};
30use crate::query_handler::sql::ServerSqlQueryHandlerRef;
31use crate::server::{AbortableStream, BaseTcpServer, Server};
32use crate::tls::ReloadableTlsServerConfig;
33
34pub struct PostgresServer {
35 base_server: BaseTcpServer,
36 make_handler: Arc<MakePostgresServerHandler>,
37 tls_server_config: Arc<ReloadableTlsServerConfig>,
38 keep_alive_secs: u64,
39 bind_addr: Option<SocketAddr>,
40}
41
42impl PostgresServer {
43 pub fn new(
45 query_handler: ServerSqlQueryHandlerRef,
46 force_tls: bool,
47 tls_server_config: Arc<ReloadableTlsServerConfig>,
48 keep_alive_secs: u64,
49 io_runtime: Runtime,
50 user_provider: Option<UserProviderRef>,
51 ) -> PostgresServer {
52 let make_handler = Arc::new(
53 MakePostgresServerHandlerBuilder::default()
54 .query_handler(query_handler.clone())
55 .user_provider(user_provider.clone())
56 .force_tls(force_tls)
57 .build()
58 .unwrap(),
59 );
60 PostgresServer {
61 base_server: BaseTcpServer::create_server("Postgres", io_runtime),
62 make_handler,
63 tls_server_config,
64 keep_alive_secs,
65 bind_addr: None,
66 }
67 }
68
69 fn accept(
70 &self,
71 io_runtime: Runtime,
72 accepting_stream: AbortableStream,
73 ) -> impl Future<Output = ()> {
74 let handler_maker = self.make_handler.clone();
75 let tls_server_config = self.tls_server_config.clone();
76 accepting_stream.for_each(move |tcp_stream| {
77 let io_runtime = io_runtime.clone();
78
79 let tls_acceptor = tls_server_config.get_server_config().map(TlsAcceptor::from);
80
81 let handler_maker = handler_maker.clone();
82
83 async move {
84 match tcp_stream {
85 Err(error) => debug!("Broken pipe: {}", error), Ok(io_stream) => {
87 let addr = match io_stream.peer_addr() {
88 Ok(addr) => {
89 debug!("PostgreSQL client coming from {}", addr);
90 Some(addr)
91 }
92 Err(e) => {
93 warn!(e; "Failed to get PostgreSQL client addr");
94 None
95 }
96 };
97
98 let _handle = io_runtime.spawn(async move {
99 crate::metrics::METRIC_POSTGRES_CONNECTIONS.inc();
100 let pg_handler = Arc::new(handler_maker.make(addr));
101 let r =
102 process_socket(io_stream, tls_acceptor.clone(), pg_handler).await;
103 crate::metrics::METRIC_POSTGRES_CONNECTIONS.dec();
104 r
105 });
106 }
107 };
108 }
109 })
110 }
111}
112
113pub const POSTGRES_SERVER: &str = "POSTGRES_SERVER";
114
115#[async_trait]
116impl Server for PostgresServer {
117 async fn shutdown(&self) -> Result<()> {
118 self.base_server.shutdown().await
119 }
120
121 async fn start(&mut self, listening: SocketAddr) -> Result<()> {
122 let (stream, addr) = self
123 .base_server
124 .bind(listening, self.keep_alive_secs)
125 .await?;
126
127 let io_runtime = self.base_server.io_runtime();
128 let join_handle = common_runtime::spawn_global(self.accept(io_runtime, stream));
129
130 self.base_server.start_with(join_handle).await?;
131
132 self.bind_addr = Some(addr);
133 Ok(())
134 }
135
136 fn name(&self) -> &str {
137 POSTGRES_SERVER
138 }
139
140 fn bind_addr(&self) -> Option<SocketAddr> {
141 self.bind_addr
142 }
143}