common_runtime/
runtime.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::future::Future;
16use std::sync::atomic::{AtomicUsize, Ordering};
17use std::sync::Arc;
18use std::thread;
19use std::time::Duration;
20
21use snafu::ResultExt;
22use tokio::runtime::Builder as RuntimeBuilder;
23use tokio::sync::oneshot;
24pub use tokio::task::{JoinError, JoinHandle};
25
26use crate::error::*;
27use crate::metrics::*;
28use crate::runtime_default::DefaultRuntime;
29use crate::runtime_throttleable::ThrottleableRuntime;
30
31// configurations
32pub type Runtime = DefaultRuntime;
33
34static RUNTIME_ID: AtomicUsize = AtomicUsize::new(0);
35
36/// Dropping the dropper will cause runtime to shutdown.
37#[derive(Debug)]
38pub struct Dropper {
39    close: Option<oneshot::Sender<()>>,
40}
41
42impl Drop for Dropper {
43    fn drop(&mut self) {
44        // Send a signal to say i am dropping.
45        let _ = self.close.take().map(|v| v.send(()));
46    }
47}
48
49pub trait RuntimeTrait {
50    /// Get a runtime builder
51    fn builder() -> Builder {
52        Builder::default()
53    }
54
55    /// Spawn a future and execute it in this thread pool
56    ///
57    /// Similar to tokio::runtime::Runtime::spawn()
58    fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
59    where
60        F: Future + Send + 'static,
61        F::Output: Send + 'static;
62
63    /// Run the provided function on an executor dedicated to blocking
64    /// operations.
65    fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
66    where
67        F: FnOnce() -> R + Send + 'static,
68        R: Send + 'static;
69
70    /// Run a future to complete, this is the runtime's entry point
71    fn block_on<F: Future>(&self, future: F) -> F::Output;
72
73    /// Get the name of the runtime
74    fn name(&self) -> &str;
75}
76
77pub trait BuilderBuild<R: RuntimeTrait> {
78    fn build(&mut self) -> Result<R>;
79}
80
81pub struct Builder {
82    runtime_name: String,
83    thread_name: String,
84    priority: Priority,
85    builder: RuntimeBuilder,
86}
87
88impl Default for Builder {
89    fn default() -> Self {
90        Self {
91            runtime_name: format!("runtime-{}", RUNTIME_ID.fetch_add(1, Ordering::Relaxed)),
92            thread_name: "default-worker".to_string(),
93            builder: RuntimeBuilder::new_multi_thread(),
94            priority: Priority::VeryHigh,
95        }
96    }
97}
98
99impl Builder {
100    pub fn priority(&mut self, priority: Priority) -> &mut Self {
101        self.priority = priority;
102        self
103    }
104
105    /// Sets the number of worker threads the Runtime will use.
106    ///
107    /// This can be any number above 0. The default value is the number of cores available to the system.
108    pub fn worker_threads(&mut self, val: usize) -> &mut Self {
109        let _ = self.builder.worker_threads(val);
110        self
111    }
112
113    /// Specifies the limit for additional threads spawned by the Runtime.
114    ///
115    /// These threads are used for blocking operations like tasks spawned through spawn_blocking,
116    /// they are not always active and will exit if left idle for too long, You can change this timeout duration
117    /// with thread_keep_alive. The default value is 512.
118    pub fn max_blocking_threads(&mut self, val: usize) -> &mut Self {
119        let _ = self.builder.max_blocking_threads(val);
120        self
121    }
122
123    /// Sets a custom timeout for a thread in the blocking pool.
124    ///
125    /// By default, the timeout for a thread is set to 10 seconds.
126    pub fn thread_keep_alive(&mut self, duration: Duration) -> &mut Self {
127        let _ = self.builder.thread_keep_alive(duration);
128        self
129    }
130
131    pub fn runtime_name(&mut self, val: impl Into<String>) -> &mut Self {
132        self.runtime_name = val.into();
133        self
134    }
135
136    /// Sets name of threads spawned by the Runtime thread pool
137    pub fn thread_name(&mut self, val: impl Into<String>) -> &mut Self {
138        self.thread_name = val.into();
139        self
140    }
141}
142
143impl BuilderBuild<DefaultRuntime> for Builder {
144    fn build(&mut self) -> Result<DefaultRuntime> {
145        let builder = self
146            .builder
147            .enable_all()
148            .thread_name(self.thread_name.clone())
149            .on_thread_start(on_thread_start(self.thread_name.clone()))
150            .on_thread_stop(on_thread_stop(self.thread_name.clone()))
151            .on_thread_park(on_thread_park(self.thread_name.clone()))
152            .on_thread_unpark(on_thread_unpark(self.thread_name.clone()));
153        let runtime = if cfg!(debug_assertions) {
154            // Set the stack size to 8MB for the thread so it wouldn't overflow on large stack usage in debug mode
155            // This is necessary to avoid stack overflow while running sqlness.
156            // https://github.com/rust-lang/rust/issues/34283
157            builder
158                .thread_stack_size(8 * 1024 * 1024)
159                .build()
160                .context(BuildRuntimeSnafu)?
161        } else {
162            builder.build().context(BuildRuntimeSnafu)?
163        };
164
165        let name = self.runtime_name.clone();
166        let handle = runtime.handle().clone();
167        let (send_stop, recv_stop) = oneshot::channel();
168        // Block the runtime to shutdown.
169        let _ = thread::Builder::new()
170            .name(format!("{}-blocker", self.thread_name))
171            .spawn(move || runtime.block_on(recv_stop));
172
173        #[cfg(tokio_unstable)]
174        register_collector(name.clone(), &handle);
175
176        Ok(DefaultRuntime::new(
177            &name,
178            handle,
179            Arc::new(Dropper {
180                close: Some(send_stop),
181            }),
182        ))
183    }
184}
185
186impl BuilderBuild<ThrottleableRuntime> for Builder {
187    fn build(&mut self) -> Result<ThrottleableRuntime> {
188        let runtime = self
189            .builder
190            .enable_all()
191            .thread_name(self.thread_name.clone())
192            .on_thread_start(on_thread_start(self.thread_name.clone()))
193            .on_thread_stop(on_thread_stop(self.thread_name.clone()))
194            .on_thread_park(on_thread_park(self.thread_name.clone()))
195            .on_thread_unpark(on_thread_unpark(self.thread_name.clone()))
196            .build()
197            .context(BuildRuntimeSnafu)?;
198
199        let name = self.runtime_name.clone();
200        let handle = runtime.handle().clone();
201        let (send_stop, recv_stop) = oneshot::channel();
202        // Block the runtime to shutdown.
203        let _ = thread::Builder::new()
204            .name(format!("{}-blocker", self.thread_name))
205            .spawn(move || runtime.block_on(recv_stop));
206
207        #[cfg(tokio_unstable)]
208        register_collector(name.clone(), &handle);
209
210        ThrottleableRuntime::new(
211            &name,
212            self.priority,
213            handle,
214            Arc::new(Dropper {
215                close: Some(send_stop),
216            }),
217        )
218    }
219}
220
221#[cfg(tokio_unstable)]
222pub fn register_collector(name: String, handle: &tokio::runtime::Handle) {
223    let name = name.replace("-", "_");
224    let monitor = tokio_metrics::RuntimeMonitor::new(handle);
225    let collector = tokio_metrics_collector::RuntimeCollector::new(monitor, name);
226    let _ = prometheus::register(Box::new(collector));
227}
228
229fn on_thread_start(thread_name: String) -> impl Fn() + 'static {
230    move || {
231        METRIC_RUNTIME_THREADS_ALIVE
232            .with_label_values(&[thread_name.as_str()])
233            .inc();
234    }
235}
236
237fn on_thread_stop(thread_name: String) -> impl Fn() + 'static {
238    move || {
239        METRIC_RUNTIME_THREADS_ALIVE
240            .with_label_values(&[thread_name.as_str()])
241            .dec();
242    }
243}
244
245fn on_thread_park(thread_name: String) -> impl Fn() + 'static {
246    move || {
247        METRIC_RUNTIME_THREADS_IDLE
248            .with_label_values(&[thread_name.as_str()])
249            .inc();
250    }
251}
252
253fn on_thread_unpark(thread_name: String) -> impl Fn() + 'static {
254    move || {
255        METRIC_RUNTIME_THREADS_IDLE
256            .with_label_values(&[thread_name.as_str()])
257            .dec();
258    }
259}
260
261#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
262pub enum Priority {
263    VeryLow = 0,
264    Low = 1,
265    Middle = 2,
266    High = 3,
267    VeryHigh = 4,
268}
269
270#[cfg(test)]
271mod tests {
272
273    use std::sync::Arc;
274    use std::thread;
275    use std::time::Duration;
276
277    use common_telemetry::dump_metrics;
278    use tokio::sync::oneshot;
279    use tokio_test::assert_ok;
280
281    use super::*;
282
283    fn runtime() -> Arc<Runtime> {
284        let runtime = Builder::default()
285            .worker_threads(2)
286            .thread_name("test_spawn_join")
287            .build();
288        Arc::new(runtime.unwrap())
289    }
290
291    #[test]
292    fn test_metric() {
293        let runtime: Runtime = Builder::default()
294            .worker_threads(5)
295            .thread_name("test_runtime_metric")
296            .build()
297            .unwrap();
298        // wait threads create
299        thread::sleep(Duration::from_millis(50));
300
301        let _handle = runtime.spawn(async {
302            thread::sleep(Duration::from_millis(50));
303        });
304
305        thread::sleep(Duration::from_millis(10));
306
307        let metric_text = dump_metrics().unwrap();
308
309        assert!(metric_text.contains("runtime_threads_idle{thread_name=\"test_runtime_metric\"}"));
310        assert!(metric_text.contains("runtime_threads_alive{thread_name=\"test_runtime_metric\"}"));
311
312        #[cfg(tokio_unstable)]
313        {
314            assert!(metric_text.contains("runtime_0_tokio_budget_forced_yield_count 0"));
315            assert!(metric_text.contains("runtime_0_tokio_injection_queue_depth 0"));
316            assert!(metric_text.contains("runtime_0_tokio_workers_count 5"));
317        }
318    }
319
320    #[test]
321    fn block_on_async() {
322        let runtime = runtime();
323
324        let out = runtime.block_on(async {
325            let (tx, rx) = oneshot::channel();
326
327            let _ = thread::spawn(move || {
328                thread::sleep(Duration::from_millis(50));
329                tx.send("ZOMG").unwrap();
330            });
331
332            assert_ok!(rx.await)
333        });
334
335        assert_eq!(out, "ZOMG");
336    }
337
338    #[test]
339    fn spawn_from_blocking() {
340        let runtime = runtime();
341        let runtime1 = runtime.clone();
342        let out = runtime.block_on(async move {
343            let runtime2 = runtime1.clone();
344            let inner = assert_ok!(
345                runtime1
346                    .spawn_blocking(move || { runtime2.spawn(async move { "hello" }) })
347                    .await
348            );
349
350            assert_ok!(inner.await)
351        });
352
353        assert_eq!(out, "hello")
354    }
355
356    #[test]
357    fn test_spawn_join() {
358        let runtime = runtime();
359        let handle = runtime.spawn(async { 1 + 1 });
360
361        assert_eq!(2, runtime.block_on(handle).unwrap());
362    }
363}