common_runtime/
repeated_task.rs1use std::sync::atomic::{AtomicBool, Ordering};
16use std::sync::Mutex;
17use std::time::Duration;
18
19use common_error::ext::ErrorExt;
20use common_telemetry::{debug, error};
21use snafu::{ensure, ResultExt};
22use tokio::task::JoinHandle;
23use tokio_util::sync::CancellationToken;
24
25use crate::error::{IllegalStateSnafu, Result, WaitGcTaskStopSnafu};
26use crate::runtime::RuntimeTrait;
27use crate::Runtime;
28
29#[async_trait::async_trait]
31pub trait TaskFunction<E> {
32 async fn call(&mut self) -> std::result::Result<(), E>;
34
35 fn name(&self) -> &str;
37}
38
39pub type BoxedTaskFunction<E> = Box<dyn TaskFunction<E> + Send + Sync + 'static>;
40
41struct TaskInner<E> {
42 task_handle: Option<JoinHandle<()>>,
44
45 task_fn: Option<BoxedTaskFunction<E>>,
47}
48
49pub struct RepeatedTask<E> {
50 name: String,
51 cancel_token: CancellationToken,
52 inner: Mutex<TaskInner<E>>,
53 started: AtomicBool,
54 interval: Duration,
55 initial_delay: Option<Duration>,
56}
57
58impl<E> std::fmt::Display for RepeatedTask<E> {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 write!(f, "RepeatedTask({})", self.name)
61 }
62}
63
64impl<E> std::fmt::Debug for RepeatedTask<E> {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_tuple("RepeatedTask").field(&self.name).finish()
67 }
68}
69
70impl<E> Drop for RepeatedTask<E> {
71 fn drop(&mut self) {
72 let inner = self.inner.get_mut().unwrap();
73 if inner.task_handle.is_some() {
74 self.cancel_token.cancel();
76 }
77 }
78}
79
80impl<E: ErrorExt + 'static> RepeatedTask<E> {
81 pub fn new(interval: Duration, task_fn: BoxedTaskFunction<E>) -> Self {
85 Self {
86 name: task_fn.name().to_string(),
87 cancel_token: CancellationToken::new(),
88 inner: Mutex::new(TaskInner {
89 task_handle: None,
90 task_fn: Some(task_fn),
91 }),
92 started: AtomicBool::new(false),
93 interval,
94 initial_delay: None,
95 }
96 }
97
98 pub fn with_initial_delay(mut self, initial_delay: Option<Duration>) -> Self {
99 self.initial_delay = initial_delay;
100 self
101 }
102
103 pub fn started(&self) -> bool {
104 self.started.load(Ordering::Relaxed)
105 }
106
107 pub fn start(&self, runtime: Runtime) -> Result<()> {
108 let mut inner = self.inner.lock().unwrap();
109 ensure!(
110 inner.task_fn.is_some(),
111 IllegalStateSnafu { name: &self.name }
112 );
113
114 let child = self.cancel_token.child_token();
115 let mut task_fn = inner.task_fn.take().unwrap();
117 let interval = self.interval;
118 let mut initial_delay = self.initial_delay;
119 let handle = runtime.spawn(async move {
121 loop {
122 let sleep_time = initial_delay.take().unwrap_or(interval);
123 if sleep_time > Duration::ZERO {
124 tokio::select! {
125 _ = tokio::time::sleep(sleep_time) => {}
126 _ = child.cancelled() => {
127 return;
128 }
129 }
130 }
131 if let Err(e) = task_fn.call().await {
132 error!(e; "Failed to run repeated task: {}", task_fn.name());
133 }
134 }
135 });
136 inner.task_handle = Some(handle);
137 self.started.store(true, Ordering::Relaxed);
138
139 debug!(
140 "Repeated task {} started with interval: {:?}",
141 self.name, self.interval
142 );
143
144 Ok(())
145 }
146
147 pub async fn stop(&self) -> Result<()> {
148 let handle = {
149 let mut inner = self.inner.lock().unwrap();
150 if inner.task_handle.is_none() {
151 return Ok(());
153 }
154
155 self.cancel_token.cancel();
156 self.started.store(false, Ordering::Relaxed);
157 inner.task_handle.take().unwrap()
159 };
160
161 handle
162 .await
163 .context(WaitGcTaskStopSnafu { name: &self.name })?;
164
165 debug!("Repeated task {} stopped", self.name);
166
167 Ok(())
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use std::sync::atomic::AtomicI32;
174 use std::sync::Arc;
175
176 use super::*;
177 use crate::error::Error;
178
179 struct TickTask {
180 n: Arc<AtomicI32>,
181 }
182
183 #[async_trait::async_trait]
184 impl TaskFunction<Error> for TickTask {
185 fn name(&self) -> &str {
186 "test"
187 }
188
189 async fn call(&mut self) -> Result<()> {
190 let _ = self.n.fetch_add(1, Ordering::Relaxed);
191 Ok(())
192 }
193 }
194
195 #[tokio::test]
196 async fn test_repeated_task() {
197 common_telemetry::init_default_ut_logging();
198
199 let n = Arc::new(AtomicI32::new(0));
200 let task_fn = TickTask { n: n.clone() };
201
202 let task = RepeatedTask::new(Duration::from_millis(100), Box::new(task_fn));
203
204 task.start(crate::global_runtime()).unwrap();
205 tokio::time::sleep(Duration::from_millis(550)).await;
206 task.stop().await.unwrap();
207
208 assert!(n.load(Ordering::Relaxed) >= 3);
209 }
210
211 #[tokio::test]
212 async fn test_repeated_task_prior_exec() {
213 common_telemetry::init_default_ut_logging();
214
215 let n = Arc::new(AtomicI32::new(0));
216 let task_fn = TickTask { n: n.clone() };
217
218 let task = RepeatedTask::new(Duration::from_millis(100), Box::new(task_fn))
219 .with_initial_delay(Some(Duration::ZERO));
220
221 task.start(crate::global_runtime()).unwrap();
222 tokio::time::sleep(Duration::from_millis(550)).await;
223 task.stop().await.unwrap();
224
225 assert!(n.load(Ordering::Relaxed) >= 4);
226 }
227}