common_runtime/
runtime_throttleable.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::future::Future;
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20use std::time::Duration;
21
22use futures::FutureExt;
23use ratelimit::Ratelimiter;
24use snafu::ResultExt;
25use tokio::runtime::Handle;
26pub use tokio::task::JoinHandle;
27use tokio::time::Sleep;
28
29use crate::error::{BuildRuntimeRateLimiterSnafu, Result};
30use crate::runtime::{Dropper, Priority, RuntimeTrait};
31use crate::Builder;
32
33struct RuntimeRateLimiter {
34    pub ratelimiter: Option<Ratelimiter>,
35}
36
37impl Debug for RuntimeRateLimiter {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        f.debug_struct("RuntimeThrottleShareWithFuture")
40            .field(
41                "ratelimiter_max_tokens",
42                &self.ratelimiter.as_ref().map(|v| v.max_tokens()),
43            )
44            .field(
45                "ratelimiter_refill_amount",
46                &self.ratelimiter.as_ref().map(|v| v.refill_amount()),
47            )
48            .finish()
49    }
50}
51
52/// A runtime to run future tasks
53#[derive(Clone, Debug)]
54pub struct ThrottleableRuntime {
55    name: String,
56    handle: Handle,
57    shared_with_future: Arc<RuntimeRateLimiter>,
58    // Used to receive a drop signal when dropper is dropped, inspired by databend
59    _dropper: Arc<Dropper>,
60}
61
62impl ThrottleableRuntime {
63    pub(crate) fn new(
64        name: &str,
65        priority: Priority,
66        handle: Handle,
67        dropper: Arc<Dropper>,
68    ) -> Result<Self> {
69        Ok(Self {
70            name: name.to_string(),
71            handle,
72            shared_with_future: Arc::new(RuntimeRateLimiter {
73                ratelimiter: priority.ratelimiter_count()?,
74            }),
75            _dropper: dropper,
76        })
77    }
78}
79
80impl RuntimeTrait for ThrottleableRuntime {
81    fn builder() -> Builder {
82        Builder::default()
83    }
84
85    /// Spawn a future and execute it in this thread pool
86    ///
87    /// Similar to tokio::runtime::Runtime::spawn()
88    fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
89    where
90        F: Future + Send + 'static,
91        F::Output: Send + 'static,
92    {
93        self.handle
94            .spawn(ThrottleFuture::new(self.shared_with_future.clone(), future))
95    }
96
97    /// Run the provided function on an executor dedicated to blocking
98    /// operations.
99    fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
100    where
101        F: FnOnce() -> R + Send + 'static,
102        R: Send + 'static,
103    {
104        self.handle.spawn_blocking(func)
105    }
106
107    /// Run a future to complete, this is the runtime's entry point
108    fn block_on<F: Future>(&self, future: F) -> F::Output {
109        self.handle.block_on(future)
110    }
111
112    fn name(&self) -> &str {
113        &self.name
114    }
115}
116
117enum State {
118    Pollable,
119    Throttled(Pin<Box<Sleep>>),
120}
121
122impl State {
123    fn unwrap_backoff(&mut self) -> &mut Pin<Box<Sleep>> {
124        match self {
125            State::Throttled(sleep) => sleep,
126            _ => panic!("unwrap_backoff failed"),
127        }
128    }
129}
130
131#[pin_project::pin_project]
132pub struct ThrottleFuture<F: Future + Send + 'static> {
133    #[pin]
134    future: F,
135
136    /// RateLimiter of this future
137    handle: Arc<RuntimeRateLimiter>,
138
139    state: State,
140}
141
142impl<F> ThrottleFuture<F>
143where
144    F: Future + Send + 'static,
145    F::Output: Send + 'static,
146{
147    fn new(handle: Arc<RuntimeRateLimiter>, future: F) -> Self {
148        Self {
149            future,
150            handle,
151            state: State::Pollable,
152        }
153    }
154}
155
156impl<F> Future for ThrottleFuture<F>
157where
158    F: Future + Send + 'static,
159    F::Output: Send + 'static,
160{
161    type Output = F::Output;
162
163    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
164        let this = self.project();
165
166        match this.state {
167            State::Pollable => {}
168            State::Throttled(ref mut sleep) => match sleep.poll_unpin(cx) {
169                Poll::Ready(_) => {
170                    *this.state = State::Pollable;
171                }
172                Poll::Pending => return Poll::Pending,
173            },
174        };
175
176        if let Some(ratelimiter) = &this.handle.ratelimiter {
177            if let Err(wait) = ratelimiter.try_wait() {
178                *this.state = State::Throttled(Box::pin(tokio::time::sleep(wait)));
179                match this.state.unwrap_backoff().poll_unpin(cx) {
180                    Poll::Ready(_) => {
181                        *this.state = State::Pollable;
182                    }
183                    Poll::Pending => {
184                        return Poll::Pending;
185                    }
186                }
187            }
188        }
189
190        let poll_res = this.future.poll(cx);
191
192        match poll_res {
193            Poll::Ready(r) => Poll::Ready(r),
194            Poll::Pending => Poll::Pending,
195        }
196    }
197}
198
199impl Priority {
200    fn ratelimiter_count(&self) -> Result<Option<Ratelimiter>> {
201        let max = 8000;
202        let gen_per_10ms = match self {
203            Priority::VeryLow => Some(2000),
204            Priority::Low => Some(4000),
205            Priority::Middle => Some(6000),
206            Priority::High => Some(8000),
207            Priority::VeryHigh => None,
208        };
209        if let Some(gen_per_10ms) = gen_per_10ms {
210            Ratelimiter::builder(gen_per_10ms, Duration::from_millis(10)) // generate poll count per 10ms
211                .max_tokens(max) // reserved token for batch request
212                .build()
213                .context(BuildRuntimeRateLimiterSnafu)
214                .map(Some)
215        } else {
216            Ok(None)
217        }
218    }
219}
220
221#[cfg(test)]
222mod tests {
223
224    use tokio::fs::File;
225    use tokio::io::AsyncWriteExt;
226    use tokio::time::Duration;
227
228    use super::*;
229    use crate::runtime::BuilderBuild;
230
231    #[tokio::test]
232    async fn test_throttleable_runtime_spawn_simple() {
233        for p in [
234            Priority::VeryLow,
235            Priority::Low,
236            Priority::Middle,
237            Priority::High,
238            Priority::VeryHigh,
239        ] {
240            let runtime: ThrottleableRuntime = Builder::default()
241                .runtime_name("test")
242                .thread_name("test")
243                .worker_threads(8)
244                .priority(p)
245                .build()
246                .expect("Fail to create runtime");
247
248            // Spawn a simple future that returns 42
249            let handle = runtime.spawn(async {
250                tokio::time::sleep(Duration::from_millis(10)).await;
251                42
252            });
253            let result = handle.await.expect("Task panicked");
254            assert_eq!(result, 42);
255        }
256    }
257
258    #[tokio::test]
259    async fn test_throttleable_runtime_spawn_complex() {
260        let tempdir = tempfile::tempdir().unwrap();
261        for p in [
262            Priority::VeryLow,
263            Priority::Low,
264            Priority::Middle,
265            Priority::High,
266            Priority::VeryHigh,
267        ] {
268            let runtime: ThrottleableRuntime = Builder::default()
269                .runtime_name("test")
270                .thread_name("test")
271                .worker_threads(8)
272                .priority(p)
273                .build()
274                .expect("Fail to create runtime");
275            let tempdirpath = tempdir.path().to_path_buf();
276            let handle = runtime.spawn(async move {
277                let mut file = File::create(tempdirpath.join("test.txt")).await.unwrap();
278                file.write_all(b"Hello, world!").await.unwrap();
279                42
280            });
281            let result = handle.await.expect("Task panicked");
282            assert_eq!(result, 42);
283        }
284    }
285}