common_base/
cancellation.rs1use std::fmt::{Debug, Display, Formatter};
20use std::future::Future;
21use std::pin::Pin;
22use std::sync::atomic::{AtomicBool, Ordering};
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26use futures::task::AtomicWaker;
27use pin_project::pin_project;
28
29#[derive(Default)]
30pub struct CancellationHandle {
31 waker: AtomicWaker,
32 cancelled: AtomicBool,
33}
34
35impl Debug for CancellationHandle {
36 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
37 f.debug_struct("CancellationHandle")
38 .field("cancelled", &self.is_cancelled())
39 .finish()
40 }
41}
42
43impl CancellationHandle {
44 pub fn waker(&self) -> &AtomicWaker {
45 &self.waker
46 }
47
48 pub fn cancel(&self) {
50 if self
51 .cancelled
52 .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
53 .is_ok()
54 {
55 self.waker.wake();
56 }
57 }
58
59 pub fn is_cancelled(&self) -> bool {
61 self.cancelled.load(Ordering::Relaxed)
62 }
63}
64
65#[pin_project]
66#[derive(Debug, Clone)]
67pub struct CancellableFuture<T> {
68 #[pin]
69 fut: T,
70 handle: Arc<CancellationHandle>,
71}
72
73impl<T> CancellableFuture<T> {
74 pub fn new(fut: T, handle: Arc<CancellationHandle>) -> Self {
75 Self { fut, handle }
76 }
77}
78
79impl<T> Future for CancellableFuture<T>
80where
81 T: Future,
82{
83 type Output = Result<T::Output, Cancelled>;
84
85 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
86 let this = self.as_mut().project();
87 if this.handle.is_cancelled() {
89 return Poll::Ready(Err(Cancelled));
90 }
91
92 if let Poll::Ready(x) = this.fut.poll(cx) {
93 return Poll::Ready(Ok(x));
94 }
95
96 this.handle.waker().register(cx.waker());
97 if this.handle.is_cancelled() {
98 return Poll::Ready(Err(Cancelled));
99 }
100 Poll::Pending
101 }
102}
103
104#[derive(Copy, Clone, Debug)]
105pub struct Cancelled;
106
107impl Display for Cancelled {
108 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
109 write!(f, "Future has been cancelled")
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use std::sync::Arc;
116 use std::time::Duration;
117
118 use tokio::time::{sleep, timeout};
119
120 use crate::cancellation::{CancellableFuture, CancellationHandle, Cancelled};
121
122 #[tokio::test]
123 async fn test_cancellable_future_completes_normally() {
124 let handle = Arc::new(CancellationHandle::default());
125 let future = async { 42 };
126 let cancellable = CancellableFuture::new(future, handle);
127
128 let result = cancellable.await;
129 assert!(result.is_ok());
130 assert_eq!(result.unwrap(), 42);
131 }
132
133 #[tokio::test]
134 async fn test_cancellable_future_cancelled_before_start() {
135 let handle = Arc::new(CancellationHandle::default());
136 handle.cancel();
137
138 let future = async { 42 };
139 let cancellable = CancellableFuture::new(future, handle);
140
141 let result = cancellable.await;
142 assert!(result.is_err());
143 assert!(matches!(result.unwrap_err(), Cancelled));
144 }
145
146 #[tokio::test]
147 async fn test_cancellable_future_cancelled_during_execution() {
148 let handle = Arc::new(CancellationHandle::default());
149 let handle_clone = handle.clone();
150
151 let future = async {
153 sleep(Duration::from_secs(10)).await;
154 42
155 };
156 let cancellable = CancellableFuture::new(future, handle);
157
158 tokio::spawn(async move {
160 sleep(Duration::from_millis(50)).await;
161 handle_clone.cancel();
162 });
163
164 let result = cancellable.await;
165 assert!(result.is_err());
166 assert!(matches!(result.unwrap_err(), Cancelled));
167 }
168
169 #[tokio::test]
170 async fn test_cancellable_future_completes_before_cancellation() {
171 let handle = Arc::new(CancellationHandle::default());
172 let handle_clone = handle.clone();
173
174 let future = async {
176 sleep(Duration::from_millis(10)).await;
177 42
178 };
179 let cancellable = CancellableFuture::new(future, handle);
180
181 tokio::spawn(async move {
183 sleep(Duration::from_millis(100)).await;
184 handle_clone.cancel();
185 });
186
187 let result = cancellable.await;
188 assert!(result.is_ok());
189 assert_eq!(result.unwrap(), 42);
190 }
191
192 #[tokio::test]
193 async fn test_cancellation_handle_is_cancelled() {
194 let handle = CancellationHandle::default();
195 assert!(!handle.is_cancelled());
196
197 handle.cancel();
198 assert!(handle.is_cancelled());
199 }
200
201 #[tokio::test]
202 async fn test_multiple_cancellable_futures_with_same_handle() {
203 let handle = Arc::new(CancellationHandle::default());
204
205 let future1 = CancellableFuture::new(async { 1 }, handle.clone());
206 let future2 = CancellableFuture::new(async { 2 }, handle.clone());
207
208 handle.cancel();
210
211 let (result1, result2) = tokio::join!(future1, future2);
212
213 assert!(result1.is_err());
214 assert!(result2.is_err());
215 assert!(matches!(result1.unwrap_err(), Cancelled));
216 assert!(matches!(result2.unwrap_err(), Cancelled));
217 }
218
219 #[tokio::test]
220 async fn test_cancellable_future_with_timeout() {
221 let handle = Arc::new(CancellationHandle::default());
222 let future = async {
223 sleep(Duration::from_secs(1)).await;
224 42
225 };
226 let cancellable = CancellableFuture::new(future, handle.clone());
227
228 let result = timeout(Duration::from_millis(100), cancellable).await;
230
231 assert!(result.is_err());
233 }
234
235 #[tokio::test]
236 async fn test_cancelled_display() {
237 let cancelled = Cancelled;
238 assert_eq!(format!("{}", cancelled), "Future has been cancelled");
239 }
240}