1use std::collections::HashMap;
16use std::fmt::{Debug, Formatter};
17use std::net::SocketAddr;
18use std::sync::Arc;
19
20use async_trait::async_trait;
21use common_runtime::Runtime;
22use common_telemetry::{error, info};
23use futures::future::{try_join_all, AbortHandle, AbortRegistration, Abortable};
24use snafu::{ensure, ResultExt};
25use strum::Display;
26use tokio::sync::Mutex;
27use tokio::task::JoinHandle;
28use tokio_stream::wrappers::TcpListenerStream;
29
30use crate::error::{self, Result};
31
32pub(crate) type AbortableStream = Abortable<TcpListenerStream>;
33
34pub type ServerHandler = (Box<dyn Server>, SocketAddr);
35
36#[derive(Clone, Display)]
38pub enum ServerHandlers {
39 Init(Arc<std::sync::Mutex<HashMap<String, ServerHandler>>>),
40 Started(Arc<HashMap<String, Box<dyn Server>>>),
41}
42
43impl Debug for ServerHandlers {
44 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
45 write!(f, "ServerHandlers::{}", self)
46 }
47}
48
49impl Default for ServerHandlers {
50 fn default() -> Self {
51 Self::Init(Arc::new(std::sync::Mutex::new(HashMap::new())))
52 }
53}
54
55impl ServerHandlers {
56 pub fn insert(&self, handler: ServerHandler) {
58 assert!(
61 matches!(self, ServerHandlers::Init(_)),
62 "unexpected: insert when `ServerHandlers` is not during initialization"
63 );
64 let ServerHandlers::Init(handlers) = self else {
65 unreachable!("guarded by the assertion above");
66 };
67 let mut handlers = handlers.lock().unwrap();
68 handlers.insert(handler.0.name().to_string(), handler);
69 }
70
71 pub fn addr(&self, name: &str) -> Option<SocketAddr> {
81 let ServerHandlers::Started(handlers) = self else {
82 return None;
83 };
84 handlers.get(name).and_then(|x| x.bind_addr())
85 }
86
87 pub async fn start_all(&mut self) -> Result<()> {
90 let ServerHandlers::Init(handlers) = self else {
91 return Ok(());
93 };
94
95 let mut handlers = {
96 let mut handlers = handlers.lock().unwrap();
97 std::mem::take(&mut *handlers)
98 };
99
100 try_join_all(handlers.values_mut().map(|(server, addr)| async move {
101 server.start(*addr).await?;
102 info!("Server {} is started", server.name());
103 Ok::<(), error::Error>(())
104 }))
105 .await?;
106
107 let handlers = handlers
108 .into_iter()
109 .map(|(k, v)| (k, v.0))
110 .collect::<HashMap<_, _>>();
111 *self = ServerHandlers::Started(Arc::new(handlers));
112 Ok(())
113 }
114
115 pub async fn shutdown_all(&mut self) -> Result<()> {
117 let ServerHandlers::Started(handlers) = self else {
118 return Ok(());
120 };
121
122 let handlers = std::mem::take(handlers);
123 try_join_all(handlers.values().map(|server| async move {
124 server.shutdown().await?;
125 info!("Service {} is shutdown!", server.name());
126 Ok::<(), error::Error>(())
127 }))
128 .await?;
129 Ok(())
130 }
131}
132
133#[async_trait]
134pub trait Server: Send + Sync {
135 async fn shutdown(&self) -> Result<()>;
137
138 async fn start(&mut self, listening: SocketAddr) -> Result<()>;
142
143 fn name(&self) -> &str;
144
145 fn bind_addr(&self) -> Option<SocketAddr> {
148 None
149 }
150}
151
152struct AcceptTask {
153 abort_handle: AbortHandle,
158 abort_registration: Option<AbortRegistration>,
159
160 join_handle: Option<JoinHandle<()>>,
162}
163
164impl AcceptTask {
165 async fn shutdown(&mut self, name: &str) -> Result<()> {
166 match self.join_handle.take() {
167 Some(join_handle) => {
168 self.abort_handle.abort();
169
170 if let Err(error) = join_handle.await {
171 error!(
173 "Unexpected error during shutdown {} server, error: {:?}",
174 name, error
175 );
176 } else {
177 info!("{name} server is shutdown.");
178 }
179 Ok(())
180 }
181 None => error::InternalSnafu {
182 err_msg: format!("{name} server is not started."),
183 }
184 .fail()?,
185 }
186 }
187
188 async fn bind(
189 &mut self,
190 addr: SocketAddr,
191 name: &str,
192 keep_alive_secs: u64,
193 ) -> Result<(Abortable<TcpListenerStream>, SocketAddr)> {
194 match self.abort_registration.take() {
195 Some(registration) => {
196 let listener =
197 tokio::net::TcpListener::bind(addr)
198 .await
199 .context(error::TokioIoSnafu {
200 err_msg: format!("{name} failed to bind addr {addr}"),
201 })?;
202 let addr = listener.local_addr()?;
204 info!("{name} server started at {addr}");
205
206 if keep_alive_secs > 0 {
208 let socket_ref = socket2::SockRef::from(&listener);
209 let keep_alive = socket2::TcpKeepalive::new()
210 .with_time(std::time::Duration::from_secs(keep_alive_secs))
211 .with_interval(std::time::Duration::from_secs(keep_alive_secs));
212 socket_ref.set_tcp_keepalive(&keep_alive)?;
213 }
214
215 let stream = TcpListenerStream::new(listener);
216 let stream = Abortable::new(stream, registration);
217 Ok((stream, addr))
218 }
219 None => error::InternalSnafu {
220 err_msg: format!("{name} server has been started."),
221 }
222 .fail()?,
223 }
224 }
225
226 fn start_with(&mut self, join_handle: JoinHandle<()>, name: &str) -> Result<()> {
227 ensure!(
228 self.join_handle.is_none(),
229 error::InternalSnafu {
230 err_msg: format!("{name} server has been started."),
231 }
232 );
233 let _handle = self.join_handle.get_or_insert(join_handle);
234 Ok(())
235 }
236}
237
238pub(crate) struct BaseTcpServer {
239 name: String,
240 accept_task: Mutex<AcceptTask>,
241 io_runtime: Runtime,
242}
243
244impl BaseTcpServer {
245 pub(crate) fn create_server(name: impl Into<String>, io_runtime: Runtime) -> Self {
246 let (abort_handle, registration) = AbortHandle::new_pair();
247 Self {
248 name: name.into(),
249 accept_task: Mutex::new(AcceptTask {
250 abort_handle,
251 abort_registration: Some(registration),
252 join_handle: None,
253 }),
254 io_runtime,
255 }
256 }
257
258 pub(crate) async fn shutdown(&self) -> Result<()> {
259 let mut task = self.accept_task.lock().await;
260 task.shutdown(&self.name).await
261 }
262
263 pub(crate) async fn bind(
267 &self,
268 addr: SocketAddr,
269 keep_alive_secs: u64,
270 ) -> Result<(Abortable<TcpListenerStream>, SocketAddr)> {
271 let mut task = self.accept_task.lock().await;
272 task.bind(addr, &self.name, keep_alive_secs).await
273 }
274
275 pub(crate) async fn start_with(&self, join_handle: JoinHandle<()>) -> Result<()> {
276 let mut task = self.accept_task.lock().await;
277 task.start_with(join_handle, &self.name)
278 }
279
280 pub(crate) fn io_runtime(&self) -> Runtime {
281 self.io_runtime.clone()
282 }
283}