servers/postgres/
server.rs1use std::any::Any;
16use std::future::Future;
17use std::net::SocketAddr;
18use std::sync::Arc;
19
20use ::auth::UserProviderRef;
21use async_trait::async_trait;
22use catalog::process_manager::ProcessManagerRef;
23use common_runtime::Runtime;
24use common_runtime::runtime::RuntimeTrait;
25use common_telemetry::{debug, warn};
26use futures::StreamExt;
27use pgwire::tokio::process_socket;
28use tokio_rustls::TlsAcceptor;
29
30use crate::error::Result;
31use crate::postgres::{MakePostgresServerHandler, MakePostgresServerHandlerBuilder};
32use crate::query_handler::sql::ServerSqlQueryHandlerRef;
33use crate::server::{AbortableStream, BaseTcpServer, Server};
34use crate::tls::ReloadableTlsServerConfig;
35
36pub struct PostgresServer {
37 base_server: BaseTcpServer,
38 make_handler: Arc<MakePostgresServerHandler>,
39 tls_server_config: Arc<ReloadableTlsServerConfig>,
40 keep_alive_secs: u64,
41 bind_addr: Option<SocketAddr>,
42 process_manager: Option<ProcessManagerRef>,
43}
44
45impl PostgresServer {
46 pub fn new(
48 query_handler: ServerSqlQueryHandlerRef,
49 force_tls: bool,
50 tls_server_config: Arc<ReloadableTlsServerConfig>,
51 keep_alive_secs: u64,
52 io_runtime: Runtime,
53 user_provider: Option<UserProviderRef>,
54 process_manager: Option<ProcessManagerRef>,
55 ) -> PostgresServer {
56 let make_handler = Arc::new(
57 MakePostgresServerHandlerBuilder::default()
58 .query_handler(query_handler.clone())
59 .user_provider(user_provider.clone())
60 .force_tls(force_tls)
61 .build()
62 .unwrap(),
63 );
64 PostgresServer {
65 base_server: BaseTcpServer::create_server("Postgres", io_runtime),
66 make_handler,
67 tls_server_config,
68 keep_alive_secs,
69 bind_addr: None,
70 process_manager,
71 }
72 }
73
74 fn accept(
75 &self,
76 io_runtime: Runtime,
77 accepting_stream: AbortableStream,
78 ) -> impl Future<Output = ()> + use<> {
79 let handler_maker = self.make_handler.clone();
80 let tls_server_config = self.tls_server_config.clone();
81 let process_manager = self.process_manager.clone();
82 accepting_stream.for_each(move |tcp_stream| {
83 let io_runtime = io_runtime.clone();
84 let tls_acceptor = tls_server_config.get_config().map(TlsAcceptor::from);
85 let handler_maker = handler_maker.clone();
86 let process_id = process_manager.as_ref().map(|p| p.next_id()).unwrap_or(0);
87
88 async move {
89 match tcp_stream {
90 Err(error) => debug!("Broken pipe: {}", error), Ok(io_stream) => {
92 let addr = match io_stream.peer_addr() {
93 Ok(addr) => {
94 debug!("PostgreSQL client coming from {}", addr);
95 Some(addr)
96 }
97 Err(e) => {
98 warn!(e; "Failed to get PostgreSQL client addr");
99 None
100 }
101 };
102
103 let _handle = io_runtime.spawn(async move {
104 crate::metrics::METRIC_POSTGRES_CONNECTIONS.inc();
105 let pg_handler = Arc::new(handler_maker.make(addr, process_id));
106 let r =
107 process_socket(io_stream, tls_acceptor.clone(), pg_handler).await;
108 crate::metrics::METRIC_POSTGRES_CONNECTIONS.dec();
109 r
110 });
111 }
112 };
113 }
114 })
115 }
116}
117
118pub const POSTGRES_SERVER: &str = "POSTGRES_SERVER";
119
120#[async_trait]
121impl Server for PostgresServer {
122 async fn shutdown(&self) -> Result<()> {
123 self.base_server.shutdown().await
124 }
125
126 async fn start(&mut self, listening: SocketAddr) -> Result<()> {
127 let (stream, addr) = self
128 .base_server
129 .bind(listening, self.keep_alive_secs)
130 .await?;
131
132 let io_runtime = self.base_server.io_runtime();
133 let join_handle = common_runtime::spawn_global(self.accept(io_runtime, stream));
134
135 self.base_server.start_with(join_handle).await?;
136
137 self.bind_addr = Some(addr);
138 Ok(())
139 }
140
141 fn name(&self) -> &str {
142 POSTGRES_SERVER
143 }
144
145 fn bind_addr(&self) -> Option<SocketAddr> {
146 self.bind_addr
147 }
148
149 fn as_any(&self) -> &dyn Any {
150 self
151 }
152}