servers/grpc/
cancellation.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::future::Future;
16
17use tokio::select;
18use tokio_util::sync::CancellationToken;
19
20type Result<T> = std::result::Result<tonic::Response<T>, tonic::Status>;
21
22pub(crate) async fn with_cancellation_handler<Request, Cancellation, Response>(
23    request: Request,
24    cancellation: Cancellation,
25) -> Result<Response>
26where
27    Request: Future<Output = Result<Response>> + Send + 'static,
28    Cancellation: Future<Output = Result<Response>> + Send + 'static,
29    Response: Send + 'static,
30{
31    let token = CancellationToken::new();
32    // Will call token.cancel() when the future is dropped, such as when the client cancels the request
33    let _drop_guard = token.clone().drop_guard();
34    let select_task = tokio::spawn(async move {
35        // Can select on token cancellation on any cancellable future while handling the request,
36        // allowing for custom cleanup code or monitoring
37        select! {
38            res = request => res,
39            _ = token.cancelled() => cancellation.await,
40        }
41    });
42
43    select_task.await.unwrap()
44}
45
46#[cfg(test)]
47mod tests {
48    use std::time::Duration;
49
50    use tokio::sync::mpsc;
51    use tokio::time;
52    use tonic::Response;
53
54    use super::*;
55
56    #[tokio::test]
57    async fn test_request_completes_first() {
58        let request = async { Ok(Response::new("Request Completed")) };
59
60        let cancellation = async {
61            time::sleep(Duration::from_secs(1)).await;
62            Ok(Response::new("Cancelled"))
63        };
64
65        let result = with_cancellation_handler(request, cancellation).await;
66        assert_eq!(result.unwrap().into_inner(), "Request Completed");
67    }
68
69    #[tokio::test]
70    async fn test_cancellation_when_dropped() {
71        let (tx, mut rx) = mpsc::channel(2);
72        let tx_cloned = tx.clone();
73        let request = async move {
74            time::sleep(Duration::from_secs(1)).await;
75            tx_cloned.send("Request Completed").await.unwrap();
76            Ok(Response::new("Completed"))
77        };
78        let cancellation = async move {
79            tx.send("Request Cancelled").await.unwrap();
80            Ok(Response::new("Cancelled"))
81        };
82
83        let response_future = with_cancellation_handler(request, cancellation);
84        // It will drop the `response_future` and then call the `cancellation` future
85        let result = time::timeout(Duration::from_millis(50), response_future).await;
86
87        assert!(result.is_err(), "Expected timeout error");
88        assert_eq!("Request Cancelled", rx.recv().await.unwrap())
89    }
90}