common_base/
cancellation.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [CancellationHandle] is used to compose with manual implementation of [futures::future::Future]
16//! or [futures::stream::Stream] to facilitate cancellation.
17//! See example in [frontend::stream_wrapper::CancellableStreamWrapper] and [CancellableFuture].
18
19use std::fmt::{Debug, Display, Formatter};
20use std::future::Future;
21use std::pin::Pin;
22use std::sync::atomic::{AtomicBool, Ordering};
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26use futures::task::AtomicWaker;
27use pin_project::pin_project;
28
29#[derive(Default)]
30pub struct CancellationHandle {
31    waker: AtomicWaker,
32    cancelled: AtomicBool,
33}
34
35impl Debug for CancellationHandle {
36    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
37        f.debug_struct("CancellationHandle")
38            .field("cancelled", &self.is_cancelled())
39            .finish()
40    }
41}
42
43impl CancellationHandle {
44    pub fn waker(&self) -> &AtomicWaker {
45        &self.waker
46    }
47
48    /// Cancels a future or stream.
49    pub fn cancel(&self) {
50        if self
51            .cancelled
52            .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
53            .is_ok()
54        {
55            self.waker.wake();
56        }
57    }
58
59    /// Is this handle cancelled.
60    pub fn is_cancelled(&self) -> bool {
61        self.cancelled.load(Ordering::Relaxed)
62    }
63}
64
65#[pin_project]
66#[derive(Debug, Clone)]
67pub struct CancellableFuture<T> {
68    #[pin]
69    fut: T,
70    handle: Arc<CancellationHandle>,
71}
72
73impl<T> CancellableFuture<T> {
74    pub fn new(fut: T, handle: Arc<CancellationHandle>) -> Self {
75        Self { fut, handle }
76    }
77}
78
79impl<T> Future for CancellableFuture<T>
80where
81    T: Future,
82{
83    type Output = Result<T::Output, Cancelled>;
84
85    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
86        let this = self.as_mut().project();
87        // Check if the task has been aborted
88        if this.handle.is_cancelled() {
89            return Poll::Ready(Err(Cancelled));
90        }
91
92        if let Poll::Ready(x) = this.fut.poll(cx) {
93            return Poll::Ready(Ok(x));
94        }
95
96        this.handle.waker().register(cx.waker());
97        if this.handle.is_cancelled() {
98            return Poll::Ready(Err(Cancelled));
99        }
100        Poll::Pending
101    }
102}
103
104#[derive(Copy, Clone, Debug)]
105pub struct Cancelled;
106
107impl Display for Cancelled {
108    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
109        write!(f, "Future has been cancelled")
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use std::sync::Arc;
116    use std::time::Duration;
117
118    use tokio::time::{sleep, timeout};
119
120    use crate::cancellation::{CancellableFuture, CancellationHandle, Cancelled};
121
122    #[tokio::test]
123    async fn test_cancellable_future_completes_normally() {
124        let handle = Arc::new(CancellationHandle::default());
125        let future = async { 42 };
126        let cancellable = CancellableFuture::new(future, handle);
127
128        let result = cancellable.await;
129        assert!(result.is_ok());
130        assert_eq!(result.unwrap(), 42);
131    }
132
133    #[tokio::test]
134    async fn test_cancellable_future_cancelled_before_start() {
135        let handle = Arc::new(CancellationHandle::default());
136        handle.cancel();
137
138        let future = async { 42 };
139        let cancellable = CancellableFuture::new(future, handle);
140
141        let result = cancellable.await;
142        assert!(result.is_err());
143        assert!(matches!(result.unwrap_err(), Cancelled));
144    }
145
146    #[tokio::test]
147    async fn test_cancellable_future_cancelled_during_execution() {
148        let handle = Arc::new(CancellationHandle::default());
149        let handle_clone = handle.clone();
150
151        // Create a future that sleeps for a long time
152        let future = async {
153            sleep(Duration::from_secs(10)).await;
154            42
155        };
156        let cancellable = CancellableFuture::new(future, handle);
157
158        // Cancel the future after a short delay
159        tokio::spawn(async move {
160            sleep(Duration::from_millis(50)).await;
161            handle_clone.cancel();
162        });
163
164        let result = cancellable.await;
165        assert!(result.is_err());
166        assert!(matches!(result.unwrap_err(), Cancelled));
167    }
168
169    #[tokio::test]
170    async fn test_cancellable_future_completes_before_cancellation() {
171        let handle = Arc::new(CancellationHandle::default());
172        let handle_clone = handle.clone();
173
174        // Create a future that completes quickly
175        let future = async {
176            sleep(Duration::from_millis(10)).await;
177            42
178        };
179        let cancellable = CancellableFuture::new(future, handle);
180
181        // Try to cancel after the future should have completed
182        tokio::spawn(async move {
183            sleep(Duration::from_millis(100)).await;
184            handle_clone.cancel();
185        });
186
187        let result = cancellable.await;
188        assert!(result.is_ok());
189        assert_eq!(result.unwrap(), 42);
190    }
191
192    #[tokio::test]
193    async fn test_cancellation_handle_is_cancelled() {
194        let handle = CancellationHandle::default();
195        assert!(!handle.is_cancelled());
196
197        handle.cancel();
198        assert!(handle.is_cancelled());
199    }
200
201    #[tokio::test]
202    async fn test_multiple_cancellable_futures_with_same_handle() {
203        let handle = Arc::new(CancellationHandle::default());
204
205        let future1 = CancellableFuture::new(async { 1 }, handle.clone());
206        let future2 = CancellableFuture::new(async { 2 }, handle.clone());
207
208        // Cancel before starting
209        handle.cancel();
210
211        let (result1, result2) = tokio::join!(future1, future2);
212
213        assert!(result1.is_err());
214        assert!(result2.is_err());
215        assert!(matches!(result1.unwrap_err(), Cancelled));
216        assert!(matches!(result2.unwrap_err(), Cancelled));
217    }
218
219    #[tokio::test]
220    async fn test_cancellable_future_with_timeout() {
221        let handle = Arc::new(CancellationHandle::default());
222        let future = async {
223            sleep(Duration::from_secs(1)).await;
224            42
225        };
226        let cancellable = CancellableFuture::new(future, handle.clone());
227
228        // Use timeout to ensure the test doesn't hang
229        let result = timeout(Duration::from_millis(100), cancellable).await;
230
231        // Should timeout because the future takes 1 second but we timeout after 100ms
232        assert!(result.is_err());
233    }
234
235    #[tokio::test]
236    async fn test_cancelled_display() {
237        let cancelled = Cancelled;
238        assert_eq!(format!("{}", cancelled), "Future has been cancelled");
239    }
240}