mito2/schedule/
scheduler.rs1use std::future::Future;
16use std::pin::Pin;
17use std::sync::atomic::{AtomicU8, Ordering};
18use std::sync::{Arc, RwLock};
19
20use common_telemetry::warn;
21use snafu::{OptionExt, ResultExt, ensure};
22use tokio::sync::Mutex;
23use tokio::task::JoinHandle;
24use tokio_util::sync::CancellationToken;
25
26use crate::error::{InvalidSchedulerStateSnafu, InvalidSenderSnafu, Result, StopSchedulerSnafu};
27
28pub type Job = Pin<Box<dyn Future<Output = ()> + Send>>;
29
30const STATE_RUNNING: u8 = 0;
32const STATE_STOP: u8 = 1;
33const STATE_AWAIT_TERMINATION: u8 = 2;
34
35#[async_trait::async_trait]
37pub trait Scheduler: Send + Sync {
38 fn schedule(&self, job: Job) -> Result<()>;
40
41 async fn stop(&self, await_termination: bool) -> Result<()>;
43}
44
45pub type SchedulerRef = Arc<dyn Scheduler>;
46
47pub struct LocalScheduler {
49 sender: RwLock<Option<async_channel::Sender<Job>>>,
51 handles: Mutex<Vec<JoinHandle<()>>>,
53 cancel_token: CancellationToken,
55 state: Arc<AtomicU8>,
57}
58
59impl LocalScheduler {
60 pub fn new(concurrency: usize) -> Self {
64 let (tx, rx) = async_channel::unbounded();
65 let token = CancellationToken::new();
66 let state = Arc::new(AtomicU8::new(STATE_RUNNING));
67
68 let mut handles = Vec::with_capacity(concurrency);
69
70 for _ in 0..concurrency {
71 let child = token.child_token();
72 let receiver = rx.clone();
73 let state_clone = state.clone();
74 let handle = common_runtime::spawn_global(async move {
75 while state_clone.load(Ordering::Relaxed) == STATE_RUNNING {
76 tokio::select! {
77 _ = child.cancelled() => {
78 break;
79 }
80 req_opt = receiver.recv() =>{
81 if let Ok(job) = req_opt {
82 job.await;
83 }
84 }
85 }
86 }
87 if state_clone.load(Ordering::Relaxed) == STATE_AWAIT_TERMINATION {
89 while let Ok(job) = receiver.recv().await {
91 job.await;
92 }
93 state_clone.store(STATE_STOP, Ordering::Relaxed);
94 }
95 });
96 handles.push(handle);
97 }
98
99 Self {
100 sender: RwLock::new(Some(tx)),
101 cancel_token: token,
102 handles: Mutex::new(handles),
103 state,
104 }
105 }
106
107 #[inline]
108 fn is_running(&self) -> bool {
109 self.state.load(Ordering::Relaxed) == STATE_RUNNING
110 }
111}
112
113#[async_trait::async_trait]
114impl Scheduler for LocalScheduler {
115 fn schedule(&self, job: Job) -> Result<()> {
116 ensure!(self.is_running(), InvalidSchedulerStateSnafu);
117
118 self.sender
119 .read()
120 .unwrap()
121 .as_ref()
122 .context(InvalidSchedulerStateSnafu)?
123 .try_send(job)
124 .map_err(|_| InvalidSenderSnafu {}.build())
125 }
126
127 async fn stop(&self, await_termination: bool) -> Result<()> {
129 ensure!(self.is_running(), InvalidSchedulerStateSnafu);
130 let state = if await_termination {
131 STATE_AWAIT_TERMINATION
132 } else {
133 STATE_STOP
134 };
135 self.sender.write().unwrap().take();
136 self.state.store(state, Ordering::Relaxed);
137 self.cancel_token.cancel();
138
139 futures::future::join_all(self.handles.lock().await.drain(..))
140 .await
141 .into_iter()
142 .collect::<std::result::Result<Vec<_>, _>>()
143 .context(StopSchedulerSnafu)?;
144
145 Ok(())
146 }
147}
148
149impl Drop for LocalScheduler {
150 fn drop(&mut self) {
151 if self.state.load(Ordering::Relaxed) != STATE_STOP {
152 warn!(
153 "scheduler should be stopped before dropping, which means the state of scheduler must be STATE_STOP"
154 );
155
156 self.sender.write().unwrap().take();
158 self.cancel_token.cancel();
159 }
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use std::sync::Arc;
166 use std::sync::atomic::AtomicI32;
167
168 use tokio::sync::Barrier;
169 use tokio::time::Duration;
170
171 use super::*;
172
173 #[tokio::test]
174 async fn test_sum_cap() {
175 let task_size = 1000;
176 let sum = Arc::new(AtomicI32::new(0));
177 let local = LocalScheduler::new(task_size);
178
179 for _ in 0..task_size {
180 let sum_clone = sum.clone();
181 local
182 .schedule(Box::pin(async move {
183 sum_clone.fetch_add(1, Ordering::Relaxed);
184 }))
185 .unwrap();
186 }
187 local.stop(true).await.unwrap();
188 assert_eq!(sum.load(Ordering::Relaxed), 1000);
189 }
190
191 #[tokio::test]
192 async fn test_sum_consumer_num() {
193 let task_size = 1000;
194 let sum = Arc::new(AtomicI32::new(0));
195 let local = LocalScheduler::new(3);
196 let mut target = 0;
197 for _ in 0..task_size {
198 let sum_clone = sum.clone();
199 let ok = local
200 .schedule(Box::pin(async move {
201 sum_clone.fetch_add(1, Ordering::Relaxed);
202 }))
203 .is_ok();
204 if ok {
205 target += 1;
206 }
207 }
208 local.stop(true).await.unwrap();
209 assert_eq!(sum.load(Ordering::Relaxed), target);
210 }
211
212 #[tokio::test]
213 async fn test_scheduler_many() {
214 let task_size = 1000;
215
216 let barrier = Arc::new(Barrier::new(task_size + 1));
217 let local: LocalScheduler = LocalScheduler::new(task_size);
218
219 for _ in 0..task_size {
220 let barrier_clone = barrier.clone();
221 local
222 .schedule(Box::pin(async move {
223 barrier_clone.wait().await;
224 }))
225 .unwrap();
226 }
227 barrier.wait().await;
228 local.stop(true).await.unwrap();
229 }
230
231 #[tokio::test]
232 async fn test_scheduler_continuous_stop() {
233 let sum = Arc::new(AtomicI32::new(0));
234 let local = Arc::new(LocalScheduler::new(1000));
235
236 let barrier = Arc::new(Barrier::new(2));
237 let barrier_clone = barrier.clone();
238 let local_stop = local.clone();
239 tokio::spawn(async move {
240 tokio::time::sleep(Duration::from_millis(5)).await;
241 local_stop.stop(true).await.unwrap();
242 barrier_clone.wait().await;
243 });
244
245 let target = Arc::new(AtomicI32::new(0));
246 let local_task = local.clone();
247 let target_clone = target.clone();
248 let sum_clone = sum.clone();
249 tokio::spawn(async move {
250 loop {
251 let sum_c = sum_clone.clone();
252 let ok = local_task
253 .schedule(Box::pin(async move {
254 sum_c.fetch_add(1, Ordering::Relaxed);
255 }))
256 .is_ok();
257 if ok {
258 target_clone.fetch_add(1, Ordering::Relaxed);
259 } else {
260 break;
261 }
262 tokio::task::yield_now().await;
263 }
264 });
265 barrier.wait().await;
266 assert_eq!(sum.load(Ordering::Relaxed), target.load(Ordering::Relaxed));
267 }
268}