datanode/heartbeat/
task_tracker.rs1use std::collections::HashMap;
16use std::sync::Arc;
17use std::time::Duration;
18
19use futures_util::future::BoxFuture;
20use snafu::ResultExt;
21use store_api::storage::RegionId;
22use tokio::sync::RwLock;
23use tokio::sync::watch::{self, Receiver};
24
25use crate::error::{self, Error, Result};
26
27#[derive(Debug, Default, Clone)]
29pub(crate) enum TaskState<T: Send + Sync + Clone> {
30    Error(Arc<Error>),
31    #[default]
32    Running,
33    Done(T),
34}
35
36pub(crate) type TaskWatcher<T> = Receiver<TaskState<T>>;
37
38async fn wait<T: Send + Sync + Clone>(watcher: &mut TaskWatcher<T>) -> Result<T> {
39    loop {
40        watcher
41            .changed()
42            .await
43            .context(error::WatchAsyncTaskChangeSnafu)?;
44
45        let r = &*watcher.borrow();
46        match r {
47            TaskState::Error(err) => return Err(err.clone()).context(error::AsyncTaskExecuteSnafu),
48            TaskState::Running => {}
49            TaskState::Done(value) => return Ok(value.clone()),
50        }
51    }
52}
53
54pub(crate) struct Task<T: Send + Sync + Clone> {
56    watcher: TaskWatcher<T>,
57}
58
59pub(crate) struct TaskTrackerInner<T: Send + Sync + Clone> {
60    state: HashMap<RegionId, Task<T>>,
61}
62
63impl<T: Send + Sync + Clone> Default for TaskTrackerInner<T> {
64    fn default() -> Self {
65        TaskTrackerInner {
66            state: HashMap::new(),
67        }
68    }
69}
70
71#[derive(Clone)]
73pub(crate) struct TaskTracker<T: Send + Sync + Clone> {
74    inner: Arc<RwLock<TaskTrackerInner<T>>>,
75}
76
77pub(crate) enum RegisterResult<T: Send + Sync + Clone> {
79    Busy(TaskWatcher<T>),
81    Running(TaskWatcher<T>),
83}
84
85impl<T: Send + Sync + Clone> RegisterResult<T> {
86    pub(crate) fn into_watcher(self) -> TaskWatcher<T> {
87        match self {
88            RegisterResult::Busy(inner) => inner,
89            RegisterResult::Running(inner) => inner,
90        }
91    }
92
93    pub(crate) fn is_busy(&self) -> bool {
95        matches!(self, RegisterResult::Busy(_))
96    }
97
98    #[cfg(test)]
99    pub(crate) fn is_running(&self) -> bool {
101        matches!(self, RegisterResult::Running(_))
102    }
103}
104
105pub(crate) enum WaitResult<T> {
107    Timeout,
108    Finish(Result<T>),
109}
110
111#[cfg(test)]
112impl<T> WaitResult<T> {
113    pub(crate) fn is_timeout(&self) -> bool {
115        matches!(self, WaitResult::Timeout)
116    }
117
118    pub(crate) fn into_finish(self) -> Option<Result<T>> {
120        match self {
121            WaitResult::Timeout => None,
122            WaitResult::Finish(result) => Some(result),
123        }
124    }
125}
126
127impl<T: Send + Sync + Clone + 'static> TaskTracker<T> {
128    pub(crate) fn new() -> Self {
130        Self {
131            inner: Arc::new(RwLock::new(TaskTrackerInner::default())),
132        }
133    }
134
135    pub(crate) async fn wait(
137        &self,
138        watcher: &mut TaskWatcher<T>,
139        timeout: Duration,
140    ) -> WaitResult<T> {
141        match tokio::time::timeout(timeout, wait(watcher)).await {
142            Ok(result) => WaitResult::Finish(result),
143            Err(_) => WaitResult::Timeout,
144        }
145    }
146
147    pub(crate) async fn wait_until_finish(&self, watcher: &mut TaskWatcher<T>) -> Result<T> {
149        wait(watcher).await
150    }
151
152    pub(crate) async fn try_register(
154        &self,
155        region_id: RegionId,
156        fut: BoxFuture<'static, Result<T>>,
157    ) -> RegisterResult<T> {
158        let mut inner = self.inner.write().await;
159        if let Some(task) = inner.state.get(®ion_id) {
160            RegisterResult::Busy(task.watcher.clone())
161        } else {
162            let moved_inner = self.inner.clone();
163            let (tx, rx) = watch::channel(TaskState::<T>::Running);
164            common_runtime::spawn_global(async move {
165                match fut.await {
166                    Ok(result) => {
167                        let _ = tx.send(TaskState::Done(result));
168                    }
169                    Err(err) => {
170                        let _ = tx.send(TaskState::Error(Arc::new(err)));
171                    }
172                };
173                moved_inner.write().await.state.remove(®ion_id);
174            });
175            inner.state.insert(
176                region_id,
177                Task {
178                    watcher: rx.clone(),
179                },
180            );
181
182            RegisterResult::Running(rx.clone())
183        }
184    }
185
186    #[cfg(test)]
187    async fn watcher(&self, region_id: RegionId) -> Option<TaskWatcher<T>> {
188        self.inner
189            .read()
190            .await
191            .state
192            .get(®ion_id)
193            .map(|task| task.watcher.clone())
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use std::time::Duration;
200
201    use store_api::storage::RegionId;
202    use tokio::sync::oneshot;
203
204    use crate::heartbeat::task_tracker::{TaskTracker, wait};
205
206    #[derive(Debug, Clone, PartialEq, Eq)]
207    struct TestResult {
208        value: i32,
209    }
210
211    #[tokio::test]
212    async fn test_async_task_tracker_register() {
213        let tracker = TaskTracker::<TestResult>::new();
214        let region_id = RegionId::new(1024, 1);
215        let (tx, rx) = oneshot::channel::<()>();
216
217        let result = tracker
218            .try_register(
219                region_id,
220                Box::pin(async move {
221                    let _ = rx.await;
222                    Ok(TestResult { value: 1024 })
223                }),
224            )
225            .await;
226
227        assert!(result.is_running());
228
229        let result = tracker
230            .try_register(
231                region_id,
232                Box::pin(async move { Ok(TestResult { value: 1023 }) }),
233            )
234            .await;
235        assert!(result.is_busy());
236        let mut watcher = tracker.watcher(region_id).await.unwrap();
237        tx.send(()).unwrap();
239
240        assert_eq!(
241            TestResult { value: 1024 },
242            wait(&mut watcher).await.unwrap()
243        );
244        let result = tracker
245            .try_register(
246                region_id,
247                Box::pin(async move { Ok(TestResult { value: 1022 }) }),
248            )
249            .await;
250        assert!(result.is_running());
251    }
252
253    #[tokio::test]
254    async fn test_async_task_tracker_wait_timeout() {
255        let tracker = TaskTracker::<TestResult>::new();
256        let region_id = RegionId::new(1024, 1);
257        let (tx, rx) = oneshot::channel::<()>();
258
259        let result = tracker
260            .try_register(
261                region_id,
262                Box::pin(async move {
263                    let _ = rx.await;
264                    Ok(TestResult { value: 1024 })
265                }),
266            )
267            .await;
268
269        let mut watcher = result.into_watcher();
270        let result = tracker.wait(&mut watcher, Duration::from_millis(100)).await;
271        assert!(result.is_timeout());
272
273        tx.send(()).unwrap();
275        let result = tracker
276            .wait(&mut watcher, Duration::from_millis(100))
277            .await
278            .into_finish()
279            .unwrap()
280            .unwrap();
281        assert_eq!(TestResult { value: 1024 }, result);
282        assert!(tracker.watcher(region_id).await.is_none());
283    }
284}