servers/
request_limiter.rs1use std::sync::Arc;
18use std::sync::atomic::{AtomicUsize, Ordering};
19
20use crate::error::{Result, TooManyConcurrentRequestsSnafu};
21
22#[derive(Clone, Default)]
27pub struct RequestMemoryLimiter {
28 inner: Option<Arc<LimiterInner>>,
29}
30
31struct LimiterInner {
32 current_usage: AtomicUsize,
33 max_memory: usize,
34}
35
36impl RequestMemoryLimiter {
37 pub fn new(max_memory: usize) -> Self {
42 if max_memory == 0 {
43 return Self { inner: None };
44 }
45
46 Self {
47 inner: Some(Arc::new(LimiterInner {
48 current_usage: AtomicUsize::new(0),
49 max_memory,
50 })),
51 }
52 }
53
54 pub fn try_acquire(&self, request_size: usize) -> Result<Option<RequestMemoryGuard>> {
59 let Some(inner) = self.inner.as_ref() else {
60 return Ok(None);
61 };
62
63 let mut new_usage = 0;
64 let result =
65 inner
66 .current_usage
67 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
68 new_usage = current.saturating_add(request_size);
69 if new_usage <= inner.max_memory {
70 Some(new_usage)
71 } else {
72 None
73 }
74 });
75
76 match result {
77 Ok(_) => Ok(Some(RequestMemoryGuard {
78 size: request_size,
79 limiter: Arc::clone(inner),
80 usage_snapshot: new_usage,
81 })),
82 Err(_current) => TooManyConcurrentRequestsSnafu {
83 limit: inner.max_memory,
84 request_size,
85 }
86 .fail(),
87 }
88 }
89
90 pub fn is_enabled(&self) -> bool {
92 self.inner.is_some()
93 }
94
95 pub fn current_usage(&self) -> usize {
97 self.inner
98 .as_ref()
99 .map(|inner| inner.current_usage.load(Ordering::Relaxed))
100 .unwrap_or(0)
101 }
102
103 pub fn max_memory(&self) -> usize {
105 self.inner
106 .as_ref()
107 .map(|inner| inner.max_memory)
108 .unwrap_or(0)
109 }
110}
111
112pub struct RequestMemoryGuard {
114 size: usize,
115 limiter: Arc<LimiterInner>,
116 usage_snapshot: usize,
117}
118
119impl RequestMemoryGuard {
120 pub fn current_usage(&self) -> usize {
122 self.usage_snapshot
123 }
124}
125
126impl Drop for RequestMemoryGuard {
127 fn drop(&mut self) {
128 self.limiter
129 .current_usage
130 .fetch_sub(self.size, Ordering::Release);
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137
138 #[test]
139 fn test_limiter_disabled() {
140 let limiter = RequestMemoryLimiter::new(0);
141 assert!(!limiter.is_enabled());
142 assert!(limiter.try_acquire(1000000).unwrap().is_none());
143 assert_eq!(limiter.current_usage(), 0);
144 }
145
146 #[test]
147 fn test_limiter_basic() {
148 let limiter = RequestMemoryLimiter::new(1000);
149 assert!(limiter.is_enabled());
150 assert_eq!(limiter.max_memory(), 1000);
151 assert_eq!(limiter.current_usage(), 0);
152
153 let _guard1 = limiter.try_acquire(400).unwrap();
155 assert_eq!(limiter.current_usage(), 400);
156
157 let _guard2 = limiter.try_acquire(500).unwrap();
159 assert_eq!(limiter.current_usage(), 900);
160
161 let result = limiter.try_acquire(200);
163 assert!(result.is_err());
164 assert_eq!(limiter.current_usage(), 900);
165
166 drop(_guard1);
168 assert_eq!(limiter.current_usage(), 500);
169
170 let _guard3 = limiter.try_acquire(200).unwrap();
172 assert_eq!(limiter.current_usage(), 700);
173 }
174
175 #[test]
176 fn test_limiter_exact_limit() {
177 let limiter = RequestMemoryLimiter::new(1000);
178
179 let _guard = limiter.try_acquire(1000).unwrap();
181 assert_eq!(limiter.current_usage(), 1000);
182
183 let result = limiter.try_acquire(1);
185 assert!(result.is_err());
186 }
187
188 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
189 async fn test_limiter_concurrent() {
190 let limiter = RequestMemoryLimiter::new(1000);
191 let mut handles = vec![];
192
193 for _ in 0..10 {
195 let limiter_clone = limiter.clone();
196 let handle = tokio::spawn(async move { limiter_clone.try_acquire(200) });
197 handles.push(handle);
198 }
199
200 let mut success_count = 0;
201 let mut fail_count = 0;
202
203 for handle in handles {
204 match handle.await.unwrap() {
205 Ok(Some(_)) => success_count += 1,
206 Err(_) => fail_count += 1,
207 Ok(None) => unreachable!(),
208 }
209 }
210
211 assert_eq!(success_count, 5);
213 assert_eq!(fail_count, 5);
214 }
215}