common_runtime/
runtime_throttleable.rs1use std::fmt::Debug;
16use std::future::Future;
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20use std::time::Duration;
21
22use futures::FutureExt;
23use ratelimit::Ratelimiter;
24use snafu::ResultExt;
25use tokio::runtime::Handle;
26pub use tokio::task::JoinHandle;
27use tokio::time::Sleep;
28
29use crate::error::{BuildRuntimeRateLimiterSnafu, Result};
30use crate::runtime::{Dropper, Priority, RuntimeTrait};
31use crate::Builder;
32
33struct RuntimeRateLimiter {
34 pub ratelimiter: Option<Ratelimiter>,
35}
36
37impl Debug for RuntimeRateLimiter {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 f.debug_struct("RuntimeThrottleShareWithFuture")
40 .field(
41 "ratelimiter_max_tokens",
42 &self.ratelimiter.as_ref().map(|v| v.max_tokens()),
43 )
44 .field(
45 "ratelimiter_refill_amount",
46 &self.ratelimiter.as_ref().map(|v| v.refill_amount()),
47 )
48 .finish()
49 }
50}
51
52#[derive(Clone, Debug)]
54pub struct ThrottleableRuntime {
55 name: String,
56 handle: Handle,
57 shared_with_future: Arc<RuntimeRateLimiter>,
58 _dropper: Arc<Dropper>,
60}
61
62impl ThrottleableRuntime {
63 pub(crate) fn new(
64 name: &str,
65 priority: Priority,
66 handle: Handle,
67 dropper: Arc<Dropper>,
68 ) -> Result<Self> {
69 Ok(Self {
70 name: name.to_string(),
71 handle,
72 shared_with_future: Arc::new(RuntimeRateLimiter {
73 ratelimiter: priority.ratelimiter_count()?,
74 }),
75 _dropper: dropper,
76 })
77 }
78}
79
80impl RuntimeTrait for ThrottleableRuntime {
81 fn builder() -> Builder {
82 Builder::default()
83 }
84
85 fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
89 where
90 F: Future + Send + 'static,
91 F::Output: Send + 'static,
92 {
93 self.handle
94 .spawn(ThrottleFuture::new(self.shared_with_future.clone(), future))
95 }
96
97 fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
100 where
101 F: FnOnce() -> R + Send + 'static,
102 R: Send + 'static,
103 {
104 self.handle.spawn_blocking(func)
105 }
106
107 fn block_on<F: Future>(&self, future: F) -> F::Output {
109 self.handle.block_on(future)
110 }
111
112 fn name(&self) -> &str {
113 &self.name
114 }
115}
116
117enum State {
118 Pollable,
119 Throttled(Pin<Box<Sleep>>),
120}
121
122impl State {
123 fn unwrap_backoff(&mut self) -> &mut Pin<Box<Sleep>> {
124 match self {
125 State::Throttled(sleep) => sleep,
126 _ => panic!("unwrap_backoff failed"),
127 }
128 }
129}
130
131#[pin_project::pin_project]
132pub struct ThrottleFuture<F: Future + Send + 'static> {
133 #[pin]
134 future: F,
135
136 handle: Arc<RuntimeRateLimiter>,
138
139 state: State,
140}
141
142impl<F> ThrottleFuture<F>
143where
144 F: Future + Send + 'static,
145 F::Output: Send + 'static,
146{
147 fn new(handle: Arc<RuntimeRateLimiter>, future: F) -> Self {
148 Self {
149 future,
150 handle,
151 state: State::Pollable,
152 }
153 }
154}
155
156impl<F> Future for ThrottleFuture<F>
157where
158 F: Future + Send + 'static,
159 F::Output: Send + 'static,
160{
161 type Output = F::Output;
162
163 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
164 let this = self.project();
165
166 match this.state {
167 State::Pollable => {}
168 State::Throttled(ref mut sleep) => match sleep.poll_unpin(cx) {
169 Poll::Ready(_) => {
170 *this.state = State::Pollable;
171 }
172 Poll::Pending => return Poll::Pending,
173 },
174 };
175
176 if let Some(ratelimiter) = &this.handle.ratelimiter {
177 if let Err(wait) = ratelimiter.try_wait() {
178 *this.state = State::Throttled(Box::pin(tokio::time::sleep(wait)));
179 match this.state.unwrap_backoff().poll_unpin(cx) {
180 Poll::Ready(_) => {
181 *this.state = State::Pollable;
182 }
183 Poll::Pending => {
184 return Poll::Pending;
185 }
186 }
187 }
188 }
189
190 let poll_res = this.future.poll(cx);
191
192 match poll_res {
193 Poll::Ready(r) => Poll::Ready(r),
194 Poll::Pending => Poll::Pending,
195 }
196 }
197}
198
199impl Priority {
200 fn ratelimiter_count(&self) -> Result<Option<Ratelimiter>> {
201 let max = 8000;
202 let gen_per_10ms = match self {
203 Priority::VeryLow => Some(2000),
204 Priority::Low => Some(4000),
205 Priority::Middle => Some(6000),
206 Priority::High => Some(8000),
207 Priority::VeryHigh => None,
208 };
209 if let Some(gen_per_10ms) = gen_per_10ms {
210 Ratelimiter::builder(gen_per_10ms, Duration::from_millis(10)) .max_tokens(max) .build()
213 .context(BuildRuntimeRateLimiterSnafu)
214 .map(Some)
215 } else {
216 Ok(None)
217 }
218 }
219}
220
221#[cfg(test)]
222mod tests {
223
224 use tokio::fs::File;
225 use tokio::io::AsyncWriteExt;
226 use tokio::time::Duration;
227
228 use super::*;
229 use crate::runtime::BuilderBuild;
230
231 #[tokio::test]
232 async fn test_throttleable_runtime_spawn_simple() {
233 for p in [
234 Priority::VeryLow,
235 Priority::Low,
236 Priority::Middle,
237 Priority::High,
238 Priority::VeryHigh,
239 ] {
240 let runtime: ThrottleableRuntime = Builder::default()
241 .runtime_name("test")
242 .thread_name("test")
243 .worker_threads(8)
244 .priority(p)
245 .build()
246 .expect("Fail to create runtime");
247
248 let handle = runtime.spawn(async {
250 tokio::time::sleep(Duration::from_millis(10)).await;
251 42
252 });
253 let result = handle.await.expect("Task panicked");
254 assert_eq!(result, 42);
255 }
256 }
257
258 #[tokio::test]
259 async fn test_throttleable_runtime_spawn_complex() {
260 let tempdir = tempfile::tempdir().unwrap();
261 for p in [
262 Priority::VeryLow,
263 Priority::Low,
264 Priority::Middle,
265 Priority::High,
266 Priority::VeryHigh,
267 ] {
268 let runtime: ThrottleableRuntime = Builder::default()
269 .runtime_name("test")
270 .thread_name("test")
271 .worker_threads(8)
272 .priority(p)
273 .build()
274 .expect("Fail to create runtime");
275 let tempdirpath = tempdir.path().to_path_buf();
276 let handle = runtime.spawn(async move {
277 let mut file = File::create(tempdirpath.join("test.txt")).await.unwrap();
278 file.write_all(b"Hello, world!").await.unwrap();
279 42
280 });
281 let result = handle.await.expect("Task panicked");
282 assert_eq!(result, 42);
283 }
284 }
285}