servers/
server.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt::{Debug, Formatter};
17use std::net::SocketAddr;
18use std::sync::Arc;
19
20use async_trait::async_trait;
21use common_runtime::Runtime;
22use common_telemetry::{error, info};
23use futures::future::{try_join_all, AbortHandle, AbortRegistration, Abortable};
24use snafu::{ensure, ResultExt};
25use strum::Display;
26use tokio::sync::Mutex;
27use tokio::task::JoinHandle;
28use tokio_stream::wrappers::TcpListenerStream;
29
30use crate::error::{self, Result};
31
32pub(crate) type AbortableStream = Abortable<TcpListenerStream>;
33
34pub type ServerHandler = (Box<dyn Server>, SocketAddr);
35
36/// [ServerHandlers] is used to manage the lifecycle of all the services like http or grpc in the GreptimeDB server.
37#[derive(Clone, Display)]
38pub enum ServerHandlers {
39    Init(Arc<std::sync::Mutex<HashMap<String, ServerHandler>>>),
40    Started(Arc<HashMap<String, Box<dyn Server>>>),
41}
42
43impl Debug for ServerHandlers {
44    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
45        write!(f, "ServerHandlers::{}", self)
46    }
47}
48
49impl Default for ServerHandlers {
50    fn default() -> Self {
51        Self::Init(Arc::new(std::sync::Mutex::new(HashMap::new())))
52    }
53}
54
55impl ServerHandlers {
56    /// Inserts a [ServerHandler] **before** the [ServerHandlers] is started.
57    pub fn insert(&self, handler: ServerHandler) {
58        // Inserts more to ServerHandlers while it is not in the initialization state
59        // is considered a bug.
60        assert!(
61            matches!(self, ServerHandlers::Init(_)),
62            "unexpected: insert when `ServerHandlers` is not during initialization"
63        );
64        let ServerHandlers::Init(handlers) = self else {
65            unreachable!("guarded by the assertion above");
66        };
67        let mut handlers = handlers.lock().unwrap();
68        handlers.insert(handler.0.name().to_string(), handler);
69    }
70
71    /// Finds the __actual__ bound address of the service by its name.
72    ///
73    /// This is useful in testing. We can configure the service to bind to port 0 first, then start
74    /// the server to get the real bound port number. This way we avoid doing careful assignment of
75    /// the port number to the service in the test.
76    ///
77    /// Note that the address is only retrievable after the [ServerHandlers] is started (the
78    /// `start_all` method is called successfully). Otherwise you may find the address still be
79    /// `None` even if you are certain the server was inserted before.
80    pub fn addr(&self, name: &str) -> Option<SocketAddr> {
81        let ServerHandlers::Started(handlers) = self else {
82            return None;
83        };
84        handlers.get(name).and_then(|x| x.bind_addr())
85    }
86
87    /// Starts all the managed services. It will block until all the services are started.
88    /// And it will set the actual bound address to the service.
89    pub async fn start_all(&mut self) -> Result<()> {
90        let ServerHandlers::Init(handlers) = self else {
91            // If already started, do nothing.
92            return Ok(());
93        };
94
95        let mut handlers = {
96            let mut handlers = handlers.lock().unwrap();
97            std::mem::take(&mut *handlers)
98        };
99
100        try_join_all(handlers.values_mut().map(|(server, addr)| async move {
101            server.start(*addr).await?;
102            info!("Server {} is started", server.name());
103            Ok::<(), error::Error>(())
104        }))
105        .await?;
106
107        let handlers = handlers
108            .into_iter()
109            .map(|(k, v)| (k, v.0))
110            .collect::<HashMap<_, _>>();
111        *self = ServerHandlers::Started(Arc::new(handlers));
112        Ok(())
113    }
114
115    /// Shutdown all the managed services. It will block until all the services are shutdown.
116    pub async fn shutdown_all(&mut self) -> Result<()> {
117        let ServerHandlers::Started(handlers) = self else {
118            // If not started, do nothing.
119            return Ok(());
120        };
121
122        let handlers = std::mem::take(handlers);
123        try_join_all(handlers.values().map(|server| async move {
124            server.shutdown().await?;
125            info!("Service {} is shutdown!", server.name());
126            Ok::<(), error::Error>(())
127        }))
128        .await?;
129        Ok(())
130    }
131}
132
133#[async_trait]
134pub trait Server: Send + Sync {
135    /// Shutdown the server gracefully.
136    async fn shutdown(&self) -> Result<()>;
137
138    /// Starts the server and binds on `listening`.
139    ///
140    /// Caller should ensure `start()` is only invoked once.
141    async fn start(&mut self, listening: SocketAddr) -> Result<()>;
142
143    fn name(&self) -> &str;
144
145    /// Finds the actual bind address of this server.
146    /// If not found (returns `None`), maybe it's not started yet, or just don't have it.
147    fn bind_addr(&self) -> Option<SocketAddr> {
148        None
149    }
150}
151
152struct AcceptTask {
153    // `abort_handle` and `abort_registration` are used in pairs in shutting down the server.
154    // They work like sender and receiver for aborting stream. When the server is shutting down,
155    // calling `abort_handle.abort()` will "notify" `abort_registration` to stop emitting new
156    // elements in the stream.
157    abort_handle: AbortHandle,
158    abort_registration: Option<AbortRegistration>,
159
160    // A handle holding the TCP accepting task.
161    join_handle: Option<JoinHandle<()>>,
162}
163
164impl AcceptTask {
165    async fn shutdown(&mut self, name: &str) -> Result<()> {
166        match self.join_handle.take() {
167            Some(join_handle) => {
168                self.abort_handle.abort();
169
170                if let Err(error) = join_handle.await {
171                    // Couldn't use `error!(e; xxx)` because JoinError doesn't implement ErrorExt.
172                    error!(
173                        "Unexpected error during shutdown {} server, error: {:?}",
174                        name, error
175                    );
176                } else {
177                    info!("{name} server is shutdown.");
178                }
179                Ok(())
180            }
181            None => error::InternalSnafu {
182                err_msg: format!("{name} server is not started."),
183            }
184            .fail()?,
185        }
186    }
187
188    async fn bind(
189        &mut self,
190        addr: SocketAddr,
191        name: &str,
192        keep_alive_secs: u64,
193    ) -> Result<(Abortable<TcpListenerStream>, SocketAddr)> {
194        match self.abort_registration.take() {
195            Some(registration) => {
196                let listener =
197                    tokio::net::TcpListener::bind(addr)
198                        .await
199                        .context(error::TokioIoSnafu {
200                            err_msg: format!("{name} failed to bind addr {addr}"),
201                        })?;
202                // get actually bond addr in case input addr use port 0
203                let addr = listener.local_addr()?;
204                info!("{name} server started at {addr}");
205
206                // set keep-alive
207                if keep_alive_secs > 0 {
208                    let socket_ref = socket2::SockRef::from(&listener);
209                    let keep_alive = socket2::TcpKeepalive::new()
210                        .with_time(std::time::Duration::from_secs(keep_alive_secs))
211                        .with_interval(std::time::Duration::from_secs(keep_alive_secs));
212                    socket_ref.set_tcp_keepalive(&keep_alive)?;
213                }
214
215                let stream = TcpListenerStream::new(listener);
216                let stream = Abortable::new(stream, registration);
217                Ok((stream, addr))
218            }
219            None => error::InternalSnafu {
220                err_msg: format!("{name} server has been started."),
221            }
222            .fail()?,
223        }
224    }
225
226    fn start_with(&mut self, join_handle: JoinHandle<()>, name: &str) -> Result<()> {
227        ensure!(
228            self.join_handle.is_none(),
229            error::InternalSnafu {
230                err_msg: format!("{name} server has been started."),
231            }
232        );
233        let _handle = self.join_handle.get_or_insert(join_handle);
234        Ok(())
235    }
236}
237
238pub(crate) struct BaseTcpServer {
239    name: String,
240    accept_task: Mutex<AcceptTask>,
241    io_runtime: Runtime,
242}
243
244impl BaseTcpServer {
245    pub(crate) fn create_server(name: impl Into<String>, io_runtime: Runtime) -> Self {
246        let (abort_handle, registration) = AbortHandle::new_pair();
247        Self {
248            name: name.into(),
249            accept_task: Mutex::new(AcceptTask {
250                abort_handle,
251                abort_registration: Some(registration),
252                join_handle: None,
253            }),
254            io_runtime,
255        }
256    }
257
258    pub(crate) async fn shutdown(&self) -> Result<()> {
259        let mut task = self.accept_task.lock().await;
260        task.shutdown(&self.name).await
261    }
262
263    /// Bind the server to the given address and set the keep-alive time.
264    ///
265    /// If `keep_alive_secs` is 0, the keep-alive will not be set.
266    pub(crate) async fn bind(
267        &self,
268        addr: SocketAddr,
269        keep_alive_secs: u64,
270    ) -> Result<(Abortable<TcpListenerStream>, SocketAddr)> {
271        let mut task = self.accept_task.lock().await;
272        task.bind(addr, &self.name, keep_alive_secs).await
273    }
274
275    pub(crate) async fn start_with(&self, join_handle: JoinHandle<()>) -> Result<()> {
276        let mut task = self.accept_task.lock().await;
277        task.start_with(join_handle, &self.name)
278    }
279
280    pub(crate) fn io_runtime(&self) -> Runtime {
281        self.io_runtime.clone()
282    }
283}