datanode/heartbeat/
task_tracker.rs1use std::collections::HashMap;
16use std::sync::Arc;
17use std::time::Duration;
18
19use futures_util::future::BoxFuture;
20use snafu::ResultExt;
21use store_api::storage::RegionId;
22use tokio::sync::watch::{self, Receiver};
23use tokio::sync::RwLock;
24
25use crate::error::{self, Error, Result};
26
27#[derive(Debug, Default, Clone)]
29pub(crate) enum TaskState<T: Send + Sync + Clone> {
30 Error(Arc<Error>),
31 #[default]
32 Running,
33 Done(T),
34}
35
36pub(crate) type TaskWatcher<T> = Receiver<TaskState<T>>;
37
38async fn wait<T: Send + Sync + Clone>(watcher: &mut TaskWatcher<T>) -> Result<T> {
39 loop {
40 watcher
41 .changed()
42 .await
43 .context(error::WatchAsyncTaskChangeSnafu)?;
44
45 let r = &*watcher.borrow();
46 match r {
47 TaskState::Error(err) => return Err(err.clone()).context(error::AsyncTaskExecuteSnafu),
48 TaskState::Running => {}
49 TaskState::Done(value) => return Ok(value.clone()),
50 }
51 }
52}
53
54pub(crate) struct Task<T: Send + Sync + Clone> {
56 watcher: TaskWatcher<T>,
57}
58
59pub(crate) struct TaskTrackerInner<T: Send + Sync + Clone> {
60 state: HashMap<RegionId, Task<T>>,
61}
62
63impl<T: Send + Sync + Clone> Default for TaskTrackerInner<T> {
64 fn default() -> Self {
65 TaskTrackerInner {
66 state: HashMap::new(),
67 }
68 }
69}
70
71#[derive(Clone)]
73pub(crate) struct TaskTracker<T: Send + Sync + Clone> {
74 inner: Arc<RwLock<TaskTrackerInner<T>>>,
75}
76
77pub(crate) enum RegisterResult<T: Send + Sync + Clone> {
79 Busy(TaskWatcher<T>),
81 Running(TaskWatcher<T>),
83}
84
85impl<T: Send + Sync + Clone> RegisterResult<T> {
86 pub(crate) fn into_watcher(self) -> TaskWatcher<T> {
87 match self {
88 RegisterResult::Busy(inner) => inner,
89 RegisterResult::Running(inner) => inner,
90 }
91 }
92
93 pub(crate) fn is_busy(&self) -> bool {
95 matches!(self, RegisterResult::Busy(_))
96 }
97
98 #[cfg(test)]
99 pub(crate) fn is_running(&self) -> bool {
101 matches!(self, RegisterResult::Running(_))
102 }
103}
104
105pub(crate) enum WaitResult<T> {
107 Timeout,
108 Finish(Result<T>),
109}
110
111#[cfg(test)]
112impl<T> WaitResult<T> {
113 pub(crate) fn is_timeout(&self) -> bool {
115 matches!(self, WaitResult::Timeout)
116 }
117
118 pub(crate) fn into_finish(self) -> Option<Result<T>> {
120 match self {
121 WaitResult::Timeout => None,
122 WaitResult::Finish(result) => Some(result),
123 }
124 }
125}
126
127impl<T: Send + Sync + Clone + 'static> TaskTracker<T> {
128 pub(crate) fn new() -> Self {
130 Self {
131 inner: Arc::new(RwLock::new(TaskTrackerInner::default())),
132 }
133 }
134
135 pub(crate) async fn wait(
137 &self,
138 watcher: &mut TaskWatcher<T>,
139 timeout: Duration,
140 ) -> WaitResult<T> {
141 match tokio::time::timeout(timeout, wait(watcher)).await {
142 Ok(result) => WaitResult::Finish(result),
143 Err(_) => WaitResult::Timeout,
144 }
145 }
146
147 pub(crate) async fn wait_until_finish(&self, watcher: &mut TaskWatcher<T>) -> Result<T> {
149 wait(watcher).await
150 }
151
152 pub(crate) async fn try_register(
154 &self,
155 region_id: RegionId,
156 fut: BoxFuture<'static, Result<T>>,
157 ) -> RegisterResult<T> {
158 let mut inner = self.inner.write().await;
159 if let Some(task) = inner.state.get(®ion_id) {
160 RegisterResult::Busy(task.watcher.clone())
161 } else {
162 let moved_inner = self.inner.clone();
163 let (tx, rx) = watch::channel(TaskState::<T>::Running);
164 common_runtime::spawn_global(async move {
165 match fut.await {
166 Ok(result) => {
167 let _ = tx.send(TaskState::Done(result));
168 }
169 Err(err) => {
170 let _ = tx.send(TaskState::Error(Arc::new(err)));
171 }
172 };
173 moved_inner.write().await.state.remove(®ion_id);
174 });
175 inner.state.insert(
176 region_id,
177 Task {
178 watcher: rx.clone(),
179 },
180 );
181
182 RegisterResult::Running(rx.clone())
183 }
184 }
185
186 #[cfg(test)]
187 async fn watcher(&self, region_id: RegionId) -> Option<TaskWatcher<T>> {
188 self.inner
189 .read()
190 .await
191 .state
192 .get(®ion_id)
193 .map(|task| task.watcher.clone())
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use std::time::Duration;
200
201 use store_api::storage::RegionId;
202 use tokio::sync::oneshot;
203
204 use crate::heartbeat::task_tracker::{wait, TaskTracker};
205
206 #[derive(Debug, Clone, PartialEq, Eq)]
207 struct TestResult {
208 value: i32,
209 }
210
211 #[tokio::test]
212 async fn test_async_task_tracker_register() {
213 let tracker = TaskTracker::<TestResult>::new();
214 let region_id = RegionId::new(1024, 1);
215 let (tx, rx) = oneshot::channel::<()>();
216
217 let result = tracker
218 .try_register(
219 region_id,
220 Box::pin(async move {
221 let _ = rx.await;
222 Ok(TestResult { value: 1024 })
223 }),
224 )
225 .await;
226
227 assert!(result.is_running());
228
229 let result = tracker
230 .try_register(
231 region_id,
232 Box::pin(async move { Ok(TestResult { value: 1023 }) }),
233 )
234 .await;
235 assert!(result.is_busy());
236 let mut watcher = tracker.watcher(region_id).await.unwrap();
237 tx.send(()).unwrap();
239
240 assert_eq!(
241 TestResult { value: 1024 },
242 wait(&mut watcher).await.unwrap()
243 );
244 let result = tracker
245 .try_register(
246 region_id,
247 Box::pin(async move { Ok(TestResult { value: 1022 }) }),
248 )
249 .await;
250 assert!(result.is_running());
251 }
252
253 #[tokio::test]
254 async fn test_async_task_tracker_wait_timeout() {
255 let tracker = TaskTracker::<TestResult>::new();
256 let region_id = RegionId::new(1024, 1);
257 let (tx, rx) = oneshot::channel::<()>();
258
259 let result = tracker
260 .try_register(
261 region_id,
262 Box::pin(async move {
263 let _ = rx.await;
264 Ok(TestResult { value: 1024 })
265 }),
266 )
267 .await;
268
269 let mut watcher = result.into_watcher();
270 let result = tracker.wait(&mut watcher, Duration::from_millis(100)).await;
271 assert!(result.is_timeout());
272
273 tx.send(()).unwrap();
275 let result = tracker
276 .wait(&mut watcher, Duration::from_millis(100))
277 .await
278 .into_finish()
279 .unwrap()
280 .unwrap();
281 assert_eq!(TestResult { value: 1024 }, result);
282 assert!(tracker.watcher(region_id).await.is_none());
283 }
284}