mito2/schedule/
scheduler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::future::Future;
16use std::pin::Pin;
17use std::sync::atomic::{AtomicU8, Ordering};
18use std::sync::{Arc, RwLock};
19
20use common_telemetry::warn;
21use snafu::{OptionExt, ResultExt, ensure};
22use tokio::sync::Mutex;
23use tokio::task::JoinHandle;
24use tokio_util::sync::CancellationToken;
25
26use crate::error::{InvalidSchedulerStateSnafu, InvalidSenderSnafu, Result, StopSchedulerSnafu};
27
28pub type Job = Pin<Box<dyn Future<Output = ()> + Send>>;
29
30///The state of scheduler
31const STATE_RUNNING: u8 = 0;
32const STATE_STOP: u8 = 1;
33const STATE_AWAIT_TERMINATION: u8 = 2;
34
35/// [Scheduler] defines a set of API to schedule Jobs
36#[async_trait::async_trait]
37pub trait Scheduler: Send + Sync {
38    /// Schedules a Job
39    fn schedule(&self, job: Job) -> Result<()>;
40
41    /// Stops scheduler. If `await_termination` is set to true, the scheduler will wait until all tasks are processed.
42    async fn stop(&self, await_termination: bool) -> Result<()>;
43}
44
45pub type SchedulerRef = Arc<dyn Scheduler>;
46
47/// Request scheduler based on local state.
48pub struct LocalScheduler {
49    /// Sends jobs to flume bounded channel
50    sender: RwLock<Option<async_channel::Sender<Job>>>,
51    /// Task handles
52    handles: Mutex<Vec<JoinHandle<()>>>,
53    /// Token used to halt the scheduler
54    cancel_token: CancellationToken,
55    /// State of scheduler
56    state: Arc<AtomicU8>,
57}
58
59impl LocalScheduler {
60    /// Starts a new scheduler.
61    ///
62    /// concurrency: the number of bounded receiver
63    pub fn new(concurrency: usize) -> Self {
64        let (tx, rx) = async_channel::unbounded();
65        let token = CancellationToken::new();
66        let state = Arc::new(AtomicU8::new(STATE_RUNNING));
67
68        let mut handles = Vec::with_capacity(concurrency);
69
70        for _ in 0..concurrency {
71            let child = token.child_token();
72            let receiver = rx.clone();
73            let state_clone = state.clone();
74            let handle = common_runtime::spawn_global(async move {
75                while state_clone.load(Ordering::Relaxed) == STATE_RUNNING {
76                    tokio::select! {
77                        _ = child.cancelled() => {
78                            break;
79                        }
80                        req_opt = receiver.recv() =>{
81                            if let Ok(job) = req_opt {
82                                job.await;
83                            }
84                        }
85                    }
86                }
87                // When task scheduler is cancelled, we will wait all task finished
88                if state_clone.load(Ordering::Relaxed) == STATE_AWAIT_TERMINATION {
89                    // recv_async waits until all sender's been dropped.
90                    while let Ok(job) = receiver.recv().await {
91                        job.await;
92                    }
93                    state_clone.store(STATE_STOP, Ordering::Relaxed);
94                }
95            });
96            handles.push(handle);
97        }
98
99        Self {
100            sender: RwLock::new(Some(tx)),
101            cancel_token: token,
102            handles: Mutex::new(handles),
103            state,
104        }
105    }
106
107    #[inline]
108    fn is_running(&self) -> bool {
109        self.state.load(Ordering::Relaxed) == STATE_RUNNING
110    }
111}
112
113#[async_trait::async_trait]
114impl Scheduler for LocalScheduler {
115    fn schedule(&self, job: Job) -> Result<()> {
116        ensure!(self.is_running(), InvalidSchedulerStateSnafu);
117
118        self.sender
119            .read()
120            .unwrap()
121            .as_ref()
122            .context(InvalidSchedulerStateSnafu)?
123            .try_send(job)
124            .map_err(|_| InvalidSenderSnafu {}.build())
125    }
126
127    /// if await_termination is true, scheduler will wait all tasks finished before stopping
128    async fn stop(&self, await_termination: bool) -> Result<()> {
129        ensure!(self.is_running(), InvalidSchedulerStateSnafu);
130        let state = if await_termination {
131            STATE_AWAIT_TERMINATION
132        } else {
133            STATE_STOP
134        };
135        self.sender.write().unwrap().take();
136        self.state.store(state, Ordering::Relaxed);
137        self.cancel_token.cancel();
138
139        futures::future::join_all(self.handles.lock().await.drain(..))
140            .await
141            .into_iter()
142            .collect::<std::result::Result<Vec<_>, _>>()
143            .context(StopSchedulerSnafu)?;
144
145        Ok(())
146    }
147}
148
149impl Drop for LocalScheduler {
150    fn drop(&mut self) {
151        if self.state.load(Ordering::Relaxed) != STATE_STOP {
152            warn!(
153                "scheduler should be stopped before dropping, which means the state of scheduler must be STATE_STOP"
154            );
155
156            // We didn't call `stop()` so we cancel all background workers here.
157            self.sender.write().unwrap().take();
158            self.cancel_token.cancel();
159        }
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use std::sync::Arc;
166    use std::sync::atomic::AtomicI32;
167
168    use tokio::sync::Barrier;
169    use tokio::time::Duration;
170
171    use super::*;
172
173    #[tokio::test]
174    async fn test_sum_cap() {
175        let task_size = 1000;
176        let sum = Arc::new(AtomicI32::new(0));
177        let local = LocalScheduler::new(task_size);
178
179        for _ in 0..task_size {
180            let sum_clone = sum.clone();
181            local
182                .schedule(Box::pin(async move {
183                    sum_clone.fetch_add(1, Ordering::Relaxed);
184                }))
185                .unwrap();
186        }
187        local.stop(true).await.unwrap();
188        assert_eq!(sum.load(Ordering::Relaxed), 1000);
189    }
190
191    #[tokio::test]
192    async fn test_sum_consumer_num() {
193        let task_size = 1000;
194        let sum = Arc::new(AtomicI32::new(0));
195        let local = LocalScheduler::new(3);
196        let mut target = 0;
197        for _ in 0..task_size {
198            let sum_clone = sum.clone();
199            let ok = local
200                .schedule(Box::pin(async move {
201                    sum_clone.fetch_add(1, Ordering::Relaxed);
202                }))
203                .is_ok();
204            if ok {
205                target += 1;
206            }
207        }
208        local.stop(true).await.unwrap();
209        assert_eq!(sum.load(Ordering::Relaxed), target);
210    }
211
212    #[tokio::test]
213    async fn test_scheduler_many() {
214        let task_size = 1000;
215
216        let barrier = Arc::new(Barrier::new(task_size + 1));
217        let local: LocalScheduler = LocalScheduler::new(task_size);
218
219        for _ in 0..task_size {
220            let barrier_clone = barrier.clone();
221            local
222                .schedule(Box::pin(async move {
223                    barrier_clone.wait().await;
224                }))
225                .unwrap();
226        }
227        barrier.wait().await;
228        local.stop(true).await.unwrap();
229    }
230
231    #[tokio::test]
232    async fn test_scheduler_continuous_stop() {
233        let sum = Arc::new(AtomicI32::new(0));
234        let local = Arc::new(LocalScheduler::new(1000));
235
236        let barrier = Arc::new(Barrier::new(2));
237        let barrier_clone = barrier.clone();
238        let local_stop = local.clone();
239        tokio::spawn(async move {
240            tokio::time::sleep(Duration::from_millis(5)).await;
241            local_stop.stop(true).await.unwrap();
242            barrier_clone.wait().await;
243        });
244
245        let target = Arc::new(AtomicI32::new(0));
246        let local_task = local.clone();
247        let target_clone = target.clone();
248        let sum_clone = sum.clone();
249        tokio::spawn(async move {
250            loop {
251                let sum_c = sum_clone.clone();
252                let ok = local_task
253                    .schedule(Box::pin(async move {
254                        sum_c.fetch_add(1, Ordering::Relaxed);
255                    }))
256                    .is_ok();
257                if ok {
258                    target_clone.fetch_add(1, Ordering::Relaxed);
259                } else {
260                    break;
261                }
262                tokio::task::yield_now().await;
263            }
264        });
265        barrier.wait().await;
266        assert_eq!(sum.load(Ordering::Relaxed), target.load(Ordering::Relaxed));
267    }
268}