1use std::any::Any;
16use std::collections::HashMap;
17use std::fmt::{Debug, Formatter};
18use std::net::SocketAddr;
19use std::sync::Arc;
20
21use async_trait::async_trait;
22use common_runtime::Runtime;
23use common_telemetry::{error, info};
24use futures::future::{AbortHandle, AbortRegistration, Abortable, try_join_all};
25use snafu::{ResultExt, ensure};
26use strum::Display;
27use tokio::sync::Mutex;
28use tokio::task::JoinHandle;
29use tokio_stream::wrappers::TcpListenerStream;
30
31use crate::error::{self, Result};
32
33pub(crate) type AbortableStream = Abortable<TcpListenerStream>;
34
35pub type ServerHandler = (Box<dyn Server>, SocketAddr);
36
37#[derive(Clone, Display)]
39pub enum ServerHandlers {
40 Init(Arc<std::sync::Mutex<HashMap<String, ServerHandler>>>),
41 Started(Arc<HashMap<String, Box<dyn Server>>>),
42}
43
44impl Debug for ServerHandlers {
45 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
46 write!(f, "ServerHandlers::{}", self)
47 }
48}
49
50impl Default for ServerHandlers {
51 fn default() -> Self {
52 Self::Init(Arc::new(std::sync::Mutex::new(HashMap::new())))
53 }
54}
55
56impl ServerHandlers {
57 pub fn insert(&self, handler: ServerHandler) {
59 assert!(
62 matches!(self, ServerHandlers::Init(_)),
63 "unexpected: insert when `ServerHandlers` is not during initialization"
64 );
65 let ServerHandlers::Init(handlers) = self else {
66 unreachable!("guarded by the assertion above");
67 };
68 let mut handlers = handlers.lock().unwrap();
69 handlers.insert(handler.0.name().to_string(), handler);
70 }
71
72 pub fn addr(&self, name: &str) -> Option<SocketAddr> {
82 let ServerHandlers::Started(handlers) = self else {
83 return None;
84 };
85 handlers.get(name).and_then(|x| x.bind_addr())
86 }
87
88 pub async fn start_all(&mut self) -> Result<()> {
91 let ServerHandlers::Init(handlers) = self else {
92 return Ok(());
94 };
95
96 let mut handlers = {
97 let mut handlers = handlers.lock().unwrap();
98 std::mem::take(&mut *handlers)
99 };
100
101 try_join_all(handlers.values_mut().map(|(server, addr)| async move {
102 server.start(*addr).await?;
103 info!("Server {} is started", server.name());
104 Ok::<(), error::Error>(())
105 }))
106 .await?;
107
108 let handlers = handlers
109 .into_iter()
110 .map(|(k, v)| (k, v.0))
111 .collect::<HashMap<_, _>>();
112 *self = ServerHandlers::Started(Arc::new(handlers));
113 Ok(())
114 }
115
116 pub async fn shutdown_all(&mut self) -> Result<()> {
118 let ServerHandlers::Started(handlers) = self else {
119 return Ok(());
121 };
122
123 let handlers = std::mem::take(handlers);
124 try_join_all(handlers.values().map(|server| async move {
125 server.shutdown().await?;
126 info!("Service {} is shutdown!", server.name());
127 Ok::<(), error::Error>(())
128 }))
129 .await?;
130 Ok(())
131 }
132}
133
134#[async_trait]
135pub trait Server: Send + Sync {
136 async fn shutdown(&self) -> Result<()>;
138
139 async fn start(&mut self, listening: SocketAddr) -> Result<()>;
143
144 fn name(&self) -> &str;
145
146 fn bind_addr(&self) -> Option<SocketAddr> {
149 None
150 }
151
152 fn as_any(&self) -> &dyn Any;
153}
154
155struct AcceptTask {
156 abort_handle: AbortHandle,
161 abort_registration: Option<AbortRegistration>,
162
163 join_handle: Option<JoinHandle<()>>,
165}
166
167impl AcceptTask {
168 async fn shutdown(&mut self, name: &str) -> Result<()> {
169 match self.join_handle.take() {
170 Some(join_handle) => {
171 self.abort_handle.abort();
172
173 if let Err(error) = join_handle.await {
174 error!(
176 "Unexpected error during shutdown {} server, error: {:?}",
177 name, error
178 );
179 } else {
180 info!("{name} server is shutdown.");
181 }
182 Ok(())
183 }
184 None => error::InternalSnafu {
185 err_msg: format!("{name} server is not started."),
186 }
187 .fail()?,
188 }
189 }
190
191 async fn bind(
192 &mut self,
193 addr: SocketAddr,
194 name: &str,
195 keep_alive_secs: u64,
196 ) -> Result<(Abortable<TcpListenerStream>, SocketAddr)> {
197 match self.abort_registration.take() {
198 Some(registration) => {
199 let listener =
200 tokio::net::TcpListener::bind(addr)
201 .await
202 .context(error::TokioIoSnafu {
203 err_msg: format!("{name} failed to bind addr {addr}"),
204 })?;
205 let addr = listener.local_addr()?;
207 info!("{name} server started at {addr}");
208
209 if keep_alive_secs > 0 {
211 let socket_ref = socket2::SockRef::from(&listener);
212 let keep_alive = socket2::TcpKeepalive::new()
213 .with_time(std::time::Duration::from_secs(keep_alive_secs))
214 .with_interval(std::time::Duration::from_secs(keep_alive_secs));
215 socket_ref.set_tcp_keepalive(&keep_alive)?;
216 }
217
218 let stream = TcpListenerStream::new(listener);
219 let stream = Abortable::new(stream, registration);
220 Ok((stream, addr))
221 }
222 None => error::InternalSnafu {
223 err_msg: format!("{name} server has been started."),
224 }
225 .fail()?,
226 }
227 }
228
229 fn start_with(&mut self, join_handle: JoinHandle<()>, name: &str) -> Result<()> {
230 ensure!(
231 self.join_handle.is_none(),
232 error::InternalSnafu {
233 err_msg: format!("{name} server has been started."),
234 }
235 );
236 let _handle = self.join_handle.get_or_insert(join_handle);
237 Ok(())
238 }
239}
240
241pub(crate) struct BaseTcpServer {
242 name: String,
243 accept_task: Mutex<AcceptTask>,
244 io_runtime: Runtime,
245}
246
247impl BaseTcpServer {
248 pub(crate) fn create_server(name: impl Into<String>, io_runtime: Runtime) -> Self {
249 let (abort_handle, registration) = AbortHandle::new_pair();
250 Self {
251 name: name.into(),
252 accept_task: Mutex::new(AcceptTask {
253 abort_handle,
254 abort_registration: Some(registration),
255 join_handle: None,
256 }),
257 io_runtime,
258 }
259 }
260
261 pub(crate) async fn shutdown(&self) -> Result<()> {
262 let mut task = self.accept_task.lock().await;
263 task.shutdown(&self.name).await
264 }
265
266 pub(crate) async fn bind(
270 &self,
271 addr: SocketAddr,
272 keep_alive_secs: u64,
273 ) -> Result<(Abortable<TcpListenerStream>, SocketAddr)> {
274 let mut task = self.accept_task.lock().await;
275 task.bind(addr, &self.name, keep_alive_secs).await
276 }
277
278 pub(crate) async fn start_with(&self, join_handle: JoinHandle<()>) -> Result<()> {
279 let mut task = self.accept_task.lock().await;
280 task.start_with(join_handle, &self.name)
281 }
282
283 pub(crate) fn io_runtime(&self) -> Runtime {
284 self.io_runtime.clone()
285 }
286}