servers/grpc/
cancellation.rs1use std::future::Future;
16
17use tokio::select;
18use tokio_util::sync::CancellationToken;
19
20type Result<T> = std::result::Result<tonic::Response<T>, tonic::Status>;
21
22pub(crate) async fn with_cancellation_handler<Request, Cancellation, Response>(
23 request: Request,
24 cancellation: Cancellation,
25) -> Result<Response>
26where
27 Request: Future<Output = Result<Response>> + Send + 'static,
28 Cancellation: Future<Output = Result<Response>> + Send + 'static,
29 Response: Send + 'static,
30{
31 let token = CancellationToken::new();
32 let _drop_guard = token.clone().drop_guard();
34 let select_task = tokio::spawn(async move {
35 select! {
38 res = request => res,
39 _ = token.cancelled() => cancellation.await,
40 }
41 });
42
43 select_task.await.unwrap()
44}
45
46#[cfg(test)]
47mod tests {
48 use std::time::Duration;
49
50 use tokio::sync::mpsc;
51 use tokio::time;
52 use tonic::Response;
53
54 use super::*;
55
56 #[tokio::test]
57 async fn test_request_completes_first() {
58 let request = async { Ok(Response::new("Request Completed")) };
59
60 let cancellation = async {
61 time::sleep(Duration::from_secs(1)).await;
62 Ok(Response::new("Cancelled"))
63 };
64
65 let result = with_cancellation_handler(request, cancellation).await;
66 assert_eq!(result.unwrap().into_inner(), "Request Completed");
67 }
68
69 #[tokio::test]
70 async fn test_cancellation_when_dropped() {
71 let (tx, mut rx) = mpsc::channel(2);
72 let tx_cloned = tx.clone();
73 let request = async move {
74 time::sleep(Duration::from_secs(1)).await;
75 tx_cloned.send("Request Completed").await.unwrap();
76 Ok(Response::new("Completed"))
77 };
78 let cancellation = async move {
79 tx.send("Request Cancelled").await.unwrap();
80 Ok(Response::new("Cancelled"))
81 };
82
83 let response_future = with_cancellation_handler(request, cancellation);
84 let result = time::timeout(Duration::from_millis(50), response_future).await;
86
87 assert!(result.is_err(), "Expected timeout error");
88 assert_eq!("Request Cancelled", rx.recv().await.unwrap())
89 }
90}