mito2/schedule/
scheduler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::future::Future;
16use std::pin::Pin;
17use std::sync::atomic::{AtomicU8, Ordering};
18use std::sync::{Arc, RwLock};
19
20use common_telemetry::warn;
21use snafu::{ensure, OptionExt, ResultExt};
22use tokio::sync::Mutex;
23use tokio::task::JoinHandle;
24use tokio_util::sync::CancellationToken;
25
26use crate::error::{InvalidSchedulerStateSnafu, InvalidSenderSnafu, Result, StopSchedulerSnafu};
27
28pub type Job = Pin<Box<dyn Future<Output = ()> + Send>>;
29
30///The state of scheduler
31const STATE_RUNNING: u8 = 0;
32const STATE_STOP: u8 = 1;
33const STATE_AWAIT_TERMINATION: u8 = 2;
34
35/// [Scheduler] defines a set of API to schedule Jobs
36#[async_trait::async_trait]
37pub trait Scheduler: Send + Sync {
38    /// Schedules a Job
39    fn schedule(&self, job: Job) -> Result<()>;
40
41    /// Stops scheduler. If `await_termination` is set to true, the scheduler will wait until all tasks are processed.
42    async fn stop(&self, await_termination: bool) -> Result<()>;
43}
44
45pub type SchedulerRef = Arc<dyn Scheduler>;
46
47/// Request scheduler based on local state.
48pub struct LocalScheduler {
49    /// Sends jobs to flume bounded channel
50    sender: RwLock<Option<async_channel::Sender<Job>>>,
51    /// Task handles
52    handles: Mutex<Vec<JoinHandle<()>>>,
53    /// Token used to halt the scheduler
54    cancel_token: CancellationToken,
55    /// State of scheduler
56    state: Arc<AtomicU8>,
57}
58
59impl LocalScheduler {
60    /// Starts a new scheduler.
61    ///
62    /// concurrency: the number of bounded receiver
63    pub fn new(concurrency: usize) -> Self {
64        let (tx, rx) = async_channel::unbounded();
65        let token = CancellationToken::new();
66        let state = Arc::new(AtomicU8::new(STATE_RUNNING));
67
68        let mut handles = Vec::with_capacity(concurrency);
69
70        for _ in 0..concurrency {
71            let child = token.child_token();
72            let receiver = rx.clone();
73            let state_clone = state.clone();
74            let handle = common_runtime::spawn_global(async move {
75                while state_clone.load(Ordering::Relaxed) == STATE_RUNNING {
76                    tokio::select! {
77                        _ = child.cancelled() => {
78                            break;
79                        }
80                        req_opt = receiver.recv() =>{
81                            if let Ok(job) = req_opt {
82                                job.await;
83                            }
84                        }
85                    }
86                }
87                // When task scheduler is cancelled, we will wait all task finished
88                if state_clone.load(Ordering::Relaxed) == STATE_AWAIT_TERMINATION {
89                    // recv_async waits until all sender's been dropped.
90                    while let Ok(job) = receiver.recv().await {
91                        job.await;
92                    }
93                    state_clone.store(STATE_STOP, Ordering::Relaxed);
94                }
95            });
96            handles.push(handle);
97        }
98
99        Self {
100            sender: RwLock::new(Some(tx)),
101            cancel_token: token,
102            handles: Mutex::new(handles),
103            state,
104        }
105    }
106
107    #[inline]
108    fn is_running(&self) -> bool {
109        self.state.load(Ordering::Relaxed) == STATE_RUNNING
110    }
111}
112
113#[async_trait::async_trait]
114impl Scheduler for LocalScheduler {
115    fn schedule(&self, job: Job) -> Result<()> {
116        ensure!(self.is_running(), InvalidSchedulerStateSnafu);
117
118        self.sender
119            .read()
120            .unwrap()
121            .as_ref()
122            .context(InvalidSchedulerStateSnafu)?
123            .try_send(job)
124            .map_err(|_| InvalidSenderSnafu {}.build())
125    }
126
127    /// if await_termination is true, scheduler will wait all tasks finished before stopping
128    async fn stop(&self, await_termination: bool) -> Result<()> {
129        ensure!(self.is_running(), InvalidSchedulerStateSnafu);
130        let state = if await_termination {
131            STATE_AWAIT_TERMINATION
132        } else {
133            STATE_STOP
134        };
135        self.sender.write().unwrap().take();
136        self.state.store(state, Ordering::Relaxed);
137        self.cancel_token.cancel();
138
139        futures::future::join_all(self.handles.lock().await.drain(..))
140            .await
141            .into_iter()
142            .collect::<std::result::Result<Vec<_>, _>>()
143            .context(StopSchedulerSnafu)?;
144
145        Ok(())
146    }
147}
148
149impl Drop for LocalScheduler {
150    fn drop(&mut self) {
151        if self.state.load(Ordering::Relaxed) != STATE_STOP {
152            warn!("scheduler should be stopped before dropping, which means the state of scheduler must be STATE_STOP");
153
154            // We didn't call `stop()` so we cancel all background workers here.
155            self.sender.write().unwrap().take();
156            self.cancel_token.cancel();
157        }
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use std::sync::atomic::AtomicI32;
164    use std::sync::Arc;
165
166    use tokio::sync::Barrier;
167    use tokio::time::Duration;
168
169    use super::*;
170
171    #[tokio::test]
172    async fn test_sum_cap() {
173        let task_size = 1000;
174        let sum = Arc::new(AtomicI32::new(0));
175        let local = LocalScheduler::new(task_size);
176
177        for _ in 0..task_size {
178            let sum_clone = sum.clone();
179            local
180                .schedule(Box::pin(async move {
181                    sum_clone.fetch_add(1, Ordering::Relaxed);
182                }))
183                .unwrap();
184        }
185        local.stop(true).await.unwrap();
186        assert_eq!(sum.load(Ordering::Relaxed), 1000);
187    }
188
189    #[tokio::test]
190    async fn test_sum_consumer_num() {
191        let task_size = 1000;
192        let sum = Arc::new(AtomicI32::new(0));
193        let local = LocalScheduler::new(3);
194        let mut target = 0;
195        for _ in 0..task_size {
196            let sum_clone = sum.clone();
197            let ok = local
198                .schedule(Box::pin(async move {
199                    sum_clone.fetch_add(1, Ordering::Relaxed);
200                }))
201                .is_ok();
202            if ok {
203                target += 1;
204            }
205        }
206        local.stop(true).await.unwrap();
207        assert_eq!(sum.load(Ordering::Relaxed), target);
208    }
209
210    #[tokio::test]
211    async fn test_scheduler_many() {
212        let task_size = 1000;
213
214        let barrier = Arc::new(Barrier::new(task_size + 1));
215        let local: LocalScheduler = LocalScheduler::new(task_size);
216
217        for _ in 0..task_size {
218            let barrier_clone = barrier.clone();
219            local
220                .schedule(Box::pin(async move {
221                    barrier_clone.wait().await;
222                }))
223                .unwrap();
224        }
225        barrier.wait().await;
226        local.stop(true).await.unwrap();
227    }
228
229    #[tokio::test]
230    async fn test_scheduler_continuous_stop() {
231        let sum = Arc::new(AtomicI32::new(0));
232        let local = Arc::new(LocalScheduler::new(1000));
233
234        let barrier = Arc::new(Barrier::new(2));
235        let barrier_clone = barrier.clone();
236        let local_stop = local.clone();
237        tokio::spawn(async move {
238            tokio::time::sleep(Duration::from_millis(5)).await;
239            local_stop.stop(true).await.unwrap();
240            barrier_clone.wait().await;
241        });
242
243        let target = Arc::new(AtomicI32::new(0));
244        let local_task = local.clone();
245        let target_clone = target.clone();
246        let sum_clone = sum.clone();
247        tokio::spawn(async move {
248            loop {
249                let sum_c = sum_clone.clone();
250                let ok = local_task
251                    .schedule(Box::pin(async move {
252                        sum_c.fetch_add(1, Ordering::Relaxed);
253                    }))
254                    .is_ok();
255                if ok {
256                    target_clone.fetch_add(1, Ordering::Relaxed);
257                } else {
258                    break;
259                }
260                tokio::task::yield_now().await;
261            }
262        });
263        barrier.wait().await;
264        assert_eq!(sum.load(Ordering::Relaxed), target.load(Ordering::Relaxed));
265    }
266}