mito2/schedule/
scheduler.rs1use std::future::Future;
16use std::pin::Pin;
17use std::sync::atomic::{AtomicU8, Ordering};
18use std::sync::{Arc, RwLock};
19
20use common_telemetry::warn;
21use snafu::{ensure, OptionExt, ResultExt};
22use tokio::sync::Mutex;
23use tokio::task::JoinHandle;
24use tokio_util::sync::CancellationToken;
25
26use crate::error::{InvalidSchedulerStateSnafu, InvalidSenderSnafu, Result, StopSchedulerSnafu};
27
28pub type Job = Pin<Box<dyn Future<Output = ()> + Send>>;
29
30const STATE_RUNNING: u8 = 0;
32const STATE_STOP: u8 = 1;
33const STATE_AWAIT_TERMINATION: u8 = 2;
34
35#[async_trait::async_trait]
37pub trait Scheduler: Send + Sync {
38 fn schedule(&self, job: Job) -> Result<()>;
40
41 async fn stop(&self, await_termination: bool) -> Result<()>;
43}
44
45pub type SchedulerRef = Arc<dyn Scheduler>;
46
47pub struct LocalScheduler {
49 sender: RwLock<Option<async_channel::Sender<Job>>>,
51 handles: Mutex<Vec<JoinHandle<()>>>,
53 cancel_token: CancellationToken,
55 state: Arc<AtomicU8>,
57}
58
59impl LocalScheduler {
60 pub fn new(concurrency: usize) -> Self {
64 let (tx, rx) = async_channel::unbounded();
65 let token = CancellationToken::new();
66 let state = Arc::new(AtomicU8::new(STATE_RUNNING));
67
68 let mut handles = Vec::with_capacity(concurrency);
69
70 for _ in 0..concurrency {
71 let child = token.child_token();
72 let receiver = rx.clone();
73 let state_clone = state.clone();
74 let handle = common_runtime::spawn_global(async move {
75 while state_clone.load(Ordering::Relaxed) == STATE_RUNNING {
76 tokio::select! {
77 _ = child.cancelled() => {
78 break;
79 }
80 req_opt = receiver.recv() =>{
81 if let Ok(job) = req_opt {
82 job.await;
83 }
84 }
85 }
86 }
87 if state_clone.load(Ordering::Relaxed) == STATE_AWAIT_TERMINATION {
89 while let Ok(job) = receiver.recv().await {
91 job.await;
92 }
93 state_clone.store(STATE_STOP, Ordering::Relaxed);
94 }
95 });
96 handles.push(handle);
97 }
98
99 Self {
100 sender: RwLock::new(Some(tx)),
101 cancel_token: token,
102 handles: Mutex::new(handles),
103 state,
104 }
105 }
106
107 #[inline]
108 fn is_running(&self) -> bool {
109 self.state.load(Ordering::Relaxed) == STATE_RUNNING
110 }
111}
112
113#[async_trait::async_trait]
114impl Scheduler for LocalScheduler {
115 fn schedule(&self, job: Job) -> Result<()> {
116 ensure!(self.is_running(), InvalidSchedulerStateSnafu);
117
118 self.sender
119 .read()
120 .unwrap()
121 .as_ref()
122 .context(InvalidSchedulerStateSnafu)?
123 .try_send(job)
124 .map_err(|_| InvalidSenderSnafu {}.build())
125 }
126
127 async fn stop(&self, await_termination: bool) -> Result<()> {
129 ensure!(self.is_running(), InvalidSchedulerStateSnafu);
130 let state = if await_termination {
131 STATE_AWAIT_TERMINATION
132 } else {
133 STATE_STOP
134 };
135 self.sender.write().unwrap().take();
136 self.state.store(state, Ordering::Relaxed);
137 self.cancel_token.cancel();
138
139 futures::future::join_all(self.handles.lock().await.drain(..))
140 .await
141 .into_iter()
142 .collect::<std::result::Result<Vec<_>, _>>()
143 .context(StopSchedulerSnafu)?;
144
145 Ok(())
146 }
147}
148
149impl Drop for LocalScheduler {
150 fn drop(&mut self) {
151 if self.state.load(Ordering::Relaxed) != STATE_STOP {
152 warn!("scheduler should be stopped before dropping, which means the state of scheduler must be STATE_STOP");
153
154 self.sender.write().unwrap().take();
156 self.cancel_token.cancel();
157 }
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use std::sync::atomic::AtomicI32;
164 use std::sync::Arc;
165
166 use tokio::sync::Barrier;
167 use tokio::time::Duration;
168
169 use super::*;
170
171 #[tokio::test]
172 async fn test_sum_cap() {
173 let task_size = 1000;
174 let sum = Arc::new(AtomicI32::new(0));
175 let local = LocalScheduler::new(task_size);
176
177 for _ in 0..task_size {
178 let sum_clone = sum.clone();
179 local
180 .schedule(Box::pin(async move {
181 sum_clone.fetch_add(1, Ordering::Relaxed);
182 }))
183 .unwrap();
184 }
185 local.stop(true).await.unwrap();
186 assert_eq!(sum.load(Ordering::Relaxed), 1000);
187 }
188
189 #[tokio::test]
190 async fn test_sum_consumer_num() {
191 let task_size = 1000;
192 let sum = Arc::new(AtomicI32::new(0));
193 let local = LocalScheduler::new(3);
194 let mut target = 0;
195 for _ in 0..task_size {
196 let sum_clone = sum.clone();
197 let ok = local
198 .schedule(Box::pin(async move {
199 sum_clone.fetch_add(1, Ordering::Relaxed);
200 }))
201 .is_ok();
202 if ok {
203 target += 1;
204 }
205 }
206 local.stop(true).await.unwrap();
207 assert_eq!(sum.load(Ordering::Relaxed), target);
208 }
209
210 #[tokio::test]
211 async fn test_scheduler_many() {
212 let task_size = 1000;
213
214 let barrier = Arc::new(Barrier::new(task_size + 1));
215 let local: LocalScheduler = LocalScheduler::new(task_size);
216
217 for _ in 0..task_size {
218 let barrier_clone = barrier.clone();
219 local
220 .schedule(Box::pin(async move {
221 barrier_clone.wait().await;
222 }))
223 .unwrap();
224 }
225 barrier.wait().await;
226 local.stop(true).await.unwrap();
227 }
228
229 #[tokio::test]
230 async fn test_scheduler_continuous_stop() {
231 let sum = Arc::new(AtomicI32::new(0));
232 let local = Arc::new(LocalScheduler::new(1000));
233
234 let barrier = Arc::new(Barrier::new(2));
235 let barrier_clone = barrier.clone();
236 let local_stop = local.clone();
237 tokio::spawn(async move {
238 tokio::time::sleep(Duration::from_millis(5)).await;
239 local_stop.stop(true).await.unwrap();
240 barrier_clone.wait().await;
241 });
242
243 let target = Arc::new(AtomicI32::new(0));
244 let local_task = local.clone();
245 let target_clone = target.clone();
246 let sum_clone = sum.clone();
247 tokio::spawn(async move {
248 loop {
249 let sum_c = sum_clone.clone();
250 let ok = local_task
251 .schedule(Box::pin(async move {
252 sum_c.fetch_add(1, Ordering::Relaxed);
253 }))
254 .is_ok();
255 if ok {
256 target_clone.fetch_add(1, Ordering::Relaxed);
257 } else {
258 break;
259 }
260 tokio::task::yield_now().await;
261 }
262 });
263 barrier.wait().await;
264 assert_eq!(sum.load(Ordering::Relaxed), target.load(Ordering::Relaxed));
265 }
266}