servers/
request_limiter.rs1use std::sync::Arc;
18use std::sync::atomic::{AtomicUsize, Ordering};
19
20use crate::error::{Result, TooManyConcurrentRequestsSnafu};
21
22#[derive(Clone, Default)]
27pub struct RequestMemoryLimiter {
28 inner: Option<Arc<LimiterInner>>,
29}
30
31struct LimiterInner {
32 current_usage: AtomicUsize,
33 max_memory: usize,
34}
35
36impl RequestMemoryLimiter {
37 pub fn new(max_memory: usize) -> Self {
42 if max_memory == 0 {
43 return Self { inner: None };
44 }
45
46 Self {
47 inner: Some(Arc::new(LimiterInner {
48 current_usage: AtomicUsize::new(0),
49 max_memory,
50 })),
51 }
52 }
53
54 pub fn try_acquire(&self, request_size: usize) -> Result<Option<RequestMemoryGuard>> {
59 let Some(inner) = self.inner.as_ref() else {
60 return Ok(None);
61 };
62
63 let mut new_usage = 0;
64 let result =
65 inner
66 .current_usage
67 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
68 new_usage = current.saturating_add(request_size);
69 if new_usage <= inner.max_memory {
70 Some(new_usage)
71 } else {
72 None
73 }
74 });
75
76 match result {
77 Ok(_) => Ok(Some(RequestMemoryGuard {
78 size: request_size,
79 limiter: Arc::clone(inner),
80 usage_snapshot: new_usage,
81 })),
82 Err(_current) => TooManyConcurrentRequestsSnafu {
83 limit: inner.max_memory,
84 request_size,
85 }
86 .fail(),
87 }
88 }
89
90 pub fn is_enabled(&self) -> bool {
92 self.inner.is_some()
93 }
94
95 pub fn current_usage(&self) -> usize {
97 self.inner
98 .as_ref()
99 .map(|inner| inner.current_usage.load(Ordering::Relaxed))
100 .unwrap_or(0)
101 }
102
103 pub fn max_memory(&self) -> usize {
105 self.inner
106 .as_ref()
107 .map(|inner| inner.max_memory)
108 .unwrap_or(0)
109 }
110}
111
112pub struct RequestMemoryGuard {
114 size: usize,
115 limiter: Arc<LimiterInner>,
116 usage_snapshot: usize,
117}
118
119impl RequestMemoryGuard {
120 pub fn current_usage(&self) -> usize {
122 self.usage_snapshot
123 }
124}
125
126impl Drop for RequestMemoryGuard {
127 fn drop(&mut self) {
128 self.limiter
129 .current_usage
130 .fetch_sub(self.size, Ordering::Release);
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use tokio::sync::Barrier;
137
138 use super::*;
139
140 #[test]
141 fn test_limiter_disabled() {
142 let limiter = RequestMemoryLimiter::new(0);
143 assert!(!limiter.is_enabled());
144 assert!(limiter.try_acquire(1000000).unwrap().is_none());
145 assert_eq!(limiter.current_usage(), 0);
146 }
147
148 #[test]
149 fn test_limiter_basic() {
150 let limiter = RequestMemoryLimiter::new(1000);
151 assert!(limiter.is_enabled());
152 assert_eq!(limiter.max_memory(), 1000);
153 assert_eq!(limiter.current_usage(), 0);
154
155 let _guard1 = limiter.try_acquire(400).unwrap();
157 assert_eq!(limiter.current_usage(), 400);
158
159 let _guard2 = limiter.try_acquire(500).unwrap();
161 assert_eq!(limiter.current_usage(), 900);
162
163 let result = limiter.try_acquire(200);
165 assert!(result.is_err());
166 assert_eq!(limiter.current_usage(), 900);
167
168 drop(_guard1);
170 assert_eq!(limiter.current_usage(), 500);
171
172 let _guard3 = limiter.try_acquire(200).unwrap();
174 assert_eq!(limiter.current_usage(), 700);
175 }
176
177 #[test]
178 fn test_limiter_exact_limit() {
179 let limiter = RequestMemoryLimiter::new(1000);
180
181 let _guard = limiter.try_acquire(1000).unwrap();
183 assert_eq!(limiter.current_usage(), 1000);
184
185 let result = limiter.try_acquire(1);
187 assert!(result.is_err());
188 }
189
190 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
191 async fn test_limiter_concurrent() {
192 let limiter = RequestMemoryLimiter::new(1000);
193 let barrier = Arc::new(Barrier::new(11)); let mut handles = vec![];
195
196 for _ in 0..10 {
198 let limiter_clone = limiter.clone();
199 let barrier_clone = barrier.clone();
200 let handle = tokio::spawn(async move {
201 barrier_clone.wait().await;
202 limiter_clone.try_acquire(200)
203 });
204 handles.push(handle);
205 }
206
207 barrier.wait().await;
209
210 let mut success_count = 0;
211 let mut fail_count = 0;
212 let mut guards = Vec::new();
213
214 for handle in handles {
215 match handle.await.unwrap() {
216 Ok(Some(guard)) => {
217 success_count += 1;
218 guards.push(guard);
219 }
220 Err(_) => fail_count += 1,
221 Ok(None) => unreachable!(),
222 }
223 }
224
225 assert_eq!(success_count, 5);
227 assert_eq!(fail_count, 5);
228 drop(guards);
229 }
230}