common_runtime/
repeated_task.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::atomic::{AtomicBool, Ordering};
16use std::sync::Mutex;
17use std::time::Duration;
18
19use common_error::ext::ErrorExt;
20use common_telemetry::{debug, error};
21use snafu::{ensure, ResultExt};
22use tokio::task::JoinHandle;
23use tokio_util::sync::CancellationToken;
24
25use crate::error::{IllegalStateSnafu, Result, WaitGcTaskStopSnafu};
26use crate::runtime::RuntimeTrait;
27use crate::Runtime;
28
29/// Task to execute repeatedly.
30#[async_trait::async_trait]
31pub trait TaskFunction<E> {
32    /// Invoke the task.
33    async fn call(&mut self) -> std::result::Result<(), E>;
34
35    /// Name of the task.
36    fn name(&self) -> &str;
37}
38
39pub type BoxedTaskFunction<E> = Box<dyn TaskFunction<E> + Send + Sync + 'static>;
40
41struct TaskInner<E> {
42    /// The repeated task handle. This handle is Some if the task is started.
43    task_handle: Option<JoinHandle<()>>,
44
45    /// The task_fn to run. This is Some if the task is not started.
46    task_fn: Option<BoxedTaskFunction<E>>,
47}
48
49pub struct RepeatedTask<E> {
50    name: String,
51    cancel_token: CancellationToken,
52    inner: Mutex<TaskInner<E>>,
53    started: AtomicBool,
54    interval: Duration,
55    initial_delay: Option<Duration>,
56}
57
58impl<E> std::fmt::Display for RepeatedTask<E> {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        write!(f, "RepeatedTask({})", self.name)
61    }
62}
63
64impl<E> std::fmt::Debug for RepeatedTask<E> {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.debug_tuple("RepeatedTask").field(&self.name).finish()
67    }
68}
69
70impl<E> Drop for RepeatedTask<E> {
71    fn drop(&mut self) {
72        let inner = self.inner.get_mut().unwrap();
73        if inner.task_handle.is_some() {
74            // Cancel the background task.
75            self.cancel_token.cancel();
76        }
77    }
78}
79
80impl<E: ErrorExt + 'static> RepeatedTask<E> {
81    /// Creates a new repeated task. The `initial_delay` is the delay before the first execution.
82    /// `initial_delay` default is None, the initial interval uses the `interval`.
83    /// You can use `with_initial_delay` to set the `initial_delay`.
84    pub fn new(interval: Duration, task_fn: BoxedTaskFunction<E>) -> Self {
85        Self {
86            name: task_fn.name().to_string(),
87            cancel_token: CancellationToken::new(),
88            inner: Mutex::new(TaskInner {
89                task_handle: None,
90                task_fn: Some(task_fn),
91            }),
92            started: AtomicBool::new(false),
93            interval,
94            initial_delay: None,
95        }
96    }
97
98    pub fn with_initial_delay(mut self, initial_delay: Option<Duration>) -> Self {
99        self.initial_delay = initial_delay;
100        self
101    }
102
103    pub fn started(&self) -> bool {
104        self.started.load(Ordering::Relaxed)
105    }
106
107    pub fn start(&self, runtime: Runtime) -> Result<()> {
108        let mut inner = self.inner.lock().unwrap();
109        ensure!(
110            inner.task_fn.is_some(),
111            IllegalStateSnafu { name: &self.name }
112        );
113
114        let child = self.cancel_token.child_token();
115        // Safety: The task is not started.
116        let mut task_fn = inner.task_fn.take().unwrap();
117        let interval = self.interval;
118        let mut initial_delay = self.initial_delay;
119        // TODO(hl): Maybe spawn to a blocking runtime.
120        let handle = runtime.spawn(async move {
121            loop {
122                let sleep_time = initial_delay.take().unwrap_or(interval);
123                if sleep_time > Duration::ZERO {
124                    tokio::select! {
125                        _ = tokio::time::sleep(sleep_time) => {}
126                        _ = child.cancelled() => {
127                            return;
128                        }
129                    }
130                }
131                if let Err(e) = task_fn.call().await {
132                    error!(e; "Failed to run repeated task: {}", task_fn.name());
133                }
134            }
135        });
136        inner.task_handle = Some(handle);
137        self.started.store(true, Ordering::Relaxed);
138
139        debug!(
140            "Repeated task {} started with interval: {:?}",
141            self.name, self.interval
142        );
143
144        Ok(())
145    }
146
147    pub async fn stop(&self) -> Result<()> {
148        let handle = {
149            let mut inner = self.inner.lock().unwrap();
150            if inner.task_handle.is_none() {
151                // We allow stop the task multiple times.
152                return Ok(());
153            }
154
155            self.cancel_token.cancel();
156            self.started.store(false, Ordering::Relaxed);
157            // Safety: The task is not stopped.
158            inner.task_handle.take().unwrap()
159        };
160
161        handle
162            .await
163            .context(WaitGcTaskStopSnafu { name: &self.name })?;
164
165        debug!("Repeated task {} stopped", self.name);
166
167        Ok(())
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use std::sync::atomic::AtomicI32;
174    use std::sync::Arc;
175
176    use super::*;
177    use crate::error::Error;
178
179    struct TickTask {
180        n: Arc<AtomicI32>,
181    }
182
183    #[async_trait::async_trait]
184    impl TaskFunction<Error> for TickTask {
185        fn name(&self) -> &str {
186            "test"
187        }
188
189        async fn call(&mut self) -> Result<()> {
190            let _ = self.n.fetch_add(1, Ordering::Relaxed);
191            Ok(())
192        }
193    }
194
195    #[tokio::test]
196    async fn test_repeated_task() {
197        common_telemetry::init_default_ut_logging();
198
199        let n = Arc::new(AtomicI32::new(0));
200        let task_fn = TickTask { n: n.clone() };
201
202        let task = RepeatedTask::new(Duration::from_millis(100), Box::new(task_fn));
203
204        task.start(crate::global_runtime()).unwrap();
205        tokio::time::sleep(Duration::from_millis(550)).await;
206        task.stop().await.unwrap();
207
208        assert!(n.load(Ordering::Relaxed) >= 3);
209    }
210
211    #[tokio::test]
212    async fn test_repeated_task_prior_exec() {
213        common_telemetry::init_default_ut_logging();
214
215        let n = Arc::new(AtomicI32::new(0));
216        let task_fn = TickTask { n: n.clone() };
217
218        let task = RepeatedTask::new(Duration::from_millis(100), Box::new(task_fn))
219            .with_initial_delay(Some(Duration::ZERO));
220
221        task.start(crate::global_runtime()).unwrap();
222        tokio::time::sleep(Duration::from_millis(550)).await;
223        task.stop().await.unwrap();
224
225        assert!(n.load(Ordering::Relaxed) >= 4);
226    }
227}