common_procedure/
rwlock.rs1use std::collections::HashMap;
16use std::hash::Hash;
17use std::sync::{Arc, Mutex};
18
19use tokio::sync::{OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock};
20
21pub enum OwnedKeyRwLockGuard {
26 Read { _guard: OwnedRwLockReadGuard<()> },
29
30 Write { _guard: OwnedRwLockWriteGuard<()> },
34}
35
36impl From<OwnedRwLockReadGuard<()>> for OwnedKeyRwLockGuard {
37 fn from(guard: OwnedRwLockReadGuard<()>) -> Self {
38 OwnedKeyRwLockGuard::Read { _guard: guard }
39 }
40}
41
42impl From<OwnedRwLockWriteGuard<()>> for OwnedKeyRwLockGuard {
43 fn from(guard: OwnedRwLockWriteGuard<()>) -> Self {
44 OwnedKeyRwLockGuard::Write { _guard: guard }
45 }
46}
47
48#[derive(Debug, Default)]
50pub struct KeyRwLock<K> {
51 inner: Mutex<HashMap<K, Arc<RwLock<()>>>>,
53}
54
55impl<K> KeyRwLock<K>
56where
57 K: Eq + Hash + Clone,
58{
59 pub fn new() -> Self {
60 KeyRwLock {
61 inner: Default::default(),
62 }
63 }
64
65 pub async fn read(&self, key: K) -> OwnedRwLockReadGuard<()> {
67 let lock = {
68 let mut locks = self.inner.lock().unwrap();
69 locks.entry(key).or_default().clone()
70 };
71
72 lock.read_owned().await
73 }
74
75 pub async fn write(&self, key: K) -> OwnedRwLockWriteGuard<()> {
77 let lock = {
78 let mut locks = self.inner.lock().unwrap();
79 locks.entry(key).or_default().clone()
80 };
81
82 lock.write_owned().await
83 }
84
85 pub fn clean_keys<'a>(&'a self, iter: impl IntoIterator<Item = &'a K>) {
91 let mut locks = self.inner.lock().unwrap();
92 let mut keys = Vec::new();
93 for key in iter {
94 if let Some(lock) = locks.get(key) {
95 if lock.try_write().is_ok() {
96 debug_assert_eq!(Arc::weak_count(lock), 0);
97 if Arc::strong_count(lock) == 1 {
99 keys.push(key);
100 }
101 }
102 }
103 }
104
105 for key in keys {
106 locks.remove(key);
107 }
108 }
109}
110
111#[cfg(test)]
112impl<K> KeyRwLock<K>
113where
114 K: Eq + Hash + Clone,
115{
116 pub fn try_read(&self, key: K) -> Result<OwnedRwLockReadGuard<()>, tokio::sync::TryLockError> {
118 let lock = {
119 let mut locks = self.inner.lock().unwrap();
120 locks.entry(key).or_default().clone()
121 };
122
123 lock.try_read_owned()
124 }
125
126 pub fn try_write(
128 &self,
129 key: K,
130 ) -> Result<OwnedRwLockWriteGuard<()>, tokio::sync::TryLockError> {
131 let lock = {
132 let mut locks = self.inner.lock().unwrap();
133 locks.entry(key).or_default().clone()
134 };
135
136 lock.try_write_owned()
137 }
138
139 pub fn len(&self) -> usize {
141 self.inner.lock().unwrap().len()
142 }
143
144 pub fn is_empty(&self) -> bool {
146 self.len() == 0
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 #[tokio::test]
155 async fn test_naive() {
156 let lock_key = KeyRwLock::new();
157
158 {
159 let _guard = lock_key.read("test1").await;
160 assert_eq!(lock_key.len(), 1);
161 assert!(lock_key.try_read("test1").is_ok());
162 assert!(lock_key.try_write("test1").is_err());
163 }
164
165 {
166 let _guard0 = lock_key.write("test2").await;
167 let _guard = lock_key.write("test1").await;
168 assert_eq!(lock_key.len(), 2);
169 assert!(lock_key.try_read("test1").is_err());
170 assert!(lock_key.try_write("test1").is_err());
171 }
172
173 assert_eq!(lock_key.len(), 2);
174
175 lock_key.clean_keys(&vec!["test1", "test2"]);
176 assert!(lock_key.is_empty());
177
178 let mut guards = Vec::new();
179 for key in ["test1", "test2"] {
180 guards.push(lock_key.read(key).await);
181 }
182 while !guards.is_empty() {
183 guards.pop();
184 }
185 lock_key.clean_keys(vec![&"test1", &"test2"]);
186 assert_eq!(lock_key.len(), 0);
187 }
188
189 #[tokio::test]
190 async fn test_clean_keys() {
191 let lock_key = KeyRwLock::<&str>::new();
192 {
193 let rwlock = {
194 lock_key
195 .inner
196 .lock()
197 .unwrap()
198 .entry("test")
199 .or_default()
200 .clone()
201 };
202 assert_eq!(Arc::strong_count(&rwlock), 2);
203 let _guard = rwlock.read_owned().await;
204
205 {
206 let inner = lock_key.inner.lock().unwrap();
207 let rwlock = inner.get("test").unwrap();
208 assert_eq!(Arc::strong_count(rwlock), 2);
209 }
210 }
211
212 {
213 let rwlock = {
214 lock_key
215 .inner
216 .lock()
217 .unwrap()
218 .entry("test")
219 .or_default()
220 .clone()
221 };
222 assert_eq!(Arc::strong_count(&rwlock), 2);
223 let _guard = rwlock.write_owned().await;
224
225 {
226 let inner = lock_key.inner.lock().unwrap();
227 let rwlock = inner.get("test").unwrap();
228 assert_eq!(Arc::strong_count(rwlock), 2);
229 }
230 }
231
232 {
233 let inner = lock_key.inner.lock().unwrap();
234 let rwlock = inner.get("test").unwrap();
235 assert_eq!(Arc::strong_count(rwlock), 1);
236 }
237
238 let rwlock = {
240 lock_key
241 .inner
242 .lock()
243 .unwrap()
244 .entry("test")
245 .or_default()
246 .clone()
247 };
248 assert_eq!(Arc::strong_count(&rwlock), 2);
249 lock_key.clean_keys(vec![&"test"]);
251 {
253 let inner = lock_key.inner.lock().unwrap();
254 inner.get("test").unwrap();
255 }
256 }
257}