servers/
server.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::collections::HashMap;
17use std::fmt::{Debug, Formatter};
18use std::net::SocketAddr;
19use std::sync::Arc;
20
21use async_trait::async_trait;
22use common_runtime::Runtime;
23use common_telemetry::{error, info};
24use futures::future::{AbortHandle, AbortRegistration, Abortable, try_join_all};
25use snafu::{ResultExt, ensure};
26use strum::Display;
27use tokio::sync::Mutex;
28use tokio::task::JoinHandle;
29use tokio_stream::wrappers::TcpListenerStream;
30
31use crate::error::{self, Result};
32
33pub(crate) type AbortableStream = Abortable<TcpListenerStream>;
34
35pub type ServerHandler = (Box<dyn Server>, SocketAddr);
36
37/// [ServerHandlers] is used to manage the lifecycle of all the services like http or grpc in the GreptimeDB server.
38#[derive(Clone, Display)]
39pub enum ServerHandlers {
40    Init(Arc<std::sync::Mutex<HashMap<String, ServerHandler>>>),
41    Started(Arc<HashMap<String, Box<dyn Server>>>),
42}
43
44impl Debug for ServerHandlers {
45    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
46        write!(f, "ServerHandlers::{}", self)
47    }
48}
49
50impl Default for ServerHandlers {
51    fn default() -> Self {
52        Self::Init(Arc::new(std::sync::Mutex::new(HashMap::new())))
53    }
54}
55
56impl ServerHandlers {
57    /// Inserts a [ServerHandler] **before** the [ServerHandlers] is started.
58    pub fn insert(&self, handler: ServerHandler) {
59        // Inserts more to ServerHandlers while it is not in the initialization state
60        // is considered a bug.
61        assert!(
62            matches!(self, ServerHandlers::Init(_)),
63            "unexpected: insert when `ServerHandlers` is not during initialization"
64        );
65        let ServerHandlers::Init(handlers) = self else {
66            unreachable!("guarded by the assertion above");
67        };
68        let mut handlers = handlers.lock().unwrap();
69        handlers.insert(handler.0.name().to_string(), handler);
70    }
71
72    /// Finds the __actual__ bound address of the service by its name.
73    ///
74    /// This is useful in testing. We can configure the service to bind to port 0 first, then start
75    /// the server to get the real bound port number. This way we avoid doing careful assignment of
76    /// the port number to the service in the test.
77    ///
78    /// Note that the address is only retrievable after the [ServerHandlers] is started (the
79    /// `start_all` method is called successfully). Otherwise you may find the address still be
80    /// `None` even if you are certain the server was inserted before.
81    pub fn addr(&self, name: &str) -> Option<SocketAddr> {
82        let ServerHandlers::Started(handlers) = self else {
83            return None;
84        };
85        handlers.get(name).and_then(|x| x.bind_addr())
86    }
87
88    /// Starts all the managed services. It will block until all the services are started.
89    /// And it will set the actual bound address to the service.
90    pub async fn start_all(&mut self) -> Result<()> {
91        let ServerHandlers::Init(handlers) = self else {
92            // If already started, do nothing.
93            return Ok(());
94        };
95
96        let mut handlers = {
97            let mut handlers = handlers.lock().unwrap();
98            std::mem::take(&mut *handlers)
99        };
100
101        try_join_all(handlers.values_mut().map(|(server, addr)| async move {
102            server.start(*addr).await?;
103            info!("Server {} is started", server.name());
104            Ok::<(), error::Error>(())
105        }))
106        .await?;
107
108        let handlers = handlers
109            .into_iter()
110            .map(|(k, v)| (k, v.0))
111            .collect::<HashMap<_, _>>();
112        *self = ServerHandlers::Started(Arc::new(handlers));
113        Ok(())
114    }
115
116    /// Shutdown all the managed services. It will block until all the services are shutdown.
117    pub async fn shutdown_all(&mut self) -> Result<()> {
118        let ServerHandlers::Started(handlers) = self else {
119            // If not started, do nothing.
120            return Ok(());
121        };
122
123        let handlers = std::mem::take(handlers);
124        try_join_all(handlers.values().map(|server| async move {
125            server.shutdown().await?;
126            info!("Service {} is shutdown!", server.name());
127            Ok::<(), error::Error>(())
128        }))
129        .await?;
130        Ok(())
131    }
132}
133
134#[async_trait]
135pub trait Server: Send + Sync {
136    /// Shutdown the server gracefully.
137    async fn shutdown(&self) -> Result<()>;
138
139    /// Starts the server and binds on `listening`.
140    ///
141    /// Caller should ensure `start()` is only invoked once.
142    async fn start(&mut self, listening: SocketAddr) -> Result<()>;
143
144    fn name(&self) -> &str;
145
146    /// Finds the actual bind address of this server.
147    /// If not found (returns `None`), maybe it's not started yet, or just don't have it.
148    fn bind_addr(&self) -> Option<SocketAddr> {
149        None
150    }
151
152    fn as_any(&self) -> &dyn Any;
153}
154
155struct AcceptTask {
156    // `abort_handle` and `abort_registration` are used in pairs in shutting down the server.
157    // They work like sender and receiver for aborting stream. When the server is shutting down,
158    // calling `abort_handle.abort()` will "notify" `abort_registration` to stop emitting new
159    // elements in the stream.
160    abort_handle: AbortHandle,
161    abort_registration: Option<AbortRegistration>,
162
163    // A handle holding the TCP accepting task.
164    join_handle: Option<JoinHandle<()>>,
165}
166
167impl AcceptTask {
168    async fn shutdown(&mut self, name: &str) -> Result<()> {
169        match self.join_handle.take() {
170            Some(join_handle) => {
171                self.abort_handle.abort();
172
173                if let Err(error) = join_handle.await {
174                    // Couldn't use `error!(e; xxx)` because JoinError doesn't implement ErrorExt.
175                    error!(
176                        "Unexpected error during shutdown {} server, error: {:?}",
177                        name, error
178                    );
179                } else {
180                    info!("{name} server is shutdown.");
181                }
182                Ok(())
183            }
184            None => error::InternalSnafu {
185                err_msg: format!("{name} server is not started."),
186            }
187            .fail()?,
188        }
189    }
190
191    async fn bind(
192        &mut self,
193        addr: SocketAddr,
194        name: &str,
195        keep_alive_secs: u64,
196    ) -> Result<(Abortable<TcpListenerStream>, SocketAddr)> {
197        match self.abort_registration.take() {
198            Some(registration) => {
199                let listener =
200                    tokio::net::TcpListener::bind(addr)
201                        .await
202                        .context(error::TokioIoSnafu {
203                            err_msg: format!("{name} failed to bind addr {addr}"),
204                        })?;
205                // get actually bond addr in case input addr use port 0
206                let addr = listener.local_addr()?;
207                info!("{name} server started at {addr}");
208
209                // set keep-alive
210                if keep_alive_secs > 0 {
211                    let socket_ref = socket2::SockRef::from(&listener);
212                    let keep_alive = socket2::TcpKeepalive::new()
213                        .with_time(std::time::Duration::from_secs(keep_alive_secs))
214                        .with_interval(std::time::Duration::from_secs(keep_alive_secs));
215                    socket_ref.set_tcp_keepalive(&keep_alive)?;
216                }
217
218                let stream = TcpListenerStream::new(listener);
219                let stream = Abortable::new(stream, registration);
220                Ok((stream, addr))
221            }
222            None => error::InternalSnafu {
223                err_msg: format!("{name} server has been started."),
224            }
225            .fail()?,
226        }
227    }
228
229    fn start_with(&mut self, join_handle: JoinHandle<()>, name: &str) -> Result<()> {
230        ensure!(
231            self.join_handle.is_none(),
232            error::InternalSnafu {
233                err_msg: format!("{name} server has been started."),
234            }
235        );
236        let _handle = self.join_handle.get_or_insert(join_handle);
237        Ok(())
238    }
239}
240
241pub(crate) struct BaseTcpServer {
242    name: String,
243    accept_task: Mutex<AcceptTask>,
244    io_runtime: Runtime,
245}
246
247impl BaseTcpServer {
248    pub(crate) fn create_server(name: impl Into<String>, io_runtime: Runtime) -> Self {
249        let (abort_handle, registration) = AbortHandle::new_pair();
250        Self {
251            name: name.into(),
252            accept_task: Mutex::new(AcceptTask {
253                abort_handle,
254                abort_registration: Some(registration),
255                join_handle: None,
256            }),
257            io_runtime,
258        }
259    }
260
261    pub(crate) async fn shutdown(&self) -> Result<()> {
262        let mut task = self.accept_task.lock().await;
263        task.shutdown(&self.name).await
264    }
265
266    /// Bind the server to the given address and set the keep-alive time.
267    ///
268    /// If `keep_alive_secs` is 0, the keep-alive will not be set.
269    pub(crate) async fn bind(
270        &self,
271        addr: SocketAddr,
272        keep_alive_secs: u64,
273    ) -> Result<(Abortable<TcpListenerStream>, SocketAddr)> {
274        let mut task = self.accept_task.lock().await;
275        task.bind(addr, &self.name, keep_alive_secs).await
276    }
277
278    pub(crate) async fn start_with(&self, join_handle: JoinHandle<()>) -> Result<()> {
279        let mut task = self.accept_task.lock().await;
280        task.start_with(join_handle, &self.name)
281    }
282
283    pub(crate) fn io_runtime(&self) -> Runtime {
284        self.io_runtime.clone()
285    }
286}