servers/
request_limiter.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Request memory limiter for controlling total memory usage of concurrent requests.
16
17use std::sync::Arc;
18use std::sync::atomic::{AtomicUsize, Ordering};
19
20use crate::error::{Result, TooManyConcurrentRequestsSnafu};
21
22/// Limiter for total memory usage of concurrent request bodies.
23///
24/// Tracks the total memory used by all concurrent request bodies
25/// and rejects new requests when the limit is reached.
26#[derive(Clone, Default)]
27pub struct RequestMemoryLimiter {
28    inner: Option<Arc<LimiterInner>>,
29}
30
31struct LimiterInner {
32    current_usage: AtomicUsize,
33    max_memory: usize,
34}
35
36impl RequestMemoryLimiter {
37    /// Create a new memory limiter.
38    ///
39    /// # Arguments
40    /// * `max_memory` - Maximum total memory for all concurrent request bodies in bytes (0 = unlimited)
41    pub fn new(max_memory: usize) -> Self {
42        if max_memory == 0 {
43            return Self { inner: None };
44        }
45
46        Self {
47            inner: Some(Arc::new(LimiterInner {
48                current_usage: AtomicUsize::new(0),
49                max_memory,
50            })),
51        }
52    }
53
54    /// Try to acquire memory for a request of given size.
55    ///
56    /// Returns `Ok(RequestMemoryGuard)` if memory was acquired successfully.
57    /// Returns `Err` if the memory limit would be exceeded.
58    pub fn try_acquire(&self, request_size: usize) -> Result<Option<RequestMemoryGuard>> {
59        let Some(inner) = self.inner.as_ref() else {
60            return Ok(None);
61        };
62
63        let mut new_usage = 0;
64        let result =
65            inner
66                .current_usage
67                .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
68                    new_usage = current.saturating_add(request_size);
69                    if new_usage <= inner.max_memory {
70                        Some(new_usage)
71                    } else {
72                        None
73                    }
74                });
75
76        match result {
77            Ok(_) => Ok(Some(RequestMemoryGuard {
78                size: request_size,
79                limiter: Arc::clone(inner),
80                usage_snapshot: new_usage,
81            })),
82            Err(_current) => TooManyConcurrentRequestsSnafu {
83                limit: inner.max_memory,
84                request_size,
85            }
86            .fail(),
87        }
88    }
89
90    /// Check if limiter is enabled
91    pub fn is_enabled(&self) -> bool {
92        self.inner.is_some()
93    }
94
95    /// Get current memory usage
96    pub fn current_usage(&self) -> usize {
97        self.inner
98            .as_ref()
99            .map(|inner| inner.current_usage.load(Ordering::Relaxed))
100            .unwrap_or(0)
101    }
102
103    /// Get max memory limit
104    pub fn max_memory(&self) -> usize {
105        self.inner
106            .as_ref()
107            .map(|inner| inner.max_memory)
108            .unwrap_or(0)
109    }
110}
111
112/// RAII guard that releases memory when dropped
113pub struct RequestMemoryGuard {
114    size: usize,
115    limiter: Arc<LimiterInner>,
116    usage_snapshot: usize,
117}
118
119impl RequestMemoryGuard {
120    /// Returns the total memory usage snapshot at the time this guard was acquired.
121    pub fn current_usage(&self) -> usize {
122        self.usage_snapshot
123    }
124}
125
126impl Drop for RequestMemoryGuard {
127    fn drop(&mut self) {
128        self.limiter
129            .current_usage
130            .fetch_sub(self.size, Ordering::Release);
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    #[test]
139    fn test_limiter_disabled() {
140        let limiter = RequestMemoryLimiter::new(0);
141        assert!(!limiter.is_enabled());
142        assert!(limiter.try_acquire(1000000).unwrap().is_none());
143        assert_eq!(limiter.current_usage(), 0);
144    }
145
146    #[test]
147    fn test_limiter_basic() {
148        let limiter = RequestMemoryLimiter::new(1000);
149        assert!(limiter.is_enabled());
150        assert_eq!(limiter.max_memory(), 1000);
151        assert_eq!(limiter.current_usage(), 0);
152
153        // Acquire 400 bytes
154        let _guard1 = limiter.try_acquire(400).unwrap();
155        assert_eq!(limiter.current_usage(), 400);
156
157        // Acquire another 500 bytes
158        let _guard2 = limiter.try_acquire(500).unwrap();
159        assert_eq!(limiter.current_usage(), 900);
160
161        // Try to acquire 200 bytes - should fail (900 + 200 > 1000)
162        let result = limiter.try_acquire(200);
163        assert!(result.is_err());
164        assert_eq!(limiter.current_usage(), 900);
165
166        // Drop first guard
167        drop(_guard1);
168        assert_eq!(limiter.current_usage(), 500);
169
170        // Now we can acquire 200 bytes
171        let _guard3 = limiter.try_acquire(200).unwrap();
172        assert_eq!(limiter.current_usage(), 700);
173    }
174
175    #[test]
176    fn test_limiter_exact_limit() {
177        let limiter = RequestMemoryLimiter::new(1000);
178
179        // Acquire exactly the limit
180        let _guard = limiter.try_acquire(1000).unwrap();
181        assert_eq!(limiter.current_usage(), 1000);
182
183        // Try to acquire 1 more byte - should fail
184        let result = limiter.try_acquire(1);
185        assert!(result.is_err());
186    }
187
188    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
189    async fn test_limiter_concurrent() {
190        let limiter = RequestMemoryLimiter::new(1000);
191        let mut handles = vec![];
192
193        // Spawn 10 tasks each trying to acquire 200 bytes
194        for _ in 0..10 {
195            let limiter_clone = limiter.clone();
196            let handle = tokio::spawn(async move { limiter_clone.try_acquire(200) });
197            handles.push(handle);
198        }
199
200        let mut success_count = 0;
201        let mut fail_count = 0;
202
203        for handle in handles {
204            match handle.await.unwrap() {
205                Ok(Some(_)) => success_count += 1,
206                Err(_) => fail_count += 1,
207                Ok(None) => unreachable!(),
208            }
209        }
210
211        // Only 5 tasks should succeed (5 * 200 = 1000)
212        assert_eq!(success_count, 5);
213        assert_eq!(fail_count, 5);
214    }
215}