Skip to main content

cli/data/
retry.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![allow(dead_code)]
16
17use std::time::Duration;
18
19use backon::ExponentialBuilder;
20
21pub(crate) fn default_retry_policy() -> ExponentialBuilder {
22    ExponentialBuilder::default()
23        .with_min_delay(Duration::from_secs(1))
24        .with_max_delay(Duration::from_secs(300))
25        .with_factor(2.0)
26        // This is the number of retries after the initial attempt.
27        .with_max_times(3)
28        .with_jitter()
29}
30
31#[cfg(test)]
32mod tests {
33    use std::future::ready;
34    use std::sync::Arc;
35    use std::sync::atomic::{AtomicUsize, Ordering};
36
37    use backon::Retryable;
38
39    use super::*;
40
41    #[tokio::test]
42    async fn test_retry_policy_retries_retryable_error_until_success() {
43        let attempts = Arc::new(AtomicUsize::new(0));
44
45        let result = ({
46            let attempts = attempts.clone();
47            move || {
48                let attempts = attempts.clone();
49                async move {
50                    let current = attempts.fetch_add(1, Ordering::SeqCst);
51                    if current < 2 {
52                        Err("retryable")
53                    } else {
54                        Ok("done")
55                    }
56                }
57            }
58        })
59        .retry(default_retry_policy())
60        .when(|error| *error == "retryable")
61        .sleep(|_| ready(()))
62        .await;
63
64        assert_eq!(result, Ok("done"));
65        assert_eq!(attempts.load(Ordering::SeqCst), 3);
66    }
67
68    #[tokio::test]
69    async fn test_retry_policy_stops_on_non_retryable_error() {
70        let attempts = Arc::new(AtomicUsize::new(0));
71
72        let result: std::result::Result<(), &str> = ({
73            let attempts = attempts.clone();
74            move || {
75                let attempts = attempts.clone();
76                async move {
77                    attempts.fetch_add(1, Ordering::SeqCst);
78                    Err("fatal")
79                }
80            }
81        })
82        .retry(default_retry_policy())
83        .when(|error| *error == "retryable")
84        .sleep(|_| ready(()))
85        .await;
86
87        assert_eq!(result, Err("fatal"));
88        assert_eq!(attempts.load(Ordering::SeqCst), 1);
89    }
90
91    #[tokio::test]
92    async fn test_retry_policy_returns_last_error_after_reaching_limit() {
93        let attempts = Arc::new(AtomicUsize::new(0));
94
95        let result: std::result::Result<(), usize> = ({
96            let attempts = attempts.clone();
97            move || {
98                let attempts = attempts.clone();
99                async move {
100                    let current = attempts.fetch_add(1, Ordering::SeqCst);
101                    Err(current)
102                }
103            }
104        })
105        .retry(default_retry_policy().with_max_times(2))
106        .when(|_| true)
107        .sleep(|_| ready(()))
108        .await;
109
110        assert_eq!(result, Err(2));
111        assert_eq!(attempts.load(Ordering::SeqCst), 3);
112    }
113}